"""
"""
from numbers import Real
from typing import Any, Iterable, List, Optional, Sequence, Tuple
import numpy as np
from torch import Tensor
from ..cfg import DEFAULTS
from ..utils.utils_signal_t import normalize as normalize_t
from .base import Augmenter
__all__ = [
"RandomRenormalize",
]
[docs]class RandomRenormalize(Augmenter):
"""Randomly re-normalize the ECG tensor,
using the Z-score normalization method.
Parameters
----------
mean : array_like, default ``[-0.05, 0.1]``
Range of mean value of the re-normalized signal, of shape ``(2,)``;
or range of mean values for each lead of the re-normalized signal,
of shape ``(lead, 2)``.
std : array_like, default ``[0.08, 0.32]``
Range of standard deviation of the re-normalized signal, of shape ``(2,)``;
or range of standard deviations for each lead of the re-normalized signal,
of shape ``(lead, 2)``.
per_channel : bool, default False
If True, re-normalization will be done per channel.
prob : float, default 0.5
Probability of applying the random re-normalization augmenter.
inplace : bool, default True
Whether to apply the random re-normalization augmenter in-place.
kwargs : dict, optional
Additional keyword arguments.
Examples
--------
.. code-block:: python
rrn = RandomRenormalize()
sig = torch.randn(32, 12, 5000)
sig, _ = rrn(sig, None)
"""
__name__ = "RandomRenormalize"
def __init__(
self,
mean: Iterable[Real] = [-0.05, 0.1],
std: Iterable[Real] = [0.08, 0.32],
per_channel: bool = False,
prob: float = 0.5,
inplace: bool = True,
**kwargs: Any
) -> None:
super().__init__()
self.mean = np.array(mean)
self.mean_mean = self.mean.mean(axis=-1, keepdims=True)
self.mean_scale = (self.mean[..., -1] - self.mean_mean) * 0.3
self.std = np.array(std)
self.std_mean = self.std.mean(axis=-1, keepdims=True)
self.std_scale = (self.std[..., -1] - self.std_mean) * 0.3
self.per_channel = per_channel
if not self.per_channel:
assert self.mean.ndim == 1 and self.std.ndim == 1
self.prob = prob
self.inplace = inplace
[docs] def forward(
self, sig: Tensor, label: Optional[Tensor], *extra_tensors: Sequence[Tensor], **kwargs: Any
) -> Tuple[Tensor, ...]:
"""Forward function of the RandomRenormalize augmenter.
Parameters
----------
sig : torch.Tensor
The input ECG tensor, of shape ``(batch, lead, siglen)``.
label : torch.Tensor, optional
The input ECG label tensor.
Not used, but kept for compatibility with other augmenters.
extra_tensors : Sequence[torch.Tensor], optional,
Not used, but kept for consistency with other augmenters.
kwargs : dict, optional
Not used, but kept for consistency with other augmenters.
Returns
-------
sig : torch.Tensor
The randomly re-normalized ECG tensor.
label : torch.Tensor
The label tensor of the augmented ECGs, unchanged.
extra_tensors: Sequence[torch.Tensor], optional,
Unchanged extra tensors.
"""
batch, lead, siglen = sig.shape
if self.mean.ndim == 2:
assert self.mean.shape[0] == lead
if self.std.ndim == 2:
assert self.std.shape[0] == lead
if not self.inplace:
sig = sig.clone()
if self.prob == 0:
return (sig, label, *extra_tensors)
indices = self.get_indices(self.prob, pop_size=batch)
if self.per_channel:
mean = DEFAULTS.RNG.normal(self.mean_mean, self.mean_scale, size=(len(indices), lead, 1))
std = DEFAULTS.RNG.normal(self.std_mean, self.std_scale, size=(len(indices), lead, 1))
else:
mean = DEFAULTS.RNG.normal(self.mean_mean, self.mean_scale, size=(len(indices), 1, 1))
std = DEFAULTS.RNG.normal(self.std_mean, self.std_scale, size=(len(indices), 1, 1))
sig[indices, ...] = normalize_t(
sig[indices, ...],
method="z-score",
mean=mean,
std=std,
per_channel=self.per_channel,
inplace=True,
)
return (sig, label, *extra_tensors)
[docs] def extra_repr_keys(self) -> List[str]:
return [
"mean",
"std",
"per_channel",
"prob",
"inplace",
] + super().extra_repr_keys()