Skip to content

modules.dbrx.modelling_dbrx_flax

DbrxPreTrainedModel

Bases: EasyDeLFlaxPretrainedModel

Source code in src/python/easydel/modules/dbrx/modelling_dbrx_flax.py
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
class DbrxPreTrainedModel(EasyDeLFlaxPretrainedModel):
    config_class: DbrxConfig = DbrxConfig
    module_class: nn.Module = None
    base_model_prefix = "model"

    def __init__(
            self,
            config: DbrxConfig,
            dtype: jnp.dtype = jnp.bfloat16,
            param_dtype: jnp.dtype = jnp.bfloat16,
            precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"),
            input_shape: Tuple[int, int] = (1, 1),
            seed: int = 0,
            _do_init: bool = False,
            **kwargs
    ):
        module = self.module_class(
            config=config,
            dtype=dtype,
            param_dtype=param_dtype,
            precision=precision,
            **kwargs
        )

        super().__init__(
            dtype=dtype, _do_init=_do_init,
            module=module, config=config, input_shape=input_shape,
            seed=seed,
        )

    def init_weights(
            self,
            rng: jax.random.PRNGKey,
            input_shape: Tuple,
            params: FrozenDict = None
    ) -> FrozenDict:
        """
        The init_weights function is used to initialize the weights of a model.
        It takes in a rng, which is a random number generator key that can be used to generate random numbers.
        The input_shape parameter specifies the shape of the inputs that will be fed into this model.
        The params parameter allows you to pass in pre-trained weights for your model, if you have them available.

        :param self: Access variables that belong to the class
        :param rng: jax.random.PRNGKey: Initialize the weights of the model
        :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids
        :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model
        :return: A frozendict of parameters
        """

        self.config.initialization_of_moe = True
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids, dtype="i4")
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"),
            input_shape,
        )
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}
        if self.config.add_cross_attention:
            encoder_hidden_states = jnp.zeros(
                input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                position_ids,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            module_init_outputs = self.module.init(
                rngs,
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=False
            )
        random_params = module_init_outputs["params"]

        self.config.initialization_of_moe = False
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def init_cache(self, batch_size, max_length):

        input_ids = jnp.ones((batch_size, max_length))
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(
            jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return init_variables["cache"]

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: Optional[chex.Array] = None,
            position_ids: Optional[chex.Array] = None,
            params: dict = None,
            past_key_values: dict = None,
            dropout_rng: jax.random.PRNGKey = None,
            train: bool = False,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            output_router_logits: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            add_params_field: bool = False,
            **kwargs
    ):
        """
        The __call__ function is the main function of a JAX module.
        It takes as input:
        - The parameters of the model (self.params)
        - The inputs to the model (input_ids, attention_mask, position_ids)
        - Whether we are training (train=True/False) and whether we want to return all hidden states and
        attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False).

        :param self: Represent the instance of the class
        :param input_ids: Pass the input sequence to the model
        :param attention_mask: Mask out the padding tokens
        :param position_ids: Specify the position of each token in the sequence
        :param params: dict: Pass in the parameters of the model
        :param past_key_values: dict: Pass the past key values to the model
        :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model
        :param train: bool: Determine whether to use dropout or not
        :param output_attentions: Optional[bool]: Determine whether to return the attention weights
        :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers
        :param return_dict: Optional[bool]: Return a dictionary of the outputs
        :param add_params_field: bool: Add a params field to the inputs dictionary
        :return: A tuple of (last_hidden_state, past_key_values)

        """

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `position_ids` when passing `past_key_values`.")

            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[
                                            None, :], (batch_size, sequence_length))

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        rng_s = {}
        if dropout_rng is not None:
            rng_s["dropout"] = dropout_rng

        inputs = {
            "params": params or self.params} if add_params_field else params or self.params

        if self.config.bits is not None:
            rng_s['params'] = jax.random.key(0)
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),  # input_ids: chex.Array
            # attention_mask: Optional[chex.Array] = None
            jnp.array(attention_mask, dtype="i4"),
            # position_ids: Optional[chex.Array] = None
            jnp.array(position_ids, dtype="i4"),
            None,  # inputs_embeds: Optional[chex.Array] = None
            output_attentions,  # output_attentions: Optional[bool] = None
            # output_hidden_states: Optional[bool] = None
            output_hidden_states,
            # output_router_logits: Optional[bool] = None
            output_router_logits,
            False,  # init_cache: bool = False
            not train,  # deterministic: bool = True
            return_dict,  # return_dict: bool = True
            rngs=rng_s,
            mutable=mutable,
        )

        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

        return outputs

