Skip to content

func._func

average_metrics(metrics)

The average_metrics function takes a list of metrics and averages them.

Parameters:

Name Type Description Default
metrics

Store the metrics for each batch

required

Returns:

Type Description

The mean of the metrics across all runs

Source code in src/fjformer/func/_func.py
13
14
15
16
17
18
19
20
21
22
23
24
def average_metrics(metrics):
    """
    The average_metrics function takes a list of metrics and averages them.

    :param metrics: Store the metrics for each batch
    :return: The mean of the metrics across all runs

    """
    return jax.tree_map(
        lambda *args: jnp.mean(jnp.stack(args)),
        *metrics
    )

fused_softmax(x, axis=-1)

The fused_softmax function is a fused version of the softmax function.

Parameters:

Name Type Description Default
x Array

chex.Array: Specify the input to the function

required
axis int

int: Specify the axis along which to apply the softmax function

-1

Returns:

Type Description

The same result as the softmax function

Source code in src/fjformer/func/_func.py
48
49
50
51
52
53
54
55
56
57
def fused_softmax(x: chex.Array, axis: int = -1):
    """
    The fused_softmax function is a fused version of the softmax function.

    :param x: chex.Array: Specify the input to the function
    :param axis: int: Specify the axis along which to apply the softmax function
    :return: The same result as the softmax function

    """
    return jnp.exp(jax.nn.log_softmax(x, axis=axis))

global_norm(tree)

Return the global L2 norm of a pytree.

Source code in src/fjformer/func/_func.py
 6
 7
 8
 9
10
def global_norm(tree):
    """ Return the global L2 norm of a pytree. """
    squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
    flattened, _ = jax.flatten_util.ravel_pytree(squared)
    return jnp.sqrt(jnp.sum(flattened))

transpose(array, dim0, dim1)

The transpose function takes an array and two dimensions, and returns a new array with the specified dimensions transposed. The first dimension is given as a positive integer, where 0 represents the outermost dimension of the array. If the first dimension is negative, it counts from the end of the shape tuple; -2 is equivalent to len(shape) - 2. The second dimension may be specified in a similar way.

Parameters:

Name Type Description Default
array Array

chex.Array: Specify the array to be transposed

required
dim0 int

int: Specify the first dimension to be transposed

required
dim1 int

int: Specify the dimension of the array

required

Returns:

Type Description

A new array with the same data, but with axes permuted

Source code in src/fjformer/func/_func.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def transpose(array: chex.Array, dim0: int, dim1: int):
    """
    The transpose function takes an array and two dimensions, and returns a new
    array with the specified dimensions transposed. The first dimension is given as
    a positive integer, where 0 represents the outermost dimension of the array. If
    the first dimension is negative, it counts from the end of the shape tuple; -2
    is equivalent to len(shape) - 2. The second dimension may be specified in a similar way.

    :param array: chex.Array: Specify the array to be transposed
    :param dim0: int: Specify the first dimension to be transposed
    :param dim1: int: Specify the dimension of the array
    :return: A new array with the same data, but with axes permuted

    """
    dim0 = dim0 if dim0 > 0 else array.ndim - dim0
    dim1 = dim1 if dim1 > 0 else array.ndim - dim1
    perm = list(range(array.ndim))
    perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
    return jnp.transpose(array, perm)