Skip to content

modules.deepseek_v2.modeling_deepseek_flax

DeepseekV2PreTrainedModel

Bases: EasyDeLFlaxPretrainedModel

Source code in src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
class DeepseekV2PreTrainedModel(EasyDeLFlaxPretrainedModel):
    config_class: DeepseekV2Config = DeepseekV2Config
    module_class: nn.Module = None
    base_model_prefix = "model"

    def __init__(
            self,
            config: DeepseekV2Config,
            dtype: jnp.dtype = jnp.bfloat16,
            param_dtype: jnp.dtype = jnp.bfloat16,
            precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"),
            input_shape: Tuple[int, int] = (1, 1),
            seed: int = 0,
            _do_init: bool = False,
            **kwargs
    ):
        module = self.module_class(
            config=config,
            dtype=dtype,
            param_dtype=param_dtype,
            precision=precision,
            **kwargs
        )

        super().__init__(
            dtype=dtype, _do_init=_do_init,
            module=module, config=config, input_shape=input_shape,
            seed=seed,
        )

    def init_weights(
            self,
            rng: jax.random.PRNGKey,
            input_shape: Tuple,
            params: FrozenDict = None
    ) -> FrozenDict:
        """
        The init_weights function is used to initialize the weights of a model.
        It takes in a rng, which is a random number generator key that can be used to generate random numbers.
        The input_shape parameter specifies the shape of the inputs that will be fed into this model.
        The params parameter allows you to pass in pre-trained weights for your model, if you have them available.

        :param self: Access variables that belong to the class
        :param rng: jax.random.PRNGKey: Initialize the weights of the model
        :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids
        :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model
        :return: A frozendict of parameters
        """

        self.config.initialization_of_moe = True
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids, dtype="i4")
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"),
            input_shape,
        )
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}
        if self.config.add_cross_attention:
            encoder_hidden_states = jnp.zeros(
                input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                position_ids,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            module_init_outputs = self.module.init(
                rngs,
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=False
            )
        random_params = module_init_outputs["params"]

        self.config.initialization_of_moe = False
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def init_cache(self, batch_size, max_length):

        input_ids = jnp.ones((batch_size, max_length))
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(
            jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return init_variables["cache"]

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: Optional[chex.Array] = None,
            position_ids: Optional[chex.Array] = None,
            params: dict = None,
            past_key_values: dict = None,
            dropout_rng: jax.random.PRNGKey = None,
            train: bool = False,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            output_router_logits: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            add_params_field: bool = False,
            **kwargs
    ):
        """
        The __call__ function is the main function of a JAX module.
        It takes as input:
        - The parameters of the model (self.params)
        - The inputs to the model (input_ids, attention_mask, position_ids)
        - Whether we are training (train=True/False) and whether we want to return all hidden states and
        attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False).

        :param self: Represent the instance of the class
        :param input_ids: Pass the input sequence to the model
        :param attention_mask: Mask out the padding tokens
        :param position_ids: Specify the position of each token in the sequence
        :param params: dict: Pass in the parameters of the model
        :param past_key_values: dict: Pass the past key values to the model
        :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model
        :param train: bool: Determine whether to use dropout or not
        :param output_attentions: Optional[bool]: Determine whether to return the attention weights
        :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers
        :param return_dict: Optional[bool]: Return a dictionary of the outputs
        :param add_params_field: bool: Add a params field to the inputs dictionary
        :return: A tuple of (last_hidden_state, past_key_values)

        """

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `position_ids` when passing `past_key_values`.")

            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[
                                            None, :], (batch_size, sequence_length))

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        rng_s = {}
        if dropout_rng is not None:
            rng_s["dropout"] = dropout_rng

        inputs = {
            "params": params or self.params} if add_params_field else params or self.params

        if self.config.bits is not None:
            rng_s['params'] = jax.random.key(0)
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),  # input_ids: chex.Array
            # attention_mask: Optional[chex.Array] = None
            jnp.array(attention_mask, dtype="i4"),
            # position_ids: Optional[chex.Array] = None
            jnp.array(position_ids, dtype="i4"),
            None,  # inputs_embeds: Optional[chex.Array] = None
            output_attentions,  # output_attentions: Optional[bool] = None
            # output_hidden_states: Optional[bool] = None
            output_hidden_states,
            # output_router_logits: Optional[bool] = None
            output_router_logits,
            False,  # init_cache: bool = False
            not train,  # deterministic: bool = True
            return_dict,  # return_dict: bool = True
            rngs=rng_s,
            mutable=mutable,
        )

        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + \
                      (unfreeze(past_key_values["cache"]),) + outputs[1:]

        return outputs