__call__(input_ids, attention_mask=None, position_ids=None, params=None, past_key_values=None, dropout_rng=None, train=False, output_attentions=None, output_hidden_states=None, output_router_logits=None, return_dict=None, add_params_field=False, **kwargs)

The call function is the main function of a JAX module. It takes as input: - The parameters of the model (self.params) - The inputs to the model (input_ids, attention_mask, position_ids) - Whether we are training (train=True/False) and whether we want to return all hidden states and attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False).

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
input_ids Array

Pass the input sequence to the model

required
attention_mask Optional[Array]

Mask out the padding tokens

None
position_ids Optional[Array]

Specify the position of each token in the sequence

None
params dict

dict: Pass in the parameters of the model

None
past_key_values dict

dict: Pass the past key values to the model

None
dropout_rng PRNGKey

jax.random.PRNGKey: Pass in a random number generator key to the model

None
train bool

bool: Determine whether to use dropout or not

False
output_attentions Optional[bool]

Optional[bool]: Determine whether to return the attention weights

None
output_hidden_states Optional[bool]

Optional[bool]: Determine whether to return the hidden states of all layers

None
return_dict Optional[bool]

Optional[bool]: Return a dictionary of the outputs

None
add_params_field bool

bool: Add a params field to the inputs dictionary

False

Returns:

Type Description

A tuple of (last_hidden_state, past_key_values)

Source code in src/python/easydel/modules/dbrx/modelling_dbrx_flax.py
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
def __call__(
        self,
        input_ids: chex.Array,
        attention_mask: Optional[chex.Array] = None,
        position_ids: Optional[chex.Array] = None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        add_params_field: bool = False,
        **kwargs
):
    """
    The __call__ function is the main function of a JAX module.
    It takes as input:
    - The parameters of the model (self.params)
    - The inputs to the model (input_ids, attention_mask, position_ids)
    - Whether we are training (train=True/False) and whether we want to return all hidden states and
    attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False).

    :param self: Represent the instance of the class
    :param input_ids: Pass the input sequence to the model
    :param attention_mask: Mask out the padding tokens
    :param position_ids: Specify the position of each token in the sequence
    :param params: dict: Pass in the parameters of the model
    :param past_key_values: dict: Pass the past key values to the model
    :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model
    :param train: bool: Determine whether to use dropout or not
    :param output_attentions: Optional[bool]: Determine whether to return the attention weights
    :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers
    :param return_dict: Optional[bool]: Return a dictionary of the outputs
    :param add_params_field: bool: Add a params field to the inputs dictionary
    :return: A tuple of (last_hidden_state, past_key_values)

    """

    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.return_dict

    batch_size, sequence_length = input_ids.shape

    if position_ids is None:
        if past_key_values is not None:
            raise ValueError(
                "Make sure to provide `position_ids` when passing `past_key_values`.")

        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[
                                        None, :], (batch_size, sequence_length))

    if attention_mask is None:
        attention_mask = jnp.ones((batch_size, sequence_length))

    rng_s = {}
    if dropout_rng is not None:
        rng_s["dropout"] = dropout_rng

    inputs = {
        "params": params or self.params} if add_params_field else params or self.params

    if self.config.bits is not None:
        rng_s['params'] = jax.random.key(0)
    if past_key_values:
        inputs["cache"] = past_key_values
        mutable = ["cache"]
    else:
        mutable = False

    outputs = self.module.apply(
        inputs,
        jnp.array(input_ids, dtype="i4"),  # input_ids: chex.Array
        # attention_mask: Optional[chex.Array] = None
        jnp.array(attention_mask, dtype="i4"),
        # position_ids: Optional[chex.Array] = None
        jnp.array(position_ids, dtype="i4"),
        None,  # inputs_embeds: Optional[chex.Array] = None
        output_attentions,  # output_attentions: Optional[bool] = None
        # output_hidden_states: Optional[bool] = None
        output_hidden_states,
        # output_router_logits: Optional[bool] = None
        output_router_logits,
        False,  # init_cache: bool = False
        not train,  # deterministic: bool = True
        return_dict,  # return_dict: bool = True
        rngs=rng_s,
        mutable=mutable,
    )

    if past_key_values is not None and return_dict:
        outputs, past_key_values = outputs
        outputs["past_key_values"] = unfreeze(past_key_values["cache"])
        return outputs
    elif past_key_values is not None and not return_dict:
        outputs, past_key_values = outputs
        outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

    return outputs

