Skip to content

trainer.base_trainer

BaseTrainer

Source code in src/python/easydel/trainer/base_trainer.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
class BaseTrainer:
    def __init__(
            self,
            arguments: TrainArguments,
            dataset_train: Dataset,
            dataset_eval: Dataset = None,
            finetune: bool = True,
            checkpoint_path: Union[str, os.PathLike] = None,
            _do_init_fns: bool = True
    ):
        """
        The __init__ function is called when the class is instantiated.
        It sets up all the variables that are needed for training, including:
        - The timer to keep track of how long each epoch takes.
        - The dataloaders for both training and evaluation (if provided).
        - The model itself, which will be created from a checkpoint if one was provided.  Otherwise,
         it will be created from scratch using the arguments passed in by the user.
         Note that this function also handles creating a mesh if one was not already specified in arguments
         or loaded from a checkpoint file (see below).
          This means that you can pass in either

        :param self: Represent the instance of the class
        :param arguments: TrainArguments: Pass the arguments to the trainer
        :param dataset_train: Dataset: Pass the training dataset to the trainer
        :param dataset_eval: Dataset: Pass the validation dataset
        :param finetune: bool: Load the model from a checkpoint
        :param checkpoint_path: Union[str,os.PathLike] : Load the checkpoint path
        :param _do_init_fns: bool: Initialize the functions
        :return: Nothing, it just initializes the class

        """
        # Loggers
        self.timer = getattr(self, "timer", None)
        self.wandb_runtime: Run | RunDisabled | None = getattr(self, "wandb_runtime", None)

        # Data
        self.dataloader_train = getattr(self, "dataloader_train", None)
        self.dataloader_eval = getattr(self, "dataloader_eval", None)
        self.max_training_steps = getattr(self, "max_training_steps", None)
        self.max_evaluation_steps = getattr(self, "max_evaluation_steps", None)
        self.dataset_train = dataset_train
        self.dataset_eval = dataset_eval

        # Model Related
        self.model = getattr(self, "model", None)
        self.config = getattr(self, "config", None)
        self.scheduler = getattr(self, "scheduler", None)
        self.tx = getattr(self, "tx", None)
        self.model_state = getattr(self, "model_state", None)

        # LoRA Related
        self.rapture = arguments.rapture
        self.lora_parameters = getattr(self, "lora_parameters", None)
        self.lora_model = getattr(self, "lora_model", None)
        self.lora_tx = getattr(self, "lora_tx", None)
        self.lora_opt_state = getattr(self, "lora_opt_state", None)
        self.lora_apply_fn = getattr(self, "lora_apply_fn", None)

        # PJit functions
        self.create_sharded_state_from_params_function = getattr(
            self,
            "create_sharded_state_from_params_function",
            None
        )
        self.sharded_train_step_function = getattr(self, "sharded_train_step_function", None)
        self.sharded_eval_step_function = getattr(self, "sharded_eval_step_function", None)
        self.initialize_state_function = getattr(self, "initialize_state_function", None)
        self.mesh = getattr(self, "mesh", None)

        # Checkpoint Managers
        self.checkpoint_manager: fjformer.CheckpointManager | None = getattr(self, "checkpoint_manager", None)

        # EasyState
        self.state_shape = getattr(self, "state_shape", None)
        self.state_partition_spec = getattr(self, "state_partition_spec", None)
        self.sharded_state = getattr(self, "sharded_state", None)

        # Rest

        self.arguments = arguments
        self.finetune = finetune
        self.checkpoint_path = checkpoint_path
        self.dtype = arguments.dtype
        self.param_dtype = arguments.param_dtype
        if self.arguments.track_memory:
            if not self.arguments.performance_mode:
                initialise_tracking()
                self.arguments._stop_capturing_memory = False
                self._start_capturing_memory().start()
        if finetune:
            if checkpoint_path is None:
                prefix_print(
                    "Warning",
                    "In case of using `finetune = True` and Passing `checkpoint_path = None`"
                    " you should pass parameters in train function"
                )
        if _do_init_fns:
            self.initialize_trainer_utils()
        else:
            prefix_print(
                "Warning",
                "you have set `_do_init_fns = False` so function will not me initialized you have "
                f"to do in manually (simply with `trainer.initialize_trainer_utils()` )"
            )

    def __str__(self):
        string = f"{self.__class__.__name__}("
        for key, value in self.__dict__.items():
            try:
                string += value.__str__().replace("\n", "\n\t")
            except TypeError:
                ...
        string += ")"
        return string

    def __repr__(self):
        return self.__str__()

    @staticmethod
    def finish():
        """
        The finish function is called when the experiment ends.
        It can be used to save data, upload files, or do any other cleanup tasks.

        :return: A dictionary of the run's metadata

        """
        wandb.finish()

    def _start_capturing_memory(self, dir_prefix: str = "/dev/shm" if sys.platform != "win32" else "."):
        def _start():
            while True:
                information_queries = {}
                for key in ["Used", "Usage Percent"]:
                    for device, info in get_capacity_matrix(dir_prefix=dir_prefix).items():
                        information_queries[f"accelerators/{device.replace('_', ' ')} ({key})"] = float(
                            info[key].replace("%", "").replace("GB", "")
                        )
                self.arguments._captured_memory = information_queries
                if self.arguments.stop_capturing_memory:
                    break
                time.sleep(1.5)

        return threading.Thread(target=_start)

    def initialize_trainer_utils(self):
        """
        The initialize_trainer_utils function is responsible for initializing the following:
            - wandb_runtime (if you use_wandb is True)
            - timer object (for logging time taken by various functions)
            - dataloader objects for training and evaluation data, along with max steps per epoch.
              The configure_dataloader function accomplishes this task.

        :param self: Represent the instance of the class
        :return: A tuple of functions

        """
        self.wandb_runtime = None
        if self.arguments.use_wandb:
            self.wandb_runtime = self.arguments.get_wandb_init()
        self.timer = Timers(
            use_wandb=False,
            tensorboard_writer=self.arguments.get_board()
        )

        self.timer("configure dataloaders").start()
        dataset_configurations = self.configure_dataloader()
        self.dataloader_train = dataset_configurations.dataloader_train
        self.max_training_steps = dataset_configurations.max_training_steps
        self.dataloader_eval = dataset_configurations.dataloader_eval
        self.max_evaluation_steps = dataset_configurations.max_evaluation_steps

        self.timer("configure dataloaders").stop()

        self.timer.log(["configure dataloaders"])

        self.timer("configure Model, Optimizer, Scheduler and Config").start()
        model_configurations = self.configure_model()
        model = model_configurations.model
        tx = model_configurations.tx
        scheduler = model_configurations.scheduler
        config = model_configurations.config
        self.model = model
        self.tx = tx
        self.scheduler = scheduler
        self.config = config
        if self.rapture is not None:
            lora_modules = self.rapture.apply_lora(
                module=model,
                parameters=self.arguments.rapture_config.parameters,
                tx=tx,
            )
            self.lora_parameters = lora_modules.lora_parameters
            self.lora_apply_fn = lora_modules.lora_module.__call__
            self.lora_opt_state = lora_modules.lora_opt_state
            self.lora_model = lora_modules.lora_module
            self.lora_tx = lora_modules.lora_tx

        self.timer("configure Model, Optimizer, Scheduler and Config").stop()
        self.timer.log(["configure Model, Optimizer, Scheduler and Config"])
        self.timer("configure functions and sharding them").start()
        function_configurations = self.configure_functions()
        self.create_sharded_state_from_params_function = \
            function_configurations.create_sharded_state_from_params_function
        self.sharded_train_step_function = function_configurations.sharded_train_step_function
        self.sharded_eval_step_function = function_configurations.sharded_eval_step_function
        self.mesh = function_configurations.mesh
        self.checkpoint_manager = function_configurations.checkpoint_manager
        self.initialize_state_function = function_configurations.initialize_state_function
        self.timer("configure functions and sharding them").stop()
        self.timer.log(["configure functions and sharding them"])

    @abstractmethod
    def create_collate_function(
            self,
            max_sequence_length: int,
            truncation_mode: Literal["keep_end", "keep_start"]
    ) -> Callable:
        raise NotImplementedError

    @abc.abstractmethod
    def configure_functions(self) -> TrainerConfigureFunctionFuncOutput:
        """
        The configure_functions function is responsible for configuring the functions that will be used in training.
        It does this by first defining a function called function_configurations, which initializes the model parameters and returns
        them as a EasyDeLState object. The EasyDeLState object contains all the information needed to train or evaluate
        on a batch of data, including:
        :param self: Access the class attributes
        :return: A TrainerConfigureFunctionFuncOutput object

        """
        raise NotImplementedError

    def configure_dataloader(self) -> TrainerConfigureDataloaderFuncOutput:
        """
        The configure_dataloader function is used to configure the dataloader for training and evaluation.

        :param self: Refer to the class instance itself
        :return: A TrainerConfigureDataloaderFuncOutput object

        """

        def create_tf_dataset(dataset: Dataset, is_train: bool) -> Iterator[ndarray[Any, Any]]:
            return (
                dataset.to_tf_dataset(
                    collate_fn=self.create_collate_function(
                        max_sequence_length=self.arguments.max_sequence_length,
                        truncation_mode=self.arguments.truncation_mode
                    ),
                    batch_size=self.arguments.total_batch_size,
                    drop_remainder=True,
                    shuffle=not is_train,
                    num_workers=self.arguments.dataloader_num_workers
                )
                .repeat(self.arguments.num_train_epochs if is_train else 1)
                .prefetch(tf.data.experimental.AUTOTUNE)
                .as_numpy_iterator()
            )

        def create_tf_dataset_from_iterable(dataset: IterableDataset, is_train: bool) -> Iterator[ndarray[Any, Any]]:
            return (
                tf.data.Dataset.from_generator(
                    lambda: dataset,
                    output_signature={
                        col: tf.TensorSpec(shape=(self.arguments.max_sequence_length,), dtype=tf.int32)
                        for col in next(iter(dataset)).keys()
                    }
                )
                .repeat(self.arguments.num_train_epochs if is_train else 1)
                .batch(self.arguments.total_batch_size, drop_remainder=False)
                .prefetch(tf.data.experimental.AUTOTUNE)
                .as_numpy_iterator()
            )

        def calculate_steps(dataset: Union[Dataset, IterableDataset], is_train: bool):
            """
            Return total number of steps to train or evaluate on.
            """
            if hasattr(dataset, "__len__"):
                num_steps = len(dataset) * (self.arguments.num_train_epochs if is_train else 1)
                max_steps = self.arguments.max_training_steps if is_train else self.arguments.max_evaluation_steps
                return min(num_steps, max_steps) if max_steps else num_steps
            else:
                num_steps = self.arguments.max_training_steps if is_train else self.arguments.max_evaluation_steps
                if not num_steps:
                    raise ValueError(
                        f"Specify the number of {'training' if is_train else 'evaluation'} steps for a generator/streaming dataset.")
                return num_steps

        def to_tf_dataloader(dataset: Union[Dataset, IterableDataset], is_train: bool):
            if hasattr(dataset, "__len__"):
                return create_tf_dataset(dataset, is_train)
            else:
                return create_tf_dataset_from_iterable(dataset, is_train)

        max_training_steps = calculate_steps(self.dataset_train, is_train=True)
        dataloader_train = to_tf_dataloader(self.dataset_train, is_train=True)

        if self.dataset_eval is not None and self.arguments.do_eval:
            max_evaluation_steps = calculate_steps(self.dataset_eval, is_train=False)
            dataloader_eval = to_tf_dataloader(self.dataset_eval, is_train=False)
        else:
            dataloader_eval, max_evaluation_steps = None, 0

        return TrainerConfigureDataloaderFuncOutput(
            dataloader_train=dataloader_train,
            max_training_steps=max_training_steps,
            dataloader_eval=dataloader_eval,
            max_evaluation_steps=max_evaluation_steps
        )

    def configure_model(self) -> TrainerConfigureModelFuncOutput:
        """
        The configure_model function is responsible for creating the model, optimizer and scheduler.

        :param self: Represent the instance of the class
        :return: A model, optimizer, scheduler and config  in TrainerConfigureModelFuncOutput Object

        """
        extra_configs = {} if self.arguments.extra_configs is None else self.arguments.extra_configs
        if self.arguments.model_class is not None:

            if not hasattr(self.arguments.configs_to_initialize_model_class["config"], "get_partition_rules"):
                assert self.arguments.custom_rule is not None, (
                    "if you are using custom model to init you must"
                    " pass custom_rule for partition rules "
                )

            self.arguments.configs_to_initialize_model_class["config"].axis_dims = self.arguments.sharding_array

            model = self.arguments.model_class(
                **self.arguments.configs_to_initialize_model_class,
                _do_init=False
            )

            config = self.arguments.configs_to_initialize_model_class["config"]

        else:
            extra_configs["gradient_checkpointing"] = self.arguments.gradient_checkpointing

            model = AutoEasyDeLModelForCausalLM.from_pretrained(
                self.arguments.model_huggingface_repo_id,
                dtype=self.arguments.dtype,
                param_dtype=self.arguments.param_dtype,
                _do_init=False
            )
            if hasattr(model, "config"):
                for k, v in extra_configs.items():
                    setattr(model.config, k, v)
                config = model.config
            else:
                config = None
                warnings.warn(
                    "Config is being set to None due to not detecting Model Configuration from taken Model "
                    "this will cause errors later."
                )
        tx, scheduler = self.arguments.get_optimizer_and_scheduler(self.max_training_steps)
        return TrainerConfigureModelFuncOutput(
            model=model,
            tx=tx,
            scheduler=scheduler,
            config=config
        )

    def _save_state(
            self,
            state: "EasyDeLState",
            gather_fns: Optional[Any | Mapping[str, Callable] | dict[Callable]],
            milestone: bool = False
    ) -> str:
        step = int(
            jax.device_get(
                state.step
            )
        ) + self.arguments.step_start_point if self.arguments.step_start_point is not None else int(
            jax.device_get(
                state.step
            )
        )
        checkpoint_name = f"{self.arguments.model_name}-S{step}"
        filename = f"{checkpoint_name}_{step}" if milestone else f"{checkpoint_name}"
        filename += ".easy"
        termcolor.cprint(f"Saving Model {filename}.", color="cyan", force_color=True)
        state.save_state(
            filename=filename,
            checkpoint_dir=os.path.join(self.arguments.save_dir, self.arguments.model_name),
            gather_fns=gather_fns,
            float_dtype=self.dtype,
            verbose=self.arguments.verbose,
            save_optimizer=self.arguments.save_optimizer_state,
        )
        return filename

    @abc.abstractmethod
    def train(self):
        """
        abstract of Train Function to train model
        """

    @abc.abstractmethod
    def eval(self, state):
        """
        abstract of Eval Function to evaluate model
        """

