Source code for torch_ecg.databases.physionet_databases.afdb

# -*- coding: utf-8 -*-

import math
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

import numpy as np
import wfdb

from ...cfg import CFG
from ...utils.misc import add_docstring, get_record_list_recursive
from ...utils.utils_interval import generalized_intervals_intersection
from ..base import DEFAULT_FIG_SIZE_PER_SEC, DataBaseInfo, PhysioNetDataBase

__all__ = [
    "AFDB",
]


_AFDB_INFO = DataBaseInfo(
    title="""
    MIT-BIH Atrial Fibrillation Database
    """,
    about="""
    1. contains 25 long-term (each 10 hours) ECG recordings of human subjects with atrial fibrillation (mostly paroxysmal)
    2. 23 records out of 25 include the two ECG signals, the left 2 records 00735 and 03665 are represented only by the rhythm (.atr) and unaudited beat (.qrs) annotation files
    3. signals are sampled at 250 samples per second with 12-bit resolution over a range of ±10 millivolts, with a typical recording bandwidth of approximately 0.1 Hz to 40 Hz
    4. 4 classes of rhythms are annotated:

        - AFIB:  atrial fibrillation
        - AFL:   atrial flutter
        - J:     AV junctional rhythm
        - N:     all other rhythms

    5. rhythm annotations almost all start with "(N", except for 4 which start with '(AFIB', which are all within 1 second (250 samples)
    6. Webpage of the database on PhysioNet [1]_. Paper describing the database [2]_.
    """,
    note="""
    1. beat annotation files (.qrs files) were prepared using an automated detector and have NOT been corrected manually
    2. for some records, manually corrected beat annotation files (.qrsc files) are available
    3. one should never use wfdb.rdann with arguments `sampfrom`, since one has to know the `aux_note` (with values in ["(N", "(J", "(AFL", "(AFIB"]) before the index at `sampfrom`
    """,
    usage=[
        "Atrial fibrillation (AF) detection",
    ],
    references=[
        "https://physionet.org/content/afdb/",
        "Moody GB, Mark RG. A new method for detecting atrial fibrillation using R-R intervals. Computers in Cardiology. 10:227-230 (1983).",
    ],
    doi=[
        "10.13026/C2MW2D",
    ],
)


