Skip to content

modules.easydel_modelling_utils

EasyDeLFlaxPretrainedModel

Bases: FlaxPreTrainedModel

Source code in src/python/easydel/modules/easydel_modelling_utils.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
class EasyDeLFlaxPretrainedModel(FlaxPreTrainedModel):
    def __init__(
            self,
            config: PretrainedConfig,
            module: flax.linen.Module,
            input_shape: Tuple = (1, 1),
            seed: int = 0,
            dtype: jnp.dtype = jnp.float32,
            param_dtype: jnp.dtype = jnp.float32,  # Ignored
            precision: Optional[Union[jax.lax.Precision, str]] = None,  # Ignored
            _do_init: bool = True,
    ):
        super().__init__(
            config=config,
            module=module,
            input_shape=input_shape,
            seed=seed,
            dtype=dtype,
            _do_init=_do_init
        )

    def get_input_embeddings(self):
        """
        The get_input_embeddings function returns the embedding layer of the model.

        :param self: Refer to the current object
        :return: The embedding layer of the model
        """
        raise NotImplementedError()

    def set_input_embeddings(self, value):
        """
        The set_input_embeddings function is used to set the embedding module of the model.

        :param self: Represent the instance of the class
        :param value: Set the embeddings of the model
        """
        raise NotImplementedError()

    def get_output_embeddings(self):
        """
        The get_output_embeddings function returns the output embeddings of a model.

        :param self: Represent the instance of the class
        :return: The output embeddings of the model
        """
        raise NotImplementedError()

    def set_output_embeddings(self, new_embeddings):
        """
        The set_output_embeddings function is used to set the output embeddings of a model.
        This function can be used to change the output embedding layer of a pretrained model in order to finetune it
        to some downstream task. Changing this layer has an effect only if the model has already been fine-tuned on some
        task (e.g., for classification). If you are training your own language models, you should call this function before
        you start training.

        :param self: Represent the instance of the class
        :param new_embeddings: Set the embeddings of the output layer
        :return: A new embedding layer
        """
        raise NotImplementedError()

    def set_decoder(self, decoder):
        """
        The set_decoder function is used to set the decoder for a given encoder.

        :param self: Refer to the object itself
        :param decoder: Set the decoder for a given encoder
        :return: A decoder
        """
        raise NotImplementedError()

    def get_decoder(self):
        """
        The get_decoder function is used to create a decoder object.

        :param self: Represent the instance of the class
        :return: A decoder object
        """
        raise NotImplementedError()

    def init_cache(self, batch_size: int, max_length: int):
        raise NotImplementedError("init_cache is not Implemented Yet!")

    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
        """
        The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

        :param self: Access variables that belong to the class
        :param input_ids: Pass in the input tokens
        :param max_length: Set the length of the sequence to be generated
        :param attention_mask: Optional[chex.Array]: Mask the attention weights
        :return: A dictionary of the past_key_values, attention_mask and position ids

        """
        batch_size, seq_length = input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length)
        extended_attention_mask = jnp.ones(
            (batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = jax.lax.dynamic_update_slice(
                extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                            None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: Optional[chex.Array] = None,
            position_ids: Optional[chex.Array] = None,
            params: dict = None,
            past_key_values: dict = None,
            dropout_rng: jax.random.PRNGKey = None,
            train: bool = False,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            extra_embedding: Optional[Union[jnp.ndarray, None]] = None,
            add_params_field: bool = False,
            vision_mask: Optional[chex.Array] = None,
            **kwargs
    ):
        raise NotImplementedError("Not Implemented Yet")

    def __repr__(self):

        """
        The __repr__ function is used to generate a string representation of an object.
        This function should return a string that can be parsed by the Python interpreter
        to recreate the object. The __repr__ function is called when you use print() on an
        object, or when you type its name in the REPL.

        :param self: Refer to the instance of the class
        :return: A string representation of the object
        """
        string = f"{self.__class__.__name__}(\n"
        for k, v in self.__dict__.items():
            if not k.startswith("_"):
                try:
                    repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n"
                    string += repr_src if len(repr_src) < 500 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n"
                except TypeError:
                    pass
        return string + ")"

    def __str__(self):

        """
        The __str__ function is called when you use the print function or when str() is used.
        It should return a string representation of the object.

        :param self: Refer to the instance of the class
        :return: The object's string representation
        """
        return self.__repr__()

    @property
    def config(self) -> EasyDeLPretrainedConfig:
        return self._config  # type:ignore

    def to_easydel_state(
            self,
            params: flax.core.FrozenDict,
    ):
        return EasyDeLState.load(
            apply_fn=self.__call__,
            params=params,
            opt_state=None,
            module_config=self.config,
        )

__repr__()

The repr function is used to generate a string representation of an object. This function should return a string that can be parsed by the Python interpreter to recreate the object. The repr function is called when you use print() on an object, or when you type its name in the REPL.

Parameters:

Name Type Description Default
self

Refer to the instance of the class

required

Returns:

Type Description

A string representation of the object

Source code in src/python/easydel/modules/easydel_modelling_utils.py
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def __repr__(self):

    """
    The __repr__ function is used to generate a string representation of an object.
    This function should return a string that can be parsed by the Python interpreter
    to recreate the object. The __repr__ function is called when you use print() on an
    object, or when you type its name in the REPL.

    :param self: Refer to the instance of the class
    :return: A string representation of the object
    """
    string = f"{self.__class__.__name__}(\n"
    for k, v in self.__dict__.items():
        if not k.startswith("_"):
            try:
                repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n"
                string += repr_src if len(repr_src) < 500 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n"
            except TypeError:
                pass
    return string + ")"

__str__()

The str function is called when you use the print function or when str() is used. It should return a string representation of the object.

Parameters:

Name Type Description Default
self

Refer to the instance of the class

required

Returns:

Type Description

The object's string representation

Source code in src/python/easydel/modules/easydel_modelling_utils.py
614
615
616
617
618
619
620
621
622
623
def __str__(self):

    """
    The __str__ function is called when you use the print function or when str() is used.
    It should return a string representation of the object.

    :param self: Refer to the instance of the class
    :return: The object's string representation
    """
    return self.__repr__()

get_decoder()

The get_decoder function is used to create a decoder object.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description

A decoder object

Source code in src/python/easydel/modules/easydel_modelling_utils.py
527
528
529
530
531
532
533
534
def get_decoder(self):
    """
    The get_decoder function is used to create a decoder object.

    :param self: Represent the instance of the class
    :return: A decoder object
    """
    raise NotImplementedError()

get_input_embeddings()

The get_input_embeddings function returns the embedding layer of the model.

Parameters:

Name Type Description Default
self

Refer to the current object

required

Returns:

Type Description

The embedding layer of the model

Source code in src/python/easydel/modules/easydel_modelling_utils.py
476
477
478
479
480
481
482
483
def get_input_embeddings(self):
    """
    The get_input_embeddings function returns the embedding layer of the model.

    :param self: Refer to the current object
    :return: The embedding layer of the model
    """
    raise NotImplementedError()

get_output_embeddings()

The get_output_embeddings function returns the output embeddings of a model.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description

The output embeddings of the model

Source code in src/python/easydel/modules/easydel_modelling_utils.py
494
495
496
497
498
499
500
501
def get_output_embeddings(self):
    """
    The get_output_embeddings function returns the output embeddings of a model.

    :param self: Represent the instance of the class
    :return: The output embeddings of the model
    """
    raise NotImplementedError()

prepare_inputs_for_generation(input_ids, max_length, attention_mask=None)

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
input_ids

Pass in the input tokens

required
max_length

Set the length of the sequence to be generated

required
attention_mask Optional[Array]

Optional[chex.Array]: Mask the attention weights

None

Returns:

Type Description

A dictionary of the past_key_values, attention_mask and position ids

Source code in src/python/easydel/modules/easydel_modelling_utils.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
    """
    The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

    :param self: Access variables that belong to the class
    :param input_ids: Pass in the input tokens
    :param max_length: Set the length of the sequence to be generated
    :param attention_mask: Optional[chex.Array]: Mask the attention weights
    :return: A dictionary of the past_key_values, attention_mask and position ids

    """
    batch_size, seq_length = input_ids.shape

    past_key_values = self.init_cache(batch_size, max_length)
    extended_attention_mask = jnp.ones(
        (batch_size, max_length), dtype="i4")
    if attention_mask is not None:
        position_ids = attention_mask.cumsum(axis=-1) - 1
        extended_attention_mask = jax.lax.dynamic_update_slice(
            extended_attention_mask, attention_mask, (0, 0))
    else:
        position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                        None, :], (batch_size, seq_length))

    return {
        "past_key_values": past_key_values,
        "attention_mask": extended_attention_mask,
        "position_ids": position_ids,
    }

