Source code for torch_ecg.databases.physionet_databases.cinc2017

# -*- coding: utf-8 -*-

import math
import os
import time
import warnings
from pathlib import Path
from typing import Any, List, Optional, Sequence, Union

import numpy as np
import pandas as pd

from ...cfg import DEFAULTS
from ...utils.misc import add_docstring, get_record_list_recursive3
from ..base import DEFAULT_FIG_SIZE_PER_SEC, DataBaseInfo, PhysioNetDataBase

__all__ = [
    "CINC2017",
]


_CINC2017_INFO = DataBaseInfo(
    title="""
    AF Classification from a Short Single Lead ECG Recording
    -- The PhysioNet Computing in Cardiology Challenge 2017
    """,
    about="""
    1. training set contains 8,528 single lead ECG recordings lasting from 9 s to just over 60 s, and the test set contains 3,658 ECG recordings of similar lengths
    2. records are of frequency 300 Hz and have been band pass filtered
    3. data distribution:

        +------------------+--------------+-------------------------------------+
        |                  |              |           Time length (s)           |
        |  Type            | # recording  +------+------+------+--------+-------+
        |                  |              | Mean | SD   | Max  | Median | Min   |
        +==================+==============+======+======+======+========+=======+
        | Normal           | 5154         | 31.9 | 10.0 | 61.0 | 30     | 9.0   |
        +------------------+--------------+------+------+------+--------+-------+
        | AF               | 771          | 31.6 | 12.5 | 60   | 30     | 10.0  |
        +------------------+--------------+------+------+------+--------+-------+
        | Other rhythm     | 2557         | 34.1 | 11.8 | 60.9 | 30     | 9.1   |
        +------------------+--------------+------+------+------+--------+-------+
        | Noisy            | 46           | 27.1 | 9.0  | 60   | 30     | 10.2  |
        +------------------+--------------+------+------+------+--------+-------+
        | Total            | 8528         | 32.5 | 10.9 | 61.0 | 30     | 9.0   |
        +------------------+--------------+------+------+------+--------+-------+

    4. Webpage of the database on PhysioNet [1]_.
    """,
    usage=[
        "Atrial fibrillation (AF) detection",
    ],
    references=[
        "https://physionet.org/content/challenge-2017/",
    ],
    doi=[
        "10.22489/CinC.2017.065-469",
    ],
)


