# -*- coding: utf-8 -*-
import json
import math
import os
from copy import deepcopy
from numbers import Real
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import wfdb
from ...cfg import CFG
from ...utils.misc import add_docstring
from ...utils.utils_interval import generalized_intervals_intersection
from ..base import DEFAULT_FIG_SIZE_PER_SEC, BeatAnn, DataBaseInfo, PhysioNetDataBase
__all__ = [
"LTAFDB",
]
_LTAFDB_INFO = DataBaseInfo(
title="""
Long Term AF Database
""",
about="""
1. contains 84 long-term ECG recordings of subjects with paroxysmal or sustained atrial fibrillation
2. each record contains two simultaneously recorded ECG signals digitized at 128 Hz
3. records have duration 24 - 25 hours
4. qrs annotations (.qrs files) were produced by an automated QRS detector, in which detected beats (including occasional ventricular ectopic beats) are labelled "N", detected artifacts are labelled "|", and AF terminations are labelled "T" (inserted manually)
5. atr annotations (.atr files) were obtained by manual review of the output of an automated ECG analysis system; in these annotation files, all detected beats are labelled by type ('"', "+", "A", "N", "Q", "V"), and rhythm changes ("\x01 Aux", "(AB", "(AFIB", "(B", "(IVR", "(N", "(SBR", "(SVTA", "(T", "(VT", "M", "MB", "MISSB", "PSE") are also annotated
6. Webpage of the database on PhysioNet [1]_. Paper describing the database [2]_.
""",
note="""
1. both channels of the signals have name "ECG"
2. the automatically generated qrs annotations (.qrs files) contains NO rhythm annotations
3. `aux_note` of .atr files of all but one ("64") record start with valid rhythms, all but one end with "" ("30" ends with "\x01 Aux")
4. for more statistics on the whole database, see [ref 3](#ref3)
""",
usage=[
"Atrial fibrillation (AF) detection",
"(3 or 4) beat type classification",
"Rhythm classification",
],
references=[
"https://physionet.org/content/ltafdb/1.0.0/",
"Petrutiu S, Sahakian AV, Swiryn S. Abrupt changes in fibrillatory wave characteristics at the termination of paroxysmal atrial fibrillation in humans. Europace 9:466-470 (2007).",
],
doi=[
"10.1093/europace/eum096",
"10.13026/C2QG6Q",
],
)
[docs]@add_docstring(_LTAFDB_INFO.format_database_docstring(), mode="prepend")
class LTAFDB(PhysioNetDataBase):
"""
Parameters
----------
db_dir : `path-like`, optional
Storage path of the database.
If not specified, data will be fetched from Physionet.
working_dir : `path-like`, optional
Working directory, to store intermediate files and log files.
verbose : int, default 1
Level of logging verbosity.
kwargs : dict, optional
Auxilliary key word arguments.
"""
__name__ = "LTAFDB"
def __init__(
self,
db_dir: Optional[Union[str, bytes, os.PathLike]] = None,
working_dir: Optional[Union[str, bytes, os.PathLike]] = None,
verbose: int = 1,
**kwargs: Any,
) -> None:
from matplotlib.pyplot import cm
super().__init__(
db_name="ltafdb",
db_dir=db_dir,
working_dir=working_dir,
verbose=verbose,
**kwargs,
)
self.fs = 128
self.data_ext = "dat"
self.auto_ann_ext = "qrs"
self.manual_ann_ext = "atr"
self.all_leads = [0, 1]
self._ls_rec()
self.rhythm_types = [
"(N",
"(AB",
"(AFIB",
"(B",
"(IVR",
"(SBR",
"(SVTA",
"(T",
"(VT",
"NOISE", # additional, since head of each record are noisy
] # others include "\x01 Aux", "M", "MB", "MISSB", "PSE"
self.rhythm_types_map = CFG({k.replace("(", ""): idx for idx, k in enumerate(self.rhythm_types)})
self.palette = kwargs.get("palette", None)
if self.palette is None:
n_colors = len([k for k in self.rhythm_types_map.keys() if k not in ["N", "NOISE"]])
colors = iter(cm.rainbow(np.linspace(0, 1, n_colors)))
self.palette = CFG()
for k in self.rhythm_types_map.keys():
if k in ["N", "NOISE"]:
continue
self.palette[k] = next(colors)
self.beat_types = [
"A",
"N",
"Q",
"V",
# '"', "+", are not beat types
]
self.palette["qrs"] = "green"
[docs] def get_subject_id(self, rec: Union[str, int]) -> int:
"""Attach a unique subject ID for the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
int
Subject ID associated with the record.
"""
raise NotImplementedError
[docs] @add_docstring(
PhysioNetDataBase.load_data.__doc__.replace(
"leads: str or int or sequence of str or int, optional,",
"leads: int or list of int, optional,",
).replace("the leads to load", "the lead number(s) to load")
)
def load_data(
self,
rec: Union[str, int],
leads: Optional[Union[int, List[int]]] = None,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
data_format: str = "channel_first",
units: str = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
return super().load_data(rec, leads, sampfrom, sampto, data_format, units, fs, return_fs)
[docs] def load_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
rhythm_format: str = "interval",
beat_format: str = "beat",
keep_original: bool = False,
) -> dict:
"""Load rhythm and beat annotations of the record.
Rhythm and beat annotations are stored in the
`aux_note`, `symbol` attributes of corresponding annotation files.
NOTE that qrs annotations (.qrs files) do NOT contain any rhythm annotations.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto : int, optional
End index of the annotations to be loaded.
rhythm_format : {"interval", "mask"}, optional
Format of returned annotation, by default "interval",
case insensitive.
beat_format : {"beat", "dict"}, optional
Format of returned annotation, by default "beat",
case insensitive.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
ann : dict
The annotations of ``rhythm`` and ``beat``, with
``rhythm`` annotatoins in the format of intervals, or mask;
``beat`` annotations in the format of dict or
:class:`~torch_ecg.databases.BeatAnn`.
NOTE
----
At head and tail of the record, segments named "NOISE" were added.
"""
if isinstance(rec, int):
rec = self[rec]
ann = {
"beat": self.load_beat_ann(
rec,
sampfrom,
sampto,
beat_format,
keep_original,
),
"rhythm": self.load_rhythm_ann(
rec,
sampfrom,
sampto,
rhythm_format,
keep_original,
),
}
return ann
[docs] def load_rhythm_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
rhythm_format: str = "interval",
keep_original: bool = False,
) -> Union[Dict[str, list], np.ndarray]:
"""Load rhythm annotations of the record.
Rhythm annotations are stored in the `aux_note` attribute
of corresponding annotation files.
NOTE that qrs annotations (.qrs files) do NOT contain any rhythm annotations.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto : int, optional
End index of the annotations to be loaded.
rhythm_format : {"interval", "mask"}, optional
Format of returned annotation, by default "interval",
case insensitive.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
ann : dict or numpy.ndarray
Annotations in the format of intervals or mask.
NOTE
----
At head and tail of the record, segments named "NOISE" were added.
"""
if isinstance(rec, int):
rec = self[rec]
assert rhythm_format.lower() in [
"interval",
"mask",
], f"rhythm_format must be 'interval' or 'mask', got {rhythm_format}"
fp = str(self.get_absolute_path(rec))
header = wfdb.rdheader(str(fp))
sig_len = header.sig_len
sf = sampfrom or 0
st = sampto or sig_len
assert st > sf, "`sampto` should be greater than `sampfrom`!"
simplified_fp = self.db_dir / f"{rec}_ann.json"
if simplified_fp.is_file():
ann = CFG(json.loads(simplified_fp.read_text()))
else:
wfdb_ann = wfdb.rdann(str(fp), extension=self.manual_ann_ext)
ann = CFG({k: [] for k in self.rhythm_types_map.keys()})
critical_points = wfdb_ann.sample.tolist()
aux_note = wfdb_ann.aux_note
start = 0
current_rhythm = "NOISE"
for idx, rhythm in zip(critical_points, aux_note):
if rhythm not in self.rhythm_types:
continue
ann[current_rhythm].append([start, idx])
current_rhythm = rhythm.replace("(", "")
start = idx
# all but one end with "" ("30" ends with "\x01 Aux")
# i.e. none ends with (start of) valid rhythm
ann[current_rhythm].append([start, critical_points[-1]])
ann["NOISE"].append([critical_points[-1], sig_len])
simplified_fp.write_text(json.dumps(ann, ensure_ascii=False))
ann = CFG({k: generalized_intervals_intersection(l_itv, [[sf, st]]) for k, l_itv in ann.items()})
ann = CFG({k: l_itv for k, l_itv in ann.items() if len(l_itv) > 0})
if rhythm_format.lower() == "mask":
tmp = deepcopy(ann)
ann = np.full(shape=(st - sf,), fill_value=self.rhythm_types_map.N, dtype=int)
for rhythm, l_itv in tmp.items():
for itv in l_itv:
ann[itv[0] - sf : itv[1] - sf] = self.rhythm_types_map[rhythm]
elif not keep_original:
for k, l_itv in ann.items():
ann[k] = [[itv[0] - sf, itv[1] - sf] for itv in l_itv]
return ann
[docs] def load_beat_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
beat_format: str = "beat",
keep_original: bool = False,
) -> Union[Dict[str, np.ndarray], List[BeatAnn]]:
"""Load beat annotations of the record.
Beat annotations are stored in the `symbol` attribute
of corresponding annotation files.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto : int, optional
End index of the annotations to be loaded.
beat_format : {"beat", "dict"}, optional
Format of returned annotation, by default "beat",
case insensitive.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
ann : dict or list
Locations (indices) of the all the
beat types ("A", "N", "Q", "V").
"""
if isinstance(rec, int):
rec = self[rec]
assert beat_format.lower() in [
"beat",
"dict",
], f"beat_format must be 'beat' or 'dict', got {beat_format}"
fp = self.get_absolute_path(rec)
header = wfdb.rdheader(str(fp))
sig_len = header.sig_len
sf = sampfrom or 0
st = sampto or sig_len
assert st > sf, "`sampto` should be greater than `sampfrom`!"
wfdb_ann = wfdb.rdann(
str(fp),
extension=self.manual_ann_ext,
sampfrom=sf,
sampto=sampto,
)
ann = CFG({k: [] for k in self.beat_types})
for idx, bt in zip(wfdb_ann.sample, wfdb_ann.symbol):
if bt not in self.beat_types:
continue
ann[bt].append(idx)
if not keep_original and sampfrom is not None:
ann = CFG({k: np.array(v, dtype=int) - sf for k, v in ann.items()})
else:
ann = CFG({k: np.array(v, dtype=int) for k, v in ann.items()})
if beat_format.lower() == "beat":
ann = [BeatAnn(i, s) for s, l in ann.items() for i in l]
return ann
[docs] def load_rpeak_indices(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
use_manual: bool = True,
keep_original: bool = False,
) -> np.ndarray:
"""Load rpeak indices of the record.
Rpeak indices, or equivalently qrs complex locations,
are stored in the `symbol` attribute of corresponding annotation files,
regardless of their beat types.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto : int, optional
End index of the annotations to be loaded.
use_manual : bool, default True
If True, manually annotated beat annotations (qrs) will be used,
instead of those generated by algorithms.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
rpeak_inds : numpy.ndarray
Locations (indices) of the all the rpeaks (qrs complexes).
"""
fp = str(self.get_absolute_path(rec))
if use_manual:
ext = self.manual_ann_ext
else:
ext = self.auto_ann_ext
wfdb_ann = wfdb.rdann(
str(fp),
extension=ext,
sampfrom=sampfrom or 0,
sampto=sampto,
)
rpeak_inds = wfdb_ann.sample[np.isin(wfdb_ann.symbol, self.beat_types)]
if not keep_original and sampfrom is not None:
rpeak_inds = rpeak_inds - sampfrom
return rpeak_inds
[docs] def plot(
self,
rec: Union[str, int],
data: Optional[np.ndarray] = None,
ann: Optional[Dict[str, np.ndarray]] = None,
beat_ann: Optional[Dict[str, np.ndarray]] = None,
rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None,
ticks_granularity: int = 0,
leads: Optional[Union[int, List[int]]] = None,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
same_range: bool = False,
**kwargs: Any,
) -> None:
"""
Plot the signals of a record or external signals (units in μV),
with metadata (fs, labels, tranche, etc.),
possibly also along with wave delineations.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
data : numpy.ndarray, optional
(2-lead) ECG signal to plot,
should be of the format "channel_first",
and compatible with `leads`.
If is not None, data of `rec` will not be used.
This is useful when plotting filtered data.
ann : dict, optional
Rhythm annotations for `data`, covering those from annotation files,
in the form of ``{k: l_itv, ...}``,
where ``k`` are listed in `self.rhythm_types_map`,
and ``l_itv`` are of the form of ``[[a, b], ...]``.
Ignored if `data` is None
beat_ann : dict, optional
Beat annotations for `data`, covering those from annotation files,
in the form of ``{k: l_inds, ...}``,
where ``k`` are listed in `self.beat_types`,
and `l_inds` are array of indices.
Ignored if `data` is None.
rpeak_inds : array_like, optional
Indices of R peaks, covering those from annotation files.
If `data` is None, then indices should be
absolute indices in the record
ticks_granularity : int, default 0
Granularity to plot axis ticks, the higher the more ticks.
0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks)
leads : int or List[int], optional,
The lead number(s) of the data to plot.
sampfrom : int, optional
Start index of the data to plot.
sampto : int, optional
End index of the data to plot.
same_range : bool, default False
If True, all leads are forced to have the same y range.
kwargs : dict, optional
Additional arguments to be passed to `matplotlib.pyplot.plot`, etc.
"""
if isinstance(rec, int):
rec = self[rec]
if "plt" not in dir():
import matplotlib.pyplot as plt
plt.MultipleLocator.MAXTICKS = 3000
if leads is None or leads == "all":
_leads = self.all_leads
elif isinstance(leads, int):
_leads = [leads]
else:
_leads = leads
assert all([ld in self.all_leads for ld in _leads])
lead_indices = [self.all_leads.index(ld) for ld in _leads]
if data is None:
_data = self.load_data(
rec,
leads=_leads,
sampfrom=sampfrom,
sampto=sampto,
data_format="channel_first",
units="μV",
)
else:
units = self._auto_infer_units(data)
self.logger.info(f"input data is auto detected to have units in {units}")
if units.lower() == "mv":
_data = 1000 * data
else:
_data = data
_leads = list(range(_data.shape[0]))
if ann is None and data is None:
_ann = self.load_rhythm_ann(
rec,
sampfrom=sampfrom,
sampto=sampto,
rhythm_format="interval",
keep_original=False,
)
else:
_ann = ann or CFG({k: [] for k in self.rhythm_types_map.keys()})
# indices to time
_ann = {k: [[itv[0] / self.fs, itv[1] / self.fs] for itv in l_itv] for k, l_itv in _ann.items()}
if rpeak_inds is None and data is None:
_rpeak = self.load_rpeak_indices(
rec,
sampfrom=sampfrom,
sampto=sampto,
use_manual=True,
keep_original=False,
)
_rpeak = _rpeak / self.fs # indices to time
else:
_rpeak = np.array(rpeak_inds or []) / self.fs # indices to time
if beat_ann is None and data is None:
_beat_ann = self.load_beat_ann(
rec,
beat_format="dict",
sampfrom=sampfrom,
sampto=sampto,
keep_original=False,
)
else:
_beat_ann = beat_ann or CFG({k: [] for k in self.beat_types})
_beat_ann = {k: [i / self.fs for i in l_inds] for k, l_inds in _beat_ann.items()} # indices to time
ann_plot_alpha = 0.2
rpeaks_plot_alpha = 0.8
nb_leads = len(_leads)
line_len = self.fs * 25 # 25 seconds
nb_lines = math.ceil(_data.shape[1] / line_len)
for seg_idx in range(nb_lines):
seg_data = _data[..., seg_idx * line_len : (seg_idx + 1) * line_len]
secs = (np.arange(seg_data.shape[1]) + seg_idx * line_len) / self.fs
seg_ann = {k: generalized_intervals_intersection(l_itv, [[secs[0], secs[-1]]]) for k, l_itv in _ann.items()}
seg_rpeaks = _rpeak[np.where((_rpeak >= secs[0]) & (_rpeak < secs[-1]))[0]]
seg_beat_ann = {k: [i for i in l_inds if secs[0] <= i <= secs[-1]] for k, l_inds in _beat_ann.items()}
fig_sz_w = int(round(DEFAULT_FIG_SIZE_PER_SEC * seg_data.shape[1] / self.fs))
if same_range:
y_ranges = np.ones((seg_data.shape[0],)) * np.max(np.abs(seg_data)) + 100
else:
y_ranges = np.max(np.abs(seg_data), axis=1) + 100
fig_sz_h = 6 * y_ranges / 1500
fig, axes = plt.subplots(nb_leads, 1, sharex=True, figsize=(fig_sz_w, np.sum(fig_sz_h)))
if nb_leads == 1:
axes = [axes]
for idx in range(nb_leads):
axes[idx].plot(secs, seg_data[idx], color="black", label=f"lead - {_leads[idx]}")
axes[idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red")
# NOTE that `Locator` has default `MAXTICKS` equal to 1000
if ticks_granularity >= 1:
axes[idx].xaxis.set_major_locator(plt.MultipleLocator(0.2))
axes[idx].yaxis.set_major_locator(plt.MultipleLocator(500))
axes[idx].grid(which="major", linestyle="-", linewidth="0.5", color="red")
if ticks_granularity >= 2:
axes[idx].xaxis.set_minor_locator(plt.MultipleLocator(0.04))
axes[idx].yaxis.set_minor_locator(plt.MultipleLocator(100))
axes[idx].grid(which="minor", linestyle=":", linewidth="0.5", color="black")
for k, l_itv in seg_ann.items():
if k in ["N", "NOISE"]:
continue
for itv in l_itv:
axes[idx].axvspan(
itv[0],
itv[1],
color=self.palette[k],
alpha=ann_plot_alpha,
label=k,
)
for ri in seg_rpeaks:
axes[idx].axvspan(
ri - 0.01,
ri + 0.01,
color=self.palette["qrs"],
alpha=rpeaks_plot_alpha,
)
for k, l_t in seg_beat_ann.items():
for t in l_t:
x_pos = t + 0.05 if t + 0.05 < secs[-1] else t - 0.15
axes[idx].text(x_pos, 0.65 * y_ranges[idx], k, color="black", fontsize=16)
axes[idx].legend(loc="upper left")
axes[idx].set_xlim(secs[0], secs[-1])
axes[idx].set_ylim(-y_ranges[idx], y_ranges[idx])
axes[idx].set_xlabel("Time [s]")
axes[idx].set_ylabel("Voltage [μV]")
plt.subplots_adjust(hspace=0.2)
plt.show()
@property
def database_info(self) -> DataBaseInfo:
return _LTAFDB_INFO