__init__(arguments, dataset_train, dataset_eval=None, finetune=True, checkpoint_path=None, _do_init_fns=True)

The init function is called when the class is instantiated. It sets up all the variables that are needed for training, including: - The timer to keep track of how long each epoch takes. - The dataloaders for both training and evaluation (if provided). - The model itself, which will be created from a checkpoint if one was provided. Otherwise, it will be created from scratch using the arguments passed in by the user. Note that this function also handles creating a mesh if one was not already specified in arguments or loaded from a checkpoint file (see below). This means that you can pass in either

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
arguments TrainArguments

TrainArguments: Pass the arguments to the trainer

required
dataset_train Dataset

Dataset: Pass the training dataset to the trainer

required
dataset_eval Dataset

Dataset: Pass the validation dataset

None
finetune bool

bool: Load the model from a checkpoint

True
checkpoint_path Union[str, PathLike]

Union[str,os.PathLike] : Load the checkpoint path

None
_do_init_fns bool

bool: Initialize the functions

True

Returns:

Type Description

Nothing, it just initializes the class

Source code in src/python/easydel/trainer/base_trainer.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def __init__(
        self,
        arguments: TrainArguments,
        dataset_train: Dataset,
        dataset_eval: Dataset = None,
        finetune: bool = True,
        checkpoint_path: Union[str, os.PathLike] = None,
        _do_init_fns: bool = True
):
    """
    The __init__ function is called when the class is instantiated.
    It sets up all the variables that are needed for training, including:
    - The timer to keep track of how long each epoch takes.
    - The dataloaders for both training and evaluation (if provided).
    - The model itself, which will be created from a checkpoint if one was provided.  Otherwise,
     it will be created from scratch using the arguments passed in by the user.
     Note that this function also handles creating a mesh if one was not already specified in arguments
     or loaded from a checkpoint file (see below).
      This means that you can pass in either

    :param self: Represent the instance of the class
    :param arguments: TrainArguments: Pass the arguments to the trainer
    :param dataset_train: Dataset: Pass the training dataset to the trainer
    :param dataset_eval: Dataset: Pass the validation dataset
    :param finetune: bool: Load the model from a checkpoint
    :param checkpoint_path: Union[str,os.PathLike] : Load the checkpoint path
    :param _do_init_fns: bool: Initialize the functions
    :return: Nothing, it just initializes the class

    """
    # Loggers
    self.timer = getattr(self, "timer", None)
    self.wandb_runtime: Run | RunDisabled | None = getattr(self, "wandb_runtime", None)

    # Data
    self.dataloader_train = getattr(self, "dataloader_train", None)
    self.dataloader_eval = getattr(self, "dataloader_eval", None)
    self.max_training_steps = getattr(self, "max_training_steps", None)
    self.max_evaluation_steps = getattr(self, "max_evaluation_steps", None)
    self.dataset_train = dataset_train
    self.dataset_eval = dataset_eval

    # Model Related
    self.model = getattr(self, "model", None)
    self.config = getattr(self, "config", None)
    self.scheduler = getattr(self, "scheduler", None)
    self.tx = getattr(self, "tx", None)
    self.model_state = getattr(self, "model_state", None)

    # LoRA Related
    self.rapture = arguments.rapture
    self.lora_parameters = getattr(self, "lora_parameters", None)
    self.lora_model = getattr(self, "lora_model", None)
    self.lora_tx = getattr(self, "lora_tx", None)
    self.lora_opt_state = getattr(self, "lora_opt_state", None)
    self.lora_apply_fn = getattr(self, "lora_apply_fn", None)

    # PJit functions
    self.create_sharded_state_from_params_function = getattr(
        self,
        "create_sharded_state_from_params_function",
        None
    )
    self.sharded_train_step_function = getattr(self, "sharded_train_step_function", None)
    self.sharded_eval_step_function = getattr(self, "sharded_eval_step_function", None)
    self.initialize_state_function = getattr(self, "initialize_state_function", None)
    self.mesh = getattr(self, "mesh", None)

    # Checkpoint Managers
    self.checkpoint_manager: fjformer.CheckpointManager | None = getattr(self, "checkpoint_manager", None)

    # EasyState
    self.state_shape = getattr(self, "state_shape", None)
    self.state_partition_spec = getattr(self, "state_partition_spec", None)
    self.sharded_state = getattr(self, "sharded_state", None)

    # Rest

    self.arguments = arguments
    self.finetune = finetune
    self.checkpoint_path = checkpoint_path
    self.dtype = arguments.dtype
    self.param_dtype = arguments.param_dtype
    if self.arguments.track_memory:
        if not self.arguments.performance_mode:
            initialise_tracking()
            self.arguments._stop_capturing_memory = False
            self._start_capturing_memory().start()
    if finetune:
        if checkpoint_path is None:
            prefix_print(
                "Warning",
                "In case of using `finetune = True` and Passing `checkpoint_path = None`"
                " you should pass parameters in train function"
            )
    if _do_init_fns:
        self.initialize_trainer_utils()
    else:
        prefix_print(
            "Warning",
            "you have set `_do_init_fns = False` so function will not me initialized you have "
            f"to do in manually (simply with `trainer.initialize_trainer_utils()` )"
        )

