Source code for torch_ecg.augmenters.mixup

"""
"""

from copy import deepcopy
from numbers import Real
from random import shuffle
from typing import Any, List, Optional, Sequence, Tuple

import numpy as np
import torch
from torch import Tensor

from ..cfg import DEFAULTS
from .base import Augmenter

__all__ = [
    "Mixup",
]


[docs]class Mixup(Augmenter): """Mixup augmentor. Mixup is a data augmentation technique originally proposed in [1]_. The PDF file of the paper can be found on arXiv [2]_. The official implementation is provided in [3]_. This technique was designed for image classification tasks, but it is also widely used for ECG tasks. Parameters ---------- fs : int, optional Sampling frequency of the ECGs to be augmented. alpha : numbers.Real, default 0.5 alpha parameter of the Beta distribution used in Mixup. beta : numbers.Real, optional beta parameter of the Beta distribution used in Mixup, defaults to `alpha`. prob : float, default 0.5 Probability of applying Mixup. inplace : bool, default True If True, ECG signal tensors will be modified inplace. **kwargs : dict, optional Additional keyword arguments, not used. Examples -------- .. code-block:: python mixup = Mixup(alpha=0.3, beta=0.6, prob=0.7) sig = torch.randn(32, 12, 5000) label = torch.randint(0, 2, (32, 26), dtype=torch.float32) sig, label = mixup(sig, label) References ---------- .. [1] Zhang, Hongyi, et al. "mixup: Beyond Empirical Risk Minimization." International Conference on Learning Representations. 2018. .. [2] https://arxiv.org/abs/1710.09412 .. [3] https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py """ __name__ = "Mixup" def __init__( self, fs: Optional[int] = None, alpha: Real = 0.5, beta: Optional[Real] = None, prob: float = 0.5, inplace: bool = True, **kwargs: Any ) -> None: super().__init__() self.fs = fs self.alpha = alpha self.beta = beta or self.alpha self.prob = prob assert 0 <= self.prob <= 1, "Probability must be between 0 and 1" self.inplace = inplace
[docs] def forward(self, sig: Tensor, label: Tensor, *extra_tensors: Sequence[Tensor], **kwargs: Any) -> Tuple[Tensor, ...]: """Forward method of the Mixup augmenter. Parameters ---------- sig : torch.Tensor Batched ECGs to be augmented, of shape ``(batch, lead, siglen)``. label : torch.Tensor Label tensor of the ECGs. extra_tensors: Sequence[torch.Tensor], optional Not used, but kept for consistency with other augmenters. **kwargs : dict, optional Not used, but kept for consistency with other augmenters. Returns ------- sig : torch.Tensor, The augmented ECGs. label : torch.Tensor The augmented labels. extra_tensors : Sequence[torch.Tensor], optional Unchanged extra tensors. """ batch, lead, siglen = sig.shape # TODO: make `lam` different for each batch element, using # lam = torch.from_numpy(DEFAULTS.RNG.beta(self.alpha, self.beta, batch), dtype=sig.dtype, device=sig.device) # of shape (batch,) lam = DEFAULTS.RNG.beta(self.alpha, self.beta) indices = np.arange(batch, dtype=int) ori = self.get_indices(prob=self.prob, pop_size=batch) # print(f"ori = {ori}, len(ori) = {len(ori)}") perm = deepcopy(ori) shuffle(perm) indices[ori] = perm indices = torch.from_numpy(indices).long() if not self.inplace: sig = sig.clone() label = label.clone() sig = lam * sig + (1 - lam) * sig[indices] label = lam * label + (1 - lam) * label[indices] # TODO: if `lam` is a Tensor of shape (batch,) instead of a scalar, # sig = lam.view(batch, 1, 1) * sig + (1 - lam).view(batch, 1, 1) * sig[indices] # label = lam.view(batch, 1) * label + (1 - lam).view(batch, 1) * label[indices] return (sig, label, *extra_tensors)
[docs] def extra_repr_keys(self) -> List[str]: return [ "alpha", "beta", "prob", "inplace", ] + super().extra_repr_keys()