Skip to content

reinforcement_learning.models.modelling_casual_language_rl

ValueHead

Bases: Module

Source code in src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class ValueHead(nn.Module):
    summary_dropout_prob: float = 0.0
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype = jnp.float32
    precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest")
    kernel_init: Callable = nn.initializers.orthogonal()

    def setup(self):
        """
        The setup function is called by the model's constructor.
        It initializes all the layers in your model, and assigns them to member variables.
        The setup function should be used for any initialization that needs to happen before running forward().
        This includes things like loading weights from a file, or setting up an optimizer.
        :param self: Represent the instance of the class
        """
        self.dropout = flax.linen.Dropout(self.summary_dropout_prob)

        self.summary = Linear(
            1,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
            kernel_init=self.kernel_init,
            use_bias=False
        )

    def __call__(self, hidden_states: chex.Array, deterministic: bool = True):
        """
        The __call__ function is the main function of a class.
        It is called when an instance of the class (an object) is invoked as a function, e.g., x(arg).
        The __call__ method enables instances of a class to be called like standard Python functions.

        :param self: Represent the instance of the class
        :param hidden_states: chex.Array: Pass the hidden states of the previous layer
        :param deterministic: bool: Determine whether to use dropout
        :return: A tensor of shape (batch_size, num_classes)

        """
        return self.summary(self.dropout(hidden_states, deterministic=deterministic))

__call__(hidden_states, deterministic=True)

The call function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, e.g., x(arg). The call method enables instances of a class to be called like standard Python functions.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
hidden_states Array

chex.Array: Pass the hidden states of the previous layer

required
deterministic bool

bool: Determine whether to use dropout

True

Returns:

Type Description

A tensor of shape (batch_size, num_classes)

Source code in src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py
60
61
62
63
64
65
66
67
68
69
70
71
72
def __call__(self, hidden_states: chex.Array, deterministic: bool = True):
    """
    The __call__ function is the main function of a class.
    It is called when an instance of the class (an object) is invoked as a function, e.g., x(arg).
    The __call__ method enables instances of a class to be called like standard Python functions.

    :param self: Represent the instance of the class
    :param hidden_states: chex.Array: Pass the hidden states of the previous layer
    :param deterministic: bool: Determine whether to use dropout
    :return: A tensor of shape (batch_size, num_classes)

    """
    return self.summary(self.dropout(hidden_states, deterministic=deterministic))

setup()

The setup function is called by the model's constructor. It initializes all the layers in your model, and assigns them to member variables. The setup function should be used for any initialization that needs to happen before running forward(). This includes things like loading weights from a file, or setting up an optimizer.

Parameters:

Name Type Description Default
self

Represent the instance of the class

required
Source code in src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def setup(self):
    """
    The setup function is called by the model's constructor.
    It initializes all the layers in your model, and assigns them to member variables.
    The setup function should be used for any initialization that needs to happen before running forward().
    This includes things like loading weights from a file, or setting up an optimizer.
    :param self: Represent the instance of the class
    """
    self.dropout = flax.linen.Dropout(self.summary_dropout_prob)

    self.summary = Linear(
        1,
        dtype=self.dtype,
        param_dtype=self.param_dtype,
        precision=self.precision,
        kernel_init=self.kernel_init,
        use_bias=False
    )