Skip to content

pallas_operations.softmax.gpu.softmax

Pallas softmax kernel.

softmax(x, *, axis=-1, num_warps=4, interpret=False, debug=False)

Computes the softmax of the input array along the specified axis.

Args: x: input array axis: the axis along which to perform the computation num_warps: the number of warps to use for executing the Triton kernel interpret: whether to interpret the kernel using pallas debug: whether to use pallas in debug mode

Returns: The result of the softmax operation over the specified axis of x.

Source code in src/fjformer/pallas_operations/softmax/gpu/softmax.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@functools.partial(jax.jit, static_argnames=["axis", "num_warps", "interpret",
                                             "debug"])
def softmax(
        x: jax.Array, *, axis: int = -1, num_warps: int = 4,
        interpret: bool = False, debug: bool = False
) -> jax.Array:
    """Computes the softmax of the input array along the specified axis.

    Args:
      x: input array
      axis: the axis along which to perform the computation
      num_warps: the number of warps to use for executing the Triton kernel
      interpret: whether to interpret the kernel using pallas
      debug: whether to use pallas in debug mode

    Returns:
      The result of the softmax operation over the specified axis of x.
    """
    axis = axis if axis >= 0 else len(x.shape) + axis
    if axis != len(x.shape) - 1:
        raise NotImplementedError(
            "reductions along non-trailing dimension unsupported")

    row_len = x.shape[-1]

    block_row = pl.next_power_of_2(row_len)
    out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype)

    kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
    f = pl.pallas_call(kernel, num_warps=num_warps, num_stages=1, grid=(),
                       out_shape=out_shape, debug=debug, interpret=interpret)

    for _ in range(len(x.shape) - 1):
        f = jax.vmap(f)

    return f(x)