set_decoder(decoder)

The set_decoder function is used to set the decoder for a given encoder.

Parameters:

Name Type Description Default
self

Refer to the object itself

required
decoder

Set the decoder for a given encoder

required

Returns:

Type Description

A decoder

Source code in src/python/easydel/modules/easydel_modelling_utils.py
517
518
519
520
521
522
523
524
525
def set_decoder(self, decoder):
    """
    The set_decoder function is used to set the decoder for a given encoder.

    :param self: Refer to the object itself
    :param decoder: Set the decoder for a given encoder
    :return: A decoder
    """
    raise NotImplementedError()

set_input_embeddings(value)

The set_input_embeddings function is used to set the embedding module of the model.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
value

Set the embeddings of the model

required
Source code in src/python/easydel/modules/easydel_modelling_utils.py
485
486
487
488
489
490
491
492
def set_input_embeddings(self, value):
    """
    The set_input_embeddings function is used to set the embedding module of the model.

    :param self: Represent the instance of the class
    :param value: Set the embeddings of the model
    """
    raise NotImplementedError()

set_output_embeddings(new_embeddings)

The set_output_embeddings function is used to set the output embeddings of a model. This function can be used to change the output embedding layer of a pretrained model in order to finetune it to some downstream task. Changing this layer has an effect only if the model has already been fine-tuned on some task (e.g., for classification). If you are training your own language models, you should call this function before you start training.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
new_embeddings

Set the embeddings of the output layer

required

Returns:

Type Description

A new embedding layer

Source code in src/python/easydel/modules/easydel_modelling_utils.py
503
504
505
506
507
508
509
510
511
512
513
514
515
def set_output_embeddings(self, new_embeddings):
    """
    The set_output_embeddings function is used to set the output embeddings of a model.
    This function can be used to change the output embedding layer of a pretrained model in order to finetune it
    to some downstream task. Changing this layer has an effect only if the model has already been fine-tuned on some
    task (e.g., for classification). If you are training your own language models, you should call this function before
    you start training.

    :param self: Represent the instance of the class
    :param new_embeddings: Set the embeddings of the output layer
    :return: A new embedding layer
    """
    raise NotImplementedError()

EasyDeLPretrainedConfig

Bases: PretrainedConfig

It initializes all the attributes of an object, and it's called when you create a new instance of that class.

Parameters:

Name Type Description Default
self

Refer to the instance of the class

required
axis_dims Sequence[int]

Sequence[int]: Specify the number of dimensions for each axis

(1, -1, 1, 1)
axis_names Sequence[str]

Sequence[str]: Set the names of the axes

('dp', 'fsdp', 'tp', 'sp')
attn_mechanism AVAILABLE_ATTENTION_MECHANISMS

Literal["vanilla", "flash", "splash", "ring"]: attention mechanism to use

'sharded_vanilla'
block_k int

int: block size of key_states

128
block_q int

int: block size of query_states

128
block_b int

int: block size of bias

1
block_q_major_dkv int | None

int: block size of block_q_major_dkv

None
block_k_major_dkv int | None

int: block size of block_k_major_dkv

None
block_k_dkv int | None

int: block size of block_k_dkv

None
block_q_dkv int | None

int: block size of block_q_dkv

None
block_k_major_dq int | None

int: block size of block_k_major_dq

None
block_k_dq int | None

int: block size of block_k_dq

None
block_q_dq int | None

int: block size of block_q_dq

