Source code for torch_ecg.augmenters.random_renormalize

"""
"""

from numbers import Real
from typing import Any, Iterable, List, Optional, Sequence, Tuple

import numpy as np
from torch import Tensor

from ..cfg import DEFAULTS
from ..utils.utils_signal_t import normalize as normalize_t
from .base import Augmenter

__all__ = [
    "RandomRenormalize",
]


[docs]class RandomRenormalize(Augmenter): """Randomly re-normalize the ECG tensor, using the Z-score normalization method. Parameters ---------- mean : array_like, default ``[-0.05, 0.1]`` Range of mean value of the re-normalized signal, of shape ``(2,)``; or range of mean values for each lead of the re-normalized signal, of shape ``(lead, 2)``. std : array_like, default ``[0.08, 0.32]`` Range of standard deviation of the re-normalized signal, of shape ``(2,)``; or range of standard deviations for each lead of the re-normalized signal, of shape ``(lead, 2)``. per_channel : bool, default False If True, re-normalization will be done per channel. prob : float, default 0.5 Probability of applying the random re-normalization augmenter. inplace : bool, default True Whether to apply the random re-normalization augmenter in-place. kwargs : dict, optional Additional keyword arguments. Examples -------- .. code-block:: python rrn = RandomRenormalize() sig = torch.randn(32, 12, 5000) sig, _ = rrn(sig, None) """ __name__ = "RandomRenormalize" def __init__( self, mean: Iterable[Real] = [-0.05, 0.1], std: Iterable[Real] = [0.08, 0.32], per_channel: bool = False, prob: float = 0.5, inplace: bool = True, **kwargs: Any ) -> None: super().__init__() self.mean = np.array(mean) self.mean_mean = self.mean.mean(axis=-1, keepdims=True) self.mean_scale = (self.mean[..., -1] - self.mean_mean) * 0.3 self.std = np.array(std) self.std_mean = self.std.mean(axis=-1, keepdims=True) self.std_scale = (self.std[..., -1] - self.std_mean) * 0.3 self.per_channel = per_channel if not self.per_channel: assert self.mean.ndim == 1 and self.std.ndim == 1 self.prob = prob self.inplace = inplace
[docs] def forward( self, sig: Tensor, label: Optional[Tensor], *extra_tensors: Sequence[Tensor], **kwargs: Any ) -> Tuple[Tensor, ...]: """Forward function of the RandomRenormalize augmenter. Parameters ---------- sig : torch.Tensor The input ECG tensor, of shape ``(batch, lead, siglen)``. label : torch.Tensor, optional The input ECG label tensor. Not used, but kept for compatibility with other augmenters. extra_tensors : Sequence[torch.Tensor], optional, Not used, but kept for consistency with other augmenters. kwargs : dict, optional Not used, but kept for consistency with other augmenters. Returns ------- sig : torch.Tensor The randomly re-normalized ECG tensor. label : torch.Tensor The label tensor of the augmented ECGs, unchanged. extra_tensors: Sequence[torch.Tensor], optional, Unchanged extra tensors. """ batch, lead, siglen = sig.shape if self.mean.ndim == 2: assert self.mean.shape[0] == lead if self.std.ndim == 2: assert self.std.shape[0] == lead if not self.inplace: sig = sig.clone() if self.prob == 0: return (sig, label, *extra_tensors) indices = self.get_indices(self.prob, pop_size=batch) if self.per_channel: mean = DEFAULTS.RNG.normal(self.mean_mean, self.mean_scale, size=(len(indices), lead, 1)) std = DEFAULTS.RNG.normal(self.std_mean, self.std_scale, size=(len(indices), lead, 1)) else: mean = DEFAULTS.RNG.normal(self.mean_mean, self.mean_scale, size=(len(indices), 1, 1)) std = DEFAULTS.RNG.normal(self.std_mean, self.std_scale, size=(len(indices), 1, 1)) sig[indices, ...] = normalize_t( sig[indices, ...], method="z-score", mean=mean, std=std, per_channel=self.per_channel, inplace=True, ) return (sig, label, *extra_tensors)
[docs] def extra_repr_keys(self) -> List[str]: return [ "mean", "std", "per_channel", "prob", "inplace", ] + super().extra_repr_keys()