Source code for torch_ecg.components.trainer

"""
Abstract base class for trainers,
in order to replace the functions for classes in the training pipelines.
"""

import logging
import os
import textwrap
from abc import ABC, abstractmethod
from collections import OrderedDict, deque
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

from ..augmenters import AugmenterManager
from ..cfg import CFG, DEFAULTS
from ..models.loss import setup_criterion
from ..utils.misc import ReprMixin, dict_to_str, dicts_equal, get_date_str, get_kwargs
from ..utils.utils_nn import default_collate_fn
from .loggers import LoggerManager

__all__ = [
    "BaseTrainer",
]


[docs] class BaseTrainer(ReprMixin, ABC): """Abstract base class for trainers. A trainer is a class that contains the training pipeline, and is responsible for training a model. Parameters ---------- model : torch.nn.Module The model to be trained dataset_cls : torch.utils.data.Dataset The class of dataset to be used for training, `dataset_cls` should be inherited from :class:`~torch.utils.data.Dataset`, and be initialized via :code:`dataset_cls(config, training=True)`. model_config : dict The configuration of the model, used to keep a record in the checkpoints. train_config : dict The configuration of the training, including configurations for the data loader, for the optimization, etc. Will also be recorded in the checkpoints. `train_config` should at least contain the following keys: - "monitor": str - "loss": str - "n_epochs": int - "batch_size": int - "learning_rate": float - "lr_scheduler": str - "lr_step_size": int, optional, depending on the scheduler - "lr_gamma": float, optional, depending on the scheduler - "max_lr": float, optional, depending on the scheduler - "optimizer": str - "decay": float, optional, depending on the optimizer - "momentum": float, optional, depending on the optimizer collate_fn : callable, optional The collate function for the data loader, defaults to :meth:`default_collate_fn`. .. versionadded:: 0.0.23 device : torch.device, optional The device to be used for training. lazy : bool, default False Whether to initialize the data loader lazily. """ __name__ = "BaseTrainer" __DEFATULT_CONFIGS__ = { "debug": True, "final_model_name": None, "log_step": 10, "flooding_level": 0, "early_stopping": {}, } __DEFATULT_CONFIGS__.update(deepcopy(DEFAULTS)) def __init__( self, model: nn.Module, dataset_cls: Dataset, model_config: dict, train_config: dict, collate_fn: Optional[callable] = None, device: Optional[torch.device] = None, lazy: bool = False, ) -> None: self.model = model if type(self.model).__name__ in [ "DataParallel", ]: # TODO: further consider "DistributedDataParallel" self._model = self.model.module else: self._model = self.model self.dataset_cls = dataset_cls self.model_config = CFG(deepcopy(model_config)) self._train_config = CFG(deepcopy(train_config)) self._train_config.checkpoints = Path(self._train_config.checkpoints) self.device = device or next(self._model.parameters()).device self.dtype = next(self._model.parameters()).dtype self.model.to(self.device) self.lazy = lazy self.collate_fn = collate_fn or default_collate_fn self.log_manager = None self.augmenter_manager = None self.train_loader = None self.val_train_loader = None self.val_loader = None self._setup_from_config(self._train_config) # monitor for training: challenge metric self.best_state_dict = OrderedDict() self.best_metric = -np.inf self.best_eval_res = dict() self.best_epoch = -1 self.pseudo_best_epoch = -1 self.saved_models = deque() self.model.train() self.global_step = 0 self.epoch = 0 self.epoch_loss = 0
[docs] def train(self) -> OrderedDict: """Train the model. Returns ------- best_state_dict : OrderedDict The state dict of the best model. """ self._setup_optimizer() self._setup_scheduler() self._setup_criterion() msg = textwrap.dedent( f""" Starting training: ------------------ Epochs: {self.n_epochs} Batch size: {self.batch_size} Learning rate: {self.lr} Training size: {self.n_train} Validation size: {self.n_val} Device: {self.device.type} Optimizer: {self.train_config.optimizer} Dataset classes: {self.train_config.classes} ----------------------------------------- """ ) self.log_manager.log_message(msg) start_epoch = self.epoch for _ in range(start_epoch, self.n_epochs): # train one epoch self.model.train() self.epoch_loss = 0 with tqdm( total=self.n_train, desc=f"Epoch {self.epoch}/{self.n_epochs}", unit="signals", dynamic_ncols=True, mininterval=1.0, ) as pbar: self.log_manager.epoch_start(self.epoch) # train one epoch self.train_one_epoch(pbar) # evaluate on train set, if debug is True if self.train_config.debug: eval_train_res = self.evaluate(self.val_train_loader) self.log_manager.log_metrics( metrics=eval_train_res, step=self.global_step, epoch=self.epoch, part="train", ) # evaluate on val set if self.val_loader is not None: eval_res = self.evaluate(self.val_loader) self.log_manager.log_metrics( metrics=eval_res, step=self.global_step, epoch=self.epoch, part="val", ) elif self.val_train_loader is not None: # if no separate val set, use the metrics on the train set eval_res = eval_train_res # update best model and best metric if monitor is set if self.train_config.monitor is not None: if eval_res[self.train_config.monitor] > self.best_metric: self.best_metric = eval_res[self.train_config.monitor] self.best_state_dict = self._model.state_dict() self.best_eval_res = deepcopy(eval_res) self.best_epoch = self.epoch self.pseudo_best_epoch = self.epoch elif self.train_config.early_stopping: if eval_res[self.train_config.monitor] >= self.best_metric - self.train_config.early_stopping.min_delta: self.pseudo_best_epoch = self.epoch elif self.epoch - self.pseudo_best_epoch >= self.train_config.early_stopping.patience: msg = f"early stopping is triggered at epoch {self.epoch}" self.log_manager.log_message(msg) break msg = textwrap.dedent( f""" best metric = {self.best_metric}, obtained at epoch {self.best_epoch} """ ) self.log_manager.log_message(msg) # save checkpoint save_suffix = f"epochloss_{self.epoch_loss:.5f}_metric_{eval_res[self.train_config.monitor]:.2f}" else: save_suffix = f"epochloss_{self.epoch_loss:.5f}" save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar" save_path = self.train_config.checkpoints / save_filename if self.train_config.keep_checkpoint_max != 0: self.save_checkpoint(str(save_path)) self.saved_models.append(save_path) # remove outdated models if len(self.saved_models) > self.train_config.keep_checkpoint_max > 0: model_to_remove = self.saved_models.popleft() try: os.remove(model_to_remove) except Exception: self.log_manager.log_message(f"failed to remove {str(model_to_remove)}") # update learning rate using lr_scheduler if self.train_config.lr_scheduler.lower() == "plateau": self._update_lr(eval_res) self.log_manager.epoch_end(self.epoch) self.epoch += 1 # save the best model if self.best_metric > -np.inf: if self.train_config.final_model_name: save_filename = self.train_config.final_model_name else: save_suffix = f"metric_{self.best_eval_res[self.train_config.monitor]:.2f}" save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar" save_path = self.train_config.model_dir / save_filename self.save_checkpoint(path=str(save_path)) self.log_manager.log_message(f"best model is saved at {save_path}") elif self.train_config.monitor is None: self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") self.best_state_dict = self._model.state_dict() save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" save_path = self.train_config.model_dir / save_filename self.save_checkpoint(path=str(save_path)) else: raise ValueError("No best model found!") self.log_manager.close() if not self.best_state_dict: # in case no best model is found, # e.g. monitor is not set, or keep_checkpoint_max is 0 self.best_state_dict = self._model.state_dict() return self.best_state_dict
[docs] def train_one_epoch(self, pbar: tqdm) -> None: """Train one epoch, and update the progress bar Parameters ---------- pbar : tqdm The progress bar for training. """ for epoch_step, data in enumerate(self.train_loader): self.global_step += 1 # data is assumed to be a tuple of tensors, of the following order: # signals, labels, *extra_tensors data = self.augmenter_manager(*data) out_tensors = self.run_one_step(*data) loss = self.criterion(*out_tensors).to(self.dtype) if self.train_config.flooding_level > 0: flood = (loss - self.train_config.flooding_level).abs() + self.train_config.flooding_level self.epoch_loss += loss.item() self.optimizer.zero_grad() flood.backward() else: self.epoch_loss += loss.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() self._update_lr() if self.global_step % self.train_config.log_step == 0: train_step_metrics = {"loss": loss.item()} if self.scheduler: train_step_metrics.update({"lr": self.scheduler.get_last_lr()[0]}) pbar.set_postfix( **{ "loss (batch)": loss.item(), "lr": self.scheduler.get_last_lr()[0], } ) else: pbar.set_postfix( **{ "loss (batch)": loss.item(), } ) if self.train_config.flooding_level > 0: train_step_metrics.update({"flood": flood.item()}) self.log_manager.log_metrics( metrics=train_step_metrics, step=self.global_step, epoch=self.epoch, part="train", ) pbar.update(data[0].shape[self.batch_dim])
@property @abstractmethod def batch_dim(self) -> int: """The batch dimension Usually 0, but can be 1 for some models, e.g. :class:`~torch_ecg.models.RR_LSTM`. """ raise NotImplementedError @property @abstractmethod def extra_required_train_config_fields(self) -> List[str]: """Extra required fields in `train_config`.""" raise NotImplementedError @property def required_train_config_fields(self) -> List[str]: """Required fields in `train_config`.""" return [ "classes", # "monitor", # can be None "n_epochs", "batch_size", "log_step", "optimizer", "lr_scheduler", "learning_rate", ] + self.extra_required_train_config_fields def _validate_train_config(self) -> None: """Validate the `train_config`. Check if all required fields are present. """ for field in self.required_train_config_fields: if field not in self.train_config: raise ValueError(f"{field} is missing in train_config!") @property def save_prefix(self) -> str: """The prefix of the saved model name.""" model_name = self._model.__name__ if hasattr(self._model, "__name__") else self._model.__class__.__name__ return f"{model_name}_epoch" @property def train_config(self) -> CFG: return self._train_config
[docs] @abstractmethod def run_one_step(self, *data: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: """Run one step of training on one batch of data. Parameters ---------- data : Tuple[torch.Tensor] The data to be processed for training one step (batch), should be of the following order: ``signals, labels, *extra_tensors``. Returns ------- Tuple[torch.Tensor] The output of the model for one step (batch) data, along with labels and extra tensors. Should be of the following order: ``preds, labels, *extra_tensors``. `preds` usually are NOT the logits, but tensors before fed into :meth:`~torch.sigmoid` or :meth:`~torch.softmax` to get the logits. """ raise NotImplementedError
[docs] @torch.no_grad() @abstractmethod def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: """Do evaluation on the given data loader. Parameters ---------- data_loader : torch.utils.data.DataLoader The data loader to evaluate on. Returns ------- dict The evaluation results (metrics). """ raise NotImplementedError
def _update_lr(self, eval_res: Optional[dict] = None) -> None: """Update learning rate using lr_scheduler, perhaps based on the `eval_res`. Parameters ---------- eval_res : dict, optional The evaluation results (metrics). """ if self.train_config.lr_scheduler.lower() == "none": pass elif self.train_config.lr_scheduler.lower() == "plateau": if eval_res is None: return metrics = eval_res[self.train_config.monitor] if isinstance(metrics, torch.Tensor): metrics = metrics.item() self.scheduler.step(metrics) elif self.train_config.lr_scheduler.lower() == "step": self.scheduler.step() elif self.train_config.lr_scheduler.lower() in [ "one_cycle", "onecycle", ]: self.scheduler.step() def _setup_from_config(self, train_config: dict) -> None: """Setup the trainer from the training configuration. Parameters ---------- train_config : dict The training configuration. """ _default_config = CFG(deepcopy(self.__DEFATULT_CONFIGS__)) _default_config.update(train_config) self._train_config = CFG(deepcopy(_default_config)) # check validity of the config self._validate_train_config() # set aliases self.n_epochs = self.train_config.n_epochs self.batch_size = self.train_config.batch_size self.lr = self.train_config.learning_rate # setup log manager first self._setup_log_manager() msg = f"training configurations are as follows:\n{dict_to_str(self.train_config)}" self.log_manager.log_message(msg) # setup directories self._setup_directories() # setup callbacks self._setup_callbacks() # setup data loaders if not self.lazy: self._setup_dataloaders() # setup augmenters manager self._setup_augmenter_manager()
[docs] def extra_log_suffix(self) -> str: """Extra suffix for the log file name.""" model_name = self._model.__name__ if hasattr(self._model, "__name__") else self._model.__class__.__name__ return f"{model_name}_{self.train_config.optimizer}_LR_{self.lr}_BS_{self.batch_size}"
def _setup_log_manager(self) -> None: """Setup the log manager.""" config = {"log_suffix": self.extra_log_suffix()} config.update(self.train_config) self.log_manager = LoggerManager.from_config(config=config) def _setup_directories(self) -> None: """Setup the directories for saving checkpoints and logs.""" if not self.train_config.get("model_dir", None): self._train_config.model_dir = self.train_config.checkpoints self._train_config.model_dir = Path(self._train_config.model_dir) self.train_config.checkpoints.mkdir(parents=True, exist_ok=True) self.train_config.model_dir.mkdir(parents=True, exist_ok=True) def _setup_callbacks(self) -> None: """Setup the callbacks.""" self._train_config.monitor = self.train_config.get("monitor", None) if self.train_config.monitor is None: assert ( self.train_config.lr_scheduler.lower() != "plateau" ), "monitor is not specified, lr_scheduler should not be ReduceLROnPlateau" self._train_config.keep_checkpoint_max = self.train_config.get("keep_checkpoint_max", 1) if self._train_config.keep_checkpoint_max < 0: self._train_config.keep_checkpoint_max = -1 self.log_manager.log_message( msg="keep_checkpoint_max is set to -1, all checkpoints will be kept", level=logging.WARNING, ) elif self._train_config.keep_checkpoint_max == 0: self.log_manager.log_message( msg="keep_checkpoint_max is set to 0, no checkpoint will be kept", level=logging.WARNING, ) def _setup_augmenter_manager(self) -> None: """Setup the augmenter manager.""" self.augmenter_manager = AugmenterManager.from_config(config=self.train_config) @abstractmethod def _setup_dataloaders( self, train_dataset: Optional[Dataset] = None, val_dataset: Optional[Dataset] = None, ) -> None: """Setup the dataloaders for training and validation. Parameters ---------- train_dataset : torch.utils.data.Dataset, optional The training dataset. val_dataset : torch.utils.data.Dataset, optional The validation dataset Examples -------- .. code-block:: python if train_dataset is None: train_dataset = self.dataset_cls(config=self.train_config, training=True, lazy=False) if val_dataset is None: val_dataset = self.dataset_cls(config=self.train_config, training=False, lazy=False) num_workers = 4 self.train_loader = DataLoader( dataset=train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=False, collate_fn=self.collate_fn, ) self.val_loader = DataLoader( dataset=val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=False, collate_fn=self.collate_fn, ) """ raise NotImplementedError @property def n_train(self) -> int: if self.train_loader is not None: return len(self.train_loader.dataset) return 0 @property def n_val(self) -> int: if self.val_loader is not None: return len(self.val_loader.dataset) return 0 def _setup_optimizer(self) -> None: """Setup the optimizer.""" if self.train_config.optimizer.lower() == "adam": optimizer_kwargs = get_kwargs(optim.Adam) optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()}) optimizer_kwargs.update(dict(lr=self.lr)) self.optimizer = optim.Adam( params=self.model.parameters(), **optimizer_kwargs, ) elif self.train_config.optimizer.lower() in ["adamw", "adamw_amsgrad"]: optimizer_kwargs = get_kwargs(optim.AdamW) optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()}) optimizer_kwargs.update( dict( lr=self.lr, amsgrad=self.train_config.optimizer.lower().endswith("amsgrad"), ) ) self.optimizer = optim.AdamW( params=self.model.parameters(), **optimizer_kwargs, ) elif self.train_config.optimizer.lower() == "sgd": optimizer_kwargs = get_kwargs(optim.SGD) optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()}) optimizer_kwargs.update(dict(lr=self.lr)) self.optimizer = optim.SGD( params=self.model.parameters(), **optimizer_kwargs, ) else: raise NotImplementedError( f"optimizer `{self.train_config.optimizer}` not implemented! " "Please use one of the following: `adam`, `adamw`, `adamw_amsgrad`, `sgd`, " "or override this method to setup your own optimizer." ) def _setup_scheduler(self) -> None: """Setup the learning rate scheduler.""" if self.train_config.lr_scheduler is None or self.train_config.lr_scheduler.lower() == "none": self.train_config.lr_scheduler = "none" self.scheduler = None elif self.train_config.lr_scheduler.lower() == "plateau": self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, "max", patience=2, verbose=False, ) elif self.train_config.lr_scheduler.lower() == "step": self.scheduler = optim.lr_scheduler.StepLR( self.optimizer, self.train_config.lr_step_size, self.train_config.lr_gamma, # verbose=False, ) elif self.train_config.lr_scheduler.lower() in [ "one_cycle", "onecycle", ]: self.scheduler = optim.lr_scheduler.OneCycleLR( optimizer=self.optimizer, max_lr=self.train_config.max_lr, epochs=self.n_epochs, steps_per_epoch=len(self.train_loader), # verbose=False, ) else: # TODO: add linear and linear with warmup schedulers raise NotImplementedError( f"lr scheduler `{self.train_config.lr_scheduler.lower()}` not implemented for training! " "Please use one of the following: `none`, `plateau`, `step`, `one_cycle`, " "or override this method to setup your own lr scheduler." ) def _setup_criterion(self) -> None: """Setup the loss function.""" loss_kw = self.train_config.get("loss_kw", {}) for k, v in loss_kw.items(): if isinstance(v, torch.Tensor): loss_kw[k] = v.to(device=self.device, dtype=self.dtype) self.criterion = setup_criterion(self.train_config.loss, **loss_kw) self.criterion.to(self.device) def _check_model_config_compatability(self, model_config: dict) -> bool: """Check if `model_config` is compatible with the current model configuration. Parameters ---------- model_config : dict Model configuration from elsewhere (e.g. from a checkpoint), which should be compatible with the current model configuration. Returns ------- bool True if compatible, False otherwise """ return dicts_equal(self.model_config, model_config)
[docs] def resume_from_checkpoint(self, checkpoint: Union[str, dict]) -> None: """NOT finished, NOT checked, Resume a training process from a checkpoint. Parameters ---------- checkpoint : str or dict If it is str, then it is the path of the checkpoint, which is a ``.pth.tar`` file containing a dict. `checkpoint` should contain at least "model_state_dict", "optimizer_state_dict", "model_config", "train_config", "epoch" to resume a training process. """ if isinstance(checkpoint, str): ckpt = torch.load(checkpoint, map_location=self.device) else: ckpt = checkpoint insufficient_msg = "this checkpoint has no sufficient data to resume training" assert isinstance(ckpt, dict), insufficient_msg assert set( [ "model_state_dict", "optimizer_state_dict", "model_config", "train_config", "epoch", ] ).issubset(ckpt.keys()), insufficient_msg if not self._check_model_config_compatability(ckpt["model_config"]): raise ValueError("model config of the checkpoint is not compatible with the config of the current model") self._model.load_state_dict(ckpt["model_state_dict"]) self.epoch = ckpt["epoch"] self._setup_from_config(ckpt["train_config"])
# TODO: resume optimizer, etc.
[docs] def save_checkpoint(self, path: str) -> None: """Save the current state of the trainer to a checkpoint. Parameters ---------- path : str Path to save the checkpoint """ torch.save( { "model_state_dict": self._model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "model_config": self.model_config, "train_config": self.train_config, "epoch": self.epoch, }, path, )
[docs] def extra_repr_keys(self) -> List[str]: return [ "train_config", ]