Skip to content

pallas_operations.ring_attention.ring_attention

This module contains ring attention forward and backward pass, supporting both blockwise computation and TPU-compatible fused attention. It features blockwise computation for feedforward networks to reduce memory cost. For more details, refer to 'RingAttention' at https://arxiv.org/abs/2305.19370 and 'Blockwise Parallel Transformers' at https://arxiv.org/abs/2310.01889.

SegmentIds

Bases: NamedTuple

SegmentIds for Q and KV sequences.

SegmentIds are used to generate segment mask, which prevents attention between different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other.

Attributes: q: segment ids along the Q sequence. kv: segment ids along the KV sequence.

Source code in src/fjformer/pallas_operations/ring_attention/ring_attention.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
class SegmentIds(NamedTuple):
    """SegmentIds for Q and KV sequences.

    SegmentIds are used to generate segment mask, which prevents attention between
    different segments in the input sequence. Each array is a list of ids
    (integers).
    Only the token with the same id can attend to each other.

    Attributes:
      q: segment ids along the Q sequence.
      kv: segment ids along the KV sequence.
    """

    q: jax.Array  # [q_seq_len]
    kv: jax.Array  # [kv_seq_len]