configure_dataloader()

The configure_dataloader function is used to configure the dataloader for training and evaluation.

Parameters:

Name Type Description Default
self

Refer to the class instance itself

required

Returns:

Type Description
TrainerConfigureDataloaderFuncOutput

A TrainerConfigureDataloaderFuncOutput object

Source code in src/python/easydel/trainer/base_trainer.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def configure_dataloader(self) -> TrainerConfigureDataloaderFuncOutput:
    """
    The configure_dataloader function is used to configure the dataloader for training and evaluation.

    :param self: Refer to the class instance itself
    :return: A TrainerConfigureDataloaderFuncOutput object

    """

    def create_tf_dataset(dataset: Dataset, is_train: bool) -> Iterator[ndarray[Any, Any]]:
        return (
            dataset.to_tf_dataset(
                collate_fn=self.create_collate_function(
                    max_sequence_length=self.arguments.max_sequence_length,
                    truncation_mode=self.arguments.truncation_mode
                ),
                batch_size=self.arguments.total_batch_size,
                drop_remainder=True,
                shuffle=not is_train,
                num_workers=self.arguments.dataloader_num_workers
            )
            .repeat(self.arguments.num_train_epochs if is_train else 1)
            .prefetch(tf.data.experimental.AUTOTUNE)
            .as_numpy_iterator()
        )

    def create_tf_dataset_from_iterable(dataset: IterableDataset, is_train: bool) -> Iterator[ndarray[Any, Any]]:
        return (
            tf.data.Dataset.from_generator(
                lambda: dataset,
                output_signature={
                    col: tf.TensorSpec(shape=(self.arguments.max_sequence_length,), dtype=tf.int32)
                    for col in next(iter(dataset)).keys()
                }
            )
            .repeat(self.arguments.num_train_epochs if is_train else 1)
            .batch(self.arguments.total_batch_size, drop_remainder=False)
            .prefetch(tf.data.experimental.AUTOTUNE)
            .as_numpy_iterator()
        )

    def calculate_steps(dataset: Union[Dataset, IterableDataset], is_train: bool):
        """
        Return total number of steps to train or evaluate on.
        """
        if hasattr(dataset, "__len__"):
            num_steps = len(dataset) * (self.arguments.num_train_epochs if is_train else 1)
            max_steps = self.arguments.max_training_steps if is_train else self.arguments.max_evaluation_steps
            return min(num_steps, max_steps) if max_steps else num_steps
        else:
            num_steps = self.arguments.max_training_steps if is_train else self.arguments.max_evaluation_steps
            if not num_steps:
                raise ValueError(
                    f"Specify the number of {'training' if is_train else 'evaluation'} steps for a generator/streaming dataset.")
            return num_steps

    def to_tf_dataloader(dataset: Union[Dataset, IterableDataset], is_train: bool):
        if hasattr(dataset, "__len__"):
            return create_tf_dataset(dataset, is_train)
        else:
            return create_tf_dataset_from_iterable(dataset, is_train)

    max_training_steps = calculate_steps(self.dataset_train, is_train=True)
    dataloader_train = to_tf_dataloader(self.dataset_train, is_train=True)

    if self.dataset_eval is not None and self.arguments.do_eval:
        max_evaluation_steps = calculate_steps(self.dataset_eval, is_train=False)
        dataloader_eval = to_tf_dataloader(self.dataset_eval, is_train=False)
    else:
        dataloader_eval, max_evaluation_steps = None, 0

    return TrainerConfigureDataloaderFuncOutput(
        dataloader_train=dataloader_train,
        max_training_steps=max_training_steps,
        dataloader_eval=dataloader_eval,
        max_evaluation_steps=max_evaluation_steps
    )