init_weights(rng, input_shape, params=None)

The init_weights function is used to initialize the weights of a model. It takes in a rng, which is a random number generator key that can be used to generate random numbers. The input_shape parameter specifies the shape of the inputs that will be fed into this model. The params parameter allows you to pass in pre-trained weights for your model, if you have them available.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
rng PRNGKey

jax.random.PRNGKey: Initialize the weights of the model

required
input_shape Tuple

Tuple: Initialize the input_ids, attention_mask and position_ids

required
params FrozenDict

flax.core.FrozenDict: Pass in the parameters of a pre-trained model

None

Returns:

Type Description
FrozenDict

A frozendict of parameters

Source code in src/python/easydel/modules/dbrx/modelling_dbrx_flax.py
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
def init_weights(
        self,
        rng: jax.random.PRNGKey,
        input_shape: Tuple,
        params: FrozenDict = None
) -> FrozenDict:
    """
    The init_weights function is used to initialize the weights of a model.
    It takes in a rng, which is a random number generator key that can be used to generate random numbers.
    The input_shape parameter specifies the shape of the inputs that will be fed into this model.
    The params parameter allows you to pass in pre-trained weights for your model, if you have them available.

    :param self: Access variables that belong to the class
    :param rng: jax.random.PRNGKey: Initialize the weights of the model
    :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids
    :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model
    :return: A frozendict of parameters
    """

    self.config.initialization_of_moe = True
    input_ids = jnp.zeros(input_shape, dtype="i4")
    attention_mask = jnp.ones_like(input_ids, dtype="i4")
    position_ids = jnp.broadcast_to(
        jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"),
        input_shape,
    )
    params_rng, dropout_rng = jax.random.split(rng)
    rngs = {"params": params_rng, "dropout": dropout_rng}
    if self.config.add_cross_attention:
        encoder_hidden_states = jnp.zeros(
            input_shape + (self.config.hidden_size,))
        encoder_attention_mask = attention_mask
        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            encoder_hidden_states,
            encoder_attention_mask,
            return_dict=False,
        )
    else:
        module_init_outputs = self.module.init(
            rngs,
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=False
        )
    random_params = module_init_outputs["params"]

    self.config.initialization_of_moe = False
    if params is not None:
        random_params = flatten_dict(unfreeze(random_params))
        params = flatten_dict(unfreeze(params))
        for missing_key in self._missing_keys:
            params[missing_key] = random_params[missing_key]
        self._missing_keys = set()
        return freeze(unflatten_dict(params))
    else:
        return random_params

FlaxDbrxAttention

Bases: BaseJAXAttentionModule

