Skip to content

modules.olmo.olmo_configuration

OLMoConfig

Bases: EasyDeLPretrainedConfig

OLMo (model) configuration.

Source code in src/python/easydel/modules/olmo/olmo_configuration.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class OLMoConfig(EasyDeLPretrainedConfig):
    """
    OLMo (model) configuration.
    """

    def __init__(
            self,
            d_model: int = 768,
            n_heads: int = 12,
            n_layers: int = 12,
            mlp_ratio: int = 4,
            mlp_hidden_size: Optional[int] = None,
            activation_type: ActivationType = ActivationType.swiglu,
            block_type: BlockType = BlockType.sequential,
            block_group_size: int = 1,
            alibi: bool = False,
            alibi_bias_max: float = 8.0,
            rope: bool = False,
            rope_full_precision: bool = True,
            flash_attention: bool = False,
            attention_dropout: float = 0.1,
            multi_query_attention: bool = False,
            attention_layer_norm: bool = False,
            residual_dropout: float = 0.1,
            embedding_dropout: float = 0.1,
            layer_norm_type: LayerNormType = LayerNormType.default,
            layer_norm_with_affine: bool = True,
            attention_layer_norm_with_affine: bool = True,
            max_sequence_length: int = 1024,
            include_bias: bool = True,
            bias_for_layer_norm: Optional[bool] = None,
            scale_logits: bool = False,
            vocab_size: int = 50257,
            embedding_size: Optional[int] = 50304,
            weight_tying: bool = True,
            eos_token_id: int = 50256,
            pad_token_id: int = 50256,
            init_std: float = 0.02,
            init_cutoff_factor: Optional[float] = None,
            gradient_checkpointing: str = "nothing_saveable",
            **kwargs
    ):
        _ = kwargs.pop("precision", None)
        _ = kwargs.pop("init_fn", None)
        _ = kwargs.pop("init_device", None)
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.mlp_ratio = mlp_ratio
        self.mlp_hidden_size = mlp_hidden_size
        self.activation_type = activation_type
        self.block_type = block_type
        self.block_group_size = block_group_size
        self.alibi = alibi
        self.alibi_bias_max = alibi_bias_max
        self.rope = rope
        self.rope_full_precision = rope_full_precision
        self.flash_attention = flash_attention
        self.attention_dropout = attention_dropout
        self.multi_query_attention = multi_query_attention
        self.attention_layer_norm = attention_layer_norm
        self.residual_dropout = residual_dropout
        self.embedding_dropout = embedding_dropout
        self.layer_norm_type = layer_norm_type
        self.layer_norm_with_affine = layer_norm_with_affine
        self.attention_layer_norm_with_affine = attention_layer_norm_with_affine
        self.max_sequence_length = max_sequence_length
        self.include_bias = include_bias
        self.bias_for_layer_norm = bias_for_layer_norm
        self.scale_logits = scale_logits
        self.gradient_checkpointing = gradient_checkpointing
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.weight_tying = weight_tying
        self.init_std = init_std
        self.init_cutoff_factor = init_cutoff_factor
        super().__init__(
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            **kwargs
        )

    def add_jax_args(
            self,
            gradient_checkpointing: str = "nothing_saveable"
    ):
        if not hasattr(self, "gradient_checkpointing"):
            self.gradient_checkpointing = gradient_checkpointing

    def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
        """
        The get_partition_rules function is used to define the partitioning scheme for a model.
        It returns a list of tuples, where each tuple contains two elements:
            1) A regex string that matches the name of one or more parameters in the model.
            2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

        :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
        :return: A list of tuples

        """
        return (

            ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            (".*", PartitionSpec(None)),
        ) if not fully_sharded_data_parallel else (

            ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"))),
            ("self_attn/o_proj/kernel", PartitionSpec(("fsdp", "sp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
            (".*", PartitionSpec(("fsdp", "sp"))),
        )

get_partition_rules(fully_sharded_data_parallel=True)

The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

Parameters:

Name Type Description Default
fully_sharded_data_parallel bool

bool: Determine whether to partition the model fully or not

True

Returns:

Type Description

A list of tuples

Source code in src/python/easydel/modules/olmo/olmo_configuration.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
    """
    The get_partition_rules function is used to define the partitioning scheme for a model.
    It returns a list of tuples, where each tuple contains two elements:
        1) A regex string that matches the name of one or more parameters in the model.
        2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

    :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
    :return: A list of tuples

    """
    return (

        ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        (".*", PartitionSpec(None)),
    ) if not fully_sharded_data_parallel else (

        ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"))),
        ("self_attn/o_proj/kernel", PartitionSpec(("fsdp", "sp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
        (".*", PartitionSpec(("fsdp", "sp"))),
    )