[docs]@add_docstring(_CINC2017_INFO.format_database_docstring(), mode="prepend") class CINC2017(PhysioNetDataBase): """ Parameters ---------- db_dir : `path-like`, optional Storage path of the database. If not specified, data will be fetched from Physionet. working_dir : `path-like`, optional Working directory, to store intermediate files and log files. verbose : int, default 1 Level of logging verbosity. kwargs : dict, optional Auxilliary key word arguments. """ __name__ = "CINC2017" def __init__( self, db_dir: Optional[Union[str, bytes, os.PathLike]] = None, working_dir: Optional[Union[str, bytes, os.PathLike]] = None, verbose: int = 1, **kwargs: Any, ) -> None: super().__init__( db_name="challenge-2017", db_dir=db_dir, working_dir=working_dir, verbose=verbose, **kwargs, ) self.fs = 300 self.rec_ext = "mat" self.ann_ext = "hea" self._all_records = [] self._df_ann = pd.DataFrame() self._df_ann_ori = pd.DataFrame() self._all_ann = [] self._ls_rec() self.d_ann_names = { "N": "Normal rhythm", "A": "AF rhythm", "O": "Other rhythm", "~": "Noisy", } self.palette = { "N": "green", "A": "red", "O": "yellow", "~": "blue", } # self._url_compressed = ( # "https://physionet.org/static/published-projects/challenge-2017/" # "af-classification-from-a-short-single-lead-ecg-recording-" # "the-physionetcomputing-in-cardiology-challenge-2017-1.0.0.zip" # ) self._url_compressed = self.get_file_download_url("training2017.zip") def _ls_rec(self) -> None: """Find all records in the database directory and store them (path, metadata, etc.) in some private attributes. """ record_list_fp = self.db_dir / "RECORDS" self._df_records = pd.DataFrame() if record_list_fp.is_file(): self._df_records["record"] = [item for item in record_list_fp.read_text().splitlines() if len(item) > 0] if len(self._df_records) > 0: if self._subsample is not None: size = min( len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._df_records["path"] = self._df_records["record"].apply(lambda x: (self.db_dir / x).resolve()) self._df_records = self._df_records[self._df_records["path"].apply(lambda x: x.is_file())] self._df_records["record"] = self._df_records["path"].apply(lambda x: x.name) if len(self._df_records) == 0: self.logger.info( "Please wait patiently to let the reader find " "all records of the database from local storage..." ) start = time.time() self._df_records["path"] = get_record_list_recursive3( db_dir=str(self.db_dir), rec_patterns=f"A[\\d]{{5}}\\.{self.rec_ext}", relative=False, ) if self._subsample is not None: size = min( len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._df_records["path"] = self._df_records["path"].apply(lambda x: Path(x)) self.logger.info(f"Done in {time.time() - start:.3f} seconds!") self._df_records["record"] = self._df_records["path"].apply(lambda x: x.name) self._df_records.set_index("record", inplace=True) self._all_records = self._df_records.index.values.tolist() ann_file = list(self.db_dir.rglob("REFERENCE.csv")) if len(ann_file) > 0: self._df_ann = pd.read_csv(ann_file[0], header=None) self._df_ann.columns = ["rec", "ann"] else: self._df_ann = pd.DataFrame(columns=["rec", "ann"]) warnings.warn( "Cannot find the annotation file `REFERENCE.csv`!", RuntimeWarning, ) ann_file = list(self.db_dir.rglob("REFERENCE-original.csv")) if len(ann_file) > 0: self._df_ann_ori = pd.read_csv(ann_file[0], header=None) self._df_ann_ori.columns = ["rec", "ann"] else: self._df_ann_ori = pd.DataFrame(columns=["rec", "ann"]) warnings.warn( "Cannot find the annotation file `REFERENCE-original.csv`!", RuntimeWarning, ) # ["N", "A", "O", "~"] self._all_ann = list(set(self._df_ann.ann.unique().tolist() + self._df_ann_ori.ann.unique().tolist()))
[docs] def load_ann(self, rec: Union[str, int], original: bool = False, ann_format: str = "a") -> str: """Load the annotation of the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. original : bool, default False If True, load annotations from the annotation file ``REFERENCE-original.csv``, otherwise from ``REFERENCE.csv``. ann_format : {"a", "f"}, optional Format of returned annotation, by default "a". - "a" - abbreviation - "f" - full name Returns ------- ann : str Annotation (label) of the record. """ if isinstance(rec, int): rec = self[rec] assert rec in self.all_records and ann_format.lower() in ["a", "f"] if original: df = self._df_ann_ori else: df = self._df_ann row = df[df.rec == rec].iloc[0] ann = row.ann if ann_format.lower() == "f": ann = self.d_ann_names[ann] return ann
[docs] def plot( self, rec: Union[str, int], data: Optional[np.ndarray] = None, ann: Optional[str] = None, ticks_granularity: int = 0, rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None, ) -> None: """Plot the ECG signal of the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. data : numpy.ndarray, optional The ECG signal to plot. If not None, data of `rec` will not be used. This is useful when plotting filtered data. ann : dict, optional, Annotations for `data`, which is a dict with keys "SPB_indices", "PVC_indices", and with :class:`~numpy.ndarray` values. Ignored if `data` is None. ticks_granularity : int, default 0 Granularity to plot axis ticks, the higher the more ticks. 0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks) rpeak_inds : array_like, optional Array of indices of R peaks. Returns ------- None """ if isinstance(rec, int): rec = self[rec] if "plt" not in dir(): import matplotlib.pyplot as plt import matplotlib.patches as mpatches if data is None: _data = self.load_data( rec, units="μV", data_format="flat", ) else: units = self._auto_infer_units(data) if units == "mV": _data = data * 1000 elif units == "μV": _data = data.copy() if ann is None or data is None: ann = self.load_ann(rec, ann_format="a") ann_fullname = self.load_ann(rec, ann_format="f") else: ann_fullname = self.d_ann_names.get(ann, ann) patch = mpatches.Patch(color=self.palette.get(ann, "blue"), label=ann_fullname) if rpeak_inds is not None: rpeak_secs = np.array(rpeak_inds) / self.fs line_len = self.fs * 25 # 25 seconds nb_lines = math.ceil(len(_data) / line_len) for idx in range(nb_lines): seg = _data[idx * line_len : (idx + 1) * line_len] secs = (np.arange(len(seg)) + idx * line_len) / self.fs fig_sz_w = int(round(DEFAULT_FIG_SIZE_PER_SEC * len(seg) / self.fs)) y_range = np.max(np.abs(seg)) + 100 fig_sz_h = 6 * y_range / 1500 fig, ax = plt.subplots(figsize=(fig_sz_w, fig_sz_h)) ax.plot(secs, seg, color="black") ax.axhline(y=0, linestyle="-", linewidth="1.0", color="red") if ticks_granularity >= 1: ax.xaxis.set_major_locator(plt.MultipleLocator(0.2)) ax.yaxis.set_major_locator(plt.MultipleLocator(500)) ax.grid(which="major", linestyle="-", linewidth="0.5", color="red") if ticks_granularity >= 2: ax.xaxis.set_minor_locator(plt.MultipleLocator(0.04)) ax.yaxis.set_minor_locator(plt.MultipleLocator(100)) ax.grid(which="minor", linestyle=":", linewidth="0.5", color="black") ax.legend(handles=[patch], loc="lower left", prop={"size": 16}) if rpeak_inds is not None: for r in rpeak_secs: ax.axvspan(r - 0.01, r + 0.01, color="green", alpha=0.7) ax.set_xlim(secs[0], secs[-1]) ax.set_ylim(-y_range, y_range) ax.set_xlabel("Time [s]") ax.set_ylabel("Voltage [μV]") plt.show()
@property def _validation_set(self) -> List[str]: """The validation set specified at https://physionet.org/content/challenge-2017/1.0.0/ """ return ( "A00001,A00002,A00003,A00004,A00005,A00006,A00007,A00008,A00009,A00010," "A00011,A00012,A00013,A00014,A00015,A00016,A00017,A00018,A00019,A00020," "A00021,A00022,A00023,A00024,A00025,A00026,A00027,A00028,A00029,A00030," "A00031,A00032,A00033,A00034,A00035,A00036,A00037,A00038,A00039,A00040," "A00041,A00042,A00043,A00044,A00045,A00046,A00047,A00048,A00049,A00050," "A00051,A00052,A00053,A00054,A00055,A00056,A00057,A00058,A00059,A00060," "A00061,A00062,A00063,A00064,A00065,A00066,A00067,A00068,A00069,A00070," "A00071,A00072,A00073,A00074,A00075,A00076,A00077,A00078,A00079,A00080," "A00081,A00082,A00083,A00084,A00085,A00086,A00087,A00088,A00089,A00090," "A00091,A00092,A00093,A00094,A00095,A00096,A00097,A00098,A00099,A00100," "A00101,A00102,A00103,A00104,A00105,A00106,A00107,A00108,A00109,A00110," "A00111,A00112,A00113,A00114,A00115,A00116,A00117,A00118,A00119,A00120," "A00121,A00122,A00123,A00124,A00125,A00126,A00127,A00128,A00129,A00130," "A00131,A00132,A00133,A00134,A00135,A00136,A00137,A00138,A00139,A00140," "A00141,A00142,A00143,A00144,A00145,A00146,A00147,A00148,A00149,A00150," "A00151,A00152,A00153,A00154,A00155,A00156,A00157,A00158,A00159,A00160," "A00161,A00162,A00163,A00164,A00165,A00166,A00167,A00168,A00169,A00170," "A00171,A00172,A00173,A00174,A00175,A00176,A00177,A00178,A00179,A00180," "A00181,A00182,A00183,A00184,A00185,A00186,A00187,A00188,A00189,A00190," "A00191,A00192,A00193,A00194,A00195,A00196,A00197,A00198,A00199,A00200," "A00201,A00202,A00203,A00204,A00205,A00206,A00207,A00208,A00209,A00210," "A00211,A00212,A00213,A00214,A00215,A00216,A00217,A00218,A00219,A00220," "A00221,A00222,A00223,A00224,A00225,A00226,A00227,A00228,A00229,A00230," "A00231,A00232,A00233,A00234,A00235,A00236,A00237,A00238,A00239,A00240," "A00241,A00242,A00244,A00245,A00247,A00248,A00249,A00253,A00267,A00271," "A00301,A00321,A00375,A00395,A00397,A00405,A00422,A00432,A00438,A00439," "A00441,A00456,A00465,A00473,A00486,A00509,A00519,A00520,A00524,A00542," "A00551,A00585,A01006,A01070,A01246,A01299,A01521,A01567,A01707,A01727," "A01772,A01833,A02168,A02372,A02772,A02785,A02833,A03549,A03738,A04086," "A04137,A04170,A04186,A04216,A04282,A04452,A04522,A04701,A04735,A04805" ).split(",") @property def database_info(self) -> DataBaseInfo: return _CINC2017_INFO