Skip to content

modules.qwen1.modelling_qwen1_flax

FlaxQwen1Attention

Bases: BaseJAXAttentionModule

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
class FlaxQwen1Attention(BaseJAXAttentionModule):
    config: Qwen1Config
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self):
        config = self.config

        self.hidden_size = config.hidden_size
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.projection_size = config.kv_channels * config.num_attention_heads
        assert self.projection_size % config.num_attention_heads == 0
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads

        self.c_attn = Linear(
            self.projection_size * 3,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=True,
            kernel_init=jax.nn.initializers.normal(
                config.initializer_range
            ),
            precision=self.precision,
            **get_dot_general_by_bits(config.bits, config.easy_method)
        )

        self.c_proj = Linear(
            config.hidden_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=not self.config.no_bias,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range
            ),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        logn_list = [
            math.log(i, self.config.seq_length) if i > self.config.seq_length else 1
            for i in range(1, 32768)
        ]
        logn_tensor = jnp.asarray(logn_list)[None, :, None, None]
        self.logn_tensor = logn_tensor
        self.rotary = FlaxQwen1EmbeddingApplyer(self.dtype)
        self.attention_performer = AttentionModule(
            use_sharding_constraint=self.config.use_sharding_constraint,
            block_k_major=self.config.block_k_major,
            block_b=self.config.block_b,
            block_q=self.config.block_q,
            block_k=self.config.block_k,
            block_q_major_dkv=self.config.block_q_major_dkv,
            block_k_major_dkv=self.config.block_k_major_dkv,
            block_k_major_dq=self.config.block_k_major_dq,
            block_k_dkv=self.config.block_k_dkv,
            block_q_dkv=self.config.block_q_dkv,
            block_q_dq=self.config.block_q_dq,
            block_k_dq=self.config.block_k_dq,
            num_attention_heads=self.config.num_attention_heads,
            attention_dropout=self.config.attn_dropout_prob,
            head_dims=self.head_dim,
            attention_partition_spec=self.config.attention_partition_spec,
            shard_attention_computation=self.config.shard_attention_computation,
            precision=self.precision,
            force_float32_tpu=True,
            attn_mechanism=self.config.attn_mechanism,
            dtype=self.dtype,
            bias_partition_spec=self.config.bias_partition_spec,
            key_partition_spec=self.config.key_partition_spec,
            query_partition_spec=self.config.query_partition_spec,
            generation_query_partition_spec=self.config.generation_query_partition_spec,
            generation_bias_partition_spec=self.config.generation_bias_partition_spec,
            generation_attention_partition_spec=self.config.generation_attention_partition_spec,
            value_partition_spec=self.config.value_partition_spec,
            scan_ring_attention=self.config.scan_ring_attention,
            mesh=self.config.jax_mesh(),
            sm_scale=1 / math.sqrt(self.head_dim),
            axis_name=self.config.attention_axis_name,
            backward_pass_impl=self.config.flash_attention_backward_pass_impl
        )

    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))

    @staticmethod
    def _transpose_sequence_head(query, key, value):
        """
        The _transpose_sequence_head function transposes the query, key and value matrices.

        :param query: Get the attention weights for each of the heads
        :param key: Determine the number of heads
        :param value: Store the values of the input
        :return: The transpose of the query, key and value matrices

        """
        return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3))

    def apply_rotary(self, batch_size, sequence_length, query, key, value, rotary_pos_emb_list, position_ids):
        """
        The apply_rotary function is a modified version of the apply_attention function in the BertModel class.
        The main difference is that it takes in an additional argument, rotary_pos_emb_list, which are used to calculate
        the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.

        :param self: Access variables that belong to the class
        :param batch_size: Reshape the query, key and value tensors
        :param sequence_length: Reshape the query, key and value tensors
        :param query_states: Calculate the attention weights
        :param key: Calculate the attention
        :param value: Compute the attention weights
        :param rotary_pos_emb_list: Calculate the frequency of each word in the vocabulary
        :param position_ids: Identify the position of each token in the sequence
        :return: A tuple of 3 tensors: query_states, key and value

        """
        query_states, key = self.rotary(
            position_ids=position_ids, query_states=query_states, key=key, rotary_pos_emb_list=rotary_pos_emb_list
        )
        return query_states, key, value

    def __call__(
            self,
            hidden_states: chex.Array,
            rotary_pos_emb_list: list[chex.Array],
            attention_mask: chex.Array,
            position_ids: chex.Array,
            causal_mask: chex.Array,
            segment_ids: Optional[chex.Array] = None,
            deterministic: bool = True,
            init_cache: bool = False,
            output_attentions: bool = False,
            encoder_hidden_states: Optional[chex.Array] = None,
            encoder_attention_mask: Optional[chex.Array] = None,
            fcm_mask=None,
    ):
        """

        The __call__ function is the main function of a JAX module. It defines how the module behaves when called
        with inputs. The __call__ function can be thought of as a "forward pass" through the model,
        and it should return all outputs that are needed for training or inference.

        :param self: Access variables that belong to the class
        :param hidden_states: chex.Array: Pass the hidden states of the previous layer
        :param rotary_pos_emb_list: list[chex.Array]: Pass in the frequency coefficients for each position
        :param attention_mask: chex.Array: Mask out certain tokens in the input sequence
        :param position_ids: chex.Array: Determine the position of each token in a sequence
        :param causal_mask: chex.Array: Mask out the future tokens in the decoder
        :param deterministic: bool: Determine whether to use dropout or not
        :param init_cache: bool: Initialize the cache
        :param output_attentions: bool: Determine whether to return the attention weights or not
        :param fcm_mask: Mask out the attention weights between the input and output tokens
        :param : Determine if the attention is causal or not
        :return: A tuple of two arrays

        """
        batch_size, sequence_length = hidden_states.shape[:2]
        mixed_x_layer: chex.Array = self.c_attn(hidden_states)
        query_states, key_states, value_states = jnp.split(mixed_x_layer, 3, 2)

        query_states = query_states.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
        key_states = key_states.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
        value_states = value_states.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)

        query_states, key_states, value_states = self.apply_rotary(
            query=query_states,
            key=key_states,
            value=value_states,
            position_ids=position_ids,
            rotary_pos_emb_list=rotary_pos_emb_list,
            batch_size=batch_size,
            sequence_length=sequence_length
        )

        query_length, key_length = query_states.shape[1], key_states.shape[1]
        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                causal_mask, (0, 0, mask_shift, 0), (1, 1,
                                                     query_length, max_decoder_length)
            )
        else:
            causal_mask = causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(
            causal_mask, (batch_size,) + causal_mask.shape[1:])
        attention_mask = jnp.broadcast_to(jnp.expand_dims(
            attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
        if attention_mask.ndim == 2:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        dropout_rng = None

        if not deterministic and self.config.attention_dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        if self.has_variable("cache", "cached_key") or init_cache:
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states,
                value_states,
                query_states,
                attention_mask
            )
        # if self.config.use_sharding_constraint:
        #     query_states = with_sharding_constraint(
        #         query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None)
        #     )
        #     key_states = with_sharding_constraint(
        #         key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
        #     )
        #     value_states = with_sharding_constraint(
        #         value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
        #     )
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, jnp.finfo(
                self.dtype).min).astype(self.dtype),
        )

        query_length, key_length = query_states.shape[1], key_states.shape[1]

        attentions = self.attention_performer.__call__(
            query_states=query_states,
            key_states=key_states,
            value_states=value_states,
            bias=attention_bias,
            attention_mask=attention_mask,
            causal=True,
            dropout_rng=dropout_rng,
            deterministic=deterministic,
            query_sequence_length=query_length,
            key_value_sequence_length=key_length,
            uses_cache=self.has_variable("cache", "cached_key") or init_cache,
            segment_ids=segment_ids,
            causal_mask=causal_mask
        )


        attn_output = self._merge_heads(attentions.attention_outputs)
        if self.config.shard_attention_computation:
            attn_output = with_sharding_constraint(
                attn_output, PartitionSpec(
                    ("dp", "fsdp"),
                    "sp" if attn_output.shape[1] != 1 else None,
                    "tp"
                )
            )

        attn_output = self.c_proj(attn_output)

        outputs = (
            attn_output, attentions.attention_weights
        ) if output_attentions else (
            attn_output,
        )
        return outputs

__call__(hidden_states, rotary_pos_emb_list, attention_mask, position_ids, causal_mask, segment_ids=None, deterministic=True, init_cache=False, output_attentions=False, encoder_hidden_states=None, encoder_attention_mask=None, fcm_mask=None)

The call function is the main function of a JAX module. It defines how the module behaves when called with inputs. The call function can be thought of as a "forward pass" through the model, and it should return all outputs that are needed for training or inference.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
hidden_states Array

chex.Array: Pass the hidden states of the previous layer

required
rotary_pos_emb_list list[Array]

list[chex.Array]: Pass in the frequency coefficients for each position

required
attention_mask Array

chex.Array: Mask out certain tokens in the input sequence

required
position_ids Array

