from copy import deepcopy
from numbers import Real
from random import shuffle
from typing import Any, List, Optional, Sequence, Tuple
import numpy as np
import torch
from torch import Tensor
from ..cfg import DEFAULTS
from .base import Augmenter
__all__ = [
[docs]class Mixup(Augmenter):
"""Mixup augmentor.
Mixup is a data augmentation technique originally proposed in [1]_.
The PDF file of the paper can be found on arXiv [2]_.
The official implementation is provided in [3]_. This technique was designed
for image classification tasks, but it is also widely used for ECG tasks.
fs : int, optional
Sampling frequency of the ECGs to be augmented.
alpha : numbers.Real, default 0.5
alpha parameter of the Beta distribution used in Mixup.
beta : numbers.Real, optional
beta parameter of the Beta distribution used in Mixup,
defaults to `alpha`.
prob : float, default 0.5
Probability of applying Mixup.
inplace : bool, default True
If True, ECG signal tensors will be modified inplace.
**kwargs : dict, optional
Additional keyword arguments, not used.
.. code-block:: python
mixup = Mixup(alpha=0.3, beta=0.6, prob=0.7)
sig = torch.randn(32, 12, 5000)
label = torch.randint(0, 2, (32, 26), dtype=torch.float32)
sig, label = mixup(sig, label)
.. [1] Zhang, Hongyi, et al. "mixup: Beyond Empirical Risk Minimization."
International Conference on Learning Representations. 2018.
.. [2] https://arxiv.org/abs/1710.09412
.. [3] https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py
__name__ = "Mixup"
def __init__(
fs: Optional[int] = None,
alpha: Real = 0.5,
beta: Optional[Real] = None,
prob: float = 0.5,
inplace: bool = True,
**kwargs: Any
) -> None:
self.fs = fs
self.alpha = alpha
self.beta = beta or self.alpha
self.prob = prob
assert 0 <= self.prob <= 1, "Probability must be between 0 and 1"
self.inplace = inplace
[docs] def forward(self, sig: Tensor, label: Tensor, *extra_tensors: Sequence[Tensor], **kwargs: Any) -> Tuple[Tensor, ...]:
"""Forward method of the Mixup augmenter.
sig : torch.Tensor
Batched ECGs to be augmented, of shape ``(batch, lead, siglen)``.
label : torch.Tensor
Label tensor of the ECGs.
extra_tensors: Sequence[torch.Tensor], optional
Not used, but kept for consistency with other augmenters.
**kwargs : dict, optional
Not used, but kept for consistency with other augmenters.
sig : torch.Tensor,
The augmented ECGs.
label : torch.Tensor
The augmented labels.
extra_tensors : Sequence[torch.Tensor], optional
Unchanged extra tensors.
batch, lead, siglen = sig.shape
# TODO: make `lam` different for each batch element, using
# lam = torch.from_numpy(DEFAULTS.RNG.beta(self.alpha, self.beta, batch), dtype=sig.dtype, device=sig.device)
# of shape (batch,)
lam = DEFAULTS.RNG.beta(self.alpha, self.beta)
indices = np.arange(batch, dtype=int)
ori = self.get_indices(prob=self.prob, pop_size=batch)
# print(f"ori = {ori}, len(ori) = {len(ori)}")
perm = deepcopy(ori)
indices[ori] = perm
indices = torch.from_numpy(indices).long()
if not self.inplace:
sig = sig.clone()
label = label.clone()
sig = lam * sig + (1 - lam) * sig[indices]
label = lam * label + (1 - lam) * label[indices]
# TODO: if `lam` is a Tensor of shape (batch,) instead of a scalar,
# sig = lam.view(batch, 1, 1) * sig + (1 - lam).view(batch, 1, 1) * sig[indices]
# label = lam.view(batch, 1) * label + (1 - lam).view(batch, 1) * label[indices]
return (sig, label, *extra_tensors)