Source code in src/python/easydel/modules/dbrx/modelling_dbrx_flax.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
class FlaxDbrxAttention(BaseJAXAttentionModule):
    config: DbrxConfig
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self):
        self.num_attention_heads = self.config.n_heads
        self.num_key_value_heads = self.config.attn_config.kv_n_heads
        config = self.config
        self.hidden_size = config.hidden_size
        self.head_dim = self.config.d_model // self.config.n_heads
        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads

        if self.num_key_value_groups == 1:
            assert self.num_attention_heads == self.config.attn_config.kv_n_heads
        self.Wqkv = Linear(
            self.hidden_size + 2 * self.num_key_value_heads * self.head_dim,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.out_proj = Linear(
            config.hidden_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )

        self.rotary = FlaxDbrxEmbedding(self.dtype)
        self.attention_performer = AttentionModule(
            use_sharding_constraint=self.config.use_sharding_constraint,
            block_k_major=self.config.block_k_major,
            block_b=self.config.block_b,
            block_q=self.config.block_q,
            block_k=self.config.block_k,
            block_q_major_dkv=self.config.block_q_major_dkv,
            block_k_major_dkv=self.config.block_k_major_dkv,
            block_k_major_dq=self.config.block_k_major_dq,
            block_k_dkv=self.config.block_k_dkv,
            block_q_dkv=self.config.block_q_dkv,
            block_q_dq=self.config.block_q_dq,
            block_k_dq=self.config.block_k_dq,
            num_attention_heads=self.num_attention_heads,
            attention_dropout=self.config.attn_config.attn_pdrop,
            head_dims=self.head_dim,
            attention_partition_spec=self.config.attention_partition_spec,
            shard_attention_computation=self.config.shard_attention_computation,
            precision=self.precision,
            force_float32_tpu=True,
            attn_mechanism=self.config.attn_mechanism,
            dtype=self.dtype,
            bias_partition_spec=self.config.bias_partition_spec,
            key_partition_spec=self.config.key_partition_spec,
            query_partition_spec=self.config.query_partition_spec,
            generation_query_partition_spec=self.config.generation_query_partition_spec,
            generation_bias_partition_spec=self.config.generation_bias_partition_spec,
            generation_attention_partition_spec=self.config.generation_attention_partition_spec,
            value_partition_spec=self.config.value_partition_spec,
            scan_ring_attention=self.config.scan_ring_attention,
            mesh=self.config.jax_mesh(),
            sm_scale=1 / math.sqrt(self.head_dim),
            axis_name=self.config.attention_axis_name,
            backward_pass_impl=self.config.flash_attention_backward_pass_impl
        )
        self.resid_dropout = flax.linen.Dropout(rate=config.resid_pdrop)

    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))

    @staticmethod
    def _transpose_sequence_head(query, key, value):
        """
        The _transpose_sequence_head function transposes the query, key and value matrices.

        :param query: Get the attention weights for each of the heads
        :param key: Determine the number of heads
        :param value: Store the values of the input
        :return: The transpose of the query, key and value matrices

        """
        return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3))

    def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids):
        """
        The apply_rotary function is a modified version of the apply_attention function in the BertModel class.
        The main difference is that it takes in an additional argument, freq_cis, which are used to calculate
        the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.

        :param self: Access variables that belong to the class
        :param batch_size: Reshape the query, key and value tensors
        :param sequence_length: Reshape the query, key and value tensors
        :param query: Calculate the attention weights
        :param key: Calculate the attention
        :param value: Compute the attention weights
        :param freq_cis: Calculate the frequency of each word in the vocabulary
        :param position_ids: Identify the position of each token in the sequence
        :return: A tuple of 3 tensors: query, key and value

        """
        query = query.reshape(
            batch_size,
            sequence_length,
            self.num_attention_heads,
            self.head_dim
        )
        key = key.reshape(
            batch_size,
            sequence_length,
            self.num_key_value_heads,
            self.head_dim
        )
        value = value.reshape(
            batch_size,
            sequence_length,
            self.num_key_value_heads,
            self.head_dim
        )

        query, key, value = self._transpose_sequence_head(query, key, value)
        query, key = self.rotary(
            position_ids=position_ids, query=query, key=key, freq_cis=freq_cis
        )
        key = repeat_kv_bnsh(key, self.num_key_value_groups)
        value = repeat_kv_bnsh(value, self.num_key_value_groups)
        return self._transpose_sequence_head(query, key, value)

    def __call__(
            self,
            hidden_states: chex.Array,
            freq_cis: Tuple[chex.Array, chex.Array],
            attention_mask: chex.Array,
            position_ids: chex.Array,
            causal_mask: chex.Array,
            segment_ids: Optional[chex.Array] = None,
            deterministic: bool = True,
            init_cache: bool = False,
            output_attentions: bool = False,
            fcm_mask=None,
    ):
        """
        The __call__ function is the main function of a JAX module. It defines how the module behaves when called
        with inputs. The __call__ function can be thought of as a "forward pass" through the model,
        and it should return all outputs that are needed for training or inference.

        :param self: Access variables that belong to the class
        :param hidden_states: chex.Array: Pass the hidden states of the previous layer
        :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position
        :param attention_mask: chex.Array: Mask out certain tokens in the input sequence
        :param position_ids: chex.Array: Determine the position of each token in a sequence
        :param causal_mask: chex.Array: Mask out the future tokens in the decoder
        :param deterministic: bool: Determine whether to use dropout or not
        :param init_cache: bool: Initialize the cache
        :param output_attentions: bool: Determine whether to return the attention weights or not
        :param fcm_mask: Mask out the attention weights between the input and output tokens
        :param : Determine if the attention is causal or not
        :return: A tuple of two arrays

        """
        batch_size, sequence_length = hidden_states.shape[:2]
        qkv_states = self.Wqkv(hidden_states)
        if self.config.attn_config.clip_qkv is not None:
            qkv_states = qkv_states.clip(
                min=-self.config.attn_config.clip_qkv,
                max=self.config.attn_config.clip_qkv
            )

        query_size = self.hidden_size
        key_size = self.num_key_value_heads * self.head_dim

        query_states, key_value_states = jnp.split(qkv_states, [query_size], axis=2)
        key_states, value_states = jnp.split(key_value_states, [key_size], axis=2)

        query_states, key_states, value_states = self.apply_rotary(
            query=query_states,
            key=key_states,
            value=value_states,
            position_ids=position_ids,
            freq_cis=freq_cis,
            batch_size=batch_size,
            sequence_length=sequence_length
        )

        assert_msg = (
            "num_attention_heads repeat wont work likely\n"
            f"INFO :\n\trepeat_kv_bnsh Used with num_key_value_groups = {self.num_key_value_groups}\n\t"
            f"NH : {self.num_attention_heads} KVH : {self.num_attention_heads}"
        )

        assert query_states.shape[-2] == self.num_attention_heads, assert_msg
        assert key_states.shape[-2] == self.num_attention_heads, assert_msg
        assert value_states.shape[-2] == self.num_attention_heads, assert_msg

        query_length, key_length = query_states.shape[1], key_states.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                causal_mask,
                (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length)
            )
        else:
            causal_mask = causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(
            causal_mask, (batch_size,) + causal_mask.shape[1:])
        attention_mask = jnp.broadcast_to(jnp.expand_dims(
            attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
        if attention_mask.ndim == 2:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        dropout_rng = None

        if not deterministic and self.config.attn_config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        if self.has_variable("cache", "cached_key") or init_cache:
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states,
                value_states,
                query_states,
                attention_mask
            )

        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, jnp.finfo(
                self.dtype).min).astype(self.dtype),
        )

        query_length, key_length = query_states.shape[1], key_states.shape[1]

        attentions = self.attention_performer.__call__(
            query_states=query_states,
            key_states=key_states,
            value_states=value_states,
            bias=attention_bias,
            attention_mask=attention_mask,
            causal=True,
            dropout_rng=dropout_rng,
            deterministic=deterministic,
            query_sequence_length=query_length,
            key_value_sequence_length=key_length,
            uses_cache=self.has_variable("cache", "cached_key") or init_cache,
            segment_ids=segment_ids,
            causal_mask=causal_mask
        )

        attn_output = self._merge_heads(attentions.attention_outputs)
        if self.config.shard_attention_computation:
            attn_output = with_sharding_constraint(
                attn_output, PartitionSpec(
                    ("dp", "fsdp"),
                    "sp" if attn_output.shape[1] != 1 else None,
                    "tp"
                )
            )
        attn_output = self.out_proj(attn_output)

        attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
        return attn_output, attentions.attention_weights

