Skip to content

optimizers.adafactor

get_adafactor_with_cosine_scheduler(steps, learning_rate=5e-05, min_dim_size_to_factor=128, decay_rate=0.8, decay_offset=0, multiply_by_parameter_scale=True, clipping_threshold=1.0, momentum=None, dtype_momentum=jnp.float32, weight_decay_rate=None, eps=1e-30, factored=True, gradient_accumulation_steps=1)

Parameters:

Name Type Description Default
gradient_accumulation_steps int
1
steps int
required
learning_rate
5e-05
weight_decay
required
min_dim_size_to_factor int
128
decay_rate float
0.8
decay_offset int
0
multiply_by_parameter_scale float
True
clipping_threshold Optional[float]
1.0
momentum Optional[float]
None
dtype_momentum ArrayDType
float32
weight_decay_rate Optional[float]
None
eps float
1e-30
factored bool
True
weight_decay_mask
required

Returns:

Type Description

Optimizer and Scheduler

Source code in src/fjformer/optimizers/adafactor.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def get_adafactor_with_cosine_scheduler(
        steps: int,
        learning_rate=5e-5,
        min_dim_size_to_factor: int = 128,
        decay_rate: float = 0.8,
        decay_offset: int = 0,
        multiply_by_parameter_scale: float = True,
        clipping_threshold: Optional[float] = 1.0,
        momentum: Optional[float] = None,
        dtype_momentum: chex.ArrayDType = jnp.float32,
        weight_decay_rate: Optional[float] = None,
        eps: float = 1e-30,
        factored: bool = True,
        gradient_accumulation_steps: int = 1
):
    """

    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate:
    :param weight_decay:
    :param min_dim_size_to_factor:
    :param decay_rate:
    :param decay_offset:
    :param multiply_by_parameter_scale:
    :param clipping_threshold:
    :param momentum:
    :param dtype_momentum:
    :param weight_decay_rate:
    :param eps:
    :param factored:
    :param weight_decay_mask:
    :param gradient_accumulation_steps
    :return: Optimizer and Scheduler
    """
    scheduler = optax.cosine_decay_schedule(
        init_value=learning_rate,
        decay_steps=steps
    )
    tx = optax.chain(
        optax.adafactor(
            learning_rate=scheduler,
            min_dim_size_to_factor=min_dim_size_to_factor,
            decay_rate=decay_rate,
            decay_offset=decay_offset,
            multiply_by_parameter_scale=multiply_by_parameter_scale,
            clipping_threshold=clipping_threshold,
            eps=eps,
            momentum=momentum,
            weight_decay_rate=weight_decay_rate,
            dtype_momentum=dtype_momentum,
            factored=factored
        )
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler

get_adafactor_with_linear_scheduler(steps, learning_rate_start=5e-05, learning_rate_end=1e-05, weight_decay=0.1, min_dim_size_to_factor=128, decay_rate=0.8, decay_offset=0, multiply_by_parameter_scale=True, clipping_threshold=1.0, momentum=None, dtype_momentum=jnp.float32, weight_decay_rate=None, eps=1e-30, factored=True, gradient_accumulation_steps=1, weight_decay_mask=None)

Parameters:

Name Type Description Default
gradient_accumulation_steps int
1
steps int
required
learning_rate_start float
5e-05
learning_rate_end float
1e-05
weight_decay
0.1
min_dim_size_to_factor int
128
decay_rate float
0.8
decay_offset int
0
multiply_by_parameter_scale float
True
clipping_threshold Optional[float]
1.0
momentum Optional[float]
None
dtype_momentum ArrayDType
float32
weight_decay_rate Optional[float]
None
eps float
1e-30
factored bool
True
weight_decay_mask
None

Returns:

Type Description

Optimizer and Scheduler

Source code in src/fjformer/optimizers/adafactor.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def get_adafactor_with_linear_scheduler(
        steps: int,
        learning_rate_start: float = 5e-5,
        learning_rate_end: float = 1e-5,
        weight_decay=1e-1,
        min_dim_size_to_factor: int = 128,
        decay_rate: float = 0.8,
        decay_offset: int = 0,
        multiply_by_parameter_scale: float = True,
        clipping_threshold: Optional[float] = 1.0,
        momentum: Optional[float] = None,
        dtype_momentum: chex.ArrayDType = jnp.float32,
        weight_decay_rate: Optional[float] = None,
        eps: float = 1e-30,
        factored: bool = True,
        gradient_accumulation_steps: int = 1,
        weight_decay_mask=None,

):
    """

    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate_start:
    :param learning_rate_end:
    :param weight_decay:
    :param min_dim_size_to_factor:
    :param decay_rate:
    :param decay_offset:
    :param multiply_by_parameter_scale:
    :param clipping_threshold:
    :param momentum:
    :param dtype_momentum:
    :param weight_decay_rate:
    :param eps:
    :param factored:
    :param weight_decay_mask:
    :return: Optimizer and Scheduler
    """
    scheduler = optax.linear_schedule(
        init_value=learning_rate_start,
        end_value=learning_rate_end,
        transition_steps=steps
    )

    tx = optax.chain(
        optax.adafactor(
            learning_rate=scheduler,
            min_dim_size_to_factor=min_dim_size_to_factor,
            decay_rate=decay_rate,
            decay_offset=decay_offset,
            multiply_by_parameter_scale=multiply_by_parameter_scale,
            clipping_threshold=clipping_threshold,
            eps=eps,
            momentum=momentum,
            weight_decay_rate=weight_decay_rate,
            dtype_momentum=dtype_momentum,
            factored=factored
        ),
        optax_add_scheduled_weight_decay(
            lambda step: -scheduler(step) * weight_decay,
            weight_decay_mask
        )
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler

get_adafactor_with_warm_up_cosine_scheduler(steps, learning_rate=5e-05, learning_rate_end=1e-05, weight_decay=0.1, min_dim_size_to_factor=128, decay_rate=0.8, decay_offset=0, multiply_by_parameter_scale=True, clipping_threshold=1.0, momentum=None, dtype_momentum=jnp.float32, weight_decay_rate=None, eps=1e-30, factored=True, exponent=1.0, weight_decay_mask=None, gradient_accumulation_steps=1, warmup_steps=500)

Parameters:

Name Type Description Default
steps int
required
learning_rate
5e-05
learning_rate_end
1e-05
weight_decay
0.1
min_dim_size_to_factor int
128
decay_rate float
0.8
decay_offset int
0
multiply_by_parameter_scale float
True
clipping_threshold Optional[float]
1.0
momentum Optional[float]
None
dtype_momentum ArrayDType
float32
weight_decay_rate Optional[float]
None
eps float
1e-30
factored bool
True
exponent float
1.0
weight_decay_mask
None
gradient_accumulation_steps int
1
warmup_steps int
500

Returns:

Type Description
Source code in src/fjformer/optimizers/adafactor.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def get_adafactor_with_warm_up_cosine_scheduler(
        steps: int,
        learning_rate=5e-5,
        learning_rate_end=1e-5,
        weight_decay=1e-1,
        min_dim_size_to_factor: int = 128,
        decay_rate: float = 0.8,
        decay_offset: int = 0,
        multiply_by_parameter_scale: float = True,
        clipping_threshold: Optional[float] = 1.0,
        momentum: Optional[float] = None,
        dtype_momentum: chex.ArrayDType = jnp.float32,
        weight_decay_rate: Optional[float] = None,
        eps: float = 1e-30,
        factored: bool = True,
        exponent: float = 1.0,
        weight_decay_mask=None,
        gradient_accumulation_steps: int = 1,
        warmup_steps: int = 500,
):
    """

    :param steps:
    :param learning_rate:
    :param learning_rate_end:
    :param weight_decay:
    :param min_dim_size_to_factor:
    :param decay_rate:
    :param decay_offset:
    :param multiply_by_parameter_scale:
    :param clipping_threshold:
    :param momentum:
    :param dtype_momentum:
    :param weight_decay_rate:
    :param eps:
    :param factored:
    :param exponent:
    :param weight_decay_mask:
    :param gradient_accumulation_steps:
    :param warmup_steps:
    :return:
    """
    scheduler = optax.warmup_cosine_decay_schedule(
        init_value=0.5e-7,
        peak_value=learning_rate,
        warmup_steps=warmup_steps,
        decay_steps=steps,
        end_value=learning_rate_end,
        exponent=exponent
    )
    tx = optax.chain(
        optax.adafactor(
            learning_rate=scheduler,
            min_dim_size_to_factor=min_dim_size_to_factor,
            decay_rate=decay_rate,
            decay_offset=decay_offset,
            multiply_by_parameter_scale=multiply_by_parameter_scale,
            clipping_threshold=clipping_threshold,
            eps=eps,
            momentum=momentum,
            weight_decay_rate=weight_decay_rate,
            dtype_momentum=dtype_momentum,
            factored=factored
        ),
        optax_add_scheduled_weight_decay(
            lambda step: -scheduler(step) * weight_decay,
            weight_decay_mask
        )
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler

get_adafactor_with_warmup_linear_scheduler(steps, min_dim_size_to_factor=128, decay_rate=0.8, decay_offset=0, multiply_by_parameter_scale=True, clipping_threshold=1.0, momentum=None, dtype_momentum=jnp.float32, weight_decay_rate=None, eps=1e-30, factored=True, gradient_accumulation_steps=1, learning_rate_start=5e-05, learning_rate_end=1e-05, warmup_steps=500)

Parameters:

Name Type Description Default
min_dim_size_to_factor int
128
decay_rate float
0.8
decay_offset int
0
multiply_by_parameter_scale float
True
clipping_threshold Optional[float]
1.0
momentum Optional[float]
None
dtype_momentum ArrayDType
float32
weight_decay_rate Optional[float]
None
factored bool
True
warmup_steps int
500
gradient_accumulation_steps int
1
steps int
required
learning_rate_start float
5e-05
learning_rate_end float
1e-05
eps float
1e-30
weight_decay

New parameter for warmup @warmup_steps (int): Number of steps for the warmup phase # return Optimizer and Scheduler with WarmUp feature

required
Source code in src/fjformer/optimizers/adafactor.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def get_adafactor_with_warmup_linear_scheduler(
        steps: int,
        min_dim_size_to_factor: int = 128,
        decay_rate: float = 0.8,
        decay_offset: int = 0,
        multiply_by_parameter_scale: float = True,
        clipping_threshold: Optional[float] = 1.0,
        momentum: Optional[float] = None,
        dtype_momentum: chex.ArrayDType = jnp.float32,
        weight_decay_rate: Optional[float] = None,
        eps: float = 1e-30,
        factored: bool = True,
        gradient_accumulation_steps: int = 1,
        learning_rate_start: float = 5e-5,
        learning_rate_end: float = 1e-5,
        warmup_steps: int = 500
):
    """
    :param min_dim_size_to_factor:
    :param decay_rate:
    :param decay_offset:
    :param multiply_by_parameter_scale:
    :param clipping_threshold:
    :param momentum:
    :param dtype_momentum:
    :param weight_decay_rate:
    :param factored:
    :param warmup_steps:
    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate_start:
    :param learning_rate_end:
    :param eps:
    :param weight_decay:

     # New parameter for warmup
     @warmup_steps (int): Number of steps for the warmup phase

     # return Optimizer and Scheduler with WarmUp feature
   """

    scheduler_warmup = optax.linear_schedule(init_value=5e-8, end_value=learning_rate_start,
                                             transition_steps=warmup_steps)
    scheduler_decay = optax.linear_schedule(init_value=learning_rate_start, end_value=learning_rate_end,
                                            transition_steps=steps - warmup_steps)

    scheduler_combined = optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps])

    tx = optax.chain(
        optax.adafactor(
            learning_rate=scheduler_combined,
            min_dim_size_to_factor=min_dim_size_to_factor,
            decay_rate=decay_rate,
            decay_offset=decay_offset,
            multiply_by_parameter_scale=multiply_by_parameter_scale,
            clipping_threshold=clipping_threshold,
            eps=eps,
            momentum=momentum,
            weight_decay_rate=weight_decay_rate,
            dtype_momentum=dtype_momentum,
            factored=factored
        )
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler_combined