Skip to content

modules.llama.llama_configuration

LlamaConfig

Bases: EasyDeLPretrainedConfig

Source code in src/python/easydel/modules/llama/llama_configuration.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
class LlamaConfig(EasyDeLPretrainedConfig):
    model_type: str = "llama"

    def __init__(
            self,
            vocab_size: int = 32000,
            hidden_size: int = 4096,
            intermediate_size: int = 11008,
            num_hidden_layers: int = 32,
            num_attention_heads: int = 32,
            number_rep_kv: int = 1,
            num_key_value_heads: Optional[int] = None,
            max_position_embeddings: int = 2048,
            rms_norm_eps: float = 1e-6,
            initializer_range: float = 0.02,
            use_cache: bool = True,
            bos_token_id: int = 0,
            eos_token_id: int = 1,
            resid_pdrop: float = 0.0,
            embd_pdrop: float = 0.0,
            attention_dropout: float = 0.0,
            rope_theta: float = 10000.,
            attention_bias: bool = False,
            tie_word_embeddings: bool = False,
            gradient_checkpointing: str = "nothing_saveable",
            fcm_min_ratio: float = -1,
            fcm_max_ratio: float = -1,
            rope_scaling: Dict[str, Union[str, float]] = None,
            scan_mlp_chunk_size: int = 1024,
            bits: Optional[int] = None,
            hidden_act: str = 'silu',
            pretraining_tp: int = 1,
            scan_layers: bool = False,
            **kwargs,
    ):
        """
        The __init__ function is called when the class is instantiated.
        It sets up the attributes of an object, which are sometimes called fields or properties.
        The __init__ function can accept arguments, but self must be the first one.

        :param self: Refer to the object itself
        :param vocab_size: int: Set the size of the vocabulary
        :param hidden_size: int: Set the size of the hidden layers in each transformer block
        :param intermediate_size: int: Set the size of the intermediate layer
        :param num_hidden_layers: int: Determine the number of layers in the transformer
        :param num_attention_heads: int: Determine the number of attention heads
        :param number_rep_kv: int: Set the number of times to repeat the key and value vectors
        :param num_key_value_heads: Optional[int]: Define the number of key-value heads
        :param max_position_embeddings: int: Set the maximum length of a sequence
        :param rms_norm_eps: float: Prevent division by zero in the rms normalization
        :param initializer_range: float: Initialize the weights of the model
        :param use_cache: bool: Determine whether the attention layer should use a cache for faster computation
        :param bos_token_id: int: Set the beginning of sequence token
        :param eos_token_id: int: Specify the end of sentence token
        :param resid_pdrop: float: Set the dropout rate for residual connections
        :param embd_pdrop: float: Dropout the embedding layer
        :param attention_dropout: float: Dropout the attention weights
        :param tie_word_embeddings: bool: Tie the word embeddings and output layer weights
        :param gradient_checkpointing: str: Specify how to checkpoint the gradients
        :param fcm_min_ratio: float: Set the minimum ratio of the number of elements in a tensor to be processed by flash
        :param fcm_max_ratio: float: Determine the maximum ratio of
        :param rope_scaling: Dict[str: Define the scaling of the rope
        :param Union[str: Specify the type of the parameter
        :param float]]: Specify the type of the parameter
        :param shard_attention_computation: bool: when ever to use shard_map for attention
        :param bits: Optional[int]: Specify the number of bits used to quantize the weights
        :param rope_theta: float : rope_theta for compute rope
        :param attention_bias: bool : whenever to use attention bias or no
        :param hidden_act: str : hidden_act for mlp
        :param axis_dims: Sequence[int]: Specify the dimensions of each axis
        :param axis_names: Sequence[str]: Specify the names of the axes in a tensor
        :param scan_layers: bool: Determine whether to use the scan_layers or not
        :param kwargs: Pass a variable number of keyword arguments to a function
        :param : Define the number of layers in the model
        :return: Nothing

        """
        num_key_value_heads = num_key_value_heads or number_rep_kv * num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.vocab_size = vocab_size

        self.number_rep_kv = number_rep_kv
        self.hidden_size = hidden_size
        self.initializer_range = initializer_range
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.rope_theta = rope_theta
        self.attention_bias = attention_bias
        self.num_attention_heads = num_attention_heads
        self.max_position_embeddings = max_position_embeddings
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.pretraining_tp = pretraining_tp
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attention_dropout = attention_dropout
        self.gradient_checkpointing = gradient_checkpointing
        self.fcm_min_ratio = fcm_min_ratio
        self.hidden_act = hidden_act
        self.fcm_max_ratio = fcm_max_ratio
        self.rope_scaling = rope_scaling
        self.bits = bits
        self.scan_layers = scan_layers
        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            scan_mlp_chunk_size=scan_mlp_chunk_size,
            bits=bits,
            **kwargs,
        )

    def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
        """
        The get_partition_rules function is used to define the partitioning scheme for a model.
        It returns a list of tuples, where each tuple contains two elements:
            1) A regex string that matches the name of one or more parameters in the model.
            2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

        :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
        :return: A list of tuples

        """
        return (

            ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            (".*", PartitionSpec(None)),
        ) if not fully_sharded_data_parallel else (

            ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
            (".*", PartitionSpec(("fsdp", "sp"))),
        )

    def add_jax_args(
            self,
            resid_pdrop: float = 0.0,
            embd_pdrop: float = 0.0,
            attention_dropout: float = 0.0,
            tie_word_embeddings: bool = False,
            gradient_checkpointing: str = "nothing_saveable",
            fcm_min_ratio: float = 0.0,
            fcm_max_ratio: float = 0.0,
            number_rep_kv: int = 1,
            bits: Optional[int] = None,
            rope_theta: float = 10000.,
            attention_bias: bool = False,
            hidden_act: str = 'silu',
            scan_layers: bool = True,
            **kwargs,
    ):
        """
        The add_jax_args function adds the following arguments to the Transformer class:

        :param self: Refer to the current object
        :param resid_pdrop: float: Set the dropout rate for residual connections
        :param embd_pdrop: float: Set the probability of dropping an embedding
        :param attention_dropout: float: Set the probability of dropping out the attention layer
        :param tie_word_embeddings: bool: Tie the word embeddings to the decoder
        :param gradient_checkpointing: str: Control the amount of memory used by jax
        :param fcm_min_ratio: float: Control the minimum ratio of the number of chunks to be used in flash-based computation
        :param fcm_max_ratio: float: Set the maximum ratio of the number of input tokens to output tokens
        :param number_rep_kv: int: Determine how many times the key and value vectors are repeated
        :param bits: Optional[int]: Determine the number of bits used in the quantization
        :param rope_theta: float : rope_theta for compute rope
        :param attention_bias: bool : whenever to use attention bias or no
        :param hidden_act: str : hidden_act for mlp
        :param scan_layers: bool: Determine whether to use scan layers or not
        """
        self.scan_layers = scan_layers
        self.embd_pdrop = embd_pdrop
        self.number_rep_kv = number_rep_kv
        self.resid_pdrop = resid_pdrop
        self.rope_theta = rope_theta
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.hidden_act = hidden_act
        self.tie_word_embeddings = tie_word_embeddings
        self.gradient_checkpointing = gradient_checkpointing
        self.fcm_min_ratio = fcm_min_ratio
        self.fcm_max_ratio = fcm_max_ratio
        self.bits = bits

    @staticmethod
    def get_weight_decay_exclusions():
        return tuple()

    @staticmethod
    def rng_keys():
        return 'params', 'dropout', 'fcm'