None
query_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the query tensor

PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None)
key_partition_spec PartitionSpec

PartitionSpec: Partition the key matrix

PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None)
value_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the value tensor

PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None)
bias_partition_spec PartitionSpec

PartitionSpec: Specify the Attention Bias partition spec

PartitionSpec(('dp', 'fsdp'), None, None, None)
attention_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the attention weights

PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None)
shard_attention_computation bool

bool: whenever to shard qkv b for attention

True
use_sharding_constraint bool

bool: whether to use sharding constraint for the arrays

False
use_scan_mlp bool

bool: Determine whether to use scan_mlp or not

True
backend Optional[None]

Optional[None]: Specify the backend to use

default_backend()
flash_attention_backward_pass_impl Literal['triton', 'xla']

Literal["triton", "xla"]: Specify the backward pass kernel for flash attention

'triton'
Source code in src/python/easydel/modules/easydel_modelling_utils.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
class EasyDeLPretrainedConfig(PretrainedConfig):
    """
    It initializes all the attributes of an object, and it's called when you create a new instance of that class.
    :param self: Refer to the instance of the class
    :param axis_dims: Sequence[int]: Specify the number of dimensions for each axis
    :param axis_names: Sequence[str]: Set the names of the axes
    :param attn_mechanism: Literal["vanilla", "flash", "splash", "ring"]: attention mechanism to use
    :param block_k: int: block size of key_states
    :param block_q: int: block size of query_states
    :param block_b: int: block size of bias
    :param block_q_major_dkv: int: block size of block_q_major_dkv
    :param block_k_major_dkv: int: block size of block_k_major_dkv
    :param block_k_dkv: int: block size of block_k_dkv
    :param block_q_dkv: int: block size of block_q_dkv
    :param block_k_major_dq: int: block size of block_k_major_dq
    :param block_k_dq: int: block size of block_k_dq
    :param block_q_dq: int: block size of block_q_dq
    :param query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor
    :param key_partition_spec: PartitionSpec: Partition the key matrix
    :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor
    :param bias_partition_spec: PartitionSpec: Specify the Attention Bias partition spec
    :param attention_partition_spec: PartitionSpec: Specify the partitioning of the attention weights
    :param shard_attention_computation: bool: whenever to shard qkv b for attention
    :param use_sharding_constraint: bool: whether to use sharding constraint for the arrays
    :param use_scan_mlp: bool: Determine whether to use scan_mlp or not
    :param backend: Optional[None]: Specify the backend to use
    :param flash_attention_backward_pass_impl: Literal["triton", "xla"]: Specify the backward pass kernel for flash attention
    """

    def __init__(
            self,
            axis_dims: Sequence[int] = (1, -1, 1, 1),
            axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"),
            attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS = "sharded_vanilla",
            block_k: int = 128,
            block_q: int = 128,
            block_b: int = 1,
            block_k_major: int = 128,
            block_q_major_dkv: int | None = None,
            block_k_major_dkv: int | None = None,
            block_k_dkv: int | None = None,
            block_q_dkv: int | None = None,
            block_k_major_dq: int | None = None,
            block_k_dq: int | None = None,
            block_q_dq: int | None = None,
            query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            generation_query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, "tp", None),
            key_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            value_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
            generation_bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
            attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            generation_attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, "tp", None),
            shard_attention_computation: bool = True,
            use_sharded_kv_caching: bool = True,
            use_sharding_constraint: bool = False,
            backend: Optional[None] = jax.default_backend(),
            easy_method: Literal["train", "serve", "convert"] = EasyMethod.TRAIN,
            bits: Optional[int] = None,
            scan_ring_attention: bool = True,
            scan_attention_layers: bool = False,
            use_scan_mlp: bool = True,
            scan_mlp_chunk_size: int = 1024,
            attention_axis_name: str = "sp",
            quantize_kv_cache: bool = False,
            flash_attention_backward_pass_impl: Literal["triton", "xla"] = "triton",
            **kwargs
    ):
        self.query_partition_spec = query_partition_spec
        self.generation_query_partition_spec = generation_query_partition_spec
        self.key_partition_spec = key_partition_spec
        self.value_partition_spec = value_partition_spec
        self.bias_partition_spec = bias_partition_spec
        self.generation_bias_partition_spec = generation_bias_partition_spec
        self.attention_partition_spec = attention_partition_spec
        self.generation_attention_partition_spec = generation_attention_partition_spec
        self.shard_attention_computation = shard_attention_computation
        self.axis_dims = axis_dims
        self.axis_names = axis_names
        self.backend = backend if backend is not None else ""
        self.easy_method = easy_method
        self.attn_mechanism = attn_mechanism
        self.block_b = block_b
        self.block_k = block_k
        self.block_q = block_q
        self.block_k_major = block_k_major
        self.block_q_major_dkv = block_q_major_dkv or block_q
        self.block_k_major_dkv = block_k_major_dkv or block_k
        self.block_k_dkv = block_k_dkv or block_k
        self.block_q_dkv = block_q_dkv or block_q
        self.block_k_major_dq = block_k_major_dq or block_k
        self.block_k_dq = block_k_dq or block_k
        self.block_q_dq = block_q_dq or block_q
        self.bits = bits
        self.scan_attention_layers = scan_attention_layers
        self.scan_ring_attention = scan_ring_attention
        self.use_sharded_kv_caching = use_sharded_kv_caching
        self.use_scan_mlp = use_scan_mlp
        self.scan_mlp_chunk_size = scan_mlp_chunk_size
        self.use_sharding_constraint = use_sharding_constraint
        self.attention_axis_name = attention_axis_name
        self.quantize_kv_cache = quantize_kv_cache
        self.flash_attention_backward_pass_impl = flash_attention_backward_pass_impl
        super().__init__(**kwargs)

    @staticmethod
    def create_mesh(
            axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"), backend=""
    ):
        """
        The create_mesh function creates a mesh object that can be used to shard arrays.

        :param axis_dims: Sequence[int]: Specify the dimensions of the mesh
        :param axis_names: Sequence[str]: Name the axes of the mesh
        :param backend: Specify the backend to use
        :return: A mesh object

        """
        array_devices = jax.numpy.ones(
            (len(jax.devices() if backend == "" else jax.devices(backend)), 1))
        if isinstance(axis_dims, str):
            axis_dims = eval(axis_dims)
            warnings.warn(
                "axis_dims argument is not a Sequence of int and it's an string. "
                "(backbone Warning in EasyDeLModuleConfig)\n"
                f"\tchanged to {axis_dims}"
            )
        if isinstance(axis_names, str):
            axis_names = eval(axis_names)
            warnings.warn(
                "axis_names argument is not a Sequence of strings and it's an string class. "
                "(backbone Warning in EasyDeLModuleConfig)\n"
                f"\tchanged to {axis_names}"
            )
        resh = array_devices.reshape(axis_dims).shape

        return Mesh(
            create_device_mesh(resh), axis_names
        )

    def jax_mesh(self) -> Mesh:
        """
        The jax_mesh function is a helper function that creates a Mesh object from the
        axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively.
        The backend attribute is also used if it exists.

        :param self: Refer to the object itself
        :return: A jaxMesh

        """
        return self.create_mesh(
            axis_dims=[v for k, v in self.axis_dims.items()] if isinstance(
                self.axis_dims,
                dict
            ) else self.axis_dims,
            axis_names=[v for k, v in self.axis_names.items()] if isinstance(
                self.axis_names,
                dict
            ) else self.axis_names,
            backend=(self.backend if self.backend is not None else "") if hasattr(
                self, 'backend') else ""
        )

    def get_partition_rules(self, fully_sharded_data_parallel: bool = True):

        """
        The get_partition_rules function is used to specify how the parameters of a model are partitioned across devices.

        :param self: Access the attributes of the class
        :param fully_sharded_data_parallel: bool: Determine whether the model is fully sharded or not
        :return: A tuple of tuples
        """
        if not fully_sharded_data_parallel:
            raise NotImplementedError()
        else:
            return (
                ('.*', PartitionSpec(("fsdp", "sp"), ),),
            )

    def get_axis_dims(self) -> Sequence[int]:
        """
        The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.

        :param self: Represent the instance of the class
        :return: The dimensions of the axes

        """
        return self.axis_dims

    def get_axis_names(self) -> Sequence[str]:
        """
        The get_axis_names function returns a list of the names of the axes.

        :param self: Represent the instance of the class
        :return: A list of the names of all axes

        """
        return self.axis_names

    def get_backend(self) -> str:
        """
        The get_backend function returns the backend that is currently being used.
        If no backend has been set, it will return the default JAX backend.

        :param self: Bind the method to an object
        :return: The backend platform

        """
        return self.backend if not self.backend == "" else jax.lib.xla_bridge.get_backend().platform

    def add_basic_configurations(
            self,
            axis_dims: Sequence[int] = ...,
            axis_names: Sequence[str] = ...,
            attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS = ...,
            block_k: int = ...,
            block_q: int = ...,
            block_b: int = ...,
            block_k_major: int = ...,
            block_q_major_dkv: int | None = ...,
            block_k_major_dkv: int | None = ...,
            block_k_dkv: int | None = ...,
            block_q_dkv: int | None = ...,
            block_k_major_dq: int | None = ...,
            block_k_dq: int | None = ...,
            block_q_dq: int | None = ...,
            query_partition_spec: PartitionSpec = ...,
            generation_query_partition_spec: PartitionSpec = ...,
            key_partition_spec: PartitionSpec = ...,
            value_partition_spec: PartitionSpec = ...,
            bias_partition_spec: PartitionSpec = ...,
            attention_partition_spec: PartitionSpec = ...,
            generation_bias_partition_spec: PartitionSpec = ...,
            generation_attention_partition_spec: PartitionSpec = ...,
            shard_attention_computation: bool = ...,
            use_sharded_kv_caching: bool = ...,
            backend: Optional[None] = ...,
            easy_method: Literal["train", "serve", "convert"] = ...,
            bits: Optional[int] = ...,
            scan_ring_attention: bool = ...,
            scan_attention_layers: bool = ...,
            use_sharding_constraint: bool = ...,
            use_scan_mlp: bool = ...,
            scan_mlp_chunk_size: int = ...,
            attention_axis_name: str = ...,
            quantize_kv_cache: bool = ...,
            flash_attention_backward_pass_impl: Literal["triton", "xla"] = ...
    ):
        """
        It initializes all the attributes of an object, and it's called when you create a new instance of that class.
        :param self: Refer to the instance of the class
        :param axis_dims: Sequence[int]: Specify the number of dimensions for each axis
        :param axis_names: Sequence[str]: Set the names of the axes
        :param attn_mechanism: Literal["vanilla", "flash", "splash"]: attention mechanism to use
        :param block_k: int: block size of key_states
        :param block_q: int: block size of query_states
        :param block_b: int: block size of bias
        :param block_k_major: int: block size if key major
        :param block_q_major_dkv: int: block size of block_q_major_dkv
        :param block_k_major_dkv: int: block size of block_k_major_dkv
        :param block_k_dkv: int: block size of block_k_dkv
        :param block_q_dkv: int: block size of block_q_dkv
        :param block_k_major_dq: int: block size of block_k_major_dq
        :param block_k_dq: int: block size of block_k_dq
        :param block_q_dq: int: block size of block_q_dq
        :param query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor
        :param key_partition_spec: PartitionSpec: Partition the key matrix
        :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor
        :param bias_partition_spec: PartitionSpec: Specify the Attention Bias partition spec
        :param attention_partition_spec: PartitionSpec: Specify the partitioning of the attention weights
        :param generation_attention_partition_spec: : PartitionSpec: Specify the partitioning of the attention weights
        in generation process
        :param generation_bias_partition_spec: : PartitionSpec: Specify the partitioning of the Attention Bias
         partition spec in generation process
        :param generation_query_partition_spec: : PartitionSpec: Specify the partitioning of the query tensor
        in generation process
        :param shard_attention_computation: bool: whenever to use shard_map for attention
        :param use_sharded_kv_caching: bool: whenever to use shard_map and sharding for key and value
        :param backend: Optional[None]: Specify the backend to use
        :param easy_method: Literal["train", "serve", "convert"]: easydel Quantization Method to be applied for
        :param bits: Optional[int]: Model bits for quantization
        :param use_sharding_constraint: bool: whether to use sharding constraint for the arrays
        :param scan_ring_attention: bool: Whether to use can for ring attention
        :param scan_attention_layers: bool: Whether to use can for attention layers
        :param use_scan_mlp: bool: Determine whether to use scan_mlp or not
        :param scan_mlp_chunk_size: int: Size of chunks in scan MLP.
        :param attention_axis_name: str: Name of the attention axis name
        :param quantize_kv_cache: bool: Whether to quantize Key/Value in attention for generation process.
        :param flash_attention_backward_pass_impl: Literal["triton", "xla"]: Specify the backward pass kernel for flash attention
        """
        set_attrs_smartly(self, "axis_dims", (1, -1, 1, 1), axis_dims)
        set_attrs_smartly(self, "axis_names", ("dp", "fsdp", "tp", "sp"), axis_names)

        set_attrs_smartly(self, "block_q", 1024, block_q)
        set_attrs_smartly(self, "block_k", 1024, block_k)
        set_attrs_smartly(self, "block_b", 1024, block_b)

        set_attrs_smartly(
            self,
            "query_partition_spec",
            PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            query_partition_spec
        )
        set_attrs_smartly(
            self,
            "generation_query_partition_spec",
            PartitionSpec(("dp", "fsdp"), None, "tp", None),
            generation_query_partition_spec
        )
        set_attrs_smartly(
            self,
            "generation_bias_partition_spec",
            PartitionSpec(("dp", "fsdp"), None, None, None),
            generation_bias_partition_spec
        )
        set_attrs_smartly(
            self,
            "key_partition_spec",
            PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            key_partition_spec
        )
        set_attrs_smartly(
            self,
            "value_partition_spec",
            PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            value_partition_spec
        )
        set_attrs_smartly(
            self,
            "bias_partition_spec",
            PartitionSpec(("dp", "fsdp"), None, None, None),
            bias_partition_spec
        )
        set_attrs_smartly(
            self,
            "attention_partition_spec",
            PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
            attention_partition_spec
        )
        set_attrs_smartly(
            self,
            "generation_attention_partition_spec",
            PartitionSpec(("dp", "fsdp"), None, "tp", None),
            generation_attention_partition_spec
        )
        set_attrs_smartly(self, "use_sharding_constraint", False, use_sharding_constraint)
        set_attrs_smartly(self, "backend", jax.default_backend(), backend)
        set_attrs_smartly(self, "shard_attention_computation", True, shard_attention_computation)
        set_attrs_smartly(self, "use_sharded_kv_caching", True, use_sharded_kv_caching)
        set_attrs_smartly(self, "attn_mechanism", "sharded_vanilla", attn_mechanism)

        set_attrs_smartly(self, "block_k_dkv", block_k_dkv or self.block_k, block_k_dkv)
        set_attrs_smartly(self, "block_q_dkv", block_q_dkv or self.block_q, block_q_dkv)

        set_attrs_smartly(self, "block_q_major_dkv", block_q_major_dkv or self.block_q, block_q_major_dkv)
        set_attrs_smartly(self, "block_k_major_dkv", block_k_major_dkv or self.block_k, block_k_major_dkv)

        set_attrs_smartly(self, "block_k_major", block_k_major or self.block_k, block_k_major)
        set_attrs_smartly(self, "block_k_major_dq", block_k_major_dq or self.block_k, block_k_major_dq)

        set_attrs_smartly(self, "block_k_dq", block_k_dq or self.block_k, block_k_dq)
        set_attrs_smartly(self, "block_q_dq", block_q_dq or self.block_q, block_q_dq)

        set_attrs_smartly(self, "easy_method", EasyMethod.TRAIN, easy_method)
        set_attrs_smartly(self, "bits", None, bits)
        set_attrs_smartly(self, "scan_attention_layers", True, scan_attention_layers)
        set_attrs_smartly(self, "scan_ring_attention", True, scan_ring_attention)
        set_attrs_smartly(self, "use_scan_mlp", True, use_scan_mlp)
        set_attrs_smartly(self, "scan_mlp_chunk_size", 1024, scan_mlp_chunk_size)
        set_attrs_smartly(self, "attention_axis_name", "sp", attention_axis_name)
        set_attrs_smartly(self, "quantize_kv_cache", False, quantize_kv_cache)
        set_attrs_smartly(self, "flash_attention_backward_pass_impl", "triton", flash_attention_backward_pass_impl)

    def __repr__(self):

        """
        The __repr__ function is used to generate a string representation of an object.
        This function should return a string that can be parsed by the Python interpreter
        to recreate the object. The __repr__ function is called when you use print() on an
        object, or when you type its name in the REPL.

        :param self: Refer to the instance of the class
        :return: A string representation of the object
        """
        string = f"{self.__class__.__name__}(\n"
        for k, v in self.__dict__.items():
            if not k.startswith("_"):
                try:
                    repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n"
                    string += repr_src if len(repr_src) < 500 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n"
                except TypeError:
                    pass
        return string + ")"

    def add_jax_args(self, **kwargs):
        for k, v in kwargs.items():
            set_attrs_smartly(self, "k", v, v)

    def __str__(self):

        """
        The __str__ function is called when you use the print function or when str() is used.
        It should return a string representation of the object.

        :param self: Refer to the instance of the class
        :return: The object's string representation
        """
        return self.__repr__()

