Skip to content

modules.mosaic_mpt.modelling_mpt_flax

FlaxMptAttention

Bases: BaseJAXAttentionModule

Source code in src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
class FlaxMptAttention(BaseJAXAttentionModule):
    config: MptConfig
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self) -> None:

        self.w_qkv = Linear(
            self.config.d_model * 3,
            kernel_init=jax.nn.initializers.normal(),
            use_bias=self.config.use_bias,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method),
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision)
        self.wo = Linear(
            self.config.d_model,
            kernel_init=jax.nn.initializers.normal(),
            use_bias=self.config.use_bias,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.attention_performer = AttentionModule(
            use_sharding_constraint=self.config.use_sharding_constraint,
            block_k_major=self.config.block_k_major,
            block_b=self.config.block_b,
            block_q=self.config.block_q,
            block_k=self.config.block_k,
            block_q_major_dkv=self.config.block_q_major_dkv,
            block_k_major_dkv=self.config.block_k_major_dkv,
            block_k_major_dq=self.config.block_k_major_dq,
            block_k_dkv=self.config.block_k_dkv,
            block_q_dkv=self.config.block_q_dkv,
            block_q_dq=self.config.block_q_dq,
            block_k_dq=self.config.block_k_dq,
            num_attention_heads=self.config.num_attention_heads,
            attention_dropout=self.config.attention_dropout,
            head_dims=self.head_dim,
            attention_partition_spec=self.config.attention_partition_spec,
            shard_attention_computation=self.config.shard_attention_computation,
            precision=self.precision,
            force_float32_tpu=True,
            attn_mechanism=self.config.attn_mechanism,
            dtype=self.dtype,
            bias_partition_spec=self.config.bias_partition_spec,
            key_partition_spec=self.config.key_partition_spec,
            query_partition_spec=self.config.query_partition_spec,
            generation_query_partition_spec=self.config.generation_query_partition_spec,
            generation_bias_partition_spec=self.config.generation_bias_partition_spec,
            generation_attention_partition_spec=self.config.generation_attention_partition_spec,
            value_partition_spec=self.config.value_partition_spec,
            scan_ring_attention=self.config.scan_ring_attention,
            mesh=self.config.jax_mesh(),
            sm_scale=1 / math.sqrt(self.config.n_heads),
            axis_name=self.config.attention_axis_name,
            backward_pass_impl=self.config.flash_attention_backward_pass_impl
        )
        if self.config.qk_ln:
            self.q_ln = nn.LayerNorm(use_bias=self.config.use_norm_bias)
            self.k_ln = nn.LayerNorm(use_bias=self.config.use_norm_bias)
        self.causal_mask = flax.linen.make_causal_mask(
            jnp.ones(
                (1, self.config.max_seq_len),
                dtype="bool"
            ), dtype="bool"
        )

    def __call__(self,
                 hidden_states: chex.Array,
                 attention_mask: chex.Array,
                 position_ids: chex.Array,
                 attn_bias: chex.Array = None,
                 init_cache: bool = False
                 ):

        """
        The __call__ function is the main function of a JAX module.
        It takes in inputs and returns outputs, just like any other Python function.
        The difference is that __call__ can also take in state (e.g., parameters) from the module itself,
        and it can update that state as part of its computation.

        :param self: Access variables that belong to the class
        :param hidden_states: chex.Array: Pass the input to the attention layer
        :param attention_mask: chex.Array: Mask out certain positions in the sequence
        :param position_ids: chex.Array: Specify the position of each token in the sequence
        :param attn_bias: chex.Array: Add a bias to the attention scores
        :param init_cache: bool: Initialize the cache
        :return: The output of the attention layer

        """
        inp_shape = hidden_states.shape
        b, s, ds = inp_shape
        qkv = self.w_qkv(hidden_states)
        q, k, v = jnp.split(qkv, 3, -1)
        if self.config.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)

        q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_heads)
        k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_heads)
        v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_heads)
        attention_mask = attention_mask.reshape(b, 1, 1, -1)
        if self.has_variable('cache', 'key_states') or init_cache:
            k, v, attention_mask = self._concatenate_to_cache(key_states=k, value=v, query=q,
                                                              attention_mask=attention_mask)
        # TODO: MPT WONT WORK CAUSE OF NEW ATTENTION MEC ON FJFORMER

        # if self.config.use_sharding_constraint:
        #     q = with_sharding_constraint(
        #         q, jax.sharding.PartitionSpec(("dp", "fsdp"), "sp" if q.shape[1] != 1 else None, "tp",None)
        #     )
        #     k = with_sharding_constraint(k, jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None))
        #     v = with_sharding_constraint(v, jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp",None))
        q_l = q.shape[1]
        k_l = k.shape[1]
        dropout_rng = None
        deterministic = False
        if deterministic:
            dropout_rng = self.make_rng("dropout")

        d = q.shape[-1]
        attn_output = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision) * jax.lax.rsqrt(
            jnp.asarray(d).astype(v.dtype))
        attn_output = with_sharding_constraint(attn_output, PartitionSpec(
            ("dp", "fsdp"),
            "sp" if attn_output.shape[1] != 1 else None,
            None,
            None)
                                               )
        if attn_bias is not None:
            attn_output += attn_bias[:, :, :, :attn_output.shape[-1]]
        mask = jnp.where(self.causal_mask == 1, 0, jnp.finfo(attn_output).min)
        if attention_mask is not None:
            attention_mask = jnp.where(
                attention_mask == 1,
                0,
                jnp.finfo(attn_output).min
            )
            attn_output += attention_mask
        attn_output += mask[:, :, :attn_output.shape[-2], :attn_output.shape[-1]]
        attn_output = nn.softmax(attn_output, -1)
        attn_output = jnp.einsum('...hqk,...khd->...qhd', attn_output, v)
        return self.wo(attn_output.reshape(inp_shape))

