Skip to content

etils.easystate

EasyDeLState

Bases: PyTreeNode

Source code in src/python/easydel/etils/easystate.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
class EasyDeLState(struct.PyTreeNode):
    step: int
    module: Optional["EasyDeLFlaxPretrainedModel"] = struct.field(pytree_node=False)  # type:ignore
    module_config: Optional["EasyDeLPretrainedConfig"] = struct.field(pytree_node=False)  # type:ignore
    module_config_args: Optional[dict] = struct.field(pytree_node=True)
    apply_fn: Callable = struct.field(pytree_node=False)
    params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    tx: optax.GradientTransformation = struct.field(pytree_node=False)
    opt_state: Optional[optax.OptState] = struct.field(pytree_node=True)
    tx_init: Optional[dict] = struct.field(pytree_node=True)
    hyperparameters: Optional[dict] = struct.field(pytree_node=True)

    def apply_gradients(self, *, grads, **kwargs):

        """
        The apply_gradients function is the core of the optimizer. It takes in a dictionary of gradients,
        and returns an updated version of itself with new parameters and state. The function also updates
        the step count.

        :param self: Refer to the current instance of the class
        :param *: Unpack the grads dictionary into positional arguments
        :param grads: Pass in the gradients of the loss function with respect to each parameter
        :param kwargs: Pass in additional arguments to the function
        :return: A new State with the updated parameters and params
        """
        if OVERWRITE_WITH_GRADIENT in grads:
            grads_with_opt = grads['params']
            params_with_opt = self.params['params']
        else:
            grads_with_opt = grads
            params_with_opt = self.params

        updates, new_opt_state = self.tx.update(
            grads_with_opt, self.opt_state, params_with_opt
        )
        new_params_with_opt = optax.apply_updates(params_with_opt, updates)
        if OVERWRITE_WITH_GRADIENT in grads:
            new_params = {
                'params': new_params_with_opt,
                OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT]
            }
        else:
            new_params = new_params_with_opt
        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=new_opt_state,
            **kwargs,
        )

    @classmethod
    def create(
            cls,
            *,
            apply_fn: Callable,
            params: Union[core.FrozenDict[str, Any], Mapping[str, Any]],
            tx: optax.GradientTransformation,
            tx_init: Optional[dict] = None,
            hyperparameters: Optional[dict] = None,
            module: Optional["EasyDeLFlaxPretrainedModel"] = None,  # type:ignore
            module_config: Optional["EasyDeLPretrainedConfig"] = None,  # type:ignore
            module_config_args: Optional[dict] = None,
            **kwargs
    ):

        """
        The create function is used to create a new instance of the class.

        :param cls: Create a new instance of the class
        :param *: Pass a list of parameters to the function
        :param apply_fn: Callable: Apply the model to a batch of data
        :param params: core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in the parameters of the model
        :param tx: optax.GradientTransformation: Initialize the optimizer
        :param tx_init: Optional[dict]: Initialize the optimizer
        :param hyperparameters: Optional[dict]: Pass hyperparameters to the state for init
        :param module: Optional[EasyDeLFlaxPretrainedModel]: Pass the module to be used int state
        :param module_config: Optional[EasyDeLPretrainedConfig]: Pass in the module config
        :param module_config_args: Optional[dict]: Store the config args of the model
        :param kwargs: Pass in additional parameters to the
        :return: A EasyDeLState object
        """
        if hyperparameters is None:
            hyperparameters = {}
        params_with_opt = (
            params['params'] if OVERWRITE_WITH_GRADIENT in params else params
        )
        opt_state = tx.init(params_with_opt)
        if module_config is not None:
            module_config = copy.deepcopy(module_config)
            cls.safe_dict(module_config.__dict__)
        return cls(
            step=0,
            apply_fn=apply_fn,
            module=module,
            params=params,
            tx=tx,
            opt_state=opt_state,
            tx_init=cls.safe_dict(tx_init),
            hyperparameters=hyperparameters,
            module_config=module_config,
            module_config_args=None,
            **kwargs,
        )

    @classmethod
    def load(
            cls,
            *,
            apply_fn: Callable,
            params: Union[core.FrozenDict[str, Any], Mapping[str, Any]],
            step: int = 0,
            opt_state: Optional[optax.OptState] = None,
            tx_init: Optional[dict] = None,
            hyperparameters: Optional[dict] = None,
            module: Optional["EasyDeLFlaxPretrainedModel"] = None,  # type:ignore
            module_config: Optional["EasyDeLPretrainedConfig"] = None,  # type:ignore
            module_config_args: Optional[dict] = None,
            **kwargs
    ):

        """
        The load function is used to load a saved state of the Model and optimizer or Model Only.

        :param cls: Make the function a class method
        :param *: Pass in a variable number of arguments
        :param step: int: Keep track of the number of steps that have been taken
        :param apply_fn: Callable: Apply the optimizer to the model
        :param params: core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in the parameters of the model
        :param opt_state: Optional[optax.OptState]: optimizer state
        :param tx_init: Optional[dict]: Pass the hyperparameters to the optimizer
        :param hyperparameters: Optional[dict]: Load hyperparameters from the state dict
        :param module: Optional[EasyDeLFlaxPretrainedModel]: Pass in the module
        :param module_config: Optional[EasyDeLPretrainedConfig]: Pass the module config
        :param module_config_args: Optional[dict]: Pass the config_args to the model
        :param kwargs: Pass in any additional parameters that may be needed for the model
        :return: A new instance of the class
        """
        if module_config is not None:
            module_config = copy.deepcopy(module_config)

        if tx_init is None:
            tx_init = {}
        tx_init = copy.deepcopy(tx_init)
        tx_init = cls.unsafe_dict(tx_init)

        tx_init["optimizer"] = cls.search("optimizer", tx_init, "adamw")
        tx_init["scheduler"] = cls.search("scheduler", tx_init, "none")
        tx_init["steps"] = cls.search("steps", tx_init, 1e6)

        def fix_dict_types(input_dict):
            fixed_dict = input_dict.copy()

            # Fix extra_optimizer_kwargs
            if 'extra_optimizer_kwargs' in fixed_dict:
                fixed_dict['extra_optimizer_kwargs'] = eval(fixed_dict['extra_optimizer_kwargs'])

            # Fix gradient_accumulation_steps
            if 'gradient_accumulation_steps' in fixed_dict:
                fixed_dict['gradient_accumulation_steps'] = int(fixed_dict['gradient_accumulation_steps'])

            # Fix steps
            if 'steps' in fixed_dict:
                fixed_dict['steps'] = int(fixed_dict['steps'])

            # Fix warmup_steps
            if 'warmup_steps' in fixed_dict:
                fixed_dict['warmup_steps'] = int(fixed_dict['warmup_steps'])

            return fixed_dict

        try:
            tx, sc = get_optimizer_and_scheduler(
                **tx_init
            )
        except TypeError:
            tx, sc = get_optimizer_and_scheduler(
                **fix_dict_types(tx_init)
            )
        if hyperparameters is None:
            hyperparameters = {}

        if module_config is not None:
            hyperparameters = cls.create_hyperparameters(module_config.model_type)
            cls.safe_dict(module_config.__dict__)
        return cls(
            step=step,
            apply_fn=apply_fn,
            params=params,
            tx=tx,
            opt_state=opt_state,
            tx_init=cls.safe_dict(tx_init),
            hyperparameters=hyperparameters,
            module=module,
            module_config=module_config,
            module_config_args=None,
            **kwargs,
        )

    @classmethod
    def load_state(
            cls,
            checkpoint_path: Union[str, os.PathLike],
            dtype: jnp.dtype = jnp.float32,
            param_dtype: jnp.dtype = jnp.float32,
            precision: Optional[Union[str, jax.lax.Precision]] = None,
            init_optimizer_state: bool = False,
            state_shard_fns: Optional[Mapping[str, Callable]] = None,
            verbose: bool = False,
            input_shape: Tuple = (1, 1),
            config_kwargs: Optional[dict] = None
    ):

        """    
        The load_state function is a class method that loads the state of an EasyDeLModel from a checkpoint.

        :param cls: Create an instance of the class
        :param checkpoint_path: str | os.PathLike: Specify the path to the checkpoint file
        :param dtype: jnp.dtype: The dtype of the model
        :param param_dtype: jnp.dtype: The dtype of the model parameters
        :param precision: Optional[Union[str, jax.lax.Precision]]: precision of the model
        :param init_optimizer_state: bool: Initialize the optimizer if it's not Initialized yet (if it Initialized the option
        will be ignored )
        :param state_shard_fns: Optional[Mapping[str,Callable]]: Specify the function that will be used 
        to shard the loaded state
        :param verbose: bool: Print out the progress of loading
        :param input_shape: Tuple: input_shape to init module
        :param config_kwargs: Optional[dict] : config kwargs to be passed to model config
        :return: A state object
        """
        from ..modules.auto_easydel_model import get_modules_by_type

        checkpoint = fjformer.CheckpointManager.load_checkpoint(
            path=checkpoint_path,
            shard_fns=state_shard_fns,
            verbose=verbose,
        )
        hyperparameters = checkpoint.get("hyperparameters")
        cfg, module, convertor = get_modules_by_type(model_type=cls.get_model_type(hyperparameters))
        checkpoint.pop("module_config", None)
        if checkpoint["module_config_args"] is not None:
            cfg_behave = cls.unsafe_dict(checkpoint.get("module_config_args", {}))
            cfg_behave.pop("id2label", None)
            cfg_behave.pop("label2id", None)
            cfg_behave.pop("torch_dtype", None)
            for k, v in cfg_behave.items():
                if v is None:
                    cfg_behave.pop(k, None)
                elif v == "None":
                    cfg_behave[k] = None
                elif isinstance(v, str):
                    if v.startswith("{") or v.startswith("(") or v.startswith("PartitionSpec"):
                        cfg_behave[k] = eval(v)
            module_config = cfg.from_dict(cfg_behave)
            if config_kwargs is not None:
                for k, v in config_kwargs.items():
                    setattr(module_config, k, v)
            module_in = module(
                config=module_config,
                dtype=dtype,
                param_dtype=param_dtype,
                precision=precision,
                input_shape=input_shape
            )
        else:
            raise TypeError(
                "Om seems like i couldn't read model correctly ;("
            )
        state = cls.load(
            apply_fn=module_in.__call__,
            module=module_in,
            module_config=module_config,
            **checkpoint
        )
        state = state.replace(
            module_config_args=None  # removing because it's not needed anymore
        )
        if init_optimizer_state:
            state = state.init_opt_state()
        return state

    @classmethod
    def get_model_type(cls, dictionary):
        return cls.find_key("model_type", dictionary)

    def save_state(
            self,
            filename: Union[str, os.PathLike],
            save_optimizer: bool = False,
            checkpoint_dir: Optional[Union[str, os.PathLike]] = None,
            verbose: bool = False,
            gather_fns: dict[Callable] = None,
            float_dtype: Union[str, jax.numpy.dtype] = None,
    ):

        """
        The save_state function saves the state of a model to disk.

        :param self: Pass the object itself to the function
        :param filename: str | os.PathLike: Specify the name of the file to save
        :param save_optimizer: bool: Determine whether to save the optimizer state or not
        :param checkpoint_dir: Optional[str | os.PathLike]: Specify the directory where the checkpoint is saved
        :param verbose: bool: Print out the path of the saved file
        :param gather_fns: dict[Callable]: Specify a dictionary of functions that can be used to gather
        :param float_dtype: str | jax.numpy.dtype: Specify the precision of the saved model
        :param : Save the optimizer state
        :return: None
        """
        state = self
        if not save_optimizer:
            state = self.replace(
                opt_state=None
            )
        state = state.replace(
            module_config_args={
                k: v for k, v in state.module.config.__dict__.items() if
                isinstance(
                    v, (int, bool, float)
                )
            }
        )
        fjformer.CheckpointManager.save_state_to_file(
            state=state,
            path=os.path.join(checkpoint_dir, filename) if checkpoint_dir is not None else filename,
            verbose=verbose,
            gather_fns=gather_fns,
            float_dtype=float_dtype,
        )

    def free_opt_state(self) -> "EasyDeLState":

        """
        The free_opt_state function is used to free the memory allocated by a previous call to setopt.
        It should be called after all the options have been set, and before you perform any of the transfers.


        :param self: Represent the instance of the class
        :return: A new state with the opt_state field set to none
        """
        return self.replace(
            opt_state=None
        )

    def init_opt_state(self) -> "EasyDeLState":

        """
        The init_opt_state function initializes the optimizer state.
        :param self: Make the object callable, and params is used to pass in a dictionary of parameters
        :return: A new instance of the class with opt_state initialized
        """
        if self.opt_state is None:
            params_with_opt = (
                self.params['params'] if OVERWRITE_WITH_GRADIENT in self.params else self.params
            )
            opt_state = self.tx.init(params_with_opt)

            return self.replace(
                opt_state=opt_state
            )
        return self

    @classmethod
    def from_pretrained(
            cls,
            pretrained_model_name_or_path: str,
            filename: Optional[str] = None,
            optimizer: AVAILABLE_OPTIMIZERS = "adamw",
            scheduler: AVAILABLE_SCHEDULERS = "none",
            tx_init: Optional[dict] = None,
            device=jax.devices('cpu')[0],
            dtype: jax.numpy.dtype = jax.numpy.float32,
            param_dtype: jax.numpy.dtype = jax.numpy.float32,
            precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"),
            sharding_axis_dims: Sequence[int] = (1, -1, 1, 1),
            sharding_axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"),
            query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            generation_query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "tp", None, None),
            key_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            value_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
            generation_bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
            attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            shard_attention_computation: bool = True,
            input_shape: Sequence[int] = (1, 1),
            backend: Optional[str] = None,
            init_optimizer_state: bool = False,
            free_optimizer_state: bool = True,
            verbose: bool = True,
            state_shard_fns: Optional[Mapping[str, Callable]] = None,
            config_kwargs: Optional[Mapping[str, Any]] = None,
            **kwargs
    ) -> "EasyDeLState":

        """
        The from_pretrained function is a helper function to quickly load a pretrained model and its associated configuration.
        This method takes care of returning the correct model class instance based on the `model_type` property in the
        config object, or when it's missing, falling back to using pattern matching on the
         `pretrained_model_name_or_path` string:

        :param cls: Refer to the class that is being defined
        :param pretrained_model_name_or_path: str: Load the pretrained model
        :param filename: Optional[str]: Specify the name of the file to download from huggingface hub
        :param optimizer: AVAILABLE_OPTIMIZERS: Specify the optimizer used for training
        :param scheduler: AVAILABLE_SCHEDULERS: Specify the name of the scheduler to use
        :param tx_init: Optional[dict]: Pass the hyperparameters of the optimizer
        :param device: Specify the device on which to run the model
        :param dtype: jax.numpy.dtype: Specify the dtype of the model parameters
        :param param_dtype: jax.numpy.dtype: Specify the data type of the parameters
        :param precision: jax.lax.Precision: Control the precision of the calculation
        :param sharding_axis_dims: Sequence[int]: Specify the dimension of each axis
        :param sharding_axis_names: Sequence[str]: Specify the names of the axes in each shard
        :param query_partition_spec: PartitionSpec: Specify the partitioning of the query matrix
        :param generation_query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor in
        generation process:param key_partition_spec: PartitionSpec: Specify the partitioning of the key matrix
        :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor
        :param bias_partition_spec: PartitionSpec: Specify the partitioning of the bias
        :param attention_partition_spec: PartitionSpec: Partition the attention weights
        :param shard_attention_computation: bool: Determine whether to use shard_map or not
        :param input_shape: Sequence[int]: Specify the shape of the input to be used for training
        :param backend: Optional[str]: Specify the backend used for the model
        :param init_optimizer_state: bool: Initialize the optimizer state
        :param free_optimizer_state: bool: Free the optimizer state from memory
        :param verbose: bool: Print the progress of loading the model
        :param state_shard_fns: Optional[Mapping[str,Callable]]: Specify the function to use for sharding the state
        :param kwargs: Pass keyword arguments to the function
        :param config_kwargs: Optional[Mapping[str, Any]]: Config kwargs to be added to config before creating module
        :return: An `EasyDeLState` object
        """
        if free_optimizer_state and init_optimizer_state:
            raise EasyDeLRuntimeError(
                "You can't use `free_optimizer_state` and `init_optimizer_state` True at same Time"
            )

        if filename is None:
            from ..modules.auto_easydel_model import AutoEasyDeLModelForCausalLM

            model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
                pretrained_model_name_or_path,
                device=device,
                dtype=dtype,
                param_dtype=param_dtype,
                precision=precision,
                sharding_axis_dims=sharding_axis_dims,
                sharding_axis_names=sharding_axis_names,
                query_partition_spec=query_partition_spec,
                generation_query_partition_spec=generation_query_partition_spec,
                generation_bias_partition_spec=generation_bias_partition_spec,
                key_partition_spec=key_partition_spec,
                value_partition_spec=value_partition_spec,
                bias_partition_spec=bias_partition_spec,
                attention_partition_spec=attention_partition_spec,
                shard_attention_computation=shard_attention_computation,
                input_shape=input_shape,
                backend=backend,
                config_kwargs=config_kwargs,
                **kwargs
            )
            if tx_init is None:
                tx_init = {}

            tx_init["optimizer"] = optimizer
            tx_init["scheduler"] = scheduler

            state = cls.load(
                apply_fn=model.__call__,
                params=FrozenDict({'params': params}),
                step=0,
                opt_state=None,
                tx_init=tx_init,
                hyperparameters=None,
                module=model,
                module_config=model.config,
                module_config_args=model.config.to_dict()
            )
        else:
            with jax.default_device(device):
                from huggingface_hub import hf_hub_download
                checkpoint_path = hf_hub_download(
                    repo_id=pretrained_model_name_or_path,
                    filename=filename,
                )
                state = cls.load_state(
                    checkpoint_path=checkpoint_path,
                    init_optimizer_state=init_optimizer_state,
                    verbose=verbose,
                    state_shard_fns=state_shard_fns,
                    dtype=dtype,
                    param_dtype=param_dtype,
                    precision=precision,
                    input_shape=input_shape
                )
        if init_optimizer_state:
            with jax.default_device(device):
                state = state.init_opt_state()
        if free_optimizer_state:
            state = state.free_opt_state()
        return state

    def shard_params(
            self,
            fully_sharded_data_parallel: bool = True,
            shard_fns: Optional[Mapping[str, Callable]] = None,
            dtype: Union[jax.numpy.dtype, str] = "bf16",
            mesh: Optional[Mesh] = None,
            rules: Optional[Sequence[Mapping[str, PartitionSpec]]] = None
    ):
        dtype = fjformer.get_dtype(dtype)
        if shard_fns is None and self.module_config is None and rules is None:
            raise EasyDeLRuntimeError(
                "the model doesn't carrying `module_config` you should pass `shard_fns` or `rules`"
            )
        elif shard_fns is None and rules is not None or self.module_config is not None:
            from fjformer import match_partition_rules, make_shard_and_gather_fns
            rules = rules or self.module_config.get_partition_rules(fully_sharded_data_parallel)
            partition_specs = match_partition_rules(
                rules=rules, params=self.params
            )
            shard_fns, gather_fns = make_shard_and_gather_fns(
                partition_specs=partition_specs,
                dtype_specs=dtype
            )
        if mesh is None:
            mesh = self.module_config.jax_mesh()
        with mesh:
            return self.replace(
                params=jax.tree_util.tree_map(
                    lambda f, p: f(p), shard_fns, self.params
                )
            )

    @staticmethod
    def create_hyperparameters(model_type: str):
        """
        it's the only way we can dump xla compiler
        """
        return {
            STRING_REP.format(
                type="str",
                key="model_type",
                value=model_type
            ): DEFAULT_ES_VAL
        }

    @staticmethod
    def safe_dict(dictionary: dict):
        for k in list(dictionary.keys()):
            val = dictionary.get(k)
            if not isinstance(val, (int, bool)):
                val = dictionary.pop(k)
                string_value_format = STRING_REP.format(
                    type=type(val).__name__,
                    key=k,
                    value=val
                )
                dictionary[string_value_format] = DEFAULT_ES_VAL
        return dictionary

    @staticmethod
    def unsafe_dict(dictionary: dict):
        result = {}
        for k in list(dictionary.keys()):
            if VALUE_SEP in k and TYPE_SEP in k:
                v = dictionary[k]
                key, value = break_format(key=k, value=v)
                result[key] = value
            else:
                result[k] = dictionary[k]
        return result

    def __str__(self):

        """
        The __str__ function is called when you call str(object) or print(object).
        The __repr__ function is called when you type the object name in the interpreter.
        If no __str__ method exists, Python will use __repr__ as a fallback.

        :param self: Refer to the object itself
        :return: string
        """
        params_size = sum(getattr(n, "size", 0) for n in jax.tree_util.tree_flatten(self.params)[0])
        opt_state_size = sum(getattr(n, "size", 0) for n in jax.tree_util.tree_flatten(self.opt_state)[0])

        def make_depth(mdl=None):
            if mdl is not None:
                try:
                    return mdl.__str__().replace(
                        "\n",
                        "\n\t"
                        ""
                    ) if hasattr(mdl, "__str__") else None
                except TypeError:
                    ...
            return mdl

        optimizer = self.tx_init.get("optimizer", None)
        scheduler = self.tx_init.get("scheduler", None)

        if optimizer is None:
            optimizer = self.find_key(
                "optimizer",
                self.tx_init
            )
        if scheduler is None:
            scheduler = self.find_key(
                "scheduler",
                self.tx_init
            )

        string = (
            f"{self.__class__.__name__}("
            f"\n\tstep = {self.step}"
            f"\n\tmodule = {make_depth(self.module)}"
            f"\n\tmodule_config = {make_depth(self.module_config)}"
            f"\n\tapply_fn: Callable = {make_depth(self.apply_fn)}"
            f"\n\tparams : {params_size} Parameters"
            f"\n\ttx = {optimizer} Optimizer with {scheduler} Scheduler"
            f"\n\topt_state : {opt_state_size} Parameters"
            f"\n\thyperparameters : {self.hyperparameters}"
            f"\n)"
        )
        return string

    @classmethod
    def search(cls, key, dictionary: dict, default: Any = None):
        req = dictionary.get(key, None)
        if req is None:
            req = cls.find_key(key, dictionary)
        return req or default

    @staticmethod
    def find_key(key, dictionary: dict) -> Union[str, None]:
        result = None
        for k, v in dictionary.items():
            k_, v_ = break_format(key=k, value=v)
            if k_ == key:
                result = v_
                break
        return result

    def __repr__(self):

        """
        The __repr__ function is the "official" string representation of an object.
        It's what you get when you type the object name at the Python prompt, or pass it to str().
        The goal of __repr__ is to be unambiguous: if eval(repr(x)) == x, then __repr__ should return a string that
        looks like a valid Python expression that could be used to recreate an object with the same value (
        given an appropriate environment). If this is not possible, a string formatted using %s
        formatting is also acceptable.

        :param self: Represent the instance of the class
        :return: A string that is a valid python expression
        """
        return self.__str__()

