AugmenterManager¶
- class torch_ecg.augmenters.AugmenterManager(*augs: Tuple[Augmenter, ...] | None, random: bool = False)[source]¶
Bases:
Module
The
Module
to manage the augmenters.- Parameters:
Examples
import torch from torch_ecg.cfg import CFG from torch_ecg.augmenters import AugmenterManager config = CFG( random=False, fs=500, baseline_wander={}, label_smooth={}, mixup={}, random_flip={}, random_masking={}, random_renormalize={}, stretch_compress={}, ) am = AugmenterManager.from_config(config) sig = torch.randn(32, 12, 5000) label = torch.randint(0, 2, (32, 26), dtype=torch.float32) mask1 = torch.randint(0, 2, (32, 5000, 3), dtype=torch.float32) mask2 = torch.randint(0, 3, (32, 5000), dtype=torch.long) sig, label, mask1, mask2 = am(sig, label, mask1, mask2)
- forward(sig: Tensor, label: Tensor | None, *extra_tensors: Sequence[Tensor], **kwargs: Any) Tensor | Tuple[Tensor] [source]¶
Forward the input ECGs through the augmenters.
- Parameters:
sig (torch.Tensor) – Batched ECGs to be augmented, of shape
(batch, lead, siglen)
.label (torch.Tensor, optional) – Batched labels of the ECGs.
*extra_tensors (Sequence[torch.Tensor], optional) – Extra tensors to be augmented, e.g. masks for custom loss functions, etc.
**kwargs (dict, optional) – Additional keyword arguments to be passed to the augmenters.
- Returns:
The augmented ECGs, labels, and optional extra tensors.
- Return type:
Sequence[torch.Tensor]
- classmethod from_config(config: dict) AugmenterManager [source]¶
Create an
AugmenterManager
from a configuration.- Parameters:
config (dict) – The configuration of the augmenters, better to be an
OrderedDict
.- Returns:
am – A new instance of
AugmenterManager
.- Return type: