"""
Abstract base class for trainers,
in order to replace the functions for classes in the training pipelines.
"""
import logging
import os
import textwrap
from abc import ABC, abstractmethod
from collections import OrderedDict, deque
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from ..augmenters import AugmenterManager
from ..cfg import CFG, DEFAULTS
from ..models.loss import setup_criterion
from ..utils.misc import ReprMixin, dict_to_str, dicts_equal, get_date_str, get_kwargs
from ..utils.utils_nn import default_collate_fn
from .loggers import LoggerManager
__all__ = [
"BaseTrainer",
]
[docs]
class BaseTrainer(ReprMixin, ABC):
"""Abstract base class for trainers.
A trainer is a class that contains the training pipeline,
and is responsible for training a model.
Parameters
----------
model : torch.nn.Module
The model to be trained
dataset_cls : torch.utils.data.Dataset
The class of dataset to be used for training,
`dataset_cls` should be inherited from :class:`~torch.utils.data.Dataset`,
and be initialized via :code:`dataset_cls(config, training=True)`.
model_config : dict
The configuration of the model,
used to keep a record in the checkpoints.
train_config : dict
The configuration of the training,
including configurations for the data loader, for the optimization, etc.
Will also be recorded in the checkpoints.
`train_config` should at least contain the following keys:
- "monitor": str
- "loss": str
- "n_epochs": int
- "batch_size": int
- "learning_rate": float
- "lr_scheduler": str
- "lr_step_size": int, optional, depending on the scheduler
- "lr_gamma": float, optional, depending on the scheduler
- "max_lr": float, optional, depending on the scheduler
- "optimizer": str
- "decay": float, optional, depending on the optimizer
- "momentum": float, optional, depending on the optimizer
collate_fn : callable, optional
The collate function for the data loader,
defaults to :meth:`default_collate_fn`.
.. versionadded:: 0.0.23
device : torch.device, optional
The device to be used for training.
lazy : bool, default False
Whether to initialize the data loader lazily.
"""
__name__ = "BaseTrainer"
__DEFATULT_CONFIGS__ = {
"debug": True,
"final_model_name": None,
"log_step": 10,
"flooding_level": 0,
"early_stopping": {},
}
__DEFATULT_CONFIGS__.update(deepcopy(DEFAULTS))
def __init__(
self,
model: nn.Module,
dataset_cls: Dataset,
model_config: dict,
train_config: dict,
collate_fn: Optional[callable] = None,
device: Optional[torch.device] = None,
lazy: bool = False,
) -> None:
self.model = model
if type(self.model).__name__ in [
"DataParallel",
]:
# TODO: further consider "DistributedDataParallel"
self._model = self.model.module
else:
self._model = self.model
self.dataset_cls = dataset_cls
self.model_config = CFG(deepcopy(model_config))
self._train_config = CFG(deepcopy(train_config))
self._train_config.checkpoints = Path(self._train_config.checkpoints)
self.device = device or next(self._model.parameters()).device
self.dtype = next(self._model.parameters()).dtype
self.model.to(self.device)
self.lazy = lazy
self.collate_fn = collate_fn or default_collate_fn
self.log_manager = None
self.augmenter_manager = None
self.train_loader = None
self.val_train_loader = None
self.val_loader = None
self._setup_from_config(self._train_config)
# monitor for training: challenge metric
self.best_state_dict = OrderedDict()
self.best_metric = -np.inf
self.best_eval_res = dict()
self.best_epoch = -1
self.pseudo_best_epoch = -1
self.saved_models = deque()
self.model.train()
self.global_step = 0
self.epoch = 0
self.epoch_loss = 0
[docs]
def train(self) -> OrderedDict:
"""Train the model.
Returns
-------
best_state_dict : OrderedDict
The state dict of the best model.
"""
self._setup_optimizer()
self._setup_scheduler()
self._setup_criterion()
msg = textwrap.dedent(
f"""
Starting training:
------------------
Epochs: {self.n_epochs}
Batch size: {self.batch_size}
Learning rate: {self.lr}
Training size: {self.n_train}
Validation size: {self.n_val}
Device: {self.device.type}
Optimizer: {self.train_config.optimizer}
Dataset classes: {self.train_config.classes}
-----------------------------------------
"""
)
self.log_manager.log_message(msg)
start_epoch = self.epoch
for _ in range(start_epoch, self.n_epochs):
# train one epoch
self.model.train()
self.epoch_loss = 0
with tqdm(
total=self.n_train,
desc=f"Epoch {self.epoch}/{self.n_epochs}",
unit="signals",
dynamic_ncols=True,
mininterval=1.0,
) as pbar:
self.log_manager.epoch_start(self.epoch)
# train one epoch
self.train_one_epoch(pbar)
# evaluate on train set, if debug is True
if self.train_config.debug:
eval_train_res = self.evaluate(self.val_train_loader)
self.log_manager.log_metrics(
metrics=eval_train_res,
step=self.global_step,
epoch=self.epoch,
part="train",
)
# evaluate on val set
if self.val_loader is not None:
eval_res = self.evaluate(self.val_loader)
self.log_manager.log_metrics(
metrics=eval_res,
step=self.global_step,
epoch=self.epoch,
part="val",
)
elif self.val_train_loader is not None:
# if no separate val set, use the metrics on the train set
eval_res = eval_train_res
# update best model and best metric if monitor is set
if self.train_config.monitor is not None:
if eval_res[self.train_config.monitor] > self.best_metric:
self.best_metric = eval_res[self.train_config.monitor]
self.best_state_dict = self._model.state_dict()
self.best_eval_res = deepcopy(eval_res)
self.best_epoch = self.epoch
self.pseudo_best_epoch = self.epoch
elif self.train_config.early_stopping:
if eval_res[self.train_config.monitor] >= self.best_metric - self.train_config.early_stopping.min_delta:
self.pseudo_best_epoch = self.epoch
elif self.epoch - self.pseudo_best_epoch >= self.train_config.early_stopping.patience:
msg = f"early stopping is triggered at epoch {self.epoch}"
self.log_manager.log_message(msg)
break
msg = textwrap.dedent(
f"""
best metric = {self.best_metric},
obtained at epoch {self.best_epoch}
"""
)
self.log_manager.log_message(msg)
# save checkpoint
save_suffix = f"epochloss_{self.epoch_loss:.5f}_metric_{eval_res[self.train_config.monitor]:.2f}"
else:
save_suffix = f"epochloss_{self.epoch_loss:.5f}"
save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar"
save_path = self.train_config.checkpoints / save_filename
if self.train_config.keep_checkpoint_max != 0:
self.save_checkpoint(str(save_path))
self.saved_models.append(save_path)
# remove outdated models
if len(self.saved_models) > self.train_config.keep_checkpoint_max > 0:
model_to_remove = self.saved_models.popleft()
try:
os.remove(model_to_remove)
except Exception:
self.log_manager.log_message(f"failed to remove {str(model_to_remove)}")
# update learning rate using lr_scheduler
if self.train_config.lr_scheduler.lower() == "plateau":
self._update_lr(eval_res)
self.log_manager.epoch_end(self.epoch)
self.epoch += 1
# save the best model
if self.best_metric > -np.inf:
if self.train_config.final_model_name:
save_filename = self.train_config.final_model_name
else:
save_suffix = f"metric_{self.best_eval_res[self.train_config.monitor]:.2f}"
save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar"
save_path = self.train_config.model_dir / save_filename
self.save_checkpoint(path=str(save_path))
self.log_manager.log_message(f"best model is saved at {save_path}")
elif self.train_config.monitor is None:
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model")
self.best_state_dict = self._model.state_dict()
save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
save_path = self.train_config.model_dir / save_filename
self.save_checkpoint(path=str(save_path))
else:
raise ValueError("No best model found!")
self.log_manager.close()
if not self.best_state_dict:
# in case no best model is found,
# e.g. monitor is not set, or keep_checkpoint_max is 0
self.best_state_dict = self._model.state_dict()
return self.best_state_dict
[docs]
def train_one_epoch(self, pbar: tqdm) -> None:
"""Train one epoch, and update the progress bar
Parameters
----------
pbar : tqdm
The progress bar for training.
"""
for epoch_step, data in enumerate(self.train_loader):
self.global_step += 1
# data is assumed to be a tuple of tensors, of the following order:
# signals, labels, *extra_tensors
data = self.augmenter_manager(*data)
out_tensors = self.run_one_step(*data)
loss = self.criterion(*out_tensors).to(self.dtype)
if self.train_config.flooding_level > 0:
flood = (loss - self.train_config.flooding_level).abs() + self.train_config.flooding_level
self.epoch_loss += loss.item()
self.optimizer.zero_grad()
flood.backward()
else:
self.epoch_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self._update_lr()
if self.global_step % self.train_config.log_step == 0:
train_step_metrics = {"loss": loss.item()}
if self.scheduler:
train_step_metrics.update({"lr": self.scheduler.get_last_lr()[0]})
pbar.set_postfix(
**{
"loss (batch)": loss.item(),
"lr": self.scheduler.get_last_lr()[0],
}
)
else:
pbar.set_postfix(
**{
"loss (batch)": loss.item(),
}
)
if self.train_config.flooding_level > 0:
train_step_metrics.update({"flood": flood.item()})
self.log_manager.log_metrics(
metrics=train_step_metrics,
step=self.global_step,
epoch=self.epoch,
part="train",
)
pbar.update(data[0].shape[self.batch_dim])
@property
@abstractmethod
def batch_dim(self) -> int:
"""The batch dimension
Usually 0, but can be 1 for some models, e.g. :class:`~torch_ecg.models.RR_LSTM`.
"""
raise NotImplementedError
@property
@abstractmethod
def extra_required_train_config_fields(self) -> List[str]:
"""Extra required fields in `train_config`."""
raise NotImplementedError
@property
def required_train_config_fields(self) -> List[str]:
"""Required fields in `train_config`."""
return [
"classes",
# "monitor", # can be None
"n_epochs",
"batch_size",
"log_step",
"optimizer",
"lr_scheduler",
"learning_rate",
] + self.extra_required_train_config_fields
def _validate_train_config(self) -> None:
"""Validate the `train_config`.
Check if all required fields are present.
"""
for field in self.required_train_config_fields:
if field not in self.train_config:
raise ValueError(f"{field} is missing in train_config!")
@property
def save_prefix(self) -> str:
"""The prefix of the saved model name."""
model_name = self._model.__name__ if hasattr(self._model, "__name__") else self._model.__class__.__name__
return f"{model_name}_epoch"
@property
def train_config(self) -> CFG:
return self._train_config
[docs]
@abstractmethod
def run_one_step(self, *data: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
"""Run one step of training on one batch of data.
Parameters
----------
data : Tuple[torch.Tensor]
The data to be processed for training one step (batch),
should be of the following order:
``signals, labels, *extra_tensors``.
Returns
-------
Tuple[torch.Tensor]
The output of the model for one step (batch) data,
along with labels and extra tensors.
Should be of the following order:
``preds, labels, *extra_tensors``.
`preds` usually are NOT the logits,
but tensors before fed into :meth:`~torch.sigmoid`
or :meth:`~torch.softmax` to get the logits.
"""
raise NotImplementedError
[docs]
@torch.no_grad()
@abstractmethod
def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
"""Do evaluation on the given data loader.
Parameters
----------
data_loader : torch.utils.data.DataLoader
The data loader to evaluate on.
Returns
-------
dict
The evaluation results (metrics).
"""
raise NotImplementedError
def _update_lr(self, eval_res: Optional[dict] = None) -> None:
"""Update learning rate using lr_scheduler,
perhaps based on the `eval_res`.
Parameters
----------
eval_res : dict, optional
The evaluation results (metrics).
"""
if self.train_config.lr_scheduler.lower() == "none":
pass
elif self.train_config.lr_scheduler.lower() == "plateau":
if eval_res is None:
return
metrics = eval_res[self.train_config.monitor]
if isinstance(metrics, torch.Tensor):
metrics = metrics.item()
self.scheduler.step(metrics)
elif self.train_config.lr_scheduler.lower() == "step":
self.scheduler.step()
elif self.train_config.lr_scheduler.lower() in [
"one_cycle",
"onecycle",
]:
self.scheduler.step()
def _setup_from_config(self, train_config: dict) -> None:
"""Setup the trainer from the training configuration.
Parameters
----------
train_config : dict
The training configuration.
"""
_default_config = CFG(deepcopy(self.__DEFATULT_CONFIGS__))
_default_config.update(train_config)
self._train_config = CFG(deepcopy(_default_config))
# check validity of the config
self._validate_train_config()
# set aliases
self.n_epochs = self.train_config.n_epochs
self.batch_size = self.train_config.batch_size
self.lr = self.train_config.learning_rate
# setup log manager first
self._setup_log_manager()
msg = f"training configurations are as follows:\n{dict_to_str(self.train_config)}"
self.log_manager.log_message(msg)
# setup directories
self._setup_directories()
# setup callbacks
self._setup_callbacks()
# setup data loaders
if not self.lazy:
self._setup_dataloaders()
# setup augmenters manager
self._setup_augmenter_manager()
def _setup_log_manager(self) -> None:
"""Setup the log manager."""
config = {"log_suffix": self.extra_log_suffix()}
config.update(self.train_config)
self.log_manager = LoggerManager.from_config(config=config)
def _setup_directories(self) -> None:
"""Setup the directories for saving checkpoints and logs."""
if not self.train_config.get("model_dir", None):
self._train_config.model_dir = self.train_config.checkpoints
self._train_config.model_dir = Path(self._train_config.model_dir)
self.train_config.checkpoints.mkdir(parents=True, exist_ok=True)
self.train_config.model_dir.mkdir(parents=True, exist_ok=True)
def _setup_callbacks(self) -> None:
"""Setup the callbacks."""
self._train_config.monitor = self.train_config.get("monitor", None)
if self.train_config.monitor is None:
assert (
self.train_config.lr_scheduler.lower() != "plateau"
), "monitor is not specified, lr_scheduler should not be ReduceLROnPlateau"
self._train_config.keep_checkpoint_max = self.train_config.get("keep_checkpoint_max", 1)
if self._train_config.keep_checkpoint_max < 0:
self._train_config.keep_checkpoint_max = -1
self.log_manager.log_message(
msg="keep_checkpoint_max is set to -1, all checkpoints will be kept",
level=logging.WARNING,
)
elif self._train_config.keep_checkpoint_max == 0:
self.log_manager.log_message(
msg="keep_checkpoint_max is set to 0, no checkpoint will be kept",
level=logging.WARNING,
)
def _setup_augmenter_manager(self) -> None:
"""Setup the augmenter manager."""
self.augmenter_manager = AugmenterManager.from_config(config=self.train_config)
@abstractmethod
def _setup_dataloaders(
self,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
) -> None:
"""Setup the dataloaders for training and validation.
Parameters
----------
train_dataset : torch.utils.data.Dataset, optional
The training dataset.
val_dataset : torch.utils.data.Dataset, optional
The validation dataset
Examples
--------
.. code-block:: python
if train_dataset is None:
train_dataset = self.dataset_cls(config=self.train_config, training=True, lazy=False)
if val_dataset is None:
val_dataset = self.dataset_cls(config=self.train_config, training=False, lazy=False)
num_workers = 4
self.train_loader = DataLoader(
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
collate_fn=self.collate_fn,
)
self.val_loader = DataLoader(
dataset=val_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
collate_fn=self.collate_fn,
)
"""
raise NotImplementedError
@property
def n_train(self) -> int:
if self.train_loader is not None:
return len(self.train_loader.dataset)
return 0
@property
def n_val(self) -> int:
if self.val_loader is not None:
return len(self.val_loader.dataset)
return 0
def _setup_optimizer(self) -> None:
"""Setup the optimizer."""
if self.train_config.optimizer.lower() == "adam":
optimizer_kwargs = get_kwargs(optim.Adam)
optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()})
optimizer_kwargs.update(dict(lr=self.lr))
self.optimizer = optim.Adam(
params=self.model.parameters(),
**optimizer_kwargs,
)
elif self.train_config.optimizer.lower() in ["adamw", "adamw_amsgrad"]:
optimizer_kwargs = get_kwargs(optim.AdamW)
optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()})
optimizer_kwargs.update(
dict(
lr=self.lr,
amsgrad=self.train_config.optimizer.lower().endswith("amsgrad"),
)
)
self.optimizer = optim.AdamW(
params=self.model.parameters(),
**optimizer_kwargs,
)
elif self.train_config.optimizer.lower() == "sgd":
optimizer_kwargs = get_kwargs(optim.SGD)
optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()})
optimizer_kwargs.update(dict(lr=self.lr))
self.optimizer = optim.SGD(
params=self.model.parameters(),
**optimizer_kwargs,
)
else:
raise NotImplementedError(
f"optimizer `{self.train_config.optimizer}` not implemented! "
"Please use one of the following: `adam`, `adamw`, `adamw_amsgrad`, `sgd`, "
"or override this method to setup your own optimizer."
)
def _setup_scheduler(self) -> None:
"""Setup the learning rate scheduler."""
if self.train_config.lr_scheduler is None or self.train_config.lr_scheduler.lower() == "none":
self.train_config.lr_scheduler = "none"
self.scheduler = None
elif self.train_config.lr_scheduler.lower() == "plateau":
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
"max",
patience=2,
verbose=False,
)
elif self.train_config.lr_scheduler.lower() == "step":
self.scheduler = optim.lr_scheduler.StepLR(
self.optimizer,
self.train_config.lr_step_size,
self.train_config.lr_gamma,
# verbose=False,
)
elif self.train_config.lr_scheduler.lower() in [
"one_cycle",
"onecycle",
]:
self.scheduler = optim.lr_scheduler.OneCycleLR(
optimizer=self.optimizer,
max_lr=self.train_config.max_lr,
epochs=self.n_epochs,
steps_per_epoch=len(self.train_loader),
# verbose=False,
)
else: # TODO: add linear and linear with warmup schedulers
raise NotImplementedError(
f"lr scheduler `{self.train_config.lr_scheduler.lower()}` not implemented for training! "
"Please use one of the following: `none`, `plateau`, `step`, `one_cycle`, "
"or override this method to setup your own lr scheduler."
)
def _setup_criterion(self) -> None:
"""Setup the loss function."""
loss_kw = self.train_config.get("loss_kw", {})
for k, v in loss_kw.items():
if isinstance(v, torch.Tensor):
loss_kw[k] = v.to(device=self.device, dtype=self.dtype)
self.criterion = setup_criterion(self.train_config.loss, **loss_kw)
self.criterion.to(self.device)
def _check_model_config_compatability(self, model_config: dict) -> bool:
"""Check if `model_config` is compatible with the current model configuration.
Parameters
----------
model_config : dict
Model configuration from elsewhere (e.g. from a checkpoint),
which should be compatible with the current model configuration.
Returns
-------
bool
True if compatible, False otherwise
"""
return dicts_equal(self.model_config, model_config)
[docs]
def resume_from_checkpoint(self, checkpoint: Union[str, dict]) -> None:
"""NOT finished, NOT checked,
Resume a training process from a checkpoint.
Parameters
----------
checkpoint : str or dict
If it is str, then it is the path of the checkpoint,
which is a ``.pth.tar`` file containing a dict.
`checkpoint` should contain at least
"model_state_dict", "optimizer_state_dict",
"model_config", "train_config", "epoch"
to resume a training process.
"""
if isinstance(checkpoint, str):
ckpt = torch.load(checkpoint, map_location=self.device)
else:
ckpt = checkpoint
insufficient_msg = "this checkpoint has no sufficient data to resume training"
assert isinstance(ckpt, dict), insufficient_msg
assert set(
[
"model_state_dict",
"optimizer_state_dict",
"model_config",
"train_config",
"epoch",
]
).issubset(ckpt.keys()), insufficient_msg
if not self._check_model_config_compatability(ckpt["model_config"]):
raise ValueError("model config of the checkpoint is not compatible with the config of the current model")
self._model.load_state_dict(ckpt["model_state_dict"])
self.epoch = ckpt["epoch"]
self._setup_from_config(ckpt["train_config"])
# TODO: resume optimizer, etc.
[docs]
def save_checkpoint(self, path: str) -> None:
"""Save the current state of the trainer to a checkpoint.
Parameters
----------
path : str
Path to save the checkpoint
"""
torch.save(
{
"model_state_dict": self._model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"model_config": self.model_config,
"train_config": self.train_config,
"epoch": self.epoch,
},
path,
)