modules.mamba.modelling_mamba_flax
FlaxMambaPretrainedModel
Bases: EasyDeLFlaxPretrainedModel
Source code in src/python/easydel/modules/mamba/modelling_mamba_flax.py
878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 |
|
__call__(input_ids=None, inputs_embeds=None, cache_params=None, deterministic=True, params=None, dropout_rng=None, train=False, output_hidden_states=None, return_dict=None, extra_embedding=None, add_params_field=False, attention_mask=None, use_cache=False, **kwargs)
The call function is the main function of a JAX module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
Represent the instance of the class |
required | |
input_ids |
Optional[Array]
|
Optional[chex.Array]: Pass in the input tokens |
None
|
inputs_embeds |
Optional[Array]
|
Optional[chex.Array]: Pass in the embedded tokens |
None
|
cache_params |
dict
|
dict: Pass in the past cache_params from a previous call to call |
None
|
params |
dict
|
dict: Pass in the parameters of the model |
None
|
dropout_rng |
PRNGKey
|
jax.random.PRNGKey: Make sure that the dropout is applied in a random way |
None
|
train |
bool
|
bool: Determine whether to use dropout or not |
False
|
output_hidden_states |
Optional[bool]
|
Optional[bool]: Return the hidden states of all layers |
None
|
return_dict |
Optional[bool]
|
Optional[bool]: Determine whether to return a dictionary or not |
None
|
extra_embedding |
Optional[Union[ndarray, None]]
|
Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids |
None
|
add_params_field |
bool
|
bool: Add the params field to the inputs dictionary |
False
|
Returns:
Type | Description |
---|---|
A tuple of the following: |
Source code in src/python/easydel/modules/mamba/modelling_mamba_flax.py
965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 |
|
__init__(config, input_shape=(1, 1), seed=0, dtype=jnp.float32, param_dtype=jnp.float32, precision=None, _do_init=True, **kwargs)
The init function is called when the class is instantiated. It sets up the instance of the class, and defines what happens when it's created. The init function can take arguments, but self is always required (it refers to the instance of the object).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
Refer to the object itself |
required | |
config |
MambaConfig
|
MambaConfig: Pass the configuration to the module |
required |
input_shape |
Tuple
|
Tuple: Specify the shape of the input to the model |
(1, 1)
|
seed |
int
|
int: Set the seed for random number generation |
0
|
dtype |
dtype
|
jnp.dtype: Specify the data type of the model ra |
float32
|
param_dtype |
dtype
|
jnp.dtype: Specify the data type of the param_dtype |
float32
|
precision |
Optional[Union[str, Precision]]
|
Optional[Union[str, lax.Precision]]: precision for model operations |
None
|
_do_init |
bool
|
bool: Control whether the module is initialized or not |
True
|
kwargs |
Pass in any additional parameters that the module_class might need |
{}
|
|
|
Specify the number of layers in the network |
required |
Returns:
Type | Description |
---|---|
The super() of the class |
Source code in src/python/easydel/modules/mamba/modelling_mamba_flax.py
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 |
|
init_weights(rng, input_shape, params=None)
The init_weights function is used to initialize the weights of a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
Access variables that belong to the class |
required | |
rng |
PRNGKey
|
jax.random.PRNGKey: Initialize the weights of the model |
required |
input_shape |
Tuple
|
Tuple: Specify the shape of the input tensor |
required |
params |
FrozenDict
|
FrozenDict: Pass in the parameters of a pre-trained model |
None
|
Returns:
Type | Description |
---|---|
FrozenDict
|
A frozendict of parameters |
Source code in src/python/easydel/modules/mamba/modelling_mamba_flax.py
929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 |
|