Source code for torch_ecg.preprocessors.baseline_remove

"""
"""

import warnings
from numbers import Real
from typing import Any

import torch

from .._preprocessors.base import preprocess_multi_lead_signal

__all__ = [
    "BaselineRemove",
]


[docs]class BaselineRemove(torch.nn.Module): """Baseline removal using median filtering. Parameters ---------- fs : numbers.Real Sampling frequency of the ECG signal to be filtered. window1 : float, default 0.2 The smaller window size of the median filter, with units in seconds. window2 : float, default 0.6 The larger window size of the median filter, with units in seconds. inplace : bool, default True Whether to perform the filtering in-place. kwargs : dict, optional Other keyword arguments for :class:`torch.nn.Module`. """ __name__ = "BaselineRemove" def __init__(self, fs: Real, window1: float = 0.2, window2: float = 0.6, inplace: bool = True, **kwargs: Any) -> None: super().__init__() self.fs = fs self.window1 = window1 self.window2 = window2 if self.window2 < self.window1: self.window1, self.window2 = self.window2, self.window1 warnings.warn("values of `window1` and `window2` are switched", RuntimeWarning) self.inplace = inplace
[docs] def forward(self, sig: torch.Tensor) -> torch.Tensor: """Apply the preprocessor to the signal tensor. Parameters ---------- sig : torch.Tensor The ECG signal tensor, of shape ``(batch, lead, siglen)``. Returns ------- torch.Tensor The median filtered (hence baseline removed) ECG signals, of shape ``(batch, lead, siglen)``. """ if not self.inplace: sig = sig.clone() sig = torch.as_tensor( preprocess_multi_lead_signal( raw_sig=sig.cpu().numpy(), fs=self.fs, bl_win=[self.window1, self.window2], ).copy(), dtype=sig.dtype, device=sig.device, ) return sig