__call__(input_ids, attention_mask=None, position_ids=None, params=None, past_key_values=None, dropout_rng=None, train=False, output_attentions=None, output_hidden_states=None, output_router_logits=None, return_dict=None, add_params_field=False, **kwargs)

The call function is the main function of a JAX module. It takes as input: - The parameters of the model (self.params) - The inputs to the model (input_ids, attention_mask, position_ids) - Whether we are training (train=True/False) and whether we want to return all hidden states and attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False).

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
input_ids Array

Pass the input sequence to the model

required
attention_mask Optional[Array]

Mask out the padding tokens

None
position_ids Optional[Array]

Specify the position of each token in the sequence

None
params dict

dict: Pass in the parameters of the model

None
past_key_values dict

dict: Pass the past key values to the model

None
dropout_rng PRNGKey

jax.random.PRNGKey: Pass in a random number generator key to the model

None
train bool

bool: Determine whether to use dropout or not

False
output_attentions Optional[bool]

Optional[bool]: Determine whether to return the attention weights

None
output_hidden_states Optional[bool]

Optional[bool]: Determine whether to return the hidden states of all layers

None
return_dict Optional[bool]

Optional[bool]: Return a dictionary of the outputs

None
add_params_field bool

bool: Add a params field to the inputs dictionary

False

Returns:

Type Description

A tuple of (last_hidden_state, past_key_values)

Source code in src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
def __call__(
        self,
        input_ids: chex.Array,
        attention_mask: Optional[chex.Array] = None,
        position_ids: Optional[chex.Array] = None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        add_params_field: bool = False,
        **kwargs
):
    """
    The __call__ function is the main function of a JAX module.
    It takes as input:
    - The parameters of the model (self.params)
    - The inputs to the model (input_ids, attention_mask, position_ids)
    - Whether we are training (train=True/False) and whether we want to return all hidden states and
    attentions weights at each layer in addition to just the last layer output (output_hidden_states=True/False).

    :param self: Represent the instance of the class
    :param input_ids: Pass the input sequence to the model
    :param attention_mask: Mask out the padding tokens
    :param position_ids: Specify the position of each token in the sequence
    :param params: dict: Pass in the parameters of the model
    :param past_key_values: dict: Pass the past key values to the model
    :param dropout_rng: jax.random.PRNGKey: Pass in a random number generator key to the model
    :param train: bool: Determine whether to use dropout or not
    :param output_attentions: Optional[bool]: Determine whether to return the attention weights
    :param output_hidden_states: Optional[bool]: Determine whether to return the hidden states of all layers
    :param return_dict: Optional[bool]: Return a dictionary of the outputs
    :param add_params_field: bool: Add a params field to the inputs dictionary
    :return: A tuple of (last_hidden_state, past_key_values)

    """

    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.return_dict

    batch_size, sequence_length = input_ids.shape

    if position_ids is None:
        if past_key_values is not None:
            raise ValueError(
                "Make sure to provide `position_ids` when passing `past_key_values`.")

        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[
                                        None, :], (batch_size, sequence_length))

    if attention_mask is None:
        attention_mask = jnp.ones((batch_size, sequence_length))

    rng_s = {}
    if dropout_rng is not None:
        rng_s["dropout"] = dropout_rng

    inputs = {
        "params": params or self.params} if add_params_field else params or self.params

    if self.config.bits is not None:
        rng_s['params'] = jax.random.key(0)
    if past_key_values:
        inputs["cache"] = past_key_values
        mutable = ["cache"]
    else:
        mutable = False

    outputs = self.module.apply(
        inputs,
        jnp.array(input_ids, dtype="i4"),  # input_ids: chex.Array
        # attention_mask: Optional[chex.Array] = None
        jnp.array(attention_mask, dtype="i4"),
        # position_ids: Optional[chex.Array] = None
        jnp.array(position_ids, dtype="i4"),
        None,  # inputs_embeds: Optional[chex.Array] = None
        output_attentions,  # output_attentions: Optional[bool] = None
        # output_hidden_states: Optional[bool] = None
        output_hidden_states,
        # output_router_logits: Optional[bool] = None
        output_router_logits,
        False,  # init_cache: bool = False
        not train,  # deterministic: bool = True
        return_dict,  # return_dict: bool = True
        rngs=rng_s,
        mutable=mutable,
    )

    if past_key_values is not None and return_dict:
        outputs, past_key_values = outputs
        outputs["past_key_values"] = unfreeze(past_key_values["cache"])
        return outputs
    elif past_key_values is not None and not return_dict:
        outputs, past_key_values = outputs
        outputs = outputs[:1] + \
                  (unfreeze(past_key_values["cache"]),) + outputs[1:]

    return outputs

