Skip to content

modules.mixtral.modelling_mixtral_flax

FlaxMixtralAttention

Bases: BaseJAXAttentionModule

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
class FlaxMixtralAttention(BaseJAXAttentionModule):
    config: MixtralConfig
    layer_index: int
    dtype: jnp.dtype = jnp.bfloat16
    param_dtype: jnp.dtype = jnp.bfloat16
    precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest")

    def setup(self) -> None:
        config = self.config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings

        dense = functools.partial(
            Linear,
            use_bias=getattr(self.config, "attention_bias", False),
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
            kernel_init=nn.initializers.normal(),
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )

        self.q_proj = dense(self.num_heads * self.head_dim)
        self.k_proj = dense(self.num_key_value_heads * self.head_dim)
        self.v_proj = dense(self.num_key_value_heads * self.head_dim)
        self.o_proj = dense(self.hidden_size)
        self.rotary = FlaxMixtralRotaryEmbedding(self.dtype)
        self.attention_performer = AttentionModule(
            use_sharding_constraint=self.config.use_sharding_constraint,
            block_k_major=self.config.block_k_major,
            block_b=self.config.block_b,
            block_q=self.config.block_q,
            block_k=self.config.block_k,
            block_q_major_dkv=self.config.block_q_major_dkv,
            block_k_major_dkv=self.config.block_k_major_dkv,
            block_k_major_dq=self.config.block_k_major_dq,
            block_k_dkv=self.config.block_k_dkv,
            block_q_dkv=self.config.block_q_dkv,
            block_q_dq=self.config.block_q_dq,
            block_k_dq=self.config.block_k_dq,
            num_attention_heads=self.config.num_attention_heads,
            attention_dropout=self.config.attention_dropout,
            head_dims=self.head_dim,
            attention_partition_spec=self.config.attention_partition_spec,
            shard_attention_computation=self.config.shard_attention_computation,
            precision=self.precision,
            force_float32_tpu=True,
            attn_mechanism=self.config.attn_mechanism,
            dtype=self.dtype,
            bias_partition_spec=self.config.bias_partition_spec,
            key_partition_spec=self.config.key_partition_spec,
            query_partition_spec=self.config.query_partition_spec,
            generation_query_partition_spec=self.config.generation_query_partition_spec,
            generation_bias_partition_spec=self.config.generation_bias_partition_spec,
            generation_attention_partition_spec=self.config.generation_attention_partition_spec,
            value_partition_spec=self.config.value_partition_spec,
            scan_ring_attention=self.config.scan_ring_attention,
            mesh=self.config.jax_mesh(),
            sm_scale=1 / math.sqrt(self.head_dim),
            axis_name=self.config.attention_axis_name
        )

    @staticmethod
    def _transpose_sequence_head(query, key, value):
        return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3))

    def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids):
        query = query.reshape(batch_size, sequence_length,
                              self.config.num_attention_heads, self.head_dim)
        key = key.reshape(batch_size, sequence_length,
                          self.config.num_key_value_heads, self.head_dim)
        value = value.reshape(batch_size, sequence_length,
                              self.config.num_key_value_heads, self.head_dim)

        query, key, value = self._transpose_sequence_head(query, key, value)
        query, key = self.rotary(
            position_ids=position_ids, query=query, key=key, freq_cis=freq_cis)
        key = repeat_kv_bnsh(key, self.num_key_value_groups)
        value = repeat_kv_bnsh(value, self.num_key_value_groups)
        return self._transpose_sequence_head(query, key, value)

    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))

    def __call__(
            self,
            hidden_states: chex.Array,
            freq_cis: Tuple[chex.Array, chex.Array],
            attention_mask: chex.Array,
            causal_mask: chex.Array,
            position_ids: chex.Array,
            segment_ids: Optional[chex.Array] = None,
            deterministic: bool = True,
            init_cache: bool = False,
            output_attentions: bool = True
    ):
        """
        The __call__ function is the main function of a JAX module.
        It defines how the module behaves when called as a function, and it's what you'll use to call your model in practice.
        The __call__ method takes an input tensor (x) and returns an output tensor (y).
        In this case, we're defining our model to be a simple linear layer with no activation: y = x @ w + b.

        :param self: Refer to the object itself
        :param hidden_states: chex.Array: Pass in the hidden state of the model
        :param freq_cis: Tuple[chex.Array, chex.Array],: Create the apply_rotary variable
        :param attention_mask: chex.Array: Mask the attention weights
        :param causal_mask: chex.Array: Mask the attention weights
        :param position_ids: chex.Array: Specify the position of each token in a sequence
        :param deterministic: bool: Determine whether to use dropout or not
        :param init_cache: bool: Initialize the cache
        :param output_attentions: bool: Determine whether to return the attention weights
        :return: A tuple of (out, attn_output)

        """
        batch_size, sequence_length = hidden_states.shape[:2]
        query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(
            hidden_states)

        query_states = query_states.reshape(
            batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
        key_states = key_states.reshape(
            batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
        value_states = value_states.reshape(
            batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)

        query_states, key_states, value_states = self.apply_rotary(
            query=query_states,
            key=key_states,
            value=value_states,
            position_ids=position_ids,
            freq_cis=freq_cis,
            batch_size=batch_size,
            sequence_length=sequence_length
        )

        assert_msg = (
            "num_attention_heads repeat wont work likely\n"
            f"INFO :\n\trepeat_kv_bnsh Used with num_key_value_groups = {self.num_key_value_groups}\n\t"
            f"NH : {self.config.num_attention_heads} KVH : {self.config.num_attention_heads}"
        )

        assert query_states.shape[-2] == self.config.num_attention_heads, assert_msg
        assert key_states.shape[-2] == self.config.num_attention_heads, assert_msg
        assert value_states.shape[-2] == self.config.num_attention_heads, assert_msg

        query_length, key_length = query_states.shape[1], key_states.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                causal_mask, (0, 0, mask_shift, 0), (1, 1,
                                                     query_length, max_decoder_length)
            )
        else:
            causal_mask = causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(
            causal_mask, (batch_size,) + causal_mask.shape[1:])
        attention_mask = jnp.broadcast_to(jnp.expand_dims(
            attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)
        if attention_mask.ndim == 2:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        dropout_rng = None

        if not deterministic and self.config.attention_dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        if self.has_variable("cache", "cached_key") or init_cache:
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states,
                value_states,
                query_states,
                attention_mask
            )
        # if self.config.use_sharding_constraint:
        #     query_states = with_sharding_constraint(
        #         query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None)
        #     )
        #     key_states = with_sharding_constraint(
        #         key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
        #     )
        #     value_states = with_sharding_constraint(
        #         value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
        #     )
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, jnp.finfo(
                self.dtype).min).astype(self.dtype),
        )

        query_length, key_length = query_states.shape[1], key_states.shape[1]

        attentions = self.attention_performer.__call__(
            query_states=query_states,
            key_states=key_states,
            value_states=value_states,
            bias=attention_bias,
            attention_mask=attention_mask,
            causal=True,
            dropout_rng=dropout_rng,
            deterministic=deterministic,
            query_sequence_length=query_length,
            key_value_sequence_length=key_length,
            uses_cache=self.has_variable("cache", "cached_key") or init_cache,
            segment_ids=segment_ids,
            causal_mask=causal_mask
        )

        attn_output = self._merge_heads(attentions.attention_outputs)
        if self.config.shard_attention_computation:
            attn_output = with_sharding_constraint(
                attn_output, PartitionSpec(
                    ("dp", "fsdp"),
                    "sp" if attn_output.shape[1] != 1 else None,
                    "tp"
                )
            )
        attn_output = self.o_proj(attn_output)
        outputs = (
            attn_output, attentions.attention_weights
        )
        return outputs

