Skip to content

modules.jetmoe.modelling_jetmoe_flax

compute_gating(k, num_experts, top_k_gates, top_k_indices)

Compute gating values for the mixture of experts based on probabilities and top-k indices.

Source code in src/python/easydel/modules/jetmoe/modelling_jetmoe_flax.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def compute_gating(k: int, num_experts: int, top_k_gates: jnp.ndarray, top_k_indices: jnp.ndarray) -> Tuple[
    chex.Array, chex.Array, chex.Array, chex.Array
]:
    """
    Compute gating values for the mixture of experts based on probabilities and top-k indices.
    """
    zeros = jnp.zeros([top_k_gates.shape[0], num_experts], dtype=top_k_gates.dtype)
    gates = zeros.at[jnp.arange(zeros.shape[0])[:, None], top_k_indices].set(1)
    expert_size = gates.astype(jnp.int32).sum(axis=0)
    top_k_gates = top_k_gates.flatten()
    top_k_experts = top_k_indices.flatten()
    index_sorted_experts = jnp.argsort(top_k_experts)
    batch_index = lax.div(index_sorted_experts, k)
    batch_gates = top_k_gates[index_sorted_experts]
    return batch_gates, batch_index, expert_size, index_sorted_experts