1from typing import Optional, Set, List
2
3import torch.nn as nn
4import torch.optim
5import torch.utils.data
6from torch.cuda import amp
7from torch.cuda.amp import GradScaler
8
9from labml import monit, tracker
10from labml.configs import BaseConfigs, option
11from labml_nn.neox.utils.finetune import FineTuner

Get trainable parameters

  • model is the model to train
  • Returns a list of parameters for training

14def get_trainable_params(model: nn.Module):

Get all parameters

23    params = list(model.parameters())

Filter parameters that require gradients

25    trainable_params = [p for p in params if p.requires_grad]

28    return trainable_params
31class TrainerConf(BaseConfigs):
32    model: nn.Module
33    layers: List[nn.Module]
34    optimizer: torch.optim.Optimizer = 'Adam'
35    train_loader: torch.utils.data.DataLoader
36    valid_loader: Optional[torch.utils.data.DataLoader] = None,
37    device: torch.device = torch.device('cuda:0')
38    scaler: Optional[GradScaler] = 'Default'
39    is_amp: bool = True
40    dtype: torch.dtype = torch.float16
41
42    is_clone_layers: bool = True
43
44    loss_func: nn.Module = nn.CrossEntropyLoss()
45    checkpoints_per_epoch: int = 0
46    samples_per_epoch: int = 0
47
48    grad_norm: Optional[float] = 1.0
49    learning_rate: float = 3e-4
50    max_seq_len: int = 1024
51    batch_size: int = 64
52    epochs: int = 16
53
54    n_gpus: int = torch.cuda.device_count()
55
56    filter_layers: Optional[Set] = None
  • dataset_split train/valid
  • sample is the sample
  • Returns the loss, output and the target

58    def get_loss(self, sample, dataset_split: str):
64        data, target = sample

Forward pass

67        with monit.section('Forward pass'):
68            output = self.model(data.to(self.device))

Move targets to the same device as output

70        target = target.to(output.device)

Calculate loss

72        loss = self.loss_func(output.view(target.numel(), -1), target.view(-1))
73
74        return loss, output, target
76    def train(self):
77        for epoch in monit.loop(self.epochs):
78            self.train_epoch()
79            tracker.new_line()
81    def sample(self, idx):
82        pass
84    def save_checkpoint(self, idx):
85        pass
87    def get_iterators(self):

Iterate through the batches

89        iterators = [('train', self.train_loader)]
90        if self.valid_loader is not None:
91            iterators.append(('valid', self.valid_loader))
92
93        if self.samples_per_epoch > 0:
94            iterators.append((self.sample, [i for i in range(self.samples_per_epoch)]))
95
96        if self.checkpoints_per_epoch > 0:
97            iterators.append((self.save_checkpoint, [i for i in range(self.checkpoints_per_epoch)]))
98
99        return iterators
101    def train_epoch(self):

Set model for train

103        self.model.train()
104
105        iterators = self.get_iterators()
106        for split_name, sample in monit.mix(1024, *iterators):
107            if split_name == 'train':

Set gradients to zero

109                self.optimizer.zero_grad()
110                tracker.add_global_step()
111
112            with torch.set_grad_enabled(split_name == 'train'):
113                if self.is_amp:

Forward pass

115                    with amp.autocast():
116                        loss, output, target = self.get_loss(sample, split_name)
117                else:
118                    loss, output, target = self.get_loss(sample, split_name)

Get predictions

121                pred = output.argmax(dim=-1)

Calculate accuracy

123                accuracy = pred.eq(target).sum().item() / (target != -100).sum()
124
125                tracker.add({f'loss.{split_name}': loss, f'acc.{split_name}': accuracy * 100})
126
127            if split_name == 'train':
128                if self.scaler is not None:

Backward pass

130                    loss = self.scaler.scale(loss)

tracker.add({'loss.scaled': loss})

133                with monit.section('Backward pass'):
134                    loss.backward()

Optimize

137                with monit.section('Optimize'):
138                    if self.scaler is None:
139                        self.optimizer.step()
140                    else:
141                        self.scaler.unscale_(self.optimizer)
142                        if self.grad_norm is not None:
143                            torch.nn.utils.clip_grad_norm_(get_trainable_params(self.model), self.grad_norm)
144                        self.scaler.step(self.optimizer)
145                        self.scaler.update()
146
147            tracker.save()
150@option(TrainerConf.optimizer, 'Adam')
151def adam_optimizer(c: TrainerConf):
152    if c.dtype == torch.float32:
153        return torch.optim.Adam(get_trainable_params(c.model), lr=c.learning_rate)
154    elif c.dtype == torch.float16:
155        from labml_nn.optimizers.adam_fp16 import AdamFP16
156        return AdamFP16(get_trainable_params(c.model), lr=c.learning_rate)
157    else:
158        raise NotImplementedError()
159
160
161@option(TrainerConf.optimizer, 'SGD')
162def sgd_optimizer(c: TrainerConf):
163    return torch.optim.SGD(get_trainable_params(c.model), lr=c.learning_rate)
164
165
166@option(TrainerConf.scaler, 'Default')
167def grad_scaler(c: TrainerConf):
168    if not c.is_amp:
169        return None
170
171    if c.dtype == torch.float16:
172        from labml_nn.optimizers.adam_fp16 import GradScalerFP16
173        return GradScalerFP16()
174    else:
175        return GradScaler()
176
177
178class PipelineParallelTrainerConf(TrainerConf):
179    is_checkpointing: bool = False
180    chunks: int
181
182    fine_tuner: FineTuner