chex.Array: Determine the position of each token in a sequence

required
causal_mask Array

chex.Array: Mask out the future tokens in the decoder

required
deterministic bool

bool: Determine whether to use dropout or not

True
init_cache bool

bool: Initialize the cache

False
output_attentions bool

bool: Determine whether to return the attention weights or not

False
fcm_mask

Mask out the attention weights between the input and output tokens

None

Determine if the attention is causal or not

required

Returns:

Type Description

A tuple of two arrays

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def __call__(
        self,
        hidden_states: chex.Array,
        rotary_pos_emb_list: list[chex.Array],
        attention_mask: chex.Array,
        position_ids: chex.Array,
        causal_mask: chex.Array,
        segment_ids: Optional[chex.Array] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        encoder_hidden_states: Optional[chex.Array] = None,
        encoder_attention_mask: Optional[chex.Array] = None,
        fcm_mask=None,
):
    """

    The __call__ function is the main function of a JAX module. It defines how the module behaves when called
    with inputs. The __call__ function can be thought of as a "forward pass" through the model,
    and it should return all outputs that are needed for training or inference.

    :param self: Access variables that belong to the class
    :param hidden_states: chex.Array: Pass the hidden states of the previous layer
    :param rotary_pos_emb_list: list[chex.Array]: Pass in the frequency coefficients for each position
    :param attention_mask: chex.Array: Mask out certain tokens in the input sequence
    :param position_ids: chex.Array: Determine the position of each token in a sequence
    :param causal_mask: chex.Array: Mask out the future tokens in the decoder
    :param deterministic: bool: Determine whether to use dropout or not
    :param init_cache: bool: Initialize the cache
    :param output_attentions: bool: Determine whether to return the attention weights or not
    :param fcm_mask: Mask out the attention weights between the input and output tokens
    :param : Determine if the attention is causal or not
    :return: A tuple of two arrays

    """
    batch_size, sequence_length = hidden_states.shape[:2]
    mixed_x_layer: chex.Array = self.c_attn(hidden_states)
    query_states, key_states, value_states = jnp.split(mixed_x_layer, 3, 2)

    query_states = query_states.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
    key_states = key_states.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
    value_states = value_states.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)

    query_states, key_states, value_states = self.apply_rotary(
        query=query_states,
        key=key_states,
        value=value_states,
        position_ids=position_ids,
        rotary_pos_emb_list=rotary_pos_emb_list,
        batch_size=batch_size,
        sequence_length=sequence_length
    )

    query_length, key_length = query_states.shape[1], key_states.shape[1]
    if self.has_variable("cache", "cached_key"):
        mask_shift = self.variables["cache"]["cache_index"]
        max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
        causal_mask = lax.dynamic_slice(
            causal_mask, (0, 0, mask_shift, 0), (1, 1,
                                                 query_length, max_decoder_length)
        )
    else:
        causal_mask = causal_mask[:, :, :query_length, :key_length]

    batch_size = hidden_states.shape[0]
    causal_mask = jnp.broadcast_to(
        causal_mask, (batch_size,) + causal_mask.shape[1:])
    attention_mask = jnp.broadcast_to(jnp.expand_dims(
        attention_mask, axis=(-3, -2)), causal_mask.shape)
    attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
    if attention_mask.ndim == 2:
        attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

    dropout_rng = None

    if not deterministic and self.config.attention_dropout > 0.0:
        dropout_rng = self.make_rng("dropout")

    if self.has_variable("cache", "cached_key") or init_cache:
        key_states, value_states, attention_mask = self._concatenate_to_cache(
            key_states,
            value_states,
            query_states,
            attention_mask
        )
    # if self.config.use_sharding_constraint:
    #     query_states = with_sharding_constraint(
    #         query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None)
    #     )
    #     key_states = with_sharding_constraint(
    #         key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
    #     )
    #     value_states = with_sharding_constraint(
    #         value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
    #     )
    attention_bias = lax.select(
        attention_mask > 0,
        jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
        jnp.full(attention_mask.shape, jnp.finfo(
            self.dtype).min).astype(self.dtype),
    )

    query_length, key_length = query_states.shape[1], key_states.shape[1]

    attentions = self.attention_performer.__call__(
        query_states=query_states,
        key_states=key_states,
        value_states=value_states,
        bias=attention_bias,
        attention_mask=attention_mask,
        causal=True,
        dropout_rng=dropout_rng,
        deterministic=deterministic,
        query_sequence_length=query_length,
        key_value_sequence_length=key_length,
        uses_cache=self.has_variable("cache", "cached_key") or init_cache,
        segment_ids=segment_ids,
        causal_mask=causal_mask
    )


    attn_output = self._merge_heads(attentions.attention_outputs)
    if self.config.shard_attention_computation:
        attn_output = with_sharding_constraint(
            attn_output, PartitionSpec(
                ("dp", "fsdp"),
                "sp" if attn_output.shape[1] != 1 else None,
                "tp"
            )
        )

    attn_output = self.c_proj(attn_output)

    outputs = (
        attn_output, attentions.attention_weights
    ) if output_attentions else (
        attn_output,
    )
    return outputs

apply_rotary(batch_size, sequence_length, query, key, value, rotary_pos_emb_list, position_ids)

The apply_rotary function is a modified version of the apply_attention function in the BertModel class. The main difference is that it takes in an additional argument, rotary_pos_emb_list, which are used to calculate the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
batch_size

Reshape the query, key and value tensors

required
sequence_length

Reshape the query, key and value tensors

required
query_states

Calculate the attention weights

required
key

Calculate the attention

required
value

Compute the attention weights

required
rotary_pos_emb_list

Calculate the frequency of each word in the vocabulary

required
position_ids

Identify the position of each token in the sequence

required

Returns:

Type Description

A tuple of 3 tensors: query_states, key and value

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def apply_rotary(self, batch_size, sequence_length, query, key, value, rotary_pos_emb_list, position_ids):
    """
    The apply_rotary function is a modified version of the apply_attention function in the BertModel class.
    The main difference is that it takes in an additional argument, rotary_pos_emb_list, which are used to calculate
    the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.

    :param self: Access variables that belong to the class
    :param batch_size: Reshape the query, key and value tensors
    :param sequence_length: Reshape the query, key and value tensors
    :param query_states: Calculate the attention weights
    :param key: Calculate the attention
    :param value: Compute the attention weights
    :param rotary_pos_emb_list: Calculate the frequency of each word in the vocabulary
    :param position_ids: Identify the position of each token in the sequence
    :return: A tuple of 3 tensors: query_states, key and value

    """
    query_states, key = self.rotary(
        position_ids=position_ids, query_states=query_states, key=key, rotary_pos_emb_list=rotary_pos_emb_list
    )
    return query_states, key, value

FlaxQwen1Block

