Computes the softmax of the input array along the specified axis.
Args:
x: input array
axis: the axis along which to perform the computation
num_warps: the number of warps to use for executing the Triton kernel
interpret: whether to interpret the kernel using pallas
debug: whether to use pallas in debug mode
Returns:
The result of the softmax operation over the specified axis of x.
Source code in src/fjformer/pallas_operations/softmax/gpu/softmax.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86 | @functools.partial(jax.jit, static_argnames=["axis", "num_warps", "interpret",
"debug"])
def softmax(
x: jax.Array, *, axis: int = -1, num_warps: int = 4,
interpret: bool = False, debug: bool = False
) -> jax.Array:
"""Computes the softmax of the input array along the specified axis.
Args:
x: input array
axis: the axis along which to perform the computation
num_warps: the number of warps to use for executing the Triton kernel
interpret: whether to interpret the kernel using pallas
debug: whether to use pallas in debug mode
Returns:
The result of the softmax operation over the specified axis of x.
"""
axis = axis if axis >= 0 else len(x.shape) + axis
if axis != len(x.shape) - 1:
raise NotImplementedError(
"reductions along non-trailing dimension unsupported")
row_len = x.shape[-1]
block_row = pl.next_power_of_2(row_len)
out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype)
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
f = pl.pallas_call(kernel, num_warps=num_warps, num_stages=1, grid=(),
out_shape=out_shape, debug=debug, interpret=interpret)
for _ in range(len(x.shape) - 1):
f = jax.vmap(f)
return f(x)
|