__call__(hidden_states, freq_cis, attention_mask, causal_mask, position_ids, segment_ids=None, deterministic=True, init_cache=False, output_attentions=True)

The call function is the main function of a JAX module. It defines how the module behaves when called as a function, and it's what you'll use to call your model in practice. The call method takes an input tensor (x) and returns an output tensor (y). In this case, we're defining our model to be a simple linear layer with no activation: y = x @ w + b.

Parameters:

Name Type Description Default
self

Refer to the object itself

required
hidden_states Array

chex.Array: Pass in the hidden state of the model

required
freq_cis Tuple[Array, Array]

Tuple[chex.Array, chex.Array],: Create the apply_rotary variable

required
attention_mask Array

chex.Array: Mask the attention weights

required
causal_mask Array

chex.Array: Mask the attention weights

required
position_ids Array

chex.Array: Specify the position of each token in a sequence

required
deterministic bool

bool: Determine whether to use dropout or not

True
init_cache bool

bool: Initialize the cache

False
output_attentions bool

bool: Determine whether to return the attention weights

True

Returns:

Type Description

A tuple of (out, attn_output)

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def __call__(
        self,
        hidden_states: chex.Array,
        freq_cis: Tuple[chex.Array, chex.Array],
        attention_mask: chex.Array,
        causal_mask: chex.Array,
        position_ids: chex.Array,
        segment_ids: Optional[chex.Array] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = True
):
    """
    The __call__ function is the main function of a JAX module.
    It defines how the module behaves when called as a function, and it's what you'll use to call your model in practice.
    The __call__ method takes an input tensor (x) and returns an output tensor (y).
    In this case, we're defining our model to be a simple linear layer with no activation: y = x @ w + b.

    :param self: Refer to the object itself
    :param hidden_states: chex.Array: Pass in the hidden state of the model
    :param freq_cis: Tuple[chex.Array, chex.Array],: Create the apply_rotary variable
    :param attention_mask: chex.Array: Mask the attention weights
    :param causal_mask: chex.Array: Mask the attention weights
    :param position_ids: chex.Array: Specify the position of each token in a sequence
    :param deterministic: bool: Determine whether to use dropout or not
    :param init_cache: bool: Initialize the cache
    :param output_attentions: bool: Determine whether to return the attention weights
    :return: A tuple of (out, attn_output)

    """
    batch_size, sequence_length = hidden_states.shape[:2]
    query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(
        hidden_states)

    query_states = query_states.reshape(
        batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
    key_states = key_states.reshape(
        batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
    value_states = value_states.reshape(
        batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)

    query_states, key_states, value_states = self.apply_rotary(
        query=query_states,
        key=key_states,
        value=value_states,
        position_ids=position_ids,
        freq_cis=freq_cis,
        batch_size=batch_size,
        sequence_length=sequence_length
    )

    assert_msg = (
        "num_attention_heads repeat wont work likely\n"
        f"INFO :\n\trepeat_kv_bnsh Used with num_key_value_groups = {self.num_key_value_groups}\n\t"
        f"NH : {self.config.num_attention_heads} KVH : {self.config.num_attention_heads}"
    )

    assert query_states.shape[-2] == self.config.num_attention_heads, assert_msg
    assert key_states.shape[-2] == self.config.num_attention_heads, assert_msg
    assert value_states.shape[-2] == self.config.num_attention_heads, assert_msg

    query_length, key_length = query_states.shape[1], key_states.shape[1]

    if self.has_variable("cache", "cached_key"):
        mask_shift = self.variables["cache"]["cache_index"]
        max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
        causal_mask = lax.dynamic_slice(
            causal_mask, (0, 0, mask_shift, 0), (1, 1,
                                                 query_length, max_decoder_length)
        )
    else:
        causal_mask = causal_mask[:, :, :query_length, :key_length]

    batch_size = hidden_states.shape[0]
    causal_mask = jnp.broadcast_to(
        causal_mask, (batch_size,) + causal_mask.shape[1:])
    attention_mask = jnp.broadcast_to(jnp.expand_dims(
        attention_mask, axis=(-3, -2)), causal_mask.shape)
    attention_mask = combine_masks(attention_mask, causal_mask)
    if attention_mask.ndim == 2:
        attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

    dropout_rng = None

    if not deterministic and self.config.attention_dropout > 0.0:
        dropout_rng = self.make_rng("dropout")

    if self.has_variable("cache", "cached_key") or init_cache:
        key_states, value_states, attention_mask = self._concatenate_to_cache(
            key_states,
            value_states,
            query_states,
            attention_mask
        )
    # if self.config.use_sharding_constraint:
    #     query_states = with_sharding_constraint(
    #         query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None)
    #     )
    #     key_states = with_sharding_constraint(
    #         key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
    #     )
    #     value_states = with_sharding_constraint(
    #         value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
    #     )
    attention_bias = lax.select(
        attention_mask > 0,
        jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
        jnp.full(attention_mask.shape, jnp.finfo(
            self.dtype).min).astype(self.dtype),
    )

    query_length, key_length = query_states.shape[1], key_states.shape[1]

    attentions = self.attention_performer.__call__(
        query_states=query_states,
        key_states=key_states,
        value_states=value_states,
        bias=attention_bias,
        attention_mask=attention_mask,
        causal=True,
        dropout_rng=dropout_rng,
        deterministic=deterministic,
        query_sequence_length=query_length,
        key_value_sequence_length=key_length,
        uses_cache=self.has_variable("cache", "cached_key") or init_cache,
        segment_ids=segment_ids,
        causal_mask=causal_mask
    )

    attn_output = self._merge_heads(attentions.attention_outputs)
    if self.config.shard_attention_computation:
        attn_output = with_sharding_constraint(
            attn_output, PartitionSpec(
                ("dp", "fsdp"),
                "sp" if attn_output.shape[1] != 1 else None,
                "tp"
            )
        )
    attn_output = self.o_proj(attn_output)
    outputs = (
        attn_output, attentions.attention_weights
    )
    return outputs

FlaxMixtralDecoderLayer

Bases: Module

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
class FlaxMixtralDecoderLayer(nn.Module):
    config: MixtralConfig
    layer_index: int
    dtype: jnp.dtype = jnp.bfloat16
    param_dtype: jnp.dtype = jnp.bfloat16
    precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest")

    def setup(self) -> None:
        # hidden_states: chex.Array
        # freq_cis: Tuple[chex.Array, chex.Array],
        # attention_mask: chex.Array
        # causal_mask: chex.Array
        # position_ids: chex.Array
        # deterministic: bool = True
        # init_cache: bool = False
        # output_attentions: bool = True

        attn_block = FlaxMixtralAttention
        mlp_block = FlaxMixtralSparseMoeBlock
        if self.config.gradient_checkpointing != "":
            attn_block = re_mat(
                attn_block,
                policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing),
                static_argnums=(
                    1, 3, 4, 6, 7, 8, 9
                )
            )
            mlp_block = re_mat(
                mlp_block,
                policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing),
                static_argnums=(
                    1,
                )
            )
        self.self_attn = attn_block(
            config=self.config,
            layer_index=self.layer_index,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision
        )
        self.block_sparse_moe = mlp_block(
            config=self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision
        )
        self.input_layernorm = MixtralRMSNorm(
            dim=self.config.hidden_size,
            eps=self.config.rms_norm_eps,
            dtype=self.dtype,
            param_dtype=self.param_dtype
        )
        self.post_attention_layernorm = MixtralRMSNorm(
            dim=self.config.hidden_size,
            eps=self.config.rms_norm_eps,
            dtype=self.dtype,
            param_dtype=self.param_dtype
        )

    def __call__(
            self,
            hidden_states: chex.Array,
            freq_cis: Tuple[chex.Array, chex.Array],
            attention_mask: chex.Array,
            causal_mask: chex.Array,
            position_ids: chex.Array,
            segment_ids: Optional[chex.Array] = None,
            deterministic: bool = True,
            init_cache: bool = False,
            output_attentions: bool = True,
            output_router_logits: Optional[bool] = None,
    ):
        """
        The __call__ function is the main function of a TransformerEncoderLayer.
        It takes in the following arguments:
            hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers.
            freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num

        :param self: Represent the instance of the class
        :param hidden_states: chex.Array: Represent the input to the encoder layer
        :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer
        :param attention_mask: chex.Array: Mask out the attention weights for certain positions
        :param causal_mask: chex.Array: Mask the future tokens
        :param position_ids: chex.Array: Indicate the position of each token in the sequence
        :param deterministic: bool: Determine whether to use dropout or not
        :param init_cache: bool: Initialize the cache for the self-attention layer
        :param output_attentions: bool: Determine whether to return the attention weights or not
        :return: A tuple of hidden_states and attention_output

        """
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        # hidden_states: chex.Array
        # freq_cis: Tuple[chex.Array, chex.Array],
        # attention_mask: chex.Array
        # causal_mask: chex.Array
        # position_ids: chex.Array
        # segment_ids: Optional[chex.Array] = None
        # deterministic: bool = True
        # init_cache: bool = False
        # output_attentions: bool = True

        hidden_states, self_attn_weights = self.self_attn(
            hidden_states,
            freq_cis,
            attention_mask,
            causal_mask,
            position_ids,
            segment_ids,
            deterministic,
            init_cache,
            output_attentions
        )

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states, router_logits = self.block_sparse_moe(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)
        if output_router_logits:
            outputs += (router_logits,)
        return outputs

