Skip to content

pallas_operations.splash_attention.tpu.splash_attention_mask

Mini-mask creation library.

CausalMask

Bases: _ComputableMask

Lazy causal mask, prevents the model from attending to future tokens.

Attributes: offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
class CausalMask(_ComputableMask):
    """Lazy causal mask, prevents the model from attending to future tokens.

    Attributes:
      offset: Offset of q start wrt kv. A positive offset shifts the bottom
        triangle upward, a negative one shifts it downward. A negative offset
        makes the first 'offset' rows of the attention matrix all 0s which leads
        to undefined softmax.
    """

    offset: int

    def __init__(
            self,
            shape: Tuple[int, int],
            offset: int = 0,
            shard_count: int = 1,
    ):
        self.offset = offset

        def causal_mask_function(q_ids, kv_ids):
            # When evaluating the mask in _process_mask we typically work with numpy
            # array views.
            # Avoid the addition when possible to avoid instantiating an actual array.
            if self.offset == 0:
                return q_ids >= kv_ids
            else:
                return q_ids + self.offset >= kv_ids

        mask_function = causal_mask_function

        super().__init__(
            shape=shape,
            mask_function=mask_function,
            shard_count=shard_count,
        )

    def __eq__(self, other: object):
        if not isinstance(other, type(self)):
            return NotImplemented

        return (
                self.shape == other.shape
                and self.offset == other.offset
                and np.array_equal(self.q_sequence, other.q_sequence)
        )

    def __hash__(self):
        return hash((
            type(self),
            self.shape,
            self.offset,
            self.q_sequence.tobytes() if self.q_sequence is not None else None,
        ))

FullMask dataclass

Bases: Mask

Lazy full mask, allows all tokens to attend to all other tokens.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
@dataclasses.dataclass(frozen=True)
class FullMask(Mask):
    """Lazy full mask, allows all tokens to attend to all other tokens."""

    # TODO(amagni): Transform FullMask into a _ComputableMask.

    _shape: tuple[int, int]

    def __post_init__(self):
        if not isinstance(self.shape, tuple):
            raise ValueError(f'Unsupported shape type: {type(self.shape)}')

    @property
    def shape(self) -> Tuple[int, ...]:
        return self._shape

    def __getitem__(self, idx) -> np.ndarray:
        if len(idx) != 2:
            raise NotImplementedError(f'Unsupported slice: {idx}')
        i, j = idx
        if not isinstance(i, slice) or not isinstance(j, slice):
            raise NotImplementedError(f'Unsupported slice: {idx}')
        i = _fill_slice(i, self.shape[0])
        j = _fill_slice(j, self.shape[1])
        return np.ones((i.stop - i.start, j.stop - j.start), dtype=np.bool_)

    def __eq__(self, other: object):
        if not isinstance(other, type(self)):
            return NotImplemented

        return self.shape == other.shape

    def __hash__(self):
        return hash((type(self), self.shape))

LocalMask

Bases: Mask

Lazy local mask, prevents model from attending to tokens outside window.

