Skip to content

checkpoint.streamer

CheckpointManager

Bases: object

Custom msgpack checkpointer that saves large train states by serializing and saving tensors one by one in a streaming fashion. Avoids running out of memory or local TPU disk with default flax checkpointer.

Source code in src/fjformer/checkpoint/streamer.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
class CheckpointManager(object):
    """
    Custom msgpack checkpointer that saves large train states by serializing
    and saving tensors one by one in a streaming fashion. Avoids running
    out of memory or local TPU disk with default flax checkpointer.
    """

    def __init__(
            self,
            checkpoint_dir,
            enable=True,
            float_dtype: Union[str, jnp.dtype] = "bf16",
            save_optimizer_state: bool = True,
            verbose: bool = False
    ):
        self.float_dtype = float_dtype
        self.save_optimizer_state = save_optimizer_state
        self.checkpoint_dir = checkpoint_dir
        self.enable = enable
        self.verbose = verbose

    def save_checkpoint(
            self,
            state: struct.PyTreeNode,
            filename: Union[str, os.PathLike],
            gather_fns: dict[Callable] = None,
            mismatch_allowed: bool = True

    ):
        if self.enable:
            path = os.path.join(self.checkpoint_dir, filename)
        else:
            path = "/dev/null"
        self.save_state_to_file(
            state, path, gather_fns, self.float_dtype, mismatch_allowed=mismatch_allowed
        )

    @staticmethod
    def save_state_to_file(
            state: struct.PyTreeNode,
            path: Union[str, os.PathLike],
            gather_fns: dict[Callable] = None,
            float_dtype=None,
            verbose: bool = False,
            mismatch_allowed: bool = True
    ):
        state = to_state_dict(state)
        packer = msgpack.Packer()
        flatten_state = flatten_dict(state)
        if gather_fns is not None:
            gather_fns = flatten_dict(to_state_dict(gather_fns))
        pbar = tqdm.tqdm(
            flatten_state.items(),
            disable=not verbose,
            desc="Saving State to File",
        )

        gather_functions_mismatch = 0

        with open(path, "wb") as stream:
            for key, value in pbar:
                if gather_fns is not None:
                    try:
                        callable_func = gather_fns[key]
                        if callable_func is None and not mismatch_allowed:
                            raise KeyError(f"Gather Function {key} is None and NoneType OBJ is not callable.")
                        value = callable_func(value) if callable_func is not None else value
                        if callable_func is None:
                            gather_functions_mismatch += 1
                    except KeyError as k_err:
                        if mismatch_allowed:
                            gather_functions_mismatch += 1
                        else:
                            raise KeyError(k_err)
                pbar.set_postfix(gather_functions_mismatch=gather_functions_mismatch)
                value = get_dtype(value, float_dtype)
                stream.write(packer.pack((key, to_bytes(value))))

    def save_pickle(
            self,
            obj,
            filename: Union[str, os.PathLike]
    ):
        """
        The save_pickle function saves a Python object to disk using the pickle module.

        :param self: Represent the instance of the class
        :param obj: Pass the object that is to be pickled
        :param filename: Specify the name of the file to be saved
        :return: A pickle object

        """
        import pickle

        def save_pickle(obj_, path_):
            with open(path_, "wb") as stream:
                pickle.dump(obj_, stream)

        if self.enable:
            path = os.path.join(self.checkpoint_dir, filename)
        else:
            path = "/dev/null"
        save_pickle(obj, path)

    def save_all(
            self,
            state: struct.PyTreeNode,
            gather_fns,
            metadata=None,
            dataset=None,
            milestone=False
    ):
        """
        The save_all function saves the following:
            - metadata.pkl (a pickle file containing a dictionary of metadata)
            - dataset.pkl (a pickle file containing the training data)
            - streaming_params_{step}.pkl or streaming_state_{step}.pkl
                (depending on whether we want to save optimizer state or not,
                this is a checkpoint that will not be overwritten by future checkpoints)

        :param self: Access the attributes and methods of the class
        :param state: struct.PyTreeNode: Save the current state of the model
        :param gather_fns: Gather the state of the optimizer
        :param metadata: Save the metadata of the training
        :param dataset: Save the dataset to disk
        :param milestone: Determine whether the checkpoint is a milestone or not
        :return: Nothing

        """
        step = int(jax.device_get(state.step))
        if self.save_optimizer_state:
            checkpoint_state = state
            checkpoint_name = "streaming_state"
            checkpoint_gather_fns = gather_fns
        else:
            checkpoint_state = state.params["params"]
            checkpoint_name = "streaming_params"
            checkpoint_gather_fns = gather_fns.params["params"]

        if milestone:
            # Save a milestone checkpoint that will not be overwritten
            self.save_pickle(metadata, f"metadata_{step}.pkl")
            self.save_pickle(dataset, f"dataset_{step}.pkl")
            self.save_checkpoint(
                checkpoint_state, f"{checkpoint_name}_{step}", checkpoint_gather_fns
            )
        else:
            # Save a normal checkpoint that can be overwritten
            self.save_pickle(metadata, "metadata.pkl")
            self.save_pickle(dataset, "dataset.pkl")
            self.save_checkpoint(
                checkpoint_state, f"{checkpoint_name}", checkpoint_gather_fns
            )

    @staticmethod
    def load_checkpoint(
            path: Union[str, os.PathLike],
            target=None,
            shard_fns: dict[Callable] = None,
            remove_dict_prefix=None,
            verbose: bool = False,
            mismatch_allowed: bool = True,
    ):
        """
        The load_checkpoint function is used to checkpoint a checkpoint from disk.

        :param path: Specify the path to the checkpoint file
        :param target: Specify the model to checkpoint the checkpoint into
        :param shard_fns: Specify a function that will be applied to each tensor in the checkpoint
        :param remove_dict_prefix: Remove the prefix of a dictionary     
        :param verbose: print state and other stuff
        :param mismatch_allowed: when ever to allow shard_fns to be passed even if their None
        :return:  of the form {key: value}, where key is a tuple and value is a tensor

        """
        if shard_fns is not None:
            shard_fns = flatten_dict(
                to_state_dict(shard_fns)
            )
        if remove_dict_prefix is not None:
            remove_dict_prefix = tuple(remove_dict_prefix)
        flatten_state = {}

        shard_functions_mismatch = 0
        with open(path, "rb") as fin:
            unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
            pbar = tqdm.tqdm(
                unpacker,
                disable=not verbose,
                desc="Loading Checkpoints From File"
            )
            for key, value in pbar:
                key = tuple(key)
                if remove_dict_prefix is not None:
                    if key[:len(remove_dict_prefix)] == remove_dict_prefix:
                        key = key[len(remove_dict_prefix):]
                    else:
                        continue

                tensor = from_bytes(None, value)
                if shard_fns is not None:
                    try:
                        callable_func = shard_fns[key]
                        if callable_func is None and not mismatch_allowed:
                            raise KeyError(f"Shard Function {key} is None and NoneType OBJ is not callable.")
                        tensor = callable_func(tensor) if callable_func is not None else tensor
                        if callable_func is None:
                            shard_functions_mismatch += 1
                    except KeyError as k_err:
                        if mismatch_allowed:
                            shard_functions_mismatch += 1
                        else:
                            raise KeyError(k_err)
                flatten_state[key] = tensor
                pbar.set_postfix(shard_functions_mismatch=shard_functions_mismatch)
        if target is not None:
            flattened_target = flatten_dict(
                to_state_dict(target), keep_empty_nodes=True
            )
            for key, value in flattened_target.items():
                if key not in flatten_state and value == empty_node:
                    flatten_state[key] = value

        state = unflatten_dict(flatten_state)
        if target is None:
            return state

        return from_state_dict(target, state)

    @staticmethod
    def load_flax_checkpoint(
            path,
            target=None,
            shard_fns=None
    ):
        """ Load a standard flax checkpoint that"s not saved with the
            msgpack streaming format.
        """
        with open(path, "rb") as fin:
            encoded_bytes = fin.read()

        state_dict = flax.serialization.msgpack_restore(encoded_bytes)
        if shard_fns is not None:
            shard_fns = to_state_dict(shard_fns)
            state_dict = jax.tree_util.tree_map(lambda fn, x: fn(x), shard_fns, state_dict)

        if target is None:
            return state_dict
        return from_state_dict(target, state_dict)

    @classmethod
    def load_state_checkpoint(
            cls,
            load_type: Literal[
                "state",
                "state_params",
                "params",
                "flax_params"
            ],
            load_path: Union[str, os.PathLike],
            state_target=None,
            state_shard_fns=None,
            disallow_state=False,
            mismatch_allowed: bool = True
    ):
        """
        The load_state_checkpoint function is used to checkpoint a checkpoint from disk.

        :param cls: Call the load_checkpoint function
        :param load_type: Specify which part of state to checkpoint
        :param load_path: Specify where to checkpoint the model from
        :param state_target: Specify the target for the train state
        :param state_shard_fns: Specify the sharding function
        :param disallow_state: Prevent loading the entire state
        :param mismatch_allowed: when ever to allow shard func to be None
        :return: A tuple of two objects, the state and restored_params

        """
        if state_target is not None:
            params_target = state_target.params["params"]
        else:
            params_target = None

        if state_shard_fns is not None:
            params_shard_fns = state_shard_fns.params["params"]
        else:
            params_shard_fns = None

        if disallow_state:
            assert load_type != "state", "Loading full state is not allowed!"
        state = None
        restored_params = None
        if load_type == "state":
            state = cls.load_checkpoint(
                path=load_path,
                target=state_target,
                shard_fns=state_shard_fns,
                mismatch_allowed=mismatch_allowed
            )
        elif load_type == "state_params":
            restored_params = cls.load_checkpoint(
                path=load_path,
                target=params_target,
                shard_fns=params_shard_fns,
                remove_dict_prefix=("params", "params"),
                mismatch_allowed=mismatch_allowed
            )
            restored_params = flax.core.frozen_dict.freeze(
                {"params": restored_params}
            )
        elif load_type == "params":
            restored_params = cls.load_checkpoint(
                path=load_path,
                target=params_target,
                shard_fns=params_shard_fns,
                mismatch_allowed=mismatch_allowed
            )
            restored_params = flax.core.frozen_dict.freeze(
                {"params": restored_params}
            )
        elif load_type == "flax_params":
            restored_params = cls.load_flax_checkpoint(
                path=load_path,
                target=params_target,
                shard_fns=params_shard_fns,
            )
            restored_params = flax.core.frozen_dict.freeze(
                {"params": restored_params}
            )
        else:
            raise ValueError(f"Invalid load_from type: {load_type}")

        return state, restored_params

