1import signal
2import typing
3from typing import Dict, List, Callable
4from typing import Optional, Tuple, Any, Collection
5
6import torch.optim
7import torch.optim
8import torch.utils.data
9import torch.utils.data
10from labml import tracker, logger, monit
11from labml.configs import BaseConfigs, meta_config, option
12from labml.internal.monitor import Loop
13from labml.logger import Text
14from torch import nn
15from .device import DeviceConfigs
16from .metrics import StateModule
19class TrainingLoopIterator(Collection):
20    def __init__(self, start: int, total: int, step: Optional[int]):
21        self.step = step
22        self.total = total
23        self.start = start
24        self.i = None
26    def __iter__(self):
27        self.i = None
28        return self
30    def __next__(self):
31        if self.step is not None:
32            if self.i is None:
33                self.i = self.start
34            else:
35                self.i += self.step
36        else:
37            if self.i is None:
38                self.i = 0
39            else:
40                self.i += 1
41
42        if self.i >= self.total:
43            raise StopIteration()
44
45        if self.step is None:
46            return tracker.get_global_step()
47        else:
48            return self.i
50    def __len__(self) -> int:
51        if self.step is not None:
52            return (self.total - self.start) // self.step
53        else:
54            return self.total
56    def __contains__(self, x: object) -> bool:
57        return False
60class TrainingLoop:
61    _iter: Optional[TrainingLoopIterator]
62    __loop: Loop
63    __signal_received: Optional[Tuple[Any, Any]]
65    def __init__(self, *,
66                 loop_count: int,
67                 loop_step: Optional[int],
68                 log_new_line_interval: int,
69                 log_write_interval: int,
70                 is_loop_on_interrupt: bool):
71        self.__loop_count = loop_count
72        self.__loop_step = loop_step
73        self.__log_new_line_interval = log_new_line_interval
74        self.__log_write_interval = log_write_interval
75        self.__last_write_step = 0
76        self.__last_new_line_step = 0
77        self.__last_save_step = 0
78        self.__signal_received = None
79        self.__is_loop_on_interrupt = is_loop_on_interrupt
80        self._iter = None
82    def __iter__(self):
83        self._iter = TrainingLoopIterator(tracker.get_global_step(),
84                                          self.__loop_count,
85                                          self.__loop_step)
86
87        self.__loop = monit.loop(typing.cast(Collection, self._iter))
88
89        iter(self.__loop)
90        try:
91            self.old_handler = signal.signal(signal.SIGINT, self.__handler)
92        except ValueError:
93            pass
94        return self
96    @property
97    def idx(self):
98        if not self._iter:
99            return 0
100        if not self._iter.i:
101            return 0
102        if self.__loop_step is None:
103            return self._iter.i
104        return self._iter.i / self.__loop_step
105
106    def __finish(self):
107        try:
108            signal.signal(signal.SIGINT, self.old_handler)
109        except ValueError:
110            pass
111        tracker.save()
112        tracker.new_line()
113
114    def __next__(self):
115        if self.__signal_received is not None:
116            logger.log('\nKilling Loop.', Text.danger)
117            monit.finish_loop()
118            self.__finish()
119            raise StopIteration("SIGINT")
120
121        try:
122            global_step = next(self.__loop)
123        except StopIteration as e:
124            self.__finish()
125            raise e
126
127        tracker.set_global_step(global_step)
128
129        if global_step - self.__last_write_step >= self.__log_write_interval:
130            tracker.save()
131            self.__last_write_step = global_step
132        if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
133            tracker.new_line()
134            self.__last_new_line_step = global_step
135
136        return global_step
137
138    def __handler(self, sig, frame):

Pass second interrupt without delaying

140        if self.__signal_received is not None:
141            logger.log('\nSIGINT received twice. Stopping...', Text.danger)
142            self.old_handler(*self.__signal_received)
143            return
144
145        if self.__is_loop_on_interrupt:

Store the interrupt signal for later

147            self.__signal_received = (sig, frame)
148            logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)
149        else:
150            self.__finish()
151            logger.log('Killing loop...', Text.danger)
152            self.old_handler(sig, frame)
154    def __str__(self):
155        return "LabTrainingLoop"

This is a configurable training loop. You can extend this class for your configurations if it involves a training loop.

>>> for step in conf.training_loop: >>> ...

Arguments: loop_count (int): Total number of steps. Defaults to 10 . loop_step (int): Number of steps to increment per iteration. Defaults to 1 . log_new_line_interval (int): The interval (in steps) to print a new line to the screen. Defaults to 1 . log_write_interval (int): The interval (in steps) to call :func:labml.tracker.save . Defaults to 1 . is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete. Defaults to False .