Bases: Module

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
class FlaxQwen1Block(nn.Module):
    config: Qwen1Config
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self) -> None:
        attn_block = FlaxQwen1Attention
        if self.config.gradient_checkpointing != "":
            attn_block = nn_partitioning.remat(
                FlaxQwen1Attention, static_argnums=(1, 3, 4, 6, 7, 8, 9, 10, 11),
                policy=get_gradient_checkpoint_policy(
                    self.config.gradient_checkpointing)
            )

        self.attn = attn_block(
            self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision
        )
        mlp_block = FlaxQwen1MLP

        if self.config.gradient_checkpointing != "":
            mlp_block = nn_partitioning.remat(
                FlaxQwen1MLP, static_argnums=(1,),
                policy=get_gradient_checkpoint_policy(
                    self.config.gradient_checkpointing)
            )

        self.mlp = mlp_block(
            self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
        )
        self.ln_1 = Qwen1RMSNorm(
            self.config.hidden_size,
            eps=self.config.layer_norm_epsilon,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
        )
        self.ln_2 = Qwen1RMSNorm(
            self.config.hidden_size,
            eps=self.config.layer_norm_epsilon,
            dtype=self.dtype,
            param_dtype=self.param_dtype,

        )

    def __call__(
            self,
            hidden_states: chex.Array,
            rotary_pos_emb_list: list[chex.Array],
            attention_mask: chex.Array,
            position_ids: chex.Array,
            causal_mask: chex.Array,
            segment_ids: Optional[chex.Array] = None,
            deterministic: bool = True,
            init_cache: bool = False,
            output_attentions: bool = False,
            encoder_hidden_states: Optional[chex.Array] = None,
            encoder_attention_mask: Optional[chex.Array] = None,
            fcm_mask: Optional[jnp.ndarray] = None,
    ):
        """
        The __call__ function is the main function of a TransformerEncoderLayer.
        It takes in hidden states, frequency-domain inputs, and masks as input. It then
        applies self-attention to the hidden states using those inputs and returns an
        output tensor with shape (batch_size, sequence_length, model_dim).

        :param self: Refer to the class instance itself
        :param hidden_states: chex.Array: Pass in the hidden state of the previous layer
        :param rotary_pos_emb_list: list[chex.Array]: Pass in the frequency information
        :param attention_mask: chex.Array: Mask out the attention weights for padding tokens
        :param position_ids: chex.Array: Determine the position of each token in the sequence
        :param causal_mask: chex.Array: Mask the attention weights
        :param deterministic: bool: Control whether the dropout is applied or not
        :param init_cache: bool: Initialize the cache in the attention layer
        :param output_attentions: bool: Return the attention weights
        :param fcm_mask: Optional[jnp.ndarray]: Mask the self-attention
        :param : Control the dropout in the self attention layer
        :return: A tuple of two items

        """
        # hidden_states: chex.Array
        # rotary_pos_emb_list: list[chex.Array]
        # attention_mask: chex.Array
        # position_ids: chex.Array
        # causal_mask: chex.Array
        # deterministic: bool = True
        # init_cache: bool = False
        # output_attentions: bool = False
        # encoder_hidden_states: Optional[chex.Array] = None
        # encoder_attention_mask: Optional[chex.Array] = None
        # fcm_mask = None

        attn_outputs = self.attn(
            self.ln_1(hidden_states),
            rotary_pos_emb_list,
            attention_mask,
            position_ids,
            causal_mask,
            segment_ids,
            deterministic,
            init_cache,
            output_attentions,
            encoder_attention_mask,
            encoder_hidden_states,
            fcm_mask,
        )
        attn_output = attn_outputs[0]
        hidden_states = hidden_states + attn_output

        feed_forward_input = self.ln_2(hidden_states)

        if self.config.use_scan_mlp:
            feed_forward_input = einops.rearrange(
                feed_forward_input,
                '... (b s) d -> ... b s d',
                b=self.config.scan_mlp_chunk_size
            )

            def mlp_forward(mlp, carry, x):
                return None, mlp(x, deterministic)

            scan_axis = feed_forward_input.ndim - 3

            _, feed_forward_hidden_states = nn.scan(
                mlp_forward,
                variable_broadcast="params",
                split_rngs={"params": False, "dropout": True},
                in_axes=scan_axis,
                out_axes=scan_axis,
            )(self.mlp, None, feed_forward_input)
            feed_forward_hidden_states = einops.rearrange(
                feed_forward_hidden_states,
                '... b s d -> ... (b s) d'
            )
        else:
            feed_forward_hidden_states = self.mlp(
                feed_forward_input,
                deterministic,
            )

        hidden_states = hidden_states + feed_forward_hidden_states

        return (hidden_states,) + attn_outputs[1:]

__call__(hidden_states, rotary_pos_emb_list, attention_mask, position_ids, causal_mask, segment_ids=None, deterministic=True, init_cache=False, output_attentions=False, encoder_hidden_states=None, encoder_attention_mask=None, fcm_mask=None)

The call function is the main function of a TransformerEncoderLayer. It takes in hidden states, frequency-domain inputs, and masks as input. It then applies self-attention to the hidden states using those inputs and returns an output tensor with shape (batch_size, sequence_length, model_dim).

Parameters:

Name Type Description Default
self

Refer to the class instance itself

required
hidden_states Array

chex.Array: Pass in the hidden state of the previous layer

required
rotary_pos_emb_list list[Array]

list[chex.Array]: Pass in the frequency information

required
attention_mask Array

chex.Array: Mask out the attention weights for padding tokens

required
position_ids Array

chex.Array: Determine the position of each token in the sequence

required
causal_mask Array

chex.Array: Mask the attention weights

required
deterministic bool

bool: Control whether the dropout is applied or not

True
init_cache bool

bool: Initialize the cache in the attention layer

False
output_attentions bool

bool: Return the attention weights

False
fcm_mask Optional[ndarray]

Optional[jnp.ndarray]: Mask the self-attention

None

Control the dropout in the self attention layer

required

Returns:

Type Description

A tuple of two items

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def __call__(
        self,
        hidden_states: chex.Array,
        rotary_pos_emb_list: list[chex.Array],
        attention_mask: chex.Array,
        position_ids: chex.Array,
        causal_mask: chex.Array,
        segment_ids: Optional[chex.Array] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        encoder_hidden_states: Optional[chex.Array] = None,
        encoder_attention_mask: Optional[chex.Array] = None,
        fcm_mask: Optional[jnp.ndarray] = None,
):
    """
    The __call__ function is the main function of a TransformerEncoderLayer.
    It takes in hidden states, frequency-domain inputs, and masks as input. It then
    applies self-attention to the hidden states using those inputs and returns an
    output tensor with shape (batch_size, sequence_length, model_dim).

    :param self: Refer to the class instance itself
    :param hidden_states: chex.Array: Pass in the hidden state of the previous layer
    :param rotary_pos_emb_list: list[chex.Array]: Pass in the frequency information
    :param attention_mask: chex.Array: Mask out the attention weights for padding tokens
    :param position_ids: chex.Array: Determine the position of each token in the sequence
    :param causal_mask: chex.Array: Mask the attention weights
    :param deterministic: bool: Control whether the dropout is applied or not
    :param init_cache: bool: Initialize the cache in the attention layer
    :param output_attentions: bool: Return the attention weights
    :param fcm_mask: Optional[jnp.ndarray]: Mask the self-attention
    :param : Control the dropout in the self attention layer
    :return: A tuple of two items

    """
    # hidden_states: chex.Array
    # rotary_pos_emb_list: list[chex.Array]
    # attention_mask: chex.Array
    # position_ids: chex.Array
    # causal_mask: chex.Array
    # deterministic: bool = True
    # init_cache: bool = False
    # output_attentions: bool = False
    # encoder_hidden_states: Optional[chex.Array] = None
    # encoder_attention_mask: Optional[chex.Array] = None
    # fcm_mask = None

    attn_outputs = self.attn(
        self.ln_1(hidden_states),
        rotary_pos_emb_list,
        attention_mask,
        position_ids,
        causal_mask,
        segment_ids,
        deterministic,
        init_cache,
        output_attentions,
        encoder_attention_mask,
        encoder_hidden_states,
        fcm_mask,
    )
    attn_output = attn_outputs[0]
    hidden_states = hidden_states + attn_output

    feed_forward_input = self.ln_2(hidden_states)

    if self.config.use_scan_mlp:
        feed_forward_input = einops.rearrange(
            feed_forward_input,
            '... (b s) d -> ... b s d',
            b=self.config.scan_mlp_chunk_size
        )

        def mlp_forward(mlp, carry, x):
            return None, mlp(x, deterministic)

        scan_axis = feed_forward_input.ndim - 3

        _, feed_forward_hidden_states = nn.scan(
            mlp_forward,
            variable_broadcast="params",
            split_rngs={"params": False, "dropout": True},
            in_axes=scan_axis,
            out_axes=scan_axis,
        )(self.mlp, None, feed_forward_input)
        feed_forward_hidden_states = einops.rearrange(
            feed_forward_hidden_states,
            '... b s d -> ... (b s) d'
        )
    else:
        feed_forward_hidden_states = self.mlp(
            feed_forward_input,
            deterministic,
        )

    hidden_states = hidden_states + feed_forward_hidden_states

    return (hidden_states,) + attn_outputs[1:]

FlaxQwen1BlockCollection

Bases: Module

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
class FlaxQwen1BlockCollection(nn.Module):
    config: Qwen1Config
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self):
        self.blocks = [
            FlaxQwen1Block(
                self.config,
                name=str(i),
                dtype=self.dtype,
                param_dtype=self.param_dtype,
                precision=self.precision
            )
            for i in range(
                self.config.num_hidden_layers
            )
        ]

    def __call__(
            self,
            hidden_states: chex.Array,
            rotary_pos_emb_list: list[chex.Array],
            attention_mask: chex.Array,
            position_ids: chex.Array,
            causal_mask: chex.Array,
            deterministic: bool = True,
            init_cache: bool = False,
            output_attentions: bool = False,
            output_hidden_states: bool = False,
            return_dict: bool = True,
    ):
        """
        The __call__ function is the main function of a JAX nn.Module.
        It defines how the module behaves when called as a function, and it's what you'll use to call your model
         in training loops or inference scripts.
        The __call__ method should take all inputs that are necessary for computing outputs from the module,
        and return all outputs that are computed by this module.

        :param self: Represent the instance of the class
        :param hidden_states: chex.Array: Pass the input tensor to the encoder
        :param rotary_pos_emb_list: chex.Array: Pass in the frequency of each token
        :param attention_mask: chex.Array: Mask out certain tokens in the input sequence
        :param position_ids: chex.Array: Specify the position of each token in a sequence
        :param causal_mask: chex.Array: Mask the attention weights
        :param deterministic: bool: Determine whether the model is in training or evaluation mode
        :param init_cache: bool: Initialize the cache for each layer
        :param output_attentions: bool: Determine whether to output the attention weights
        :param output_hidden_states: bool: Determine whether to return the hidden states of each layer
        :param return_dict: bool: Return a dictionary of the outputs
        :param : Determine whether to use the forgetful causal mask
        :return: A tuple of 3 values

        """
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None

        if not deterministic and self.config.fcm_max_ratio > 0:
            # Apply forgetful causal mask
            batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
            fcm_ratio = jax.random.uniform(
                self.make_rng('fcm'), shape=(batch_size, 1, 1, 1),
                minval=self.config.fcm_min_ratio,
                maxval=self.config.fcm_max_ratio
            )
            fcm_mask = jax.random.uniform(
                self.make_rng('fcm'),
                shape=(batch_size, 1, seq_length, seq_length)
            ) > fcm_ratio
            fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
            fcm_mask = fcm_mask.astype('bool')
        else:
            fcm_mask = None

        for block in self.blocks:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = block(
                hidden_states=hidden_states,
                rotary_pos_emb_list=rotary_pos_emb_list,
                attention_mask=attention_mask,
                position_ids=position_ids,
                causal_mask=causal_mask,
                deterministic=deterministic,
                init_cache=init_cache,
                output_attentions=output_attentions,
                fcm_mask=fcm_mask,
            )
            hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions += (layer_outputs[1],)

        outputs = (hidden_states, all_hidden_states, all_attentions)

        return outputs

