utils
GenerateRNG
Source code in src/fjformer/utils.py
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | |
__init__(seed=0)
The init function is called when the class is instantiated. It sets up the initial state of the object, which in this case includes a seed and a random number generator. The seed can be set by passing an argument to init, but if no argument is passed it defaults to 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
self |
Represent the instance of the class |
required | |
seed |
int
|
int: Set the seed for the random number generator |
0
|
Returns:
| Type | Description |
|---|---|
|
The object itself |
Source code in src/fjformer/utils.py
122 123 124 125 126 127 128 129 130 131 132 133 134 | |
__next__()
The next function is called by the for loop to get the next value. It uses a while True loop so that it can return an infinite number of values. The function splits the random number generator into two parts, one part is used to generate a key and then returned, and the other part becomes the new random number generator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
self |
Represent the instance of the class |
required |
Returns:
| Type | Description |
|---|---|
|
A random number |
Source code in src/fjformer/utils.py
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | |
JaxRNG
Bases: object
Source code in src/fjformer/utils.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
__call__(keys=None)
The call function is a special function in Python that allows an object to be called like a function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
self |
Refer to the object itself |
required | |
keys |
Split the random number generator into multiple parts |
None
|
Returns:
| Type | Description |
|---|---|
|
A random number generator |
Source code in src/fjformer/utils.py
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
__init__(rng)
The init function is called when the class is instantiated. It sets up the random number generator, which will be used to generate random numbers for initializing weights and biases.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
self |
Represent the instance of the class |
required | |
rng |
Generate random numbers |
required |
Returns:
| Type | Description |
|---|---|
|
The object itself |
Source code in src/fjformer/utils.py
58 59 60 61 62 63 64 65 66 67 68 69 | |
from_seed(seed)
classmethod
The from_seed function is a class method that takes a seed and returns an instance of the class. This allows us to create multiple instances of the same random number generator with different seeds, which can be useful for debugging or reproducibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cls |
Pass the class of the object that is being created |
required | |
seed |
Initialize the random number generator |
required |
Returns:
| Type | Description |
|---|---|
|
An instance of the class |
Source code in src/fjformer/utils.py
43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
count_num_params(_p)
The count_num_params function is a helper function that counts the number of parameters in a model. It takes as input an unfrozen parameter dictionary, and returns the total number of parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
_p |
Count the number of parameters in a model |
required |
Returns:
| Type | Description |
|---|---|
|
The number of parameters in the model |
Source code in src/fjformer/utils.py
16 17 18 19 20 21 22 23 24 25 26 | |
count_params(_p)
The count_params function takes in a Flax model and prints out the number of parameters it contains. Args: _p (Flax Params]): A Flax model to count the number of parameters for.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
_p |
Count the number of parameters in a model |
required |
Returns:
| Type | Description |
|---|---|
|
The number of parameters in a model |
Source code in src/fjformer/utils.py
29 30 31 32 33 34 35 36 37 38 39 | |
init_rng(seed)
The init_rng function initializes the global random number generator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed |
Initialize the random number generator |
required |
Returns:
| Type | Description |
|---|---|
|
A random number generator |
Source code in src/fjformer/utils.py
93 94 95 96 97 98 99 100 101 102 | |
is_torch_available()
The is_torch_available function checks if PyTorch is installed.
Returns:
| Type | Description |
|---|---|
|
True if the torch module is installed |
Source code in src/fjformer/utils.py
6 7 8 9 10 11 12 13 | |
next_rng(*args, **kwargs)
The next_rng function is a wrapper around jax.random.PRNGKey, which is used to generate random numbers in JAX. The next_rng function generates a new PRNGKey from the previous one, and updates the global variable jax_utils_rng with this new key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args |
Pass a variable number of arguments to the function |
()
|
|
**kwargs |
Pass in a dictionary of parameters |
{}
|
Returns:
| Type | Description |
|---|---|
|
A random number generator |
Source code in src/fjformer/utils.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 | |