load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None, verbose=False, mismatch_allowed=True) staticmethod

The load_checkpoint function is used to checkpoint a checkpoint from disk.

Parameters:

Name Type Description Default
path Union[str, PathLike]

Specify the path to the checkpoint file

required
target

Specify the model to checkpoint the checkpoint into

None
shard_fns dict[Callable]

Specify a function that will be applied to each tensor in the checkpoint

None
remove_dict_prefix

Remove the prefix of a dictionary

None
verbose bool

print state and other stuff

False
mismatch_allowed bool

when ever to allow shard_fns to be passed even if their None

True

Returns:

Type Description

of the form {key: value}, where key is a tuple and value is a tensor

Source code in src/fjformer/checkpoint/streamer.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
@staticmethod
def load_checkpoint(
        path: Union[str, os.PathLike],
        target=None,
        shard_fns: dict[Callable] = None,
        remove_dict_prefix=None,
        verbose: bool = False,
        mismatch_allowed: bool = True,
):
    """
    The load_checkpoint function is used to checkpoint a checkpoint from disk.

    :param path: Specify the path to the checkpoint file
    :param target: Specify the model to checkpoint the checkpoint into
    :param shard_fns: Specify a function that will be applied to each tensor in the checkpoint
    :param remove_dict_prefix: Remove the prefix of a dictionary     
    :param verbose: print state and other stuff
    :param mismatch_allowed: when ever to allow shard_fns to be passed even if their None
    :return:  of the form {key: value}, where key is a tuple and value is a tensor

    """
    if shard_fns is not None:
        shard_fns = flatten_dict(
            to_state_dict(shard_fns)
        )
    if remove_dict_prefix is not None:
        remove_dict_prefix = tuple(remove_dict_prefix)
    flatten_state = {}

    shard_functions_mismatch = 0
    with open(path, "rb") as fin:
        unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
        pbar = tqdm.tqdm(
            unpacker,
            disable=not verbose,
            desc="Loading Checkpoints From File"
        )
        for key, value in pbar:
            key = tuple(key)
            if remove_dict_prefix is not None:
                if key[:len(remove_dict_prefix)] == remove_dict_prefix:
                    key = key[len(remove_dict_prefix):]
                else:
                    continue

            tensor = from_bytes(None, value)
            if shard_fns is not None:
                try:
                    callable_func = shard_fns[key]
                    if callable_func is None and not mismatch_allowed:
                        raise KeyError(f"Shard Function {key} is None and NoneType OBJ is not callable.")
                    tensor = callable_func(tensor) if callable_func is not None else tensor
                    if callable_func is None:
                        shard_functions_mismatch += 1
                except KeyError as k_err:
                    if mismatch_allowed:
                        shard_functions_mismatch += 1
                    else:
                        raise KeyError(k_err)
            flatten_state[key] = tensor
            pbar.set_postfix(shard_functions_mismatch=shard_functions_mismatch)
    if target is not None:
        flattened_target = flatten_dict(
            to_state_dict(target), keep_empty_nodes=True
        )
        for key, value in flattened_target.items():
            if key not in flatten_state and value == empty_node:
                flatten_state[key] = value

    state = unflatten_dict(flatten_state)
    if target is None:
        return state

    return from_state_dict(target, state)

