Skip to content

modules.falcon.modelling_falcon_flax

FlaxFalconPretrainedModel

Bases: EasyDeLFlaxPretrainedModel

Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
class FlaxFalconPretrainedModel(EasyDeLFlaxPretrainedModel):
    module_class: nn.Module = None
    config_class = FalconConfig

    def __init__(self, config,
                 _do_init=False,
                 dtype: jnp.dtype = jnp.float32,
                 param_dtype: jnp.dtype = jnp.float32,
                 input_shape: Tuple = (1, 1),
                 precision: Optional[Union[str, jax.lax.Precision]] = jax.lax.Precision("fastest")
                 ):
        module = self.module_class(config=config, dtype=dtype, param_dtype=param_dtype, precision=precision)
        super().__init__(_do_init=_do_init, module=module, config=config, dtype=dtype, input_shape=input_shape)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        """
        The init_weights function is used to initialize the weights of a model.

        :param self: Access variables that belong to the class
        :param rng: jax.random.PRNGKey: Initialize the weights of the model
        :param input_shape: Tuple: Specify the shape of the input tensor
        :param params: FrozenDict: Pass in the parameters of a pre-trained model
        :return: A frozendict of parameters

        """
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(
            jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        if self.config.add_cross_attention:
            encoder_hidden_states = jnp.zeros(
                input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                position_ids,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                position_ids,
                return_dict=False
            )

        random_params = module_init_outputs["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: Optional[chex.Array] = None,
            position_ids: Optional[chex.Array] = None,
            past_key_values: Optional[nn.Module] = None,
            output_attentions: bool = False,
            train: bool = True,
            return_dict: Optional[bool] = True,
            params: FrozenDict = None,
            add_params_field: bool = False,
            **kwargs
    ):
        input_ids = jnp.asarray(input_ids, dtype=jnp.int32)
        inputs = {'params': params or self.params} if add_params_field else params or self.params
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")

            position_ids = jnp.broadcast_to(jnp.arange(input_ids.shape[1])[None, :],
                                            (input_ids.shape[0], input_ids.shape[1]))
        rngs = {}
        if self.config.bits is not None:
            rngs['params'] = jax.random.key(0)
        if attention_mask is None:
            attention_mask = jnp.ones((input_ids.shape[0], input_ids.shape[1]))

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            output_attentions,
            not train,
            False,
            return_dict,
            mutable=mutable,
            rngs=rngs
        )

        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
        return outputs

    def init_cache(self, batch_size, max_length):

        input_ids = jnp.ones((batch_size, max_length))
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return init_variables["cache"]

    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
        batch_size, seq_length = input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length)
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    @staticmethod
    def update_inputs_for_generation(model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs

init_weights(rng, input_shape, params=None)

The init_weights function is used to initialize the weights of a model.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
rng PRNGKey

jax.random.PRNGKey: Initialize the weights of the model

required
input_shape Tuple

Tuple: Specify the shape of the input tensor

required
params FrozenDict

FrozenDict: Pass in the parameters of a pre-trained model

None

Returns:

Type Description
FrozenDict

A frozendict of parameters

Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
    """
    The init_weights function is used to initialize the weights of a model.

    :param self: Access variables that belong to the class
    :param rng: jax.random.PRNGKey: Initialize the weights of the model
    :param input_shape: Tuple: Specify the shape of the input tensor
    :param params: FrozenDict: Pass in the parameters of a pre-trained model
    :return: A frozendict of parameters

    """
    input_ids = jnp.zeros(input_shape, dtype="i4")
    attention_mask = jnp.ones_like(input_ids)
    position_ids = jnp.broadcast_to(jnp.arange(
        jnp.atleast_2d(input_ids).shape[-1]), input_shape)
    params_rng, dropout_rng = jax.random.split(rng)
    rngs = {"params": params_rng, "dropout": dropout_rng}

    if self.config.add_cross_attention:
        encoder_hidden_states = jnp.zeros(
            input_shape + (self.config.hidden_size,))
        encoder_attention_mask = attention_mask
        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            encoder_hidden_states,
            encoder_attention_mask,
            return_dict=False,
        )
    else:
        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            return_dict=False
        )

    random_params = module_init_outputs["params"]

    if params is not None:
        random_params = flatten_dict(unfreeze(random_params))
        params = flatten_dict(unfreeze(params))
        for missing_key in self._missing_keys:
            params[missing_key] = random_params[missing_key]
        self._missing_keys = set()
        return freeze(unflatten_dict(params))
    else:
        return random_params

apply_rotary_pos_embedding(tensor, sin_, cos_)

The apply_rotary_pos_embedding function applies a rotary positional embedding to the input tensor.

Parameters:

Name Type Description Default
tensor

Pass in the tensor that we want to apply the positional embedding to

required
sin_

Rotate the tensor by half of its length

required
cos_

Multiply the tensor and cosine of the angle

required

Returns:

Type Description

A tensor with the same shape as its input,

Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def apply_rotary_pos_embedding(tensor, sin_, cos_):
    """
    The apply_rotary_pos_embedding function applies a rotary positional embedding to the input tensor.

    :param tensor: Pass in the tensor that we want to apply the positional embedding to
    :param sin_: Rotate the tensor by half of its length
    :param cos_: Multiply the tensor and cosine of the angle
    :return: A tensor with the same shape as its input,

    """
    return (tensor * cos_) + (_rotate_half(tensor) * sin_)

built_bloom_alibi(attention_mask, num_attention_heads)

