Skip to content

reinforcement_learning.core

add_suffix(input_dict, suffix)

Add suffix to dict keys.

Source code in src/python/easydel/reinforcement_learning/core.py
92
93
94
def add_suffix(input_dict, suffix):
    """Add suffix to dict keys."""
    return dict((k + suffix, v) for k, v in input_dict.items())

multinomial(logits, num_samples, replacement=False)

Implements the torch.multinomial function in JAX.

Args: logits (jnp.array): The unnormalized log probabilities of the events. num_samples (int): The number of samples to draw. replacement (bool): Don't use this ; Returns: jnp.array: A matrix of shape (num_samples, batch_size) containing the sampled indices.

Source code in src/python/easydel/reinforcement_learning/core.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def multinomial(logits, num_samples: int, replacement: bool = False):
    """
    Implements the `torch.multinomial` function in JAX.

    Args:
        logits (jnp.array): The unnormalized log probabilities of the events.
        num_samples (int): The number of samples to draw.
        replacement (bool): Don't use this ;\

    Returns:
        jnp.array: A matrix of shape (num_samples, batch_size) containing the
            sampled indices.
    """
    logits = jax.nn.log_softmax(logits, axis=-1)
    if replacement:
        return jax.random.categorical(logits, num_samples)
    else:
        samples = []
        for _ in range(num_samples):
            sample = jax.random.categorical(logits, 1)
            samples.append(sample[0])
            logits = logits.at[sample[0]].set(-jnp.inf)
        return jnp.array(samples)