Skip to content

modules.qwen2.qwen_configuration

Qwen2Config

Bases: EasyDeLPretrainedConfig

Source code in src/python/easydel/modules/qwen2/qwen_configuration.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class Qwen2Config(EasyDeLPretrainedConfig):
    model_type: str = "qwen2"

    def __init__(
            self,
            vocab_size=151936,
            hidden_size=4096,
            intermediate_size=22016,
            num_hidden_layers=32,
            num_attention_heads=32,
            num_key_value_heads=32,
            hidden_act="silu",
            max_position_embeddings=32768,
            initializer_range=0.02,
            rms_norm_eps=1e-6,
            use_cache=True,
            tie_word_embeddings=False,
            rope_theta=10000.0,
            use_sliding_window=False,
            sliding_window=4096,
            max_window_layers=28,
            attention_dropout=0.0,
            resid_pdrop: float = 0.0,
            embd_pdrop: float = 0.0,
            gradient_checkpointing: str = "nothing_saveable",
            fcm_min_ratio: float = 0.0,
            fcm_max_ratio: float = 0.0,
            use_scan_mlp: bool = False,
            scan_mlp_chunk_size: int = 1024,
            number_rep_kv: int = 1,
            bits: Optional[int] = None,
            scan_layers: bool = True,
            rope_scaling: Optional[Mapping[str, str | float]] = None,
            **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.use_sliding_window = use_sliding_window
        self.sliding_window = sliding_window
        self.max_window_layers = max_window_layers

        # for backward compatibility
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        self.rope_scaling = rope_scaling
        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.scan_layers = scan_layers
        self.embd_pdrop = embd_pdrop
        self.number_rep_kv = number_rep_kv
        self.resid_pdrop = resid_pdrop
        self.attention_dropout = attention_dropout
        self.tie_word_embeddings = tie_word_embeddings
        self.gradient_checkpointing = gradient_checkpointing
        self.fcm_min_ratio = fcm_min_ratio
        self.fcm_max_ratio = fcm_max_ratio
        self.use_scan_mlp = use_scan_mlp
        self.scan_mlp_chunk_size = scan_mlp_chunk_size
        self.bits = bits
        super().__init__(
            tie_word_embeddings=tie_word_embeddings,
            use_scan_mlp=use_scan_mlp,
            scan_mlp_chunk_size=scan_mlp_chunk_size,
            bits=bits,
            **kwargs,
        )

    def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
        """
        The get_partition_rules function is used to define the partitioning scheme for a model.
        It returns a list of tuples, where each tuple contains two elements:
            1) A regex string that matches the name of one or more parameters in the model.
            2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

        :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
        :return: A list of tuples

        """
        return (

            ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            (".*", PartitionSpec(None)),
        ) if not fully_sharded_data_parallel else (

            ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            (".*", PartitionSpec(("fsdp", "sp"))),
        )

    def add_jax_args(
            self,
            resid_pdrop: float = 0.0,
            embd_pdrop: float = 0.0,
            attention_dropout: float = 0.0,
            tie_word_embeddings: bool = False,
            gradient_checkpointing: str = "nothing_saveable",
            fcm_min_ratio: float = 0.0,
            fcm_max_ratio: float = 0.0,
            use_scan_mlp: bool = False,
            scan_mlp_chunk_size: int = 1024,
            number_rep_kv: int = 1,
            bits: Optional[int] = None,
            rope_theta: float = 10000.,
            hidden_act: str = "silu",
            scan_layers: bool = True,
            rope_scaling: Optional[Mapping[str, str | float]] = None,
            **kwargs,
    ):
        """
        The add_jax_args function adds the following arguments to the Transformer class:

        :param self: Refer to the current object
        :param resid_pdrop: float: Set the dropout rate for residual connections
        :param embd_pdrop: float: Set the probability of dropping an embedding
        :param attention_dropout: float: Set the probability of dropping out the attention layer
        :param tie_word_embeddings: bool: Tie the word embeddings to the decoder
        :param gradient_checkpointing: str: Control the amount of memory used by jax
        :param fcm_min_ratio: float: Control the minimum ratio of the number of chunks to be used in flash-based computation
        :param fcm_max_ratio: float: Set the maximum ratio of the number of input tokens to output tokens
        :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or not
        :param scan_mlp_chunk_size: int: Set the chunk size for scan_mlp
        :param number_rep_kv: int: Determine how many times the key and value vectors are repeated
        :param bits: Optional[int]: Determine the number of bits used in the quantization
        :param rope_theta: float : rope_theta for compute rope
        :param hidden_act: str : hidden_act for mlp
        :param scan_layers: bool: Determine whether to use scan layers or not
        :return: The following:

        """
        self.scan_layers = scan_layers
        self.embd_pdrop = embd_pdrop
        self.number_rep_kv = number_rep_kv
        self.resid_pdrop = resid_pdrop
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self.attention_dropout = attention_dropout
        self.hidden_act = hidden_act
        self.tie_word_embeddings = tie_word_embeddings
        self.gradient_checkpointing = gradient_checkpointing
        self.fcm_min_ratio = fcm_min_ratio
        self.fcm_max_ratio = fcm_max_ratio

        self.use_scan_mlp = use_scan_mlp
        self.scan_mlp_chunk_size = scan_mlp_chunk_size
        self.bits = bits

    @staticmethod
    def get_weight_decay_exclusions():
        return tuple()

    @staticmethod
    def rng_keys():
        return "params", "dropout", "fcm"

add_jax_args(resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, tie_word_embeddings=False, gradient_checkpointing='nothing_saveable', fcm_min_ratio=0.0, fcm_max_ratio=0.0, use_scan_mlp=False, scan_mlp_chunk_size=1024, number_rep_kv=1, bits=None, rope_theta=10000.0, hidden_act='silu', scan_layers=True, rope_scaling=None, **kwargs)

The add_jax_args function adds the following arguments to the Transformer class:

Parameters:

Name Type Description Default
self

Refer to the current object

required
resid_pdrop float

float: Set the dropout rate for residual connections

0.0
embd_pdrop float

float: Set the probability of dropping an embedding

0.0
attention_dropout float

float: Set the probability of dropping out the attention layer

0.0
tie_word_embeddings bool

bool: Tie the word embeddings to the decoder

False
gradient_checkpointing str

str: Control the amount of memory used by jax

'nothing_saveable'
fcm_min_ratio float

float: Control the minimum ratio of the number of chunks to be used in flash-based computation

0.0
fcm_max_ratio float

float: Set the maximum ratio of the number of input tokens to output tokens

0.0
use_scan_mlp bool

bool: Determine whether to use the scan_mlp function or not

False
scan_mlp_chunk_size int

int: Set the chunk size for scan_mlp

1024
number_rep_kv int

int: Determine how many times the key and value vectors are repeated

1
bits Optional[int]

Optional[int]: Determine the number of bits used in the quantization

None
rope_theta float

float : rope_theta for compute rope

10000.0
hidden_act str

str : hidden_act for mlp

'silu'
scan_layers bool

bool: Determine whether to use scan layers or not

True

Returns:

Type Description

The following:

Source code in src/python/easydel/modules/qwen2/qwen_configuration.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def add_jax_args(
        self,
        resid_pdrop: float = 0.0,
        embd_pdrop: float = 0.0,
        attention_dropout: float = 0.0,
        tie_word_embeddings: bool = False,
        gradient_checkpointing: str = "nothing_saveable",
        fcm_min_ratio: float = 0.0,
        fcm_max_ratio: float = 0.0,
        use_scan_mlp: bool = False,
        scan_mlp_chunk_size: int = 1024,
        number_rep_kv: int = 1,
        bits: Optional[int] = None,
        rope_theta: float = 10000.,
        hidden_act: str = "silu",
        scan_layers: bool = True,
        rope_scaling: Optional[Mapping[str, str | float]] = None,
        **kwargs,
):
    """
    The add_jax_args function adds the following arguments to the Transformer class:

    :param self: Refer to the current object
    :param resid_pdrop: float: Set the dropout rate for residual connections
    :param embd_pdrop: float: Set the probability of dropping an embedding
    :param attention_dropout: float: Set the probability of dropping out the attention layer
    :param tie_word_embeddings: bool: Tie the word embeddings to the decoder
    :param gradient_checkpointing: str: Control the amount of memory used by jax
    :param fcm_min_ratio: float: Control the minimum ratio of the number of chunks to be used in flash-based computation
    :param fcm_max_ratio: float: Set the maximum ratio of the number of input tokens to output tokens
    :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or not
    :param scan_mlp_chunk_size: int: Set the chunk size for scan_mlp
    :param number_rep_kv: int: Determine how many times the key and value vectors are repeated
    :param bits: Optional[int]: Determine the number of bits used in the quantization
    :param rope_theta: float : rope_theta for compute rope
    :param hidden_act: str : hidden_act for mlp
    :param scan_layers: bool: Determine whether to use scan layers or not
    :return: The following:

    """
    self.scan_layers = scan_layers
    self.embd_pdrop = embd_pdrop
    self.number_rep_kv = number_rep_kv
    self.resid_pdrop = resid_pdrop
    self.rope_theta = rope_theta
    self.rope_scaling = rope_scaling
    self.attention_dropout = attention_dropout
    self.hidden_act = hidden_act
    self.tie_word_embeddings = tie_word_embeddings
    self.gradient_checkpointing = gradient_checkpointing
    self.fcm_min_ratio = fcm_min_ratio
    self.fcm_max_ratio = fcm_max_ratio

    self.use_scan_mlp = use_scan_mlp
    self.scan_mlp_chunk_size = scan_mlp_chunk_size
    self.bits = bits

get_partition_rules(fully_sharded_data_parallel=True)

The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

Parameters:

Name Type Description Default
fully_sharded_data_parallel bool

bool: Determine whether to partition the model fully or not

True

Returns:

Type Description

A list of tuples

Source code in src/python/easydel/modules/qwen2/qwen_configuration.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
    """
    The get_partition_rules function is used to define the partitioning scheme for a model.
    It returns a list of tuples, where each tuple contains two elements:
        1) A regex string that matches the name of one or more parameters in the model.
        2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

    :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
    :return: A list of tuples

    """
    return (

        ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        (".*", PartitionSpec(None)),
    ) if not fully_sharded_data_parallel else (

        ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        (".*", PartitionSpec(("fsdp", "sp"))),
    )