__repr__()

The repr function is the "official" string representation of an object. It's what you get when you type the object name at the Python prompt, or pass it to str(). The goal of repr is to be unambiguous: if eval(repr(x)) == x, then repr should return a string that looks like a valid Python expression that could be used to recreate an object with the same value ( given an appropriate environment). If this is not possible, a string formatted using %s formatting is also acceptable.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description

A string that is a valid python expression

Source code in src/python/easydel/etils/easystate.py
696
697
698
699
700
701
702
703
704
705
706
707
708
709
def __repr__(self):

    """
    The __repr__ function is the "official" string representation of an object.
    It's what you get when you type the object name at the Python prompt, or pass it to str().
    The goal of __repr__ is to be unambiguous: if eval(repr(x)) == x, then __repr__ should return a string that
    looks like a valid Python expression that could be used to recreate an object with the same value (
    given an appropriate environment). If this is not possible, a string formatted using %s
    formatting is also acceptable.

    :param self: Represent the instance of the class
    :return: A string that is a valid python expression
    """
    return self.__str__()

__str__()

The str function is called when you call str(object) or print(object). The repr function is called when you type the object name in the interpreter. If no str method exists, Python will use repr as a fallback.

Parameters:

Name Type Description Default
self

Refer to the object itself

required

Returns:

Type Description

string

Source code in src/python/easydel/etils/easystate.py
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def __str__(self):

    """
    The __str__ function is called when you call str(object) or print(object).
    The __repr__ function is called when you type the object name in the interpreter.
    If no __str__ method exists, Python will use __repr__ as a fallback.

    :param self: Refer to the object itself
    :return: string
    """
    params_size = sum(getattr(n, "size", 0) for n in jax.tree_util.tree_flatten(self.params)[0])
    opt_state_size = sum(getattr(n, "size", 0) for n in jax.tree_util.tree_flatten(self.opt_state)[0])

    def make_depth(mdl=None):
        if mdl is not None:
            try:
                return mdl.__str__().replace(
                    "\n",
                    "\n\t"
                    ""
                ) if hasattr(mdl, "__str__") else None
            except TypeError:
                ...
        return mdl

    optimizer = self.tx_init.get("optimizer", None)
    scheduler = self.tx_init.get("scheduler", None)

    if optimizer is None:
        optimizer = self.find_key(
            "optimizer",
            self.tx_init
        )
    if scheduler is None:
        scheduler = self.find_key(
            "scheduler",
            self.tx_init
        )

    string = (
        f"{self.__class__.__name__}("
        f"\n\tstep = {self.step}"
        f"\n\tmodule = {make_depth(self.module)}"
        f"\n\tmodule_config = {make_depth(self.module_config)}"
        f"\n\tapply_fn: Callable = {make_depth(self.apply_fn)}"
        f"\n\tparams : {params_size} Parameters"
        f"\n\ttx = {optimizer} Optimizer with {scheduler} Scheduler"
        f"\n\topt_state : {opt_state_size} Parameters"
        f"\n\thyperparameters : {self.hyperparameters}"
        f"\n)"
    )
    return string