__call__(hidden_states, rotary_pos_emb_list, attention_mask, position_ids, causal_mask, deterministic=True, init_cache=False, output_attentions=False, output_hidden_states=False, return_dict=True)

The call function is the main function of a JAX nn.Module. It defines how the module behaves when called as a function, and it's what you'll use to call your model in training loops or inference scripts. The call method should take all inputs that are necessary for computing outputs from the module, and return all outputs that are computed by this module.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
hidden_states Array

chex.Array: Pass the input tensor to the encoder

required
rotary_pos_emb_list list[Array]

chex.Array: Pass in the frequency of each token

required
attention_mask Array

chex.Array: Mask out certain tokens in the input sequence

required
position_ids Array

chex.Array: Specify the position of each token in a sequence

required
causal_mask Array

chex.Array: Mask the attention weights

required
deterministic bool

bool: Determine whether the model is in training or evaluation mode

True
init_cache bool

bool: Initialize the cache for each layer

False
output_attentions bool

bool: Determine whether to output the attention weights

False
output_hidden_states bool

bool: Determine whether to return the hidden states of each layer

False
return_dict bool

bool: Return a dictionary of the outputs

True

Determine whether to use the forgetful causal mask

required

Returns:

Type Description

A tuple of 3 values

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
def __call__(
        self,
        hidden_states: chex.Array,
        rotary_pos_emb_list: list[chex.Array],
        attention_mask: chex.Array,
        position_ids: chex.Array,
        causal_mask: chex.Array,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
):
    """
    The __call__ function is the main function of a JAX nn.Module.
    It defines how the module behaves when called as a function, and it's what you'll use to call your model
     in training loops or inference scripts.
    The __call__ method should take all inputs that are necessary for computing outputs from the module,
    and return all outputs that are computed by this module.

    :param self: Represent the instance of the class
    :param hidden_states: chex.Array: Pass the input tensor to the encoder
    :param rotary_pos_emb_list: chex.Array: Pass in the frequency of each token
    :param attention_mask: chex.Array: Mask out certain tokens in the input sequence
    :param position_ids: chex.Array: Specify the position of each token in a sequence
    :param causal_mask: chex.Array: Mask the attention weights
    :param deterministic: bool: Determine whether the model is in training or evaluation mode
    :param init_cache: bool: Initialize the cache for each layer
    :param output_attentions: bool: Determine whether to output the attention weights
    :param output_hidden_states: bool: Determine whether to return the hidden states of each layer
    :param return_dict: bool: Return a dictionary of the outputs
    :param : Determine whether to use the forgetful causal mask
    :return: A tuple of 3 values

    """
    all_attentions = () if output_attentions else None
    all_hidden_states = () if output_hidden_states else None

    if not deterministic and self.config.fcm_max_ratio > 0:
        # Apply forgetful causal mask
        batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
        fcm_ratio = jax.random.uniform(
            self.make_rng('fcm'), shape=(batch_size, 1, 1, 1),
            minval=self.config.fcm_min_ratio,
            maxval=self.config.fcm_max_ratio
        )
        fcm_mask = jax.random.uniform(
            self.make_rng('fcm'),
            shape=(batch_size, 1, seq_length, seq_length)
        ) > fcm_ratio
        fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
        fcm_mask = fcm_mask.astype('bool')
    else:
        fcm_mask = None

    for block in self.blocks:
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        layer_outputs = block(
            hidden_states=hidden_states,
            rotary_pos_emb_list=rotary_pos_emb_list,
            attention_mask=attention_mask,
            position_ids=position_ids,
            causal_mask=causal_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            fcm_mask=fcm_mask,
        )
        hidden_states = layer_outputs[0]

        if output_attentions:
            all_attentions += (layer_outputs[1],)

    outputs = (hidden_states, all_hidden_states, all_attentions)

    return outputs

FlaxQwen1ForCausalLMModule

Bases: Module

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
class FlaxQwen1ForCausalLMModule(nn.Module):
    config: Qwen1Config
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self):
        self.transformer = FlaxQwen1Module(
            self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
        )

        self.lm_head = Linear(
            self.config.vocab_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: chex.Array = None,
            position_ids: chex.Array = None,
            deterministic: bool = True,
            init_cache: bool = False,
            output_attentions: bool = False,
            output_hidden_states: bool = False,
            return_dict: bool = True,
            extra_embedding: Optional[Union[jnp.ndarray, None]] = None
    ):
        """
        The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs.

        :param self: Refer to the object itself
        :param input_ids: chex.Array: Pass the input token ids to the model
        :param attention_mask: chex.Array: Mask out the padding tokens
        :param position_ids: chex.Array: Specify the position of each token in the input sequence
        :param deterministic: bool: Control whether the model is trained or not
        :param init_cache: bool: Initialize the cache for the decoder
        :param output_attentions: bool: Return the attention weights
        :param output_hidden_states: bool: Determine whether to return the hidden states
        :param return_dict: bool: Return a dictionary of the outputs or not
        :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the word that we want to predict
        :param None]]: Pass in the extra embedding
        :return: The logits and the hidden states

        """
        batch_size, seq_length = input_ids.shape
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
                (batch_size, seq_length)
            )
        outputs = self.transformer(
            input_ids,
            attention_mask,
            position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            extra_embedding=extra_embedding
        )

        hidden_states = outputs[0]

        if self.config.tie_word_embeddings:
            shared_kernel = self.model.variables["params"]["wte"]["embedding"]
            shared_kernel = fjformer.linen.linen.control_quantization(shared_kernel, self.param_dtype).T
            lm_logits = self.lm_head.apply(
                {"params": {"kernel": shared_kernel}}, hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)

        lm_logits = lm_logits.astype(jnp.float32)

        if not return_dict:
            return (lm_logits,) + outputs[1:]

        return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

__call__(input_ids, attention_mask=None, position_ids=None, deterministic=True, init_cache=False, output_attentions=False, output_hidden_states=False, return_dict=True, extra_embedding=None)

The call function is the main function of a Flax module. It takes in inputs and returns outputs.

Parameters:

Name Type Description Default
self

Refer to the object itself

required
input_ids Array

chex.Array: Pass the input token ids to the model

required
attention_mask Array

chex.Array: Mask out the padding tokens

None
position_ids Array

chex.Array: Specify the position of each token in the input sequence

None
deterministic bool

bool: Control whether the model is trained or not

True
init_cache bool

bool: Initialize the cache for the decoder

False
output_attentions bool

bool: Return the attention weights

False
output_hidden_states bool

bool: Determine whether to return the hidden states

False
return_dict bool

bool: Return a dictionary of the outputs or not

True
extra_embedding Optional[Union[ndarray, None]]

Optional[Union[jnp.ndarray: Pass in the embedding of the word that we want to predict

None
None]]

Pass in the extra embedding

required

Returns:

Type Description

The logits and the hidden states

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
def __call__(
        self,
        input_ids: chex.Array,
        attention_mask: chex.Array = None,
        position_ids: chex.Array = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        extra_embedding: Optional[Union[jnp.ndarray, None]] = None
):
    """
    The __call__ function is the main function of a Flax module. It takes in inputs and returns outputs.

    :param self: Refer to the object itself
    :param input_ids: chex.Array: Pass the input token ids to the model
    :param attention_mask: chex.Array: Mask out the padding tokens
    :param position_ids: chex.Array: Specify the position of each token in the input sequence
    :param deterministic: bool: Control whether the model is trained or not
    :param init_cache: bool: Initialize the cache for the decoder
    :param output_attentions: bool: Return the attention weights
    :param output_hidden_states: bool: Determine whether to return the hidden states
    :param return_dict: bool: Return a dictionary of the outputs or not
    :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of the word that we want to predict
    :param None]]: Pass in the extra embedding
    :return: The logits and the hidden states

    """
    batch_size, seq_length = input_ids.shape
    if attention_mask is None:
        attention_mask = jnp.ones_like(input_ids)
    if position_ids is None:
        position_ids = jnp.broadcast_to(
            jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
            (batch_size, seq_length)
        )
    outputs = self.transformer(
        input_ids,
        attention_mask,
        position_ids,
        deterministic=deterministic,
        init_cache=init_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        extra_embedding=extra_embedding
    )

    hidden_states = outputs[0]

    if self.config.tie_word_embeddings:
        shared_kernel = self.model.variables["params"]["wte"]["embedding"]
        shared_kernel = fjformer.linen.linen.control_quantization(shared_kernel, self.param_dtype).T
        lm_logits = self.lm_head.apply(
            {"params": {"kernel": shared_kernel}}, hidden_states)
    else:
        lm_logits = self.lm_head(hidden_states)

    lm_logits = lm_logits.astype(jnp.float32)

    if not return_dict:
        return (lm_logits,) + outputs[1:]

    return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

