from typing import Optional, Tuple, Iterable, Any, Union, Type
from collections import defaultdict
from tqdm import tqdm as standard_tqdm
from tqdm.notebook import tqdm as notebook_tqdm
import numpy as np
import torch
from torch.nn import Module
from reflectorch.ml.loggers import Logger, Loggers
from .utils import is_divisor
__all__ = [
'Trainer',
'TrainerCallback',
'DataLoader',
'PeriodicTrainerCallback',
]
[docs]
class Trainer(object):
"""Trainer class
Args:
model (nn.Module): neural network
loader (DataLoader): data loader
lr (float): learning rate
batch_size (int): batch size
clip_grad_norm (int, optional): maximum norm for gradient clipping if it is not ``None``. Defaults to None.
logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
"""
TOTAL_LOSS_KEY: str = 'total_loss'
def __init__(self,
model: Module,
loader: 'DataLoader',
lr: float,
batch_size: int,
clip_grad_norm_max: Optional[int] = None,
logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
optim_kwargs: dict = None,
**kwargs
):
self.model = model
self.loader = loader
self.batch_size = batch_size
self.clip_grad_norm_max = clip_grad_norm_max
self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
self.lrs = []
self.losses = defaultdict(list)
self.logger = _init_logger(logger)
self.callback_params = {}
for k, v in kwargs.items():
setattr(self, k, v)
self.init()
def init(self):
pass
[docs]
def log(self, name: str, data):
"""log data"""
self.logger.log(name, data)
[docs]
def train(self,
num_batches: int,
callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
disable_tqdm: bool = False,
use_notebook_tqdm: bool = False,
update_tqdm_freq: int = 1,
grad_accumulation_steps: int = 1,
):
"""starts the training process
Args:
num_batches (int): total number of training iterations
callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
use_notebook_tqdm (bool, optional): should be set to ``True`` when used in a Jupyter Notebook. Defaults to False.
update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
"""
if isinstance(callbacks, TrainerCallback):
callbacks = (callbacks,)
callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
tqdm_class = notebook_tqdm if use_notebook_tqdm else standard_tqdm
pbar = tqdm_class(range(num_batches), disable=disable_tqdm)
callbacks.start_training(self)
for batch_num in pbar:
self.model.train()
self.optim.zero_grad()
total_loss, avr_loss_dict = 0, defaultdict(list)
for _ in range(grad_accumulation_steps):
batch_data = self.get_batch_by_idx(batch_num)
loss_dict = self.get_loss_dict(batch_data)
loss = loss_dict['loss'] / grad_accumulation_steps
total_loss += loss.item()
_update_loss_dict(avr_loss_dict, loss_dict)
if not torch.isfinite(loss).item():
raise ValueError('Loss is not finite!')
loss.backward()
if self.clip_grad_norm_max is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
self.optim.step()
avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
self._update_losses(avr_loss_dict, total_loss)
if not disable_tqdm:
self._update_tqdm(pbar, batch_num, update_tqdm_freq)
break_epoch = callbacks.end_batch(self, batch_num)
if break_epoch:
break
callbacks.end_training(self)
def _update_tqdm(self, pbar, batch_num: int, update_tqdm_freq: int):
if is_divisor(batch_num, update_tqdm_freq):
last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
pbar.set_description(f'Loss = {last_loss:.2e}')
postfix = {}
for key in self.losses.keys():
if key != self.TOTAL_LOSS_KEY:
last_value = self.losses[key][-1]
postfix[key] = f'{last_value:.4f}'
postfix['lr'] = f'{self.lr():.2e}'
pbar.set_postfix(postfix)
def get_batch_by_idx(self, batch_num: int) -> Any:
raise NotImplementedError
def get_loss_dict(self, batch_data) -> dict:
raise NotImplementedError
def _update_losses(self, loss_dict: dict, loss: float) -> None:
_update_loss_dict(self.losses, loss_dict)
self.losses[self.TOTAL_LOSS_KEY].append(loss)
self.lrs.append(self.lr())
[docs]
def lr(self, param_group: int = 0) -> float:
"""get the learning rate"""
return self.optim.param_groups[param_group]['lr']
[docs]
def set_lr(self, lr: float, param_group: int = 0) -> None:
"""set the learning rate"""
self.optim.param_groups[param_group]['lr'] = lr
[docs]
class TrainerCallback(object):
"""Base class for trainer callbacks
"""
[docs]
def start_training(self, trainer: Trainer) -> None:
"""add functionality the start of training
Args:
trainer (Trainer): the trainer object
"""
pass
[docs]
def end_training(self, trainer: Trainer) -> None:
"""add functionality at the end of training
Args:
trainer (Trainer): the trainer object
"""
pass
[docs]
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
"""add functionality at the end of the iteration / batch
Args:
trainer (Trainer): the trainer object
batch_num (int): the index of the current iteration / batch
Returns:
Union[bool, None]:
"""
pass
def __repr__(self):
return f'{self.__class__.__name__}()'
class DataLoader(TrainerCallback):
pass
[docs]
class PeriodicTrainerCallback(TrainerCallback):
"""Base class for trainer callbacks which perform an action periodically after a number of iterations
Args:
step (int, optional): Number of iterations after which the action is repeated. Defaults to 1.
last_epoch (int, optional): the last training iteration for which the action is performed. Defaults to -1.
"""
def __init__(self, step: int = 1, last_epoch: int = -1):
self.step = step
self.last_epoch = last_epoch
[docs]
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
"""add functionality at the end of the iteration / batch
Args:
trainer (Trainer): the trainer object
batch_num (int): the index of the current iteration / batch
Returns:
Union[bool, None]:
"""
if (
is_divisor(batch_num, self.step) and
(self.last_epoch == -1 or batch_num < self.last_epoch)
):
return self._end_batch(trainer, batch_num)
def _end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
pass
class _StackedTrainerCallbacks(TrainerCallback):
def __init__(self, callbacks: Iterable[TrainerCallback]):
self.callbacks = tuple(callbacks)
def start_training(self, trainer: Trainer) -> None:
for c in self.callbacks:
c.start_training(trainer)
def end_training(self, trainer: Trainer) -> None:
for c in self.callbacks:
c.end_training(trainer)
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
break_epoch = False
for c in self.callbacks:
break_epoch += bool(c.end_batch(trainer, batch_num))
return break_epoch
def __repr__(self):
callbacks = ", ".join(repr(c) for c in self.callbacks)
return f'StackedTrainerCallbacks({callbacks})'
def _init_logger(logger: Union[Logger, Tuple[Logger, ...], Loggers] = None):
if not logger:
return Logger()
if isinstance(logger, Logger):
return logger
return Loggers(*logger)
def _update_loss_dict(loss_dict: dict, new_values: dict):
for k, v in new_values.items():
loss_dict[k].append(v.item())