init_weights(rng, input_shape, params=None)

The init_weights function is used to initialize the weights of a model. It takes in a rng, which is a random number generator key that can be used to generate random numbers. The input_shape parameter specifies the shape of the inputs that will be fed into this model. The params parameter allows you to pass in pre-trained weights for your model, if you have them available.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
rng PRNGKey

jax.random.PRNGKey: Initialize the weights of the model

required
input_shape Tuple

Tuple: Initialize the input_ids, attention_mask and position_ids

required
params FrozenDict

flax.core.FrozenDict: Pass in the parameters of a pre-trained model

None

Returns:

Type Description
FrozenDict

A frozendict of parameters

Source code in src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
def init_weights(
        self,
        rng: jax.random.PRNGKey,
        input_shape: Tuple,
        params: FrozenDict = None
) -> FrozenDict:
    """
    The init_weights function is used to initialize the weights of a model.
    It takes in a rng, which is a random number generator key that can be used to generate random numbers.
    The input_shape parameter specifies the shape of the inputs that will be fed into this model.
    The params parameter allows you to pass in pre-trained weights for your model, if you have them available.

    :param self: Access variables that belong to the class
    :param rng: jax.random.PRNGKey: Initialize the weights of the model
    :param input_shape: Tuple: Initialize the input_ids, attention_mask and position_ids
    :param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model
    :return: A frozendict of parameters
    """

    self.config.initialization_of_moe = True
    input_ids = jnp.zeros(input_shape, dtype="i4")
    attention_mask = jnp.ones_like(input_ids, dtype="i4")
    position_ids = jnp.broadcast_to(
        jnp.arange(jnp.atleast_2d(input_ids).shape[-1], dtype="i4"),
        input_shape,
    )
    params_rng, dropout_rng = jax.random.split(rng)
    rngs = {"params": params_rng, "dropout": dropout_rng}
    if self.config.add_cross_attention:
        encoder_hidden_states = jnp.zeros(
            input_shape + (self.config.hidden_size,))
        encoder_attention_mask = attention_mask
        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            encoder_hidden_states,
            encoder_attention_mask,
            return_dict=False,
        )
    else:
        module_init_outputs = self.module.init(
            rngs,
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=False
        )
    random_params = module_init_outputs["params"]

    self.config.initialization_of_moe = False
    if params is not None:
        random_params = flatten_dict(unfreeze(random_params))
        params = flatten_dict(unfreeze(params))
        for missing_key in self._missing_keys:
            params[missing_key] = random_params[missing_key]
        self._missing_keys = set()
        return freeze(unflatten_dict(params))
    else:
        return random_params

FlaxDeepseekV2ForCausalLM

Bases: DeepseekV2PreTrainedModel