Attributes: _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). window_size: Size of the two sides of the local window (None identifes no limit for the given side). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. _q_sequence: Important for performance.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
class LocalMask(Mask):
    """Lazy local mask, prevents model from attending to tokens outside window.

    Attributes:
      _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len).
      window_size: Size of the two sides of the local window (None identifes no
        limit for the given side).
      offset: Offset of q start wrt kv. A positive offset shifts the bottom
        triangle upward, a negative one shifts it downward. A negative offset
        makes the first 'offset' rows of the attention matrix all 0s which leads
        to undefined softmax.
      _q_sequence: Important for performance.
    """

    # TODO(amagni): Transform LocalMask into a _ComputableMask.

    _shape: Tuple[int, int]
    window_size: Tuple[int | None, int | None]
    offset: int
    _q_sequence: np.ndarray | None = None

    def __init__(
            self,
            shape: Tuple[int, int],
            window_size: Tuple[int | None, int | None],
            offset: int,
            shard_count: int = 1,
    ):
        self._shape = shape
        self.window_size = window_size
        self.offset = offset

        if self.shape[0] % (shard_count * shard_count) != 0:
            raise ValueError(
                f'Shard count squared ({shard_count * shard_count}) must'
                f' divide Q seq_len ({self.shape[0]}) evenly.'
            )

    @property
    def shape(self) -> Tuple[int, int]:
        return self._shape

    def __getitem__(self, idx) -> np.ndarray:
        if len(idx) != 2:
            raise NotImplementedError(f'Unsupported slice: {idx}')
        q_slice, kv_slice = idx
        if not isinstance(q_slice, slice) or not isinstance(kv_slice, slice):
            raise NotImplementedError(f'Unsupported slice: {idx}')

        q_slice = _fill_slice(q_slice, self.shape[0])
        kv_slice = _fill_slice(kv_slice, self.shape[1])

        if self._q_sequence is None:
            rows = np.arange(q_slice.start, q_slice.stop)
        else:
            rows = self._q_sequence[q_slice]

        cols = np.arange(kv_slice.start, kv_slice.stop)

        left_size, right_size = self.window_size

        if left_size is None and right_size is None:
            return np.ones((rows.shape[0], cols.shape[0]), dtype=np.bool_)
        else:
            expanded_cols = cols[None, :]
            if self.offset != 0:
                expanded_rows = rows[:, None] + self.offset
            else:
                expanded_rows = rows[:, None]
            if left_size is not None and right_size is not None:
                return (expanded_rows <= expanded_cols + left_size) & (
                        expanded_cols - right_size <= expanded_rows
                )

            elif left_size is not None and right_size is None:
                return expanded_rows <= expanded_cols + left_size
            else:
                assert left_size is None and right_size is not None
                return expanded_cols - right_size <= expanded_rows

    def __eq__(self, other: object):
        if not isinstance(other, type(self)):
            return NotImplemented

        return (
                self.shape == other.shape
                and self.window_size == other.window_size
                and self.offset == other.offset
                and (True if self._q_sequence is None else
                     np.array_equal(self._q_sequence, other._q_sequence))
        )

    def __hash__(self):
        return hash((
            type(self),
            self.shape,
            self.window_size,
            self.offset,
            self._q_sequence.tobytes() if self._q_sequence is not None else None,
        ))

Mask

A base class for splash attention masks.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Mask:
    """A base class for splash attention masks."""

    @property
    def shape(self) -> Tuple[int, ...]:
        raise NotImplementedError

    def __getitem__(self, idx) -> np.ndarray:
        raise NotImplementedError

    def __bool__(self) -> bool:
        raise NotImplementedError(
            'Conversion to bool is unsupported. Could be caused by using logical'
            ' instead of bitwise operations on masks.'
        )

    def __or__(self, other: 'Mask') -> 'Mask':
        if self.shape != other.shape:
            raise ValueError(
                f'Invalid shape for other: {other.shape}, expected: {self.shape}'
            )
        return LogicalOr(self, other)

    def __and__(self, other: 'Mask') -> 'Mask':
        if self.shape != other.shape:
            raise ValueError(
                f'Invalid shape for other: {other.shape}, expected: {self.shape}'
            )
        return LogicalAnd(self, other)

MultiHeadMask dataclass

Bases: Mask