configure_functions() abstractmethod

The configure_functions function is responsible for configuring the functions that will be used in training. It does this by first defining a function called function_configurations, which initializes the model parameters and returns them as a EasyDeLState object. The EasyDeLState object contains all the information needed to train or evaluate on a batch of data, including:

Parameters:

Name Type Description Default
self

Access the class attributes

required

Returns:

Type Description
TrainerConfigureFunctionFuncOutput

A TrainerConfigureFunctionFuncOutput object

Source code in src/python/easydel/trainer/base_trainer.py
280
281
282
283
284
285
286
287
288
289
290
291
@abc.abstractmethod
def configure_functions(self) -> TrainerConfigureFunctionFuncOutput:
    """
    The configure_functions function is responsible for configuring the functions that will be used in training.
    It does this by first defining a function called function_configurations, which initializes the model parameters and returns
    them as a EasyDeLState object. The EasyDeLState object contains all the information needed to train or evaluate
    on a batch of data, including:
    :param self: Access the class attributes
    :return: A TrainerConfigureFunctionFuncOutput object

    """
    raise NotImplementedError

configure_model()

The configure_model function is responsible for creating the model, optimizer and scheduler.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description
TrainerConfigureModelFuncOutput

A model, optimizer, scheduler and config in TrainerConfigureModelFuncOutput Object

