pallas_operations.splash_attention.tpu.splash_attention_kernel
Implementation of Sparse Flash Attention, a.k.a. "Splash" attention.
BlockSizes
dataclass
Tile sizes parameterizing SplashAttention kernels.
Those parameters have negligible effect on numerics, but affect performance greatly.
Note that changing the layouts only influences the physical layout that the kernel will enforce. The logical interface to splash attention always takes the head dimension as the minormost one.
Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 | |
SegmentIds
Bases: NamedTuple
SegmentIds for Q and KV sequences.
SegmentIds are a mechanims to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other.
The static mask (e.g. causal) is "and-ed" with the segment id mask to form the actual attention mask. It is important that the latter does not have any all-zero rows (along dimension kv). Otherwise it would result in a invalid softmax (the denominator would be 0). This condition holds for causal self-attention because in this case segment ids form a block diagonal matrix so at least one element in each row is set. It is easy to break this condition with non-self-attention configurations. Attributes: q: segment ids along the Q sequence kv: segment ids along the KV sequence
Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | |
SplashAttentionKernel
Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 | |
manual_sharding_spec(sharding)
Returns a value that can be used as a shard_map partition spec for the kernel.
Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 | |
get_kernel_name(is_mqa, save_residuals, is_segmented, phase)
Returns a unique name for all SplashAttention kernel variants.
Source code in src/fjformer/pallas_operations/splash_attention/tpu/splash_attention_kernel.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | |