pallas_operations.splash_attention.tpu.splash_attention_mask_info
Mini-mask creation library.
MaskInfo
Bases: NamedTuple
Contains runtime masking information for the Splash attention kernel.
The arrays data_next, mask_next and block_mask are placed in TPU scalar-memory. This is a scarse resource so the mask creation logic attempts to shrink the data-type of these arrays to the smallest possible one. This can be: np.int32, np.int16 or np.int8.
For the arrays data_next, mask_next and block_mask the size of the first dimension can be one of the two following values: num_head or num_head_shards. The first dimension has size: * num_head_shards when there is only one unique mask for each head in a shard. In this case the three arrays are broadcasted to all the heads in the shard. * num_heads when there is more than one unique mask for each head in the shard.
Attributes:
data_next: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks]
NumPy array where each entry contains the next kv block index to
prefetch.
mask_next: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks]
NumPy array where each entry contains the next mask block index in
partial_mask_blocks to prefetch.
block_mask: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks]
NumPy array whose entries can be 0, 1 or 2. An entry of 0 indicates that
the corresponding block in the full mask was all zeros. An entry of 1
indicates that the corresponding block in the full mask contained both
zeros and ones. An entry of 2 indicates the corresponding block was
entirely ones.
partial_mask_blocks: A i32[num_partial_blocks, block_q, block_kv] NumPy
array that contains the blocks of the original mask that contained both
zeros and ones. The entries in mask_next point to indices in the first
axis of this array.
q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking,
this contains the list of indices that correspond to q tokens. For plain
causal this is just np.arange(q_sequence_length).
Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask_info.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | |