Source code for torch_ecg._preprocessors.bandpass

"""BandPass filter preprocessor."""

from numbers import Real
from typing import Any, List, Optional, Tuple

import numpy as np

from .base import PreProcessor, preprocess_multi_lead_signal

__all__ = [
    "BandPass",
]


[docs]class BandPass(PreProcessor): """Bandpass filtering preprocessor. Parameters ---------- lowcut : numbers.Real, optional Low cutoff frequency highcut : numbers.Real, optional High cutoff frequency. filter_type : {"butter", "fir"}, optional Type of the bandpass filter, default "butter". filter_order : int, optional Order of the bandpass filter. **kwargs : dict, optional Other arguments for :class:`PreProcessor`. Examples -------- .. code-block:: python from torch_ecg.cfg import DEFAULTS sig = DEFAULTS.RNG.randn(1000) pp = BandPass(lowcut=0.5, highcut=45, filter_type="butter", filter_order=4) sig, _ = pp(sig, 500) """ __name__ = "BandPass" def __init__( self, lowcut: Optional[Real] = 0.5, highcut: Optional[Real] = 45, filter_type: str = "butter", filter_order: Optional[int] = None, **kwargs: Any ) -> None: self.lowcut = lowcut self.highcut = highcut assert any([self.lowcut is not None, self.highcut is not None]), "At least one of lowcut and highcut should be set" if not self.lowcut: self.lowcut = 0 if not self.highcut: self.highcut = float("inf") self.filter_type = filter_type self.filter_order = filter_order
[docs] def apply(self, sig: np.ndarray, fs: int) -> Tuple[np.ndarray, int]: """Apply the preprocessor to `sig`. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - 1d array, which is a single-lead ECG; - 2d array, which is a multi-lead ECG of "lead_first" format; - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. fs : int Sampling frequency of the ECG signal. Returns ------- filtered_sig : :class:`numpy.ndarray` Bandpass filtered ECG signal. fs : :class:`int` Sampling frequency of the filtered ECG signal. """ self._check_sig(sig) filtered_sig = preprocess_multi_lead_signal( raw_sig=sig, fs=fs, band_fs=[self.lowcut, self.highcut], filter_type=self.filter_type, filter_order=self.filter_order, ) return filtered_sig, fs
[docs] def extra_repr_keys(self) -> List[str]: return [ "lowcut", "highcut", "filter_type", "filter_order", ] + super().extra_repr_keys()