Source code for torch_ecg.databases.base

# -*- coding: utf-8 -*-
"""
Base classes for datasets from different sources:

    - PhysioNet
    - NSRR
    - CPSC
    - Other databases

Remarks
-------
1. For whole-dataset visualizing: http://zzz.bwh.harvard.edu/luna/vignettes/dataplots/
2. Visualizing using UMAP: http://zzz.bwh.harvard.edu/luna/vignettes/nsrr-umap/

"""

import logging
import os
import posixpath
import pprint
import re
import textwrap
import time
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from numbers import Real
from pathlib import Path
from string import punctuation
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import requests
import scipy.signal as SS
import wfdb
from pyedflib import EdfReader

from ..cfg import _DATA_CACHE, CFG, DEFAULTS
from ..utils import ecg_arrhythmia_knowledge as EAK  # noqa : F401
from ..utils.download import http_get
from ..utils.misc import CitationMixin, ReprMixin, dict_to_str, get_record_list_recursive, init_logger
from .aux_data import get_physionet_dbs

__all__ = [
    "WFDB_Beat_Annotations",
    "WFDB_Non_Beat_Annotations",
    "WFDB_Rhythm_Annotations",
    "PhysioNetDataBase",
    "NSRRDataBase",
    "CPSCDataBase",
    "DEFAULT_FIG_SIZE_PER_SEC",
    "BeatAnn",
    "DataBaseInfo",
    "PSGDataBaseMixin",
]


WFDB_Beat_Annotations = {
    "N": "Normal beat",
    "L": "Left bundle branch block beat",
    "R": "Right bundle branch block beat",
    "B": "Bundle branch block beat (unspecified)",
    "A": "Atrial premature beat",
    "a": "Aberrated atrial premature beat",
    "J": "Nodal (junctional) premature beat",
    "S": "Supraventricular premature or ectopic beat (atrial or nodal)",
    "V": "Premature ventricular contraction",
    "r": "R-on-T premature ventricular contraction",
    "F": "Fusion of ventricular and normal beat",
    "e": "Atrial escape beat",
    "j": "Nodal (junctional) escape beat",
    "n": "Supraventricular escape beat (atrial or nodal)",
    "E": "Ventricular escape beat",
    "/": "Paced beat",
    "f": "Fusion of paced and normal beat",
    "Q": "Unclassifiable beat",
    "?": "Beat not classified during learning",
}

WFDB_Non_Beat_Annotations = {
    "[": "Start of ventricular flutter/fibrillation",
    "!": "Ventricular flutter wave",
    "]": "End of ventricular flutter/fibrillation",
    "x": "Non-conducted P-wave (blocked APC)",
    "(": "Waveform onset",
    ")": "Waveform end",
    "p": "Peak of P-wave",
    "t": "Peak of T-wave",
    "u": "Peak of U-wave",
    "`": "PQ junction",
    "'": "J-point",
    "^": "(Non-captured) pacemaker artifact",
    "|": "Isolated QRS-like artifact",
    "~": "Change in signal quality",
    "+": "Rhythm change",
    "s": "ST segment change",
    "T": "T-wave change",
    "*": "Systole",
    "D": "Diastole",
    "=": "Measurement annotation",
    '"': "Comment annotation",
    "@": "Link to external data",
}

WFDB_Rhythm_Annotations = {
    "(AB": "Atrial bigeminy",
    "(AFIB": "Atrial fibrillation",
    "(AFL": "Atrial flutter",
    "(B": "Ventricular bigeminy",
    "(BII": "2° heart block",
    "(IVR": "Idioventricular rhythm",
    "(N": "Normal sinus rhythm",
    "(NOD": "Nodal (A-V junctional) rhythm",
    "(P": "Paced rhythm",
    "(PREX": "Pre-excitation (WPW)",
    "(SBR": "Sinus bradycardia",
    "(SVTA": "Supraventricular tachyarrhythmia",
    "(T": "Ventricular trigeminy",
    "(VFL": "Ventricular flutter",
    "(VT": "Ventricular tachycardia",
}


