Skip to content

optimizers.adamw

get_adamw_with_cosine_scheduler(steps, learning_rate=5e-05, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, weight_decay=0.1, gradient_accumulation_steps=1, mu_dtype=None)

Parameters:

Name Type Description Default
gradient_accumulation_steps int
1
steps int
required
learning_rate float
5e-05
b1 float
0.9
b2 float
0.999
eps float
1e-08
eps_root float
0.0
weight_decay float
0.1
mu_dtype Optional[ArrayDType]
None

Returns:

Type Description

Optimizer and Scheduler

Source code in src/fjformer/optimizers/adamw.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_adamw_with_cosine_scheduler(
        steps: int,
        learning_rate: float = 5e-5,
        b1: float = 0.9,
        b2: float = 0.999,
        eps: float = 1e-8,
        eps_root: float = 0.0,
        weight_decay: float = 1e-1,
        gradient_accumulation_steps: int = 1,
        mu_dtype: Optional[chex.ArrayDType] = None,

):
    """

    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate:
    :param b1:
    :param b2:
    :param eps:
    :param eps_root:
    :param weight_decay:
    :param mu_dtype:
    :return: Optimizer and Scheduler
    """
    scheduler = optax.cosine_decay_schedule(
        init_value=learning_rate,
        decay_steps=steps
    )
    tx = optax.chain(
        optax.scale_by_adam(
            b1=b1,
            b2=b2,
            eps=eps,
            eps_root=eps_root,
            mu_dtype=mu_dtype
        ),
        optax.add_decayed_weights(
            weight_decay=weight_decay
        ),
        optax.scale_by_schedule(scheduler),
        optax.scale(-1)
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler

get_adamw_with_linear_scheduler(steps, learning_rate_start=5e-05, learning_rate_end=1e-05, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, weight_decay=0.1, gradient_accumulation_steps=1, mu_dtype=None)

Parameters:

Name Type Description Default
gradient_accumulation_steps int
1
steps int
required
learning_rate_start float
5e-05
learning_rate_end float
1e-05
b1 float
0.9
b2 float
0.999
eps float
1e-08
eps_root float
0.0
weight_decay float
0.1
mu_dtype Optional[ArrayDType]
None

Returns:

Type Description

Optimizer and Scheduler

Source code in src/fjformer/optimizers/adamw.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def get_adamw_with_linear_scheduler(
        steps: int,
        learning_rate_start: float = 5e-5,
        learning_rate_end: float = 1e-5,
        b1: float = 0.9,
        b2: float = 0.999,
        eps: float = 1e-8,
        eps_root: float = 0.0,
        weight_decay: float = 1e-1,
        gradient_accumulation_steps: int = 1,
        mu_dtype: Optional[chex.ArrayDType] = None,

):
    """

    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate_start:
    :param learning_rate_end:
    :param b1:
    :param b2:
    :param eps:
    :param eps_root:
    :param weight_decay:
    :param mu_dtype:
    :return: Optimizer and Scheduler
    """
    scheduler = optax.linear_schedule(
        init_value=learning_rate_start,
        end_value=learning_rate_end,
        transition_steps=steps
    )
    tx = optax.chain(
        optax.scale_by_adam(
            b1=b1,
            b2=b2,
            eps=eps,
            eps_root=eps_root,
            mu_dtype=mu_dtype
        ),
        optax.add_decayed_weights(
            weight_decay=weight_decay
        ),
        optax.scale_by_schedule(scheduler),
        optax.scale(-1)
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler

get_adamw_with_warm_up_cosine_scheduler(steps, learning_rate=5e-05, learning_rate_end=1e-05, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, weight_decay=0.1, exponent=1.0, gradient_accumulation_steps=1, warmup_steps=500, mu_dtype=None)

Parameters:

Name Type Description Default
steps int
required
learning_rate float
5e-05
learning_rate_end float
1e-05
b1 float
0.9
b2 float
0.999
eps float
1e-08
eps_root float
0.0
weight_decay float
0.1
exponent float
1.0
gradient_accumulation_steps int
1
warmup_steps int
500
mu_dtype Optional[ArrayDType]
None

Returns:

Type Description
Source code in src/fjformer/optimizers/adamw.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def get_adamw_with_warm_up_cosine_scheduler(
        steps: int,
        learning_rate: float = 5e-5,
        learning_rate_end: float = 1e-5,
        b1: float = 0.9,
        b2: float = 0.999,
        eps: float = 1e-8,
        eps_root: float = 0.0,
        weight_decay: float = 1e-1,
        exponent: float = 1.0,
        gradient_accumulation_steps: int = 1,
        warmup_steps: int = 500,
        mu_dtype: Optional[chex.ArrayDType] = None
):
    """

    :param steps:
    :param learning_rate:
    :param learning_rate_end:
    :param b1:
    :param b2:
    :param eps:
    :param eps_root:
    :param weight_decay:
    :param exponent:
    :param gradient_accumulation_steps:
    :param warmup_steps:
    :param mu_dtype:
    :return:
    """
    scheduler = optax.warmup_cosine_decay_schedule(
        init_value=0.5e-7,
        peak_value=learning_rate,
        warmup_steps=warmup_steps,
        decay_steps=steps,
        end_value=learning_rate_end,
        exponent=exponent
    )
    tx = optax.chain(
        optax.scale_by_adam(
            b1=b1,
            b2=b2,
            eps=eps,
            eps_root=eps_root,
            mu_dtype=mu_dtype
        ),
        optax.add_decayed_weights(
            weight_decay=weight_decay
        ),
        optax.scale_by_schedule(scheduler),
        optax.scale(-1)
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler

get_adamw_with_warmup_linear_scheduler(steps, learning_rate_start=5e-05, learning_rate_end=1e-05, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, weight_decay=0.1, gradient_accumulation_steps=1, mu_dtype=None, warmup_steps=500)

Thanks TO JinSeoungwoo

Parameters:

Name Type Description Default
warmup_steps int
500
gradient_accumulation_steps int
1
steps int
required
learning_rate_start float
5e-05
learning_rate_end float
1e-05
b1 float
0.9
b2 float
0.999
eps float
1e-08
eps_root float
0.0
weight_decay float
0.1
mu_dtype Optional[ArrayDType]

New parameter for warmup @warmup_steps (int): Number of steps for the warmup phase # return Optimizer and Scheduler with WarmUp feature

None
Source code in src/fjformer/optimizers/adamw.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def get_adamw_with_warmup_linear_scheduler(
        steps: int,
        learning_rate_start: float = 5e-5,
        learning_rate_end: float = 1e-5,
        b1: float = 0.9,
        b2: float = 0.999,
        eps: float = 1e-8,
        eps_root: float = 0.0,
        weight_decay: float = 1e-1,
        gradient_accumulation_steps: int = 1,
        mu_dtype: Optional[chex.ArrayDType] = None,
        warmup_steps: int = 500
):
    """
    Thanks TO [JinSeoungwoo](https://github.com/erfanzar/EasyDeL/issues/32)
    :param warmup_steps:
    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate_start:
    :param learning_rate_end:
    :param b1:
    :param b2:
    :param eps:
    :param eps_root:
    :param weight_decay:
    :param mu_dtype:

     # New parameter for warmup
     @warmup_steps (int): Number of steps for the warmup phase

     # return Optimizer and Scheduler with WarmUp feature
   """

    scheduler_warmup = optax.linear_schedule(init_value=5e-8, end_value=learning_rate_start,
                                             transition_steps=warmup_steps)
    scheduler_decay = optax.linear_schedule(init_value=learning_rate_start, end_value=learning_rate_end,
                                            transition_steps=steps - warmup_steps)

    scheduler_combined = optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps])

    tx = optax.chain(
        optax.scale_by_adam(
            b1=b1,
            b2=b2,
            eps=eps,
            eps_root=eps_root,
            mu_dtype=mu_dtype
        ),
        optax.add_decayed_weights(
            weight_decay=weight_decay
        ),
        optax.scale_by_schedule(scheduler_combined),
        optax.scale(-1)
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler_combined