Skip to content

modules.mamba.modelling_mamba_flax

FlaxMambaPretrainedModel

Bases: EasyDeLFlaxPretrainedModel

Source code in src/python/easydel/modules/mamba/modelling_mamba_flax.py
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
class FlaxMambaPretrainedModel(EasyDeLFlaxPretrainedModel):
    config_class = MambaConfig
    base_model_prefix = "backbone"
    module_class: nn.Module = None

    def __init__(
            self,
            config: MambaConfig,
            input_shape: Tuple = (1, 1),
            seed: int = 0,
            dtype: jnp.dtype = jnp.float32,
            param_dtype: jnp.dtype = jnp.float32,
            precision: Optional[Union[str, lax.Precision]] = None,
            _do_init: bool = True,
            **kwargs,
    ):
        """
        The __init__ function is called when the class is instantiated.
        It sets up the instance of the class, and defines what happens when it's created.
        The __init__ function can take arguments, but self is always required (it refers to the instance of the object).


        :param self: Refer to the object itself
        :param config: MambaConfig: Pass the configuration to the module
        :param input_shape: Tuple: Specify the shape of the input to the model
        :param seed: int: Set the seed for random number generation
        :param dtype: jnp.dtype: Specify the data type of the model ra
        :param param_dtype: jnp.dtype: Specify the data type of the param_dtype
        :param precision: Optional[Union[str, lax.Precision]]: precision for model operations
        :param _do_init: bool: Control whether the module is initialized or not
        :param kwargs: Pass in any additional parameters that the module_class might need
        :param : Specify the number of layers in the network
        :return: The super() of the class

        """
        module = self.module_class(
            config=config,
            dtype=dtype,
            param_dtype=param_dtype,
            precision=precision,
            **kwargs
        )
        super().__init__(
            config,
            module,
            input_shape=(input_shape[0], 1),
            seed=seed,
            dtype=dtype,
            _do_init=_do_init
        )

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        """
        The init_weights function is used to initialize the weights of a model.

        :param self: Access variables that belong to the class
        :param rng: jax.random.PRNGKey: Initialize the weights of the model
        :param input_shape: Tuple: Specify the shape of the input tensor
        :param params: FrozenDict: Pass in the parameters of a pre-trained model
        :return: A frozendict of parameters

        """
        input_ids = jnp.zeros(input_shape, dtype="i4")
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            return_dict=False
        )

        random_params = module_init_outputs["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def init_cache(self, batch_size, max_length):
        return None

    def __call__(
            self,
            input_ids: Optional[chex.Array] = None,
            inputs_embeds: Optional[chex.Array] = None,
            cache_params: dict = None,
            deterministic: bool = True,
            params: dict = None,
            dropout_rng: jax.random.PRNGKey = None,
            train: bool = False,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            extra_embedding: Optional[Union[jnp.ndarray, None]] = None,
            add_params_field: bool = False,
            attention_mask: Optional[chex.Array] = None,  # Ignored(we are using an SSM model not attention)
            use_cache: bool = False,
            **kwargs
    ):
        """
        The __call__ function is the main function of a JAX module.

        :param self: Represent the instance of the class
        :param input_ids: Optional[chex.Array]: Pass in the input tokens
        :param inputs_embeds: Optional[chex.Array]: Pass in the embedded tokens
        :param cache_params: dict: Pass in the past cache_params from a previous call to __call__
        :param params: dict: Pass in the parameters of the model
        :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way
        :param train: bool: Determine whether to use dropout or not
        :param output_hidden_states: Optional[bool]: Return the hidden states of all layers
        :param return_dict: Optional[bool]: Determine whether to return a dictionary or not
        :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids
        :param add_params_field: bool: Add the params field to the inputs dictionary
        :return: A tuple of the following:

        """
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        assert sequence_length <= self.config.max_position_embeddings, "Maximum Position Embedding Reached !"
        if cache_params is not None:
            assert isinstance(cache_params, FlaxMambaCache), f"Wrong cache input_type of {type(cache_params)}"
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        rngs["params"] = jax.random.key(0)

        inputs = {
            "params": params or self.params
        } if add_params_field else params or self.params

        # input_ids: Optional[chex.Array] = None,
        # inputs_embeds: Optional[chex.Array] = None,
        # cache_params: Optional[chex.Array] = None,
        # deterministic: bool = True,
        # use_cache: Optional[bool] = None,
        # output_hidden_states: Optional[bool] = None,
        # return_dict: Optional[bool] = None,

        return self.module.apply(
            inputs,
            input_ids,
            inputs_embeds,
            cache_params,
            train,
            use_cache,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=False,
        )

__call__(input_ids=None, inputs_embeds=None, cache_params=None, deterministic=True, params=None, dropout_rng=None, train=False, output_hidden_states=None, return_dict=None, extra_embedding=None, add_params_field=False, attention_mask=None, use_cache=False, **kwargs)

The call function is the main function of a JAX module.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
input_ids Optional[Array]

Optional[chex.Array]: Pass in the input tokens

None
inputs_embeds Optional[Array]

Optional[chex.Array]: Pass in the embedded tokens

None
cache_params dict

dict: Pass in the past cache_params from a previous call to call

None
params dict

dict: Pass in the parameters of the model

None
dropout_rng PRNGKey

jax.random.PRNGKey: Make sure that the dropout is applied in a random way

None
train bool

bool: Determine whether to use dropout or not

False
output_hidden_states Optional[bool]

Optional[bool]: Return the hidden states of all layers

None
return_dict Optional[bool]

Optional[bool]: Determine whether to return a dictionary or not

None
extra_embedding Optional[Union[ndarray, None]]

Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids

None
add_params_field bool

bool: Add the params field to the inputs dictionary

False

Returns:

Type Description

A tuple of the following:

Source code in src/python/easydel/modules/mamba/modelling_mamba_flax.py
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
def __call__(
        self,
        input_ids: Optional[chex.Array] = None,
        inputs_embeds: Optional[chex.Array] = None,
        cache_params: dict = None,
        deterministic: bool = True,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        extra_embedding: Optional[Union[jnp.ndarray, None]] = None,
        add_params_field: bool = False,
        attention_mask: Optional[chex.Array] = None,  # Ignored(we are using an SSM model not attention)
        use_cache: bool = False,
        **kwargs
):
    """
    The __call__ function is the main function of a JAX module.

    :param self: Represent the instance of the class
    :param input_ids: Optional[chex.Array]: Pass in the input tokens
    :param inputs_embeds: Optional[chex.Array]: Pass in the embedded tokens
    :param cache_params: dict: Pass in the past cache_params from a previous call to __call__
    :param params: dict: Pass in the parameters of the model
    :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way
    :param train: bool: Determine whether to use dropout or not
    :param output_hidden_states: Optional[bool]: Return the hidden states of all layers
    :param return_dict: Optional[bool]: Determine whether to return a dictionary or not
    :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids
    :param add_params_field: bool: Add the params field to the inputs dictionary
    :return: A tuple of the following:

    """
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.return_dict

    batch_size, sequence_length = input_ids.shape

    assert sequence_length <= self.config.max_position_embeddings, "Maximum Position Embedding Reached !"
    if cache_params is not None:
        assert isinstance(cache_params, FlaxMambaCache), f"Wrong cache input_type of {type(cache_params)}"
    rngs = {}
    if dropout_rng is not None:
        rngs["dropout"] = dropout_rng

    rngs["params"] = jax.random.key(0)

    inputs = {
        "params": params or self.params
    } if add_params_field else params or self.params

    # input_ids: Optional[chex.Array] = None,
    # inputs_embeds: Optional[chex.Array] = None,
    # cache_params: Optional[chex.Array] = None,
    # deterministic: bool = True,
    # use_cache: Optional[bool] = None,
    # output_hidden_states: Optional[bool] = None,
    # return_dict: Optional[bool] = None,

    return self.module.apply(
        inputs,
        input_ids,
        inputs_embeds,
        cache_params,
        train,
        use_cache,
        output_hidden_states,
        return_dict,
        rngs=rngs,
        mutable=False,
    )

__init__(config, input_shape=(1, 1), seed=0, dtype=jnp.float32, param_dtype=jnp.float32, precision=None, _do_init=True, **kwargs)

The init function is called when the class is instantiated. It sets up the instance of the class, and defines what happens when it's created. The init function can take arguments, but self is always required (it refers to the instance of the object).

Parameters:

Name Type Description Default
self

Refer to the object itself

required
config MambaConfig

MambaConfig: Pass the configuration to the module

required
input_shape Tuple

Tuple: Specify the shape of the input to the model

(1, 1)
seed int

int: Set the seed for random number generation

0
dtype dtype

jnp.dtype: Specify the data type of the model ra

float32
param_dtype dtype

jnp.dtype: Specify the data type of the param_dtype

float32
precision Optional[Union[str, Precision]]

Optional[Union[str, lax.Precision]]: precision for model operations

None
_do_init bool

bool: Control whether the module is initialized or not

True
kwargs

Pass in any additional parameters that the module_class might need

{}

Specify the number of layers in the network

required

Returns:

Type Description

The super() of the class

Source code in src/python/easydel/modules/mamba/modelling_mamba_flax.py
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
def __init__(
        self,
        config: MambaConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        param_dtype: jnp.dtype = jnp.float32,
        precision: Optional[Union[str, lax.Precision]] = None,
        _do_init: bool = True,
        **kwargs,
):
    """
    The __init__ function is called when the class is instantiated.
    It sets up the instance of the class, and defines what happens when it's created.
    The __init__ function can take arguments, but self is always required (it refers to the instance of the object).


    :param self: Refer to the object itself
    :param config: MambaConfig: Pass the configuration to the module
    :param input_shape: Tuple: Specify the shape of the input to the model
    :param seed: int: Set the seed for random number generation
    :param dtype: jnp.dtype: Specify the data type of the model ra
    :param param_dtype: jnp.dtype: Specify the data type of the param_dtype
    :param precision: Optional[Union[str, lax.Precision]]: precision for model operations
    :param _do_init: bool: Control whether the module is initialized or not
    :param kwargs: Pass in any additional parameters that the module_class might need
    :param : Specify the number of layers in the network
    :return: The super() of the class

    """
    module = self.module_class(
        config=config,
        dtype=dtype,
        param_dtype=param_dtype,
        precision=precision,
        **kwargs
    )
    super().__init__(
        config,
        module,
        input_shape=(input_shape[0], 1),
        seed=seed,
        dtype=dtype,
        _do_init=_do_init
    )

init_weights(rng, input_shape, params=None)

The init_weights function is used to initialize the weights of a model.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
rng PRNGKey

jax.random.PRNGKey: Initialize the weights of the model

required
input_shape Tuple

Tuple: Specify the shape of the input tensor

required
params FrozenDict

FrozenDict: Pass in the parameters of a pre-trained model

None

Returns:

Type Description
FrozenDict

A frozendict of parameters

Source code in src/python/easydel/modules/mamba/modelling_mamba_flax.py
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
    """
    The init_weights function is used to initialize the weights of a model.

    :param self: Access variables that belong to the class
    :param rng: jax.random.PRNGKey: Initialize the weights of the model
    :param input_shape: Tuple: Specify the shape of the input tensor
    :param params: FrozenDict: Pass in the parameters of a pre-trained model
    :return: A frozendict of parameters

    """
    input_ids = jnp.zeros(input_shape, dtype="i4")
    params_rng, dropout_rng = jax.random.split(rng)
    rngs = {"params": params_rng, "dropout": dropout_rng}

    module_init_outputs = self.module.init(
        rngs,
        input_ids,
        return_dict=False
    )

    random_params = module_init_outputs["params"]

    if params is not None:
        random_params = flatten_dict(unfreeze(random_params))
        params = flatten_dict(unfreeze(params))
        for missing_key in self._missing_keys:
            params[missing_key] = random_params[missing_key]
        self._missing_keys = set()
        return freeze(unflatten_dict(params))
    else:
        return random_params