Source code for torch_ecg.databases.other_databases.sph

# -*- coding: utf-8 -*-

import os
import time
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union

import h5py
import numpy as np
import pandas as pd

from ...cfg import DEFAULTS
from ...utils import EAK
from ...utils.download import http_get
from ...utils.misc import add_docstring, get_record_list_recursive3, ms2samples
from ..base import DEFAULT_FIG_SIZE_PER_SEC, DataBaseInfo, _DataBase, _PlotCfg

__all__ = [
    "SPH",
]


_SPH_INFO = DataBaseInfo(
    title="""
    Shandong Provincial Hospital Database
    """,
    about=r"""
    1. contains 25770 ECG records from 24666 patients (55.36% male and 44.64% female), with between 10 and 60 seconds
    2. sampling frequency is 500 Hz
    3. records were acquired from Shandong Provincial Hospital (SPH) between 2019/08 and 2020/08
    4. diagnostic statements of all ECG records are in full compliance with the AHA/ACC/HRS recommendations, consisting of 44 primary statements and 15 modifiers
    5. 46.04% records in the dataset contain ECG abnormalities, and 14.45% records have multiple diagnostic statements
    6. (IMPORTANT) noises caused by the power line interference, baseline wander, and muscle contraction have been removed by the machine
    7. (Label production) The ECG analysis system automatically calculate nine ECG features for reference, which include heart rate, P wave duration, P-R interval, QRS duration, QT interval, corrected QT (QTc) interval, QRS axis, the amplitude of the R wave in lead V5 (RV5), and the amplitude of the S wave in lead V1 (SV1). A cardiologist made the final diagnosis in consideration of the patient health record.
    8. The paper [1]_, [2]_. Data can be downloaded from [3]_. The annotation system is described in [4]_.
    """,
    usage=[
        "ECG arrhythmia detection",
    ],
    references=[
        "https://www.nature.com/articles/s41597-022-01403-5",
        "Liu, H., Chen, D., Chen, D. et al. A large-scale multi-label 12-lead electrocardiogram database with standardized diagnostic statements. Sci Data 9, 272 (2022). https://doi.org/10.1038/s41597-022-01403-5",
        "https://springernature.figshare.com/collections/A_large-scale_multi-label_12-lead_electrocardiogram_database_with_standardized_diagnostic_statements/5779802/1",
        "Mason, J. W., Hancock, E. W. & Gettes, L. S. Recommendations for the standardization and interpretation of the electrocardiogram. Circulation 115, 1325–1332 (2007).",
    ],
    doi=[
        "10.1038/s41597-022-01403-5",
        "10.6084/m9.figshare.c.5779802.v1",
    ],
)