apply_gradients(*, grads, **kwargs)

The apply_gradients function is the core of the optimizer. It takes in a dictionary of gradients, and returns an updated version of itself with new parameters and state. The function also updates the step count.

Parameters:

Name Type Description Default
self

Refer to the current instance of the class

required
*

Unpack the grads dictionary into positional arguments

required
grads

Pass in the gradients of the loss function with respect to each parameter

required
kwargs

Pass in additional arguments to the function

{}

Returns:

Type Description

A new State with the updated parameters and params

Source code in src/python/easydel/etils/easystate.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def apply_gradients(self, *, grads, **kwargs):

    """
    The apply_gradients function is the core of the optimizer. It takes in a dictionary of gradients,
    and returns an updated version of itself with new parameters and state. The function also updates
    the step count.

    :param self: Refer to the current instance of the class
    :param *: Unpack the grads dictionary into positional arguments
    :param grads: Pass in the gradients of the loss function with respect to each parameter
    :param kwargs: Pass in additional arguments to the function
    :return: A new State with the updated parameters and params
    """
    if OVERWRITE_WITH_GRADIENT in grads:
        grads_with_opt = grads['params']
        params_with_opt = self.params['params']
    else:
        grads_with_opt = grads
        params_with_opt = self.params

    updates, new_opt_state = self.tx.update(
        grads_with_opt, self.opt_state, params_with_opt
    )
    new_params_with_opt = optax.apply_updates(params_with_opt, updates)
    if OVERWRITE_WITH_GRADIENT in grads:
        new_params = {
            'params': new_params_with_opt,
            OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT]
        }
    else:
        new_params = new_params_with_opt
    return self.replace(
        step=self.step + 1,
        params=new_params,
        opt_state=new_opt_state,
        **kwargs,
    )