__init__(vocab_size=32000, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, number_rep_kv=1, num_key_value_heads=None, max_position_embeddings=2048, rms_norm_eps=1e-06, initializer_range=0.02, use_cache=True, bos_token_id=0, eos_token_id=1, resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, rope_theta=10000.0, attention_bias=False, tie_word_embeddings=False, gradient_checkpointing='nothing_saveable', fcm_min_ratio=-1, fcm_max_ratio=-1, rope_scaling=None, scan_mlp_chunk_size=1024, bits=None, hidden_act='silu', pretraining_tp=1, scan_layers=False, **kwargs)

The init function is called when the class is instantiated. It sets up the attributes of an object, which are sometimes called fields or properties. The init function can accept arguments, but self must be the first one.

Parameters:

Name Type Description Default
self

Refer to the object itself

required
vocab_size int

int: Set the size of the vocabulary

32000
hidden_size int

int: Set the size of the hidden layers in each transformer block

4096
intermediate_size int

int: Set the size of the intermediate layer

11008
num_hidden_layers int

int: Determine the number of layers in the transformer

32
num_attention_heads int

int: Determine the number of attention heads

32
number_rep_kv int

int: Set the number of times to repeat the key and value vectors

1
num_key_value_heads Optional[int]

Optional[int]: Define the number of key-value heads

None
max_position_embeddings int

int: Set the maximum length of a sequence

2048
rms_norm_eps float

float: Prevent division by zero in the rms normalization

1e-06
initializer_range float

float: Initialize the weights of the model

0.02
use_cache bool

bool: Determine whether the attention layer should use a cache for faster computation

True
bos_token_id int

int: Set the beginning of sequence token

0
eos_token_id int

int: Specify the end of sentence token

1
resid_pdrop float

float: Set the dropout rate for residual connections

0.0
embd_pdrop float

float: Dropout the embedding layer

0.0
attention_dropout float

float: Dropout the attention weights

0.0
tie_word_embeddings bool

bool: Tie the word embeddings and output layer weights

False
gradient_checkpointing str

str: Specify how to checkpoint the gradients

'nothing_saveable'
fcm_min_ratio float

float: Set the minimum ratio of the number of elements in a tensor to be processed by flash

-1
fcm_max_ratio float

float: Determine the maximum ratio of

-1
rope_scaling Dict[str, Union[str, float]]

Dict[str: Define the scaling of the rope

None
Union[str

Specify the type of the parameter

required
float]]

Specify the type of the parameter

required
shard_attention_computation

bool: when ever to use shard_map for attention

required
bits Optional[int]

Optional[int]: Specify the number of bits used to quantize the weights

None
rope_theta float

float : rope_theta for compute rope

10000.0
attention_bias bool

bool : whenever to use attention bias or no

False
hidden_act str

str : hidden_act for mlp

'silu'
axis_dims

Sequence[int]: Specify the dimensions of each axis

required
axis_names

Sequence[str]: Specify the names of the axes in a tensor

required
scan_layers bool

bool: Determine whether to use the scan_layers or not

False
kwargs

Pass a variable number of keyword arguments to a function

{}

Define the number of layers in the model

required

Returns:

Type Description

Nothing