__call__(hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids=None, deterministic=True, init_cache=False, output_attentions=False, fcm_mask=None)

The call function is the main function of a JAX module. It defines how the module behaves when called with inputs. The call function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
hidden_states Array

chex.Array: Pass the hidden states of the previous layer

required
freq_cis Tuple[Array, Array]

Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position

required
attention_mask Array

chex.Array: Mask out certain tokens in the input sequence

required
position_ids Array

chex.Array: Determine the position of each token in a sequence

required
causal_mask Array

chex.Array: Mask out the future tokens in the decoder

required
deterministic bool

bool: Determine whether to use dropout or not

True
init_cache bool

bool: Initialize the cache

False
output_attentions bool

bool: Determine whether to return the attention weights or not

False
fcm_mask

Mask out the attention weights between the input and output tokens

None

Determine if the attention is causal or not

required

Returns:

Type Description

A tuple of two arrays

Source code in src/python/easydel/modules/dbrx/modelling_dbrx_flax.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def __call__(
        self,
        hidden_states: chex.Array,
        freq_cis: Tuple[chex.Array, chex.Array],
        attention_mask: chex.Array,
        position_ids: chex.Array,
        causal_mask: chex.Array,
        segment_ids: Optional[chex.Array] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        fcm_mask=None,
):
    """
    The __call__ function is the main function of a JAX module. It defines how the module behaves when called
    with inputs. The __call__ function can be thought of as a "forward pass" through the model,
    and it should return all outputs that are needed for training or inference.

    :param self: Access variables that belong to the class
    :param hidden_states: chex.Array: Pass the hidden states of the previous layer
    :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position
    :param attention_mask: chex.Array: Mask out certain tokens in the input sequence
    :param position_ids: chex.Array: Determine the position of each token in a sequence
    :param causal_mask: chex.Array: Mask out the future tokens in the decoder
    :param deterministic: bool: Determine whether to use dropout or not
    :param init_cache: bool: Initialize the cache
    :param output_attentions: bool: Determine whether to return the attention weights or not
    :param fcm_mask: Mask out the attention weights between the input and output tokens
    :param : Determine if the attention is causal or not
    :return: A tuple of two arrays

    """
    batch_size, sequence_length = hidden_states.shape[:2]
    qkv_states = self.Wqkv(hidden_states)
    if self.config.attn_config.clip_qkv is not None:
        qkv_states = qkv_states.clip(
            min=-self.config.attn_config.clip_qkv,
            max=self.config.attn_config.clip_qkv
        )

    query_size = self.hidden_size
    key_size = self.num_key_value_heads * self.head_dim

    query_states, key_value_states = jnp.split(qkv_states, [query_size], axis=2)
    key_states, value_states = jnp.split(key_value_states, [key_size], axis=2)

    query_states, key_states, value_states = self.apply_rotary(
        query=query_states,
        key=key_states,
        value=value_states,
        position_ids=position_ids,
        freq_cis=freq_cis,
        batch_size=batch_size,
        sequence_length=sequence_length
    )

    assert_msg = (
        "num_attention_heads repeat wont work likely\n"
        f"INFO :\n\trepeat_kv_bnsh Used with num_key_value_groups = {self.num_key_value_groups}\n\t"
        f"NH : {self.num_attention_heads} KVH : {self.num_attention_heads}"
    )

    assert query_states.shape[-2] == self.num_attention_heads, assert_msg
    assert key_states.shape[-2] == self.num_attention_heads, assert_msg
    assert value_states.shape[-2] == self.num_attention_heads, assert_msg

    query_length, key_length = query_states.shape[1], key_states.shape[1]

    if self.has_variable("cache", "cached_key"):
        mask_shift = self.variables["cache"]["cache_index"]
        max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
        causal_mask = lax.dynamic_slice(
            causal_mask,
            (0, 0, mask_shift, 0),
            (1, 1, query_length, max_decoder_length)
        )
    else:
        causal_mask = causal_mask[:, :, :query_length, :key_length]

    batch_size = hidden_states.shape[0]
    causal_mask = jnp.broadcast_to(
        causal_mask, (batch_size,) + causal_mask.shape[1:])
    attention_mask = jnp.broadcast_to(jnp.expand_dims(
        attention_mask, axis=(-3, -2)), causal_mask.shape)
    attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
    if attention_mask.ndim == 2:
        attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

    dropout_rng = None

    if not deterministic and self.config.attn_config.attn_pdrop > 0.0:
        dropout_rng = self.make_rng("dropout")

    if self.has_variable("cache", "cached_key") or init_cache:
        key_states, value_states, attention_mask = self._concatenate_to_cache(
            key_states,
            value_states,
            query_states,
            attention_mask
        )

    attention_bias = lax.select(
        attention_mask > 0,
        jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
        jnp.full(attention_mask.shape, jnp.finfo(
            self.dtype).min).astype(self.dtype),
    )

    query_length, key_length = query_states.shape[1], key_states.shape[1]

    attentions = self.attention_performer.__call__(
        query_states=query_states,
        key_states=key_states,
        value_states=value_states,
        bias=attention_bias,
        attention_mask=attention_mask,
        causal=True,
        dropout_rng=dropout_rng,
        deterministic=deterministic,
        query_sequence_length=query_length,
        key_value_sequence_length=key_length,
        uses_cache=self.has_variable("cache", "cached_key") or init_cache,
        segment_ids=segment_ids,
        causal_mask=causal_mask
    )

    attn_output = self._merge_heads(attentions.attention_outputs)
    if self.config.shard_attention_computation:
        attn_output = with_sharding_constraint(
            attn_output, PartitionSpec(
                ("dp", "fsdp"),
                "sp" if attn_output.shape[1] != 1 else None,
                "tp"
            )
        )
    attn_output = self.out_proj(attn_output)

    attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
    return attn_output, attentions.attention_weights