create(*, apply_fn, params, tx, tx_init=None, hyperparameters=None, module=None, module_config=None, module_config_args=None, **kwargs) classmethod

The create function is used to create a new instance of the class.

Parameters:

Name Type Description Default
cls

Create a new instance of the class

required
*

Pass a list of parameters to the function

required
apply_fn Callable

Callable: Apply the model to a batch of data

required
params Union[FrozenDict[str, Any], Mapping[str, Any]]

core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in the parameters of the model

required
tx GradientTransformation

optax.GradientTransformation: Initialize the optimizer

required
tx_init Optional[dict]

Optional[dict]: Initialize the optimizer

None
hyperparameters Optional[dict]

Optional[dict]: Pass hyperparameters to the state for init

None
module Optional[EasyDeLFlaxPretrainedModel]

Optional[EasyDeLFlaxPretrainedModel]: Pass the module to be used int state

None
module_config Optional[EasyDeLPretrainedConfig]

Optional[EasyDeLPretrainedConfig]: Pass in the module config

None
module_config_args Optional[dict]

Optional[dict]: Store the config args of the model

None
kwargs

Pass in additional parameters to the

{}

Returns:

Type Description

A EasyDeLState object

Source code in src/python/easydel/etils/easystate.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
@classmethod
def create(
        cls,
        *,
        apply_fn: Callable,
        params: Union[core.FrozenDict[str, Any], Mapping[str, Any]],
        tx: optax.GradientTransformation,
        tx_init: Optional[dict] = None,
        hyperparameters: Optional[dict] = None,
        module: Optional["EasyDeLFlaxPretrainedModel"] = None,  # type:ignore
        module_config: Optional["EasyDeLPretrainedConfig"] = None,  # type:ignore
        module_config_args: Optional[dict] = None,
        **kwargs
):

    """
    The create function is used to create a new instance of the class.

    :param cls: Create a new instance of the class
    :param *: Pass a list of parameters to the function
    :param apply_fn: Callable: Apply the model to a batch of data
    :param params: core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in the parameters of the model
    :param tx: optax.GradientTransformation: Initialize the optimizer
    :param tx_init: Optional[dict]: Initialize the optimizer
    :param hyperparameters: Optional[dict]: Pass hyperparameters to the state for init
    :param module: Optional[EasyDeLFlaxPretrainedModel]: Pass the module to be used int state
    :param module_config: Optional[EasyDeLPretrainedConfig]: Pass in the module config
    :param module_config_args: Optional[dict]: Store the config args of the model
    :param kwargs: Pass in additional parameters to the
    :return: A EasyDeLState object
    """
    if hyperparameters is None:
        hyperparameters = {}
    params_with_opt = (
        params['params'] if OVERWRITE_WITH_GRADIENT in params else params
    )
    opt_state = tx.init(params_with_opt)
    if module_config is not None:
        module_config = copy.deepcopy(module_config)
        cls.safe_dict(module_config.__dict__)
    return cls(
        step=0,
        apply_fn=apply_fn,
        module=module,
        params=params,
        tx=tx,
        opt_state=opt_state,
        tx_init=cls.safe_dict(tx_init),
        hyperparameters=hyperparameters,
        module_config=module_config,
        module_config_args=None,
        **kwargs,
    )

create_hyperparameters(model_type) staticmethod

it's the only way we can dump xla compiler

Source code in src/python/easydel/etils/easystate.py
587
588
589
590
591
592
593
594
595
596
597
598
@staticmethod
def create_hyperparameters(model_type: str):
    """
    it's the only way we can dump xla compiler
    """
    return {
        STRING_REP.format(
            type="str",
            key="model_type",
            value=model_type
        ): DEFAULT_ES_VAL
    }

free_opt_state()

The free_opt_state function is used to free the memory allocated by a previous call to setopt. It should be called after all the options have been set, and before you perform any of the transfers.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description
EasyDeLState

A new state with the opt_state field set to none

Source code in src/python/easydel/etils/easystate.py
386
387
388
389
390
391
392
393
394
395
396
397
398
def free_opt_state(self) -> "EasyDeLState":

    """
    The free_opt_state function is used to free the memory allocated by a previous call to setopt.
    It should be called after all the options have been set, and before you perform any of the transfers.


    :param self: Represent the instance of the class
    :return: A new state with the opt_state field set to none
    """
    return self.replace(
        opt_state=None
    )

from_pretrained(pretrained_model_name_or_path, filename=None, optimizer='adamw', scheduler='none', tx_init=None, device=jax.devices('cpu')[0], dtype=jax.numpy.float32, param_dtype=jax.numpy.float32, precision=jax.lax.Precision('fastest'), sharding_axis_dims=(1, -1, 1, 1), sharding_axis_names=('dp', 'fsdp', 'tp', 'sp'), query_partition_spec=PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), generation_query_partition_spec=PartitionSpec(('dp', 'fsdp'), 'tp', None, None), key_partition_spec=PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), value_partition_spec=PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), bias_partition_spec=PartitionSpec(('dp', 'fsdp'), None, None, None), generation_bias_partition_spec=PartitionSpec(('dp', 'fsdp'), None, None, None), attention_partition_spec=PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), shard_attention_computation=True, input_shape=(1, 1), backend=None, init_optimizer_state=False, free_optimizer_state=True, verbose=True, state_shard_fns=None, config_kwargs=None, **kwargs) classmethod

The from_pretrained function is a helper function to quickly load a pretrained model and its associated configuration. This method takes care of returning the correct model class instance based on the model_type property in the config object, or when it's missing, falling back to using pattern matching on the pretrained_model_name_or_path string:

Parameters:

Name Type Description Default
cls

Refer to the class that is being defined

required
pretrained_model_name_or_path str

str: Load the pretrained model

required
filename Optional[str]

Optional[str]: Specify the name of the file to download from huggingface hub

None
optimizer AVAILABLE_OPTIMIZERS

AVAILABLE_OPTIMIZERS: Specify the optimizer used for training

'adamw'
scheduler AVAILABLE_SCHEDULERS

AVAILABLE_SCHEDULERS: Specify the name of the scheduler to use

'none'
tx_init Optional[dict]

Optional[dict]: Pass the hyperparameters of the optimizer

None
device

Specify the device on which to run the model

devices('cpu')[0]
dtype dtype

jax.numpy.dtype: Specify the dtype of the model parameters

float32
param_dtype dtype

jax.numpy.dtype: Specify the data type of the parameters

float32
precision Optional[Precision]

jax.lax.Precision: Control the precision of the calculation

Precision('fastest')
sharding_axis_dims Sequence[int]

Sequence[int]: Specify the dimension of each axis

(1, -1, 1, 1)
sharding_axis_names Sequence[str]

Sequence[str]: Specify the names of the axes in each shard

('dp', 'fsdp', 'tp', 'sp')
query_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the query matrix

PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None)
generation_query_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the query tensor in generation process:param key_partition_spec: PartitionSpec: Specify the partitioning of the key matrix

PartitionSpec(('dp', 'fsdp'), 'tp', None, None)
value_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the value tensor

PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None)
bias_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the bias

PartitionSpec(('dp', 'fsdp'), None, None, None)
attention_partition_spec PartitionSpec

PartitionSpec: Partition the attention weights

PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None)
shard_attention_computation bool

bool: Determine whether to use shard_map or not

True
input_shape Sequence[int]

Sequence[int]: Specify the shape of the input to be used for training

(1, 1)
backend Optional[str]

Optional[str]: Specify the backend used for the model

None
init_optimizer_state bool

bool: Initialize the optimizer state

False
free_optimizer_state bool

bool: Free the optimizer state from memory

True
verbose bool

bool: Print the progress of loading the model

True
state_shard_fns Optional[Mapping[str, Callable]]

Optional[Mapping[str,Callable]]: Specify the function to use for sharding the state

None
kwargs

Pass keyword arguments to the function

{}
config_kwargs Optional[Mapping[str, Any]]

Optional[Mapping[str, Any]]: Config kwargs to be added to config before creating module

None

Returns:

Type Description
EasyDeLState

An EasyDeLState object

Source code in src/python/easydel/etils/easystate.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
@classmethod
def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        filename: Optional[str] = None,
        optimizer: AVAILABLE_OPTIMIZERS = "adamw",
        scheduler: AVAILABLE_SCHEDULERS = "none",
        tx_init: Optional[dict] = None,
        device=jax.devices('cpu')[0],
        dtype: jax.numpy.dtype = jax.numpy.float32,
        param_dtype: jax.numpy.dtype = jax.numpy.float32,
        precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"),
        sharding_axis_dims: Sequence[int] = (1, -1, 1, 1),
        sharding_axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"),
        query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        generation_query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "tp", None, None),
        key_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        value_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
        generation_bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
        attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        shard_attention_computation: bool = True,
        input_shape: Sequence[int] = (1, 1),
        backend: Optional[str] = None,
        init_optimizer_state: bool = False,
        free_optimizer_state: bool = True,
        verbose: bool = True,
        state_shard_fns: Optional[Mapping[str, Callable]] = None,
        config_kwargs: Optional[Mapping[str, Any]] = None,
        **kwargs
) -> "EasyDeLState":

    """
    The from_pretrained function is a helper function to quickly load a pretrained model and its associated configuration.
    This method takes care of returning the correct model class instance based on the `model_type` property in the
    config object, or when it's missing, falling back to using pattern matching on the
     `pretrained_model_name_or_path` string:

    :param cls: Refer to the class that is being defined
    :param pretrained_model_name_or_path: str: Load the pretrained model
    :param filename: Optional[str]: Specify the name of the file to download from huggingface hub
    :param optimizer: AVAILABLE_OPTIMIZERS: Specify the optimizer used for training
    :param scheduler: AVAILABLE_SCHEDULERS: Specify the name of the scheduler to use
    :param tx_init: Optional[dict]: Pass the hyperparameters of the optimizer
    :param device: Specify the device on which to run the model
    :param dtype: jax.numpy.dtype: Specify the dtype of the model parameters
    :param param_dtype: jax.numpy.dtype: Specify the data type of the parameters
    :param precision: jax.lax.Precision: Control the precision of the calculation
    :param sharding_axis_dims: Sequence[int]: Specify the dimension of each axis
    :param sharding_axis_names: Sequence[str]: Specify the names of the axes in each shard
    :param query_partition_spec: PartitionSpec: Specify the partitioning of the query matrix
    :param generation_query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor in
    generation process:param key_partition_spec: PartitionSpec: Specify the partitioning of the key matrix
    :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor
    :param bias_partition_spec: PartitionSpec: Specify the partitioning of the bias
    :param attention_partition_spec: PartitionSpec: Partition the attention weights
    :param shard_attention_computation: bool: Determine whether to use shard_map or not
    :param input_shape: Sequence[int]: Specify the shape of the input to be used for training
    :param backend: Optional[str]: Specify the backend used for the model
    :param init_optimizer_state: bool: Initialize the optimizer state
    :param free_optimizer_state: bool: Free the optimizer state from memory
    :param verbose: bool: Print the progress of loading the model
    :param state_shard_fns: Optional[Mapping[str,Callable]]: Specify the function to use for sharding the state
    :param kwargs: Pass keyword arguments to the function
    :param config_kwargs: Optional[Mapping[str, Any]]: Config kwargs to be added to config before creating module
    :return: An `EasyDeLState` object
    """
    if free_optimizer_state and init_optimizer_state:
        raise EasyDeLRuntimeError(
            "You can't use `free_optimizer_state` and `init_optimizer_state` True at same Time"
        )

    if filename is None:
        from ..modules.auto_easydel_model import AutoEasyDeLModelForCausalLM

        model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path,
            device=device,
            dtype=dtype,
            param_dtype=param_dtype,
            precision=precision,
            sharding_axis_dims=sharding_axis_dims,
            sharding_axis_names=sharding_axis_names,
            query_partition_spec=query_partition_spec,
            generation_query_partition_spec=generation_query_partition_spec,
            generation_bias_partition_spec=generation_bias_partition_spec,
            key_partition_spec=key_partition_spec,
            value_partition_spec=value_partition_spec,
            bias_partition_spec=bias_partition_spec,
            attention_partition_spec=attention_partition_spec,
            shard_attention_computation=shard_attention_computation,
            input_shape=input_shape,
            backend=backend,
            config_kwargs=config_kwargs,
            **kwargs
        )
        if tx_init is None:
            tx_init = {}

        tx_init["optimizer"] = optimizer
        tx_init["scheduler"] = scheduler

        state = cls.load(
            apply_fn=model.__call__,
            params=FrozenDict({'params': params}),
            step=0,
            opt_state=None,
            tx_init=tx_init,
            hyperparameters=None,
            module=model,
            module_config=model.config,
            module_config_args=model.config.to_dict()
        )
    else:
        with jax.default_device(device):
            from huggingface_hub import hf_hub_download
            checkpoint_path = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename=filename,
            )
            state = cls.load_state(
                checkpoint_path=checkpoint_path,
                init_optimizer_state=init_optimizer_state,
                verbose=verbose,
                state_shard_fns=state_shard_fns,
                dtype=dtype,
                param_dtype=param_dtype,
                precision=precision,
                input_shape=input_shape
            )
    if init_optimizer_state:
        with jax.default_device(device):
            state = state.init_opt_state()
    if free_optimizer_state:
        state = state.free_opt_state()
    return state

