Skip to content

pallas_operations.splash_attention.tpu.splash_attention_kernel

Implementation of Sparse Flash Attention, a.k.a. "Splash" attention.

BlockSizes dataclass

Tile sizes parameterizing SplashAttention kernels.

Those parameters have negligible effect on numerics, but affect performance greatly.

Note that changing the layouts only influences the physical layout that the kernel will enforce. The logical interface to splash attention always takes the head dimension as the minormost one.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
@dataclasses.dataclass(unsafe_hash=True)
class BlockSizes:
    """Tile sizes parameterizing SplashAttention kernels.

    Those parameters have negligible effect on numerics, but affect performance
    greatly.

    Note that changing the layouts only influences the physical layout that the
    kernel will enforce. The logical interface to splash attention always takes
    the head dimension as the minormost one.
    """
    block_q: int
    block_kv: int
    block_kv_compute: int | None = None

    block_q_dkv: int | None = None
    block_kv_dkv: int | None = None
    block_kv_dkv_compute: int | None = None

    block_q_dq: int | None = None
    block_kv_dq: int | None = None

    use_fused_bwd_kernel: bool = False

    q_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR
    k_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR
    v_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR

    def __post_init__(self):
        if self.block_kv_compute is None:
            self.block_kv_compute = self.block_kv
        if self.block_kv_dkv_compute is None:
            self.block_kv_dkv_compute = self.block_kv_dkv
        if self.use_fused_bwd_kernel:
            if self.block_q_dq is not None or self.block_kv_dq is not None:
                raise ValueError(
                    "Block sizes for dq kernel are not needed with a fused kernel."
                )

    @property
    def has_backward_blocks(self) -> bool:
        backward_blocks = (
            self.block_q_dkv, self.block_kv_dkv, self.block_kv_dkv_compute,
        )
        if not self.use_fused_bwd_kernel:
            backward_blocks += (self.block_q_dq, self.block_kv_dq)
        return all(b is not None for b in backward_blocks)

    @classmethod
    def get_default(cls):
        # TODO(apaszke,sharadmv): Select better parameters based on a heuristic.
        return BlockSizes(
            block_q=128,
            block_kv=128,
            block_kv_compute=128,
            block_q_dkv=128,
            block_kv_dkv=128,
            block_kv_dkv_compute=128,
            block_q_dq=128,
            block_kv_dq=128,
        )

SegmentIds

Bases: NamedTuple

SegmentIds for Q and KV sequences.

SegmentIds are a mechanims to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other.

The static mask (e.g. causal) is "and-ed" with the segment id mask to form the actual attention mask. It is important that the latter does not have any all-zero rows (along dimension kv). Otherwise it would result in a invalid softmax (the denominator would be 0). This condition holds for causal self-attention because in this case segment ids form a block diagonal matrix so at least one element in each row is set. It is easy to break this condition with non-self-attention configurations. Attributes: q: segment ids along the Q sequence kv: segment ids along the KV sequence

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class SegmentIds(NamedTuple):
    """SegmentIds for Q and KV sequences.

    SegmentIds are a mechanims to ensure that there is no cross-attention between
    segments (fraction of a sequence) that have been concatenated together into a
    sequence. Each array is a list of ids (integers). Only tokens with the same
    id are allowed to attend to each other.

    The static mask (e.g. causal) is "and-ed" with the segment id mask to form
    the actual attention mask. It is important that the latter does not have any
    all-zero rows (along dimension kv). Otherwise it would result in a invalid
    softmax (the denominator would be 0).
    This condition holds for causal self-attention because in this case segment
    ids form a block diagonal matrix so at least one element in each row is set.
    It is easy to break this condition with non-self-attention configurations.
    Attributes:
      q: segment ids along the Q sequence
      kv: segment ids along the KV sequence
    """

    q: jax.Array  # [q_seq_len]
    kv: jax.Array  # [kv_seq_len]