FlaxQwen1ForSequenceClassificationModule

Bases: Module

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
class FlaxQwen1ForSequenceClassificationModule(nn.Module):
    num_classes: int
    config: Qwen1Config
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self):
        """
        The setup function is called once at the beginning of training.
        It initializes the model and optimizer, and sets up any other state that needs to be initialized.

        :param self: Access variables that belong to the class
        :return: A tuple of the model and the classifier
        """
        self.model = FlaxQwen1Module(self.config, dtype=self.dtype)
        self.classifier = Linear(
            self.num_classes,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range),
            precision=self.precision,
        )

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: chex.Array = None,
            position_ids: chex.Array = None,
            deterministic: bool = True,
            init_cache: bool = False,
            output_attentions: bool = False,
            output_hidden_states: bool = False,
            return_dict: bool = True,
            extra_embedding: Optional[Union[jnp.ndarray, None]] = None
    ):
        """
        The __call__ function is the main function of a Flax module.
        It takes in all the inputs to the model and returns all outputs from it.
        The __call__ function can be called directly on an instance of a class, or by using parentheses after an instance:
            >>> my_model = MyModel()  # instantiate your model class
            >>> output = my_model(input)  # call your model with input data as arguments to __call__

        :param self: Refer to the class instance
        :param input_ids: chex.Array: Pass the input to the model
        :param attention_mask: chex.Array: Specify which tokens are masked
        :param position_ids: chex.Array: Specify the position of each token in the sequence
        :param deterministic: bool: Control whether the model is run in deterministic or stochastic mode
        :param init_cache: bool: Initialize the cache for the transformer
        :param output_attentions: bool: Return the attention weights
        :param output_hidden_states: bool: Return the hidden states of all h
        :param return_dict: bool: Return a dictionary of outputs
        :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of a new word
        :param None]]: Pass the extra embedding to the model
        :return: A tuple of logits and hidden_states

        """
        batch_size, seq_length = input_ids.shape
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
                (batch_size, seq_length)
            )
        outputs = self.model(
            input_ids,
            attention_mask,
            position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            extra_embedding=extra_embedding
        )

        hidden_states = outputs[0]
        prediction = self.classifier(hidden_states)
        if return_dict:
            return FlaxSequenceClassifierOutput(
                logits=prediction,
                hidden_states=hidden_states
            )
        else:
            return prediction,

__call__(input_ids, attention_mask=None, position_ids=None, deterministic=True, init_cache=False, output_attentions=False, output_hidden_states=False, return_dict=True, extra_embedding=None)

The call function is the main function of a Flax module. It takes in all the inputs to the model and returns all outputs from it. The call function can be called directly on an instance of a class, or by using parentheses after an instance: >>> my_model = MyModel() # instantiate your model class >>> output = my_model(input) # call your model with input data as arguments to call

Parameters:

Name Type Description Default
self

Refer to the class instance

required
input_ids Array

chex.Array: Pass the input to the model

required
attention_mask Array

chex.Array: Specify which tokens are masked

None
position_ids Array

chex.Array: Specify the position of each token in the sequence

None
deterministic bool

bool: Control whether the model is run in deterministic or stochastic mode

True
init_cache bool

bool: Initialize the cache for the transformer

False
output_attentions bool

bool: Return the attention weights

False
output_hidden_states bool

bool: Return the hidden states of all h

False
return_dict bool

bool: Return a dictionary of outputs

True
extra_embedding Optional[Union[ndarray, None]]

Optional[Union[jnp.ndarray: Pass in the embedding of a new word

None
None]]

Pass the extra embedding to the model

required

Returns:

Type Description

A tuple of logits and hidden_states

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
def __call__(
        self,
        input_ids: chex.Array,
        attention_mask: chex.Array = None,
        position_ids: chex.Array = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        extra_embedding: Optional[Union[jnp.ndarray, None]] = None
):
    """
    The __call__ function is the main function of a Flax module.
    It takes in all the inputs to the model and returns all outputs from it.
    The __call__ function can be called directly on an instance of a class, or by using parentheses after an instance:
        >>> my_model = MyModel()  # instantiate your model class
        >>> output = my_model(input)  # call your model with input data as arguments to __call__

    :param self: Refer to the class instance
    :param input_ids: chex.Array: Pass the input to the model
    :param attention_mask: chex.Array: Specify which tokens are masked
    :param position_ids: chex.Array: Specify the position of each token in the sequence
    :param deterministic: bool: Control whether the model is run in deterministic or stochastic mode
    :param init_cache: bool: Initialize the cache for the transformer
    :param output_attentions: bool: Return the attention weights
    :param output_hidden_states: bool: Return the hidden states of all h
    :param return_dict: bool: Return a dictionary of outputs
    :param extra_embedding: Optional[Union[jnp.ndarray: Pass in the embedding of a new word
    :param None]]: Pass the extra embedding to the model
    :return: A tuple of logits and hidden_states

    """
    batch_size, seq_length = input_ids.shape
    if attention_mask is None:
        attention_mask = jnp.ones_like(input_ids)
    if position_ids is None:
        position_ids = jnp.broadcast_to(
            jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
            (batch_size, seq_length)
        )
    outputs = self.model(
        input_ids,
        attention_mask,
        position_ids,
        deterministic=deterministic,
        init_cache=init_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        extra_embedding=extra_embedding
    )

    hidden_states = outputs[0]
    prediction = self.classifier(hidden_states)
    if return_dict:
        return FlaxSequenceClassifierOutput(
            logits=prediction,
            hidden_states=hidden_states
        )
    else:
        return prediction,

setup()

The setup function is called once at the beginning of training. It initializes the model and optimizer, and sets up any other state that needs to be initialized.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required

Returns:

Type Description

A tuple of the model and the classifier

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
def setup(self):
    """
    The setup function is called once at the beginning of training.
    It initializes the model and optimizer, and sets up any other state that needs to be initialized.

    :param self: Access variables that belong to the class
    :return: A tuple of the model and the classifier
    """
    self.model = FlaxQwen1Module(self.config, dtype=self.dtype)
    self.classifier = Linear(
        self.num_classes,
        dtype=self.dtype,
        param_dtype=self.param_dtype,
        use_bias=False,
        kernel_init=jax.nn.initializers.normal(
            stddev=self.config.initializer_range),
        precision=self.precision,
    )

FlaxQwen1MLP

Bases: Module

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class FlaxQwen1MLP(nn.Module):
    config: Qwen1Config
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self) -> None:
        config = self.config

        self.w1 = Linear(
            config.intermediate_size // 2,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=not self.config.no_bias,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range
            ),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.w2 = Linear(
            config.intermediate_size // 2,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=not self.config.no_bias,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )
        self.c_proj = Linear(
            config.hidden_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=not self.config.no_bias,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            precision=self.precision,
            **get_dot_general_by_bits(self.config.bits, self.config.easy_method)
        )

    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        """
        The __call__ function is the main function of a class.
        It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments).
        The __call__ method enables instances of a class to be called like standard Python functions.

        :param self: Represent the instance of the class
        :param x: jnp.ndarray: Pass in the input to the layer
        :param deterministic: bool: Determine whether to use dropout
        :return: A tensor that is the result of applying a dropout function to x

        """
        x = self.c_proj(jax.nn.silu(self.w2(x)) * self.w1(x))
        return x

__call__(x, deterministic=True)

The call function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments). The call method enables instances of a class to be called like standard Python functions.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
x ndarray

jnp.ndarray: Pass in the input to the layer

required
deterministic bool

bool: Determine whether to use dropout

True

Returns:

Type Description
ndarray

A tensor that is the result of applying a dropout function to x

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
    """
    The __call__ function is the main function of a class.
    It is called when an instance of the class (an object) is invoked as a function, i.e., obj(arguments).
    The __call__ method enables instances of a class to be called like standard Python functions.

    :param self: Represent the instance of the class
    :param x: jnp.ndarray: Pass in the input to the layer
    :param deterministic: bool: Determine whether to use dropout
    :return: A tensor that is the result of applying a dropout function to x

    """
    x = self.c_proj(jax.nn.silu(self.w2(x)) * self.w1(x))
    return x

FlaxQwen1Module

