AugmenterManager

class torch_ecg.augmenters.AugmenterManager(*augs: Tuple[Augmenter, ...] | None, random: bool = False)[source]

Bases: Module

The Module to manage the augmenters.

Parameters:
  • aug (Tuple[Augmenter], optional) – The augmenters to be added to the manager.

  • random (bool, default False) – Whether to apply the augmenters in random order.

Examples

import torch
from torch_ecg.cfg import CFG
from torch_ecg.augmenters import AugmenterManager

config = CFG(
    random=False,
    fs=500,
    baseline_wander={},
    label_smooth={},
    mixup={},
    random_flip={},
    random_masking={},
    random_renormalize={},
    stretch_compress={},
)
am = AugmenterManager.from_config(config)
sig = torch.randn(32, 12, 5000)
label = torch.randint(0, 2, (32, 26), dtype=torch.float32)
mask1 = torch.randint(0, 2, (32, 5000, 3), dtype=torch.float32)
mask2 = torch.randint(0, 3, (32, 5000), dtype=torch.long)
sig, label, mask1, mask2 = am(sig, label, mask1, mask2)
property augmenters: List[Augmenter]

The list of augmenters in the manager.

extra_repr() str[source]

Extra keys for __repr__() and __str__().

forward(sig: Tensor, label: Tensor | None, *extra_tensors: Sequence[Tensor], **kwargs: Any) Tensor | Tuple[Tensor][source]

Forward the input ECGs through the augmenters.

Parameters:
  • sig (torch.Tensor) – Batched ECGs to be augmented, of shape (batch, lead, siglen).

  • label (torch.Tensor, optional) – Batched labels of the ECGs.

  • *extra_tensors (Sequence[torch.Tensor], optional) – Extra tensors to be augmented, e.g. masks for custom loss functions, etc.

  • **kwargs (dict, optional) – Additional keyword arguments to be passed to the augmenters.

Returns:

The augmented ECGs, labels, and optional extra tensors.

Return type:

Sequence[torch.Tensor]

classmethod from_config(config: dict) AugmenterManager[source]

Create an AugmenterManager from a configuration.

Parameters:

config (dict) – The configuration of the augmenters, better to be an OrderedDict.

Returns:

am – A new instance of AugmenterManager.

Return type:

AugmenterManager

rearrange(new_ordering: List[str]) None[source]

Rearrange the augmenters in the manager.

Parameters:

new_ordering (List[str]) – The list of augmenter names in the new order.