__call__(hidden_states, attention_mask, position_ids, attn_bias=None, init_cache=False)

The call function is the main function of a JAX module. It takes in inputs and returns outputs, just like any other Python function. The difference is that call can also take in state (e.g., parameters) from the module itself, and it can update that state as part of its computation.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
hidden_states Array

chex.Array: Pass the input to the attention layer

required
attention_mask Array

chex.Array: Mask out certain positions in the sequence

required
position_ids Array

chex.Array: Specify the position of each token in the sequence

required
attn_bias Array

chex.Array: Add a bias to the attention scores

None
init_cache bool

bool: Initialize the cache

False

Returns:

Type Description

The output of the attention layer

Source code in src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def __call__(self,
             hidden_states: chex.Array,
             attention_mask: chex.Array,
             position_ids: chex.Array,
             attn_bias: chex.Array = None,
             init_cache: bool = False
             ):

    """
    The __call__ function is the main function of a JAX module.
    It takes in inputs and returns outputs, just like any other Python function.
    The difference is that __call__ can also take in state (e.g., parameters) from the module itself,
    and it can update that state as part of its computation.

    :param self: Access variables that belong to the class
    :param hidden_states: chex.Array: Pass the input to the attention layer
    :param attention_mask: chex.Array: Mask out certain positions in the sequence
    :param position_ids: chex.Array: Specify the position of each token in the sequence
    :param attn_bias: chex.Array: Add a bias to the attention scores
    :param init_cache: bool: Initialize the cache
    :return: The output of the attention layer

    """
    inp_shape = hidden_states.shape
    b, s, ds = inp_shape
    qkv = self.w_qkv(hidden_states)
    q, k, v = jnp.split(qkv, 3, -1)
    if self.config.qk_ln:
        q = self.q_ln(q)
        k = self.k_ln(k)

    q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_heads)
    k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_heads)
    v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_heads)
    attention_mask = attention_mask.reshape(b, 1, 1, -1)
    if self.has_variable('cache', 'key_states') or init_cache:
        k, v, attention_mask = self._concatenate_to_cache(key_states=k, value=v, query=q,
                                                          attention_mask=attention_mask)
    # TODO: MPT WONT WORK CAUSE OF NEW ATTENTION MEC ON FJFORMER

    # if self.config.use_sharding_constraint:
    #     q = with_sharding_constraint(
    #         q, jax.sharding.PartitionSpec(("dp", "fsdp"), "sp" if q.shape[1] != 1 else None, "tp",None)
    #     )
    #     k = with_sharding_constraint(k, jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None))
    #     v = with_sharding_constraint(v, jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp",None))
    q_l = q.shape[1]
    k_l = k.shape[1]
    dropout_rng = None
    deterministic = False
    if deterministic:
        dropout_rng = self.make_rng("dropout")

    d = q.shape[-1]
    attn_output = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision) * jax.lax.rsqrt(
        jnp.asarray(d).astype(v.dtype))
    attn_output = with_sharding_constraint(attn_output, PartitionSpec(
        ("dp", "fsdp"),
        "sp" if attn_output.shape[1] != 1 else None,
        None,
        None)
                                           )
    if attn_bias is not None:
        attn_output += attn_bias[:, :, :, :attn_output.shape[-1]]
    mask = jnp.where(self.causal_mask == 1, 0, jnp.finfo(attn_output).min)
    if attention_mask is not None:
        attention_mask = jnp.where(
            attention_mask == 1,
            0,
            jnp.finfo(attn_output).min
        )
        attn_output += attention_mask
    attn_output += mask[:, :, :attn_output.shape[-2], :attn_output.shape[-1]]
    attn_output = nn.softmax(attn_output, -1)
    attn_output = jnp.einsum('...hqk,...khd->...qhd', attn_output, v)
    return self.wo(attn_output.reshape(inp_shape))