[docs] @add_docstring(_AFDB_INFO.format_database_docstring(), mode="prepend") class AFDB(PhysioNetDataBase): """ Parameters ---------- db_dir : `path-like`, optional Storage path of the database. If not specified, data will be fetched from Physionet. working_dir : `path-like`, optional Working directory, to store intermediate files and log files. verbose : int, default 1 Level of logging verbosity. kwargs : dict, optional Auxilliary key word arguments. """ __name__ = "AFDB" def __init__( self, db_dir: Optional[Union[str, bytes, os.PathLike]] = None, working_dir: Optional[Union[str, bytes, os.PathLike]] = None, verbose: int = 1, **kwargs: Any, ) -> None: super().__init__( db_name="afdb", db_dir=db_dir, working_dir=working_dir, verbose=verbose, **kwargs, ) self.fs = 250 self.data_ext = "dat" self.ann_ext = "atr" self.auto_beat_ann_ext = "qrs" self.manual_beat_ann_ext = "qrsc" self.all_leads = [ "ECG1", "ECG2", ] self.special_records = ["00735", "03665"] self.qrsc_records = None self._ls_rec() self.class_map = CFG(AFIB=1, AFL=2, J=3, N=0) # an extra isoelectric self.palette = kwargs.get("palette", None) if self.palette is None: self.palette = CFG( AFIB="blue", AFL="red", J="yellow", # N="green", qrs="green", ) def _ls_rec(self, local: bool = True) -> None: """ Find all records (relative path without file extension), and save into `self._all_records` for further use. Parameters ---------- local : bool, default True If True, read from local storage, prior to using :func:`wfdb.get_record_list`. """ super()._ls_rec(local=local) self._all_records = [rec for rec in self._all_records if rec not in self.special_records] self.qrsc_records = get_record_list_recursive(self.db_dir, self.manual_beat_ann_ext, relative=False) self.qrsc_records = [Path(rec).stem for rec in self.qrsc_records]
[docs] def load_ann( self, rec: Union[str, int], sampfrom: Optional[int] = None, sampto: Optional[int] = None, ann_format: Literal["intervals", "mask"] = "intervals", keep_original: bool = False, ) -> Union[Dict[str, list], np.ndarray]: """Load annotations (header) from the .hea files. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. sampfrom : int, optional Start index of the annotations to be loaded. sampto: int, optional End index of the annotations to be loaded. ann_format : {"intervals", "mask"}, default "intervals" Format of returned annotation, case insensitive. keep_original : bool, default False If True, when `ann_format` is "intervals", intervals (in the form [a,b]) will keep the same with the annotation file, otherwise subtract `sampfrom` if specified. Returns ------- ann : dict or numpy.ndarray The annotations in the format of intervals, or in the format of masks. """ fp = str(self.get_absolute_path(rec)) wfdb_ann = wfdb.rdann(fp, extension=self.ann_ext) header = wfdb.rdheader(fp) sig_len = header.sig_len sf = sampfrom or 0 st = sampto or sig_len assert st > sf, "`sampto` should be greater than `sampfrom`!" ann = CFG({k: [] for k in self.class_map.keys()}) critical_points = wfdb_ann.sample.tolist() + [sig_len] aux_note = wfdb_ann.aux_note if aux_note[0] == "(N": # ref. the doc string of the class critical_points[0] = 0 else: critical_points.insert(0, 0) aux_note.insert(0, "(N") for idx, rhythm in enumerate(aux_note): ann[rhythm.replace("(", "")].append([critical_points[idx], critical_points[idx + 1]]) ann = CFG({k: generalized_intervals_intersection(l_itv, [[sf, st]]) for k, l_itv in ann.items()}) if ann_format.lower() == "mask": tmp = deepcopy(ann) ann = np.full(shape=(st - sf,), fill_value=self.class_map.N, dtype=int) for rhythm, l_itv in tmp.items(): for itv in l_itv: ann[itv[0] - sf : itv[1] - sf] = self.class_map[rhythm] elif not keep_original: for k, l_itv in ann.items(): ann[k] = [[itv[0] - sf, itv[1] - sf] for itv in l_itv] return ann
[docs] def load_beat_ann( self, rec: Union[str, int], sampfrom: Optional[int] = None, sampto: Optional[int] = None, use_manual: bool = True, keep_original: bool = False, ) -> np.ndarray: """Load beat annotations from corresponding annotation files. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. sampfrom : int, optional Start index of the annotations to be loaded. sampto : int, optional End index of the annotations to be loaded. use_manual : bool, default True If True, use manually annotated beat annotations (qrs), instead of those generated by algorithms. keep_original : bool, default False If True, indices will keep the same with the annotation file, otherwise subtract `sampfrom` if specified. Returns ------- ann : numpy.ndarray Locations (indices) of the qrs complexes. """ if isinstance(rec, int): rec = self[rec] fp = str(self.get_absolute_path(rec)) if use_manual and rec in self.qrsc_records: ext = self.manual_beat_ann_ext else: ext = self.auto_beat_ann_ext ann = wfdb.rdann( fp, extension=ext, sampfrom=sampfrom or 0, sampto=sampto, ) ann = ann.sample if not keep_original and sampfrom is not None: ann -= sampfrom return ann
[docs] @add_docstring(load_beat_ann.__doc__) def load_rpeak_indices( self, rec: Union[str, int], sampfrom: Optional[int] = None, sampto: Optional[int] = None, use_manual: bool = True, keep_original: bool = False, ) -> np.ndarray: """ alias of `self.load_beat_ann` """ return self.load_beat_ann(rec, sampfrom, sampto, use_manual, keep_original)
[docs] def plot( self, rec: Union[str, int], data: Optional[np.ndarray] = None, ann: Optional[Dict[str, np.ndarray]] = None, rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None, ticks_granularity: int = 0, leads: Optional[Union[str, int, List[str], List[int]]] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, same_range: bool = False, **kwargs: Any, ) -> None: """ Plot the signals of a record or external signals (units in μV), with metadata (fs, labels, tranche, etc.), possibly also along with wave delineations. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. data : numpy.ndarray, optional (2-lead) ECG signal to plot, should be of the format "channel_first", and compatible with `leads`. If given, data of `rec` will not be used, which is useful when plotting filtered data. ann : dict, optional Annotations for `data`, covering those from annotation files, in the form of ``{"AFIB":l_itv, "AFL":l_itv, "J":l_itv, "N":l_itv}``, where ``l_itv`` in the form of ``[[a, b], ...]``. Ignored if `data` is None. rpeak_inds : array_like, optional Indices of R peaks, covering those from annotation files. If `data` is None, then indices should be the absolute indices in the record. ticks_granularity : int, default 0 Granularity to plot axis ticks, the higher the more ticks. 0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks). leads : str or int or List[str] or List[int], optional The leads to plot. sampfrom : int, optional Start index of the data to plot. sampto : int, optional End index of the data to plot. same_range : bool, default False If True, forces all leads to have the same y range. kwargs : dict, optional Keyword arguments for :func:`matplotlib.pyplot.plot`, etc. """ if isinstance(rec, int): rec = self[rec] if "plt" not in dir(): import matplotlib.pyplot as plt plt.MultipleLocator.MAXTICKS = 3000 if leads is None or leads == "all": _leads = self.all_leads elif isinstance(leads, str): _leads = [leads] elif isinstance(leads, int): _leads = [self.all_leads[leads]] else: _leads = leads assert all([ld in self.all_leads for ld in _leads]) or set(_leads) <= {0, 1} try: lead_indices = [self.all_leads.index(ld) for ld in _leads] except ValueError: lead_indices = _leads if data is None: _data = self.load_data( rec, leads=_leads, sampfrom=sampfrom, sampto=sampto, data_format="channel_first", units="μV", ) else: units = self._auto_infer_units(data) self.logger.info(f"input data is auto detected to have units in `{units}`") if units.lower() == "mv": _data = 1000 * data else: _data = data _leads = [f"ECG_{idx}" for idx in range(_data.shape[0])] if ann is None and data is None: _ann = self.load_ann( rec, sampfrom=sampfrom, sampto=sampto, ann_format="intervals", keep_original=False, ) else: _ann = ann or CFG({k: [] for k in self.class_map.keys()}) # indices to time _ann = {k: [[itv[0] / self.fs, itv[1] / self.fs] for itv in l_itv] for k, l_itv in _ann.items()} if rpeak_inds is None and data is None: _rpeak = self.load_rpeak_indices( rec, sampfrom=sampfrom, sampto=sampto, use_manual=True, keep_original=False, ) _rpeak = _rpeak / self.fs # indices to time else: _rpeak = np.array(rpeak_inds or []) / self.fs # indices to time ann_plot_alpha = 0.2 rpeaks_plot_alpha = 0.8 nb_leads = len(_leads) line_len = self.fs * 25 # 25 seconds nb_lines = math.ceil(_data.shape[1] / line_len) for seg_idx in range(nb_lines): seg_data = _data[..., seg_idx * line_len : (seg_idx + 1) * line_len] secs = (np.arange(seg_data.shape[1]) + seg_idx * line_len) / self.fs seg_ann = {k: generalized_intervals_intersection(l_itv, [[secs[0], secs[-1]]]) for k, l_itv in _ann.items()} seg_rpeaks = _rpeak[np.where((_rpeak >= secs[0]) & (_rpeak < secs[-1]))[0]] fig_sz_w = int(round(DEFAULT_FIG_SIZE_PER_SEC * seg_data.shape[1] / self.fs)) if same_range: y_ranges = np.ones((seg_data.shape[0],)) * np.max(np.abs(seg_data)) + 100 else: y_ranges = np.max(np.abs(seg_data), axis=1) + 100 fig_sz_h = 6 * y_ranges / 1500 fig, axes = plt.subplots(nb_leads, 1, sharex=True, figsize=(fig_sz_w, np.sum(fig_sz_h))) if nb_leads == 1: axes = [axes] for idx in range(nb_leads): axes[idx].plot(secs, seg_data[idx], color="black", label=f"lead - {_leads[idx]}") axes[idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red") # NOTE that `Locator` has default `MAXTICKS` equal to 1000 if ticks_granularity >= 1: axes[idx].xaxis.set_major_locator(plt.MultipleLocator(0.2)) axes[idx].yaxis.set_major_locator(plt.MultipleLocator(500)) axes[idx].grid(which="major", linestyle="-", linewidth="0.5", color="red") if ticks_granularity >= 2: axes[idx].xaxis.set_minor_locator(plt.MultipleLocator(0.04)) axes[idx].yaxis.set_minor_locator(plt.MultipleLocator(100)) axes[idx].grid(which="minor", linestyle=":", linewidth="0.5", color="black") for k, l_itv in seg_ann.items(): if k == "N": continue for itv in l_itv: axes[idx].axvspan( itv[0], itv[1], color=self.palette[k], alpha=ann_plot_alpha, label=k, ) for ri in seg_rpeaks: axes[idx].axvspan( ri - 0.01, ri + 0.01, color=self.palette["qrs"], alpha=rpeaks_plot_alpha, ) axes[idx].axvspan( ri - 0.075, ri + 0.075, color=self.palette["qrs"], alpha=ann_plot_alpha, ) axes[idx].legend(loc="upper left") axes[idx].set_xlim(secs[0], secs[-1]) axes[idx].set_ylim(-y_ranges[idx], y_ranges[idx]) axes[idx].set_xlabel("Time [s]") axes[idx].set_ylabel("Voltage [μV]") plt.subplots_adjust(hspace=0.2) plt.show()
@property def database_info(self) -> DataBaseInfo: return _AFDB_INFO