__repr__()

The repr function is used to generate a string representation of an object. This function should return a string that can be parsed by the Python interpreter to recreate the object. The repr function is called when you use print() on an object, or when you type its name in the REPL.

Parameters:

Name Type Description Default
self

Refer to the instance of the class

required

Returns:

Type Description

A string representation of the object

Source code in src/python/easydel/modules/easydel_modelling_utils.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
def __repr__(self):

    """
    The __repr__ function is used to generate a string representation of an object.
    This function should return a string that can be parsed by the Python interpreter
    to recreate the object. The __repr__ function is called when you use print() on an
    object, or when you type its name in the REPL.

    :param self: Refer to the instance of the class
    :return: A string representation of the object
    """
    string = f"{self.__class__.__name__}(\n"
    for k, v in self.__dict__.items():
        if not k.startswith("_"):
            try:
                repr_src = f"\t{k} : " + v.__str__().replace("\n", "\n\t") + "\n"
                string += repr_src if len(repr_src) < 500 else f"\t{k} : " + f"{v.__class__.__name__}(...)" + "\n"
            except TypeError:
                pass
    return string + ")"

__str__()

The str function is called when you use the print function or when str() is used. It should return a string representation of the object.

Parameters:

Name Type Description Default
self

Refer to the instance of the class

