可配置的优化器

10from typing import Tuple
11
12import torch
13
14from labml.configs import BaseConfigs, option, meta_config
15from labml_nn.optimizers import WeightDecay

优化器配置

18class OptimizerConfigs(BaseConfigs):

优化器

26    optimizer: torch.optim.Adam

体重衰减

29    weight_decay_obj: WeightDecay

权重衰减是否解耦;即权重衰减不添加到梯度中

32    weight_decouple: bool = True

体重衰减

34    weight_decay: float = 0.0

体重衰减是绝对的还是应该乘以学习速率

36    weight_decay_absolute: bool = False

adam 更新是否经过优化(不同的 epsilon)

39    optimized_adam_update: bool = True

要优化的参数

42    parameters: any

学习率

45    learning_rate: float = 0.01

Adam 的 Beta 值

47    betas: Tuple[float, float] = (0.9, 0.999)

Epsilon 代表亚当

49    eps: float = 1e-08

新加坡元的势头

52    momentum: float = 0.5

是否使用 AmsGrad

54    amsgrad: bool = False

预热优化器步骤数

57    warmup: int = 2_000

优化器步长总数(余弦衰减)

59    total_steps: int = int(1e10)

是否在 AdaBeLief 中退化为新加坡元

62    degenerate_to_sgd: bool = True

是否在 AdaBelief 中使用整改过的亚当

65    rectify: bool = True

Noam 优化器的模型嵌入大小

68    d_model: int
69
70    rho: float
72    def __init__(self):
73        super().__init__(_primary='optimizer')
74
75
76meta_config(OptimizerConfigs.parameters)
79@option(OptimizerConfigs.weight_decay_obj, 'L2')
80def _weight_decay(c: OptimizerConfigs):
81    return WeightDecay(c.weight_decay, c.weight_decouple, c.weight_decay_absolute)
82
83
84@option(OptimizerConfigs.optimizer, 'SGD')
85def _sgd_optimizer(c: OptimizerConfigs):
86    return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum,
87                           weight_decay=c.weight_decay)
88
89
90@option(OptimizerConfigs.optimizer, 'Adam')
91def _adam_optimizer(c: OptimizerConfigs):
92    if c.amsgrad:
93        from labml_nn.optimizers.amsgrad import AMSGrad
94        return AMSGrad(c.parameters,
95                       lr=c.learning_rate, betas=c.betas, eps=c.eps,
96                       optimized_update=c.optimized_adam_update,
97                       weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad)
98    else:
99        from labml_nn.optimizers.adam import Adam
100        return Adam(c.parameters,
101                    lr=c.learning_rate, betas=c.betas, eps=c.eps,
102                    optimized_update=c.optimized_adam_update,
103                    weight_decay=c.weight_decay_obj)
104
105
106@option(OptimizerConfigs.optimizer, 'AdamW')
107def _adam_warmup_optimizer(c: OptimizerConfigs):
108    from labml_nn.optimizers.adam_warmup import AdamWarmup
109    return AdamWarmup(c.parameters,
110                      lr=c.learning_rate, betas=c.betas, eps=c.eps,
111                      weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup)
112
113
114@option(OptimizerConfigs.optimizer, 'RAdam')
115def _radam_optimizer(c: OptimizerConfigs):
116    from labml_nn.optimizers.radam import RAdam
117    return RAdam(c.parameters,
118                 lr=c.learning_rate, betas=c.betas, eps=c.eps,
119                 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
120                 degenerated_to_sgd=c.degenerate_to_sgd)
121
122
123@option(OptimizerConfigs.optimizer, 'AdaBelief')
124def _ada_belief_optimizer(c: OptimizerConfigs):
125    from labml_nn.optimizers.ada_belief import AdaBelief
126    return AdaBelief(c.parameters,
127                     lr=c.learning_rate, betas=c.betas, eps=c.eps,
128                     weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
129                     degenerate_to_sgd=c.degenerate_to_sgd,
130                     rectify=c.rectify)
131
132
133@option(OptimizerConfigs.optimizer, 'Noam')
134def _noam_optimizer(c: OptimizerConfigs):
135    from labml_nn.optimizers.noam import Noam
136    return Noam(c.parameters,
137                lr=c.learning_rate, betas=c.betas, eps=c.eps,
138                weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup,
139                d_model=c.d_model)
140
141
142@option(OptimizerConfigs.optimizer, 'Sophia')
143def _sophia_optimizer(c: OptimizerConfigs):
144    from labml_nn.optimizers.sophia import Sophia
145    return Sophia(c.parameters,
146                  lr=c.learning_rate, betas=c.betas, eps=c.eps,
147                  weight_decay=c.weight_decay_obj, rho=c.rho)
148
149
150@option(OptimizerConfigs.optimizer, 'AdamWarmupCosineDecay')
151def _noam_optimizer(c: OptimizerConfigs):
152    from labml_nn.optimizers.adam_warmup_cosine_decay import AdamWarmupCosineDecay
153    return AdamWarmupCosineDecay(c.parameters,
154                                 lr=c.learning_rate, betas=c.betas, eps=c.eps,
155                                 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
156                                 warmup=c.warmup, total_steps=c.total_steps)