__call__(hidden_states, freq_cis, attention_mask, causal_mask, position_ids, segment_ids=None, deterministic=True, init_cache=False, output_attentions=True, output_router_logits=None)

The call function is the main function of a TransformerEncoderLayer. It takes in the following arguments: hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
hidden_states Array

chex.Array: Represent the input to the encoder layer

required
freq_cis Tuple[Array, Array]

Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer

required
attention_mask Array

chex.Array: Mask out the attention weights for certain positions

required
causal_mask Array

chex.Array: Mask the future tokens

required
position_ids Array

chex.Array: Indicate the position of each token in the sequence

required
deterministic bool

bool: Determine whether to use dropout or not

True
init_cache bool

bool: Initialize the cache for the self-attention layer

False
output_attentions bool

bool: Determine whether to return the attention weights or not

True

Returns:

Type Description

A tuple of hidden_states and attention_output

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
def __call__(
        self,
        hidden_states: chex.Array,
        freq_cis: Tuple[chex.Array, chex.Array],
        attention_mask: chex.Array,
        causal_mask: chex.Array,
        position_ids: chex.Array,
        segment_ids: Optional[chex.Array] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = True,
        output_router_logits: Optional[bool] = None,
):
    """
    The __call__ function is the main function of a TransformerEncoderLayer.
    It takes in the following arguments:
        hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers.
        freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num

    :param self: Represent the instance of the class
    :param hidden_states: chex.Array: Represent the input to the encoder layer
    :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer
    :param attention_mask: chex.Array: Mask out the attention weights for certain positions
    :param causal_mask: chex.Array: Mask the future tokens
    :param position_ids: chex.Array: Indicate the position of each token in the sequence
    :param deterministic: bool: Determine whether to use dropout or not
    :param init_cache: bool: Initialize the cache for the self-attention layer
    :param output_attentions: bool: Determine whether to return the attention weights or not
    :return: A tuple of hidden_states and attention_output

    """
    residual = hidden_states
    hidden_states = self.input_layernorm(hidden_states)

    # hidden_states: chex.Array
    # freq_cis: Tuple[chex.Array, chex.Array],
    # attention_mask: chex.Array
    # causal_mask: chex.Array
    # position_ids: chex.Array
    # segment_ids: Optional[chex.Array] = None
    # deterministic: bool = True
    # init_cache: bool = False
    # output_attentions: bool = True

    hidden_states, self_attn_weights = self.self_attn(
        hidden_states,
        freq_cis,
        attention_mask,
        causal_mask,
        position_ids,
        segment_ids,
        deterministic,
        init_cache,
        output_attentions
    )

    hidden_states = residual + hidden_states

    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states, router_logits = self.block_sparse_moe(hidden_states)
    hidden_states = residual + hidden_states

    outputs = (hidden_states,)
    if output_attentions:
        outputs += (self_attn_weights,)
    if output_router_logits:
        outputs += (router_logits,)
    return outputs

FlaxMixtralDecoderLayerCollection

Bases: Module

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
class FlaxMixtralDecoderLayerCollection(nn.Module):
    config: MixtralConfig
    dtype: jnp.dtype = jnp.bfloat16
    param_dtype: jnp.dtype = jnp.bfloat16
    precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest")

    def setup(self) -> None:
        self.blocks = [
            FlaxMixtralDecoderLayer(
                layer_index=layer_index,
                config=self.config,
                dtype=self.dtype,
                param_dtype=self.param_dtype,
                precision=self.precision,
                name=str(layer_index)
            )

            for layer_index in range(self.config.num_hidden_layers)
        ]

    def __call__(
            self,
            hidden_states: chex.Array,
            freq_cis: Tuple[chex.Array, chex.Array],
            attention_mask: chex.Array,
            causal_mask: chex.Array,
            position_ids: chex.Array,
            deterministic: bool = True,
            init_cache: bool = False,
            output_hidden_states: Optional[bool] = False,
            output_attentions: Optional[bool] = False,
            output_router_logits: Optional[bool] = None,
    ):
        """
        The __call__ function is the main function of a TransformerEncoderLayer.
        It takes in the following arguments:
            hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers.
            freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num

        :param self: Represent the instance of the class
        :param hidden_states: chex.Array: Represent the input to the encoder layer
        :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer
        :param attention_mask: chex.Array: Mask out the attention weights for certain positions
        :param causal_mask: chex.Array: Mask the future tokens
        :param position_ids: chex.Array: Indicate the position of each token in the sequence
        :param deterministic: bool: Determine whether to use dropout or not
        :param init_cache: bool: Initialize the cache for the self-attention layer
        :param output_attentions: bool: Determine whether to return the attention weights or not
        :return: A tuple of hidden_states, attention_output, all_hidden_states and all_router_logits

        """
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_router_logits = () if output_router_logits else None

        for block in self.blocks:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            layer_outputs = block(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                output_attentions=output_attentions,
                output_router_logits=output_router_logits,
                init_cache=init_cache,
                freq_cis=freq_cis,
                causal_mask=causal_mask,
                deterministic=deterministic,
            )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

            if output_router_logits:
                all_router_logits += (layer_outputs[-1],)

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (all_self_attns,)
        if output_hidden_states:
            outputs += (all_hidden_states,)
        if output_router_logits:
            outputs += (all_router_logits,)
        return outputs