Source code in src/python/easydel/trainer/base_trainer.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def configure_model(self) -> TrainerConfigureModelFuncOutput:
    """
    The configure_model function is responsible for creating the model, optimizer and scheduler.

    :param self: Represent the instance of the class
    :return: A model, optimizer, scheduler and config  in TrainerConfigureModelFuncOutput Object

    """
    extra_configs = {} if self.arguments.extra_configs is None else self.arguments.extra_configs
    if self.arguments.model_class is not None:

        if not hasattr(self.arguments.configs_to_initialize_model_class["config"], "get_partition_rules"):
            assert self.arguments.custom_rule is not None, (
                "if you are using custom model to init you must"
                " pass custom_rule for partition rules "
            )

        self.arguments.configs_to_initialize_model_class["config"].axis_dims = self.arguments.sharding_array

        model = self.arguments.model_class(
            **self.arguments.configs_to_initialize_model_class,
            _do_init=False
        )

        config = self.arguments.configs_to_initialize_model_class["config"]

    else:
        extra_configs["gradient_checkpointing"] = self.arguments.gradient_checkpointing

        model = AutoEasyDeLModelForCausalLM.from_pretrained(
            self.arguments.model_huggingface_repo_id,
            dtype=self.arguments.dtype,
            param_dtype=self.arguments.param_dtype,
            _do_init=False
        )
        if hasattr(model, "config"):
            for k, v in extra_configs.items():
                setattr(model.config, k, v)
            config = model.config
        else:
            config = None
            warnings.warn(
                "Config is being set to None due to not detecting Model Configuration from taken Model "
                "this will cause errors later."
            )
    tx, scheduler = self.arguments.get_optimizer_and_scheduler(self.max_training_steps)
    return TrainerConfigureModelFuncOutput(
        model=model,
        tx=tx,
        scheduler=scheduler,
        config=config
    )

