Skip to content

modules.gemma.gemma_configuration

GemmaConfig

Bases: EasyDeLPretrainedConfig

Source code in src/python/easydel/modules/gemma/gemma_configuration.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class GemmaConfig(EasyDeLPretrainedConfig):
    model_type: str = "gemma"

    def __init__(
            self,
            vocab_size=256000,
            hidden_size=3072,
            intermediate_size=24576,
            num_hidden_layers=28,
            num_attention_heads=16,
            num_key_value_heads=16,
            head_dim=256,
            hidden_act="gelu_pytorch_tanh",
            max_position_embeddings=8192,
            initializer_range=0.02,
            rms_norm_eps=1e-6,
            use_cache=True,
            pad_token_id=0,
            eos_token_id=1,
            bos_token_id=2,
            tie_word_embeddings=True,
            rope_theta=10000.0,
            attention_bias=False,
            attention_dropout=0.0,
            gradient_checkpointing: str = "nothing_saveable",
            bits: Optional[int] = None,
            scan_layers: bool = False,
            hidden_activation=None,
            **kwargs,
    ):
        """
        The __init__ function is called when the class is instantiated.
        It sets up the attributes of an object, which are sometimes called fields or properties.
        The __init__ function can accept arguments, but self must be the first one.
        """

        self.gradient_checkpointing = gradient_checkpointing
        self.bits = bits
        self.scan_layers = scan_layers
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.head_dim = head_dim
        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.hidden_activation = hidden_activation
        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            pad_token_id=pad_token_id,
            tie_word_embeddings=tie_word_embeddings,
            bits=bits,
            **kwargs,
        )

    def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
        """
        The get_partition_rules function is used to define the partitioning scheme for a model.
        It returns a list of tuples, where each tuple contains two elements:
            1) A regex string that matches the name of one or more parameters in the model.
            2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

        :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
        :return: A list of tuples

        """
        return (

            ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            (".*", PartitionSpec(None)),
        ) if not fully_sharded_data_parallel else (

            ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"))),
            ("self_attn/o_proj/kernel", PartitionSpec(("fsdp", "sp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
            (".*", PartitionSpec(("fsdp", "sp"))),
        )

    def add_jax_args(
            self,
            gradient_checkpointing: str = "nothing_saveable",
            bits: Optional[int] = None,
            **kwargs,
    ):
        """
        The add_jax_args function adds the following arguments to the Transformer class:

        :param self: Refer to the current object
        :param gradient_checkpointing: str: Control the amount of memory used by jax
        :param bits: Optional[int]: Determine the number of bits used in the quantization
        """
        self.gradient_checkpointing = gradient_checkpointing
        self.bits = bits

    @staticmethod
    def get_weight_decay_exclusions():
        return tuple()

    @staticmethod
    def rng_keys():
        return 'params', 'dropout', 'fcm'

__init__(vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_act='gelu_pytorch_tanh', max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, gradient_checkpointing='nothing_saveable', bits=None, scan_layers=False, hidden_activation=None, **kwargs)

The init function is called when the class is instantiated. It sets up the attributes of an object, which are sometimes called fields or properties. The init function can accept arguments, but self must be the first one.

Source code in src/python/easydel/modules/gemma/gemma_configuration.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
        self,
        vocab_size=256000,
        hidden_size=3072,
        intermediate_size=24576,
        num_hidden_layers=28,
        num_attention_heads=16,
        num_key_value_heads=16,
        head_dim=256,
        hidden_act="gelu_pytorch_tanh",
        max_position_embeddings=8192,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        pad_token_id=0,
        eos_token_id=1,
        bos_token_id=2,
        tie_word_embeddings=True,
        rope_theta=10000.0,
        attention_bias=False,
        attention_dropout=0.0,
        gradient_checkpointing: str = "nothing_saveable",
        bits: Optional[int] = None,
        scan_layers: bool = False,
        hidden_activation=None,
        **kwargs,
):
    """
    The __init__ function is called when the class is instantiated.
    It sets up the attributes of an object, which are sometimes called fields or properties.
    The __init__ function can accept arguments, but self must be the first one.
    """

    self.gradient_checkpointing = gradient_checkpointing
    self.bits = bits
    self.scan_layers = scan_layers
    self.vocab_size = vocab_size
    self.max_position_embeddings = max_position_embeddings
    self.hidden_size = hidden_size
    self.intermediate_size = intermediate_size
    self.num_hidden_layers = num_hidden_layers
    self.num_attention_heads = num_attention_heads
    self.head_dim = head_dim
    self.num_key_value_heads = num_key_value_heads
    self.hidden_act = hidden_act
    self.initializer_range = initializer_range
    self.rms_norm_eps = rms_norm_eps
    self.use_cache = use_cache
    self.rope_theta = rope_theta
    self.attention_bias = attention_bias
    self.attention_dropout = attention_dropout
    self.hidden_activation = hidden_activation
    super().__init__(
        bos_token_id=bos_token_id,
        eos_token_id=eos_token_id,
        pad_token_id=pad_token_id,
        tie_word_embeddings=tie_word_embeddings,
        bits=bits,
        **kwargs,
    )

add_jax_args(gradient_checkpointing='nothing_saveable', bits=None, **kwargs)

The add_jax_args function adds the following arguments to the Transformer class:

Parameters:

Name Type Description Default
self

Refer to the current object

required
gradient_checkpointing str

str: Control the amount of memory used by jax

'nothing_saveable'
bits Optional[int]

Optional[int]: Determine the number of bits used in the quantization

None
Source code in src/python/easydel/modules/gemma/gemma_configuration.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def add_jax_args(
        self,
        gradient_checkpointing: str = "nothing_saveable",
        bits: Optional[int] = None,
        **kwargs,
):
    """
    The add_jax_args function adds the following arguments to the Transformer class:

    :param self: Refer to the current object
    :param gradient_checkpointing: str: Control the amount of memory used by jax
    :param bits: Optional[int]: Determine the number of bits used in the quantization
    """
    self.gradient_checkpointing = gradient_checkpointing
    self.bits = bits

get_partition_rules(fully_sharded_data_parallel=True)

The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

Parameters:

Name Type Description Default
fully_sharded_data_parallel bool

bool: Determine whether to partition the model fully or not

True

Returns:

Type Description

A list of tuples

Source code in src/python/easydel/modules/gemma/gemma_configuration.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
    """
    The get_partition_rules function is used to define the partitioning scheme for a model.
    It returns a list of tuples, where each tuple contains two elements:
        1) A regex string that matches the name of one or more parameters in the model.
        2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

    :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
    :return: A list of tuples

    """
    return (

        ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        (".*", PartitionSpec(None)),
    ) if not fully_sharded_data_parallel else (

        ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"))),
        ("self_attn/o_proj/kernel", PartitionSpec(("fsdp", "sp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
        (".*", PartitionSpec(("fsdp", "sp"))),
    )