Skip to content

modules.stablelm.modelling_stablelm_flax

FlaxStableLmAttention

Bases: BaseJAXAttentionModule

Source code in src/python/easydel/modules/stablelm/modelling_stablelm_flax.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
class FlaxStableLmAttention(BaseJAXAttentionModule):
    config: StableLmConfig
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self):
        config: StableLmConfig = self.config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.partial_rotary_factor = config.partial_rotary_factor

        if self.num_key_value_groups == 1:
            assert self.config.num_attention_heads == self.config.num_key_value_heads
        self.q_proj = Linear(
            config.num_attention_heads * self.head_dim,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=self.config.use_qkv_bias,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.k_proj = Linear(
            config.num_key_value_heads * self.head_dim,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=self.config.use_qkv_bias,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.v_proj = Linear(
            config.num_key_value_heads * self.head_dim,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=self.config.use_qkv_bias,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.o_proj = Linear(
            config.hidden_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )

        self.rotary_emb_dim = int(self.config.partial_rotary_factor * self.head_dim)
        self.attention_performer = AttentionModule(
            use_sharding_constraint=self.config.use_sharding_constraint,
            block_k_major=self.config.block_k_major,
            block_b=self.config.block_b,
            block_q=self.config.block_q,
            block_k=self.config.block_k,
            block_q_major_dkv=self.config.block_q_major_dkv,
            block_k_major_dkv=self.config.block_k_major_dkv,
            block_k_major_dq=self.config.block_k_major_dq,
            block_k_dkv=self.config.block_k_dkv,
            block_q_dkv=self.config.block_q_dkv,
            block_q_dq=self.config.block_q_dq,
            block_k_dq=self.config.block_k_dq,
            num_attention_heads=self.config.num_attention_heads,
            attention_dropout=self.config.attention_dropout,
            head_dims=self.head_dim,
            attention_partition_spec=self.config.attention_partition_spec,
            shard_attention_computation=self.config.shard_attention_computation,
            precision=self.precision,
            force_float32_tpu=True,
            attn_mechanism=self.config.attn_mechanism,
            dtype=self.dtype,
            bias_partition_spec=self.config.bias_partition_spec,
            key_partition_spec=self.config.key_partition_spec,
            query_partition_spec=self.config.query_partition_spec,
            generation_query_partition_spec=self.config.generation_query_partition_spec,
            generation_bias_partition_spec=self.config.generation_bias_partition_spec,
            generation_attention_partition_spec=self.config.generation_attention_partition_spec,
            value_partition_spec=self.config.value_partition_spec,
            scan_ring_attention=self.config.scan_ring_attention,
            mesh=self.config.jax_mesh(),
            sm_scale=1 / math.sqrt(self.head_dim),
            axis_name=self.config.attention_axis_name,
            backward_pass_impl=self.config.flash_attention_backward_pass_impl
        )

    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))

    @staticmethod
    def _transpose_sequence_head(query, key, value):
        """
        The _transpose_sequence_head function transposes the query, key and value matrices.

        :param query: Get the attention weights for each of the heads
        :param key: Determine the number of heads
        :param value: Store the values of the input
        :return: The transpose of the query, key and value matrices

        """
        return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3))

    def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids):
        """
        The apply_rotary function is a modified version of the apply_attention function in the BertModel class.
        The main difference is that it takes in an additional argument, freq_cis, which are used to calculate
        the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.

        :param self: Access variables that belong to the class
        :param batch_size: Reshape the query_states, key and value tensors
        :param sequence_length: Reshape the query_states, key and value tensors
        :param query: Calculate the attention weights
        :param key: Calculate the attention
        :param value: Compute the attention weights
        :param freq_cis: Calculate the frequency of each word in the vocabulary
        :param position_ids: Identify the position of each token in the sequence
        :return: A tuple of 3 tensors: query_states, key and value

        """
        query = query.reshape(
            batch_size,
            sequence_length,
            self.config.num_attention_heads,
            self.head_dim
        )
        key = key.reshape(
            batch_size,
            sequence_length,
            self.config.num_key_value_heads,
            self.head_dim
        )
        value = value.reshape(
            batch_size,
            sequence_length,
            self.config.num_key_value_heads,
            self.head_dim
        )

        query, key, value = self._transpose_sequence_head(query, key, value)

        sin, cos = freq_cis

        sin = sin[position_ids][:, None, :, :]
        cos = cos[position_ids][:, None, :, :]

        query_rot, query_pass = (
            query[..., : self.rotary_emb_dim],
            query[..., self.rotary_emb_dim:],
        )
        key_rot, key_pass = (
            key[..., : self.rotary_emb_dim],
            key[..., self.rotary_emb_dim:],
        )

        key_rot = apply_rotary_pos_emb(key_rot, sin, cos)
        query_rot = apply_rotary_pos_emb(query_rot, sin, cos)

        query = jnp.concatenate((query_rot, query_pass), axis=-1)
        key = jnp.concatenate((key_rot, key_pass), axis=-1)

        key = repeat_kv_bnsh(key, self.num_key_value_groups)
        value = repeat_kv_bnsh(value, self.num_key_value_groups)
        return self._transpose_sequence_head(query, key, value)

    def __call__(
            self,
            hidden_states: chex.Array,
            freq_cis: Tuple[chex.Array, chex.Array],
            attention_mask: chex.Array,
            position_ids: chex.Array,
            causal_mask: chex.Array,
            segment_ids: Optional[chex.Array] = None,
            deterministic: bool = True,
            init_cache: bool = False,
            output_attentions: bool = False,
            fcm_mask=None,
    ):
        """

        The __call__ function is the main function of a JAX module. It defines how the module behaves when called
        with inputs. The __call__ function can be thought of as a "forward pass" through the model,
        and it should return all outputs that are needed for training or inference.

        :param self: Access variables that belong to the class
        :param hidden_states: chex.Array: Pass the hidden states of the previous layer
        :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position
        :param attention_mask: chex.Array: Mask out certain tokens in the input sequence
        :param position_ids: chex.Array: Determine the position of each token in a sequence
        :param causal_mask: chex.Array: Mask out the future tokens in the decoder
        :param deterministic: bool: Determine whether to use dropout or not
        :param init_cache: bool: Initialize the cache
        :param output_attentions: bool: Determine whether to return the attention weights or not
        :param fcm_mask: Mask out the attention weights between the input and output tokens
        :param : Determine if the attention is causal or not
        :return: A tuple of two arrays

        """
        batch_size, sequence_length = hidden_states.shape[:2]
        query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(
            hidden_states)

        query_states = query_states.reshape(
            batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
        key_states = key_states.reshape(
            batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
        value_states = value_states.reshape(
            batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)

        query_states, key_states, value_states = self.apply_rotary(
            query=query_states,
            key=key_states,
            value=value_states,
            position_ids=position_ids,
            freq_cis=freq_cis,
            batch_size=batch_size,
            sequence_length=sequence_length
        )

        assert_msg = (
            "num_attention_heads repeat wont work likely\n"
            f"INFO :\n\trepeat_kv_bnsh Used with num_key_value_groups = {self.num_key_value_groups}\n\t"
            f"NH : {self.config.num_attention_heads} KVH : {self.config.num_attention_heads}"
        )

        assert query_states.shape[-2] == self.config.num_attention_heads, assert_msg
        assert key_states.shape[-2] == self.config.num_attention_heads, assert_msg
        assert value_states.shape[-2] == self.config.num_attention_heads, assert_msg

        query_length, key_length = query_states.shape[1], key_states.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                causal_mask,
                (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length)
            )
        else:
            causal_mask = causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(
            causal_mask, (batch_size,) + causal_mask.shape[1:])
        attention_mask = jnp.broadcast_to(jnp.expand_dims(
            attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
        if attention_mask.ndim == 2:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        dropout_rng = None

        if not deterministic and self.config.attention_dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        if self.has_variable("cache", "cached_key") or init_cache:
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states,
                value_states,
                query_states,
                attention_mask
            )
        # if self.config.use_sharding_constraint:
        #     query_states = with_sharding_constraint(
        #         query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None)
        #     )
        #     key_states = with_sharding_constraint(
        #         key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
        #     )
        #     value_states = with_sharding_constraint(
        #         value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
        #     )
        use_qkv_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, jnp.finfo(
                self.dtype).min).astype(self.dtype),
        )

        query_length, key_length = query_states.shape[1], key_states.shape[1]

        attentions = self.attention_performer.__call__(
            query_states=query_states,
            key_states=key_states,
            value_states=value_states,
            bias=use_qkv_bias,
            attention_mask=attention_mask,
            causal=True,
            dropout_rng=dropout_rng,
            deterministic=deterministic,
            query_sequence_length=query_length,
            key_value_sequence_length=key_length,
            uses_cache=self.has_variable("cache", "cached_key") or init_cache,
            segment_ids=segment_ids,
            causal_mask=causal_mask
        )


        attn_output = self._merge_heads(attentions.attention_outputs)
        if self.config.shard_attention_computation:
            attn_output = with_sharding_constraint(
                attn_output, PartitionSpec(
                    ("dp", "fsdp"),
                    "sp" if attn_output.shape[1] != 1 else None,
                    "tp"
                )
            )
        attn_output = self.o_proj(attn_output)
        outputs = (attn_output, attentions.attention_weights) if output_attentions else (attn_output,)
        return outputs

