Compute gating values for the mixture of experts based on probabilities and top-k indices.
Source code in src/python/easydel/modules/jetmoe/modelling_jetmoe_flax.py
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 | def compute_gating(k: int, num_experts: int, top_k_gates: jnp.ndarray, top_k_indices: jnp.ndarray) -> Tuple[
chex.Array, chex.Array, chex.Array, chex.Array
]:
"""
Compute gating values for the mixture of experts based on probabilities and top-k indices.
"""
zeros = jnp.zeros([top_k_gates.shape[0], num_experts], dtype=top_k_gates.dtype)
gates = zeros.at[jnp.arange(zeros.shape[0])[:, None], top_k_indices].set(1)
expert_size = gates.astype(jnp.int32).sum(axis=0)
top_k_gates = top_k_gates.flatten()
top_k_experts = top_k_indices.flatten()
index_sorted_experts = jnp.argsort(top_k_experts)
batch_index = lax.div(index_sorted_experts, k)
batch_gates = top_k_gates[index_sorted_experts]
return batch_gates, batch_index, expert_size, index_sorted_experts
|