158class TrainingLoopConfigs(BaseConfigs):
176    loop_count: int = 10
177    loop_step: int = 1
178    log_new_line_interval: int = 1
179    log_write_interval: int = 1
180    is_loop_on_interrupt: bool = False
181
182    training_loop: TrainingLoop
185@option(TrainingLoopConfigs.training_loop)
186def _loop_configs(c: TrainingLoopConfigs):
187    return TrainingLoop(loop_count=c.loop_count,
188                        loop_step=c.loop_step,
189                        log_new_line_interval=c.log_new_line_interval,
190                        log_write_interval=c.log_write_interval,
191                        is_loop_on_interrupt=c.is_loop_on_interrupt)
192
193
194meta_config(TrainingLoopConfigs.loop_step,
195            TrainingLoopConfigs.loop_count,
196            TrainingLoopConfigs.log_new_line_interval,
197            TrainingLoopConfigs.log_write_interval,
198            TrainingLoopConfigs.is_loop_on_interrupt)
199
200
201class ModeState:
202    def __init__(self):
203        self._rollback_stack = []
204
205        self.is_train = False
206        self.is_optimize = False
207
208    def _enter(self, mode: Dict[str, any]):
209        rollback = {}
210        for k, v in mode.items():
211            if v is None:
212                continue
213            rollback[k] = getattr(self, k)
214            setattr(self, k, v)
215
216        self._rollback_stack.append(rollback)
217
218        return len(self._rollback_stack)
219
220    def _exit(self, n: int):
221        assert n == len(self._rollback_stack)
222
223        rollback = self._rollback_stack[-1]
224        self._rollback_stack.pop(-1)
225
226        for k, v in rollback.items():
227            setattr(self, k, v)
228
229    def update(self, *,
230               is_train: Optional[bool] = None,
231               is_optimize: Optional[bool] = None):
232        return Mode(self,
233                    is_train=is_train,
234                    is_optimize=is_optimize)
235
236
237class Mode:
238    def __init__(self, mode: ModeState, **kwargs: any):
239        self.mode = mode
240        self.update = {}
241        for k, v in kwargs.items():
242            if v is not None:
243                self.update[k] = v
244
245        self.idx = -1
246
247    def __enter__(self):
248        self.idx = self.mode._enter(self.update)
249
250    def __exit__(self, exc_type, exc_val, exc_tb):
251        self.mode._exit(self.idx)
252
253
254class Trainer:
255    def __init__(self, *,
256                 name: str,
257                 mode: ModeState,
258                 data_loader: torch.utils.data.DataLoader,
259                 inner_iterations: int,
260                 state_modules: List[StateModule],
261                 is_track_time: bool,
262                 step: Callable[[any, 'BatchIndex'], None]):
263        self.is_track_time = is_track_time
264        self.mode = mode
265        self.name = name
266        self.step = step
267        self.state_modules = state_modules
268        self.__iterable = None
269        self.__states = [sm.create_state() for sm in self.state_modules]
270        self.inner_iterations = inner_iterations
271        self.data_loader = data_loader
272        self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations)
273
274    def set_data_loader(self, data_loader: torch.utils.data.DataLoader):
275        self.data_loader = data_loader
276        self._batch_index = BatchIndex(len(data_loader), self.inner_iterations)
277        self.__iterable = None
278
279    def __call__(self):
280        for sm, s in zip(self.state_modules, self.__states):
281            sm.set_state(s)
282
283        if self.__iterable is None or self._batch_index.completed:
284            self.__iterable = iter(self.data_loader)
285            self._batch_index.reset(len(self.data_loader), self.inner_iterations)
286            for sm in self.state_modules:
287                sm.on_epoch_start()
288        with torch.set_grad_enabled(self.mode.is_train):
289            self.__iterate()
290
291        if self._batch_index.completed:
292            for sm in self.state_modules:
293                sm.on_epoch_end()
294
295    def __iterate(self):
296        with monit.section(self.name, is_partial=True, is_track=self.is_track_time):
297            if self._batch_index.idx == 0:
298                monit.progress(0)
299            while not self._batch_index.iteration_completed:
300                batch = next(self.__iterable)
301
302                self.step(batch, self._batch_index)
303
304                self._batch_index.step()
305                monit.progress(self._batch_index.epoch_progress)
306
307        self._batch_index.step_inner()
308
309
310class BatchIndex:
311    idx: int
312    total: int
313    iteration: int
314    total_iterations: int
315
316    def __init__(self, total: int, total_iterations: int):
317        self.total_iterations = total_iterations
318        self.total = total
319
320    def is_interval(self, interval: int):
321        if interval <= 0:
322            return False
323        if self.idx + 1 == self.total:
324            return True
325        else:
326            return (self.idx + 1) % interval == 0
327
328    @property
329    def is_last(self):
330        return self.idx + 1 == self.total
331
332    @property
333    def completed(self):
334        return self.iteration >= self.total_iterations
335
336    @property
337    def iteration_completed(self):

// is important so that the last step happens on the last iteration

339        return self.idx >= (self.iteration + 1) * self.total // self.total_iterations