Source code in src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
class FlaxDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
    module_class = FlaxDeepseekV2ForCausalLMModule

    def set_input_embeddings(self, value):
        self.module.model.embed_tokens = value

    def get_input_embeddings(self):
        return self.module.model.embed_tokens

    def set_decoder(self, decoder):
        self.module.model = decoder

    def get_decoder(self):
        return self.module.model

    def get_output_embeddings(self):
        return self.module.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.module.lm_head = new_embeddings

    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
        """
        The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

        :param self: Access variables that belong to the class
        :param input_ids: Pass in the input tokens
        :param max_length: Set the length of the sequence to be generated
        :param attention_mask: Optional[chex.Array]: Mask the attention weights
        :return: A dictionary of the past_key_values, attention_mask and position ids

        """
        batch_size, seq_length = input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length)
        extended_attention_mask = jnp.ones(
            (batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(
                extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                            None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs

prepare_inputs_for_generation(input_ids, max_length, attention_mask=None)

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
input_ids

Pass in the input tokens

required
max_length

Set the length of the sequence to be generated

required
attention_mask Optional[Array]

Optional[chex.Array]: Mask the attention weights

None

Returns:

Type Description

A dictionary of the past_key_values, attention_mask and position ids

Source code in src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
    """
    The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

    :param self: Access variables that belong to the class
    :param input_ids: Pass in the input tokens
    :param max_length: Set the length of the sequence to be generated
    :param attention_mask: Optional[chex.Array]: Mask the attention weights
    :return: A dictionary of the past_key_values, attention_mask and position ids

    """
    batch_size, seq_length = input_ids.shape

    past_key_values = self.init_cache(batch_size, max_length)
    extended_attention_mask = jnp.ones(
        (batch_size, max_length), dtype="i4")
    if attention_mask is not None:
        position_ids = attention_mask.cumsum(axis=-1) - 1
        extended_attention_mask = lax.dynamic_update_slice(
            extended_attention_mask, attention_mask, (0, 0))
    else:
        position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                        None, :], (batch_size, seq_length))

    return {
        "past_key_values": past_key_values,
        "attention_mask": extended_attention_mask,
        "position_ids": position_ids,
    }

FlaxDeepseekV2ForCausalLMModule

Bases: Module

Source code in src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
class FlaxDeepseekV2ForCausalLMModule(nn.Module):
    config: DeepseekV2Config
    dtype: jnp.dtype = jnp.bfloat16
    param_dtype: jnp.dtype = jnp.bfloat16
    precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest")

    def setup(self) -> None:
        self.model = FlaxDeepseekV2Module(
            config=self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision
        )
        self.lm_head = nn.Linear(
            self.config.vocab_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
            use_bias=False,
            kernel_init=nn.initializers.normal(self.config.initializer_range),
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: chex.Array,
            position_ids: chex.Array,
            deterministic: bool = True,
            inputs_embeds: chex.Array = None,
            init_cache: bool = False,
            output_attentions: bool = False,
            output_hidden_states: bool = False,
            return_dict: bool = True,
    ):
        """
            The __call__ function is the main function of a Flax module. It defines how the model will be called,
            and what it returns. In this case, we are calling our Transformer model with input_ids and attention_mask
            as inputs (these are defined in __init__). We also have some optional arguments that can be passed to
            the call function: deterministic (whether to use dropout), inputs_embeds (if you want to pass your own embeddings),
            output_attentions and output_hidden states which return additional outputs from the transformer layers if set True. Finally,

            :param self: Refer to the object itself
            :param input_ids: chex.Array: Pass in the input tokens
            :param attention_mask: chex.Array: Mask out the padding tokens
            :param position_ids: chex.Array: Specify the position of each token in the sequence
            :param deterministic: bool: Determine whether to use dropout in the model
            :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens
            :param init_cache: bool: Initialize the cache for the decoder
            :param output_attentions: bool: Return the attention weights
            :param output_hidden_states: bool: Return the hidden states of all layers
            :param return_dict: bool: Return a dictionary of the outputs or just the logits
            :param : Determine whether to return the logits or not
            :return: A tuple of (lm_logits, hidden_states, attentions)

        """
        batch_size, seq_length = input_ids.shape

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
                (batch_size, seq_length)
            )
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            deterministic=deterministic,
            inputs_embeds=inputs_embeds,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

        hidden_states = outputs[0]

        if self.config.tie_word_embeddings:
            shared_kernel = self.transformer.variables["params"]["embed_tokens"]["embedding"]
            shared_kernel = fjformer.linen.linen.control_quantization(shared_kernel, self.param_dtype).T
            lm_logits = self.lm_head.apply(
                {"params": {"kernel": shared_kernel}}, hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)

        # lm_logits = lm_logits.astype(jnp.float32)

        if not return_dict:
            return (lm_logits,) + outputs[1:]

        return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