apply_rotary(batch_size, sequence_length, query, key, value, freq_cis, position_ids)

The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
batch_size

Reshape the query, key and value tensors

required
sequence_length

Reshape the query, key and value tensors

required
query

Calculate the attention weights

required
key

Calculate the attention

required
value

Compute the attention weights

required
freq_cis

Calculate the frequency of each word in the vocabulary

required
position_ids

Identify the position of each token in the sequence

required

Returns:

Type Description

A tuple of 3 tensors: query, key and value

Source code in src/python/easydel/modules/dbrx/modelling_dbrx_flax.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids):
    """
    The apply_rotary function is a modified version of the apply_attention function in the BertModel class.
    The main difference is that it takes in an additional argument, freq_cis, which are used to calculate
    the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.

    :param self: Access variables that belong to the class
    :param batch_size: Reshape the query, key and value tensors
    :param sequence_length: Reshape the query, key and value tensors
    :param query: Calculate the attention weights
    :param key: Calculate the attention
    :param value: Compute the attention weights
    :param freq_cis: Calculate the frequency of each word in the vocabulary
    :param position_ids: Identify the position of each token in the sequence
    :return: A tuple of 3 tensors: query, key and value

    """
    query = query.reshape(
        batch_size,
        sequence_length,
        self.num_attention_heads,
        self.head_dim
    )
    key = key.reshape(
        batch_size,
        sequence_length,
        self.num_key_value_heads,
        self.head_dim
    )
    value = value.reshape(
        batch_size,
        sequence_length,
        self.num_key_value_heads,
        self.head_dim
    )

    query, key, value = self._transpose_sequence_head(query, key, value)
    query, key = self.rotary(
        position_ids=position_ids, query=query, key=key, freq_cis=freq_cis
    )
    key = repeat_kv_bnsh(key, self.num_key_value_groups)
    value = repeat_kv_bnsh(value, self.num_key_value_groups)
    return self._transpose_sequence_head(query, key, value)

