BaseTrainer¶
- class torch_ecg.components.BaseTrainer(model: Module, dataset_cls: Dataset, model_config: dict, train_config: dict, collate_fn: callable | None = None, device: device | None = None, lazy: bool = False)[source]¶
-
Abstract base class for trainers.
A trainer is a class that contains the training pipeline, and is responsible for training a model.
- Parameters:
model (torch.nn.Module) – The model to be trained
dataset_cls (torch.utils.data.Dataset) – The class of dataset to be used for training, dataset_cls should be inherited from
Dataset
, and be initialized viadataset_cls(config, training=True)
.model_config (dict) – The configuration of the model, used to keep a record in the checkpoints.
train_config (dict) –
The configuration of the training, including configurations for the data loader, for the optimization, etc. Will also be recorded in the checkpoints. train_config should at least contain the following keys:
”monitor”: str
”loss”: str
”n_epochs”: int
”batch_size”: int
”learning_rate”: float
- ”lr_scheduler”: str
”lr_step_size”: int, optional, depending on the scheduler
”lr_gamma”: float, optional, depending on the scheduler
”max_lr”: float, optional, depending on the scheduler
- ”optimizer”: str
”decay”: float, optional, depending on the optimizer
”momentum”: float, optional, depending on the optimizer
collate_fn (callable, optional) – The collate function for the data loader, defaults to
default_collate_fn()
.device (torch.device, optional) – The device to be used for training.
lazy (bool, default False) – Whether to initialize the data loader lazily.
- abstract property batch_dim: int¶
The batch dimension
Usually 0, but can be 1 for some models, e.g.
RR_LSTM
.
- abstract evaluate(data_loader: DataLoader) Dict[str, float] [source]¶
Do evaluation on the given data loader.
- Parameters:
data_loader (torch.utils.data.DataLoader) – The data loader to evaluate on.
- Returns:
The evaluation results (metrics).
- Return type:
- abstract property extra_required_train_config_fields: List[str]¶
Extra required fields in train_config.
- resume_from_checkpoint(checkpoint: str | dict) None [source]¶
NOT finished, NOT checked,
Resume a training process from a checkpoint.
- abstract run_one_step(*data: Tuple[Tensor, ...]) Tuple[Tensor, ...] [source]¶
Run one step of training on one batch of data.
- Parameters:
data (Tuple[torch.Tensor]) – The data to be processed for training one step (batch), should be of the following order:
signals, labels, *extra_tensors
.- Returns:
The output of the model for one step (batch) data, along with labels and extra tensors. Should be of the following order:
preds, labels, *extra_tensors
. preds usually are NOT the logits, but tensors before fed intosigmoid()
orsoftmax()
to get the logits.- Return type:
Tuple[torch.Tensor]
- save_checkpoint(path: str) None [source]¶
Save the current state of the trainer to a checkpoint.
- Parameters:
path (str) – Path to save the checkpoint
- train() OrderedDict [source]¶
Train the model.
- Returns:
best_state_dict – The state dict of the best model.
- Return type:
OrderedDict