pallas_operations.ring_attention.ring_attention
This module contains ring attention forward and backward pass, supporting both blockwise computation and TPU-compatible fused attention. It features blockwise computation for feedforward networks to reduce memory cost. For more details, refer to 'RingAttention' at https://arxiv.org/abs/2305.19370 and 'Blockwise Parallel Transformers' at https://arxiv.org/abs/2310.01889.
SegmentIds
Bases: NamedTuple
SegmentIds for Q and KV sequences.
SegmentIds are used to generate segment mask, which prevents attention between different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other.
Attributes: q: segment ids along the Q sequence. kv: segment ids along the KV sequence.
Source code in src/fjformer/pallas_operations/ring_attention/ring_attention.py
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 | |