Augmenter¶
- class torch_ecg.augmenters.Augmenter[source]¶
-
Base class for augmenters.
An Augmentor performs data augmentation on the input ECGs, labels, and optional extra tensors.
- abstract forward(sig: Tensor, label: Tensor | None = None, *extra_tensors: Sequence[Tensor], **kwargs: Any) Tuple[Tensor, ...] [source]¶
Forward method of the augmenter.
- Parameters:
sig (torch.Tensor) – Batched ECGs to be augmented, of shape
(batch, lead, siglen)
.label (torch.Tensor, optional) – Batched labels of the ECGs.
*extra_tensors (Sequence[torch.Tensor], optional) – Extra tensors to be augmented, e.g. masks for custom loss functions, etc.
**kwargs (dict, optional) – Additional keyword arguments to be passed to the augmenters.
- Returns:
The augmented ECGs, labels, and optional extra tensors.
- Return type:
Sequence[torch.Tensor]
- get_indices(prob: float, pop_size: int, scale_ratio: float = 0.1) List[int] [source]¶
Get a list of indices to be selected.
A random list of indices in the range
[0, pop_size-1]
is generated, with the probability of each index to be selected.- Parameters:
- Returns:
indices – A list of indices.
- Return type:
List[int],
TODO
Add parameter min_dist so that any 2 selected indices are at least min_dist apart.