__call__(hidden_states, freq_cis, attention_mask, causal_mask, position_ids, deterministic=True, init_cache=False, output_hidden_states=False, output_attentions=False, output_router_logits=None)

The call function is the main function of a TransformerEncoderLayer. It takes in the following arguments: hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers. freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
hidden_states Array

chex.Array: Represent the input to the encoder layer

required
freq_cis Tuple[Array, Array]

Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer

required
attention_mask Array

chex.Array: Mask out the attention weights for certain positions

required
causal_mask Array

chex.Array: Mask the future tokens

required
position_ids Array

chex.Array: Indicate the position of each token in the sequence

required
deterministic bool

bool: Determine whether to use dropout or not

True
init_cache bool

bool: Initialize the cache for the self-attention layer

False
output_attentions Optional[bool]

bool: Determine whether to return the attention weights or not

False

Returns:

Type Description

A tuple of hidden_states, attention_output, all_hidden_states and all_router_logits

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
def __call__(
        self,
        hidden_states: chex.Array,
        freq_cis: Tuple[chex.Array, chex.Array],
        attention_mask: chex.Array,
        causal_mask: chex.Array,
        position_ids: chex.Array,
        deterministic: bool = True,
        init_cache: bool = False,
        output_hidden_states: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        output_router_logits: Optional[bool] = None,
):
    """
    The __call__ function is the main function of a TransformerEncoderLayer.
    It takes in the following arguments:
        hidden_states (chex.Array): The input to the encoder layer, which is also its output after being processed by all sublayers.
        freq_cis (chex.Array): A tensor containing frequency-domain representations of each token's context vector, used for computing self-attention weights and biases in a more efficient manner than using position embeddings or sinusoidal positional encoding vectors would allow for [2]. This tensor has shape `(batch_size, num

    :param self: Represent the instance of the class
    :param hidden_states: chex.Array: Represent the input to the encoder layer
    :param freq_cis: Tuple[chex.Array, chex.Array],: Pass the frequency information to the attention layer
    :param attention_mask: chex.Array: Mask out the attention weights for certain positions
    :param causal_mask: chex.Array: Mask the future tokens
    :param position_ids: chex.Array: Indicate the position of each token in the sequence
    :param deterministic: bool: Determine whether to use dropout or not
    :param init_cache: bool: Initialize the cache for the self-attention layer
    :param output_attentions: bool: Determine whether to return the attention weights or not
    :return: A tuple of hidden_states, attention_output, all_hidden_states and all_router_logits

    """
    all_hidden_states = () if output_hidden_states else None
    all_self_attns = () if output_attentions else None
    all_router_logits = () if output_router_logits else None

    for block in self.blocks:
        if output_hidden_states:
            all_hidden_states += (hidden_states,)
        layer_outputs = block(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_router_logits=output_router_logits,
            init_cache=init_cache,
            freq_cis=freq_cis,
            causal_mask=causal_mask,
            deterministic=deterministic,
        )

        hidden_states = layer_outputs[0]

        if output_attentions:
            all_self_attns += (layer_outputs[1],)

        if output_router_logits:
            all_router_logits += (layer_outputs[-1],)

    outputs = (hidden_states,)
    if output_attentions:
        outputs += (all_self_attns,)
    if output_hidden_states:
        outputs += (all_hidden_states,)
    if output_router_logits:
        outputs += (all_router_logits,)
    return outputs

FlaxMixtralForCausalLM

