# -*- coding: utf-8 -*-
import math
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
import numpy as np
import wfdb
from ...cfg import CFG
from ...utils.misc import add_docstring, get_record_list_recursive
from ...utils.utils_interval import generalized_intervals_intersection
from ..base import DEFAULT_FIG_SIZE_PER_SEC, DataBaseInfo, PhysioNetDataBase
__all__ = [
"AFDB",
]
_AFDB_INFO = DataBaseInfo(
title="""
MIT-BIH Atrial Fibrillation Database
""",
about="""
1. contains 25 long-term (each 10 hours) ECG recordings of human subjects with atrial fibrillation (mostly paroxysmal)
2. 23 records out of 25 include the two ECG signals, the left 2 records 00735 and 03665 are represented only by the rhythm (.atr) and unaudited beat (.qrs) annotation files
3. signals are sampled at 250 samples per second with 12-bit resolution over a range of ±10 millivolts, with a typical recording bandwidth of approximately 0.1 Hz to 40 Hz
4. 4 classes of rhythms are annotated:
- AFIB: atrial fibrillation
- AFL: atrial flutter
- J: AV junctional rhythm
- N: all other rhythms
5. rhythm annotations almost all start with "(N", except for 4 which start with '(AFIB', which are all within 1 second (250 samples)
6. Webpage of the database on PhysioNet [1]_. Paper describing the database [2]_.
""",
note="""
1. beat annotation files (.qrs files) were prepared using an automated detector and have NOT been corrected manually
2. for some records, manually corrected beat annotation files (.qrsc files) are available
3. one should never use wfdb.rdann with arguments `sampfrom`, since one has to know the `aux_note` (with values in ["(N", "(J", "(AFL", "(AFIB"]) before the index at `sampfrom`
""",
usage=[
"Atrial fibrillation (AF) detection",
],
references=[
"https://physionet.org/content/afdb/",
"Moody GB, Mark RG. A new method for detecting atrial fibrillation using R-R intervals. Computers in Cardiology. 10:227-230 (1983).",
],
doi=[
"10.13026/C2MW2D",
],
)
[docs]
@add_docstring(_AFDB_INFO.format_database_docstring(), mode="prepend")
class AFDB(PhysioNetDataBase):
"""
Parameters
----------
db_dir : `path-like`, optional
Storage path of the database.
If not specified, data will be fetched from Physionet.
working_dir : `path-like`, optional
Working directory, to store intermediate files and log files.
verbose : int, default 1
Level of logging verbosity.
kwargs : dict, optional
Auxilliary key word arguments.
"""
__name__ = "AFDB"
def __init__(
self,
db_dir: Optional[Union[str, bytes, os.PathLike]] = None,
working_dir: Optional[Union[str, bytes, os.PathLike]] = None,
verbose: int = 1,
**kwargs: Any,
) -> None:
super().__init__(
db_name="afdb",
db_dir=db_dir,
working_dir=working_dir,
verbose=verbose,
**kwargs,
)
self.fs = 250
self.data_ext = "dat"
self.ann_ext = "atr"
self.auto_beat_ann_ext = "qrs"
self.manual_beat_ann_ext = "qrsc"
self.all_leads = [
"ECG1",
"ECG2",
]
self.special_records = ["00735", "03665"]
self.qrsc_records = None
self._ls_rec()
self.class_map = CFG(AFIB=1, AFL=2, J=3, N=0) # an extra isoelectric
self.palette = kwargs.get("palette", None)
if self.palette is None:
self.palette = CFG(
AFIB="blue",
AFL="red",
J="yellow",
# N="green",
qrs="green",
)
def _ls_rec(self, local: bool = True) -> None:
"""
Find all records (relative path without file extension),
and save into `self._all_records` for further use.
Parameters
----------
local : bool, default True
If True, read from local storage,
prior to using :func:`wfdb.get_record_list`.
"""
super()._ls_rec(local=local)
self._all_records = [rec for rec in self._all_records if rec not in self.special_records]
self.qrsc_records = get_record_list_recursive(self.db_dir, self.manual_beat_ann_ext, relative=False)
self.qrsc_records = [Path(rec).stem for rec in self.qrsc_records]
[docs]
def load_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
ann_format: Literal["intervals", "mask"] = "intervals",
keep_original: bool = False,
) -> Union[Dict[str, list], np.ndarray]:
"""Load annotations (header) from the .hea files.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto: int, optional
End index of the annotations to be loaded.
ann_format : {"intervals", "mask"}, default "intervals"
Format of returned annotation, case insensitive.
keep_original : bool, default False
If True, when `ann_format` is "intervals",
intervals (in the form [a,b]) will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
ann : dict or numpy.ndarray
The annotations in the format of intervals, or in the format of masks.
"""
fp = str(self.get_absolute_path(rec))
wfdb_ann = wfdb.rdann(fp, extension=self.ann_ext)
header = wfdb.rdheader(fp)
sig_len = header.sig_len
sf = sampfrom or 0
st = sampto or sig_len
assert st > sf, "`sampto` should be greater than `sampfrom`!"
ann = CFG({k: [] for k in self.class_map.keys()})
critical_points = wfdb_ann.sample.tolist() + [sig_len]
aux_note = wfdb_ann.aux_note
if aux_note[0] == "(N":
# ref. the doc string of the class
critical_points[0] = 0
else:
critical_points.insert(0, 0)
aux_note.insert(0, "(N")
for idx, rhythm in enumerate(aux_note):
ann[rhythm.replace("(", "")].append([critical_points[idx], critical_points[idx + 1]])
ann = CFG({k: generalized_intervals_intersection(l_itv, [[sf, st]]) for k, l_itv in ann.items()})
if ann_format.lower() == "mask":
tmp = deepcopy(ann)
ann = np.full(shape=(st - sf,), fill_value=self.class_map.N, dtype=int)
for rhythm, l_itv in tmp.items():
for itv in l_itv:
ann[itv[0] - sf : itv[1] - sf] = self.class_map[rhythm]
elif not keep_original:
for k, l_itv in ann.items():
ann[k] = [[itv[0] - sf, itv[1] - sf] for itv in l_itv]
return ann
[docs]
def load_beat_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
use_manual: bool = True,
keep_original: bool = False,
) -> np.ndarray:
"""Load beat annotations from corresponding annotation files.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto : int, optional
End index of the annotations to be loaded.
use_manual : bool, default True
If True, use manually annotated beat annotations (qrs),
instead of those generated by algorithms.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
ann : numpy.ndarray
Locations (indices) of the qrs complexes.
"""
if isinstance(rec, int):
rec = self[rec]
fp = str(self.get_absolute_path(rec))
if use_manual and rec in self.qrsc_records:
ext = self.manual_beat_ann_ext
else:
ext = self.auto_beat_ann_ext
ann = wfdb.rdann(
fp,
extension=ext,
sampfrom=sampfrom or 0,
sampto=sampto,
)
ann = ann.sample
if not keep_original and sampfrom is not None:
ann -= sampfrom
return ann
[docs]
@add_docstring(load_beat_ann.__doc__)
def load_rpeak_indices(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
use_manual: bool = True,
keep_original: bool = False,
) -> np.ndarray:
"""
alias of `self.load_beat_ann`
"""
return self.load_beat_ann(rec, sampfrom, sampto, use_manual, keep_original)
[docs]
def plot(
self,
rec: Union[str, int],
data: Optional[np.ndarray] = None,
ann: Optional[Dict[str, np.ndarray]] = None,
rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None,
ticks_granularity: int = 0,
leads: Optional[Union[str, int, List[str], List[int]]] = None,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
same_range: bool = False,
**kwargs: Any,
) -> None:
"""
Plot the signals of a record or external signals (units in μV),
with metadata (fs, labels, tranche, etc.),
possibly also along with wave delineations.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
data : numpy.ndarray, optional
(2-lead) ECG signal to plot,
should be of the format "channel_first", and compatible with `leads`.
If given, data of `rec` will not be used,
which is useful when plotting filtered data.
ann : dict, optional
Annotations for `data`, covering those from annotation files,
in the form of ``{"AFIB":l_itv, "AFL":l_itv, "J":l_itv, "N":l_itv}``,
where ``l_itv`` in the form of ``[[a, b], ...]``.
Ignored if `data` is None.
rpeak_inds : array_like, optional
Indices of R peaks, covering those from annotation files.
If `data` is None, then indices should be the absolute indices in the record.
ticks_granularity : int, default 0
Granularity to plot axis ticks, the higher the more ticks.
0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks).
leads : str or int or List[str] or List[int], optional
The leads to plot.
sampfrom : int, optional
Start index of the data to plot.
sampto : int, optional
End index of the data to plot.
same_range : bool, default False
If True, forces all leads to have the same y range.
kwargs : dict, optional
Keyword arguments for :func:`matplotlib.pyplot.plot`, etc.
"""
if isinstance(rec, int):
rec = self[rec]
if "plt" not in dir():
import matplotlib.pyplot as plt
plt.MultipleLocator.MAXTICKS = 3000
if leads is None or leads == "all":
_leads = self.all_leads
elif isinstance(leads, str):
_leads = [leads]
elif isinstance(leads, int):
_leads = [self.all_leads[leads]]
else:
_leads = leads
assert all([ld in self.all_leads for ld in _leads]) or set(_leads) <= {0, 1}
try:
lead_indices = [self.all_leads.index(ld) for ld in _leads]
except ValueError:
lead_indices = _leads
if data is None:
_data = self.load_data(
rec,
leads=_leads,
sampfrom=sampfrom,
sampto=sampto,
data_format="channel_first",
units="μV",
)
else:
units = self._auto_infer_units(data)
self.logger.info(f"input data is auto detected to have units in `{units}`")
if units.lower() == "mv":
_data = 1000 * data
else:
_data = data
_leads = [f"ECG_{idx}" for idx in range(_data.shape[0])]
if ann is None and data is None:
_ann = self.load_ann(
rec,
sampfrom=sampfrom,
sampto=sampto,
ann_format="intervals",
keep_original=False,
)
else:
_ann = ann or CFG({k: [] for k in self.class_map.keys()})
# indices to time
_ann = {k: [[itv[0] / self.fs, itv[1] / self.fs] for itv in l_itv] for k, l_itv in _ann.items()}
if rpeak_inds is None and data is None:
_rpeak = self.load_rpeak_indices(
rec,
sampfrom=sampfrom,
sampto=sampto,
use_manual=True,
keep_original=False,
)
_rpeak = _rpeak / self.fs # indices to time
else:
_rpeak = np.array(rpeak_inds or []) / self.fs # indices to time
ann_plot_alpha = 0.2
rpeaks_plot_alpha = 0.8
nb_leads = len(_leads)
line_len = self.fs * 25 # 25 seconds
nb_lines = math.ceil(_data.shape[1] / line_len)
for seg_idx in range(nb_lines):
seg_data = _data[..., seg_idx * line_len : (seg_idx + 1) * line_len]
secs = (np.arange(seg_data.shape[1]) + seg_idx * line_len) / self.fs
seg_ann = {k: generalized_intervals_intersection(l_itv, [[secs[0], secs[-1]]]) for k, l_itv in _ann.items()}
seg_rpeaks = _rpeak[np.where((_rpeak >= secs[0]) & (_rpeak < secs[-1]))[0]]
fig_sz_w = int(round(DEFAULT_FIG_SIZE_PER_SEC * seg_data.shape[1] / self.fs))
if same_range:
y_ranges = np.ones((seg_data.shape[0],)) * np.max(np.abs(seg_data)) + 100
else:
y_ranges = np.max(np.abs(seg_data), axis=1) + 100
fig_sz_h = 6 * y_ranges / 1500
fig, axes = plt.subplots(nb_leads, 1, sharex=True, figsize=(fig_sz_w, np.sum(fig_sz_h)))
if nb_leads == 1:
axes = [axes]
for idx in range(nb_leads):
axes[idx].plot(secs, seg_data[idx], color="black", label=f"lead - {_leads[idx]}")
axes[idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red")
# NOTE that `Locator` has default `MAXTICKS` equal to 1000
if ticks_granularity >= 1:
axes[idx].xaxis.set_major_locator(plt.MultipleLocator(0.2))
axes[idx].yaxis.set_major_locator(plt.MultipleLocator(500))
axes[idx].grid(which="major", linestyle="-", linewidth="0.5", color="red")
if ticks_granularity >= 2:
axes[idx].xaxis.set_minor_locator(plt.MultipleLocator(0.04))
axes[idx].yaxis.set_minor_locator(plt.MultipleLocator(100))
axes[idx].grid(which="minor", linestyle=":", linewidth="0.5", color="black")
for k, l_itv in seg_ann.items():
if k == "N":
continue
for itv in l_itv:
axes[idx].axvspan(
itv[0],
itv[1],
color=self.palette[k],
alpha=ann_plot_alpha,
label=k,
)
for ri in seg_rpeaks:
axes[idx].axvspan(
ri - 0.01,
ri + 0.01,
color=self.palette["qrs"],
alpha=rpeaks_plot_alpha,
)
axes[idx].axvspan(
ri - 0.075,
ri + 0.075,
color=self.palette["qrs"],
alpha=ann_plot_alpha,
)
axes[idx].legend(loc="upper left")
axes[idx].set_xlim(secs[0], secs[-1])
axes[idx].set_ylim(-y_ranges[idx], y_ranges[idx])
axes[idx].set_xlabel("Time [s]")
axes[idx].set_ylabel("Voltage [μV]")
plt.subplots_adjust(hspace=0.2)
plt.show()
@property
def database_info(self) -> DataBaseInfo:
return _AFDB_INFO