10from typing import Tuple
11
12import torch
13
14from labml.configs import BaseConfigs, option, meta_config
15from labml_nn.optimizers import WeightDecay
18class OptimizerConfigs(BaseConfigs):
Optimizer
26 optimizer: torch.optim.Adam
Weight decay
29 weight_decay_obj: WeightDecay
Whether weight decay is decoupled; i.e. weight decay is not added to gradients
32 weight_decouple: bool = True
Weight decay
34 weight_decay: float = 0.0
Whether weight decay is absolute or should be multiplied by learning rate
36 weight_decay_absolute: bool = False
Whether the adam update is optimized (different epsilon)
39 optimized_adam_update: bool = True
Parameters to be optimized
42 parameters: any
Learning rate
45 learning_rate: float = 0.01
Beta values for Adam
47 betas: Tuple[float, float] = (0.9, 0.999)
Epsilon for adam
49 eps: float = 1e-08
Momentum for SGD
52 momentum: float = 0.5
Whether to use AMSGrad
54 amsgrad: bool = False
Number of warmup optimizer steps
57 warmup: int = 2_000
Total number of optimizer steps (for cosine decay)
59 total_steps: int = int(1e10)
Whether to degenerate to SGD in AdaBelief
62 degenerate_to_sgd: bool = True
Whether to use Rectified Adam in AdaBelief
65 rectify: bool = True
Model embedding size for Noam optimizer
68 d_model: int
69
70 rho: float
72 def __init__(self):
73 super().__init__(_primary='optimizer')
74
75
76meta_config(OptimizerConfigs.parameters)
79@option(OptimizerConfigs.weight_decay_obj, 'L2')
80def _weight_decay(c: OptimizerConfigs):
81 return WeightDecay(c.weight_decay, c.weight_decouple, c.weight_decay_absolute)
82
83
84@option(OptimizerConfigs.optimizer, 'SGD')
85def _sgd_optimizer(c: OptimizerConfigs):
86 return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum,
87 weight_decay=c.weight_decay)
88
89
90@option(OptimizerConfigs.optimizer, 'Adam')
91def _adam_optimizer(c: OptimizerConfigs):
92 if c.amsgrad:
93 from labml_nn.optimizers.amsgrad import AMSGrad
94 return AMSGrad(c.parameters,
95 lr=c.learning_rate, betas=c.betas, eps=c.eps,
96 optimized_update=c.optimized_adam_update,
97 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad)
98 else:
99 from labml_nn.optimizers.adam import Adam
100 return Adam(c.parameters,
101 lr=c.learning_rate, betas=c.betas, eps=c.eps,
102 optimized_update=c.optimized_adam_update,
103 weight_decay=c.weight_decay_obj)
104
105
106@option(OptimizerConfigs.optimizer, 'AdamW')
107def _adam_warmup_optimizer(c: OptimizerConfigs):
108 from labml_nn.optimizers.adam_warmup import AdamWarmup
109 return AdamWarmup(c.parameters,
110 lr=c.learning_rate, betas=c.betas, eps=c.eps,
111 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup)
112
113
114@option(OptimizerConfigs.optimizer, 'RAdam')
115def _radam_optimizer(c: OptimizerConfigs):
116 from labml_nn.optimizers.radam import RAdam
117 return RAdam(c.parameters,
118 lr=c.learning_rate, betas=c.betas, eps=c.eps,
119 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
120 degenerated_to_sgd=c.degenerate_to_sgd)
121
122
123@option(OptimizerConfigs.optimizer, 'AdaBelief')
124def _ada_belief_optimizer(c: OptimizerConfigs):
125 from labml_nn.optimizers.ada_belief import AdaBelief
126 return AdaBelief(c.parameters,
127 lr=c.learning_rate, betas=c.betas, eps=c.eps,
128 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
129 degenerate_to_sgd=c.degenerate_to_sgd,
130 rectify=c.rectify)
131
132
133@option(OptimizerConfigs.optimizer, 'Noam')
134def _noam_optimizer(c: OptimizerConfigs):
135 from labml_nn.optimizers.noam import Noam
136 return Noam(c.parameters,
137 lr=c.learning_rate, betas=c.betas, eps=c.eps,
138 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup,
139 d_model=c.d_model)
140
141
142@option(OptimizerConfigs.optimizer, 'Sophia')
143def _sophia_optimizer(c: OptimizerConfigs):
144 from labml_nn.optimizers.sophia import Sophia
145 return Sophia(c.parameters,
146 lr=c.learning_rate, betas=c.betas, eps=c.eps,
147 weight_decay=c.weight_decay_obj, rho=c.rho)
148
149
150@option(OptimizerConfigs.optimizer, 'AdamWarmupCosineDecay')
151def _noam_optimizer(c: OptimizerConfigs):
152 from labml_nn.optimizers.adam_warmup_cosine_decay import AdamWarmupCosineDecay
153 return AdamWarmupCosineDecay(c.parameters,
154 lr=c.learning_rate, betas=c.betas, eps=c.eps,
155 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
156 warmup=c.warmup, total_steps=c.total_steps)