Source code for torch_ecg._preprocessors.base

"""Base class for preprocessors."""

from abc import ABC, abstractmethod
from itertools import repeat
from numbers import Real
from typing import List, Optional, Tuple

import numpy as np
from biosppy.signals.tools import filter_signal
from scipy.ndimage import median_filter

from ..cfg import DEFAULTS
from ..utils.misc import ReprMixin, add_docstring
from ..utils.utils_signal import butter_bandpass_filter

# from scipy.signal import medfilt
# https://github.com/scipy/scipy/issues/9680


__all__ = [
    "PreProcessor",
    "preprocess_multi_lead_signal",
    "preprocess_single_lead_signal",
]


[docs]class PreProcessor(ReprMixin, ABC): """Base class for preprocessors.""" __name__ = "PreProcessor"
[docs] @abstractmethod def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: """Apply the preprocessor to `sig`. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - 1d array, which is a single-lead ECG; - 2d array, which is a multi-lead ECG of "lead_first" format; - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. fs : numbers.Real Sampling frequency of the ECG signal. """ raise NotImplementedError
@add_docstring(apply) def __call__(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: """alias of :meth:`self.apply`.""" return self.apply(sig, fs) def _check_sig(self, sig: np.ndarray) -> None: """Check validity of the signal. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - 1d array, which is a single-lead ECG; - 2d array, which is a multi-lead ECG of "lead_first" format; - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. """ if sig.ndim not in [1, 2, 3]: raise ValueError( "Invalid input ECG signal. Should be" "1d array, which is a single-lead ECG;" "or 2d array, which is a multi-lead ECG of `lead_first` format;" "or 3d array, which is a tensor of several ECGs, of shape (batch, lead, siglen)." )
[docs]def preprocess_multi_lead_signal( raw_sig: np.ndarray, fs: Real, sig_fmt: str = "channel_first", bl_win: Optional[List[Real]] = None, band_fs: Optional[List[Real]] = None, filter_type: str = "butter", filter_order: Optional[int] = None, ) -> np.ndarray: """Perform preprocessing for multi-lead ECG signal (with units in mV). preprocessing may include median filter, bandpass filter, and rpeaks detection, etc. Also works for single-lead ECG signal (setting ``sig_fmt="channel_first"``). Parameters ---------- raw_sig : numpy.ndarray The raw ECG signal, with units in mV. fs : numbers.Real Sampling frequency of `raw_sig`. sig_fmt : str, default "channel_first" Format of the multi-lead ECG signal, "channel_last" (alias "lead_last"), or "channel_first" (alias "lead_first"). bl_win : List[numbers.Real], optional Window (units in second) of baseline removal using :meth:`~scipy.ndimage.median_filter`, the first is the shorter one, the second the longer one, a typical pair is ``[0.2, 0.6]``. If is None or empty, baseline removal will not be performed. band_fs : List[numbers.Real], optional Frequency band of the bandpass filter, a typical pair is ``[0.5, 45]``. Be careful when detecting paced rhythm. If is None or empty, bandpass filtering will not be performed. filter_type : {"butter", "fir"}, optional Type of the bandpass filter, default "butter". filter_order : int, optional Order of the bandpass filter. Returns ------- filtered_ecg : numpy.ndarray The array of the processed ECG signal. The format of the signal is kept the same with the original signal, i.e. `sig_fmt`. """ raw_sig = np.asarray(raw_sig) assert raw_sig.ndim in [2, 3], "multi-lead signal should be 2d or 3d array" assert sig_fmt.lower() in [ "channel_first", "lead_first", "channel_last", "lead_last", ], f"multi-lead signal format `{sig_fmt}` not supported" if sig_fmt.lower() in ["channel_last", "lead_last"]: # might have a batch dimension at the first axis filtered_ecg = np.moveaxis(raw_sig, -2, -1).astype(DEFAULTS.np_dtype) else: filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) # remove baseline if bl_win: window1, window2 = list(repeat(1, filtered_ecg.ndim)), list(repeat(1, filtered_ecg.ndim)) window1[-1] = 2 * (int(bl_win[0] * fs) // 2) + 1 # window size must be odd window2[-1] = 2 * (int(bl_win[1] * fs) // 2) + 1 baseline = median_filter(filtered_ecg, size=window1, mode="nearest") baseline = median_filter(baseline, size=window2, mode="nearest") filtered_ecg = filtered_ecg - baseline # filter signal if band_fs: assert band_fs[0] < band_fs[1], "Invalid frequency band" nyq = 0.5 * fs if band_fs[0] <= 0 and band_fs[1] < nyq: band = "lowpass" frequency = band_fs[1] elif band_fs[1] >= nyq and band_fs[0] > 0: band = "highpass" frequency = band_fs[0] elif band_fs[0] > 0 and band_fs[1] < nyq: band = "bandpass" frequency = band_fs else: raise AssertionError("Invalid frequency band") if filter_type.lower() == "fir": filtered_ecg = filter_signal( signal=filtered_ecg, ftype="FIR", # ftype="butter", band=band, order=filter_order or int(0.2 * fs), sampling_rate=fs, frequency=frequency, )["signal"] elif filter_type.lower() == "butter": filtered_ecg = butter_bandpass_filter( data=filtered_ecg, lowcut=band_fs[0], highcut=band_fs[1], fs=fs, order=filter_order or round(0.01 * fs), # better be determined by the `buttord` ) else: raise ValueError(f"Unsupported filter type `{filter_type}`") if sig_fmt.lower() in ["channel_last", "lead_last"]: filtered_ecg = filtered_ecg.T return filtered_ecg
[docs]def preprocess_single_lead_signal( raw_sig: np.ndarray, fs: Real, bl_win: Optional[List[Real]] = None, band_fs: Optional[List[Real]] = None, filter_type: str = "butter", filter_order: Optional[int] = None, ) -> np.ndarray: """Perform preprocessing for single lead ECG signal (with units in mV). Preprocessing may include median filter, bandpass filter, and rpeaks detection, etc. Parameters ---------- raw_sig : numpy.ndarray Raw ECG signal, with units in mV. fs : numbers.Real Sampling frequency of `raw_sig`. bl_win : list (of 2 numbers.Real), optional Window (units in second) of baseline removal using :meth:`~scipy.ndimage.median_filter`, the first is the shorter one, the second the longer one, a typical pair is ``[0.2, 0.6]``. If is None or empty, baseline removal will not be performed. band_fs : list of numbers.Real, optional Frequency band of the bandpass filter, a typical pair is ``[0.5, 45]``. Be careful when detecting paced rhythm. If is None or empty, bandpass filtering will not be performed. filter_type : {"butter", "fir"}, optional Type of the bandpass filter, default "butter". filter_order : int, optional, Order of the bandpass filter. Returns ------- filtered_ecg : numpy.ndarray The array of the processed ECG signal. NOTE ---- Bandpass filter uses FIR filters, an alternative can be Butterworth filter, e.g. :meth:`~torch_ecg.utils.butter_bandpass_filter`. """ filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) assert filtered_ecg.ndim == 1, "single-lead signal should be 1d array" # remove baseline if bl_win: window1 = 2 * (int(bl_win[0] * fs) // 2) + 1 # window size must be odd window2 = 2 * (int(bl_win[1] * fs) // 2) + 1 baseline = median_filter(filtered_ecg, size=window1, mode="nearest") baseline = median_filter(baseline, size=window2, mode="nearest") filtered_ecg = filtered_ecg - baseline # filter signal if band_fs: assert band_fs[0] < band_fs[1], "Invalid frequency band" nyq = 0.5 * fs if band_fs[0] <= 0 and band_fs[1] < nyq: band = "lowpass" frequency = band_fs[1] elif band_fs[1] >= nyq and band_fs[0] > 0: band = "highpass" frequency = band_fs[0] elif band_fs[0] > 0 and band_fs[1] < nyq: band = "bandpass" frequency = band_fs else: raise AssertionError("Invalid frequency band") if filter_type.lower() == "fir": filtered_ecg = filter_signal( signal=filtered_ecg, ftype="FIR", # ftype="butter", band=band, order=int(0.3 * fs), sampling_rate=fs, frequency=frequency, )["signal"] elif filter_type.lower() == "butter": filtered_ecg = butter_bandpass_filter( data=filtered_ecg, lowcut=band_fs[0], highcut=band_fs[1], fs=fs, order=filter_order or round(0.01 * fs), # better be determined by the `buttord` ) else: raise ValueError(f"Unsupported filter type `{filter_type}`") return filtered_ecg