Source code for torch_ecg.augmenters.random_masking

"""
"""

from numbers import Real
from random import randint
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch import Tensor

from .base import Augmenter

__all__ = [
    "RandomMasking",
]


[docs]class RandomMasking(Augmenter): """Randomly mask ECGs with a probability. Parameters ---------- fs : int Sampling frequency of the ECGs to be augmented. mask_value : numbers.Real, default 0.0 Value to mask with. mask_width : Sequence[numbers.Real], default ``[0.08, 0.18]`` Width range of the masking window, with units in seconds prob : numbers.Real or Sequence[numbers.Real], default ``[0.3, 0.15]`` Probabilities of masking ECG signals, the first probality is for the batch dimension, the second probability is for the lead dimension. Note that 0.15 is approximately the proportion of QRS complexes in ECGs. inplace : bool, default True Whether to mask inplace or not. kwargs : dict, optional Additional keyword arguments. Examples -------- .. code-block:: python rm = RandomMasking(fs=500, prob=0.7) sig = torch.randn(32, 12, 5000) critical_points = [np.arange(250, 5000 - 250, step=400) for _ in range(32)] sig, _ = rm(sig, None, critical_points=critical_points) """ __name__ = "RandomMasking" def __init__( self, fs: int, mask_value: Real = 0.0, mask_width: Sequence[Real] = [0.08, 0.18], prob: Union[Sequence[Real], Real] = [0.3, 0.15], inplace: bool = True, **kwargs: Any ) -> None: super().__init__() self.fs = fs self.prob = prob if isinstance(self.prob, Real): self.prob = np.array([self.prob, self.prob]) else: self.prob = np.array(self.prob) assert (self.prob >= 0).all() and (self.prob <= 1).all(), "Probability must be between 0 and 1" self.mask_value = mask_value self.mask_width = (np.array(mask_width) * self.fs).round().astype(int) self.inplace = inplace
[docs] def forward( self, sig: Tensor, label: Optional[Tensor], *extra_tensors: Sequence[Tensor], critical_points: Optional[Sequence[Sequence[int]]] = None, **kwargs: Any ) -> Tuple[Tensor, ...]: """Forward method of the RandomMasking augmenter. Parameters ---------- sig : torch.Tensor Batched ECGs to be augmented, of shape ``(batch, lead, siglen)``. label : torch.Tensor Label tensor of the ECGs. Not used, but kept for compatibility with other augmenters. extra_tensors : Sequence[torch.Tensor], optional Not used, but kept for consistency with other augmenters. critical_points : Sequence[Sequence[int]], optional If given, random masking will be performed in windows centered at these points. This is useful for example when one wants to randomly mask QRS complexes. kwargs : dict, optional Not used, but kept for consistency with other augmenters. Returns ------- sig : torch.Tensor The augmented ECGs, of shape ``(batch, lead, siglen)``. label : torch.Tensor Label tensor of the augmented ECGs, unchanged. extra_tensors : Sequence[torch.Tensor], optional Unchanged extra tensors. """ batch, lead, siglen = sig.shape if not self.inplace: sig = sig.clone() if self.prob[0] == 0: return (sig, label, *extra_tensors) sig_mask_prob = self.prob[1] / self.mask_width[1] sig_mask_scale_ratio = min(self.prob[1] / 4, 0.1) / self.mask_width[1] mask = torch.full_like(sig, 1, dtype=sig.dtype, device=sig.device) for batch_idx in self.get_indices(prob=self.prob[0], pop_size=batch): if critical_points is not None: indices = self.get_indices(prob=self.prob[1], pop_size=len(critical_points[batch_idx])) indices = np.arange(siglen)[indices] else: indices = np.array( self.get_indices( prob=sig_mask_prob, pop_size=siglen - self.mask_width[1], scale_ratio=sig_mask_scale_ratio, ) ) indices += self.mask_width[1] // 2 for j in indices: masked_radius = randint(self.mask_width[0], self.mask_width[1]) // 2 mask[batch_idx, :, j - masked_radius : j + masked_radius] = self.mask_value # print(f"batch_idx = {batch_idx}, len(indices) = {len(indices)}") sig = sig.mul_(mask) return (sig, label, *extra_tensors)
[docs] def extra_repr_keys(self) -> List[str]: return [ "fs", "mask_value", "mask_width", "prob", "inplace", ] + super().extra_repr_keys()