Bases: Module

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
class FlaxQwen1Module(nn.Module):
    config: Qwen1Config
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]] = None

    def setup(self):

        self.wte = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            param_dtype=self.param_dtype,
        )
        self.drop = flax.linen.Dropout(rate=self.config.emb_dropout_prob)
        self.h = FlaxQwen1BlockCollection(
            self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision
        )
        self.ln_f = Qwen1RMSNorm(
            self.config.hidden_size,
            eps=self.config.layer_norm_epsilon,
            dtype=self.dtype,
            param_dtype=self.param_dtype
        )
        config = self.config
        if config.rotary_pct == 1.0:
            self.rotary_ndims = None
        else:
            assert config.rotary_pct < 1
            self.rotary_ndims = int(
                config.kv_channels * config.rotary_pct
            )
        self.causal_mask = make_causal_mask(
            jnp.ones(
                (1, getattr(config, "c_max_position_embeddings", config.seq_length)),
                dtype="bool"),
            dtype="bool"
        )
        self.rope_cache = compute_qwen1_rope(
            dim=self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels,
            base=self.config.rotary_emb_base,
            seqlen=getattr(config, "freq_max_position_embeddings", config.seq_length)
        )

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: chex.Array,
            position_ids: chex.Array,
            deterministic: bool = True,
            inputs_embeds: chex.Array = None,
            init_cache: bool = False,
            output_attentions: bool = False,
            output_hidden_states: bool = False,
            return_dict: bool = True,
            extra_embedding: Optional[Union[jnp.ndarray, None]] = None
    ):
        """
        The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids
        and returns the output of the model. The __call__ function also has optional arguments that can be used to control
        the behavior of the model (e.g., deterministic=True). These optional arguments are passed as keyword arguments when
        calling a Flax model.

        :param self: Represent the instance of the class
        :param input_ids: chex.Array: Pass in the input token ids
        :param attention_mask: chex.Array: Mask out the padding tokens
        :param position_ids: chex.Array: Indicate the position of each token in a sequence
        :param deterministic: bool: Control whether dropout is applied or not
        :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens
        :param init_cache: bool: Initialize the cache
        :param output_attentions: bool: Determine whether to return the attentions or not
        :param output_hidden_states: bool: Determine whether to return hidden states
        :param return_dict: bool: Return a dictionary of the output or not
        :param extra_embedding: Optional[Union[jnp.ndarray, None]]: Pass in the embedding of the
        :return: A tuple of:

        """
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids.astype("i4"))

        batch_size, sequence_length, _ = inputs_embeds.shape
        kv_seq_len = sequence_length

        if self.h.blocks[0].attn.has_variable("cache", "cached_key"):
            cache_index = self.h.blocks[0].attn.get_variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
            kv_seq_len += cache_index

        # if deterministic or not self.config.use_dynamic_ntk:
        #     ntk_alpha_list = [1.0]
        # elif kv_seq_len != inputs_embeds.shape[1]:
        #     ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
        # else:
        #     ntk_alpha_list = []
        #     if attention_mask is not None and kv_seq_len > self.seq_length:
        #         true_seq_lens = jnp.sum(attention_mask.reshape(batch_size, 1, 1, -1) == 0, axis=-1, dtype=jnp.float32)
        #         for i in range(inputs_embeds.shape[0]):
        #             true_seq_len = true_seq_lens[i].item()
        #             ntk_alpha = self.get_ntk_alpha(true_seq_len)
        #             ntk_alpha_list.append(ntk_alpha)
        #     else:
        #         ntk_alpha = self.get_ntk_alpha(kv_seq_len)
        #         ntk_alpha_list.append(ntk_alpha)
        # self.rotary_emb.set_ntk_alpha_cached_list(ntk_alpha_list)
        # rotary_pos_emb_list = []
        assert sequence_length <= self.config.seq_length, "Maximum Position Embedding Reached !"
        inputs_embeds = inputs_embeds + extra_embedding if extra_embedding is not None else inputs_embeds
        hidden_states = self.drop(
            inputs_embeds, deterministic=deterministic
        )

        outputs = self.h(
            hidden_states=hidden_states,
            rotary_pos_emb_list=[self.rope_cache],
            attention_mask=attention_mask,
            position_ids=position_ids,
            causal_mask=self.causal_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        hidden_states = self.ln_f(hidden_states)

        if output_hidden_states:
            all_hidden_states = outputs[1] + (hidden_states,)
            outputs = (hidden_states, all_hidden_states) + outputs[2:]
        else:
            outputs = (hidden_states,) + outputs[1:]

        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=outputs[1],
            attentions=outputs[-1],
        )

__call__(input_ids, attention_mask, position_ids, deterministic=True, inputs_embeds=None, init_cache=False, output_attentions=False, output_hidden_states=False, return_dict=True, extra_embedding=None)

The call function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids and returns the output of the model. The call function also has optional arguments that can be used to control the behavior of the model (e.g., deterministic=True). These optional arguments are passed as keyword arguments when calling a Flax model.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
input_ids Array

chex.Array: Pass in the input token ids

required
attention_mask Array

chex.Array: Mask out the padding tokens

required
position_ids Array

chex.Array: Indicate the position of each token in a sequence

required
deterministic bool

bool: Control whether dropout is applied or not

True
inputs_embeds Array

chex.Array: Pass in the embeddings of the input tokens

None
init_cache bool

bool: Initialize the cache

False
output_attentions bool

bool: Determine whether to return the attentions or not

False
output_hidden_states bool

bool: Determine whether to return hidden states

False
return_dict bool

bool: Return a dictionary of the output or not

True
extra_embedding Optional[Union[ndarray, None]]

Optional[Union[jnp.ndarray, None]]: Pass in the embedding of the

None

Returns:

Type Description

