# -*- coding: utf-8 -*-
import os
from copy import deepcopy
from numbers import Real
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
import wfdb
from ...cfg import CFG, DEFAULTS
from ...utils import EAK
from ...utils.misc import add_docstring
from ...utils.utils_data import ECGWaveForm, masks_to_waveforms
from ..base import DataBaseInfo, PhysioNetDataBase
__all__ = [
"LUDB",
"compute_metrics",
]
_LUDB_INFO = DataBaseInfo(
title="""
Lobachevsky University Electrocardiography Database
""",
about="""
1. consist of 200 10-second conventional 12-lead (i, ii, iii, avr, avl, avf, v1, v2, v3, v4, v5, v6) ECG signal records, with sampling frequency 500 Hz
2. boundaries of P, T waves and QRS complexes were manually annotated by cardiologists, and with the corresponding diagnosis
3. annotated are 16797 P waves, 21966 QRS complexes, 19666 T waves (in total, 58429 annotated waves)
4. distributions of data:
- rhythm distribution:
+-------------------------+----------------+
| Rhythms | Number of ECGs |
+=========================+================+
| Sinus rhythm | 143 |
+-------------------------+----------------+
| Sinus tachycardia | 4 |
+-------------------------+----------------+
| Sinus bradycardia | 25 |
+-------------------------+----------------+
| Sinus arrhythmia | 8 |
+-------------------------+----------------+
| Irregular sinus rhythm | 2 |
+-------------------------+----------------+
| Abnormal rhythm | 19 |
+-------------------------+----------------+
- electrical axis distribution:
+----------------------------+----------------+
| Heart electric axis | Number of ECGs |
+============================+================+
| Normal | 75 |
+----------------------------+----------------+
| Left axis deviation (LAD) | 66 |
+----------------------------+----------------+
| Vertical | 26 |
+----------------------------+----------------+
| Horizontal | 20 |
+----------------------------+----------------+
| Right axis deviation (RAD) | 3 |
+----------------------------+----------------+
| Undetermined | 10 |
+----------------------------+----------------+
- distribution of records with conduction abnomalities (totally 79):
+------------------------------------------------+----------------+
| Conduction abnormalities | Number of ECGs |
+================================================+================+
| Sinoatrial blockade, undetermined | 1 |
+------------------------------------------------+----------------+
| I degree AV block | 10 |
+------------------------------------------------+----------------+
| III degree AV-block | 5 |
+------------------------------------------------+----------------+
| Incomplete right bundle branch block | 29 |
+------------------------------------------------+----------------+
| Incomplete left bundle branch block | 6 |
+------------------------------------------------+----------------+
| Left anterior hemiblock | 16 |
+------------------------------------------------+----------------+
| Complete right bundle branch block | 4 |
+------------------------------------------------+----------------+
| Complete left bundle branch block | 4 |
+------------------------------------------------+----------------+
| Non-specific intravintricular conduction delay | 4 |
+------------------------------------------------+----------------+
- distribution of records with extrasystoles (totally 35):
+------------------------------------------------------------------+----------------+
| Extrasystoles | Number of ECGs |
+==================================================================+================+
| Atrial extrasystole, undetermined | 2 |
+------------------------------------------------------------------+----------------+
| Atrial extrasystole, low atrial | 1 |
+------------------------------------------------------------------+----------------+
| Atrial extrasystole, left atrial | 2 |
+------------------------------------------------------------------+----------------+
| Atrial extrasystole, SA-nodal extrasystole | 3 |
+------------------------------------------------------------------+----------------+
| Atrial extrasystole, type: single PAC | 4 |
+------------------------------------------------------------------+----------------+
| Atrial extrasystole, type: bigemini | 1 |
+------------------------------------------------------------------+----------------+
| Atrial extrasystole, type: quadrigemini | 1 |
+------------------------------------------------------------------+----------------+
| Atrial extrasystole, type: allorhythmic pattern | 1 |
+------------------------------------------------------------------+----------------+
| Ventricular extrasystole, morphology: polymorphic | 2 |
+------------------------------------------------------------------+----------------+
| Ventricular extrasystole, localisation: RVOT, anterior wall | 3 |
+------------------------------------------------------------------+----------------+
| Ventricular extrasystole, localisation: RVOT, antero-septal part | 1 |
+------------------------------------------------------------------+----------------+
| Ventricular extrasystole, localisation: IVS, middle part | 1 |
+------------------------------------------------------------------+----------------+
| Ventricular extrasystole, localisation: LVOT, LVS | 2 |
+------------------------------------------------------------------+----------------+
| Ventricular extrasystole, localisation: LV, undefined | 1 |
+------------------------------------------------------------------+----------------+
| Ventricular extrasystole, type: single PVC | 6 |
+------------------------------------------------------------------+----------------+
| Ventricular extrasystole, type: intercalary PVC | 2 |
+------------------------------------------------------------------+----------------+
| Ventricular extrasystole, type: couplet | 2 |
+------------------------------------------------------------------+----------------+
- distribution of records with hypertrophies (totally 253):
+-------------------------------+----------------+
| Hypertrophies | Number of ECGs |
+===============================+================+
| Right atrial hypertrophy | 1 |
+-------------------------------+----------------+
| Left atrial hypertrophy | 102 |
+-------------------------------+----------------+
| Right atrial overload | 17 |
+-------------------------------+----------------+
| Left atrial overload | 11 |
+-------------------------------+----------------+
| Left ventricular hypertrophy | 108 |
+-------------------------------+----------------+
| Right ventricular hypertrophy | 3 |
+-------------------------------+----------------+
| Left ventricular overload | 11 |
+-------------------------------+----------------+
- distribution of records of pacing rhythms (totally 12):
+------------------------------+----------------+
| Cardiac pacing | Number of ECGs |
+==============================+================+
| UNIpolar atrial pacing | 1 |
+------------------------------+----------------+
| UNIpolar ventricular pacing | 6 |
+------------------------------+----------------+
| BIpolar ventricular pacing | 2 |
+------------------------------+----------------+
| Biventricular pacing | 1 |
+------------------------------+----------------+
| P-synchrony | 2 |
+------------------------------+----------------+
- distribution of records with ischemia (totally 141):
+-------------------------------------------------------+----------------+
| Ischemia | Number of ECGs |
+=======================================================+================+
| STEMI: anterior wall | 8 |
+-------------------------------------------------------+----------------+
| STEMI: lateral wall | 7 |
+-------------------------------------------------------+----------------+
| STEMI: septal | 8 |
+-------------------------------------------------------+----------------+
| STEMI: inferior wall | 1 |
+-------------------------------------------------------+----------------+
| STEMI: apical | 5 |
+-------------------------------------------------------+----------------+
| Ischemia: anterior wall | 5 |
+-------------------------------------------------------+----------------+
| Ischemia: lateral wall | 8 |
+-------------------------------------------------------+----------------+
| Ischemia: septal | 4 |
+-------------------------------------------------------+----------------+
| Ischemia: inferior wall | 10 |
+-------------------------------------------------------+----------------+
| Ischemia: posterior wall | 2 |
+-------------------------------------------------------+----------------+
| Ischemia: apical | 6 |
+-------------------------------------------------------+----------------+
| Scar formation: lateral wall | 3 |
+-------------------------------------------------------+----------------+
| Scar formation: septal | 9 |
+-------------------------------------------------------+----------------+
| Scar formation: inferior wall | 3 |
+-------------------------------------------------------+----------------+
| Scar formation: posterior wall | 6 |
+-------------------------------------------------------+----------------+
| Scar formation: apical | 5 |
+-------------------------------------------------------+----------------+
| Undefined ischemia/scar/supp.NSTEMI: anterior wall | 12 |
+-------------------------------------------------------+----------------+
| Undefined ischemia/scar/supp.NSTEMI: lateral wall | 16 |
+-------------------------------------------------------+----------------+
| Undefined ischemia/scar/supp.NSTEMI: septal | 5 |
+-------------------------------------------------------+----------------+
| Undefined ischemia/scar/supp.NSTEMI: inferior wall | 3 |
+-------------------------------------------------------+----------------+
| Undefined ischemia/scar/supp.NSTEMI: posterior wall | 4 |
+-------------------------------------------------------+----------------+
| Undefined ischemia/scar/supp.NSTEMI: apical | 11 |
+-------------------------------------------------------+----------------+
- distribution of records with non-specific repolarization abnormalities (totally 85):
+-------------------------------------------+----------------+
| Non-specific repolarization abnormalities | Number of ECGs |
+===========================================+================+
| Anterior wall | 18 |
+-------------------------------------------+----------------+
| Lateral wall | 13 |
+-------------------------------------------+----------------+
| Septal | 15 |
+-------------------------------------------+----------------+
| Inferior wall | 19 |
+-------------------------------------------+----------------+
| Posterior wall | 9 |
+-------------------------------------------+----------------+
| Apical | 11 |
+-------------------------------------------+----------------+
- there are also 9 records with early repolarization syndrome
there might well be records with multiple conditions.
5. ludb.csv stores information about the subjects (gender, age, rhythm type, direction of the electrical axis of the heart, the presence of a cardiac pacemaker, etc.)
6. Webpage of the database on PhysioNet [1]_. Paper describing the database [2]_.
""",
usage=[
"ECG wave delineation",
"ECG arrhythmia classification",
],
issues="""
1. (version 1.0.0, fixed in version 1.0.1) ADC gain might be wrong, either `units` should be μV, or `adc_gain` should be 1000 times larger
""",
references=[
"https://physionet.org/content/ludb/1.0.1/",
"Kalyakulina, A., Yusipov, I., Moskalenko, V., Nikolskiy, A., Kozlov, A., Kosonogov, K., Zolotykh, N., & Ivanchenko, M. (2020). Lobachevsky University Electrocardiography Database (version 1.0.1).",
],
doi=[
"10.1109/ACCESS.2020.3029211",
"10.13026/eegm-h675",
],
)
[docs]@add_docstring(_LUDB_INFO.format_database_docstring(), mode="prepend")
class LUDB(PhysioNetDataBase):
"""
Parameters
----------
db_dir : `path-like`, optional
Storage path of the database.
If not specified, data will be fetched from Physionet.
working_dir : `path-like`, optional
Working directory, to store intermediate files and log files.
verbose : int, default 1
Level of logging verbosity.
kwargs : dict, optional
Auxilliary key word arguments.
"""
__name__ = "LUDB"
def __init__(
self,
db_dir: Optional[Union[str, bytes, os.PathLike]] = None,
working_dir: Optional[Union[str, bytes, os.PathLike]] = None,
verbose: int = 1,
**kwargs: Any,
) -> None:
super().__init__(
db_name="ludb",
db_dir=db_dir,
working_dir=working_dir,
verbose=verbose,
**kwargs,
)
if self.version == "1.0.0":
self.logger.info("Version of LUDB 1.0.0 has bugs, make sure that version 1.0.1 or higher is used")
self.fs = 500
self.spacing = 1000 / self.fs
self.data_ext = "dat"
self.all_leads = deepcopy(EAK.Standard12Leads)
self.all_leads_lower = [ld.lower() for ld in self.all_leads]
# Version 1.0.0 has different beat_ann_ext: [f"atr_{item}" for item in self.all_leads_lower]
self.beat_ann_ext = [f"{item}" for item in self.all_leads_lower]
self._all_symbols = ["(", ")", "p", "N", "t"]
"""
this can be obtained using the following code:
>>> data_gen = LUDB(db_dir="/home/wenhao71/data/PhysioNet/ludb/1.0.1/")
>>> all_symbols = set()
>>> for rec in data_gen.all_records:
... for ext in data_gen.beat_ann_ext:
... ann = wfdb.rdann(data_gen.get_absolute_path(rec), extension=ext)
... all_symbols.update(ann.symbol)
"""
self._symbol_to_wavename = CFG(N="qrs", p="pwave", t="twave")
self._wavename_to_symbol = CFG({v: k for k, v in self._symbol_to_wavename.items()})
self.class_map = CFG(p=1, N=2, t=3, i=0) # an extra isoelectric
self._df_subject_info = None
self._ls_rec()
def _ls_rec(self) -> None:
"""Find all records in the database directory
and store them (path, metadata, etc.) in some private attributes.
"""
super()._ls_rec()
if (self.db_dir / "ludb.csv").is_file():
# newly added in version 1.0.1
self._df_subject_info = pd.read_csv(self.db_dir / "ludb.csv")
self._df_subject_info.ID = self._df_subject_info.ID.apply(str)
self._df_subject_info.Sex = self._df_subject_info.Sex.apply(lambda s: s.strip())
self._df_subject_info.Age = self._df_subject_info.Age.apply(lambda s: s.strip())
else:
self._df_subject_info = pd.DataFrame(
columns=[
"ID",
"Sex",
"Age",
"Rhythms",
"Electric axis of the heart",
"Conduction abnormalities",
"Extrasystolies",
"Hypertrophies",
"Cardiac pacing",
"Ischemia",
"Non-specific repolarization abnormalities",
"Other states",
]
)
[docs] def get_subject_id(self, rec: Union[str, int]) -> int:
"""Attach a unique subject ID for the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
int
Subject ID associated with the record.
"""
if isinstance(rec, int):
rec = self[rec]
return int(rec)
[docs] def get_absolute_path(self, rec: Union[str, int], extension: Optional[str] = None) -> Path:
"""Get the absolute path of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
extension : str, optional
Extension of the file.
Returns
-------
pathlib.Path
Absolute path of the file.
"""
if isinstance(rec, int):
rec = self[rec]
if extension is not None and not extension.startswith("."):
extension = f".{extension}"
return self.db_dir / "data" / f"{rec}{extension or ''}"
[docs] def load_ann(
self,
rec: Union[str, int],
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
metadata: bool = False,
) -> dict:
"""Load the wave delineation, along with metadata if specified.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
leads : str or int or List[str] or List[int], optional
The leads of the wave delineation to be loaded.
metadata : bool, default False
If True, metadata will be loaded from corresponding head file.
Returns
-------
ann_dict : dict
The wave delineation annotations.
"""
if isinstance(rec, int):
rec = self[rec]
ann_dict = CFG()
rec_fp = str(self.get_absolute_path(rec))
# wave delineation annotations
_leads = self._normalize_leads(leads)
_ann_ext = [f"{ld.lower()}" for ld in _leads] # for Version 1.0.0, it is f"{l.lower()}"
ann_dict["waves"] = CFG({ld: [] for ld in _leads})
for ld, ext in zip(_leads, _ann_ext):
ann = wfdb.rdann(rec_fp, extension=ext)
df_lead_ann = pd.DataFrame()
symbols = np.array(ann.symbol)
peak_inds = np.where(np.isin(symbols, ["p", "N", "t"]))[0]
df_lead_ann["peak"] = ann.sample[peak_inds]
df_lead_ann["onset"] = np.nan
df_lead_ann["offset"] = np.nan
for i, row in df_lead_ann.iterrows():
peak_idx = peak_inds[i]
if peak_idx == 0:
df_lead_ann.loc[i, "onset"] = row["peak"]
if symbols[peak_idx + 1] == ")":
df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1]
else:
df_lead_ann.loc[i, "offset"] = row["peak"]
elif peak_idx == len(symbols) - 1:
df_lead_ann.loc[i, "offset"] = row["peak"]
if symbols[peak_idx - 1] == "(":
df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1]
else:
df_lead_ann.loc[i, "onset"] = row["peak"]
else:
if symbols[peak_idx - 1] == "(":
df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1]
else:
df_lead_ann.loc[i, "onset"] = row["peak"]
if symbols[peak_idx + 1] == ")":
df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1]
else:
df_lead_ann.loc[i, "offset"] = row["peak"]
# df_lead_ann["onset"] = ann.sample[np.where(symbols=="(")[0]]
# df_lead_ann["offset"] = ann.sample[np.where(symbols==")")[0]]
df_lead_ann["duration"] = (df_lead_ann["offset"] - df_lead_ann["onset"]) * self.spacing
df_lead_ann.index = symbols[peak_inds]
for c in ["peak", "onset", "offset"]:
df_lead_ann[c] = df_lead_ann[c].values.astype(int)
for _, row in df_lead_ann.iterrows():
w = ECGWaveForm(
name=self._symbol_to_wavename[row.name],
onset=int(row.onset),
offset=int(row.offset),
peak=int(row.peak),
duration=row.duration,
)
ann_dict["waves"][ld].append(w)
if metadata:
header_dict = self._load_header(rec)
ann_dict.update(header_dict)
return ann_dict
[docs] def load_diagnoses(self, rec: Union[str, int]) -> List[str]:
"""Load diagnoses of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
diagnoses : List[str]
List of diagnoses of the record.
"""
diagnoses = self._load_header(rec)["diagnoses"]
return diagnoses
[docs] def load_masks(
self,
rec: Union[str, int],
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
mask_format: str = "channel_first",
class_map: Optional[Dict[str, int]] = None,
) -> np.ndarray:
"""Load the wave delineation in the form of masks.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
leads : str or int or List[str] or List[int], optional
The leads of the wave delineation to be loaded.
mask_format : str, default "channel_first"
Format of the mask,
"channel_last" (alias "lead_last"), or
"channel_first" (alias "lead_first").
class_map : dict, optional
Custom class map.
If not set, `self.class_map` will be used.
Returns
-------
masks : numpy.ndarray
The masks corresponding to the wave delineation annotations
of the record.
"""
if isinstance(rec, int):
rec = self[rec]
_class_map = CFG(class_map) if class_map is not None else self.class_map
data = self.load_data(rec, leads=leads, data_format="channel_first")
masks = np.full_like(data, fill_value=_class_map.i, dtype=int)
waves = self.load_ann(rec, leads=leads, metadata=False)["waves"]
for idx, (l, l_w) in enumerate(waves.items()):
for w in l_w:
masks[idx, w.onset : w.offset] = _class_map[self._wavename_to_symbol[w.name]]
if mask_format.lower() not in [
"channel_first",
"lead_first",
]:
masks = masks.T
return masks
[docs] def from_masks(
self,
masks: np.ndarray,
mask_format: str = "channel_first",
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
class_map: Optional[Dict[str, int]] = None,
fs: Optional[Real] = None,
) -> Dict[str, List[ECGWaveForm]]:
"""Convert masks into lists of waveforms.
Parameters
----------
masks : numpy.ndarray
Wave delineation in the form of masks,
of shape ``(n_leads, seq_len)`` or ``(seq_len,)``.
mask_format : str, default "channel_first"
Format of the mask, used only when ``masks.ndim = 2``.
One of "channel_last" (alias "lead_last"), or
"channel_first" (alias "lead_first").
leads : str or int or List[str] or List[int], optional
The leads of the wave delineation to be loaded.
class_map : dict, optional
Custom class map.
If not set, `self.class_map` will be used.
fs : numbers.Real, optional
Sampling frequency of the signal corresponding to the `masks`,
If is None, `self.fs` will be used,
to compute `duration` of the ECG waveforms.
Returns
-------
waves : dict
The wave delineation annotations of the record.
Each item value of the dict is a list
containing the :class:`ECGWaveForm` corr. to the lead (item key).
"""
if masks.ndim == 1:
_masks = masks[np.newaxis, ...]
elif masks.ndim == 2:
if mask_format.lower() not in [
"channel_first",
"lead_first",
]:
_masks = masks.T
else:
_masks = masks.copy()
else:
raise ValueError(f"masks should be of dim 1 or 2, but got a {masks.ndim}d array")
if leads is not None:
_leads = self._normalize_leads(leads)
else:
_leads = [f"lead_{idx+1}" for idx in range(_masks.shape[0])]
assert len(_leads) == _masks.shape[0]
_class_map = CFG(class_map) if class_map is not None else self.class_map
_fs = fs if fs is not None else self.fs
waves = CFG({lead_name: [] for lead_name in _leads})
for channel_idx, lead_name in enumerate(_leads):
current_mask = _masks[channel_idx, ...]
for wave_symbol, wave_number in _class_map.items():
if wave_symbol not in [
"p",
"N",
"t",
]:
continue
wave_name = self._symbol_to_wavename[wave_symbol]
current_wave_inds = np.where(current_mask == wave_number)[0]
if len(current_wave_inds) == 0:
continue
np.where(np.diff(current_wave_inds) > 1)
split_inds = np.where(np.diff(current_wave_inds) > 1)[0].tolist()
split_inds = sorted(split_inds + [i + 1 for i in split_inds])
split_inds = [0] + split_inds + [len(current_wave_inds) - 1]
for i in range(len(split_inds) // 2):
itv_start = current_wave_inds[split_inds[2 * i]]
itv_end = current_wave_inds[split_inds[2 * i + 1]] + 1
w = ECGWaveForm(
name=wave_name,
onset=itv_start,
offset=itv_end,
peak=np.nan,
duration=1000 * (itv_end - itv_start) / _fs, # ms
)
waves[lead_name].append(w)
waves[lead_name].sort(key=lambda w: w.onset)
return waves
def _load_header(self, rec: Union[str, int]) -> dict:
"""Load header data of a record into a dict.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
header_dict : dict
The header data of the record.
"""
header_dict = CFG({})
rec_fp = str(self.get_absolute_path(rec))
header_reader = wfdb.rdheader(rec_fp)
header_dict["units"] = header_reader.units
header_dict["baseline"] = header_reader.baseline
header_dict["adc_gain"] = header_reader.adc_gain
header_dict["record_fmt"] = header_reader.fmt
try:
header_dict["age"] = int([line for line in header_reader.comments if "<age>" in line][0].split(": ")[-1])
except Exception:
header_dict["age"] = np.nan
try:
header_dict["sex"] = [line for line in header_reader.comments if "<sex>" in line][0].split(": ")[-1]
except Exception:
header_dict["sex"] = ""
d_start = [idx for idx, line in enumerate(header_reader.comments) if "<diagnoses>" in line][0] + 1
header_dict["diagnoses"] = header_reader.comments[d_start:]
return header_dict
[docs] def load_subject_info(self, rec: Union[str, int], fields: Optional[Union[str, Sequence[str]]] = None) -> Union[dict, str]:
"""Load subject info of a record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
fields : str or Sequence[str], optional
Field(s) of the subject info of record.
If not specified, all fields of the subject info will be loaded.
Returns
-------
info : dict or str
Subject info of the given fields of the record.
"""
if isinstance(rec, int):
rec = self[rec]
row = self._df_subject_info[self._df_subject_info.ID == rec]
if row.empty:
return {}
row = row.iloc[0]
info = row.to_dict()
if fields is not None:
if isinstance(fields, str):
assert fields in self._df_subject_info.columns, f"No field `{fields}`"
info = info[fields]
else:
assert set(fields).issubset(
set(self._df_subject_info.columns)
), f"No field(s) {set(fields).difference(set(self._df_subject_info.columns))}"
info = {k: v for k, v in info.items() if k in fields}
return info
[docs] def plot(
self,
rec: Union[str, int],
data: Optional[np.ndarray] = None,
ticks_granularity: int = 0,
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
same_range: bool = False,
waves: Optional[ECGWaveForm] = None,
**kwargs: Any,
) -> None:
"""
Plot the signals of a record or external signals (units in μV),
with metadata (fs, labels, tranche, etc.),
possibly also along with wave delineations.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
data : numpy.ndarray, optional
12-lead ECG signal to plot.
If is not None, data of `rec` will not be used.
This is useful when plotting filtered data
ticks_granularity : int, default 0
Granularity to plot axis ticks, the higher the more ticks.
0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks)
leads : str or int or List[str] or List[int], optional
The leads of the ECG signal to plot.
same_range : bool, default False
If True, all leads are forced to have the same y range.
waves : ECGWaveForm, optional
The waves (p waves, t waves, qrs complexes) to plot.
kwargs : dict, optional
Additional keyword arguments to pass to :func:`matplotlib.pyplot.plot`.
TODO
----
1. Slice too long records, and plot separately for each segment.
2. Plot waves using :func:`matplotlib.pyplot.axvspan`.
NOTE
----
`Locator` of ``plt`` has default `MAXTICKS` of 1000.
If not modifying this number, at most 40 seconds of signal could be plotted once.
Contributors: Jeethan, and WEN Hao
"""
if isinstance(rec, int):
rec = self[rec]
if "plt" not in dir():
import matplotlib.pyplot as plt
plt.MultipleLocator.MAXTICKS = 3000
if data is not None:
assert leads is not None, "`leads` must be specified when `data` is given"
data = np.atleast_2d(data)
_leads = self._normalize_leads(leads)
_lead_indices = [self.all_leads.index(ld) for ld in _leads]
assert len(_leads) == data.shape[0], "number of leads must match data"
units = self._auto_infer_units(data)
self.logger.info(f"input data is auto detected to have units in {units}")
if units.lower() == "mv":
_data = 1000 * data
else:
_data = data
else:
_leads = self._normalize_leads(leads)
_lead_indices = [self.all_leads.index(ld) for ld in _leads]
_data = self.load_data(rec, data_format="channel_first", units="μV")[_lead_indices]
if same_range:
y_ranges = np.ones((_data.shape[0],)) * np.max(np.abs(_data)) + 100
else:
y_ranges = np.max(np.abs(_data), axis=1) + 100
if data is None and waves is None:
waves = self.load_ann(rec, leads=_leads)["waves"]
pwaves = {ld: [] for ld in _leads}
qrs = {ld: [] for ld in _leads}
twaves = {ld: [] for ld in _leads}
if waves is not None:
for l_idx, l_w in waves.items():
for wv in l_w:
itv = [wv.onset, wv.offset]
if wv.name == self._symbol_to_wavename["p"]:
pwaves[l_idx].append(itv)
elif wv.name == self._symbol_to_wavename["N"]:
qrs[l_idx].append(itv)
elif wv.name == self._symbol_to_wavename["t"]:
twaves[l_idx].append(itv)
palette = {
"pwaves": "green",
"qrs": "red",
"twaves": "yellow",
}
plot_alpha = 0.4
diagnoses = self.load_diagnoses(rec)
nb_leads = len(_leads)
seg_len = self.fs * 25 # 25 seconds
nb_segs = _data.shape[1] // seg_len
t = np.arange(_data.shape[1]) / self.fs
duration = len(t) / self.fs
fig_sz_w = int(round(4.8 * duration))
fig_sz_h = 6 * y_ranges / 1500
fig, axes = plt.subplots(nb_leads, 1, sharex=True, figsize=(fig_sz_w, np.sum(fig_sz_h)))
if nb_leads == 1:
axes = [axes]
for idx in range(nb_leads):
lead_name = self.all_leads[_lead_indices[idx]]
axes[idx].plot(
t,
_data[idx],
color="black",
linewidth="2.0",
label=f"lead - {lead_name}",
)
axes[idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red")
# NOTE that `Locator` has default `MAXTICKS` equal to 1000
if ticks_granularity >= 1:
axes[idx].xaxis.set_major_locator(plt.MultipleLocator(0.2))
axes[idx].yaxis.set_major_locator(plt.MultipleLocator(500))
axes[idx].grid(which="major", linestyle="-", linewidth="0.5", color="red")
if ticks_granularity >= 2:
axes[idx].xaxis.set_minor_locator(plt.MultipleLocator(0.04))
axes[idx].yaxis.set_minor_locator(plt.MultipleLocator(100))
axes[idx].grid(which="minor", linestyle=":", linewidth="0.5", color="black")
# add extra info. to legend
# https://stackoverflow.com/questions/16826711/is-it-possible-to-add-a-string-as-a-legend-item-in-matplotlib
for d in diagnoses:
axes[idx].plot([], [], " ", label=d)
for w in ["pwaves", "qrs", "twaves"]:
for itv in eval(f"{w}['{lead_name}']"):
axes[idx].axvspan(
itv[0] / self.fs,
itv[1] / self.fs,
color=palette[w],
alpha=plot_alpha,
)
axes[idx].legend(loc="upper left")
axes[idx].set_xlim(t[0], t[-1])
axes[idx].set_ylim(-y_ranges[idx], y_ranges[idx])
axes[idx].set_xlabel("Time [s]")
axes[idx].set_ylabel("Voltage [μV]")
plt.subplots_adjust(hspace=0.2)
if kwargs.get("save_path", None):
plt.savefig(kwargs["save_path"], dpi=200, bbox_inches="tight")
else:
plt.show()
@property
def database_info(self) -> DataBaseInfo:
return _LUDB_INFO
__TOLERANCE = 150 # ms
__WaveNames = ["pwave", "qrs", "twave"]
def compute_metrics(
truth_masks: Sequence[np.ndarray],
pred_masks: Sequence[np.ndarray],
class_map: Dict[str, int],
fs: Real,
mask_format: str = "channel_first",
) -> Dict[str, Dict[str, float]]:
"""Compute metrics for the wave delineation task.
Compute metrics
(sensitivity, precision, f1_score, mean error
and standard deviation of the mean errors)
for multiple evaluations
Parameters
----------
truth_masks : Sequence[numpy.ndarray]
A sequence of ground truth masks,
each of which can also hold multiple masks
from different samples (differ by record or by lead).
pred_masks : Sequence[numpy.ndarray]
Predictions corresponding to `truth_masks`
class_map : Dict[str, int]
Class map, mapping names to waves to numbers from 0 to n_classes-1,
the keys should contain "pwave", "qrs", "twave".
fs : numbers.Real
Sampling frequency of the signal corresponding to the masks,
used to compute the duration of each waveform,
hence the error and standard deviations of errors.
mask_format : str, default "channel_first"
Format of the mask, one of the following:
"channel_last" (alias "lead_last"), or
"channel_first" (alias "lead_first").
Returns
-------
scorings : dict
A dictionary containing the scorings of onsets and offsets
of pwaves, qrs complexes, twaves.
Each scoring is a dict consisting of the following metrics:
sensitivity, precision, f1_score, mean_error, standard_deviation.
"""
assert len(truth_masks) == len(pred_masks)
truth_waveforms, pred_waveforms = [], []
# compute for each element
for tm, pm in zip(truth_masks, pred_masks):
n_masks = tm.shape[0] if mask_format.lower() in ["channel_first", "lead_first"] else tm.shape[1]
new_t = masks_to_waveforms(tm, class_map, fs, mask_format)
new_t = [new_t[f"lead_{idx+1}"] for idx in range(n_masks)] # list of list of `ECGWaveForm`s
truth_waveforms += new_t
new_p = masks_to_waveforms(pm, class_map, fs, mask_format)
new_p = [new_p[f"lead_{idx+1}"] for idx in range(n_masks)] # list of list of `ECGWaveForm`s
pred_waveforms += new_p
scorings = compute_metrics_waveform(truth_waveforms, pred_waveforms, fs)
return scorings
def compute_metrics_waveform(
truth_waveforms: Sequence[Sequence[ECGWaveForm]],
pred_waveforms: Sequence[Sequence[ECGWaveForm]],
fs: Real,
) -> Dict[str, Dict[str, float]]:
"""
Compute the sensitivity, precision, f1_score, mean error
and standard deviation of the mean errors,
of evaluations on a multiple samples (differ by records, or leads).
Parameters
----------
truth_waveforms : Sequence[Sequence[ECGWaveForm]]
The ground truth. Each element is a sequence of
:class:`ECGWaveForm` from the same sample.
pred_waveforms : Sequence[Sequence[ECGWaveForm]]
The predictions corresponding to `truth_waveforms`.
Each element is a sequence of :class:`ECGWaveForm`
from the same sample.
fs : numbers.Real
Sampling frequency of the signal corresponding to the waveforms,
used to compute the duration of each waveform,
hence the error and standard deviations of errors.
Returns
-------
scorings : dict
A dictionary containing the scorings of onsets and offsets
of pwaves, qrs complexes, twaves.
Each scoring is a dict consisting of the following metrics:
sensitivity, precision, f1_score, mean_error, standard_deviation.
"""
truth_positive = CFG(
{
f"{wave}_{term}": 0
for wave in [
"pwave",
"qrs",
"twave",
]
for term in ["onset", "offset"]
}
)
false_positive = CFG(
{
f"{wave}_{term}": 0
for wave in [
"pwave",
"qrs",
"twave",
]
for term in ["onset", "offset"]
}
)
false_negative = CFG(
{
f"{wave}_{term}": 0
for wave in [
"pwave",
"qrs",
"twave",
]
for term in ["onset", "offset"]
}
)
errors = CFG(
{
f"{wave}_{term}": []
for wave in [
"pwave",
"qrs",
"twave",
]
for term in ["onset", "offset"]
}
)
# accumulating results
for tw, pw in zip(truth_waveforms, pred_waveforms):
s = _compute_metrics_waveform(tw, pw, fs)
for wave in [
"pwave",
"qrs",
"twave",
]:
for term in ["onset", "offset"]:
truth_positive[f"{wave}_{term}"] += s[f"{wave}_{term}"]["truth_positive"]
false_positive[f"{wave}_{term}"] += s[f"{wave}_{term}"]["false_positive"]
false_negative[f"{wave}_{term}"] += s[f"{wave}_{term}"]["false_negative"]
errors[f"{wave}_{term}"] += s[f"{wave}_{term}"]["errors"]
scorings = CFG()
for wave in [
"pwave",
"qrs",
"twave",
]:
for term in ["onset", "offset"]:
tp = truth_positive[f"{wave}_{term}"]
fp = false_positive[f"{wave}_{term}"]
fn = false_negative[f"{wave}_{term}"]
err = errors[f"{wave}_{term}"]
sensitivity = tp / (tp + fn + DEFAULTS.eps)
precision = tp / (tp + fp + DEFAULTS.eps)
f1_score = 2 * sensitivity * precision / (sensitivity + precision + DEFAULTS.eps)
mean_error = np.mean(err) * 1000 / fs
standard_deviation = np.std(err) * 1000 / fs
scorings[f"{wave}_{term}"] = CFG(
sensitivity=sensitivity,
precision=precision,
f1_score=f1_score,
mean_error=mean_error,
standard_deviation=standard_deviation,
)
return scorings
def _compute_metrics_waveform(
truths: Sequence[ECGWaveForm], preds: Sequence[ECGWaveForm], fs: Real
) -> Dict[str, Dict[str, float]]:
"""
compute the sensitivity, precision, f1_score, mean error
and standard deviation of the mean errors,
of evaluations on a single sample (the same record, the same lead).
Parameters
----------
truths : Sequence[ECGWaveForm]
The ground truth.
preds : Sequence[ECGWaveForm]
The predictions corresponding to `truths`,
fs : numbers.Real
Sampling frequency of the signal corresponding to the waveforms,
used to compute the duration of each waveform,
hence the error and standard deviations of errors.
Returns
-------
scorings : dict
A dictionary containing the scorings of onsets and offsets
of pwaves, qrs complexes, twaves.
Each scoring is a dict consisting of the following metrics:
truth_positive, false_negative, false_positive, errors,
sensitivity, precision, f1_score, mean_error, standard_deviation.
"""
pwave_onset_truths, pwave_offset_truths, pwave_onset_preds, pwave_offset_preds = (
[],
[],
[],
[],
)
qrs_onset_truths, qrs_offset_truths, qrs_onset_preds, qrs_offset_preds = (
[],
[],
[],
[],
)
twave_onset_truths, twave_offset_truths, twave_onset_preds, twave_offset_preds = (
[],
[],
[],
[],
)
for item in ["truths", "preds"]:
for w in eval(item):
for term in ["onset", "offset"]:
eval(f"{w.name}_{term}_{item}.append(w.{term})")
scorings = CFG()
for wave in [
"pwave",
"qrs",
"twave",
]:
for term in ["onset", "offset"]:
(
truth_positive,
false_negative,
false_positive,
errors,
sensitivity,
precision,
f1_score,
mean_error,
standard_deviation,
) = _compute_metrics_base(eval(f"{wave}_{term}_truths"), eval(f"{wave}_{term}_preds"), fs)
scorings[f"{wave}_{term}"] = CFG(
truth_positive=truth_positive,
false_negative=false_negative,
false_positive=false_positive,
errors=errors,
sensitivity=sensitivity,
precision=precision,
f1_score=f1_score,
mean_error=mean_error,
standard_deviation=standard_deviation,
)
return scorings
def _compute_metrics_base(
truths: Sequence[Real], preds: Sequence[Real], fs: Real
) -> Tuple[int, int, int, List[float], float, float, float, float, float]:
"""The base function for computing the metrics.
Parameters
----------
truths : Sequence[Real]
Ground truth of indices of corresponding critical points.
preds : Sequence[Real]
Predicted indices of corresponding critical points.
fs : numbers.Real
Sampling frequency of the signal corresponding to the critical points,
used to compute the duration of each waveform,
hence the error and standard deviations of errors.
Returns
-------
tuple
tuple of the following metrics:
truth_positive, false_negative, false_positive, errors,
sensitivity, precision, f1_score, mean_error, standard_deviation.
"""
_tolerance = __TOLERANCE * fs / 1000
_truths = np.array(truths)
_preds = np.array(preds)
truth_positive, false_positive, false_negative = 0, 0, 0
errors = []
n_included = 0
for point in truths:
_pred = _preds[np.where(np.abs(_preds - point) <= _tolerance)[0].tolist()]
if len(_pred) > 0:
truth_positive += 1
idx = np.argmin(np.abs(_pred - point))
errors.append(_pred[idx] - point)
else:
false_negative += 1
n_included += len(_pred)
# false_positive = len(_preds) - n_included
false_positive = len(_preds) - truth_positive
# print(f"""
# truth_positive = {truth_positive}
# false_positive = {false_positive}
# false_negative = {false_negative}
# """)
# print(f"len(truths) = {len(truths)}, truth_positive + false_negative = {truth_positive + false_negative}")
# print(f"len(preds) = {len(preds)}, truth_positive + false_positive = {truth_positive + false_positive}")
sensitivity = truth_positive / (truth_positive + false_negative + DEFAULTS.eps)
precision = truth_positive / (truth_positive + false_positive + DEFAULTS.eps)
f1_score = 2 * sensitivity * precision / (sensitivity + precision + DEFAULTS.eps)
mean_error = np.mean(errors) * 1000 / fs
standard_deviation = np.std(errors) * 1000 / fs
return (
truth_positive,
false_negative,
false_positive,
errors,
sensitivity,
precision,
f1_score,
mean_error,
standard_deviation,
)