BaseTrainer

class torch_ecg.components.BaseTrainer(model: Module, dataset_cls: Dataset, model_config: dict, train_config: dict, collate_fn: callable | None = None, device: device | None = None, lazy: bool = False)[source]

Bases: ReprMixin, ABC

Abstract base class for trainers.

A trainer is a class that contains the training pipeline, and is responsible for training a model.

Parameters:
  • model (torch.nn.Module) – The model to be trained

  • dataset_cls (torch.utils.data.Dataset) – The class of dataset to be used for training, dataset_cls should be inherited from Dataset, and be initialized via dataset_cls(config, training=True).

  • model_config (dict) – The configuration of the model, used to keep a record in the checkpoints.

  • train_config (dict) –

    The configuration of the training, including configurations for the data loader, for the optimization, etc. Will also be recorded in the checkpoints. train_config should at least contain the following keys:

    • ”monitor”: str

    • ”loss”: str

    • ”n_epochs”: int

    • ”batch_size”: int

    • ”learning_rate”: float

    • ”lr_scheduler”: str
      • ”lr_step_size”: int, optional, depending on the scheduler

      • ”lr_gamma”: float, optional, depending on the scheduler

      • ”max_lr”: float, optional, depending on the scheduler

    • ”optimizer”: str
      • ”decay”: float, optional, depending on the optimizer

      • ”momentum”: float, optional, depending on the optimizer

  • collate_fn (callable, optional) – The collate function for the data loader, defaults to default_collate_fn().

  • device (torch.device, optional) – The device to be used for training.

  • lazy (bool, default False) – Whether to initialize the data loader lazily.

abstract property batch_dim: int

The batch dimension

Usually 0, but can be 1 for some models, e.g. RR_LSTM.

abstract evaluate(data_loader: DataLoader) Dict[str, float][source]

Do evaluation on the given data loader.

Parameters:

data_loader (torch.utils.data.DataLoader) – The data loader to evaluate on.

Returns:

The evaluation results (metrics).

Return type:

dict

extra_log_suffix() str[source]

Extra suffix for the log file name.

extra_repr_keys() List[str][source]

Extra keys for __repr__() and __str__().

abstract property extra_required_train_config_fields: List[str]

Extra required fields in train_config.

property required_train_config_fields: List[str]

Required fields in train_config.

resume_from_checkpoint(checkpoint: str | dict) None[source]

NOT finished, NOT checked,

Resume a training process from a checkpoint.

Parameters:

checkpoint (str or dict) – If it is str, then it is the path of the checkpoint, which is a .pth.tar file containing a dict. checkpoint should contain at least “model_state_dict”, “optimizer_state_dict”, “model_config”, “train_config”, “epoch” to resume a training process.

abstract run_one_step(*data: Tuple[Tensor, ...]) Tuple[Tensor, ...][source]

Run one step of training on one batch of data.

Parameters:

data (Tuple[torch.Tensor]) – The data to be processed for training one step (batch), should be of the following order: signals, labels, *extra_tensors.

Returns:

The output of the model for one step (batch) data, along with labels and extra tensors. Should be of the following order: preds, labels, *extra_tensors. preds usually are NOT the logits, but tensors before fed into sigmoid() or softmax() to get the logits.

Return type:

Tuple[torch.Tensor]

save_checkpoint(path: str) None[source]

Save the current state of the trainer to a checkpoint.

Parameters:

path (str) – Path to save the checkpoint

property save_prefix: str

The prefix of the saved model name.

train() OrderedDict[source]

Train the model.

Returns:

best_state_dict – The state dict of the best model.

Return type:

OrderedDict

train_one_epoch(pbar: tqdm_asyncio) None[source]

Train one epoch, and update the progress bar

Parameters:

pbar (tqdm) – The progress bar for training.