init_opt_state()

The init_opt_state function initializes the optimizer state.

Parameters:

Name Type Description Default
self

Make the object callable, and params is used to pass in a dictionary of parameters

required

Returns:

Type Description
EasyDeLState

A new instance of the class with opt_state initialized

Source code in src/python/easydel/etils/easystate.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def init_opt_state(self) -> "EasyDeLState":

    """
    The init_opt_state function initializes the optimizer state.
    :param self: Make the object callable, and params is used to pass in a dictionary of parameters
    :return: A new instance of the class with opt_state initialized
    """
    if self.opt_state is None:
        params_with_opt = (
            self.params['params'] if OVERWRITE_WITH_GRADIENT in self.params else self.params
        )
        opt_state = self.tx.init(params_with_opt)

        return self.replace(
            opt_state=opt_state
        )
    return self

load(*, apply_fn, params, step=0, opt_state=None, tx_init=None, hyperparameters=None, module=None, module_config=None, module_config_args=None, **kwargs) classmethod

The load function is used to load a saved state of the Model and optimizer or Model Only.

Parameters:

Name Type Description Default
cls

Make the function a class method

required
*

Pass in a variable number of arguments

required
step int

int: Keep track of the number of steps that have been taken

0
apply_fn Callable

Callable: Apply the optimizer to the model

required
params Union[FrozenDict[str, Any], Mapping[str, Any]]

core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in the parameters of the model

required
opt_state Optional[OptState]

Optional[optax.OptState]: optimizer state

None
tx_init Optional[dict]

Optional[dict]: Pass the hyperparameters to the optimizer

None
hyperparameters Optional[dict]

Optional[dict]: Load hyperparameters from the state dict

None
module Optional[EasyDeLFlaxPretrainedModel]

Optional[EasyDeLFlaxPretrainedModel]: Pass in the module

None
module_config Optional[EasyDeLPretrainedConfig]

Optional[EasyDeLPretrainedConfig]: Pass the module config

None
module_config_args Optional[dict]

Optional[dict]: Pass the config_args to the model

None
kwargs

Pass in any additional parameters that may be needed for the model

{}

Returns:

Type Description

A new instance of the class

Source code in src/python/easydel/etils/easystate.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
@classmethod
def load(
        cls,
        *,
        apply_fn: Callable,
        params: Union[core.FrozenDict[str, Any], Mapping[str, Any]],
        step: int = 0,
        opt_state: Optional[optax.OptState] = None,
        tx_init: Optional[dict] = None,
        hyperparameters: Optional[dict] = None,
        module: Optional["EasyDeLFlaxPretrainedModel"] = None,  # type:ignore
        module_config: Optional["EasyDeLPretrainedConfig"] = None,  # type:ignore
        module_config_args: Optional[dict] = None,
        **kwargs
):

    """
    The load function is used to load a saved state of the Model and optimizer or Model Only.

    :param cls: Make the function a class method
    :param *: Pass in a variable number of arguments
    :param step: int: Keep track of the number of steps that have been taken
    :param apply_fn: Callable: Apply the optimizer to the model
    :param params: core.FrozenDict[str,Any] | Mapping[str,Any]: Pass in the parameters of the model
    :param opt_state: Optional[optax.OptState]: optimizer state
    :param tx_init: Optional[dict]: Pass the hyperparameters to the optimizer
    :param hyperparameters: Optional[dict]: Load hyperparameters from the state dict
    :param module: Optional[EasyDeLFlaxPretrainedModel]: Pass in the module
    :param module_config: Optional[EasyDeLPretrainedConfig]: Pass the module config
    :param module_config_args: Optional[dict]: Pass the config_args to the model
    :param kwargs: Pass in any additional parameters that may be needed for the model
    :return: A new instance of the class
    """
    if module_config is not None:
        module_config = copy.deepcopy(module_config)

    if tx_init is None:
        tx_init = {}
    tx_init = copy.deepcopy(tx_init)
    tx_init = cls.unsafe_dict(tx_init)

    tx_init["optimizer"] = cls.search("optimizer", tx_init, "adamw")
    tx_init["scheduler"] = cls.search("scheduler", tx_init, "none")
    tx_init["steps"] = cls.search("steps", tx_init, 1e6)

    def fix_dict_types(input_dict):
        fixed_dict = input_dict.copy()

        # Fix extra_optimizer_kwargs
        if 'extra_optimizer_kwargs' in fixed_dict:
            fixed_dict['extra_optimizer_kwargs'] = eval(fixed_dict['extra_optimizer_kwargs'])

        # Fix gradient_accumulation_steps
        if 'gradient_accumulation_steps' in fixed_dict:
            fixed_dict['gradient_accumulation_steps'] = int(fixed_dict['gradient_accumulation_steps'])

        # Fix steps
        if 'steps' in fixed_dict:
            fixed_dict['steps'] = int(fixed_dict['steps'])

        # Fix warmup_steps
        if 'warmup_steps' in fixed_dict:
            fixed_dict['warmup_steps'] = int(fixed_dict['warmup_steps'])

        return fixed_dict

    try:
        tx, sc = get_optimizer_and_scheduler(
            **tx_init
        )
    except TypeError:
        tx, sc = get_optimizer_and_scheduler(
            **fix_dict_types(tx_init)
        )
    if hyperparameters is None:
        hyperparameters = {}

    if module_config is not None:
        hyperparameters = cls.create_hyperparameters(module_config.model_type)
        cls.safe_dict(module_config.__dict__)
    return cls(
        step=step,
        apply_fn=apply_fn,
        params=params,
        tx=tx,
        opt_state=opt_state,
        tx_init=cls.safe_dict(tx_init),
        hyperparameters=hyperparameters,
        module=module,
        module_config=module_config,
        module_config_args=None,
        **kwargs,
    )

