Source code for torch_ecg.augmenters.random_flip

"""
"""

from numbers import Real
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch import Tensor

from .base import Augmenter

__all__ = [
    "RandomFlip",
]


[docs]class RandomFlip(Augmenter): """Randomly flip the ECGs along the voltage axis. Parameters ---------- fs : int, optional Sampling frequency of the ECGs to be augmented per_channel : bool, default True Whether to flip each channel independently. prob : float or Sequence[float], default ``[0.4, 0.2]`` Probability of performing flip, the first probality is for the batch dimension, the second probability is for the lead dimension. inplace : bool, default True If True, ECG signal tensors will be modified inplace. kwargs : dict, optional Additional keyword arguments. Examples -------- .. code-block:: python rf = RandomFlip() sig = torch.randn(32, 12, 5000) sig, _ = rf(sig, None) """ __name__ = "RandomFlip" def __init__( self, fs: Optional[int] = None, per_channel: bool = True, prob: Union[Sequence[float], float] = [0.4, 0.2], inplace: bool = True, **kwargs: Any ) -> None: super().__init__() self.fs = fs self.per_channel = per_channel self.inplace = inplace self.prob = prob if isinstance(self.prob, Real): self.prob = np.array([self.prob, self.prob]) else: self.prob = np.array(self.prob) assert (self.prob >= 0).all() and (self.prob <= 1).all(), "Probability must be between 0 and 1"
[docs] def forward( self, sig: Tensor, label: Optional[Tensor], *extra_tensors: Sequence[Tensor], **kwargs: Any ) -> Tuple[Tensor, ...]: """Forward function of the RandomFlip augmenter. Parameters ---------- sig : torch.Tensor The ECGs to be augmented, of shape ``(batch, lead, siglen)``. label : torch.Tensor, optional Label tensor of the ECGs. Not used, but kept for consistency with other augmenters. extra_tensors : Sequence[torch.Tensor], optional Not used, but kept for consistency with other augmenters. kwargs : dict, optional Additional keyword arguments. Not used, but kept for consistency with other augmenters. Returns ------- sig : torch.Tensor The augmented ECGs. label : torch.Tensor The label tensor of the augmented ECGs, unchanged. extra_tensors : Sequence[torch.Tensor], optional Unchanged extra tensors. """ batch, lead, siglen = sig.shape if not self.inplace: sig = sig.clone() if self.prob[0] == 0: return (sig, label, *extra_tensors) if self.per_channel: flip = torch.ones((batch, lead, 1), dtype=sig.dtype, device=sig.device) for i in self.get_indices(prob=self.prob[0], pop_size=batch): flip[i, self.get_indices(prob=self.prob[1], pop_size=lead), ...] = -1 sig = sig.mul_(flip) else: flip = torch.ones((batch, 1, 1), dtype=sig.dtype, device=sig.device) flip[self.get_indices(prob=self.prob[0], pop_size=batch), ...] = -1 sig = sig.mul_(flip) return (sig, label, *extra_tensors)
[docs] def extra_repr_keys(self) -> List[str]: return [ "per_channel", "prob", "inplace", ] + super().extra_repr_keys()