Bases: MixtralPreTrainedModel

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
class FlaxMixtralForCausalLM(MixtralPreTrainedModel):
    module_class = FlaxMixtralForCausalLMModule

    def set_input_embeddings(self, value):
        self.module.model.embed_tokens = value

    def get_input_embeddings(self):
        return self.module.model.embed_tokens

    def set_decoder(self, decoder):
        self.module.model = decoder

    def get_decoder(self):
        return self.module.model

    def get_output_embeddings(self):
        return self.module.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.module.lm_head = new_embeddings

    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
        """
        The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

        :param self: Access variables that belong to the class
        :param input_ids: Pass in the input tokens
        :param max_length: Set the length of the sequence to be generated
        :param attention_mask: Optional[chex.Array]: Mask the attention weights
        :return: A dictionary of the past_key_values, attention_mask and position ids

        """
        batch_size, seq_length = input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length)
        extended_attention_mask = jnp.ones(
            (batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(
                extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                            None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs

prepare_inputs_for_generation(input_ids, max_length, attention_mask=None)

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
input_ids

Pass in the input tokens

required
max_length

Set the length of the sequence to be generated

required
attention_mask Optional[Array]

Optional[chex.Array]: Mask the attention weights

None

Returns:

Type Description

A dictionary of the past_key_values, attention_mask and position ids

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
    """
    The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

    :param self: Access variables that belong to the class
    :param input_ids: Pass in the input tokens
    :param max_length: Set the length of the sequence to be generated
    :param attention_mask: Optional[chex.Array]: Mask the attention weights
    :return: A dictionary of the past_key_values, attention_mask and position ids

    """
    batch_size, seq_length = input_ids.shape

    past_key_values = self.init_cache(batch_size, max_length)
    extended_attention_mask = jnp.ones(
        (batch_size, max_length), dtype="i4")
    if attention_mask is not None:
        position_ids = attention_mask.cumsum(axis=-1) - 1
        extended_attention_mask = lax.dynamic_update_slice(
            extended_attention_mask, attention_mask, (0, 0))
    else:
        position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                        None, :], (batch_size, seq_length))

    return {
        "past_key_values": past_key_values,
        "attention_mask": extended_attention_mask,
        "position_ids": position_ids,
    }

FlaxMixtralSparseMoeBlock

Bases: Module

This implementation is strictly equivalent to standard MoE with full capacity (no dropped tokens). It's faster since it formulates MoE operations in terms of block-sparse operations to accomodate imbalanced assignments of tokens to experts, whereas standard MoE either (1) drop tokens at the cost of reduced performance or (2) set capacity factor to number of experts and thus waste computation and memory on padding.

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
class FlaxMixtralSparseMoeBlock(nn.Module):
    """
    This implementation is
    strictly equivalent to standard MoE with full capacity (no
    dropped tokens). It's faster since it formulates MoE operations
    in terms of block-sparse operations to accomodate imbalanced
    assignments of tokens to experts, whereas standard MoE either
    (1) drop tokens at the cost of reduced performance or (2) set
    capacity factor to number of experts and thus waste computation
    and memory on padding.
    """
    config: MixtralConfig
    dtype: jnp.dtype = jnp.bfloat16
    param_dtype: jnp.dtype = jnp.bfloat16
    precision: Optional[
        Union[None, jax.lax.Precision]
    ] = jax.lax.Precision("fastest")

    def setup(self) -> None:
        self.gate = Linear(
            self.config.num_local_experts,
            use_bias=False,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
            kernel_init=nn.initializers.normal(),
        )

        self.experts = FlaxMixtralBlocKSparesTop2MLPCollection(
            config=self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision
        )

    def __call__(
            self,
            hidden_states: chex.Array,
            e: bool = False  # Ignored
    ) -> Tuple[chex.Array, chex.Array]:
        batch_size, sequence_length, hidden_dim = hidden_states.shape

        router_logits = self.gate(hidden_states).astype(  # no reshaping is needed
            jnp.promote_types(self.dtype, jnp.float32)
        )
        routing_weights, selected_experts = jax.lax.top_k(
            router_logits,
            k=self.config.num_experts_per_tok
        )
        routing_weights = jax.nn.softmax(
            routing_weights.astype(
                jnp.promote_types(self.dtype, jnp.float32)
            ), axis=-1
        )

        return self.experts(
            selected_experts=selected_experts,
            batch_size=batch_size,
            sequence_length=sequence_length,
            hidden_dim=hidden_dim,
            hidden_states=hidden_states,
            routing_weights=routing_weights
        ), router_logits

MixtralPreTrainedModel

Bases: EasyDeLFlaxPretrainedModel

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
class MixtralPreTrainedModel(EasyDeLFlaxPretrainedModel):
    config_class: MixtralConfig = MixtralConfig
    module_class: nn.Module = None
    base_model_prefix = "model"

    # main_input_name = "input_ids"

    def __init__(
            self,
            config: MixtralConfig,
            dtype: jnp.dtype = jnp.bfloat16,
            param_dtype: jnp.dtype = jnp.bfloat16,
            precision: Optional[jax.lax.Precision] = jax.lax.Precision(
                "fastest"),
            input_shape: Tuple[int, int] = (1, 1),
            seed: int = 0,
            _do_init: bool = False,
            **kwargs
    ):
        module = self.module_class(
            config=config,
            dtype=dtype,
            param_dtype=param_dtype,
            precision=precision,
            **kwargs
        )

        super().__init__(
            dtype=dtype, _do_init=_do_init,
            module=module, config=config, input_shape=input_shape,
            seed=seed,
        )

    def init_weights(
            self,
            rng: jax.random.PRNGKey,
            input_shape: Tuple,
            params: FrozenDict = None
    ) -> FrozenDict:
        """
        The init_weights function is used to initialize the weights of a model.
        It takes in a rng, which is a random number generator key that can be used to generate random numbers.
        The input_shape parameter specifies the shape of the inputs that will be fed into this model.
        The params parameter allows you to pass in pre-trained weights for your model, if you have them available.

        :param self: Access variables that belong to the class
        :param rng: jax.random.PRNGKey: Initialize the weights of the model
        :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids
        :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model
        :return: A frozendict of parameters
        """

        self.config.initialization_of_moe = True
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids, dtype="i4")
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"),
            input_shape,
        )
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}
        if self.config.add_cross_attention:
            encoder_hidden_states = jnp.zeros(
                input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                position_ids,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            module_init_outputs = self.module.init(
                rngs,
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=False
            )
        random_params = module_init_outputs["params"]

        self.config.initialization_of_moe = False
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def init_cache(self, batch_size, max_length):

        input_ids = jnp.ones((batch_size, max_length))
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(
            jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return init_variables["cache"]

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: Optional[chex.Array] = None,
            position_ids: Optional[chex.Array] = None,
            params: dict = None,
            past_key_values: dict = None,
            dropout_rng: jax.random.PRNGKey = None,
            train: bool = False,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            output_router_logits: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            add_params_field: bool = False,
            **kwargs
    ):
        """
        The __call__ function is the main function of a JAX module.
        It takes as input:
        - The parameters of the model (self.params)
        - The inputs to the model (input_ids, attention_mask, position_ids)
        - Whether we are training (train=True/False) and whether we want to return all hidden states and
        attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False).

        :param self: Represent the instance of the class
        :param input_ids: Pass the input sequence to the model
        :param attention_mask: Mask out the padding tokens
        :param position_ids: Specify the position of each token in the sequence
        :param params: dict: Pass in the parameters of the model
        :param past_key_values: dict: Pass the past key values to the model
        :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model
        :param train: bool: Determine whether to use dropout or not
        :param output_attentions: Optional[bool]: Determine whether to return the attention weights
        :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers
        :param return_dict: Optional[bool]: Return a dictionary of the outputs
        :param add_params_field: bool: Add a params field to the inputs dictionary
        :return: A tuple of (last_hidden_state, past_key_values)

        """

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `position_ids` when passing `past_key_values`.")

            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[
                                            None, :], (batch_size, sequence_length))

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        rng_s = {}
        if dropout_rng is not None:
            rng_s["dropout"] = dropout_rng

        inputs = {
            "params": params or self.params} if add_params_field else params or self.params

        if self.config.bits is not None:
            rng_s['params'] = jax.random.key(0)
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),  # input_ids: chex.Array
            # attention_mask: Optional[chex.Array] = None
            jnp.array(attention_mask, dtype="i4"),
            # position_ids: Optional[chex.Array] = None
            jnp.array(position_ids, dtype="i4"),
            None,  # inputs_embeds: Optional[chex.Array] = None
            output_attentions,  # output_attentions: Optional[bool] = None
            # output_hidden_states: Optional[bool] = None
            output_hidden_states,
            # output_router_logits: Optional[bool] = None
            output_router_logits,
            False,  # init_cache: bool = False
            not train,  # deterministic: bool = True
            return_dict,  # return_dict: bool = True
            rngs=rng_s,
            mutable=mutable,
        )

        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + \
                      (unfreeze(past_key_values["cache"]),) + outputs[1:]

        return outputs