Lazy multihead mask, combines multiple lazy masks one per head.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@dataclasses.dataclass
class MultiHeadMask(Mask):
    """Lazy multihead mask, combines multiple lazy masks one per head."""

    masks: Sequence[Mask]

    def __post_init__(self):
        if not self.masks:
            raise ValueError('Unsupported empty tuple of masks')

        shape = self.masks[0].shape
        for mask in self.masks[1:]:
            if shape != mask.shape:
                raise ValueError(
                    f'Unexpected mask shape, got: {mask.shape}, expected: {shape}'
                )

        if not all(isinstance(mask, Mask) for mask in self.masks):
            raise ValueError('masks should be of type Mask')

        if any(isinstance(mask, MultiHeadMask) for mask in self.masks):
            raise ValueError('Nesting MultiHeadMasks is not supported')

    @property
    def shape(self) -> Tuple[int, ...]:
        return (len(self.masks),) + self.masks[0].shape

    def __getitem__(self, idx) -> np.ndarray:
        if len(idx) != 3:
            raise NotImplementedError(f'Unsupported slice: {idx}')

        head_slice = idx[0]
        if isinstance(head_slice, int):
            assert head_slice >= 0 and head_slice <= len(self.masks)
            return self.masks[head_slice][idx[1:]]
        else:
            slice_masks = [mask[idx[1:]] for mask in self.masks[head_slice]]
            return np.stack(slice_masks)

    def __eq__(self, other: object):
        if not isinstance(other, type(self)):
            return NotImplemented

        return self.masks == other.masks

    def __hash__(self):
        return hash((type(self),) + tuple(hash(mask) for mask in self.masks))

NumpyMask dataclass

Bases: Mask

A mask backed by a dense numpy array.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
@dataclasses.dataclass
class NumpyMask(Mask):
    """A mask backed by a dense numpy array."""

    array: np.ndarray

    def __post_init__(self):
        if self.array.ndim != 2:
            raise ValueError('Expected a 2-dim array')

        if self.array.dtype != np.bool_:
            raise ValueError('Mask must be a boolean array')

    @property
    def shape(self) -> Tuple[int, ...]:
        return self.array.shape

    def __getitem__(self, idx) -> np.ndarray:
        return self.array[idx]

    def __eq__(self, other: object):
        if not isinstance(other, type(self)):
            return NotImplemented

        return np.array_equal(self.array, other.array, equal_nan=True)

    def __hash__(self):
        return hash((type(self), self.array.tobytes()))

make_causal_mask(shape, offset=0)

Makes a causal attention mask.

Args: shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax.

Returns: The causal mask.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def make_causal_mask(shape: Tuple[int, int], offset: int = 0) -> np.ndarray:
    """Makes a causal attention mask.

    Args:
      shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len).
      offset: Offset of q start wrt kv. A positive offset shifts the bottom
        triangle upward, a negative one shifts it downward. A negative offset
        makes the first 'offset' rows of the attention matrix all 0s which leads
        to undefined softmax.

    Returns:
      The causal mask.
    """
    q_seq_len, kv_seq_len = shape
    q_idx = np.arange(q_seq_len, dtype=np.int32)
    kv_idx = np.arange(kv_seq_len, dtype=np.int32)
    return (q_idx[:, None] + offset >= kv_idx[None, :]).astype(np.bool_)

make_local_attention_mask(shape, window_size, *, offset=0)

Makes a local attention mask.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def make_local_attention_mask(
        shape: Tuple[int, int],
        window_size: Tuple[int | None, int | None],
        *,
        offset: int = 0,
) -> np.ndarray:
    """Makes a local attention mask."""
    q_seq_len, kv_seq_len = shape
    q_idx = np.arange(q_seq_len, dtype=np.int32)
    kv_idx = np.arange(kv_seq_len, dtype=np.int32)
    mask = np.ones((q_seq_len, kv_seq_len), dtype=np.bool_)
    left, right = window_size
    if left is not None:
        mask = mask & (q_idx[:, None] - left + offset <= kv_idx[None, :])
    if right is not None:
        mask = mask & (q_idx[:, None] + right + offset >= kv_idx[None, :])
    return mask.astype(np.bool_)

make_random_mask(shape, sparsity, seed)

Makes a random attention mask.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_mask.py
 95
 96
 97
 98
 99
100
def make_random_mask(
        shape: Tuple[int, int], sparsity: float, seed: int
) -> np.ndarray:
    """Makes a random attention mask."""
    np.random.seed(seed)
    return np.random.binomial(n=1, p=1.0 - sparsity, size=shape).astype(np.bool_)