FlaxDbrxForCausalLM

Bases: DbrxPreTrainedModel

Source code in src/python/easydel/modules/dbrx/modelling_dbrx_flax.py
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
class FlaxDbrxForCausalLM(DbrxPreTrainedModel):
    module_class = FlaxDbrxForCausalLMModule

    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
        """
        The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

        :param self: Access variables that belong to the class
        :param input_ids: Pass in the input tokens
        :param max_length: Set the length of the sequence to be generated
        :param attention_mask: Optional[chex.Array]: Mask the attention weights
        :return: A dictionary of the past_key_values, attention_mask and position ids

        """
        batch_size, seq_length = input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length)
        extended_attention_mask = jnp.ones(
            (batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(
                extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                            None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs

prepare_inputs_for_generation(input_ids, max_length, attention_mask=None)

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
input_ids

Pass in the input tokens

required
max_length

Set the length of the sequence to be generated

required
attention_mask Optional[Array]

Optional[chex.Array]: Mask the attention weights

None

Returns:

Type Description

A dictionary of the past_key_values, attention_mask and position ids

Source code in src/python/easydel/modules/dbrx/modelling_dbrx_flax.py
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
    """
    The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

    :param self: Access variables that belong to the class
    :param input_ids: Pass in the input tokens
    :param max_length: Set the length of the sequence to be generated
    :param attention_mask: Optional[chex.Array]: Mask the attention weights
    :return: A dictionary of the past_key_values, attention_mask and position ids

    """
    batch_size, seq_length = input_ids.shape

    past_key_values = self.init_cache(batch_size, max_length)
    extended_attention_mask = jnp.ones(
        (batch_size, max_length), dtype="i4")
    if attention_mask is not None:
        position_ids = attention_mask.cumsum(axis=-1) - 1
        extended_attention_mask = lax.dynamic_update_slice(
            extended_attention_mask, attention_mask, (0, 0))
    else:
        position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                        None, :], (batch_size, seq_length))

    return {
        "past_key_values": past_key_values,
        "attention_mask": extended_attention_mask,
        "position_ids": position_ids,
    }