1from typing import Optional, Set, List
2
3import torch.nn as nn
4import torch.optim
5import torch.utils.data
6from torch.cuda import amp
7from torch.cuda.amp import GradScaler
8
9from labml import monit, tracker
10from labml.configs import BaseConfigs, option
11from labml_nn.neox.utils.finetune import FineTuner
14def get_trainable_params(model: nn.Module):
Get all parameters
23 params = list(model.parameters())
Filter parameters that require gradients
25 trainable_params = [p for p in params if p.requires_grad]
28 return trainable_params
31class TrainerConf(BaseConfigs):
32 model: nn.Module
33 layers: List[nn.Module]
34 optimizer: torch.optim.Optimizer = 'Adam'
35 train_loader: torch.utils.data.DataLoader
36 valid_loader: Optional[torch.utils.data.DataLoader] = None,
37 device: torch.device = torch.device('cuda:0')
38 scaler: Optional[GradScaler] = 'Default'
39 is_amp: bool = True
40 dtype: torch.dtype = torch.float16
41
42 is_clone_layers: bool = True
43
44 loss_func: nn.Module = nn.CrossEntropyLoss()
45 checkpoints_per_epoch: int = 0
46 samples_per_epoch: int = 0
47
48 grad_norm: Optional[float] = 1.0
49 learning_rate: float = 3e-4
50 max_seq_len: int = 1024
51 batch_size: int = 64
52 epochs: int = 16
53
54 n_gpus: int = torch.cuda.device_count()
55
56 filter_layers: Optional[Set] = None
dataset_split
train/valid sample
is the sample Returns the loss, output and the target
58 def get_loss(self, sample, dataset_split: str):
64 data, target = sample
Forward pass
67 with monit.section('Forward pass'):
68 output = self.model(data.to(self.device))
Move targets to the same device as output
70 target = target.to(output.device)
Calculate loss
72 loss = self.loss_func(output.view(target.numel(), -1), target.view(-1))
73
74 return loss, output, target
76 def train(self):
77 for epoch in monit.loop(self.epochs):
78 self.train_epoch()
79 tracker.new_line()
81 def sample(self, idx):
82 pass
84 def save_checkpoint(self, idx):
85 pass
87 def get_iterators(self):
Iterate through the batches
89 iterators = [('train', self.train_loader)]
90 if self.valid_loader is not None:
91 iterators.append(('valid', self.valid_loader))
92
93 if self.samples_per_epoch > 0:
94 iterators.append((self.sample, [i for i in range(self.samples_per_epoch)]))
95
96 if self.checkpoints_per_epoch > 0:
97 iterators.append((self.save_checkpoint, [i for i in range(self.checkpoints_per_epoch)]))
98
99 return iterators
101 def train_epoch(self):
Set model for train
103 self.model.train()
104
105 iterators = self.get_iterators()
106 for split_name, sample in monit.mix(1024, *iterators):
107 if split_name == 'train':
Set gradients to zero
109 self.optimizer.zero_grad()
110 tracker.add_global_step()
111
112 with torch.set_grad_enabled(split_name == 'train'):
113 if self.is_amp:
Forward pass
115 with amp.autocast():
116 loss, output, target = self.get_loss(sample, split_name)
117 else:
118 loss, output, target = self.get_loss(sample, split_name)
Get predictions
121 pred = output.argmax(dim=-1)
Calculate accuracy
123 accuracy = pred.eq(target).sum().item() / (target != -100).sum()
124
125 tracker.add({f'loss.{split_name}': loss, f'acc.{split_name}': accuracy * 100})
126
127 if split_name == 'train':
128 if self.scaler is not None:
Backward pass
130 loss = self.scaler.scale(loss)
tracker.add({'loss.scaled': loss})
133 with monit.section('Backward pass'):
134 loss.backward()
Optimize
137 with monit.section('Optimize'):
138 if self.scaler is None:
139 self.optimizer.step()
140 else:
141 self.scaler.unscale_(self.optimizer)
142 if self.grad_norm is not None:
143 torch.nn.utils.clip_grad_norm_(get_trainable_params(self.model), self.grad_norm)
144 self.scaler.step(self.optimizer)
145 self.scaler.update()
146
147 tracker.save()
150@option(TrainerConf.optimizer, 'Adam')
151def adam_optimizer(c: TrainerConf):
152 if c.dtype == torch.float32:
153 return torch.optim.Adam(get_trainable_params(c.model), lr=c.learning_rate)
154 elif c.dtype == torch.float16:
155 from labml_nn.optimizers.adam_fp16 import AdamFP16
156 return AdamFP16(get_trainable_params(c.model), lr=c.learning_rate)
157 else:
158 raise NotImplementedError()
159
160
161@option(TrainerConf.optimizer, 'SGD')
162def sgd_optimizer(c: TrainerConf):
163 return torch.optim.SGD(get_trainable_params(c.model), lr=c.learning_rate)
164
165
166@option(TrainerConf.scaler, 'Default')
167def grad_scaler(c: TrainerConf):
168 if not c.is_amp:
169 return None
170
171 if c.dtype == torch.float16:
172 from labml_nn.optimizers.adam_fp16 import GradScalerFP16
173 return GradScalerFP16()
174 else:
175 return GradScaler()
176
177
178class PipelineParallelTrainerConf(TrainerConf):
179 is_checkpointing: bool = False
180 chunks: int
181
182 fine_tuner: FineTuner