reinforcement_learning.models.modelling_casual_language_rl
ValueHead
Bases: Module
Source code in src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
|
__call__(hidden_states, deterministic=True)
The call function is the main function of a class. It is called when an instance of the class (an object) is invoked as a function, e.g., x(arg). The call method enables instances of a class to be called like standard Python functions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
Represent the instance of the class |
required | |
hidden_states |
Array
|
chex.Array: Pass the hidden states of the previous layer |
required |
deterministic |
bool
|
bool: Determine whether to use dropout |
True
|
Returns:
Type | Description |
---|---|
A tensor of shape (batch_size, num_classes) |
Source code in src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py
60 61 62 63 64 65 66 67 68 69 70 71 72 |
|
setup()
The setup function is called by the model's constructor. It initializes all the layers in your model, and assigns them to member variables. The setup function should be used for any initialization that needs to happen before running forward(). This includes things like loading weights from a file, or setting up an optimizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
Represent the instance of the class |
required |
Source code in src/python/easydel/reinforcement_learning/models/modelling_casual_language_rl.py
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
|