[docs]@add_docstring(_SPH_INFO.format_database_docstring(), mode="prepend") class SPH(_DataBase): """ Parameters ---------- db_dir : `path-like`, optional Storage path of the database. If not specified, data will be fetched from Physionet. working_dir : `path-like`, optional Working directory, to store intermediate files and log files. verbose : int, default 1 Level of logging verbosity. kwargs : dict, optional Auxilliary key word arguments """ __name__ = "SPH" def __init__( self, db_dir: Optional[Union[str, bytes, os.PathLike]] = None, working_dir: Optional[Union[str, bytes, os.PathLike]] = None, verbose: int = 1, **kwargs: Any, ) -> None: super().__init__( db_name="SPH", db_dir=db_dir, working_dir=working_dir, verbose=verbose, **kwargs, ) self.data_ext = "h5" self.ann_ext = None self.header_ext = None self.fs = 500 self.all_leads = deepcopy(EAK.Standard12Leads) self._version = "v1" self._df_code = None self._df_metadata = None self._all_records = None self._ls_rec() def _ls_rec(self) -> None: """Find all records in the database directory and store them (path, metadata, etc.) in some private attributes. """ record_list_fp = self.db_dir / "RECORDS" self._df_records = pd.DataFrame() write_file = False if record_list_fp.is_file(): self._df_records["record"] = [item for item in record_list_fp.read_text().splitlines() if len(item) > 0] if self._subsample is not None: size = min( len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._df_records["path"] = self._df_records["record"].apply(lambda x: (self.db_dir / x).resolve()) self._df_records["record"] = self._df_records["path"].apply(lambda x: x.name) indices = self._df_records["path"].apply(lambda x: x.is_file()) self._df_records = self._df_records.loc[indices] if len(self._df_records) == 0: write_file = True self.logger.info( "Please wait patiently to let the reader find " "all records of the database from local storage..." ) start = time.time() record_pattern = "A[\\d]{5}\\.h5" self._df_records["path"] = get_record_list_recursive3(self.db_dir, record_pattern, relative=False) if self._subsample is not None: size = min( len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._df_records["path"] = self._df_records["path"].apply(lambda x: Path(x)) self.logger.info(f"Done in {time.time() - start:.3f} seconds!") self._df_records["record"] = self._df_records["path"].apply(lambda x: x.name) self._df_records.set_index("record", inplace=True) self._all_records = self._df_records.index.values.tolist() if write_file and self._subsample is None: record_list_fp.write_text( "\n".join(self._df_records["path"].apply(lambda x: x.relative_to(self.db_dir).as_posix()).tolist()) ) if (self.db_dir / "code.csv").is_file(): self._df_code = pd.read_csv(self.db_dir / "code.csv").astype(str) if (self.db_dir / "metadata.csv").is_file(): self._df_metadata = pd.read_csv(self.db_dir / "metadata.csv")
[docs] def get_subject_id(self, rec: Union[str, int]) -> str: """Attach a unique subject ID for the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records` Returns ------- sid : str Subject ID associated with the record. """ if isinstance(rec, int): rec = self[rec] sid = self._df_metadata.loc[self._df_metadata["ECG_ID"] == rec]["Patient_ID"].iloc[0] return sid
[docs] def load_data( self, rec: Union[str, int], leads: Optional[Union[str, int, List[Union[str, int]]]] = None, data_format: str = "channel_first", units: str = "mV", return_fs: bool = False, ) -> np.ndarray: """Load ECG data from h5 file of the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. leads : str or int or List[str] or List[int], optional The leads of the ECG data to be loaded. data_format : str, default "channel_first" Format of the ECG data, "channel_last" (alias "lead_last"), or "channel_first" (alias "lead_first"). units : str, default "mV" Units of the output signal, can also be "μV" (alias "uV", "muV"). return_fs : bool, default False Whether to return the sampling frequency of the output signal. Returns ------- data : numpy.ndarray The loaded ECG data. data_fs : numbers.Real, optional Sampling frequency of the output signal. Returned if `return_fs` is True. """ assert data_format.lower() in [ "channel_first", "lead_first", "channel_last", "lead_last", ], f"Invalid data_format: `{data_format}`" _leads = self._normalize_leads(leads, numeric=True) with h5py.File(self.get_absolute_path(rec, extension=self.data_ext), "r") as f: data = f["ecg"][_leads].astype(DEFAULTS.DTYPE.NP) if units.lower() in ["uv", "μv", "muv"]: data = data * 1000 elif units.lower() != "mv": raise AssertionError(f"Invalid units: `{units}`") if data_format.lower() in ["channel_last", "lead_last"]: data = data.T if return_fs: return data, self.fs return data
[docs] def load_ann(self, rec: Union[str, int], ann_format: str = "c", ignore_modifier: bool = True) -> List[str]: """Load annotation from the metadata file. Parameters ---------- rec : int or str Record name or index of the record in :attr:`all_records`. ann_format : str, default "a" Format of labels, one of the following (case insensitive): - "a": abbreviations - "f": full names - "c": AHACode ignore_modifier : bool, default True Whether to ignore the modifiers of the annotations or not. For example, "60+310" will be converted to "60" Returns ------- labels : List[str] The list of labels. """ if isinstance(rec, int): rec = self[rec] labels = [lb.strip() for lb in self._df_metadata[self._df_metadata["ECG_ID"] == rec]["AHA_Code"].iloc[0].split(";")] modifiers = [lb.split("+")[1] if "+" in lb else "" for lb in labels] if ignore_modifier: labels = [lb.split("+")[0] for lb in labels] if ann_format.lower() == "c": pass # default format elif ann_format.lower() == "f": labels = [self._df_code[self._df_code["Code"] == lb.split("+")[0]]["Description"].iloc[0] for lb in labels] if not ignore_modifier: labels = [ f"""{self._df_code[self._df_code["Code"] == m]["Description"].iloc[0]} {lb}""" if len(m) > 0 else lb for lb, m in zip(labels, modifiers) ] elif ann_format.lower() == "a": raise NotImplementedError("Abbreviations are not supported yet") else: raise ValueError(f"Unknown annotation format: `{ann_format}`") return labels
[docs] def get_subject_info(self, rec_or_sid: Union[str, int], items: Optional[List[str]] = None) -> dict: """Read auxiliary information of a subject (a record) from the header files. Parameters ---------- rec : int or str Record name, or index of the record in :attr:`all_records`, or the subject ID. items : List[str], optional Items of information to be returned (e.g. age, sex, etc.). Returns ------- subject_info : dict Information about the subject, including "age", "sex". """ if isinstance(rec_or_sid, int): rec_or_sid = self[rec_or_sid] row = self._df_metadata[self._df_metadata["ECG_ID"] == rec_or_sid].iloc[0] else: if rec_or_sid.startswith("A"): row = self._df_metadata[self._df_metadata["ECG_ID"] == rec_or_sid].iloc[0] else: row = self._df_metadata[self._df_metadata["Patient_ID"] == rec_or_sid].iloc[0] if items is None or len(items) == 0: info_items = [ "age", "sex", ] else: info_items = items subject_info = {item: row[item.capitalize()] for item in info_items} return subject_info
[docs] def get_age(self, rec: Union[str, int]) -> int: """Get the age of the subject that the record belongs to. Parameters ---------- rec : int or str Record name or index of the record in :attr:`all_records`. Returns ------- age : int Age of the subject. """ if isinstance(rec, int): rec = self[rec] age = self._df_metadata[self._df_metadata["ECG_ID"] == rec]["Age"].iloc[0].item() return age
[docs] def get_sex(self, rec: Union[str, int]) -> str: """Get the sex of the subject that the record belongs to. Parameters ---------- rec : int or str Record name or index of the record in :attr:`all_records`. Returns ------- sex : str Sex of the subject. """ if isinstance(rec, int): rec = self[rec] sex = self._df_metadata[self._df_metadata["ECG_ID"] == rec]["Sex"].iloc[0] return sex
[docs] def get_siglen(self, rec: Union[str, int]) -> int: """Get the length of the ECG signal of the record. Parameters ---------- rec : int or str Record name or index of the record in :attr:`all_records`. Returns ------- siglen : int Length of the ECG signal of the record. """ if isinstance(rec, int): rec = self[rec] siglen = self._df_metadata[self._df_metadata["ECG_ID"] == rec]["N"].iloc[0].item() return siglen
@property def url(self) -> Dict[str, str]: return { "metadata.csv": "https://springernature.figshare.com/ndownloader/files/34793152", "code.csv": "https://springernature.figshare.com/ndownloader/files/32630954", "records.tar": "https://springernature.figshare.com/ndownloader/files/32630684", }
[docs] def download(self, files: Optional[Union[str, Sequence[str]]]) -> None: """Download the database from the figshare website.""" if files is None: files = self.url.keys() if isinstance(files, str): files = [files] assert set(files).issubset(self.url), f"`files` should be a subset of {list(self.url)}" for filename in files: url = self.url[filename] if not (self.db_dir / filename).is_file(): http_get(url, self.db_dir, filename=filename) self._ls_rec()
[docs] def plot( self, rec: Union[str, int], data: Optional[np.ndarray] = None, ann: Optional[Sequence[str]] = None, ticks_granularity: int = 0, leads: Optional[Union[str, int, List[Union[str, int]]]] = None, same_range: bool = False, waves: Optional[Dict[str, Sequence[int]]] = None, **kwargs: Any, ) -> None: """ Plot the signals of a record or external signals (units in μV), with metadata (fs, labels, tranche, etc.), possibly also along with wave delineations. Parameters ---------- rec : int or str Record name or index of the record in :attr:`all_records`. data : numpy.ndarray, optional (12-lead) ECG signal to plot. Should be of the format "channel_first", and compatible with `leads`. If not None, data of `rec` will not be used. Tthis is useful when plotting filtered data. ann : Sequence[str], optional Annotations for `data`. Ignored if `data` is None. ticks_granularity : int, default 0 Granularity to plot axis ticks, the higher the more ticks. 0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks) leads : str or int or List[str] or List[int], optional The leads of the ECG signal to plot. same_range : bool, default False If True, forces all leads to have the same y range. waves : dict, optional A dictionary containing the indices of the wave critical points, including "p_onsets", "p_peaks", "p_offsets", "q_onsets", "q_peaks", "r_peaks", "s_peaks", "s_offsets", "t_onsets", "t_peaks", "t_offsets". kwargs : dict, optional Additional keyword arguments to pass to :func:`matplotlib.pyplot.plot`. TODO ---- 1. Slice too long records, and plot separately for each segment. 2. Plot waves using :func:`matplotlib.pyplot.axvspan`. NOTE ---- `Locator` of ``plt`` has default `MAXTICKS` of 1000. If not modifying this number, at most 40 seconds of signal could be plotted once. Contributors: Jeethan, and WEN Hao """ if isinstance(rec, int): rec = self[rec] if "plt" not in dir(): import matplotlib.pyplot as plt plt.MultipleLocator.MAXTICKS = 3000 _leads = self._normalize_leads(leads, numeric=False) lead_indices = [self.all_leads.index(ld) for ld in _leads] if data is None: _data = self.load_data(rec, data_format="channel_first", units="μV")[lead_indices] else: units = self._auto_infer_units(data) self.logger.info(f"input data is auto detected to have units in `{units}`") if units.lower() == "mv": _data = 1000 * data else: _data = data assert _data.shape[0] == len(_leads), ( f"number of leads from data of shape ({_data.shape[0]}) does not " f"match the length ({len(_leads)}) of `leads`" ) if same_range: y_ranges = np.ones((_data.shape[0],)) * np.max(np.abs(_data)) + 100 else: y_ranges = np.max(np.abs(_data), axis=1) + 100 if waves: if waves.get("p_onsets", None) and waves.get("p_offsets", None): p_waves = [[onset, offset] for onset, offset in zip(waves["p_onsets"], waves["p_offsets"])] elif waves.get("p_peaks", None): p_waves = [ [ max(0, p + ms2samples(_PlotCfg.p_onset, fs=self.fs)), min( _data.shape[1], p + ms2samples(_PlotCfg.p_offset, fs=self.fs), ), ] for p in waves["p_peaks"] ] else: p_waves = [] if waves.get("q_onsets", None) and waves.get("s_offsets", None): qrs = [[onset, offset] for onset, offset in zip(waves["q_onsets"], waves["s_offsets"])] elif waves.get("q_peaks", None) and waves.get("s_peaks", None): qrs = [ [ max(0, q + ms2samples(_PlotCfg.q_onset, fs=self.fs)), min( _data.shape[1], s + ms2samples(_PlotCfg.s_offset, fs=self.fs), ), ] for q, s in zip(waves["q_peaks"], waves["s_peaks"]) ] elif waves.get("r_peaks", None): qrs = [ [ max(0, r + ms2samples(_PlotCfg.qrs_radius, fs=self.fs)), min( _data.shape[1], r + ms2samples(_PlotCfg.qrs_radius, fs=self.fs), ), ] for r in waves["r_peaks"] ] else: qrs = [] if waves.get("t_onsets", None) and waves.get("t_offsets", None): t_waves = [[onset, offset] for onset, offset in zip(waves["t_onsets"], waves["t_offsets"])] elif waves.get("t_peaks", None): t_waves = [ [ max(0, t + ms2samples(_PlotCfg.t_onset, fs=self.fs)), min( _data.shape[1], t + ms2samples(_PlotCfg.t_offset, fs=self.fs), ), ] for t in waves["t_peaks"] ] else: t_waves = [] else: p_waves, qrs, t_waves = [], [], [] palette = { "p_waves": "green", "qrs": "red", "t_waves": "yellow", } plot_alpha = 0.4 if ann is None or data is None: ann = self.load_ann(rec, ann_format="f") nb_leads = len(_leads) seg_len = self.fs * 25 # 25 seconds nb_segs = _data.shape[1] // seg_len t = np.arange(_data.shape[1]) / self.fs duration = len(t) / self.fs fig_sz_w = int(round(DEFAULT_FIG_SIZE_PER_SEC * duration)) fig_sz_h = 6 * np.maximum(y_ranges, 750) / 1500 fig, axes = plt.subplots(nb_leads, 1, sharex=False, figsize=(fig_sz_w, np.sum(fig_sz_h))) if nb_leads == 1: axes = [axes] for idx in range(nb_leads): axes[idx].plot( t, _data[idx], color="black", linewidth="2.0", label=f"lead - {_leads[idx]}", ) axes[idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red") # NOTE that `Locator` has default `MAXTICKS` equal to 1000 if ticks_granularity >= 1: axes[idx].xaxis.set_major_locator(plt.MultipleLocator(0.2)) axes[idx].yaxis.set_major_locator(plt.MultipleLocator(500)) axes[idx].grid(which="major", linestyle="-", linewidth="0.4", color="red") if ticks_granularity >= 2: axes[idx].xaxis.set_minor_locator(plt.MultipleLocator(0.04)) axes[idx].yaxis.set_minor_locator(plt.MultipleLocator(100)) axes[idx].grid(which="minor", linestyle=":", linewidth="0.2", color="gray") # add extra info. to legend # https://stackoverflow.com/questions/16826711/is-it-possible-to-add-a-string-as-a-legend-item-in-matplotlib axes[idx].plot([], [], " ", label=f"labels - {','.join(ann)}") for w in ["p_waves", "qrs", "t_waves"]: for itv in eval(w): axes[idx].axvspan(t[itv[0]], t[itv[1]], color=palette[w], alpha=plot_alpha) axes[idx].legend(loc="upper left", fontsize=14) axes[idx].set_xlim(t[0], t[-1]) axes[idx].set_ylim(min(-600, -y_ranges[idx]), max(600, y_ranges[idx])) axes[idx].set_xlabel("Time [s]", fontsize=16) axes[idx].set_ylabel("Voltage [μV]", fontsize=16) plt.subplots_adjust(hspace=0.05) fig.tight_layout() if kwargs.get("save_path", None): plt.savefig(kwargs["save_path"], dpi=200, bbox_inches="tight") else: plt.show()
@property def database_info(self) -> DataBaseInfo: return _SPH_INFO