Skip to content

modules.qwen1.qwen1_configuration

Qwen1Config

Bases: EasyDeLPretrainedConfig

Source code in src/python/easydel/modules/qwen1/qwen1_configuration.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class Qwen1Config(EasyDeLPretrainedConfig):
    model_type: str = "qwen"

    def __init__(
            self,
            vocab_size=151936,
            hidden_size=4096,
            num_hidden_layers=32,
            num_attention_heads=32,
            emb_dropout_prob=0.0,
            attn_dropout_prob=0.0,
            layer_norm_epsilon=1e-6,
            initializer_range=0.02,
            seq_length=8192,
            scale_attn_weights=True,
            use_cache=True,
            kv_channels=128,
            rotary_pct=1.0,
            rotary_emb_base=10000,
            use_dynamic_ntk=True,
            use_logn_attn=True,
            intermediate_size=22016,
            no_bias=True,
            tie_word_embeddings=False,
            softmax_in_fp32=False,
            gradient_checkpointing: str = "nothing_saveable",
            use_scan_mlp: bool = False,
            scan_mlp_chunk_size: int = 1024,
            bits: Optional[int] = None,
            scan_layers: bool = True,
            init_rope_cache_auto: bool = False,
            **kwargs,
    ):
        self.vocab_size = vocab_size
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.scale_attn_weights = scale_attn_weights
        self.no_bias = no_bias
        self.kv_channels = kv_channels
        self.use_dynamic_ntk = use_dynamic_ntk
        self.use_logn_attn = use_logn_attn
        self.rotary_emb_base = rotary_emb_base
        self.rotary_pct = rotary_pct
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.layer_norm_epsilon = layer_norm_epsilon
        self.softmax_in_fp32 = softmax_in_fp32
        self.initializer_range = initializer_range
        self.use_cache = use_cache
        self.scan_layers = scan_layers
        self.emb_dropout_prob = emb_dropout_prob
        self.attn_dropout_prob = attn_dropout_prob
        self.init_rope_cache_auto = init_rope_cache_auto
        self.tie_word_embeddings = tie_word_embeddings
        self.gradient_checkpointing = gradient_checkpointing
        self.use_scan_mlp = use_scan_mlp
        self.scan_mlp_chunk_size = scan_mlp_chunk_size
        self.bits = bits
        super().__init__(
            tie_word_embeddings=tie_word_embeddings,
            use_scan_mlp=use_scan_mlp,
            scan_mlp_chunk_size=scan_mlp_chunk_size,
            bits=bits,
            **kwargs,
        )

    def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
        """
        The get_partition_rules function is used to define the partitioning scheme for a model.
        It returns a list of tuples, where each tuple contains two elements:
            1) A regex string that matches the name of one or more parameters in the model.
            2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

        :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
        :return: A list of tuples

        """
        return (

            ("model/wte/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

            ("self_attn/c_attn/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("self_attn/c_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

            ("mlp/w1/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("mlp/w2/kernel", PartitionSpec(("fsdp", "sp")), "tp"),
            ("mlp/c_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

            ("ln_1/kernel", PartitionSpec(None)),
            ("ln_2/kernel", PartitionSpec(None)),

            ("model/ln_f/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            (".*", PartitionSpec(None)),
        ) if not fully_sharded_data_parallel else (

            ("model/wte/embedding", PartitionSpec(("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))),

            ("mlp/w1/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/w2/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/c_proj/kernel", PartitionSpec(("fsdp", "sp"))),

            ("ln_1/kernel", PartitionSpec(None)),
            ("ln_2/kernel", PartitionSpec(None)),

            ("model/ln_f/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
            (".*", PartitionSpec(None)),

        )

    def add_jax_args(
            self,
            gradient_checkpointing: str = "nothing_saveable",
            use_scan_mlp: bool = False,
            scan_mlp_chunk_size: int = 1024,
            bits: Optional[int] = None,
            scan_layers: bool = True,
            init_rope_cache_auto: bool = False,
            **kwargs,
    ):
        """
        The add_jax_args function adds the following arguments to the Transformer class:

        :param self: Refer to the current object
        :param gradient_checkpointing: str: Control the amount of memory used by jax
        :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or not
        :param scan_mlp_chunk_size: int: Set the chunk size for scan_mlp
        :param init_rope_cache_auto: bool: Whether to use the rope_cache_auto in model
        :param bits: Optional[int]: Determine the number of bits used in the quantization
        :param scan_layers: bool: Determine whether to use scan layers or not
        :return: The following:

        """
        self.scan_layers = scan_layers
        self.gradient_checkpointing = gradient_checkpointing
        self.use_scan_mlp = use_scan_mlp
        self.scan_mlp_chunk_size = scan_mlp_chunk_size
        self.bits = bits
        self.init_rope_cache_auto = init_rope_cache_auto

    @staticmethod
    def get_weight_decay_exclusions():
        return tuple()

    @staticmethod
    def rng_keys():
        return "params", "dropout", "fcm"

add_jax_args(gradient_checkpointing='nothing_saveable', use_scan_mlp=False, scan_mlp_chunk_size=1024, bits=None, scan_layers=True, init_rope_cache_auto=False, **kwargs)

The add_jax_args function adds the following arguments to the Transformer class:

Parameters:

Name Type Description Default
self

Refer to the current object

required
gradient_checkpointing str

str: Control the amount of memory used by jax

'nothing_saveable'
use_scan_mlp bool

bool: Determine whether to use the scan_mlp function or not

False
scan_mlp_chunk_size int

int: Set the chunk size for scan_mlp

1024
init_rope_cache_auto bool

bool: Whether to use the rope_cache_auto in model

False
bits Optional[int]

Optional[int]: Determine the number of bits used in the quantization

None
scan_layers bool

bool: Determine whether to use scan layers or not

True

Returns:

Type Description

The following:

Source code in src/python/easydel/modules/qwen1/qwen1_configuration.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def add_jax_args(
        self,
        gradient_checkpointing: str = "nothing_saveable",
        use_scan_mlp: bool = False,
        scan_mlp_chunk_size: int = 1024,
        bits: Optional[int] = None,
        scan_layers: bool = True,
        init_rope_cache_auto: bool = False,
        **kwargs,
):
    """
    The add_jax_args function adds the following arguments to the Transformer class:

    :param self: Refer to the current object
    :param gradient_checkpointing: str: Control the amount of memory used by jax
    :param use_scan_mlp: bool: Determine whether to use the scan_mlp function or not
    :param scan_mlp_chunk_size: int: Set the chunk size for scan_mlp
    :param init_rope_cache_auto: bool: Whether to use the rope_cache_auto in model
    :param bits: Optional[int]: Determine the number of bits used in the quantization
    :param scan_layers: bool: Determine whether to use scan layers or not
    :return: The following:

    """
    self.scan_layers = scan_layers
    self.gradient_checkpointing = gradient_checkpointing
    self.use_scan_mlp = use_scan_mlp
    self.scan_mlp_chunk_size = scan_mlp_chunk_size
    self.bits = bits
    self.init_rope_cache_auto = init_rope_cache_auto

get_partition_rules(fully_sharded_data_parallel=True)

The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

Parameters:

Name Type Description Default
fully_sharded_data_parallel bool

bool: Determine whether to partition the model fully or not

True

Returns:

Type Description

A list of tuples

Source code in src/python/easydel/modules/qwen1/qwen1_configuration.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
    """
    The get_partition_rules function is used to define the partitioning scheme for a model.
    It returns a list of tuples, where each tuple contains two elements:
        1) A regex string that matches the name of one or more parameters in the model.
        2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

    :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
    :return: A list of tuples

    """
    return (

        ("model/wte/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

        ("self_attn/c_attn/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/c_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

        ("mlp/w1/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("mlp/w2/kernel", PartitionSpec(("fsdp", "sp")), "tp"),
        ("mlp/c_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

        ("ln_1/kernel", PartitionSpec(None)),
        ("ln_2/kernel", PartitionSpec(None)),

        ("model/ln_f/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        (".*", PartitionSpec(None)),
    ) if not fully_sharded_data_parallel else (

        ("model/wte/embedding", PartitionSpec(("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/o_proj/kernel", PartitionSpec("tp", ("sp", "fsdp"))),

        ("mlp/w1/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/w2/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/c_proj/kernel", PartitionSpec(("fsdp", "sp"))),

        ("ln_1/kernel", PartitionSpec(None)),
        ("ln_2/kernel", PartitionSpec(None)),

        ("model/ln_f/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
        (".*", PartitionSpec(None)),

    )