__call__(input_ids, attention_mask=None, position_ids=None, params=None, past_key_values=None, dropout_rng=None, train=False, output_attentions=None, output_hidden_states=None, output_router_logits=None, return_dict=None, add_params_field=False, **kwargs)

The call function is the main function of a JAX module. It takes as input: - The parameters of the model (self.params) - The inputs to the model (input_ids, attention_mask, position_ids) - Whether we are training (train=True/False) and whether we want to return all hidden states and attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False).

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
input_ids Array

Pass the input sequence to the model

required
attention_mask Optional[Array]

Mask out the padding tokens

None
position_ids Optional[Array]

Specify the position of each token in the sequence

None
params dict

dict: Pass in the parameters of the model

None
past_key_values dict

dict: Pass the past key values to the model

None
dropout_rng PRNGKey

jax.random.PRNGKey: Pass in a random number generator key to the model

None
train bool

bool: Determine whether to use dropout or not

False
output_attentions Optional[bool]

Optional[bool]: Determine whether to return the attention weights

None
output_hidden_states Optional[bool]

Optional[bool]: Determine whether to return the hidden states of all layers

None
return_dict Optional[bool]

Optional[bool]: Return a dictionary of the outputs

None
add_params_field bool

bool: Add a params field to the inputs dictionary

False

Returns:

Type Description