load_flax_checkpoint(path, target=None, shard_fns=None) staticmethod

Load a standard flax checkpoint that"s not saved with the msgpack streaming format.

Source code in src/fjformer/checkpoint/streamer.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
@staticmethod
def load_flax_checkpoint(
        path,
        target=None,
        shard_fns=None
):
    """ Load a standard flax checkpoint that"s not saved with the
        msgpack streaming format.
    """
    with open(path, "rb") as fin:
        encoded_bytes = fin.read()

    state_dict = flax.serialization.msgpack_restore(encoded_bytes)
    if shard_fns is not None:
        shard_fns = to_state_dict(shard_fns)
        state_dict = jax.tree_util.tree_map(lambda fn, x: fn(x), shard_fns, state_dict)

    if target is None:
        return state_dict
    return from_state_dict(target, state_dict)

load_state_checkpoint(load_type, load_path, state_target=None, state_shard_fns=None, disallow_state=False, mismatch_allowed=True) classmethod

The load_state_checkpoint function is used to checkpoint a checkpoint from disk.

Parameters:

Name Type Description Default
cls

Call the load_checkpoint function

required
load_type Literal['state', 'state_params', 'params', 'flax_params']

Specify which part of state to checkpoint

required
load_path Union[str, PathLike]

Specify where to checkpoint the model from