__call__(input_ids, attention_mask, position_ids, deterministic=True, inputs_embeds=None, init_cache=False, output_attentions=False, output_hidden_states=False, return_dict=True)

The call function is the main function of a Flax module. It defines how the model will be called, and what it returns. In this case, we are calling our Transformer model with input_ids and attention_mask as inputs (these are defined in init). We also have some optional arguments that can be passed to the call function: deterministic (whether to use dropout), inputs_embeds (if you want to pass your own embeddings), output_attentions and output_hidden states which return additional outputs from the transformer layers if set True. Finally,

Parameters:

Name Type Description Default
self

Refer to the object itself

required
input_ids Array

chex.Array: Pass in the input tokens

required
attention_mask Array

chex.Array: Mask out the padding tokens

required
position_ids Array

chex.Array: Specify the position of each token in the sequence

required
deterministic bool

bool: Determine whether to use dropout in the model

True
inputs_embeds Array

chex.Array: Pass in the embeddings of the input tokens

None
init_cache bool

bool: Initialize the cache for the decoder

False
output_attentions bool

bool: Return the attention weights

False
output_hidden_states bool

bool: Return the hidden states of all layers

False
return_dict bool

bool: Return a dictionary of the outputs or just the logits

True

Determine whether to return the logits or not

required

Returns:

Type Description

A tuple of (lm_logits, hidden_states, attentions)

Source code in src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
def __call__(
        self,
        input_ids: chex.Array,
        attention_mask: chex.Array,
        position_ids: chex.Array,
        deterministic: bool = True,
        inputs_embeds: chex.Array = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
):
    """
        The __call__ function is the main function of a Flax module. It defines how the model will be called,
        and what it returns. In this case, we are calling our Transformer model with input_ids and attention_mask
        as inputs (these are defined in __init__). We also have some optional arguments that can be passed to
        the call function: deterministic (whether to use dropout), inputs_embeds (if you want to pass your own embeddings),
        output_attentions and output_hidden states which return additional outputs from the transformer layers if set True. Finally,

        :param self: Refer to the object itself
        :param input_ids: chex.Array: Pass in the input tokens
        :param attention_mask: chex.Array: Mask out the padding tokens
        :param position_ids: chex.Array: Specify the position of each token in the sequence
        :param deterministic: bool: Determine whether to use dropout in the model
        :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens
        :param init_cache: bool: Initialize the cache for the decoder
        :param output_attentions: bool: Return the attention weights
        :param output_hidden_states: bool: Return the hidden states of all layers
        :param return_dict: bool: Return a dictionary of the outputs or just the logits
        :param : Determine whether to return the logits or not
        :return: A tuple of (lm_logits, hidden_states, attentions)

    """
    batch_size, seq_length = input_ids.shape

    if attention_mask is None:
        attention_mask = jnp.ones_like(input_ids)
    if position_ids is None:
        position_ids = jnp.broadcast_to(
            jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
            (batch_size, seq_length)
        )
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        deterministic=deterministic,
        inputs_embeds=inputs_embeds,
        init_cache=init_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict
    )

    hidden_states = outputs[0]

    if self.config.tie_word_embeddings:
        shared_kernel = self.transformer.variables["params"]["embed_tokens"]["embedding"]
        shared_kernel = fjformer.linen.linen.control_quantization(shared_kernel, self.param_dtype).T
        lm_logits = self.lm_head.apply(
            {"params": {"kernel": shared_kernel}}, hidden_states)
    else:
        lm_logits = self.lm_head(hidden_states)

    # lm_logits = lm_logits.astype(jnp.float32)

    if not return_dict:
        return (lm_logits,) + outputs[1:]

    return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