load_state(checkpoint_path, dtype=jnp.float32, param_dtype=jnp.float32, precision=None, init_optimizer_state=False, state_shard_fns=None, verbose=False, input_shape=(1, 1), config_kwargs=None) classmethod

The load_state function is a class method that loads the state of an EasyDeLModel from a checkpoint.

Parameters:

Name Type Description Default
cls

Create an instance of the class

required
checkpoint_path Union[str, PathLike]

str | os.PathLike: Specify the path to the checkpoint file

required
dtype dtype

jnp.dtype: The dtype of the model

float32
param_dtype dtype

jnp.dtype: The dtype of the model parameters

float32
precision Optional[Union[str, Precision]]

Optional[Union[str, jax.lax.Precision]]: precision of the model

None
init_optimizer_state bool

bool: Initialize the optimizer if it's not Initialized yet (if it Initialized the option will be ignored )

False
state_shard_fns Optional[Mapping[str, Callable]]

Optional[Mapping[str,Callable]]: Specify the function that will be used to shard the loaded state

None
verbose bool

bool: Print out the progress of loading

False
input_shape Tuple

Tuple: input_shape to init module

(1, 1)
config_kwargs Optional[dict]

Optional[dict] : config kwargs to be passed to model config

None

Returns:

Type Description

A state object

Source code in src/python/easydel/etils/easystate.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
@classmethod
def load_state(
        cls,
        checkpoint_path: Union[str, os.PathLike],
        dtype: jnp.dtype = jnp.float32,
        param_dtype: jnp.dtype = jnp.float32,
        precision: Optional[Union[str, jax.lax.Precision]] = None,
        init_optimizer_state: bool = False,
        state_shard_fns: Optional[Mapping[str, Callable]] = None,
        verbose: bool = False,
        input_shape: Tuple = (1, 1),
        config_kwargs: Optional[dict] = None
):

    """    
    The load_state function is a class method that loads the state of an EasyDeLModel from a checkpoint.

    :param cls: Create an instance of the class
    :param checkpoint_path: str | os.PathLike: Specify the path to the checkpoint file
    :param dtype: jnp.dtype: The dtype of the model
    :param param_dtype: jnp.dtype: The dtype of the model parameters
    :param precision: Optional[Union[str, jax.lax.Precision]]: precision of the model
    :param init_optimizer_state: bool: Initialize the optimizer if it's not Initialized yet (if it Initialized the option
    will be ignored )
    :param state_shard_fns: Optional[Mapping[str,Callable]]: Specify the function that will be used 
    to shard the loaded state
    :param verbose: bool: Print out the progress of loading
    :param input_shape: Tuple: input_shape to init module
    :param config_kwargs: Optional[dict] : config kwargs to be passed to model config
    :return: A state object
    """
    from ..modules.auto_easydel_model import get_modules_by_type

    checkpoint = fjformer.CheckpointManager.load_checkpoint(
        path=checkpoint_path,
        shard_fns=state_shard_fns,
        verbose=verbose,
    )
    hyperparameters = checkpoint.get("hyperparameters")
    cfg, module, convertor = get_modules_by_type(model_type=cls.get_model_type(hyperparameters))
    checkpoint.pop("module_config", None)
    if checkpoint["module_config_args"] is not None:
        cfg_behave = cls.unsafe_dict(checkpoint.get("module_config_args", {}))
        cfg_behave.pop("id2label", None)
        cfg_behave.pop("label2id", None)
        cfg_behave.pop("torch_dtype", None)
        for k, v in cfg_behave.items():
            if v is None:
                cfg_behave.pop(k, None)
            elif v == "None":
                cfg_behave[k] = None
            elif isinstance(v, str):
                if v.startswith("{") or v.startswith("(") or v.startswith("PartitionSpec"):
                    cfg_behave[k] = eval(v)
        module_config = cfg.from_dict(cfg_behave)
        if config_kwargs is not None:
            for k, v in config_kwargs.items():
                setattr(module_config, k, v)
        module_in = module(
            config=module_config,
            dtype=dtype,
            param_dtype=param_dtype,
            precision=precision,
            input_shape=input_shape
        )
    else:
        raise TypeError(
            "Om seems like i couldn't read model correctly ;("
        )
    state = cls.load(
        apply_fn=module_in.__call__,
        module=module_in,
        module_config=module_config,
        **checkpoint
    )
    state = state.replace(
        module_config_args=None  # removing because it's not needed anymore
    )
    if init_optimizer_state:
        state = state.init_opt_state()
    return state

save_state(filename, save_optimizer=False, checkpoint_dir=None, verbose=False, gather_fns=None, float_dtype=None)

The save_state function saves the state of a model to disk.

Parameters:

Name Type Description Default
self

Pass the object itself to the function

required
filename Union[str, PathLike]

str | os.PathLike: Specify the name of the file to save

required
save_optimizer bool

bool: Determine whether to save the optimizer state or not

False
checkpoint_dir Optional[Union[str, PathLike]]

Optional[str | os.PathLike]: Specify the directory where the checkpoint is saved

None
verbose bool

bool: Print out the path of the saved file

False
gather_fns dict[Callable]

dict[Callable]: Specify a dictionary of functions that can be used to gather

None
float_dtype Union[str, dtype]

str | jax.numpy.dtype: Specify the precision of the saved model

None

Save the optimizer state

required

Returns:

Type Description

None

Source code in src/python/easydel/etils/easystate.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def save_state(
        self,
        filename: Union[str, os.PathLike],
        save_optimizer: bool = False,
        checkpoint_dir: Optional[Union[str, os.PathLike]] = None,
        verbose: bool = False,
        gather_fns: dict[Callable] = None,
        float_dtype: Union[str, jax.numpy.dtype] = None,
):

    """
    The save_state function saves the state of a model to disk.

    :param self: Pass the object itself to the function
    :param filename: str | os.PathLike: Specify the name of the file to save
    :param save_optimizer: bool: Determine whether to save the optimizer state or not
    :param checkpoint_dir: Optional[str | os.PathLike]: Specify the directory where the checkpoint is saved
    :param verbose: bool: Print out the path of the saved file
    :param gather_fns: dict[Callable]: Specify a dictionary of functions that can be used to gather
    :param float_dtype: str | jax.numpy.dtype: Specify the precision of the saved model
    :param : Save the optimizer state
    :return: None
    """
    state = self
    if not save_optimizer:
        state = self.replace(
            opt_state=None
        )
    state = state.replace(
        module_config_args={
            k: v for k, v in state.module.config.__dict__.items() if
            isinstance(
                v, (int, bool, float)
            )
        }
    )
    fjformer.CheckpointManager.save_state_to_file(
        state=state,
        path=os.path.join(checkpoint_dir, filename) if checkpoint_dir is not None else filename,
        verbose=verbose,
        gather_fns=gather_fns,
        float_dtype=float_dtype,
    )