A tuple of:

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
def __call__(
        self,
        input_ids: chex.Array,
        attention_mask: chex.Array,
        position_ids: chex.Array,
        deterministic: bool = True,
        inputs_embeds: chex.Array = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        extra_embedding: Optional[Union[jnp.ndarray, None]] = None
):
    """
    The __call__ function is the main function of a Flax model. It takes in input_ids, attention_mask, and position_ids
    and returns the output of the model. The __call__ function also has optional arguments that can be used to control
    the behavior of the model (e.g., deterministic=True). These optional arguments are passed as keyword arguments when
    calling a Flax model.

    :param self: Represent the instance of the class
    :param input_ids: chex.Array: Pass in the input token ids
    :param attention_mask: chex.Array: Mask out the padding tokens
    :param position_ids: chex.Array: Indicate the position of each token in a sequence
    :param deterministic: bool: Control whether dropout is applied or not
    :param inputs_embeds: chex.Array: Pass in the embeddings of the input tokens
    :param init_cache: bool: Initialize the cache
    :param output_attentions: bool: Determine whether to return the attentions or not
    :param output_hidden_states: bool: Determine whether to return hidden states
    :param return_dict: bool: Return a dictionary of the output or not
    :param extra_embedding: Optional[Union[jnp.ndarray, None]]: Pass in the embedding of the
    :return: A tuple of:

    """
    if inputs_embeds is None:
        inputs_embeds = self.wte(input_ids.astype("i4"))

    batch_size, sequence_length, _ = inputs_embeds.shape
    kv_seq_len = sequence_length

    if self.h.blocks[0].attn.has_variable("cache", "cached_key"):
        cache_index = self.h.blocks[0].attn.get_variable(
            "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
        )
        kv_seq_len += cache_index

    # if deterministic or not self.config.use_dynamic_ntk:
    #     ntk_alpha_list = [1.0]
    # elif kv_seq_len != inputs_embeds.shape[1]:
    #     ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
    # else:
    #     ntk_alpha_list = []
    #     if attention_mask is not None and kv_seq_len > self.seq_length:
    #         true_seq_lens = jnp.sum(attention_mask.reshape(batch_size, 1, 1, -1) == 0, axis=-1, dtype=jnp.float32)
    #         for i in range(inputs_embeds.shape[0]):
    #             true_seq_len = true_seq_lens[i].item()
    #             ntk_alpha = self.get_ntk_alpha(true_seq_len)
    #             ntk_alpha_list.append(ntk_alpha)
    #     else:
    #         ntk_alpha = self.get_ntk_alpha(kv_seq_len)
    #         ntk_alpha_list.append(ntk_alpha)
    # self.rotary_emb.set_ntk_alpha_cached_list(ntk_alpha_list)
    # rotary_pos_emb_list = []
    assert sequence_length <= self.config.seq_length, "Maximum Position Embedding Reached !"
    inputs_embeds = inputs_embeds + extra_embedding if extra_embedding is not None else inputs_embeds
    hidden_states = self.drop(
        inputs_embeds, deterministic=deterministic
    )

    outputs = self.h(
        hidden_states=hidden_states,
        rotary_pos_emb_list=[self.rope_cache],
        attention_mask=attention_mask,
        position_ids=position_ids,
        causal_mask=self.causal_mask,
        deterministic=deterministic,
        init_cache=init_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    hidden_states = outputs[0]
    hidden_states = self.ln_f(hidden_states)

    if output_hidden_states:
        all_hidden_states = outputs[1] + (hidden_states,)
        outputs = (hidden_states, all_hidden_states) + outputs[2:]
    else:
        outputs = (hidden_states,) + outputs[1:]

    if not return_dict:
        return tuple(v for v in outputs if v is not None)

    return FlaxBaseModelOutput(
        last_hidden_state=hidden_states,
        hidden_states=outputs[1],
        attentions=outputs[-1],
    )

FlaxQwen1PreTrainedModel

Bases: EasyDeLFlaxPretrainedModel

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
class FlaxQwen1PreTrainedModel(EasyDeLFlaxPretrainedModel):
    config_class = Qwen1Config
    base_model_prefix = "model"
    module_class: nn.Module = None

    def __init__(
            self,
            config: Qwen1Config,
            input_shape: Tuple = (1, 1),
            seed: int = 0,
            dtype: jnp.dtype = jnp.float32,
            _do_init: bool = True,
            **kwargs,
    ):
        """
        The __init__ function is called when the class is instantiated.
        It sets up the instance of the class, and defines what happens when it's created.
        The __init__ function can take arguments, but self is always required (it refers to the instance of the object).


        :param self: Refer to the object itself
        :param config: Qwen1Config: Pass the configuration to the module
        :param input_shape: Tuple: Specify the shape of the input to the model
        :param seed: int: Set the seed for random number generation
        :param dtype: jnp.dtype: Specify the data type of the input
        :param _do_init: bool: Control whether the module is initialized or not
        :param kwargs: Pass in any additional parameters that the module_class might need
        :param : Specify the number of h in the network
        :return: The super() of the class

        """
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        """
        The init_weights function is used to initialize the weights of a model.

        :param self: Access variables that belong to the class
        :param rng: jax.random.PRNGKey: Initialize the weights of the model
        :param input_shape: Tuple: Specify the shape of the input tensor
        :param params: FrozenDict: Pass in the parameters of a pre-trained model
        :return: A frozendict of parameters

        """
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        if self.config.add_cross_attention:
            encoder_hidden_states = jnp.zeros(
                input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                position_ids,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            module_init_outputs = self.module.init(
                rngs, input_ids, attention_mask, position_ids, return_dict=False)

        random_params = module_init_outputs["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def init_cache(self, batch_size, max_length):
        """
        The init_cache function is used to initialize the cache for a given batch size and sequence length.
        The cache is a dictionary that contains all the intermediate states from each layer in the model.
        This allows us to run inference on multiple batches without having to re-run forward passes through every layer in
        the model, which would be very slow.

        :param self: Access the module
        :param batch_size: Define the batch size of the input tensors
        :param max_length: Set the length of the input sequence
        :return: A dictionary with the following keys:

        """
        input_ids = jnp.ones((batch_size, max_length))
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(
            jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return init_variables["cache"]

    # def init_rope(self, batch_size, max_length):
    #     """
    #     The init_rope function is used to initialize the rope for a given batch size and sequence length.
    #     The cache is a dictionary that contains all the intermediate states from each layer in the model.
    #
    #     :param self: Access the module
    #     :param batch_size: Define the batch size of the input tensors
    #     :param max_length: Set the length of the input sequence
    #     """
    #     input_ids = jnp.ones((batch_size, max_length))
    #     attention_mask = jnp.ones_like(input_ids)
    #     position_ids = jnp.broadcast_to(jnp.arange(
    #         jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
    #
    #     init_variables = self.module.init(
    #         jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
    #     )
    #     return init_variables["rope_cache"]

    def __call__(
            self,
            input_ids: chex.Array,
            attention_mask: chex.Array = None,
            position_ids: chex.Array = None,
            params: dict = None,
            past_key_values: dict = None,
            # past_rope_cache: dict = None,
            dropout_rng: jax.random.PRNGKey = None,
            train: bool = False,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = True,
            extra_embedding: Optional[Union[jnp.ndarray, None]] = None,
            add_params_field: bool = False,
            **kwargs
    ):
        """
        The __call__ function is the main function of a JAX module.
        It takes in inputs and returns outputs, but it also has some other important features:
        - It can take in mutable state (e.g., past_key_values) that will be updated during the call and returned at the end.
        - It can take in random number generators (rngs) that are used to generate random numbers for dropout or sampling operations.

        :param self: Represent the instance of the class
        :param input_ids: chex.Array: Pass in the input tokens
        :param attention_mask: chex.Array: Mask out certain tokens in the input
        :param position_ids: chex.Array: Create the positional embeddings
        :param params: dict: Pass in the parameters of the model
        :param past_key_values: dict: Pass in the past key values from a previous call to __call__
        :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way
        :param train: bool: Determine whether to use dropout or not
        :param output_attentions: Optional[bool]: Determine whether to return the attention weights
        :param output_hidden_states: Optional[bool]: Return the hidden states of all h
        :param return_dict: Optional[bool]: Determine whether to return a dictionary or not
        :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids
        :param add_params_field: bool: Add the params field to the inputs dictionary
        :return: A tuple of the following:

        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        assert sequence_length <= self.config.seq_length, "Maximum Position Embedding Reached !"

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `position_ids` when passing `past_key_values`.")

            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[
                                            None, :], (batch_size, sequence_length))

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        if self.config.bits is not None:
            rngs['params'] = jax.random.key(0)

        inputs = {
            "params": params or self.params
        } if add_params_field else params or self.params
        mutable = False
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]

        # if past_rope_cache is not None:
        #     inputs["rope_cache"] = past_rope_cache
        # elif self.config.init_rope_cache_auto:
        #     inputs["rope_cache"] = self.init_rope(batch_size=batch_size, max_length=sequence_length)
        # else:
        #     raise ValueError(
        #         "if you are setting `init_rope_cache_auto=False` you should pass `rope_cache` beside param"
        #     )
        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            extra_embedding,
            rngs=rngs,
            mutable=mutable,
        )

        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
        # if return_dict:
        #     outputs["past_rope_cache"] = unfreeze(rope_cache["rope_cache"])
        # else:
        #     outputs = outputs, unfreeze(rope_cache["rope_cache"])
        return outputs

    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
        """
        The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

        :param self: Access variables that belong to the class
        :param input_ids: Pass in the input tokens
        :param max_length: Set the length of the sequence to be generated
        :param attention_mask: Optional[chex.Array]: Mask the attention weights
        :return: A dictionary of the past_key_values, attention_mask and position ids

        """
        batch_size, seq_length = input_ids.shape
        extended_attention_mask = jnp.ones(
            (batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = jax.lax.dynamic_update_slice(
                extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                            None, :], (batch_size, seq_length))

        return {
            "past_key_values": self.init_cache(batch_size, max_length),
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
            # "past_rope_cache": self.init_rope(batch_size=batch_size, max_length=max_length)
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        # model_kwargs["past_rope_cache"] = model_outputs.past_rope_cache
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs

__call__(input_ids, attention_mask=None, position_ids=None, params=None, past_key_values=None, dropout_rng=None, train=False, output_attentions=None, output_hidden_states=None, return_dict=True, extra_embedding=None, add_params_field=False, **kwargs)

The call function is the main function of a JAX module. It takes in inputs and returns outputs, but it also has some other important features: - It can take in mutable state (e.g., past_key_values) that will be updated during the call and returned at the end. - It can take in random number generators (rngs) that are used to generate random numbers for dropout or sampling operations.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
input_ids Array

chex.Array: Pass in the input tokens

required
attention_mask Array

chex.Array: Mask out certain tokens in the input

None
position_ids Array

chex.Array: Create the positional embeddings

None
params dict

dict: Pass in the parameters of the model

None
past_key_values dict

dict: Pass in the past key values from a previous call to call

None
dropout_rng PRNGKey

jax.random.PRNGKey: Make sure that the dropout is applied in a random way

None
train bool

bool: Determine whether to use dropout or not

False
output_attentions Optional[bool]

Optional[bool]: Determine whether to return the attention weights

None
output_hidden_states Optional[bool]

Optional[bool]: Return the hidden states of all h

None
return_dict Optional[bool]

Optional[bool]: Determine whether to return a dictionary or not

True
extra_embedding Optional[Union[ndarray, None]]

Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids

None
add_params_field bool

bool: Add the params field to the inputs dictionary

False

Returns:

Type Description

A tuple of the following:

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
def __call__(
        self,
        input_ids: chex.Array,
        attention_mask: chex.Array = None,
        position_ids: chex.Array = None,
        params: dict = None,
        past_key_values: dict = None,
        # past_rope_cache: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = True,
        extra_embedding: Optional[Union[jnp.ndarray, None]] = None,
        add_params_field: bool = False,
        **kwargs
):
    """
    The __call__ function is the main function of a JAX module.
    It takes in inputs and returns outputs, but it also has some other important features:
    - It can take in mutable state (e.g., past_key_values) that will be updated during the call and returned at the end.
    - It can take in random number generators (rngs) that are used to generate random numbers for dropout or sampling operations.

    :param self: Represent the instance of the class
    :param input_ids: chex.Array: Pass in the input tokens
    :param attention_mask: chex.Array: Mask out certain tokens in the input
    :param position_ids: chex.Array: Create the positional embeddings
    :param params: dict: Pass in the parameters of the model
    :param past_key_values: dict: Pass in the past key values from a previous call to __call__
    :param dropout_rng: jax.random.PRNGKey: Make sure that the dropout is applied in a random way
    :param train: bool: Determine whether to use dropout or not
    :param output_attentions: Optional[bool]: Determine whether to return the attention weights
    :param output_hidden_states: Optional[bool]: Return the hidden states of all h
    :param return_dict: Optional[bool]: Determine whether to return a dictionary or not
    :param extra_embedding: Optional[Union[jnp.ndarray,None]]: Pass in the embedding for the input_ids
    :param add_params_field: bool: Add the params field to the inputs dictionary
    :return: A tuple of the following:

    """
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.return_dict

    batch_size, sequence_length = input_ids.shape

    assert sequence_length <= self.config.seq_length, "Maximum Position Embedding Reached !"

    if position_ids is None:
        if past_key_values is not None:
            raise ValueError(
                "Make sure to provide `position_ids` when passing `past_key_values`.")

        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[
                                        None, :], (batch_size, sequence_length))

    if attention_mask is None:
        attention_mask = jnp.ones((batch_size, sequence_length))

    rngs = {}
    if dropout_rng is not None:
        rngs["dropout"] = dropout_rng

    if self.config.bits is not None:
        rngs['params'] = jax.random.key(0)

    inputs = {
        "params": params or self.params
    } if add_params_field else params or self.params
    mutable = False
    if past_key_values:
        inputs["cache"] = past_key_values
        mutable = ["cache"]

    # if past_rope_cache is not None:
    #     inputs["rope_cache"] = past_rope_cache
    # elif self.config.init_rope_cache_auto:
    #     inputs["rope_cache"] = self.init_rope(batch_size=batch_size, max_length=sequence_length)
    # else:
    #     raise ValueError(
    #         "if you are setting `init_rope_cache_auto=False` you should pass `rope_cache` beside param"
    #     )
    outputs = self.module.apply(
        inputs,
        jnp.array(input_ids, dtype="i4"),
        jnp.array(attention_mask, dtype="i4"),
        jnp.array(position_ids, dtype="i4"),
        not train,
        False,
        output_attentions,
        output_hidden_states,
        return_dict,
        extra_embedding,
        rngs=rngs,
        mutable=mutable,
    )

    if past_key_values is not None and return_dict:
        outputs, past_key_values = outputs
        outputs["past_key_values"] = unfreeze(past_key_values["cache"])
        return outputs
    elif past_key_values is not None and not return_dict:
        outputs, past_key_values = outputs
        outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
    # if return_dict:
    #     outputs["past_rope_cache"] = unfreeze(rope_cache["rope_cache"])
    # else:
    #     outputs = outputs, unfreeze(rope_cache["rope_cache"])
    return outputs

__init__(config, input_shape=(1, 1), seed=0, dtype=jnp.float32, _do_init=True, **kwargs)

The init function is called when the class is instantiated. It sets up the instance of the class, and defines what happens when it's created. The init function can take arguments, but self is always required (it refers to the instance of the object).

Parameters:

Name Type Description Default
self

Refer to the object itself

required
config Qwen1Config

Qwen1Config: Pass the configuration to the module

required
input_shape Tuple

Tuple: Specify the shape of the input to the model

(1, 1)
seed int

int: Set the seed for random number generation

0
dtype dtype

jnp.dtype: Specify the data type of the input

float32
_do_init bool

bool: Control whether the module is initialized or not

True
kwargs

Pass in any additional parameters that the module_class might need

{}

Specify the number of h in the network

required

Returns:

Type Description

The super() of the class

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
def __init__(
        self,
        config: Qwen1Config,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
):
    """
    The __init__ function is called when the class is instantiated.
    It sets up the instance of the class, and defines what happens when it's created.
    The __init__ function can take arguments, but self is always required (it refers to the instance of the object).


    :param self: Refer to the object itself
    :param config: Qwen1Config: Pass the configuration to the module
    :param input_shape: Tuple: Specify the shape of the input to the model
    :param seed: int: Set the seed for random number generation
    :param dtype: jnp.dtype: Specify the data type of the input
    :param _do_init: bool: Control whether the module is initialized or not
    :param kwargs: Pass in any additional parameters that the module_class might need
    :param : Specify the number of h in the network
    :return: The super() of the class

    """
    module = self.module_class(config=config, dtype=dtype, **kwargs)
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

init_cache(batch_size, max_length)

The init_cache function is used to initialize the cache for a given batch size and sequence length. The cache is a dictionary that contains all the intermediate states from each layer in the model. This allows us to run inference on multiple batches without having to re-run forward passes through every layer in the model, which would be very slow.

Parameters:

Name Type Description Default
self

Access the module

required
batch_size

Define the batch size of the input tensors

required
max_length

Set the length of the input sequence

required

Returns:

Type Description

A dictionary with the following keys:

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
def init_cache(self, batch_size, max_length):
    """
    The init_cache function is used to initialize the cache for a given batch size and sequence length.
    The cache is a dictionary that contains all the intermediate states from each layer in the model.
    This allows us to run inference on multiple batches without having to re-run forward passes through every layer in
    the model, which would be very slow.

    :param self: Access the module
    :param batch_size: Define the batch size of the input tensors
    :param max_length: Set the length of the input sequence
    :return: A dictionary with the following keys:

    """
    input_ids = jnp.ones((batch_size, max_length))
    attention_mask = jnp.ones_like(input_ids)
    position_ids = jnp.broadcast_to(jnp.arange(
        jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

    init_variables = self.module.init(
        jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
    )
    return init_variables["cache"]

init_weights(rng, input_shape, params=None)

The init_weights function is used to initialize the weights of a model.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
rng PRNGKey

jax.random.PRNGKey: Initialize the weights of the model

required
input_shape Tuple

Tuple: Specify the shape of the input tensor

required
params FrozenDict

FrozenDict: Pass in the parameters of a pre-trained model

None

Returns:

Type Description
FrozenDict

A frozendict of parameters

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
    """
    The init_weights function is used to initialize the weights of a model.

    :param self: Access variables that belong to the class
    :param rng: jax.random.PRNGKey: Initialize the weights of the model
    :param input_shape: Tuple: Specify the shape of the input tensor
    :param params: FrozenDict: Pass in the parameters of a pre-trained model
    :return: A frozendict of parameters

    """
    input_ids = jnp.zeros(input_shape, dtype="i4")
    attention_mask = jnp.ones_like(input_ids)
    position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
    params_rng, dropout_rng = jax.random.split(rng)
    rngs = {"params": params_rng, "dropout": dropout_rng}

    if self.config.add_cross_attention:
        encoder_hidden_states = jnp.zeros(
            input_shape + (self.config.hidden_size,))
        encoder_attention_mask = attention_mask
        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            encoder_hidden_states,
            encoder_attention_mask,
            return_dict=False,
        )
    else:
        module_init_outputs = self.module.init(
            rngs, input_ids, attention_mask, position_ids, return_dict=False)

    random_params = module_init_outputs["params"]

    if params is not None:
        random_params = flatten_dict(unfreeze(random_params))
        params = flatten_dict(unfreeze(params))
        for missing_key in self._missing_keys:
            params[missing_key] = random_params[missing_key]
        self._missing_keys = set()
        return freeze(unflatten_dict(params))
    else:
        return random_params

prepare_inputs_for_generation(input_ids, max_length, attention_mask=None)

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters:

Name Type Description Default
self

Access variables that belong to the class

required
input_ids

Pass in the input tokens

required
max_length

Set the length of the sequence to be generated

required
attention_mask Optional[Array]

Optional[chex.Array]: Mask the attention weights

None

Returns:

Type Description

A dictionary of the past_key_values, attention_mask and position ids

Source code in src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
    """
    The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

    :param self: Access variables that belong to the class
    :param input_ids: Pass in the input tokens
    :param max_length: Set the length of the sequence to be generated
    :param attention_mask: Optional[chex.Array]: Mask the attention weights
    :return: A dictionary of the past_key_values, attention_mask and position ids

    """
    batch_size, seq_length = input_ids.shape
    extended_attention_mask = jnp.ones(
        (batch_size, max_length), dtype="i4")
    if attention_mask is not None:
        position_ids = attention_mask.cumsum(axis=-1) - 1
        extended_attention_mask = jax.lax.dynamic_update_slice(
            extended_attention_mask, attention_mask, (0, 0))
    else:
        position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[
                                        None, :], (batch_size, seq_length))

    return {
        "past_key_values": self.init_cache(batch_size, max_length),
        "attention_mask": extended_attention_mask,
        "position_ids": position_ids,
        # "past_rope_cache": self.init_rope(batch_size=batch_size, max_length=max_length)
    }