Source code in src/python/easydel/modules/llama/llama_configuration.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(
        self,
        vocab_size: int = 32000,
        hidden_size: int = 4096,
        intermediate_size: int = 11008,
        num_hidden_layers: int = 32,
        num_attention_heads: int = 32,
        number_rep_kv: int = 1,
        num_key_value_heads: Optional[int] = None,
        max_position_embeddings: int = 2048,
        rms_norm_eps: float = 1e-6,
        initializer_range: float = 0.02,
        use_cache: bool = True,
        bos_token_id: int = 0,
        eos_token_id: int = 1,
        resid_pdrop: float = 0.0,
        embd_pdrop: float = 0.0,
        attention_dropout: float = 0.0,
        rope_theta: float = 10000.,
        attention_bias: bool = False,
        tie_word_embeddings: bool = False,
        gradient_checkpointing: str = "nothing_saveable",
        fcm_min_ratio: float = -1,
        fcm_max_ratio: float = -1,
        rope_scaling: Dict[str, Union[str, float]] = None,
        scan_mlp_chunk_size: int = 1024,
        bits: Optional[int] = None,
        hidden_act: str = 'silu',
        pretraining_tp: int = 1,
        scan_layers: bool = False,
        **kwargs,
):
    """
    The __init__ function is called when the class is instantiated.
    It sets up the attributes of an object, which are sometimes called fields or properties.
    The __init__ function can accept arguments, but self must be the first one.

    :param self: Refer to the object itself
    :param vocab_size: int: Set the size of the vocabulary
    :param hidden_size: int: Set the size of the hidden layers in each transformer block
    :param intermediate_size: int: Set the size of the intermediate layer
    :param num_hidden_layers: int: Determine the number of layers in the transformer
    :param num_attention_heads: int: Determine the number of attention heads
    :param number_rep_kv: int: Set the number of times to repeat the key and value vectors
    :param num_key_value_heads: Optional[int]: Define the number of key-value heads
    :param max_position_embeddings: int: Set the maximum length of a sequence
    :param rms_norm_eps: float: Prevent division by zero in the rms normalization
    :param initializer_range: float: Initialize the weights of the model
    :param use_cache: bool: Determine whether the attention layer should use a cache for faster computation
    :param bos_token_id: int: Set the beginning of sequence token
    :param eos_token_id: int: Specify the end of sentence token
    :param resid_pdrop: float: Set the dropout rate for residual connections
    :param embd_pdrop: float: Dropout the embedding layer
    :param attention_dropout: float: Dropout the attention weights
    :param tie_word_embeddings: bool: Tie the word embeddings and output layer weights
    :param gradient_checkpointing: str: Specify how to checkpoint the gradients
    :param fcm_min_ratio: float: Set the minimum ratio of the number of elements in a tensor to be processed by flash
    :param fcm_max_ratio: float: Determine the maximum ratio of
    :param rope_scaling: Dict[str: Define the scaling of the rope
    :param Union[str: Specify the type of the parameter
    :param float]]: Specify the type of the parameter
    :param shard_attention_computation: bool: when ever to use shard_map for attention
    :param bits: Optional[int]: Specify the number of bits used to quantize the weights
    :param rope_theta: float : rope_theta for compute rope
    :param attention_bias: bool : whenever to use attention bias or no
    :param hidden_act: str : hidden_act for mlp
    :param axis_dims: Sequence[int]: Specify the dimensions of each axis
    :param axis_names: Sequence[str]: Specify the names of the axes in a tensor
    :param scan_layers: bool: Determine whether to use the scan_layers or not
    :param kwargs: Pass a variable number of keyword arguments to a function
    :param : Define the number of layers in the model
    :return: Nothing

    """
    num_key_value_heads = num_key_value_heads or number_rep_kv * num_attention_heads
    self.num_key_value_heads = num_key_value_heads
    self.vocab_size = vocab_size

    self.number_rep_kv = number_rep_kv
    self.hidden_size = hidden_size
    self.initializer_range = initializer_range
    self.intermediate_size = intermediate_size
    self.num_hidden_layers = num_hidden_layers
    self.rope_theta = rope_theta
    self.attention_bias = attention_bias
    self.num_attention_heads = num_attention_heads
    self.max_position_embeddings = max_position_embeddings
    self.rms_norm_eps = rms_norm_eps
    self.use_cache = use_cache
    self.pretraining_tp = pretraining_tp
    self.resid_pdrop = resid_pdrop
    self.embd_pdrop = embd_pdrop
    self.attention_dropout = attention_dropout
    self.gradient_checkpointing = gradient_checkpointing
    self.fcm_min_ratio = fcm_min_ratio
    self.hidden_act = hidden_act
    self.fcm_max_ratio = fcm_max_ratio
    self.rope_scaling = rope_scaling
    self.bits = bits
    self.scan_layers = scan_layers
    super().__init__(
        bos_token_id=bos_token_id,
        eos_token_id=eos_token_id,
        tie_word_embeddings=tie_word_embeddings,
        scan_mlp_chunk_size=scan_mlp_chunk_size,
        bits=bits,
        **kwargs,
    )

add_jax_args(resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, tie_word_embeddings=False, gradient_checkpointing='nothing_saveable', fcm_min_ratio=0.0, fcm_max_ratio=0.0, number_rep_kv=1, bits=None, rope_theta=10000.0, attention_bias=False, hidden_act='silu', scan_layers=True, **kwargs)

The add_jax_args function adds the following arguments to the Transformer class:

Parameters:

Name Type Description Default
self

