modules.jetmoe.jetmoe_configuration
JetMoEConfig
Bases: EasyDeLPretrainedConfig
Source code in src/python/easydel/modules/jetmoe/jetmoe_configuration.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
add_jax_args(tie_word_embeddings=False, gradient_checkpointing='nothing_saveable', bits=None, **kwargs)
The add_jax_args function adds the following arguments to the Transformer class:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
Refer to the current object |
required | |
tie_word_embeddings |
bool
|
bool: Tie the word embeddings to the decoder |
False
|
gradient_checkpointing |
str
|
str: Control the amount of memory used by jax |
'nothing_saveable'
|
bits |
Optional[int]
|
Optional[int]: Determine the number of bits used in the quantization |
None
|
Source code in src/python/easydel/modules/jetmoe/jetmoe_configuration.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
|
get_partition_rules(fully_sharded_data_parallel=True)
The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fully_sharded_data_parallel |
bool
|
bool: Determine whether to partition the model fully or not |
True
|
Returns:
Type | Description |
---|---|
A list of tuples |
Source code in src/python/easydel/modules/jetmoe/jetmoe_configuration.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
|