required

Returns:

Type Description

The object's string representation

Source code in src/python/easydel/modules/easydel_modelling_utils.py
443
444
445
446
447
448
449
450
451
452
def __str__(self):

    """
    The __str__ function is called when you use the print function or when str() is used.
    It should return a string representation of the object.

    :param self: Refer to the instance of the class
    :return: The object's string representation
    """
    return self.__repr__()

add_basic_configurations(axis_dims=..., axis_names=..., attn_mechanism=..., block_k=..., block_q=..., block_b=..., block_k_major=..., block_q_major_dkv=..., block_k_major_dkv=..., block_k_dkv=..., block_q_dkv=..., block_k_major_dq=..., block_k_dq=..., block_q_dq=..., query_partition_spec=..., generation_query_partition_spec=..., key_partition_spec=..., value_partition_spec=..., bias_partition_spec=..., attention_partition_spec=..., generation_bias_partition_spec=..., generation_attention_partition_spec=..., shard_attention_computation=..., use_sharded_kv_caching=..., backend=..., easy_method=..., bits=..., scan_ring_attention=..., scan_attention_layers=..., use_sharding_constraint=..., use_scan_mlp=..., scan_mlp_chunk_size=..., attention_axis_name=..., quantize_kv_cache=..., flash_attention_backward_pass_impl=...)