A tuple of (last_hidden_state, past_key_values)

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
def __call__(
        self,
        input_ids: chex.Array,
        attention_mask: Optional[chex.Array] = None,
        position_ids: Optional[chex.Array] = None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        add_params_field: bool = False,
        **kwargs
):
    """
    The __call__ function is the main function of a JAX module.
    It takes as input:
    - The parameters of the model (self.params)
    - The inputs to the model (input_ids, attention_mask, position_ids)
    - Whether we are training (train=True/False) and whether we want to return all hidden states and
    attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False).

    :param self: Represent the instance of the class
    :param input_ids: Pass the input sequence to the model
    :param attention_mask: Mask out the padding tokens
    :param position_ids: Specify the position of each token in the sequence
    :param params: dict: Pass in the parameters of the model
    :param past_key_values: dict: Pass the past key values to the model
    :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model
    :param train: bool: Determine whether to use dropout or not
    :param output_attentions: Optional[bool]: Determine whether to return the attention weights
    :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers
    :param return_dict: Optional[bool]: Return a dictionary of the outputs
    :param add_params_field: bool: Add a params field to the inputs dictionary
    :return: A tuple of (last_hidden_state, past_key_values)

    """

    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.return_dict

    batch_size, sequence_length = input_ids.shape

    if position_ids is None:
        if past_key_values is not None:
            raise ValueError(
                "Make sure to provide `position_ids` when passing `past_key_values`.")

        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[
                                        None, :], (batch_size, sequence_length))

    if attention_mask is None:
        attention_mask = jnp.ones((batch_size, sequence_length))

    rng_s = {}
    if dropout_rng is not None:
        rng_s["dropout"] = dropout_rng

    inputs = {
        "params": params or self.params} if add_params_field else params or self.params

    if self.config.bits is not None:
        rng_s['params'] = jax.random.key(0)
    if past_key_values:
        inputs["cache"] = past_key_values
        mutable = ["cache"]
    else:
        mutable = False

    outputs = self.module.apply(
        inputs,
        jnp.array(input_ids, dtype="i4"),  # input_ids: chex.Array
        # attention_mask: Optional[chex.Array] = None
        jnp.array(attention_mask, dtype="i4"),
        # position_ids: Optional[chex.Array] = None
        jnp.array(position_ids, dtype="i4"),
        None,  # inputs_embeds: Optional[chex.Array] = None
        output_attentions,  # output_attentions: Optional[bool] = None
        # output_hidden_states: Optional[bool] = None
        output_hidden_states,
        # output_router_logits: Optional[bool] = None
        output_router_logits,
        False,  # init_cache: bool = False
        not train,  # deterministic: bool = True
        return_dict,  # return_dict: bool = True
        rngs=rng_s,
        mutable=mutable,
    )

    if past_key_values is not None and return_dict:
        outputs, past_key_values = outputs
        outputs["past_key_values"] = unfreeze(past_key_values["cache"])
        return outputs
    elif past_key_values is not None and not return_dict:
        outputs, past_key_values = outputs
        outputs = outputs[:1] + \
                  (unfreeze(past_key_values["cache"]),) + outputs[1:]

    return outputs

init_weights(rng, input_shape, params=None)

The init_weights function is used to initialize the weights of a model. It takes in a rng, which is a random number generator key that can be used to generate random numbers. The input_shape parameter specifies the shape of the inputs that will be fed into this model. The params parameter allows you to pass in pre-trained weights for your model, if you have them available.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
rng PRNGKey

jax.random.PRNGKey: Initialize the weights of the model

required
input_shape Tuple

Tuple: Initialize the input_ids, attention_mask and position_ids

required
params FrozenDict

flax.core.FrozenDict: Pass in the parameters of a pre-trained model

None

Returns:

Type Description
FrozenDict

A frozendict of parameters

Source code in src/python/easydel/modules/mixtral/modelling_mixtral_flax.py
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
def init_weights(
        self,
        rng: jax.random.PRNGKey,
        input_shape: Tuple,
        params: FrozenDict = None
) -> FrozenDict:
    """
    The init_weights function is used to initialize the weights of a model.
    It takes in a rng, which is a random number generator key that can be used to generate random numbers.
    The input_shape parameter specifies the shape of the inputs that will be fed into this model.
    The params parameter allows you to pass in pre-trained weights for your model, if you have them available.

    :param self: Access variables that belong to the class
    :param rng: jax.random.PRNGKey: Initialize the weights of the model
    :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids
    :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model
    :return: A frozendict of parameters
    """

    self.config.initialization_of_moe = True
    input_ids = jnp.zeros(input_shape, dtype="i4")
    attention_mask = jnp.ones_like(input_ids, dtype="i4")
    position_ids = jnp.broadcast_to(
        jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"),
        input_shape,
    )
    params_rng, dropout_rng = jax.random.split(rng)
    rngs = {"params": params_rng, "dropout": dropout_rng}
    if self.config.add_cross_attention:
        encoder_hidden_states = jnp.zeros(
            input_shape + (self.config.hidden_size,))
        encoder_attention_mask = attention_mask
        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            encoder_hidden_states,
            encoder_attention_mask,
            return_dict=False,
        )
    else:
        module_init_outputs = self.module.init(
            rngs,
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=False
        )
    random_params = module_init_outputs["params"]

    self.config.initialization_of_moe = False
    if params is not None:
        random_params = flatten_dict(unfreeze(random_params))
        params = flatten_dict(unfreeze(params))
        for missing_key in self._missing_keys:
            params[missing_key] = random_params[missing_key]
        self._missing_keys = set()
        return freeze(unflatten_dict(params))
    else:
        return random_params