required
state_target

Specify the target for the train state

None
state_shard_fns

Specify the sharding function

None
disallow_state

Prevent loading the entire state

False
mismatch_allowed bool

when ever to allow shard func to be None

True

Returns:

Type Description

A tuple of two objects, the state and restored_params

Source code in src/fjformer/checkpoint/streamer.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
@classmethod
def load_state_checkpoint(
        cls,
        load_type: Literal[
            "state",
            "state_params",
            "params",
            "flax_params"
        ],
        load_path: Union[str, os.PathLike],
        state_target=None,
        state_shard_fns=None,
        disallow_state=False,
        mismatch_allowed: bool = True
):
    """
    The load_state_checkpoint function is used to checkpoint a checkpoint from disk.

    :param cls: Call the load_checkpoint function
    :param load_type: Specify which part of state to checkpoint
    :param load_path: Specify where to checkpoint the model from
    :param state_target: Specify the target for the train state
    :param state_shard_fns: Specify the sharding function
    :param disallow_state: Prevent loading the entire state
    :param mismatch_allowed: when ever to allow shard func to be None
    :return: A tuple of two objects, the state and restored_params

    """
    if state_target is not None:
        params_target = state_target.params["params"]
    else:
        params_target = None

    if state_shard_fns is not None:
        params_shard_fns = state_shard_fns.params["params"]
    else:
        params_shard_fns = None

    if disallow_state:
        assert load_type != "state", "Loading full state is not allowed!"
    state = None
    restored_params = None
    if load_type == "state":
        state = cls.load_checkpoint(
            path=load_path,
            target=state_target,
            shard_fns=state_shard_fns,
            mismatch_allowed=mismatch_allowed
        )
    elif load_type == "state_params":
        restored_params = cls.load_checkpoint(
            path=load_path,
            target=params_target,
            shard_fns=params_shard_fns,
            remove_dict_prefix=("params", "params"),
            mismatch_allowed=mismatch_allowed
        )
        restored_params = flax.core.frozen_dict.freeze(
            {"params": restored_params}
        )
    elif load_type == "params":
        restored_params = cls.load_checkpoint(
            path=load_path,
            target=params_target,
            shard_fns=params_shard_fns,
            mismatch_allowed=mismatch_allowed
        )
        restored_params = flax.core.frozen_dict.freeze(
            {"params": restored_params}
        )
    elif load_type == "flax_params":
        restored_params = cls.load_flax_checkpoint(
            path=load_path,
            target=params_target,
            shard_fns=params_shard_fns,
        )
        restored_params = flax.core.frozen_dict.freeze(
            {"params": restored_params}
        )
    else:
        raise ValueError(f"Invalid load_from type: {load_type}")

    return state, restored_params