It initializes all the attributes of an object, and it's called when you create a new instance of that class.

Parameters:

Name Type Description Default
self

Refer to the instance of the class

required
axis_dims Sequence[int]

Sequence[int]: Specify the number of dimensions for each axis

...
axis_names Sequence[str]

Sequence[str]: Set the names of the axes

...
attn_mechanism AVAILABLE_ATTENTION_MECHANISMS

Literal["vanilla", "flash", "splash"]: attention mechanism to use

...
block_k int

int: block size of key_states

...
block_q int

int: block size of query_states

...
block_b int

int: block size of bias

...
block_k_major int

int: block size if key major

...
block_q_major_dkv int | None

int: block size of block_q_major_dkv

...
block_k_major_dkv int | None

int: block size of block_k_major_dkv

...
block_k_dkv int | None

int: block size of block_k_dkv

...
block_q_dkv int | None

int: block size of block_q_dkv

...
block_k_major_dq int | None

int: block size of block_k_major_dq

...
block_k_dq int | None

int: block size of block_k_dq

...
block_q_dq int | None

int: block size of block_q_dq

...
query_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the query tensor

...
key_partition_spec PartitionSpec

PartitionSpec: Partition the key matrix

...
value_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the value tensor

...
bias_partition_spec PartitionSpec

PartitionSpec: Specify the Attention Bias partition spec

...
attention_partition_spec PartitionSpec

PartitionSpec: Specify the partitioning of the attention weights

...
generation_attention_partition_spec PartitionSpec

: PartitionSpec: Specify the partitioning of the attention weights in generation process

...
generation_bias_partition_spec PartitionSpec

: PartitionSpec: Specify the partitioning of the Attention Bias partition spec in generation process

...
generation_query_partition_spec PartitionSpec

: PartitionSpec: Specify the partitioning of the query tensor in generation process

...
shard_attention_computation bool

bool: whenever to use shard_map for attention

...
use_sharded_kv_caching bool

bool: whenever to use shard_map and sharding for key and value

...
backend Optional[None]

Optional[None]: Specify the backend to use

...
easy_method Literal['train', 'serve', 'convert']

Literal["train", "serve", "convert"]: easydel Quantization Method to be applied for

...
bits Optional[int]

Optional[int]: Model bits for quantization

...
use_sharding_constraint bool

bool: whether to use sharding constraint for the arrays

...
scan_ring_attention bool

bool: Whether to use can for ring attention

...
scan_attention_layers bool

bool: Whether to use can for attention layers

...
use_scan_mlp bool

bool: Determine whether to use scan_mlp or not

...
scan_mlp_chunk_size int

int: Size of chunks in scan MLP.

...
attention_axis_name str

str: Name of the attention axis name

...
quantize_kv_cache bool

bool: Whether to quantize Key/Value in attention for generation process.

...
flash_attention_backward_pass_impl Literal['triton', 'xla']

Literal["triton", "xla"]: Specify the backward pass kernel for flash attention

