Skip to content

pallas_operations.flash_attention.tpu.jax_flash_attn_tpu

Flash Attention TPU kernel.

BlockSizes dataclass

Tile sizes parameterizing FlashAttention kernels.

Those parameters have negligible effect on numerics, but affect performance greatly.

Source code in src/fjformer/pallas_operations/flash_attention/tpu/jax_flash_attn_tpu.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@dataclasses.dataclass(frozen=True)
class BlockSizes:
    """Tile sizes parameterizing FlashAttention kernels.

    Those parameters have negligible effect on numerics, but affect performance
    greatly.
    """
    block_q: int
    block_k_major: int
    block_k: int
    block_b: int

    block_q_major_dkv: Optional[int] = None
    block_k_major_dkv: Optional[int] = None
    block_k_dkv: Optional[int] = None
    block_q_dkv: Optional[int] = None

    block_k_major_dq: Optional[int] = None
    block_k_dq: Optional[int] = None
    block_q_dq: Optional[int] = None

    def __post_init__(self):
        def verify_major_minor(prefix, suffix, major, minor):
            if minor > major:
                raise ValueError(
                    f"{prefix}{suffix}={minor} should be smaller than"
                    f" {prefix}_major{suffix}={major}"
                )
            if major % minor != 0:
                raise ValueError(
                    f"{prefix}{suffix}={minor} should divide"
                    f" {prefix}_major{suffix}={major}"
                )

        verify_major_minor("block_k", "", self.block_k_major, self.block_k)
        if self.block_q_major_dkv is not None and self.block_q_dkv is not None:
            verify_major_minor(
                "block_q", "_dkv", self.block_q_major_dkv, self.block_q_dkv
            )
        if self.block_k_major_dkv is not None and self.block_k_dkv is not None:
            verify_major_minor(
                "block_k", "_dkv", self.block_k_major_dkv, self.block_k_dkv
            )
        if self.block_k_major_dq is not None and self.block_k_dq is not None:
            verify_major_minor(
                "block_k", "_dq", self.block_k_major_dq, self.block_k_dq
            )

    @property
    def has_backward_blocks(self) -> bool:
        backward_blocks = (
            self.block_q_major_dkv,
            self.block_k_major_dkv,
            self.block_q_dkv,
            self.block_k_dkv,
            self.block_k_major_dq,
            self.block_k_dq,
            self.block_q_dq,
        )
        return all(b is not None for b in backward_blocks)

    @classmethod
    def get_default(cls, batch_size, num_heads, q_seq_len, kv_len, d_model):
        # TODO(apaszke,sharadmv): Select better parameters based on a heuristic.
        del batch_size, num_heads, q_seq_len, kv_len, d_model  # Unused.
        return BlockSizes(
            block_q=128,
            block_k_major=128,
            block_k=128,
            block_b=1,
            block_q_major_dkv=128,
            block_k_major_dkv=128,
            block_k_dkv=128,
            block_q_dkv=128,
            block_k_major_dq=128,
            block_k_dq=128,
            block_q_dq=128,
        )

SegmentIds

Bases: NamedTuple

SegmentIds for Q and KV sequences.

SegmentIds are used to generate segment mask, which prevents attention between different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other.

Attributes: q: segment ids along the Q sequence. kv: segment ids along the KV sequence.

Source code in src/fjformer/pallas_operations/flash_attention/tpu/jax_flash_attn_tpu.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class SegmentIds(NamedTuple):
    """SegmentIds for Q and KV sequences.

    SegmentIds are used to generate segment mask, which prevents attention between
    different segments in the input sequence. Each array is a list of ids
    (integers).
    Only the token with the same id can attend to each other.

    Attributes:
      q: segment ids along the Q sequence.
      kv: segment ids along the KV sequence.
    """

    q: jax.Array  # [batch_size, q_seq_len]
    kv: jax.Array  # [batch_size, kv_seq_len]