FlaxDeepseekV2Module

Bases: Module

Source code in src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
class FlaxDeepseekV2Module(nn.Module):
    config: DeepseekV2Config
    dtype: jnp.dtype = jnp.bfloat16
    param_dtype: jnp.dtype = jnp.bfloat16
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self):

        self.embed_tokens = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range),
            dtype=self.dtype,
            param_dtype=self.param_dtype,
        )

        self.layers = FlaxDeepseekV2DecoratorCollection(
            self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision
        )
        self.norm = DeepseekV2RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
            dtype=self.dtype,
            param_dtype=self.param_dtype
        )

        initial_rope_kwargs = {}
        method = None
        if self.config.rope_scaling is not None:
            scaling_type = self.config.rope_scaling["type"]
            method = scaling_type
            if scaling_type != "yarn":
                initial_rope_kwargs = dict(scaling_factor=self.config.rope_scaling["factor"])
            else:
                initial_rope_kwargs = {
                    key: self.config.rope_scaling[key]
                    for key in [
                        "original_max_position_embeddings",
                        "beta_fast",
                        "beta_slow",
                        "mscale",
                        "mscale_all_dim",
                    ]
                    if key in self.config.rope_scaling
                }
                initial_rope_kwargs["scaling_factor"] = self.config.rope_scaling["factor"]
        self.freq_cis = init_deepseek_rotary_embedding(
            dim=self.config.hidden_size // self.config.num_attention_heads,
            max_position_embeddings=(
                getattr(
                    self.config,
                    "freq_max_position_embeddings",
                    self.config.max_position_embeddings
                )
            ),
            base=self.config.rope_theta,
            method=method,  # type:ignore
            kwargs=initial_rope_kwargs
        )
        self.causal_mask = flax.linen.make_causal_mask(
            jnp.ones(
                (
                    1,
                    getattr(
                        self.config,
                        "c_max_position_embeddings",
                        self.config.max_position_embeddings
                    )
                ),
                dtype="bool"
            ),
            dtype="bool"
        )

    def __call__(
            self,
            input_ids: Optional[chex.Array] = None,
            attention_mask: Optional[chex.Array] = None,
            position_ids: Optional[chex.Array] = None,
            deterministic: bool = True,
            inputs_embeds: chex.Array = None,
            init_cache: bool = False,
            output_attentions: bool = False,
            output_hidden_states: bool = False,
            return_dict: bool = True,
    ) -> typing.Union[Tuple[chex.Array, ...], FlaxBaseModelOutput]:
        """
        The __call__ function is the main function of a Flax model.
        It takes in input_ids, attention_mask, and position_ids as inputs to the model.
        The output is a tuple containing: last hidden state (hidden states), all hidden states (if output_hidden_states=True), attentions (if output attentions=True).


        :param self: Represent the instance of the class
        :param input_ids: chex.Array: Pass in the input ids
        :param attention_mask: chex.Array: Mask out the attention weights for certain tokens
        :param position_ids: chex.Array: Determine the position of each token in a sequence
        :param deterministic: bool: Determine whether to use dropout or not
        :param inputs_embeds: chex.Array: Pass in the embedding of the input_ids
        :param init_cache: bool: Initialize the cache for the decoder
        :param output_attentions: bool: Determine whether to return the attention weights or not
        :param output_hidden_states: bool: Return all hidden states or just the last one
        :param return_dict: bool: Return a dictionary of the outputs or not
        :param : Determine whether the model is in training mode or not
        :return: A tuple of the hidden states, all hidden states, and attentions

        """
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids.astype("i4"))
        if attention_mask.ndim == 2:
            b, s = attention_mask.shape
            attention_mask = attention_mask.reshape(b, 1, 1, s)

        outputs = self.layers(
            hidden_states=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            freq_cis=self.freq_cis,
            init_cache=init_cache,
            output_attentions=output_attentions,
            deterministic=deterministic,
            causal_mask=self.causal_mask
        )

        hidden_states = outputs[0]
        hidden_states = self.norm(hidden_states)

        if output_hidden_states:
            all_hidden_states = outputs[1] + (hidden_states,)
            outputs = (hidden_states, all_hidden_states) + outputs[2:]
        else:
            outputs = (hidden_states,) + outputs[1:]

        if not return_dict:
            return tuple(value for value in outputs if value is not None)

        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=outputs[1],
            attentions=outputs[-1],
        )

