Skip to content

modules.mistral.vision_mistral_configuration

VisionMistralConfig

Bases: MistralConfig

Source code in src/python/easydel/modules/mistral/vision_mistral_configuration.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class VisionMistralConfig(MistralConfig):
    def __init__(
            self,
            vision_vocab_size=8448,
            tie_vision_embeddings=False,
            sample_mode="all",
            **kwargs
    ):
        super().__init__(**kwargs)
        self.vision_vocab_size = vision_vocab_size
        self.tie_vision_embeddings = tie_vision_embeddings
        self.sample_mode = sample_mode

    def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
        """
        The get_partition_rules function is used to define the partitioning scheme for a model.
        It returns a list of tuples, where each tuple contains two elements:
            1) A regex string that matches the name of one or more parameters in the model.
            2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

        :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
        :return: A list of tuples

        """
        return (

            ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),
            ("model/embed_vision/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            ("vision_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            (".*", PartitionSpec(None)),
        ) if not fully_sharded_data_parallel else (

            ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))),
            ("model/embed_vision/embedding", PartitionSpec(("fsdp", "sp"))),

            ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"))),
            ("self_attn/o_proj/kernel", PartitionSpec(("fsdp", "sp"))),

            ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
            ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

            ("input_layernorm/kernel", PartitionSpec(None)),
            ("post_attention_layernorm/kernel", PartitionSpec(None)),

            ("model/norm/kernel", PartitionSpec(None)),
            ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
            ("vision_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
            (".*", PartitionSpec(("fsdp", "sp"))),
        )

get_partition_rules(fully_sharded_data_parallel=True)

The get_partition_rules function is used to define the partitioning scheme for a model. It returns a list of tuples, where each tuple contains two elements: 1) A regex string that matches the name of one or more parameters in the model. 2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

Parameters:

Name Type Description Default
fully_sharded_data_parallel bool

bool: Determine whether to partition the model fully or not

True

Returns:

Type Description

A list of tuples

Source code in src/python/easydel/modules/mistral/vision_mistral_configuration.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def get_partition_rules(self, fully_sharded_data_parallel: bool = True):
    """
    The get_partition_rules function is used to define the partitioning scheme for a model.
    It returns a list of tuples, where each tuple contains two elements:
        1) A regex string that matches the name of one or more parameters in the model.
        2) A PartitionScheme object that defines how those parameters should be partitioned across devices.

    :param fully_sharded_data_parallel: bool: Determine whether to partition the model fully or not
    :return: A list of tuples

    """
    return (

        ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),
        ("model/embed_vision/embedding", PartitionSpec("tp", ("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        ("vision_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        (".*", PartitionSpec(None)),
    ) if not fully_sharded_data_parallel else (

        ("model/embed_tokens/embedding", PartitionSpec(("fsdp", "sp"))),
        ("model/embed_vision/embedding", PartitionSpec(("fsdp", "sp"))),

        ("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"))),
        ("self_attn/o_proj/kernel", PartitionSpec(("fsdp", "sp"))),

        ("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/down_proj/kernel", PartitionSpec(("fsdp", "sp"))),
        ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"))),

        ("input_layernorm/kernel", PartitionSpec(None)),
        ("post_attention_layernorm/kernel", PartitionSpec(None)),

        ("model/norm/kernel", PartitionSpec(None)),
        ("lm_head/kernel", PartitionSpec(("fsdp", "sp"))),
        ("vision_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
        (".*", PartitionSpec(("fsdp", "sp"))),
    )