class _DataBase(ReprMixin, ABC):
    """Universal abstract base class for all databases.

    Parameters
    ----------
    db_dir : `path-like`, optional
        Storage path of the database.
        If not specified, data will be fetched from Physionet.
    working_dir : `path-like`, optional
        Working directory, to store intermediate files and log files.
    verbose : int, default 1
        Level of logging verbosity.
    kwargs : dict, optional
        Auxilliary key word arguments

    """

    def __init__(
        self,
        db_name: str,
        db_dir: Optional[Union[str, bytes, os.PathLike]] = None,
        working_dir: Optional[Union[str, bytes, os.PathLike]] = None,
        verbose: int = 1,
        **kwargs: Any,
    ) -> None:
        self.db_name = db_name
        if db_dir is None:
            db_dir = _DATA_CACHE / db_name
            warnings.warn(
                f"`db_dir` is not specified, " f"using default `{db_dir}` as the storage path",
                RuntimeWarning,
            )
        self.db_dir = Path(db_dir).expanduser().resolve().absolute()
        if not self.db_dir.exists():
            self.db_dir.mkdir(parents=True, exist_ok=True)
            warnings.warn(
                f"`{self.db_dir}` does not exist. It is now created. "
                "Please check if it is set correctly. "
                "Or if you may want to download the database into this folder, "
                "please use the `download()` method.",
                RuntimeWarning,
            )
        self.working_dir = Path(working_dir or DEFAULTS.working_dir).expanduser().resolve().absolute() / self.db_name
        self.working_dir.mkdir(parents=True, exist_ok=True)

        self.logger = kwargs.get("logger", None)
        if self.logger is None:
            self.logger = init_logger(
                log_dir=False,
                suffix=self.__class__.__name__,
                verbose=verbose,
            )
        else:
            assert isinstance(self.logger, logging.Logger), "logger must be a `logging.Logger` instance"

        self.data_ext = None
        self.ann_ext = None
        self.header_ext = "hea"
        self.verbose = verbose
        self._df_records = pd.DataFrame()
        self._all_records = None

        self._subsample = kwargs.get("subsample", None)
        assert (
            self._subsample is None or 0 < self._subsample <= 1
        ), f"`subsample` must be in (0, 1], but got `{self._subsample}`"

    @abstractmethod
    def _ls_rec(self) -> None:
        """Find all records in the database."""
        raise NotImplementedError

    @abstractmethod
    def load_data(self, rec: Union[str, int], **kwargs) -> Any:
        """Load data from the record."""
        raise NotImplementedError

    @abstractmethod
    def load_ann(self, rec: Union[str, int], **kwargs) -> Any:
        """Load annotations of the record.

        NOTE that the records might have several annotation files.
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def database_info(self) -> "DataBaseInfo":
        """The :class:`DataBaseInfo` object of the database."""
        raise NotImplementedError

    def get_citation(self, format: Optional[str] = None, style: Optional[str] = None) -> None:
        """Get the citations of the papers related to the database.

        Parameters
        ----------
        lookup : bool, default True
            Whether to lookup the citation from the DOI or not.
        format : str, optional
            Format of the final output
            If specified, the default format ("bib") will be overrided.
        style : str, optional
            Style of the final output.
            If specified, the default style ("apa") will be overrided.
            Valid only when `format` is ``"text"``.
        print_result : bool, default False
            Whether to print the final output
            instead of returning it or not.

        Returns
        -------
        None

        """
        self.database_info.get_citation(lookup=True, format=format, style=style, timeout=10.0, print_result=True)

    def _auto_infer_units(self, sig: np.ndarray, sig_type: str = "ECG") -> str:
        """Automatically infer the units of the signal.

        It is assumed that `sig` is not raw signal, but with baseline removed.

        Parameters
        ----------
        sig : ndarray
            The signal to infer its units.
        sig_type : str, default "ECG"
            Type of the signal, case insensitive.

        Returns
        -------
        units : {"μV", "mV"}
            Units of the signal.

        """
        if sig_type.lower() == "ecg":
            _MAX_mV = 20  # 20mV, seldom an ECG device has range larger than this value
            max_val = np.max(np.abs(sig))
            if max_val > _MAX_mV:
                units = "μV"
            else:
                units = "mV"
        else:
            raise NotImplementedError(f"not implemented for {sig_type}")
        return units

    @property
    def all_records(self) -> List[str]:
        if self._all_records is None:
            self._ls_rec()
        return self._all_records

    def get_absolute_path(self, rec: Union[str, int], extension: Optional[str] = None) -> Path:
        """Get the absolute path of the record.

        Parameters
        ----------
        rec : str or int
            Record name or index of the record in :attr:`all_records`.
        extension : str, optional
            Extension of the file.

        Returns
        -------
        path : pathlib.Path
            Absolute path of the file.

        """
        if isinstance(rec, int):
            rec = self[rec]
        path = self._df_records.loc[rec].path
        if extension is not None:
            path = path.with_suffix(extension if extension.startswith(".") else f".{extension}")
        return path

    def _normalize_leads(
        self,
        leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
        all_leads: Optional[Sequence[str]] = None,
        numeric: bool = False,
    ) -> List[Union[str, int]]:
        """Normalize the leads to a list of standard lead names.

        Parameters
        ----------
        leads : str or int or List[str] or List[int], optional
            the (names of) leads to normalize
        all_leads : list of str, optional
            All leads of the records in the database,
            If is None, the database class should have attribute `all_leads`,
            and `self.all_leads` will be used.
        numeric : bool, default False
            If True, indices of the leads will be returned
            instead of lead names.

        Returns
        -------
        leads : List[str] or List[int]
            The normalized leads

        """
        if all_leads is None:
            assert hasattr(
                self, "all_leads"
            ), "If `all_leads` is not specified, the database class should have attribute `all_leads`!"
            all_leads = self.all_leads
        err_msg = (
            f"`leads` should be a subset of {all_leads} or non-negative integers "
            f"less than {len(all_leads)}, but got {leads}"
        )
        if leads is None or (isinstance(leads, str) and leads.lower() == "all"):
            _leads = all_leads
        elif isinstance(leads, str):
            _leads = [leads]
        elif isinstance(leads, int):
            assert len(all_leads) > leads >= 0, err_msg
            _leads = [all_leads[leads]]
        else:
            try:
                _leads = [ld if isinstance(ld, str) else all_leads[ld] for ld in leads]
            except Exception:
                raise AssertionError(err_msg)
        assert set(_leads).issubset(all_leads), err_msg
        if numeric:
            _leads = [all_leads.index(ld) for ld in _leads]
        return _leads

    @classmethod
    def get_arrhythmia_knowledge(cls, arrhythmias: Union[str, List[str]]) -> None:
        """Knowledge about ECG features of specific arrhythmias.

        Parameters
        ----------
        arrhythmias : str or List[str]
            The arrhythmia(s) to check,
            in abbreviations or in SNOMEDCTCode.

        Returns
        -------
        None

        """
        if isinstance(arrhythmias, str):
            d = [arrhythmias]
        else:
            d = arrhythmias
        for idx, item in enumerate(d):
            print(dict_to_str(eval(f"EAK.{item}")))
            if idx < len(d) - 1:
                print("*" * 110)

    def extra_repr_keys(self) -> List[str]:
        return [
            "db_name",
            "db_dir",
        ]

    @property
    @abstractmethod
    def url(self) -> Union[str, List[str]]:
        """URL(s) for downloading the database."""
        raise NotImplementedError

    def __len__(self) -> int:
        return len(self.all_records)

    def __getitem__(self, index: int) -> str:
        return self.all_records[index]


[docs]class PhysioNetDataBase(_DataBase): """Base class for readers for PhysioNet database. PhysioNet is a large repository of freely available biomedical signals, including ECG, EEG, EMG, and other signals. The website is [#phy_website]_. Parameters ---------- db_name : str Name of the database. db_dir : `path-like`, optional Storage path of the database. If is None, `wfdb` will fetch data from PhysioNet. working_dir : `path-like`, optional Working directory, to store intermediate files and log files. verbose : int, default 1 Verbosity level for logging. kwargs : dict, optional Auxilliary key word arguments. References ---------- .. [#phy_website] https://www.physionet.org/ """ def __init__( self, db_name: str, db_dir: Optional[Union[str, bytes, os.PathLike]] = None, working_dir: Optional[Union[str, bytes, os.PathLike]] = None, verbose: int = 1, **kwargs: Any, ) -> None: super().__init__( db_name=db_name, db_dir=db_dir, working_dir=working_dir, verbose=verbose, **kwargs, ) # `self.fs` for those with single signal source, e.g. ECG, # for those with multiple signal sources like PSG, # `self.fs` is default to the frequency of ECG if ECG applicable self.fs = kwargs.get("fs", None) self._all_records = None self._version = None self._url_compressed = None self.df_all_db_info = get_physionet_dbs() if self.verbose > 2: self.df_all_db_info = ( pd.DataFrame( wfdb.get_dbs(), columns=[ "db_name", "db_description", ], ) .drop_duplicates() .reset_index(drop=True) ) def _ls_rec(self, db_name: Optional[str] = None, local: bool = True) -> None: """ Find all records (relative path without file extension), and save into some private attributes for further use. Parameters ---------- db_name : str, optional Name of the database for using :meth:`wfdb.get_record_list`. If is None, :attr:`self.db_name` will be used. local : bool, default True If True, search records in local storage, prior using :meth:`wfdb.get_record_list`. Returns ------- None """ empty_warning_msg = ( "No records found in the database! " "Please check if path to the database is correct. " "Or you can try to download the database first using the `download` method." ) if local: self._ls_rec_local() if len(self._df_records) == 0: warnings.warn(empty_warning_msg, RuntimeWarning) return try: self._df_records = pd.DataFrame() self._df_records["record"] = wfdb.get_record_list(db_name or self.db_name) self._df_records["path"] = self._df_records["record"].apply(lambda x: (self.db_dir / x).resolve()) # keep only the records that exist in `self.db_dir` # NOTE # 1. data files might be in some subdirectories of `self.db_dir` # 2. `wfdb.get_record_list` will return records without file extension self._df_records = self._df_records[self._df_records["path"].apply(lambda x: len(x.parent.glob(f"{x.name}.*")) > 0)] # if no record found, # search locally and recursively inside `self.db_dir` if len(self._df_records) == 0: return self._ls_rec_local() self._df_records["record"] = self._df_records["path"].apply( lambda x: x.name ) # remove relative path, leaving only the record name self._df_records.set_index("record", inplace=True) if self._subsample is not None: size = min( len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._all_records = self._df_records.index.tolist() except Exception: self._ls_rec_local() if len(self._df_records) == 0: warnings.warn(empty_warning_msg, RuntimeWarning) def _ls_rec_local(self) -> None: """Find all records in :attr:`self.db_dir`.""" record_list_fp = self.db_dir / "RECORDS" self._df_records = pd.DataFrame() if record_list_fp.is_file(): self._df_records["record"] = [item for item in record_list_fp.read_text().splitlines() if len(item) > 0] if len(self._df_records) > 0: if self._subsample is not None: size = min( len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._df_records["path"] = self._df_records["record"].apply(lambda x: (self.db_dir / x).resolve()) self._df_records = self._df_records[self._df_records["path"].apply(lambda x: x.is_file())] self._df_records["record"] = self._df_records["path"].apply(lambda x: x.name) if len(self._df_records) == 0: print("Please wait patiently to let the reader find " "all records of the database from local storage...") start = time.time() self._df_records["path"] = get_record_list_recursive(self.db_dir, self.data_ext, relative=False) if self._subsample is not None: size = min( len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._df_records["path"] = self._df_records["path"].apply(lambda x: Path(x)) self.logger.info(f"Done in {time.time() - start:.3f} seconds!") self._df_records["record"] = self._df_records["path"].apply(lambda x: x.name) self._df_records.set_index("record", inplace=True) self._all_records = self._df_records.index.values.tolist()
[docs] def get_subject_id(self, rec: Union[str, int]) -> int: """Attach a unique subject ID for the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. Returns ------- int Subject ID associated with the record. """ raise NotImplementedError
[docs] def load_data( self, rec: Union[str, int], leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", units: Union[str, type(None)] = "mV", fs: Optional[Real] = None, return_fs: bool = False, ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: """Load physical (converted from digital) ECG data, which is more understandable for humans; or load digital signal directly. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. leads : str or int or Sequence[str] or Sequence[int], optional The leads of the ECG data to load. None or "all" for all leads. sampfrom : int, optional Start index of the data to be loaded. sampto : int, optional End index of the data to be loaded. data_format : str, default "channel_first" Format of the ECG data, "channel_last" (alias "lead_last"), or "channel_first" (alias "lead_first"), or "flat" (alias "plain") which is valid only when `leads` is a single lead units : str or None, default "mV" Units of the output signal, can also be "μV" (aliases "uV", "muV"). None for digital data, without digital-to-physical conversion. fs : numbers.Real, optional Sampling frequency of the output signal. If not None, the loaded data will be resampled to this frequency; if None, `self.fs` will be used if available and not None; otherwise, the original sampling frequency will be used. return_fs : bool, default False Whether to return the sampling frequency of the output signal. Returns ------- data : numpy.ndarray The ECG data loaded from the record, with given `units` and `data_format`. data_fs : numbers.Real, optional Sampling frequency of the output signal. Returned if `return_fs` is True. """ fp = str(self.get_absolute_path(rec)) if hasattr(self, "all_leads"): all_leads = self.all_leads else: all_leads = wfdb.rdheader(fp).sig_name _leads = self._normalize_leads(leads, all_leads, numeric=False) allowed_data_format = [ "channel_first", "lead_first", "channel_last", "lead_last", "flat", "plain", ] assert ( data_format.lower() in allowed_data_format ), f"`data_format` should be one of `{allowed_data_format}`, but got `{data_format}`" if len(_leads) > 1: assert data_format.lower() in [ "channel_first", "lead_first", "channel_last", "lead_last", ], ( "`data_format` should be one of `['channel_first', 'lead_first', 'channel_last', 'lead_last']` " f"when the passed number of `leads` is larger than 1, but got `{data_format}`" ) allowed_units = ["mv", "uv", "μv", "muv"] assert ( units is None or units.lower() in allowed_units ), f"`units` should be one of `{allowed_units}` or None, but got `{units}`" rdrecord_kwargs = dict( sampfrom=sampfrom or 0, sampto=sampto, physical=units is not None, return_res=DEFAULTS.DTYPE.INT, channels=[all_leads.index(ld) for ld in _leads], ) # use `channels` instead of `channel_names` since there're exceptional cases where `channel_names` has duplicates wfdb_rec = wfdb.rdrecord(fp, **rdrecord_kwargs) # p_signal or d_signal is in the format of "lead_last", and with units in "mV" if units is None: data = wfdb_rec.d_signal elif units.lower() == "mv": data = wfdb_rec.p_signal elif units.lower() in ["μv", "uv", "muv"]: data = 1000 * wfdb_rec.p_signal if fs is not None: data_fs = fs elif hasattr(self, "fs"): data_fs = self.fs else: data_fs = wfdb_rec.fs if data_fs != wfdb_rec.fs: data = SS.resample_poly(data, data_fs, wfdb_rec.fs, axis=0).astype(data.dtype) if data_format.lower() in ["channel_first", "lead_first"]: data = data.T elif data_format.lower() in ["flat", "plain"]: data = data.flatten() if return_fs: return data, data_fs return data
[docs] def helper(self, items: Union[List[str], str, type(None)] = None) -> None: """Print corr. meanings of symbols belonging to `items`. More details can be found in the PhysioNet WFDB annotation manual [#ann_man]_. Parameters ---------- items : str or List[str], optional Items to print. If is None, then a comprehensive printing of meanings of all symbols will be performed. Returns ------- None References ---------- .. [#ann_man] https://archive.physionet.org/physiobank/annotations.shtml """ attrs = vars(self) methods = [ func for func in dir(self) if callable(getattr(self, func)) and not (func.startswith("__") and func.endswith("__")) ] beat_annotations = deepcopy(WFDB_Beat_Annotations) non_beat_annotations = deepcopy(WFDB_Non_Beat_Annotations) rhythm_annotations = deepcopy(WFDB_Rhythm_Annotations) all_annotations = [ beat_annotations, non_beat_annotations, rhythm_annotations, ] summary_items = [ "beat", "non-beat", "rhythm", ] if items is None: _items = [ "attributes", "methods", "beat", "non-beat", "rhythm", ] elif isinstance(items, str): _items = [items] else: _items = items pp = pprint.PrettyPrinter(indent=4) if "attributes" in _items: print("--- helpler - attributes ---") pp.pprint(attrs) if "methods" in _items: print("--- helpler - methods ---") pp.pprint(methods) if "beat" in _items: print("--- helpler - beat ---") pp.pprint(beat_annotations) if "non-beat" in _items: print("--- helpler - non-beat ---") pp.pprint(non_beat_annotations) if "rhythm" in _items: print("--- helpler - rhythm ---") pp.pprint(rhythm_annotations) for k in _items: if k in summary_items: continue for a in all_annotations: if k in a.keys() or "(" + k in a.keys(): try: print(f"`{k.split('(')[1]}` stands for `{a[k]}`") except IndexError: try: print(f"`{k}` stands for `{a[k]}`") except KeyError: print(f"`{k}` stands for `{a['('+k]}`")
[docs] def get_file_download_url(self, file_name: Union[str, bytes, os.PathLike]) -> str: """Get the download url of the file. Parameters ---------- file_name : `path-like` Name of the file, e.g. "data/001a.dat", "training/tr03-0005/tr03-0005.mat", etc. Returns ------- url : str URL of the file to be downloaded. """ url = posixpath.join( wfdb.io.download.PN_INDEX_URL, self.db_name, self.version, file_name, ) return url
@property def version(self) -> str: """Version of the database.""" if self._version is not None: return self._version try: self._version = wfdb.io.record.get_version(self.db_name) except Exception: warnings.warn( "Cannot get the version number from PhysioNet! Defaults to '1.0.0'", RuntimeWarning, ) self._version = "1.0.0" return self._version @property def webpage(self) -> str: """URL of the database webpage""" return posixpath.join(wfdb.io.download.PN_CONTENT_URL, f"{self.db_name}/{self.version}") @property def url(self) -> str: """URL of the database index page for downloading.""" return posixpath.join(wfdb.io.download.PN_INDEX_URL, f"{self.db_name}/{self.version}") @property def url_(self) -> Union[str, type(None)]: """URL of the compressed database file for downloading.""" if self._url_compressed is not None: return self._url_compressed domain = "https://physionet.org/static/published-projects/" punct = re.sub("[\\-:]", "", punctuation) try: db_desc = self.df_all_db_info[self.df_all_db_info["db_name"] == self.db_name].iloc[0]["db_description"] except IndexError: self.logger.info(f"\042{self.db_name}\042 is not in the database list hosted at PhysioNet!") return None db_desc = re.sub(f"[{punct}]+", "", db_desc).lower() db_desc = re.sub("[\\s:]+", "-", db_desc) url = posixpath.join(domain, f"{self.db_name}/{db_desc}-{self.version}.zip") if requests.head(url).headers.get("Content-Type") == "application/zip": self._url_compressed = url else: new_url = posixpath.join(wfdb.io.download.PN_INDEX_URL, f"{self.db_name}/get-zip/{self.version}") print(f"{url} is not available, try {new_url} instead") return self._url_compressed
[docs] def download(self, compressed: bool = True) -> None: """Download the database from PhysioNet.""" if compressed: if self.url_ is not None: http_get(self.url_, self.db_dir, extract=True) self._ls_rec() return else: self.logger.info("No compressed database available! Downloading the uncompressed version...") wfdb.dl_database( self.db_name, self.db_dir, keep_subdirs=True, overwrite=False, ) self._ls_rec()
[docs]class NSRRDataBase(_DataBase): """Base class for readers for the NSRR database. For a full list of available databases, and their descriptions, please visit the NSRR database webpage [1]_. Parameters ---------- db_name : str Name of the database. db_dir : `path-like`, optional Local storage path of the database. working_dir : `path-like`, optional Working directory, to store intermediate files and log files. verbose : int, default 1 Verbosity level for logging. kwargs : dict, optional Auxilliary key word arguments. References ---------- .. [1] https://sleepdata.org/ """ def __init__( self, db_name: str, db_dir: Optional[Union[str, bytes, os.PathLike]] = None, working_dir: Optional[Union[str, bytes, os.PathLike]] = None, verbose: int = 1, **kwargs: Any, ) -> None: super().__init__( db_name=db_name, db_dir=db_dir, working_dir=working_dir, verbose=verbose, **kwargs, ) self.fs = kwargs.get("fs", None) self._all_records = None self.file_opened = None all_dbs = [ [ "shhs", "Multi-cohort study focused on sleep-disordered breathing and cardiovascular outcomes", ], ["mesa", ""], ["oya", ""], [ "chat", "Multi-center randomized trial comparing early adenotonsillectomy to " "watchful waiting plus supportive care", ], [ "heartbeat", "Multi-center Phase II randomized controlled trial that evaluates the effects " "of supplemental nocturnal oxygen or Positive Airway Pressure (PAP) therapy", ], # more to be added ] self.df_all_db_info = pd.DataFrame( { "db_name": [item[0] for item in all_dbs], "db_description": [item[1] for item in all_dbs], } ) self.kwargs = kwargs
[docs] def safe_edf_file_operation( self, operation: str = "close", full_file_path: Optional[Union[str, bytes, os.PathLike]] = None, ) -> None: """Safe IO operation for edf file. Parameters ---------- operation : {"open", "close"}, optional Operation name, by default "close". full_file_path : `path-like`, optional Path of the file which contains the data. If is None, default path will be used. Returns ------- None Raises ------ ValueError If the operation is not supported. """ if operation == "open": if self.file_opened is not None: self.file_opened._close() self.file_opened = EdfReader(str(full_file_path)) elif operation == "close": if self.file_opened is not None: self.file_opened._close() self.file_opened = None else: raise ValueError("Illegal operation")
[docs] def get_subject_id(self, rec: Union[str, int]) -> int: """Attach a unique subject ID for the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. Returns ------- int Subject ID associated with the record. """ raise NotImplementedError
[docs] def show_rec_stats(self, rec: Union[str, int]) -> None: """Print the statistics about the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. Returns ------- None """ raise NotImplementedError
[docs] def helper(self, items: Union[List[str], str, type(None)] = None) -> None: """Print corr. meanings of symbols belonging to `items`. Parameters ---------- items : str or List[str], optional Items to print. If is None, then a comprehensive printing of meanings of all symbols will be performed. Returns ------- None """ pp = pprint.PrettyPrinter(indent=4) attrs = vars(self) methods = [ func for func in dir(self) if callable(getattr(self, func)) and not (func.startswith("__") and func.endswith("__")) ] if items is None: _items = [ "attributes", "methods", ] elif isinstance(items, str): _items = [items] else: _items = items pp = pprint.PrettyPrinter(indent=4) if "attributes" in _items: print("--- helpler - attributes ---") pp.pprint(attrs) if "methods" in _items: print("--- helpler - methods ---") pp.pprint(methods)
[docs]class CPSCDataBase(_DataBase): """Base class for readers for the CPSC database. Parameters ---------- db_name : str Name of the database. db_dir : `path-like`, optional Local storage path of the database. working_dir : `path-like`, optional Working directory, to store intermediate files and log files. verbose : int, default 1 Verbosity level for logging. kwargs : dict, optional Auxilliary key word arguments. """ def __init__( self, db_name: str, db_dir: Optional[Union[str, bytes, os.PathLike]] = None, working_dir: Optional[Union[str, bytes, os.PathLike]] = None, verbose: int = 1, **kwargs: Any, ) -> None: super().__init__( db_name=db_name, db_dir=db_dir, working_dir=working_dir, verbose=verbose, **kwargs, ) self.fs = kwargs.get("fs", None) self._all_records = None self.kwargs = kwargs
[docs] def get_subject_id(self, rec: Union[str, int]) -> int: """Attach a unique subject ID for the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. Returns ------- int Subject ID associated with the record. """ raise NotImplementedError
[docs] def helper(self, items: Union[List[str], str, type(None)] = None) -> None: """Print corr. meanings of symbols belonging to `items`. Parameters ---------- items : str or List[str], optional Items to print. If is None, then a comprehensive printing of meanings of all symbols will be performed. Returns ------- None """ pp = pprint.PrettyPrinter(indent=4) attrs = vars(self) methods = [ func for func in dir(self) if callable(getattr(self, func)) and not (func.startswith("__") and func.endswith("__")) ] if items is None: _items = [ "attributes", "methods", ] elif isinstance(items, str): _items = [items] else: _items = items pp = pprint.PrettyPrinter(indent=4) if "attributes" in _items: print("--- helpler - attributes ---") pp.pprint(attrs) if "methods" in _items: print("--- helpler - methods ---") pp.pprint(methods)
[docs] def download(self) -> None: """Download the database from `self.url`.""" if isinstance(self.url, str): http_get(self.url, self.db_dir, extract=True) else: for url in self.url: http_get(url, self.db_dir, extract=True) self._ls_rec()
@dataclass class DataBaseInfo(CitationMixin): """A dataclass to store the information of a database. Attributes ---------- title : str Title of the database. about : str or list of str Description of the database. usage : list of str Potential usages of the database. references : list of str References of the database. note : str or list of str, optional Notes of the database. issues : str or list of str, optional Known issues of the database. status : str, optional Status of the database. doi : str or list of str, optional DOI of the paper(s) describing the database. """ title: str about: Union[str, Sequence[str]] usage: Sequence[str] references: Sequence[str] note: Optional[Union[str, Sequence[str]]] = None issues: Optional[Union[str, Sequence[str]]] = None status: Optional[str] = None doi: Optional[Union[str, Sequence[str]]] = None def format_database_docstring(self, indent: Optional[str] = None) -> str: """Format the database docstring from the information stored in the dataclass. The docstring will use the reStructuredText format. Parameters ---------- indent : str, optional Indent of the docstring. If not specified, then 4 spaces will be used. Returns ------- str The formatted docstring. NOTE ---- An environment variable ``DB_BIB_LOOKUP`` can be set to ``True`` to enable the lookup of the bib entries. """ if indent is None: indent = " " * 4 title = textwrap.dedent(self.title).strip("\n ") if isinstance(self.about, str): about = "ABOUT\n-----\n" + textwrap.dedent(self.about).strip("\n ") else: about = ["ABOUT", "-----"] + [f"{idx+1}. {line}" for idx, line in enumerate(self.about)] about = "\n".join(about) if self.note is None: # note = "NOTE\n----" note = "" elif isinstance(self.note, str): note = "NOTE\n----\n" + textwrap.dedent(self.note).strip("\n ") else: note = ["NOTE", "----"] + [f"{idx+1}. {line}" for idx, line in enumerate(self.note)] note = "\n".join(note) if self.issues is None: # issues = "Issues\n------" issues = "" elif isinstance(self.issues, str): issues = "Issues\n------\n" + textwrap.dedent(self.issues).strip("\n ") else: issues = ["Issues", "-" * 6] + [f"{idx+1}. {line}" for idx, line in enumerate(self.issues)] issues = "\n".join(issues) references = ["References", "-" * 10] + [ # f"""{idx+1}. <a name="ref{idx+1}"></a> {line}""" f""".. [{idx+1}] {line}""" for idx, line in enumerate(self.references) ] references = "\n".join(references) usage = ["Usage", "------"] + [f"{idx+1}. {line}" for idx, line in enumerate(self.usage)] usage = "\n".join(usage) docstring = textwrap.indent( f"""\n{title}\n\n{about}\n\n{note}\n\n{usage}\n\n{issues}\n\n{references}\n""", indent, ) if self.status is not None and len(self.status) > 0: docstring = f"{self.status}\n\n{docstring}" lookup = os.getenv("DB_BIB_LOOKUP", False) citation = self.get_citation(lookup=lookup, print_result=False) if citation.startswith("@"): citation = textwrap.indent(citation, indent) citation = textwrap.indent(f"""Citation\n--------\n.. code-block:: bibtex\n\n{citation}""", indent) docstring = f"{docstring}\n\n{citation}\n" elif not lookup: citation = textwrap.indent(f"""Citation\n--------\n{citation}""", indent) docstring = f"{docstring}\n\n{citation}\n" return docstring
[docs]class PSGDataBaseMixin: """A mixin class for PSG databases. Contains methods for - convertions between sleep stage intervals and sleep stage masks - hypnogram plotting """
[docs] def sleep_stage_intervals_to_mask( self, intervals: Dict[str, List[List[int]]], fs: Optional[int] = None, granularity: int = 30, class_map: Optional[Dict[str, int]] = None, ) -> np.ndarray: """Convert sleep stage intervals to sleep stage mask. Parameters ---------- intervals : dict Sleep stage intervals, in the format of dict of list of lists of int. Keys are sleep stages and values are lists of lists of start and end indices of the sleep stages. fs : int, optional Sampling frequency corresponding to the sleep stage intervals, defaults to the sampling frequency of the database. granularity : int, default 30 Granularity of the sleep stage mask, with units in seconds. class_map : dict, optional A dictionary mapping sleep stages to integers. If the database reader does not have a `sleep_stage_names` attribute, this parameter must be provided. Returns ------- numpy.ndarray Sleep stage mask. """ fs = fs or self.fs assert fs is not None and fs > 0, "`fs` must be positive" assert granularity > 0, "`granularity` must be positive" if not hasattr(self, "sleep_stage_names"): assert class_map is not None, "`class_map` must be provided" else: class_map = class_map or {k: len(self.sleep_stage_names) - i - 1 for i, k in enumerate(self.sleep_stage_names)} intervals = { class_map[k]: [[int(round(s / fs / granularity)), int(round(e / fs / granularity))] for s, e in v] for k, v in intervals.items() } intervals = {k: [[s, e] for s, e in v if s < e] for k, v in intervals.items()} intervals = {k: v for k, v in intervals.items() if len(v) > 0} siglen = max([e for v in intervals.values() for s, e in v]) mask = np.zeros(siglen, dtype=int) for k, v in intervals.items(): for s, e in v: mask[s:e] = k return mask
[docs] def plot_hypnogram( self, mask: np.ndarray, granularity: int = 30, class_map: Optional[Dict[str, int]] = None, **kwargs, ) -> tuple: """Hypnogram visualization. Parameters ---------- mask : numpy.ndarray Sleep stage mask. granularity : int, default 30 Granularity of the sleep stage mask to be plotted, with units in seconds. class_map : dict, optional A dictionary mapping sleep stages to integers. If the database reader does not have a `sleep_stage_names` attribute, this parameter must be provided. kwargs : dict, optional Additional keyword arguments passed to :meth:`matplotlib.pyplot.plot`. Returns ------- fig : matplotlib.figure.Figure Figure object. ax : matplotlib.axes.Axes Axes object. """ if not hasattr(self, "sleep_stage_names"): pass else: class_map = class_map or {k: len(self.sleep_stage_names) - i - 1 for i, k in enumerate(self.sleep_stage_names)} if "plt" not in globals(): import matplotlib.pyplot as plt fig_width = len(mask) * granularity / 3600 / 6 * 20 # stardard width is 20 for 6 hours fig, ax = plt.subplots(figsize=(fig_width, 4)) color = kwargs.pop("color", "black") ax.plot(mask, color=color, **kwargs) # xticks to the format of HH:MM, every half hour xticks = np.arange(0, len(mask), 1800 / granularity) xticklabels = [f"{int(i * granularity / 3600):02d}:{int(i * granularity / 60 % 60):02d}" for i in xticks] ax.set_xticks(xticks) ax.set_xticklabels(xticklabels, fontsize=14) ax.set_xlabel("Time", fontsize=18) ax.set_xlim(0, len(mask)) # yticks to the format of sleep stages yticks = sorted(class_map.values()) yticklabels = [k for k, v in sorted(class_map.items(), key=lambda x: x[1])] ax.set_yticks(yticks) ax.set_yticklabels(yticklabels, fontsize=14) ax.set_ylabel("Sleep Stage", fontsize=18) return fig, ax
DEFAULT_FIG_SIZE_PER_SEC = 4.8
[docs]@dataclass class BeatAnn: """Dataclass for beat annotation. Attributes ---------- index : int Index of the beat. symbol : str Symbol of the beat. Properties ---------- name : str Name of the beat. """ index: int symbol: str @property def name(self) -> str: if self.symbol in WFDB_Beat_Annotations: return WFDB_Beat_Annotations[self.symbol] return WFDB_Non_Beat_Annotations.get(self.symbol, self.symbol)
# configurations for visualization _PlotCfg = CFG() # used only when corr. values are absent # all values are time bias w.r.t. corr. peaks, with units in ms _PlotCfg.p_onset = -40 _PlotCfg.p_offset = 40 _PlotCfg.q_onset = -20 _PlotCfg.s_offset = 40 _PlotCfg.qrs_radius = 60 _PlotCfg.t_onset = -100 _PlotCfg.t_offset = 60