__call__(input_ids=None, attention_mask=None, position_ids=None, deterministic=True, inputs_embeds=None, init_cache=False, output_attentions=False, output_hidden_states=False, return_dict=True)

The call function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids as inputs to the model. The output is a tuple containing: last hidden state (hidden states), all hidden states (if output_hidden_states=True), attentions (if output attentions=True).

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
input_ids Optional[Array]

chex.Array: Pass in the input ids

None
attention_mask Optional[Array]

chex.Array: Mask out the attention weights for certain tokens

None
position_ids Optional[Array]

chex.Array: Determine the position of each token in a sequence

None
deterministic bool

bool: Determine whether to use dropout or not

True
inputs_embeds Array

chex.Array: Pass in the embedding of the input_ids

None
init_cache bool

bool: Initialize the cache for the decoder

False
output_attentions bool

bool: Determine whether to return the attention weights or not

False
output_hidden_states bool

bool: Return all hidden states or just the last one

False
return_dict bool

bool: Return a dictionary of the outputs or not

True

Determine whether the model is in training mode or not

required

Returns:

Type Description
Union[Tuple[Array, ...], FlaxBaseModelOutput]

A tuple of the hidden states, all hidden states, and attentions

Source code in src/python/easydel/modules/deepseek_v2/modeling_deepseek_flax.py
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
def __call__(
        self,
        input_ids: Optional[chex.Array] = None,
        attention_mask: Optional[chex.Array] = None,
        position_ids: Optional[chex.Array] = None,
        deterministic: bool = True,
        inputs_embeds: chex.Array = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
) -> typing.Union[Tuple[chex.Array, ...], FlaxBaseModelOutput]:
    """
    The __call__ function is the main function of a Flax model.
    It takes in input_ids, attention_mask, and position_ids as inputs to the model.
    The output is a tuple containing: last hidden state (hidden states), all hidden states (if output_hidden_states=True), attentions (if output attentions=True).


    :param self: Represent the instance of the class
    :param input_ids: chex.Array: Pass in the input ids
    :param attention_mask: chex.Array: Mask out the attention weights for certain tokens
    :param position_ids: chex.Array: Determine the position of each token in a sequence
    :param deterministic: bool: Determine whether to use dropout or not
    :param inputs_embeds: chex.Array: Pass in the embedding of the input_ids
    :param init_cache: bool: Initialize the cache for the decoder
    :param output_attentions: bool: Determine whether to return the attention weights or not
    :param output_hidden_states: bool: Return all hidden states or just the last one
    :param return_dict: bool: Return a dictionary of the outputs or not
    :param : Determine whether the model is in training mode or not
    :return: A tuple of the hidden states, all hidden states, and attentions

    """
    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids.astype("i4"))
    if attention_mask.ndim == 2:
        b, s = attention_mask.shape
        attention_mask = attention_mask.reshape(b, 1, 1, s)

    outputs = self.layers(
        hidden_states=inputs_embeds,
        attention_mask=attention_mask,
        position_ids=position_ids,
        freq_cis=self.freq_cis,
        init_cache=init_cache,
        output_attentions=output_attentions,
        deterministic=deterministic,
        causal_mask=self.causal_mask
    )

    hidden_states = outputs[0]
    hidden_states = self.norm(hidden_states)

    if output_hidden_states:
        all_hidden_states = outputs[1] + (hidden_states,)
        outputs = (hidden_states, all_hidden_states) + outputs[2:]
    else:
        outputs = (hidden_states,) + outputs[1:]

    if not return_dict:
        return tuple(value for value in outputs if value is not None)

    return FlaxBaseModelOutput(
        last_hidden_state=hidden_states,
        hidden_states=outputs[1],
        attentions=outputs[-1],
    )