...
Source code in src/python/easydel/modules/easydel_modelling_utils.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def add_basic_configurations(
        self,
        axis_dims: Sequence[int] = ...,
        axis_names: Sequence[str] = ...,
        attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS = ...,
        block_k: int = ...,
        block_q: int = ...,
        block_b: int = ...,
        block_k_major: int = ...,
        block_q_major_dkv: int | None = ...,
        block_k_major_dkv: int | None = ...,
        block_k_dkv: int | None = ...,
        block_q_dkv: int | None = ...,
        block_k_major_dq: int | None = ...,
        block_k_dq: int | None = ...,
        block_q_dq: int | None = ...,
        query_partition_spec: PartitionSpec = ...,
        generation_query_partition_spec: PartitionSpec = ...,
        key_partition_spec: PartitionSpec = ...,
        value_partition_spec: PartitionSpec = ...,
        bias_partition_spec: PartitionSpec = ...,
        attention_partition_spec: PartitionSpec = ...,
        generation_bias_partition_spec: PartitionSpec = ...,
        generation_attention_partition_spec: PartitionSpec = ...,
        shard_attention_computation: bool = ...,
        use_sharded_kv_caching: bool = ...,
        backend: Optional[None] = ...,
        easy_method: Literal["train", "serve", "convert"] = ...,
        bits: Optional[int] = ...,
        scan_ring_attention: bool = ...,
        scan_attention_layers: bool = ...,
        use_sharding_constraint: bool = ...,
        use_scan_mlp: bool = ...,
        scan_mlp_chunk_size: int = ...,
        attention_axis_name: str = ...,
        quantize_kv_cache: bool = ...,
        flash_attention_backward_pass_impl: Literal["triton", "xla"] = ...
):
    """
    It initializes all the attributes of an object, and it's called when you create a new instance of that class.
    :param self: Refer to the instance of the class
    :param axis_dims: Sequence[int]: Specify the number of dimensions for each axis
    :param axis_names: Sequence[str]: Set the names of the axes
    :param attn_mechanism: Literal["vanilla", "flash", "splash"]: attention mechanism to use
    :param block_k: int: block size of key_states
    :param block_q: int: block size of query_states
    :param block_b: int: block size of bias
    :param block_k_major: int: block size if key major
    :param block_q_major_dkv: int: block size of block_q_major_dkv
    :param block_k_major_dkv: int: block size of block_k_major_dkv
    :param block_k_dkv: int: block size of block_k_dkv
    :param block_q_dkv: int: block size of block_q_dkv
    :param block_k_major_dq: int: block size of block_k_major_dq
    :param block_k_dq: int: block size of block_k_dq
    :param block_q_dq: int: block size of block_q_dq
    :param query_partition_spec: PartitionSpec: Specify the partitioning of the query tensor
    :param key_partition_spec: PartitionSpec: Partition the key matrix
    :param value_partition_spec: PartitionSpec: Specify the partitioning of the value tensor
    :param bias_partition_spec: PartitionSpec: Specify the Attention Bias partition spec
    :param attention_partition_spec: PartitionSpec: Specify the partitioning of the attention weights
    :param generation_attention_partition_spec: : PartitionSpec: Specify the partitioning of the attention weights
    in generation process
    :param generation_bias_partition_spec: : PartitionSpec: Specify the partitioning of the Attention Bias
     partition spec in generation process
    :param generation_query_partition_spec: : PartitionSpec: Specify the partitioning of the query tensor
    in generation process
    :param shard_attention_computation: bool: whenever to use shard_map for attention
    :param use_sharded_kv_caching: bool: whenever to use shard_map and sharding for key and value
    :param backend: Optional[None]: Specify the backend to use
    :param easy_method: Literal["train", "serve", "convert"]: easydel Quantization Method to be applied for
    :param bits: Optional[int]: Model bits for quantization
    :param use_sharding_constraint: bool: whether to use sharding constraint for the arrays
    :param scan_ring_attention: bool: Whether to use can for ring attention
    :param scan_attention_layers: bool: Whether to use can for attention layers
    :param use_scan_mlp: bool: Determine whether to use scan_mlp or not
    :param scan_mlp_chunk_size: int: Size of chunks in scan MLP.
    :param attention_axis_name: str: Name of the attention axis name
    :param quantize_kv_cache: bool: Whether to quantize Key/Value in attention for generation process.
    :param flash_attention_backward_pass_impl: Literal["triton", "xla"]: Specify the backward pass kernel for flash attention
    """
    set_attrs_smartly(self, "axis_dims", (1, -1, 1, 1), axis_dims)
    set_attrs_smartly(self, "axis_names", ("dp", "fsdp", "tp", "sp"), axis_names)

    set_attrs_smartly(self, "block_q", 1024, block_q)
    set_attrs_smartly(self, "block_k", 1024, block_k)
    set_attrs_smartly(self, "block_b", 1024, block_b)

    set_attrs_smartly(
        self,
        "query_partition_spec",
        PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        query_partition_spec
    )
    set_attrs_smartly(
        self,
        "generation_query_partition_spec",
        PartitionSpec(("dp", "fsdp"), None, "tp", None),
        generation_query_partition_spec
    )
    set_attrs_smartly(
        self,
        "generation_bias_partition_spec",
        PartitionSpec(("dp", "fsdp"), None, None, None),
        generation_bias_partition_spec
    )
    set_attrs_smartly(
        self,
        "key_partition_spec",
        PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        key_partition_spec
    )
    set_attrs_smartly(
        self,
        "value_partition_spec",
        PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        value_partition_spec
    )
    set_attrs_smartly(
        self,
        "bias_partition_spec",
        PartitionSpec(("dp", "fsdp"), None, None, None),
        bias_partition_spec
    )
    set_attrs_smartly(
        self,
        "attention_partition_spec",
        PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
        attention_partition_spec
    )
    set_attrs_smartly(
        self,
        "generation_attention_partition_spec",
        PartitionSpec(("dp", "fsdp"), None, "tp", None),
        generation_attention_partition_spec
    )
    set_attrs_smartly(self, "use_sharding_constraint", False, use_sharding_constraint)
    set_attrs_smartly(self, "backend", jax.default_backend(), backend)
    set_attrs_smartly(self, "shard_attention_computation", True, shard_attention_computation)
    set_attrs_smartly(self, "use_sharded_kv_caching", True, use_sharded_kv_caching)
    set_attrs_smartly(self, "attn_mechanism", "sharded_vanilla", attn_mechanism)

    set_attrs_smartly(self, "block_k_dkv", block_k_dkv or self.block_k, block_k_dkv)
    set_attrs_smartly(self, "block_q_dkv", block_q_dkv or self.block_q, block_q_dkv)

    set_attrs_smartly(self, "block_q_major_dkv", block_q_major_dkv or self.block_q, block_q_major_dkv)
    set_attrs_smartly(self, "block_k_major_dkv", block_k_major_dkv or self.block_k, block_k_major_dkv)

    set_attrs_smartly(self, "block_k_major", block_k_major or self.block_k, block_k_major)
    set_attrs_smartly(self, "block_k_major_dq", block_k_major_dq or self.block_k, block_k_major_dq)

    set_attrs_smartly(self, "block_k_dq", block_k_dq or self.block_k, block_k_dq)
    set_attrs_smartly(self, "block_q_dq", block_q_dq or self.block_q, block_q_dq)

    set_attrs_smartly(self, "easy_method", EasyMethod.TRAIN, easy_method)
    set_attrs_smartly(self, "bits", None, bits)
    set_attrs_smartly(self, "scan_attention_layers", True, scan_attention_layers)
    set_attrs_smartly(self, "scan_ring_attention", True, scan_ring_attention)
    set_attrs_smartly(self, "use_scan_mlp", True, use_scan_mlp)
    set_attrs_smartly(self, "scan_mlp_chunk_size", 1024, scan_mlp_chunk_size)
    set_attrs_smartly(self, "attention_axis_name", "sp", attention_axis_name)
    set_attrs_smartly(self, "quantize_kv_cache", False, quantize_kv_cache)
    set_attrs_smartly(self, "flash_attention_backward_pass_impl", "triton", flash_attention_backward_pass_impl)

