import os
import re
import warnings
from numbers import Real
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
from scipy.io import loadmat
from ...cfg import DEFAULTS
from ...utils.download import http_get
from ...utils.misc import add_docstring, get_record_list_recursive
from ..base import DEFAULT_FIG_SIZE_PER_SEC, CPSCDataBase, DataBaseInfo
__all__ = [
_CPSC2018_INFO = DataBaseInfo(
The China Physiological Signal Challenge 2018:
Automatic identification of the rhythm/morphology abnormalities in 12-lead ECGs.
1. training set contains 6,877 (female: 3178; male: 3699) 12 leads ECG recordings lasting from 6 s to just 60 s.
2. ECG recordings were sampled as 500 Hz.
3. the training data can be downloaded using links in [1]_, but the link in [2]_ is recommended. File structure will be assumed to follow [2]_.
4. the training data are in the ``channel first`` format.
5. types of abnormal rhythm/morphology + normal in the training set:
| No. | name | abbr. | number of records |
| 0 | Normal | N | 918 |
| 1 | Atrial fibrillation | AF | 1098 |
| 2 | First-degree atrioventricular block | I-AVB | 704 |
| 3 | Left bundle brunch block | LBBB | 207 |
| 4 | Right bundle brunch block | RBBB | 1695 |
| 5 | Premature atrial contraction | PAC | 556 |
| 6 | Premature ventricular contraction | PVC | 672 |
| 7 | ST-segment depression | STD | 825 |
| 8 | ST-segment elevated | STE | 202 |
6. ordering of the leads in the data of all the records are
.. code-block:: python
["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
7. meanings in the .hea files: **to write**
8. knowledge about the abnormal rhythms: ref. :meth:`get_disease_knowledge`.
9. Challenge official website [1]_, see also [2]_.
1. Ages of records A0608, A1549, A1876, A2299, A5990 are "NaN".
2. CINC2020 (ref. [2]_) released totally 3453 unused training data of CPSC2018, whose filenames start with "Q".
These file names are not "continuous". The last record is "Q3581".
"ECG arrythmia detection",
[docs]@add_docstring(_CPSC2018_INFO.format_database_docstring(), mode="prepend")
class CPSC2018(CPSCDataBase):
db_dir : `path-like`, optional
Storage path of the database.
If not specified, data will be fetched from Physionet.
working_dir : `path-like`, optional
Working directory, to store intermediate files and log files.
verbose : int, default 1
Level of logging verbosity.
kwargs : dict, optional
Auxilliary key word arguments.
__name__ = "CPSC2018"
def __init__(
db_dir: Optional[Union[str, bytes, os.PathLike]] = None,
working_dir: Optional[Union[str, bytes, os.PathLike]] = None,
verbose: int = 1,
**kwargs: Any,
) -> None:
self.fs = 500
self.spacing = 1000 / self.fs
self.rec_ext = "mat"
self.ann_ext = "mat" # the same file as the record
self.nb_records = 6877
self.all_leads = [
self.all_diagnosis = [
self.all_diagnosis_original = sorted(
self.diagnosis_abbr_to_full = {
"N": "Normal",
"AF": "Atrial fibrillation",
"I-AVB": "First-degree atrioventricular block",
"LBBB": "Left bundle brunch block",
"RBBB": "Right bundle brunch block",
"PAC": "Premature atrial contraction",
"PVC": "Premature ventricular contraction",
"STD": "ST-segment depression",
"STE": "ST-segment elevated",
self.diagnosis_num_to_abbr = {
1: "N",
2: "AF",
3: "I-AVB",
4: "LBBB",
5: "RBBB",
6: "PAC",
7: "PVC",
8: "STD",
9: "STE",
self._all_records = None
def _ls_rec(self) -> None:
"""Find all records in the database directory
and store them (path, metadata, etc.) in some private attributes.
self._df_records = pd.DataFrame()
self._df_records["path"] = get_record_list_recursive(self.db_dir, self.rec_ext, relative=False)
if self._subsample is not None:
size = min(
max(1, int(round(self._subsample * len(self._df_records)))),
self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False)
self._df_records["path"] = self._df_records["path"].apply(lambda x: Path(x))
self._df_records["record"] = self._df_records["path"].apply(lambda x: x.name)
self._df_records.set_index("record", inplace=True)
self._all_records = self._df_records.index.values.tolist()
# find and load the annotation csv file
ann_file = list(self.db_dir.rglob("REFERENCE.csv"))
if len(ann_file) != 1:
"Annotation file not found. Please call method `_download_labels`, " "and call method `_ls_rec` again.",
for c in ["labels_n", "labels_a", "labels_f"]:
self._df_records.at[self._df_records.index, c] = None
df_ann = pd.read_csv(ann_file[0])
label_mat = df_ann[df_ann.columns[1:]].values
mask = np.isnan(label_mat)
label_mat = [row[~mask[idx]].astype(int).tolist() for idx, row in enumerate(label_mat)]
df_ann.loc[df_ann.index, "labels_n"] = df_ann.apply(lambda row: label_mat[row.name], axis=1)
df_ann.loc[df_ann.index, "labels_a"] = df_ann.apply(
lambda row: [self.diagnosis_num_to_abbr[i] for i in row.labels_n],
df_ann.loc[df_ann.index, "labels_f"] = df_ann.apply(
lambda row: [self.diagnosis_abbr_to_full[i] for i in row.labels_a],
df_ann = df_ann[[df_ann.columns[0], "labels_n", "labels_a", "labels_f"]]
df_ann.columns = ["record", "labels_n", "labels_a", "labels_f"]
df_ann.set_index("record", inplace=True)
# merge `df_ann` and `self._df_records`
self._df_records = self._df_records.merge(df_ann, how="left", left_index=True, right_index=True)
[docs] def get_subject_id(self, rec: Union[int, str]) -> int:
"""Attach a unique subject ID for the record.
rec : str or int
Record name or index of the record in :attr:`all_records`.
Subject ID associated with the record.
if isinstance(rec, int):
rec = self[rec]
s2d = {
"A": "11",
"B": "12",
prefix = "".join(re.findall(r"[A-Z]", rec))
n = rec.replace(prefix, "")
sid = int(f"{s2d[prefix]}{'0'*(8-len(n))}{n}")
return sid
[docs] def load_data(
rec: Union[int, str],
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
units: str = "mV",
return_fs: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
"""Load the ECG data of a record.
rec : int or str
Record name or index of the record in :attr:`all_records`.
leads : str or int or Sequence[str] or Sequence[int], optional
The leads to load,
None or "all" for all leads.
data_format : str, default "channel_first"
Format of the ECG data, "channel_last" or "channel_first" (original)
units : str, default "mV"
Units of the output signal, can also be "μV" (with an alias "uV"),
case insensitive.
return_fs : bool, default False
Whether to return the sampling frequency of the output signal.
data : numpy.ndarray
The loaded ECG data.
data_fs : numbers.Real, optional
Sampling frequency of the output signal.
Returned if `return_fs` is True.
if isinstance(rec, int):
rec = self[rec]
rec_fp = self.get_absolute_path(rec, self.rec_ext)
_leads = self._normalize_leads(leads, numeric=True)
allowed_data_format = [
assert (
data_format.lower() in allowed_data_format
), f"`data_format` should be one of `{allowed_data_format}`, but got `{data_format}`"
allowed_units = ["mv", "uv", "μv", "muv"]
assert (
units is None or units.lower() in allowed_units
), f"`units` should be one of `{allowed_units}` or None, but got `{units}`"
data = loadmat(str(rec_fp))
data = np.asarray(data["ECG"]["data"][0, 0], dtype=DEFAULTS.DTYPE.NP)[_leads, :]
if data_format.lower() in ["channel_last", "lead_last"]:
data = data.T
if units.lower() == "mv" and self._auto_infer_units(data) != "mV":
data /= 1000
elif units.lower() in ["uv", "μv", "muv"] and self._auto_infer_units(data) != "μV":
data *= 1000
if return_fs:
return data, self.fs
return data
[docs] def load_ann(self, rec: Union[str, int], ann_format: str = "n") -> List[str]:
"""Load labels (diagnoses or arrhythmias) of a record.
rec : str or int
Record name or index of the record in :attr:`all_records`.
ann_format : str, default "n"
Format of labels, one of the following (case insensitive):
- "a", abbreviations
- "f", full names
- "n", numeric codes
labels : List[str]
The list of labels.
if isinstance(rec, int):
rec = self[rec]
col = {
"a": "labels_a",
"f": "labels_f",
"n": "labels_n",
except KeyError:
raise ValueError(f"`ann_format` should be one of `['a', 'f', 'n']`, but got `{ann_format}`")
labels = self._df_records.loc[rec, col]
return labels
[docs] @add_docstring(load_ann.__doc__)
def get_labels(self, rec: Union[str, int], ann_format: str = "n") -> List[str]:
"""alias of `load_ann`"""
return self.load_ann(rec, ann_format)
[docs] def get_subject_info(self, rec: Union[int, str], items: Optional[List[str]] = None) -> dict:
"""Get subject information (e.g sex, age, etc.).
rec : int or str
Record name or index of the record in :attr:`all_records`.
items : List[str], optional
Items of the subject information (e.g. sex, age, etc.).
subject_info : dict
The subject information.
if items is None or len(items) == 0:
info_items = ["age", "sex"]
info_items = items
if isinstance(rec, int):
rec = self[rec]
rec_fp = self.get_absolute_path(rec, self.ann_ext)
data = loadmat(str(rec_fp))["ECG"]
subject_info = {
"age": data["age"][0, 0][0, 0],
"sex": data["sex"][0, 0][0],
return subject_info
[docs] def plot(
rec: Union[int, str],
ticks_granularity: int = 0,
leads: Optional[Union[str, List[str]]] = None,
**kwargs: Any,
) -> None:
"""Plot the ECG data of a record.
rec : int or str
Record name or index of the record in :attr:`all_records`.
ticks_granularity : int, default 0
Granularity to plot axis ticks, the higher the more ticks.
0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks)
leads : str or List[str], optional
The leads to plot
kwargs : dict, optional
Auxilliary key word arguments to pass to :func:`matplotlib.pyplot.subplots`.
if isinstance(rec, int):
rec = self[rec]
if "plt" not in dir():
import matplotlib.pyplot as plt
leads = self._normalize_leads(leads)
data = self.load_data(rec, data_format="channel_first", units="μV", leads=leads)
y_ranges = np.max(np.abs(data), axis=1) + 100
diag = self.get_labels(rec, ann_format="f")
nb_leads = len(leads)
t = np.arange(data.shape[1]) / self.fs
duration = len(t) / self.fs
fig_sz_w = int(round(DEFAULT_FIG_SIZE_PER_SEC * duration))
fig_sz_h = 6 * y_ranges / 1500
fig, axes = plt.subplots(nb_leads, 1, sharex=True, figsize=(fig_sz_w, np.sum(fig_sz_h)))
for idx in range(nb_leads):
label=f"lead - {leads[idx]}",
axes[idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red")
if ticks_granularity >= 1:
axes[idx].grid(which="major", linestyle="-", linewidth="0.5", color="red")
if ticks_granularity >= 2:
axes[idx].grid(which="minor", linestyle=":", linewidth="0.5", color="black")
axes[idx].plot([], [], " ", label=f"labels - {','.join(diag)}")
axes[idx].legend(loc="upper left")
axes[idx].set_xlim(t[0], t[-1])
axes[idx].set_ylim(-y_ranges[idx], y_ranges[idx])
axes[idx].set_xlabel("Time [s]")
axes[idx].set_ylabel("Voltage [μV]")
if kwargs.get("save_path", None):
plt.savefig(kwargs["save_path"], dpi=200, bbox_inches="tight")
def url(self) -> List[str]:
return [
def database_info(self) -> DataBaseInfo:
return _CPSC2018_INFO
def webpage(self) -> str:
return "http://2018.icbeb.org/Challenge.html"
def _download_labels(self) -> None:
label_url = "http://2018.icbeb.org/file/REFERENCE.csv"
http_get(label_url, self.db_dir, extract=False)
[docs] def download(self) -> None:
"""Download the database from :attr:`self.url`."""
for url in self.url:
http_get(url, self.db_dir, extract=True)
def compute_metrics():
""" """
raise NotImplementedError