The built_bloom_alibi function is used to create a bloom alibi for the attention mask. The bloom alibi is used in the Bloom Attention layer to ensure that each token has a unique attention vector, even if it's masked out. This ensures that all tokens have an equal chance of being selected as the most important token in the sequence, which helps with training stability and performance.

Parameters:

Name Type Description Default
attention_mask

Mask out the padding tokens in the input sequence

required
num_attention_heads

Determine the number of attention heads in the model

required

Returns:

Type Description

A tensor of shape (batch_size, num_attention_heads, 1, sequence_length)

Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def built_bloom_alibi(attention_mask, num_attention_heads):
    """
    The built_bloom_alibi function is used to create a bloom alibi for the attention mask.
    The bloom alibi is used in the Bloom Attention layer to ensure that each token has a unique
    attention vector, even if it's masked out. This ensures that all tokens have an equal chance of being selected as
    the most important token in the sequence, which helps with training stability and performance.

    :param attention_mask: Mask out the padding tokens in the input sequence
    :param num_attention_heads: Determine the number of attention heads in the model
    :return: A tensor of shape (batch_size, num_attention_heads, 1, sequence_length)

    """
    batch_size, sequence_length = attention_mask.shape
    cp2 = 2 ** math.floor(math.log2(num_attention_heads))
    base = jnp.asarray(
        2 ** (- (2 ** -(math.log2(cp2) - 3))), dtype=jnp.float32
    )
    powers = jnp.arange(1, 1 + cp2, dtype=jnp.float32)
    slops = jnp.power(base, powers)
    if cp2 != num_attention_heads:
        extra_base = jnp.asarray(
            2 ** (-(2 ** -(math.log2(2 * cp2) - 3))), dtype=jnp.float32
        )
        num_rem_heads = min(cp2, num_attention_heads - cp2)
        extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
        slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
    arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
    alibi = slops[..., jnp.newaxis].astype(jnp.bfloat16) * arange_tensor
    return alibi.reshape(batch_size, num_attention_heads, 1, sequence_length)

dropout_add(linen_drop, x, residual, deterministic)

The dropout_add function is a helper function that adds the residual to the output of the dropout layer. This is necessary because we want to use deterministic=True when we are evaluating our model, but we still need to add in the residual. The reason for this is that during training, we have two paths through our network: one with dropout and one without. The path without dropout (residual) allows us to backpropagate gradients through both paths at once.

Parameters:

Name Type Description Default
linen_drop Dropout

flax.linen.Dropout: Specify the dropout layer

required
x Array

chex.Array: Pass in the input to the dropout layer

required
residual Array

chex.Array: Add the residual to the output of dropout_add

required
deterministic bool

bool: Determine whether the dropout layer is active or not

required

Returns:

Type Description
Array

A tensor that is the sum of the residual and a dropout layer

Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def dropout_add(linen_drop: flax.linen.Dropout, x: chex.Array, residual: chex.Array, deterministic: bool) -> chex.Array:
    """
    The dropout_add function is a helper function that adds the residual to the output of
    the dropout layer. This is necessary because we want to use deterministic=True when
    we are evaluating our model, but we still need to add in the residual. The reason for this
    is that during training, we have two paths through our network: one with dropout and one without.
    The path without dropout (residual) allows us to backpropagate gradients through both paths at once.

    :param linen_drop: flax.linen.Dropout: Specify the dropout layer
    :param x: chex.Array: Pass in the input to the dropout layer
    :param residual: chex.Array: Add the residual to the output of dropout_add
    :param deterministic: bool: Determine whether the dropout layer is active or not
    :return: A tensor that is the sum of the residual and a dropout layer

    """
    out = linen_drop(inputs=x, deterministic=deterministic)
    out = residual + out
    return out

precompute_falcon_freq_cis(max_position_embedding, head_dim, theta=10000)

The precompute_falcon_freq_cis function is used to precompute the sinusoidal frequencies for the FALCON model. The function takes in three arguments: max_position_embedding, head_dim, and theta. The first two are self-explanatory; the third is a hyperparameter that controls how quickly the frequency increases with position (i.e., how many times higher it will be at position i than at position 0). The default value of 10000 was chosen because it worked well on the tasks we tested.

Parameters:

Name Type Description Default
max_position_embedding int

int: Set the maximum length of the sequence

required
head_dim int

int: Determine the size of the positional embedding

required
theta float

float: Adjust the frequency of the sinusoid

10000

Returns:

Type Description

A tuple of two arrays

Source code in src/python/easydel/modules/falcon/modelling_falcon_flax.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def precompute_falcon_freq_cis(max_position_embedding: int, head_dim: int, theta: float = 10000):
    """
    The precompute_falcon_freq_cis function is used to precompute the sinusoidal frequencies for the FALCON model.
    The function takes in three arguments: max_position_embedding, head_dim, and theta. The first two are self-explanatory;
    the third is a hyperparameter that controls how quickly the frequency increases with position (i.e., how many times
    higher it will be at position i than at position 0). The default value of 10000 was chosen because it worked well on
    the tasks we tested.

    :param max_position_embedding: int: Set the maximum length of the sequence
    :param head_dim: int: Determine the size of the positional embedding
    :param theta: float: Adjust the frequency of the sinusoid
    :return: A tuple of two arrays

    """
    inv_freq_cis = 1.0 / (theta ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim))
    freq = jnp.einsum("i , j -> i j", jnp.arange(max_position_embedding), inv_freq_cis).astype("float32")

    embed = jnp.concatenate((freq, freq), axis=-1)
    return jnp.sin(embed)[:, :], jnp.cos(embed)[:, :]