create_mesh(axis_dims=(1, -1, 1, 1), axis_names=('dp', 'fsdp', 'tp', 'sp'), backend='') staticmethod

The create_mesh function creates a mesh object that can be used to shard arrays.

Parameters:

Name Type Description Default
axis_dims Sequence[int]

Sequence[int]: Specify the dimensions of the mesh

(1, -1, 1, 1)
axis_names Sequence[str]

Sequence[str]: Name the axes of the mesh

('dp', 'fsdp', 'tp', 'sp')
backend

Specify the backend to use

''

Returns:

Type Description

A mesh object

Source code in src/python/easydel/modules/easydel_modelling_utils.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
@staticmethod
def create_mesh(
        axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"), backend=""
):
    """
    The create_mesh function creates a mesh object that can be used to shard arrays.

    :param axis_dims: Sequence[int]: Specify the dimensions of the mesh
    :param axis_names: Sequence[str]: Name the axes of the mesh
    :param backend: Specify the backend to use
    :return: A mesh object

    """
    array_devices = jax.numpy.ones(
        (len(jax.devices() if backend == "" else jax.devices(backend)), 1))
    if isinstance(axis_dims, str):
        axis_dims = eval(axis_dims)
        warnings.warn(
            "axis_dims argument is not a Sequence of int and it's an string. "
            "(backbone Warning in EasyDeLModuleConfig)\n"
            f"\tchanged to {axis_dims}"
        )
    if isinstance(axis_names, str):
        axis_names = eval(axis_names)
        warnings.warn(
            "axis_names argument is not a Sequence of strings and it's an string class. "
            "(backbone Warning in EasyDeLModuleConfig)\n"
            f"\tchanged to {axis_names}"
        )
    resh = array_devices.reshape(axis_dims).shape

    return Mesh(
        create_device_mesh(resh), axis_names
    )

get_axis_dims()

The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description
Sequence[int]

The dimensions of the axes

Source code in src/python/easydel/modules/easydel_modelling_utils.py
224
225
226
227
228
229
230
231
232
def get_axis_dims(self) -> Sequence[int]:
    """
    The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.

    :param self: Represent the instance of the class
    :return: The dimensions of the axes

    """
    return self.axis_dims

get_axis_names()

The get_axis_names function returns a list of the names of the axes.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required

Returns:

Type Description
Sequence[str]

A list of the names of all axes

Source code in src/python/easydel/modules/easydel_modelling_utils.py
234
235
236
237
238
239
240
241
242
def get_axis_names(self) -> Sequence[str]:
    """
    The get_axis_names function returns a list of the names of the axes.

    :param self: Represent the instance of the class
    :return: A list of the names of all axes

    """
    return self.axis_names

get_backend()

The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend.

Parameters:

Name Type Description Default
self

Bind the method to an object

required

Returns:

Type Description
str

The backend platform

Source code in src/python/easydel/modules/easydel_modelling_utils.py
244
245
246
247
248
249
250
251
252
253
def get_backend(self) -> str:
    """
    The get_backend function returns the backend that is currently being used.
    If no backend has been set, it will return the default JAX backend.

    :param self: Bind the method to an object
    :return: The backend platform

    """
    return self.backend if not self.backend == "" else jax.lib.xla_bridge.get_backend().platform

get_partition_rules(fully_sharded_data_parallel=True)

The get_partition_rules function is used to specify how the parameters of a model are partitioned across devices.

Parameters:

Name Type Description Default
self

Access the attributes of the class

required
fully_sharded_data_parallel bool

bool: Determine whether the model is fully sharded or not

True

Returns:

Type Description

A tuple of tuples

Source code in src/python/easydel/modules/easydel_modelling_utils.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def get_partition_rules(self, fully_sharded_data_parallel: bool = True):

    """
    The get_partition_rules function is used to specify how the parameters of a model are partitioned across devices.

    :param self: Access the attributes of the class
    :param fully_sharded_data_parallel: bool: Determine whether the model is fully sharded or not
    :return: A tuple of tuples
    """
    if not fully_sharded_data_parallel:
        raise NotImplementedError()
    else:
        return (
            ('.*', PartitionSpec(("fsdp", "sp"), ),),
        )

jax_mesh()

The jax_mesh function is a helper function that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The backend attribute is also used if it exists.

Parameters:

Name Type Description Default
self

Refer to the object itself

required

Returns:

Type Description
Mesh

A jaxMesh

Source code in src/python/easydel/modules/easydel_modelling_utils.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def jax_mesh(self) -> Mesh:
    """
    The jax_mesh function is a helper function that creates a Mesh object from the
    axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively.
    The backend attribute is also used if it exists.

    :param self: Refer to the object itself
    :return: A jaxMesh

    """
    return self.create_mesh(
        axis_dims=[v for k, v in self.axis_dims.items()] if isinstance(
            self.axis_dims,
            dict
        ) else self.axis_dims,
        axis_names=[v for k, v in self.axis_names.items()] if isinstance(
            self.axis_names,
            dict
        ) else self.axis_names,
        backend=(self.backend if self.backend is not None else "") if hasattr(
            self, 'backend') else ""
    )