Source code for reflectorch.ml.basic_trainer

from typing import Optional, Tuple, Iterable, Any, Union, Type
from collections import defaultdict

from tqdm.notebook import trange
import numpy as np

import torch
from torch.nn import Module

from reflectorch.ml.loggers import Logger, Loggers

from .utils import is_divisor

__all__ = [
    'Trainer',
    'TrainerCallback',
    'DataLoader',
    'PeriodicTrainerCallback',
]


[docs] class Trainer(object): """Trainer class Args: model (nn.Module): neural network loader (DataLoader): data loader lr (float): learning rate batch_size (int): batch size clip_grad_norm (int, optional): maximum norm for gradient clipping if it is not ``None``. Defaults to None. logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None. optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam. optim_kwargs (dict, optional): optimizer arguments. Defaults to None. train_with_q_input (bool, optional): if ``True`` the q values are also used as input. Defaults to False. """ TOTAL_LOSS_KEY: str = 'total_loss' def __init__(self, model: Module, loader: 'DataLoader', lr: float, batch_size: int, clip_grad_norm_max: Optional[int] = None, train_with_q_input: bool = False, logger: Union[Logger, Tuple[Logger, ...], Loggers] = None, optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam, optim_kwargs: dict = None, **kwargs ): self.model = model self.loader = loader self.batch_size = batch_size self.clip_grad_norm_max = clip_grad_norm_max self.train_with_q_input = train_with_q_input self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {})) self.lrs = [] self.losses = defaultdict(list) self.logger = _init_logger(logger) self.callback_params = {} for k, v in kwargs.items(): setattr(self, k, v) self.init() def init(self): pass
[docs] def log(self, name: str, data): """log data""" self.logger.log(name, data)
[docs] def train(self, num_batches: int, callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (), disable_tqdm: bool = False, update_tqdm_freq: int = 10, grad_accumulation_steps: int = 1, ): """starts the training process Args: num_batches (int): total number of training iterations callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to (). disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False. update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10. grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1. """ if isinstance(callbacks, TrainerCallback): callbacks = (callbacks,) callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader]) pbar = trange(num_batches, disable=disable_tqdm) callbacks.start_training(self) for batch_num in pbar: self.model.train() self.optim.zero_grad() total_loss, avr_loss_dict = 0, defaultdict(list) for _ in range(grad_accumulation_steps): batch_data = self.get_batch_by_idx(batch_num) loss_dict = self.get_loss_dict(batch_data) loss = loss_dict['loss'] / grad_accumulation_steps total_loss += loss.item() _update_loss_dict(avr_loss_dict, loss_dict) if not torch.isfinite(loss).item(): raise ValueError('Loss is not finite!') loss.backward() if self.clip_grad_norm_max is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max) self.optim.step() avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()} self._update_losses(avr_loss_dict, total_loss) if not disable_tqdm: self._update_tqdm(pbar, batch_num, update_tqdm_freq) break_epoch = callbacks.end_batch(self, batch_num) if break_epoch: break callbacks.end_training(self)
def _update_tqdm(self, pbar, batch_num: int, update_tqdm_freq: int): if is_divisor(batch_num, update_tqdm_freq): last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:]) pbar.set_description(f'Loss = {last_loss:.2e}') def get_batch_by_idx(self, batch_num: int) -> Any: raise NotImplementedError def get_loss_dict(self, batch_data) -> dict: raise NotImplementedError def _update_losses(self, loss_dict: dict, loss: float) -> None: _update_loss_dict(self.losses, loss_dict) self.losses[self.TOTAL_LOSS_KEY].append(loss) self.lrs.append(self.lr())
[docs] def configure_optimizer(self, optim_cls, lr: float, **kwargs) -> torch.optim.Optimizer: """configure the optimizer based on the optimizer class, the learning rate and the optimizer keyword arguments Args: optim_cls: the class of the optimizer lr (float): the learning rate Returns: torch.optim.Optimizer: """ optim = optim_cls(self.model.parameters(), lr, **kwargs) return optim
[docs] def lr(self, param_group: int = 0) -> float: """get the learning rate""" return self.optim.param_groups[param_group]['lr']
[docs] def set_lr(self, lr: float, param_group: int = 0) -> None: """set the learning rate""" self.optim.param_groups[param_group]['lr'] = lr
[docs] class TrainerCallback(object): """Base class for trainer callbacks """
[docs] def start_training(self, trainer: Trainer) -> None: """add functionality the start of training Args: trainer (Trainer): the trainer object """ pass
[docs] def end_training(self, trainer: Trainer) -> None: """add functionality at the end of training Args: trainer (Trainer): the trainer object """ pass
[docs] def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]: """add functionality at the end of the iteration / batch Args: trainer (Trainer): the trainer object batch_num (int): the index of the current iteration / batch Returns: Union[bool, None]: """ pass
def __repr__(self): return f'{self.__class__.__name__}()'
class DataLoader(TrainerCallback): pass
[docs] class PeriodicTrainerCallback(TrainerCallback): """Base class for trainer callbacks which perform an action periodically after a number of iterations Args: step (int, optional): Number of iterations after which the action is repeated. Defaults to 1. last_epoch (int, optional): the last training iteration for which the action is performed. Defaults to -1. """ def __init__(self, step: int = 1, last_epoch: int = -1): self.step = step self.last_epoch = last_epoch
[docs] def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]: """add functionality at the end of the iteration / batch Args: trainer (Trainer): the trainer object batch_num (int): the index of the current iteration / batch Returns: Union[bool, None]: """ if ( is_divisor(batch_num, self.step) and (self.last_epoch == -1 or batch_num < self.last_epoch) ): return self._end_batch(trainer, batch_num)
def _end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]: pass
class _StackedTrainerCallbacks(TrainerCallback): def __init__(self, callbacks: Iterable[TrainerCallback]): self.callbacks = tuple(callbacks) def start_training(self, trainer: Trainer) -> None: for c in self.callbacks: c.start_training(trainer) def end_training(self, trainer: Trainer) -> None: for c in self.callbacks: c.end_training(trainer) def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]: break_epoch = False for c in self.callbacks: break_epoch += bool(c.end_batch(trainer, batch_num)) return break_epoch def __repr__(self): callbacks = ", ".join(repr(c) for c in self.callbacks) return f'StackedTrainerCallbacks({callbacks})' def _init_logger(logger: Union[Logger, Tuple[Logger, ...], Loggers] = None): if not logger: return Logger() if isinstance(logger, Logger): return logger return Loggers(*logger) def _update_loss_dict(loss_dict: dict, new_values: dict): for k, v in new_values.items(): loss_dict[k].append(v.item())