__call__(hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids=None, deterministic=True, init_cache=False, output_attentions=False, fcm_mask=None)

The call function is the main function of a JAX module. It defines how the module behaves when called with inputs. The call function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
hidden_states Array

chex.Array: Pass the hidden states of the previous layer

required
freq_cis Tuple[Array, Array]

Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position

required
attention_mask Array

chex.Array: Mask out certain tokens in the input sequence

required
position_ids Array

chex.Array: Determine the position of each token in a sequence

required
causal_mask Array

chex.Array: Mask out the future tokens in the decoder

required
deterministic bool

bool: Determine whether to use dropout or not

True
init_cache bool

bool: Initialize the cache

False
output_attentions bool

bool: Determine whether to return the attention weights or not

False
fcm_mask

Mask out the attention weights between the input and output tokens

None

Determine if the attention is causal or not

required

Returns:

Type Description

A tuple of two arrays

Source code in src/python/easydel/modules/stablelm/modelling_stablelm_flax.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def __call__(
        self,
        hidden_states: chex.Array,
        freq_cis: Tuple[chex.Array, chex.Array],
        attention_mask: chex.Array,
        position_ids: chex.Array,
        causal_mask: chex.Array,
        segment_ids: Optional[chex.Array] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        fcm_mask=None,
):
    """

    The __call__ function is the main function of a JAX module. It defines how the module behaves when called
    with inputs. The __call__ function can be thought of as a "forward pass" through the model,
    and it should return all outputs that are needed for training or inference.

    :param self: Access variables that belong to the class
    :param hidden_states: chex.Array: Pass the hidden states of the previous layer
    :param freq_cis: Tuple[chex.Array, chex.Array],: Pass in the frequency coefficients for each position
    :param attention_mask: chex.Array: Mask out certain tokens in the input sequence
    :param position_ids: chex.Array: Determine the position of each token in a sequence
    :param causal_mask: chex.Array: Mask out the future tokens in the decoder
    :param deterministic: bool: Determine whether to use dropout or not
    :param init_cache: bool: Initialize the cache
    :param output_attentions: bool: Determine whether to return the attention weights or not
    :param fcm_mask: Mask out the attention weights between the input and output tokens
    :param : Determine if the attention is causal or not
    :return: A tuple of two arrays

    """
    batch_size, sequence_length = hidden_states.shape[:2]
    query_states, key_states, value_states = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(
        hidden_states)

    query_states = query_states.reshape(
        batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
    key_states = key_states.reshape(
        batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
    value_states = value_states.reshape(
        batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)

    query_states, key_states, value_states = self.apply_rotary(
        query=query_states,
        key=key_states,
        value=value_states,
        position_ids=position_ids,
        freq_cis=freq_cis,
        batch_size=batch_size,
        sequence_length=sequence_length
    )

    assert_msg = (
        "num_attention_heads repeat wont work likely\n"
        f"INFO :\n\trepeat_kv_bnsh Used with num_key_value_groups = {self.num_key_value_groups}\n\t"
        f"NH : {self.config.num_attention_heads} KVH : {self.config.num_attention_heads}"
    )

    assert query_states.shape[-2] == self.config.num_attention_heads, assert_msg
    assert key_states.shape[-2] == self.config.num_attention_heads, assert_msg
    assert value_states.shape[-2] == self.config.num_attention_heads, assert_msg

    query_length, key_length = query_states.shape[1], key_states.shape[1]

    if self.has_variable("cache", "cached_key"):
        mask_shift = self.variables["cache"]["cache_index"]
        max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
        causal_mask = lax.dynamic_slice(
            causal_mask,
            (0, 0, mask_shift, 0),
            (1, 1, query_length, max_decoder_length)
        )
    else:
        causal_mask = causal_mask[:, :, :query_length, :key_length]

    batch_size = hidden_states.shape[0]
    causal_mask = jnp.broadcast_to(
        causal_mask, (batch_size,) + causal_mask.shape[1:])
    attention_mask = jnp.broadcast_to(jnp.expand_dims(
        attention_mask, axis=(-3, -2)), causal_mask.shape)
    attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
    if attention_mask.ndim == 2:
        attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

    dropout_rng = None

    if not deterministic and self.config.attention_dropout > 0.0:
        dropout_rng = self.make_rng("dropout")

    if self.has_variable("cache", "cached_key") or init_cache:
        key_states, value_states, attention_mask = self._concatenate_to_cache(
            key_states,
            value_states,
            query_states,
            attention_mask
        )
    # if self.config.use_sharding_constraint:
    #     query_states = with_sharding_constraint(
    #         query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None)
    #     )
    #     key_states = with_sharding_constraint(
    #         key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
    #     )
    #     value_states = with_sharding_constraint(
    #         value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
    #     )
    use_qkv_bias = lax.select(
        attention_mask > 0,
        jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
        jnp.full(attention_mask.shape, jnp.finfo(
            self.dtype).min).astype(self.dtype),
    )

    query_length, key_length = query_states.shape[1], key_states.shape[1]

    attentions = self.attention_performer.__call__(
        query_states=query_states,
        key_states=key_states,
        value_states=value_states,
        bias=use_qkv_bias,
        attention_mask=attention_mask,
        causal=True,
        dropout_rng=dropout_rng,
        deterministic=deterministic,
        query_sequence_length=query_length,
        key_value_sequence_length=key_length,
        uses_cache=self.has_variable("cache", "cached_key") or init_cache,
        segment_ids=segment_ids,
        causal_mask=causal_mask
    )


    attn_output = self._merge_heads(attentions.attention_outputs)
    if self.config.shard_attention_computation:
        attn_output = with_sharding_constraint(
            attn_output, PartitionSpec(
                ("dp", "fsdp"),
                "sp" if attn_output.shape[1] != 1 else None,
                "tp"
            )
        )
    attn_output = self.o_proj(attn_output)
    outputs = (attn_output, attentions.attention_weights) if output_attentions else (attn_output,)
    return outputs

apply_rotary(batch_size, sequence_length, query, key, value, freq_cis, position_ids)

The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, freq_cis, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
batch_size

Reshape the query_states, key and value tensors

required
sequence_length

Reshape the query_states, key and value tensors

required
query

Calculate the attention weights

required
key

Calculate the attention

required
value

Compute the attention weights

required
freq_cis

Calculate the frequency of each word in the vocabulary

required
position_ids

Identify the position of each token in the sequence

required

Returns:

Type Description

A tuple of 3 tensors: query_states, key and value

Source code in src/python/easydel/modules/stablelm/modelling_stablelm_flax.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids):
    """
    The apply_rotary function is a modified version of the apply_attention function in the BertModel class.
    The main difference is that it takes in an additional argument, freq_cis, which are used to calculate
    the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.

    :param self: Access variables that belong to the class
    :param batch_size: Reshape the query_states, key and value tensors
    :param sequence_length: Reshape the query_states, key and value tensors
    :param query: Calculate the attention weights
    :param key: Calculate the attention
    :param value: Compute the attention weights
    :param freq_cis: Calculate the frequency of each word in the vocabulary
    :param position_ids: Identify the position of each token in the sequence
    :return: A tuple of 3 tensors: query_states, key and value

    """
    query = query.reshape(
        batch_size,
        sequence_length,
        self.config.num_attention_heads,
        self.head_dim
    )
    key = key.reshape(
        batch_size,
        sequence_length,
        self.config.num_key_value_heads,
        self.head_dim
    )
    value = value.reshape(
        batch_size,
        sequence_length,
        self.config.num_key_value_heads,
        self.head_dim
    )

    query, key, value = self._transpose_sequence_head(query, key, value)

    sin, cos = freq_cis

    sin = sin[position_ids][:, None, :, :]
    cos = cos[position_ids][:, None, :, :]

    query_rot, query_pass = (
        query[..., : self.rotary_emb_dim],
        query[..., self.rotary_emb_dim:],
    )
    key_rot, key_pass = (
        key[..., : self.rotary_emb_dim],
        key[..., self.rotary_emb_dim:],
    )

    key_rot = apply_rotary_pos_emb(key_rot, sin, cos)
    query_rot = apply_rotary_pos_emb(query_rot, sin, cos)

    query = jnp.concatenate((query_rot, query_pass), axis=-1)
    key = jnp.concatenate((key_rot, key_pass), axis=-1)

    key = repeat_kv_bnsh(key, self.num_key_value_groups)
    value = repeat_kv_bnsh(value, self.num_key_value_groups)
    return self._transpose_sequence_head(query, key, value)

FlaxStableLmMLP

Bases: Module

Source code in src/python/easydel/modules/stablelm/modelling_stablelm_flax.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class FlaxStableLmMLP(nn.Module):
    config: StableLmConfig
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self) -> None:
        config = self.config

        self.gate_proj = Linear(
            config.intermediate_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.down_proj = Linear(
            config.hidden_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.up_proj = Linear(
            config.intermediate_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.act_fn = ACT2FN[config.hidden_act]

    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        """
        The __call__ function is the main function of a class.
        It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments).
        The __call__ method enables instances of a class to be called like standard Python functions.

        :param self: Represent the instance of the class
        :param x: jnp.ndarray: Pass in the input to the layer
        :param deterministic: bool: Determine whether to use dropout # Ignored
        :return: A tensor that is the result of function to x

        """
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

__call__(x, deterministic=True)

The call function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). The call method enables instances of a class to be called like standard Python functions.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
x ndarray

jnp.ndarray: Pass in the input to the layer

required
deterministic bool

bool: Determine whether to use dropout # Ignored

True

Returns:

Type Description
ndarray

A tensor that is the result of function to x

Source code in src/python/easydel/modules/stablelm/modelling_stablelm_flax.py
111
112
113
114
115
116
117
118
119
120
121
122
123
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
    """
    The __call__ function is the main function of a class.
    It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments).
    The __call__ method enables instances of a class to be called like standard Python functions.

    :param self: Represent the instance of the class
    :param x: jnp.ndarray: Pass in the input to the layer
    :param deterministic: bool: Determine whether to use dropout # Ignored
    :return: A tensor that is the result of function to x

    """
    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

FlaxStableLmPreTrainedModel

Bases: EasyDeLFlaxPretrainedModel

StableLm pre-trained model.

Source code in src/python/easydel/modules/stablelm/modelling_stablelm_flax.py
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
class FlaxStableLmPreTrainedModel(EasyDeLFlaxPretrainedModel):
    """StableLm pre-trained model."""
    module_class = None
    config_class = StableLmConfig
    base_model_prefix = "model"

    def __init__(
            self,
            config: StableLmConfig,
            dtype: jnp.dtype = jnp.float32,
            param_dtype: jnp.dtype = jnp.float32,
            precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"),
            input_shape=(1, 1),
            seed: int = 42,
            _do_init: bool = False
    ) -> None:
        module = self.module_class(
            config=config,
            dtype=dtype,
            param_dtype=param_dtype,
            precision=precision
        )
        super().__init__(
            config=config,
            module=module,
            input_shape=input_shape,
            _do_init=_do_init,
            seed=seed
        )

    def init_cache(self, batch_size, max_length):

        input_ids = jnp.ones((batch_size, max_length))
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(
            jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return init_variables["cache"]

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        module_init_outputs = self.module.init(rngs, input_ids, attention_mask)

        random_params = module_init_outputs["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: chex.Array = None,
            position_ids: chex.Array = None,
            params: dict = None,
            past_key_values: dict = None,
            dropout_rng: jax.random.PRNGKey = None,
            train: bool = False,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = True,
            extra_embedding: Optional[Union[jnp.ndarray, None]] = None,
            add_params_field: bool = False,
            **kwargs
    ):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        assert sequence_length <= self.config.max_position_embeddings, "Maximum Position Embedding Reached !"

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        if self.config.bits is not None:
            rngs['params'] = jax.random.key(0)

        inputs = {"params": params or self.params} if add_params_field else params or self.params

        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            input_ids=input_ids,
            inputs_embeds=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            extra_embedding=extra_embedding,
            deterministic=not train,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            init_cache=False,
            return_dict=return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

        return outputs