1import signal
2import typing
3from typing import Dict, List, Callable
4from typing import Optional, Tuple, Any, Collection
5
6import torch.optim
7import torch.optim
8import torch.utils.data
9import torch.utils.data
10from labml import tracker, logger, monit
11from labml.configs import BaseConfigs, meta_config, option
12from labml.internal.monitor import Loop
13from labml.logger import Text
14from torch import nn
15from .device import DeviceConfigs
16from .metrics import StateModule
19class TrainingLoopIterator(Collection):
20 def __init__(self, start: int, total: int, step: Optional[int]):
21 self.step = step
22 self.total = total
23 self.start = start
24 self.i = None
26 def __iter__(self):
27 self.i = None
28 return self
30 def __next__(self):
31 if self.step is not None:
32 if self.i is None:
33 self.i = self.start
34 else:
35 self.i += self.step
36 else:
37 if self.i is None:
38 self.i = 0
39 else:
40 self.i += 1
41
42 if self.i >= self.total:
43 raise StopIteration()
44
45 if self.step is None:
46 return tracker.get_global_step()
47 else:
48 return self.i
50 def __len__(self) -> int:
51 if self.step is not None:
52 return (self.total - self.start) // self.step
53 else:
54 return self.total
56 def __contains__(self, x: object) -> bool:
57 return False
60class TrainingLoop:
61 _iter: Optional[TrainingLoopIterator]
62 __loop: Loop
63 __signal_received: Optional[Tuple[Any, Any]]
65 def __init__(self, *,
66 loop_count: int,
67 loop_step: Optional[int],
68 log_new_line_interval: int,
69 log_write_interval: int,
70 is_loop_on_interrupt: bool):
71 self.__loop_count = loop_count
72 self.__loop_step = loop_step
73 self.__log_new_line_interval = log_new_line_interval
74 self.__log_write_interval = log_write_interval
75 self.__last_write_step = 0
76 self.__last_new_line_step = 0
77 self.__last_save_step = 0
78 self.__signal_received = None
79 self.__is_loop_on_interrupt = is_loop_on_interrupt
80 self._iter = None
82 def __iter__(self):
83 self._iter = TrainingLoopIterator(tracker.get_global_step(),
84 self.__loop_count,
85 self.__loop_step)
86
87 self.__loop = monit.loop(typing.cast(Collection, self._iter))
88
89 iter(self.__loop)
90 try:
91 self.old_handler = signal.signal(signal.SIGINT, self.__handler)
92 except ValueError:
93 pass
94 return self
96 @property
97 def idx(self):
98 if not self._iter:
99 return 0
100 if not self._iter.i:
101 return 0
102 if self.__loop_step is None:
103 return self._iter.i
104 return self._iter.i / self.__loop_step
105
106 def __finish(self):
107 try:
108 signal.signal(signal.SIGINT, self.old_handler)
109 except ValueError:
110 pass
111 tracker.save()
112 tracker.new_line()
113
114 def __next__(self):
115 if self.__signal_received is not None:
116 logger.log('\nKilling Loop.', Text.danger)
117 monit.finish_loop()
118 self.__finish()
119 raise StopIteration("SIGINT")
120
121 try:
122 global_step = next(self.__loop)
123 except StopIteration as e:
124 self.__finish()
125 raise e
126
127 tracker.set_global_step(global_step)
128
129 if global_step - self.__last_write_step >= self.__log_write_interval:
130 tracker.save()
131 self.__last_write_step = global_step
132 if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
133 tracker.new_line()
134 self.__last_new_line_step = global_step
135
136 return global_step
137
138 def __handler(self, sig, frame):
Pass second interrupt without delaying
140 if self.__signal_received is not None:
141 logger.log('\nSIGINT received twice. Stopping...', Text.danger)
142 self.old_handler(*self.__signal_received)
143 return
144
145 if self.__is_loop_on_interrupt:
Store the interrupt signal for later
147 self.__signal_received = (sig, frame)
148 logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)
149 else:
150 self.__finish()
151 logger.log('Killing loop...', Text.danger)
152 self.old_handler(sig, frame)
154 def __str__(self):
155 return "LabTrainingLoop"
This is a configurable training loop. You can extend this class for your configurations if it involves a training loop.
>>> for step in conf.training_loop: >>> ...
Arguments: loop_count (int): Total number of steps. Defaults to
10
. loop_step (int): Number of steps to increment per iteration. Defaults to
1
. log_new_line_interval (int): The interval (in steps) to print a new line to the screen. Defaults to
1
. log_write_interval (int): The interval (in steps) to call :func:
labml.tracker.save
. Defaults to
1
. is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete. Defaults to
False
.
158class TrainingLoopConfigs(BaseConfigs):
176 loop_count: int = 10
177 loop_step: int = 1
178 log_new_line_interval: int = 1
179 log_write_interval: int = 1
180 is_loop_on_interrupt: bool = False
181
182 training_loop: TrainingLoop
185@option(TrainingLoopConfigs.training_loop)
186def _loop_configs(c: TrainingLoopConfigs):
187 return TrainingLoop(loop_count=c.loop_count,
188 loop_step=c.loop_step,
189 log_new_line_interval=c.log_new_line_interval,
190 log_write_interval=c.log_write_interval,
191 is_loop_on_interrupt=c.is_loop_on_interrupt)
192
193
194meta_config(TrainingLoopConfigs.loop_step,
195 TrainingLoopConfigs.loop_count,
196 TrainingLoopConfigs.log_new_line_interval,
197 TrainingLoopConfigs.log_write_interval,
198 TrainingLoopConfigs.is_loop_on_interrupt)
199
200
201class ModeState:
202 def __init__(self):
203 self._rollback_stack = []
204
205 self.is_train = False
206 self.is_optimize = False
207
208 def _enter(self, mode: Dict[str, any]):
209 rollback = {}
210 for k, v in mode.items():
211 if v is None:
212 continue
213 rollback[k] = getattr(self, k)
214 setattr(self, k, v)
215
216 self._rollback_stack.append(rollback)
217
218 return len(self._rollback_stack)
219
220 def _exit(self, n: int):
221 assert n == len(self._rollback_stack)
222
223 rollback = self._rollback_stack[-1]
224 self._rollback_stack.pop(-1)
225
226 for k, v in rollback.items():
227 setattr(self, k, v)
228
229 def update(self, *,
230 is_train: Optional[bool] = None,
231 is_optimize: Optional[bool] = None):
232 return Mode(self,
233 is_train=is_train,
234 is_optimize=is_optimize)
235
236
237class Mode:
238 def __init__(self, mode: ModeState, **kwargs: any):
239 self.mode = mode
240 self.update = {}
241 for k, v in kwargs.items():
242 if v is not None:
243 self.update[k] = v
244
245 self.idx = -1
246
247 def __enter__(self):
248 self.idx = self.mode._enter(self.update)
249
250 def __exit__(self, exc_type, exc_val, exc_tb):
251 self.mode._exit(self.idx)
252
253
254class Trainer:
255 def __init__(self, *,
256 name: str,
257 mode: ModeState,
258 data_loader: torch.utils.data.DataLoader,
259 inner_iterations: int,
260 state_modules: List[StateModule],
261 is_track_time: bool,
262 step: Callable[[any, 'BatchIndex'], None]):
263 self.is_track_time = is_track_time
264 self.mode = mode
265 self.name = name
266 self.step = step
267 self.state_modules = state_modules
268 self.__iterable = None
269 self.__states = [sm.create_state() for sm in self.state_modules]
270 self.inner_iterations = inner_iterations
271 self.data_loader = data_loader
272 self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations)
273
274 def set_data_loader(self, data_loader: torch.utils.data.DataLoader):
275 self.data_loader = data_loader
276 self._batch_index = BatchIndex(len(data_loader), self.inner_iterations)
277 self.__iterable = None
278
279 def __call__(self):
280 for sm, s in zip(self.state_modules, self.__states):
281 sm.set_state(s)
282
283 if self.__iterable is None or self._batch_index.completed:
284 self.__iterable = iter(self.data_loader)
285 self._batch_index.reset(len(self.data_loader), self.inner_iterations)
286 for sm in self.state_modules:
287 sm.on_epoch_start()
288 with torch.set_grad_enabled(self.mode.is_train):
289 self.__iterate()
290
291 if self._batch_index.completed:
292 for sm in self.state_modules:
293 sm.on_epoch_end()
294
295 def __iterate(self):
296 with monit.section(self.name, is_partial=True, is_track=self.is_track_time):
297 if self._batch_index.idx == 0:
298 monit.progress(0)
299 while not self._batch_index.iteration_completed:
300 batch = next(self.__iterable)
301
302 self.step(batch, self._batch_index)
303
304 self._batch_index.step()
305 monit.progress(self._batch_index.epoch_progress)
306
307 self._batch_index.step_inner()
308
309
310class BatchIndex:
311 idx: int
312 total: int
313 iteration: int
314 total_iterations: int
315
316 def __init__(self, total: int, total_iterations: int):
317 self.total_iterations = total_iterations
318 self.total = total
319
320 def is_interval(self, interval: int):
321 if interval <= 0:
322 return False
323 if self.idx + 1 == self.total:
324 return True
325 else:
326 return (self.idx + 1) % interval == 0
327
328 @property
329 def is_last(self):
330 return self.idx + 1 == self.total
331
332 @property
333 def completed(self):
334 return self.iteration >= self.total_iterations
335
336 @property
337 def iteration_completed(self):
// is important so that the last step happens on the last iteration
339 return self.idx >= (self.iteration + 1) * self.total // self.total_iterations
This is a configurable module that you can extend for experiments that involve a training and validation datasets (i.e. most DL experiments).
Arguments: epochs (int): Number of epochs to train on. Defaults to
10
. train_loader (torch.utils.data.DataLoader): Training data loader. valid_loader (torch.utils.data.DataLoader): Training data loader. inner_iterations (int): Number of times to switch between training and validation within an epoch. Defaults to
1
.
You can override
init
,
step
functions. There is also a
sample
function that you can override to generate samples ever time it switches between training and validation.
341 @property
342 def epoch_progress(self):
343 return self.idx / self.total
344
345 def step(self):
346 self.idx += 1
347
348 def step_inner(self):
349 self.iteration += 1
350
351 def reset(self, total: int, total_iterations: int):
352 self.total = total
353 self.total_iterations = total_iterations
354 self.idx = 0
355 self.iteration = 0
356
357
358class TrainValidConfigs(TrainingLoopConfigs):
373 state_modules: List[StateModule]
374
375 mode: ModeState
376
377 epochs: int = 10
378
379 trainer: Trainer
380 validator: Trainer
381 train_loader: torch.utils.data.DataLoader
382 valid_loader: torch.utils.data.DataLoader
383
384 loop_count = '_data_loop_count'
385 loop_step = None
386
387 inner_iterations: int = 1
388
389 is_track_time: bool = False
391 def init(self):
392 pass
394 def step(self, batch: Any, batch_idx: BatchIndex):
395 raise NotImplementedError
397 def run_step(self):
398 for i in range(self.inner_iterations):
399 with tracker.namespace('sample'):
400 self.sample()
401 with self.mode.update(is_train=True):
402 with tracker.namespace('train'):
403 self.trainer()
404 if self.validator:
405 with tracker.namespace('valid'):
406 self.validator()
407 tracker.save()
409 def run(self):
410 with monit.section("Initialize"):
411 self.init()
412 _ = self.validator
413 _ = self.trainer
414 for _ in self.training_loop:
415 self.run_step()
417 def sample(self):
418 pass
This is a configurable module that works for many standard DL experiments.
Arguments: model: A PyTorch model. optimizer: A PyTorch optimizer to update model. device: The device to train the model on. This defaults to a configurable device loss_function: A function to calculate the loss. This should accept
model_output, target
as arguments. update_batches (int): Number of batches to accumulate before taking an optimizer step. Defaults to
1
. log_save_batches (int): How often to call :func:
labml.tracker.save
.
421@option(TrainValidConfigs.trainer)
422def _default_trainer(c: TrainValidConfigs):
423 return Trainer(name='Train',
424 mode=c.mode,
425 data_loader=c.train_loader,
426 inner_iterations=c.inner_iterations,
427 state_modules=c.state_modules,
428 is_track_time=c.is_track_time,
429 step=c.step)
430
431
432@option(TrainValidConfigs.validator)
433def _default_validator(c: TrainValidConfigs):
434 return Trainer(name='Valid',
435 mode=c.mode,
436 data_loader=c.valid_loader,
437 inner_iterations=c.inner_iterations,
438 state_modules=c.state_modules,
439 is_track_time=c.is_track_time,
440 step=c.step)
441
442
443@option(TrainValidConfigs.loop_count)
444def _data_loop_count(c: TrainValidConfigs):
445 return c.epochs
446
447
448class SimpleTrainValidConfigs(TrainValidConfigs):
462 optimizer: torch.optim.Adam
463 model: nn.Module
464 device: torch.device = DeviceConfigs()
465
466 loss_func: nn.Module
467
468 update_batches: int = 1
469 log_save_batches: int = 1
470
471 state_modules: List[StateModule] = []
473 def init(self):
474 pass
476 def step(self, batch: Any, batch_idx: BatchIndex):
477 self.model.train(self.mode.is_train)
478 data, target = batch[0].to(self.device), batch[1].to(self.device)
479
480 if self.mode.is_train:
481 tracker.add_global_step(len(data))
482
483 with monit.section("model"):
484 output = self.model(data)
485
486 loss = self.loss_func(output, target)
487 tracker.add("loss.", loss)
488
489 if self.mode.is_train:
490 with monit.section('backward'):
491 loss.backward()
492
493 if batch_idx.is_interval(self.update_batches):
494 with monit.section('optimize'):
495 self.optimizer.step()
496 self.optimizer.zero_grad()
497
498 if batch_idx.is_interval(self.log_save_batches):
499 tracker.save()
500
501
502meta_config(SimpleTrainValidConfigs.update_batches,
503 )
506@option(SimpleTrainValidConfigs.optimizer)
507def _default_optimizer(c: SimpleTrainValidConfigs):
508 from .optimizer import OptimizerConfigs
509 opt_conf = OptimizerConfigs()
510 opt_conf.parameters = c.model.parameters()
511 return opt_conf