Refer to the current object

required
resid_pdrop float

float: Set the dropout rate for residual connections

0.0
embd_pdrop float

float: Set the probability of dropping an embedding

0.0
attention_dropout float

float: Set the probability of dropping out the attention layer

0.0
tie_word_embeddings bool

bool: Tie the word embeddings to the decoder

False
gradient_checkpointing str

str: Control the amount of memory used by jax

'nothing_saveable'
fcm_min_ratio float

float: Control the minimum ratio of the number of chunks to be used in flash-based computation

0.0
fcm_max_ratio float

float: Set the maximum ratio of the number of input tokens to output tokens

0.0
number_rep_kv int

int: Determine how many times the key and value vectors are repeated

1
bits Optional[int]

Optional[int]: Determine the number of bits used in the quantization

None
rope_theta float

float : rope_theta for compute rope

10000.0
attention_bias bool

bool : whenever to use attention bias or no

False
hidden_act str

str : hidden_act for mlp

'silu'
scan_layers bool

bool: Determine whether to use scan layers or not

True
Source code in src/python/easydel/modules/llama/llama_configuration.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def add_jax_args(
        self,
        resid_pdrop: float = 0.0,
        embd_pdrop: float = 0.0,
        attention_dropout: float = 0.0,
        tie_word_embeddings: bool = False,
        gradient_checkpointing: str = "nothing_saveable",
        fcm_min_ratio: float = 0.0,
        fcm_max_ratio: float = 0.0,
        number_rep_kv: int = 1,
        bits: Optional[int] = None,
        rope_theta: float = 10000.,
        attention_bias: bool = False,
        hidden_act: str = 'silu',
        scan_layers: bool = True,
        **kwargs,
):
    """
    The add_jax_args function adds the following arguments to the Transformer class:

    :param self: Refer to the current object
    :param resid_pdrop: float: Set the dropout rate for residual connections
    :param embd_pdrop: float: Set the probability of dropping an embedding
    :param attention_dropout: float: Set the probability of dropping out the attention layer
    :param tie_word_embeddings: bool: Tie the word embeddings to the decoder
    :param gradient_checkpointing: str: Control the amount of memory used by jax
    :param fcm_min_ratio: float: Control the minimum ratio of the number of chunks to be used in flash-based computation
    :param fcm_max_ratio: float: Set the maximum ratio of the number of input tokens to output tokens
    :param number_rep_kv: int: Determine how many times the key and value vectors are repeated
    :param bits: Optional[int]: Determine the number of bits used in the quantization
    :param rope_theta: float : rope_theta for compute rope
    :param attention_bias: bool : whenever to use attention bias or no
    :param hidden_act: str : hidden_act for mlp
    :param scan_layers: bool: Determine whether to use scan layers or not
    """
    self.scan_layers = scan_layers
    self.embd_pdrop = embd_pdrop
    self.number_rep_kv = number_rep_kv
    self.resid_pdrop = resid_pdrop
    self.rope_theta = rope_theta
    self.attention_bias = attention_bias
    self.attention_dropout = attention_dropout
    self.hidden_act = hidden_act
    self.tie_word_embeddings = tie_word_embeddings
    self.gradient_checkpointing = gradient_checkpointing
    self.fcm_min_ratio = fcm_min_ratio
    self.fcm_max_ratio = fcm_max_ratio
    self.bits = bits

get_partition_rules(fully_sharded_data_parallel=True)

The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

Parameters:

Name Type Description Default
fully_sharded_data_parallel bool

bool: Determine whether to partition the model fully or not

True

Returns:

Type Description

A list of tuples

Source code in src/python/easydel/modules/llama/llama_configuration.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
    """
    The get_partition_rules function is used to define the partitioning scheme for a model.
    It returns a list of tuples, where each tuple contains two elements:
        1) A regex string that matches the name of one or more parameters in the model.
        2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

    :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
    :return: A list of tuples

    """
    return (

        ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        (".*", PartitionSpec(None)),
    ) if not fully_sharded_data_parallel else (

        ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
        (".*", PartitionSpec(("fsdp", "sp"))),
    )