save_all(state, gather_fns, metadata=None, dataset=None, milestone=False)

The save_all function saves the following: - metadata.pkl (a pickle file containing a dictionary of metadata) - dataset.pkl (a pickle file containing the training data) - streaming_params_{step}.pkl or streaming_state_{step}.pkl (depending on whether we want to save optimizer state or not, this is a checkpoint that will not be overwritten by future checkpoints)

Parameters:

Name Type Description Default
self

Access the attributes and methods of the class

required
state PyTreeNode

struct.PyTreeNode: Save the current state of the model

required
gather_fns

Gather the state of the optimizer

required
metadata

Save the metadata of the training

None
dataset

Save the dataset to disk

None
milestone

Determine whether the checkpoint is a milestone or not

False

Returns:

Type Description

Nothing

Source code in src/fjformer/checkpoint/streamer.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def save_all(
        self,
        state: struct.PyTreeNode,
        gather_fns,
        metadata=None,
        dataset=None,
        milestone=False
):
    """
    The save_all function saves the following:
        - metadata.pkl (a pickle file containing a dictionary of metadata)
        - dataset.pkl (a pickle file containing the training data)
        - streaming_params_{step}.pkl or streaming_state_{step}.pkl
            (depending on whether we want to save optimizer state or not,
            this is a checkpoint that will not be overwritten by future checkpoints)

    :param self: Access the attributes and methods of the class
    :param state: struct.PyTreeNode: Save the current state of the model
    :param gather_fns: Gather the state of the optimizer
    :param metadata: Save the metadata of the training
    :param dataset: Save the dataset to disk
    :param milestone: Determine whether the checkpoint is a milestone or not
    :return: Nothing

    """
    step = int(jax.device_get(state.step))
    if self.save_optimizer_state:
        checkpoint_state = state
        checkpoint_name = "streaming_state"
        checkpoint_gather_fns = gather_fns
    else:
        checkpoint_state = state.params["params"]
        checkpoint_name = "streaming_params"
        checkpoint_gather_fns = gather_fns.params["params"]

    if milestone:
        # Save a milestone checkpoint that will not be overwritten
        self.save_pickle(metadata, f"metadata_{step}.pkl")
        self.save_pickle(dataset, f"dataset_{step}.pkl")
        self.save_checkpoint(
            checkpoint_state, f"{checkpoint_name}_{step}", checkpoint_gather_fns
        )
    else:
        # Save a normal checkpoint that can be overwritten
        self.save_pickle(metadata, "metadata.pkl")
        self.save_pickle(dataset, "dataset.pkl")
        self.save_checkpoint(
            checkpoint_state, f"{checkpoint_name}", checkpoint_gather_fns
        )

save_pickle(obj, filename)

The save_pickle function saves a Python object to disk using the pickle module.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
obj

Pass the object that is to be pickled

required
filename Union[str, PathLike]

Specify the name of the file to be saved

required

Returns:

Type Description

A pickle object

Source code in src/fjformer/checkpoint/streamer.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def save_pickle(
        self,
        obj,
        filename: Union[str, os.PathLike]
):
    """
    The save_pickle function saves a Python object to disk using the pickle module.

    :param self: Represent the instance of the class
    :param obj: Pass the object that is to be pickled
    :param filename: Specify the name of the file to be saved
    :return: A pickle object

    """
    import pickle

    def save_pickle(obj_, path_):
        with open(path_, "wb") as stream:
            pickle.dump(obj_, stream)

    if self.enable:
        path = os.path.join(self.checkpoint_dir, filename)
    else:
        path = "/dev/null"
    save_pickle(obj, path)