This is a configurable module that you can extend for experiments that involve a training and validation datasets (i.e. most DL experiments).

Arguments: epochs (int): Number of epochs to train on. Defaults to 10 . train_loader (torch.utils.data.DataLoader): Training data loader. valid_loader (torch.utils.data.DataLoader): Training data loader. inner_iterations (int): Number of times to switch between training and validation within an epoch. Defaults to 1 .

You can override init , step functions. There is also a sample function that you can override to generate samples ever time it switches between training and validation.

341    @property
342    def epoch_progress(self):
343        return self.idx / self.total
344
345    def step(self):
346        self.idx += 1
347
348    def step_inner(self):
349        self.iteration += 1
350
351    def reset(self, total: int, total_iterations: int):
352        self.total = total
353        self.total_iterations = total_iterations
354        self.idx = 0
355        self.iteration = 0
356
357
358class TrainValidConfigs(TrainingLoopConfigs):
373    state_modules: List[StateModule]
374
375    mode: ModeState
376
377    epochs: int = 10
378
379    trainer: Trainer
380    validator: Trainer
381    train_loader: torch.utils.data.DataLoader
382    valid_loader: torch.utils.data.DataLoader
383
384    loop_count = '_data_loop_count'
385    loop_step = None
386
387    inner_iterations: int = 1
388
389    is_track_time: bool = False
391    def init(self):
392        pass
394    def step(self, batch: Any, batch_idx: BatchIndex):
395        raise NotImplementedError
397    def run_step(self):
398        for i in range(self.inner_iterations):
399            with tracker.namespace('sample'):
400                self.sample()
401            with self.mode.update(is_train=True):
402                with tracker.namespace('train'):
403                    self.trainer()
404            if self.validator:
405                with tracker.namespace('valid'):
406                    self.validator()
407            tracker.save()
409    def run(self):
410        with monit.section("Initialize"):
411            self.init()
412        _ = self.validator
413        _ = self.trainer
414        for _ in self.training_loop:
415            self.run_step()
417    def sample(self):
418        pass

This is a configurable module that works for many standard DL experiments.

Arguments: model: A PyTorch model. optimizer: A PyTorch optimizer to update model. device: The device to train the model on. This defaults to a configurable device loss_function: A function to calculate the loss. This should accept model_output, target as arguments. update_batches (int): Number of batches to accumulate before taking an optimizer step. Defaults to 1 . log_save_batches (int): How often to call :func:labml.tracker.save .

421@option(TrainValidConfigs.trainer)
422def _default_trainer(c: TrainValidConfigs):
423    return Trainer(name='Train',
424                   mode=c.mode,
425                   data_loader=c.train_loader,
426                   inner_iterations=c.inner_iterations,
427                   state_modules=c.state_modules,
428                   is_track_time=c.is_track_time,
429                   step=c.step)
430
431
432@option(TrainValidConfigs.validator)
433def _default_validator(c: TrainValidConfigs):
434    return Trainer(name='Valid',
435                   mode=c.mode,
436                   data_loader=c.valid_loader,
437                   inner_iterations=c.inner_iterations,
438                   state_modules=c.state_modules,
439                   is_track_time=c.is_track_time,
440                   step=c.step)
441
442
443@option(TrainValidConfigs.loop_count)
444def _data_loop_count(c: TrainValidConfigs):
445    return c.epochs
446
447
448class SimpleTrainValidConfigs(TrainValidConfigs):
462    optimizer: torch.optim.Adam
463    model: nn.Module
464    device: torch.device = DeviceConfigs()
465
466    loss_func: nn.Module
467
468    update_batches: int = 1
469    log_save_batches: int = 1
470
471    state_modules: List[StateModule] = []
473    def init(self):
474        pass
476    def step(self, batch: Any, batch_idx: BatchIndex):
477        self.model.train(self.mode.is_train)
478        data, target = batch[0].to(self.device), batch[1].to(self.device)
479
480        if self.mode.is_train:
481            tracker.add_global_step(len(data))
482
483        with monit.section("model"):
484            output = self.model(data)
485
486        loss = self.loss_func(output, target)
487        tracker.add("loss.", loss)
488
489        if self.mode.is_train:
490            with monit.section('backward'):
491                loss.backward()
492
493            if batch_idx.is_interval(self.update_batches):
494                with monit.section('optimize'):
495                    self.optimizer.step()
496                self.optimizer.zero_grad()
497
498            if batch_idx.is_interval(self.log_save_batches):
499                tracker.save()
500
501
502meta_config(SimpleTrainValidConfigs.update_batches,
503            )
506@option(SimpleTrainValidConfigs.optimizer)
507def _default_optimizer(c: SimpleTrainValidConfigs):
508    from .optimizer import OptimizerConfigs
509    opt_conf = OptimizerConfigs()
510    opt_conf.parameters = c.model.parameters()
511    return opt_conf