eval(state) abstractmethod

abstract of Eval Function to evaluate model

Source code in src/python/easydel/trainer/base_trainer.py
459
460
461
462
463
@abc.abstractmethod
def eval(self, state):
    """
    abstract of Eval Function to evaluate model
    """

finish() staticmethod

The finish function is called when the experiment ends. It can be used to save data, upload files, or do any other cleanup tasks.

Returns:

Type Description

A dictionary of the run's metadata

Source code in src/python/easydel/trainer/base_trainer.py
178
179
180
181
182
183
184
185
186
187
@staticmethod
def finish():
    """
    The finish function is called when the experiment ends.
    It can be used to save data, upload files, or do any other cleanup tasks.

    :return: A dictionary of the run's metadata

    """
    wandb.finish()

initialize_trainer_utils()

The initialize_trainer_utils function is responsible for initializing the following: - wandb_runtime (if you use_wandb is True) - timer object (for logging time taken by various functions) - dataloader objects for training and evaluation data, along with max steps per epoch. The configure_dataloader function accomplishes this task.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description

A tuple of functions

Source code in src/python/easydel/trainer/base_trainer.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def initialize_trainer_utils(self):
    """
    The initialize_trainer_utils function is responsible for initializing the following:
        - wandb_runtime (if you use_wandb is True)
        - timer object (for logging time taken by various functions)
        - dataloader objects for training and evaluation data, along with max steps per epoch.
          The configure_dataloader function accomplishes this task.

    :param self: Represent the instance of the class
    :return: A tuple of functions

    """
    self.wandb_runtime = None
    if self.arguments.use_wandb:
        self.wandb_runtime = self.arguments.get_wandb_init()
    self.timer = Timers(
        use_wandb=False,
        tensorboard_writer=self.arguments.get_board()
    )

    self.timer("configure dataloaders").start()
    dataset_configurations = self.configure_dataloader()
    self.dataloader_train = dataset_configurations.dataloader_train
    self.max_training_steps = dataset_configurations.max_training_steps
    self.dataloader_eval = dataset_configurations.dataloader_eval
    self.max_evaluation_steps = dataset_configurations.max_evaluation_steps

    self.timer("configure dataloaders").stop()

    self.timer.log(["configure dataloaders"])

    self.timer("configure Model, Optimizer, Scheduler and Config").start()
    model_configurations = self.configure_model()
    model = model_configurations.model
    tx = model_configurations.tx
    scheduler = model_configurations.scheduler
    config = model_configurations.config
    self.model = model
    self.tx = tx
    self.scheduler = scheduler
    self.config = config
    if self.rapture is not None:
        lora_modules = self.rapture.apply_lora(
            module=model,
            parameters=self.arguments.rapture_config.parameters,
            tx=tx,
        )
        self.lora_parameters = lora_modules.lora_parameters
        self.lora_apply_fn = lora_modules.lora_module.__call__
        self.lora_opt_state = lora_modules.lora_opt_state
        self.lora_model = lora_modules.lora_module
        self.lora_tx = lora_modules.lora_tx

    self.timer("configure Model, Optimizer, Scheduler and Config").stop()
    self.timer.log(["configure Model, Optimizer, Scheduler and Config"])
    self.timer("configure functions and sharding them").start()
    function_configurations = self.configure_functions()
    self.create_sharded_state_from_params_function = \
        function_configurations.create_sharded_state_from_params_function
    self.sharded_train_step_function = function_configurations.sharded_train_step_function
    self.sharded_eval_step_function = function_configurations.sharded_eval_step_function
    self.mesh = function_configurations.mesh
    self.checkpoint_manager = function_configurations.checkpoint_manager
    self.initialize_state_function = function_configurations.initialize_state_function
    self.timer("configure functions and sharding them").stop()
    self.timer.log(["configure functions and sharding them"])

train() abstractmethod

abstract of Train Function to train model

Source code in src/python/easydel/trainer/base_trainer.py
453
454
455
456
457
@abc.abstractmethod
def train(self):
    """
    abstract of Train Function to train model
    """