Skip to content

pallas_operations.splash_attention.tpu.splash_attention_mask_info

Mini-mask creation library.

MaskInfo

Bases: NamedTuple

Contains runtime masking information for the Splash attention kernel.

The arrays data_next, mask_next and block_mask are placed in TPU scalar-memory. This is a scarse resource so the mask creation logic attempts to shrink the data-type of these arrays to the smallest possible one. This can be: np.int32, np.int16 or np.int8.

For the arrays data_next, mask_next and block_mask the size of the first dimension can be one of the two following values: num_head or num_head_shards. The first dimension has size: * num_head_shards when there is only one unique mask for each head in a shard. In this case the three arrays are broadcasted to all the heads in the shard. * num_heads when there is more than one unique mask for each head in the shard.

Attributes: data_next: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array where each entry contains the next kv block index to prefetch. mask_next: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array where each entry contains the next mask block index in partial_mask_blocks to prefetch. block_mask: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array whose entries can be 0, 1 or 2. An entry of 0 indicates that the corresponding block in the full mask was all zeros. An entry of 1 indicates that the corresponding block in the full mask contained both zeros and ones. An entry of 2 indicates the corresponding block was entirely ones. partial_mask_blocks: A i32[num_partial_blocks, block_q, block_kv] NumPy array that contains the blocks of the original mask that contained both zeros and ones. The entries in mask_next point to indices in the first axis of this array. q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking, this contains the list of indices that correspond to q tokens. For plain causal this is just np.arange(q_sequence_length).

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask_info.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class MaskInfo(NamedTuple):
    """Contains runtime masking information for the Splash attention kernel.

    The arrays data_next, mask_next and block_mask are placed in TPU
    scalar-memory. This is a scarse resource so the mask creation logic attempts
    to shrink the data-type of these arrays to the smallest possible one.
    This can be: np.int32, np.int16 or np.int8.

    For the arrays data_next, mask_next and block_mask the size of the first
    dimension can be one of the two following values: num_head or
    num_head_shards.
    The first dimension has size:
    * num_head_shards when there is only one unique mask for each head in a shard.
    In this case the three arrays are broadcasted to all the heads in the shard.
    * num_heads when there is more than one unique mask for each head in the
    shard.

    Attributes:
      data_next: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks]
        NumPy array where each entry contains the next `kv` block index to
        prefetch.
      mask_next: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks]
        NumPy array where each entry contains the next mask block index in
        `partial_mask_blocks` to prefetch.
      block_mask: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks]
        NumPy array whose entries can be 0, 1 or 2. An entry of 0 indicates that
        the corresponding block in the full mask was all zeros. An entry of 1
        indicates that the corresponding block in the full mask contained both
        zeros and ones. An entry of 2 indicates the corresponding block was
        entirely ones.
      partial_mask_blocks: A i32[num_partial_blocks, block_q, block_kv] NumPy
        array that contains the blocks of the original mask that contained both
        zeros and ones. The entries in `mask_next` point to indices in the first
        axis of this array.
      q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking,
        this contains the list of indices that correspond to q tokens. For plain
        causal this is just np.arange(q_sequence_length).
    """

    data_next: np.ndarray | None
    mask_next: np.ndarray | None
    block_mask: np.ndarray | None
    partial_mask_blocks: np.ndarray | None
    q_sequence: np.ndarray | None