SplashAttentionKernel

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
@jax.tree_util.register_pytree_node_class
class SplashAttentionKernel:

    def __init__(
            self,
            fwd_mask_info: mask_info_lib.MaskInfo,
            dq_mask_info: mask_info_lib.MaskInfo | None,
            dkv_mask_info: mask_info_lib.MaskInfo | None,
            **kwargs,
    ):
        self.kwargs = kwargs
        self.fwd_mask_info = fwd_mask_info
        self.dq_mask_info = dq_mask_info
        self.dkv_mask_info = dkv_mask_info

    def __call__(self, *args, **kwargs) -> SplashCustomReturnType:
        return _splash_attention(
            self.fwd_mask_info,
            self.dq_mask_info,
            self.dkv_mask_info,
            *args,
            **kwargs,
            **self.kwargs,
        )

    def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding):
        """Returns a value that can be used as a shard_map partition spec for the kernel."""
        if self.fwd_mask_info.data_next is not None:
            block_mask_shape = self.fwd_mask_info.data_next.shape
            try:
                shard_shape = sharding.shard_shape(block_mask_shape)
            except ValueError as exc:
                raise ValueError(
                    "The sharding must divide the mask blocks evenly between devices"
                ) from exc
            if block_mask_shape[-1] != shard_shape[-1]:
                raise ValueError("Sharding the kv sequence dimension is not supported")
        spec = sharding.spec
        assert len(spec) == 2
        replicated = jax.sharding.PartitionSpec()
        # Shard q_sequence over the sequence dimension only.
        q_sequence_spec = jax.sharding.PartitionSpec(spec[1])
        mask_info_specs = mask_info_lib.MaskInfo(  # pytype: disable=wrong-arg-types
            data_next=spec if self.fwd_mask_info.data_next is not None else None,
            mask_next=spec if self.fwd_mask_info.mask_next is not None else None,
            block_mask=spec if self.fwd_mask_info.block_mask is not None else None,
            partial_mask_blocks=replicated
            if self.fwd_mask_info.partial_mask_blocks is not None
            else None,
            q_sequence=q_sequence_spec
            if self.fwd_mask_info.q_sequence is not None
            else None,
        )
        return SplashAttentionKernel(
            mask_info_specs,
            mask_info_specs if self.dq_mask_info is not None else None,
            mask_info_specs if self.dkv_mask_info is not None else None,
            **self.kwargs,
        )

    def tree_flatten(self):
        return (
            (self.fwd_mask_info, self.dq_mask_info, self.dkv_mask_info),
            self.kwargs,
        )

    @classmethod
    def tree_unflatten(cls, kwargs, values):
        fwd_mask_info, dq_mask_info, dkv_mask_info = values
        # NamedTuples are not preserved during pytree serialization.
        dq_mask_info = (
            mask_info_lib.MaskInfo(*dq_mask_info)
            if dq_mask_info is not None
            else None
        )
        dkv_mask_info = (
            mask_info_lib.MaskInfo(*dkv_mask_info)
            if dkv_mask_info is not None
            else None
        )
        return SplashAttentionKernel(
            mask_info_lib.MaskInfo(*fwd_mask_info),
            dq_mask_info,
            dkv_mask_info,
            **kwargs,
        )

manual_sharding_spec(sharding)

Returns a value that can be used as a shard_map partition spec for the kernel.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding):
    """Returns a value that can be used as a shard_map partition spec for the kernel."""
    if self.fwd_mask_info.data_next is not None:
        block_mask_shape = self.fwd_mask_info.data_next.shape
        try:
            shard_shape = sharding.shard_shape(block_mask_shape)
        except ValueError as exc:
            raise ValueError(
                "The sharding must divide the mask blocks evenly between devices"
            ) from exc
        if block_mask_shape[-1] != shard_shape[-1]:
            raise ValueError("Sharding the kv sequence dimension is not supported")
    spec = sharding.spec
    assert len(spec) == 2
    replicated = jax.sharding.PartitionSpec()
    # Shard q_sequence over the sequence dimension only.
    q_sequence_spec = jax.sharding.PartitionSpec(spec[1])
    mask_info_specs = mask_info_lib.MaskInfo(  # pytype: disable=wrong-arg-types
        data_next=spec if self.fwd_mask_info.data_next is not None else None,
        mask_next=spec if self.fwd_mask_info.mask_next is not None else None,
        block_mask=spec if self.fwd_mask_info.block_mask is not None else None,
        partial_mask_blocks=replicated
        if self.fwd_mask_info.partial_mask_blocks is not None
        else None,
        q_sequence=q_sequence_spec
        if self.fwd_mask_info.q_sequence is not None
        else None,
    )
    return SplashAttentionKernel(
        mask_info_specs,
        mask_info_specs if self.dq_mask_info is not None else None,
        mask_info_specs if self.dkv_mask_info is not None else None,
        **self.kwargs,
    )

get_kernel_name(is_mqa, save_residuals, is_segmented, phase)

Returns a unique name for all SplashAttention kernel variants.

Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def get_kernel_name(
        is_mqa: bool, save_residuals: bool, is_segmented: bool, phase: str
) -> str:
    """Returns a unique name for all SplashAttention kernel variants."""

    assert phase == "dq" or phase == "dkv" or phase == "fwd"
    # Saving residuals is supported only for the fwd phase.
    assert not save_residuals or phase == "fwd"
    residuals = ""
    if save_residuals:
        residuals = "_residuals"
    elif phase == "fwd":
        residuals = "_no_residuals"
    attention_type = "mqa" if is_mqa else "mha"
    segments = "_segmented" if is_segmented else ""
    return f"splash_{attention_type}_{phase}{segments}{residuals}"