# -*- coding: utf-8 -*-
import os
from collections import Counter, defaultdict
from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np
import pandas as pd
import wfdb
from tqdm.auto import tqdm
from ...cfg import CFG, DEFAULTS
from ...utils.misc import add_docstring, get_record_list_recursive3
from ...utils.utils_interval import generalized_intervals_intersection
from ..base import (
BeatAnn,
DataBaseInfo,
PhysioNetDataBase,
WFDB_Beat_Annotations,
WFDB_Non_Beat_Annotations,
WFDB_Rhythm_Annotations,
)
__all__ = [
"MITDB",
]
_MITDB_INFO = DataBaseInfo(
title="""
MIT-BIH Arrhythmia Database
""",
about="""
1. contains 48 half-hour excerpts of two-channel ambulatory ECG recordings, obtained from 47 subjects.
2. recordings were digitized at 360 samples per second per channel with 11-bit resolution over a 10 mV range.
3. annotations contains:
- beat-wise or finer (e.g. annotations of flutter wave) annotations, accessed via the `symbol` attribute of an `Annotation`.
- rhythm annotations, accessed via the `aux_note` attribute of an `Annotation`.
4. Webpage of the database on PhysioNet [1]_.
""",
usage=[
"Beat classification",
"Rhythm classification (segmentation)",
"R peaks detection",
],
references=[
"https://physionet.org/content/mitdb/",
],
doi=[
"10.1109/51.932724",
"10.13026/C2F305",
],
)
[docs]@add_docstring(_MITDB_INFO.format_database_docstring(), mode="prepend")
class MITDB(PhysioNetDataBase):
"""
Parameters
----------
db_dir : `path-like`, optional
Storage path of the database.
If not specified, data will be fetched from Physionet.
working_dir : `path-like`, optional
Working directory, to store intermediate files and log files.
verbose : int, default 1
Level of logging verbosity.
kwargs : dict, optional
Auxilliary key word arguments.
"""
__name__ = "MITDB"
def __init__(
self,
db_dir: Optional[Union[str, bytes, os.PathLike]] = None,
working_dir: Optional[Union[str, bytes, os.PathLike]] = None,
verbose: int = 1,
**kwargs: Any,
) -> None:
super().__init__(
db_name="mitdb",
db_dir=db_dir,
working_dir=working_dir,
verbose=verbose,
**kwargs,
)
self.fs = 360
self.data_ext = "dat"
self.data_pattern = "^[\\d]{3}$"
self.data_pattern_with_ext = f"^[\\d]{{3}}\\.{self.data_ext}$"
self.ann_ext = "atr"
self.beat_types_extended = list("""!"+/AEFJLNQRSV[]aefjx|~""")
self.nonbeat_types = [item for item in self.beat_types_extended if item in WFDB_Non_Beat_Annotations]
self.beat_types = [item for item in self.beat_types_extended if item in WFDB_Beat_Annotations]
self.beat_types_map = {item: i for i, item in enumerate(self.beat_types)}
self.beat_types_extended_map = {item: i for i, item in enumerate(self.beat_types_extended)}
self.rhythm_types = [
"(AB",
"(AFIB",
"(AFL",
"(B",
"(BII",
"(IVR",
"(N",
"(NOD",
"(P",
"(PREX",
"(SBR",
"(SVTA",
"(T",
"(VFL",
"(VT",
"MISSB",
"PSE",
"TS",
]
self.rhythm_types = [rt.lstrip("(") for rt in self.rhythm_types if rt in WFDB_Rhythm_Annotations]
self.rhythm_types_map = {rt: idx for idx, rt in enumerate(self.rhythm_types)}
self._rhythm_ignore_index = -100
# records have different lead names
# therefore, self.all_leads should not be set
# otherwise, it will cause problems when loading data using `self.load_data`
self._all_leads = ["MLII", "V1", "V2", "V4", "V5"]
self._ls_rec()
self._stats = pd.DataFrame()
self._stats_columns = ["record", "beat_num", "beat_type_num", "rhythm_len"]
self._aggregate_stats()
def _ls_rec(self) -> None:
"""Find all records in the database directory
and store them (path, metadata, etc.) in some private attributes.
"""
subsample = self._subsample
self._subsample = None # so that no subsampling in super()._ls_rec()
super()._ls_rec()
# filters out records with names not matching `self.data_pattern`
if len(self._df_records) > 0:
self._df_records = self._df_records[self._df_records.index.str.match(self.data_pattern)]
if len(self._all_records) == 0:
self._df_records = pd.DataFrame()
self._df_records["path"] = get_record_list_recursive3(self.db_dir, self.data_pattern_with_ext, relative=False)
self._df_records["record"] = self._df_records["path"].apply(lambda x: x.stem)
self._df_records.set_index("record", inplace=True)
if subsample is not None:
size = min(
len(self._df_records),
max(1, int(round(subsample * len(self._df_records)))),
)
self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`")
self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False)
self._all_records = self._df_records.index.tolist()
self._subsample = subsample
def _aggregate_stats(self) -> None:
"""Aggregate statistics for all records in the database."""
self._stats = pd.DataFrame(columns=self._stats_columns)
if len(self) == 0:
return
with tqdm(
range(len(self)),
desc="Aggregating stats",
unit="record",
dynamic_ncols=True,
mininterval=1.0,
disable=(self.verbose < 1),
) as pbar:
for idx in pbar:
rec_ann = self.load_ann(idx)
beat_type_num = {k: v for k, v in Counter([item.symbol for item in rec_ann["beat"]]).most_common()}
beat_num = sum(beat_type_num.values())
rhythm_len = {k: sum([itv[1] - itv[0] for itv in v]) for k, v in rec_ann["rhythm"].items()}
self._stats = pd.concat(
[
self._stats,
pd.DataFrame(
[
[
self._all_records[idx],
beat_num,
beat_type_num,
rhythm_len,
]
],
columns=self._stats_columns,
),
],
ignore_index=True,
)
[docs] def load_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
rhythm_format: str = "intervals",
rhythm_types: Optional[Sequence[str]] = None,
beat_format: str = "beat",
beat_types: Optional[Sequence[str]] = None,
keep_original: bool = False,
) -> dict:
"""Load rhythm and beat annotations of the record.
Rhythm and beat annotations are stored in the `aux_note`, `symbol`
attributes of corresponding annotation files.
NOTE that qrs annotations (.qrs files) do NOT contain any rhythm annotations.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto : int, optional
End index of the annotations to be loaded.
rhythm_format : {"interval", "mask"}, optional
Format of returned annotation, by default "interval",
case insensitive.
rhythm_types : list of str, optional
Defaults to `self.rhythm_types`.
If is not None, only the rhythm annotations
with the specified types will be returned.
beat_format : {"beat", "dict"}, optional
Format of returned annotation, by default "beat",
case insensitive.
beat_types : List[str], optional
Beat types to be loaded, by default `self.beat_types`.
If is not None, only the beat annotations
with the specified types will be returned.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
ann : dict
The annotations of ``rhythm`` and ``beat``, with
``rhythm`` annotatoins in the format of intervals, or mask;
``beat`` annotations in the format of dict or
:class:`~torch_ecg.databases.BeatAnn`.
"""
assert rhythm_format.lower() in [
"intervals",
"mask",
], f"`rhythm_format` must be one of ['intervals', 'mask'], got {rhythm_format}"
assert beat_format.lower() in [
"beat",
"dict",
], f"`beat_format` must be one of ['beat', 'dict'], got {beat_format}"
fp = str(self.get_absolute_path(rec))
wfdb_ann = wfdb.rdann(fp, extension=self.ann_ext)
header = wfdb.rdheader(fp)
sig_len = header.sig_len
sf = sampfrom or 0
st = sampto or sig_len
assert st > sf, "`sampto` should be greater than `sampfrom`!"
sample_inds = wfdb_ann.sample
indices = np.where((sample_inds >= sf) & (sample_inds < st))[0]
if beat_types is None:
beat_types = self.beat_types
beat_ann = [BeatAnn(i, s) for i, s in zip(sample_inds[indices], np.array(wfdb_ann.symbol)[indices]) if s in beat_types]
if rhythm_types is None:
rhythm_types = self.rhythm_types
rhythm_types_map = self.rhythm_types_map
else:
rhythm_types = [rt.lstrip("(") for rt in rhythm_types]
rhythm_types_map = {rt: idx for idx, rt in enumerate(rhythm_types)}
rhythm_intervals = defaultdict(list)
start_idx, rhythm = None, None
for ra, si in zip(wfdb_ann.aux_note, sample_inds):
ra = ra.rstrip("\x00").lstrip("(")
if ra in rhythm_types:
if start_idx is not None:
rhythm_intervals[rhythm].append([start_idx, si])
start_idx = si
rhythm = ra.lstrip("(")
if start_idx is not None:
rhythm_intervals[rhythm].append([start_idx, si])
rhythm_intervals = {k: np.array(generalized_intervals_intersection(v, [[sf, st]])) for k, v in rhythm_intervals.items()}
if rhythm_format.lower() == "mask":
rhythm_mask = np.full((st - sf,), self._rhythm_ignore_index, dtype=int)
for k, v in rhythm_intervals.items():
for itv in v:
rhythm_mask[itv[0] - sf : itv[1] - sf] = self.rhythm_types_map[k]
if not keep_original:
rhythm_intervals = {k: v - sf for k, v in rhythm_intervals.items()}
for b in beat_ann:
b.index -= sf
# if not extended_beats:
# beat_ann = [b for b in beat_ann if b.symbol in self.beat_types]
if beat_format.lower() == "dict":
beat_ann = {s: np.array([b.index for b in beat_ann if b.symbol == s], dtype=int) for s in self.beat_types_extended}
beat_ann = {k: v for k, v in beat_ann.items() if len(v) > 0}
ann = {}
ann["beat"] = beat_ann
if rhythm_format.lower() == "intervals":
ann["rhythm"] = rhythm_intervals
else:
ann["rhythm"] = rhythm_mask
return ann
[docs] def load_rhythm_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
rhythm_format: str = "intervals",
rhythm_types: Optional[Sequence[str]] = None,
keep_original: bool = False,
) -> Union[Dict[str, list], np.ndarray]:
"""Load rhythm annotations of the record.
Rhythm annotations are stored in the `aux_note` attribute
of corresponding annotation files.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto : int, optional
End index of the annotations to be loaded.
rhythm_format : {"interval", "mask"}, optional
Format of returned annotation, by default "interval",
case insensitive.
rhythm_types : list of str, optional
Defaults to `self.rhythm_types`.
If is not None, only the rhythm annotations
with the specified types will be returned.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
ann : dict or numpy.ndarray
Annotations in the format of intervals or mask.
"""
return self.load_ann(
rec,
sampfrom,
sampto,
rhythm_format=rhythm_format,
rhythm_types=rhythm_types,
keep_original=keep_original,
)["rhythm"]
[docs] def load_beat_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
beat_format: str = "beat",
beat_types: Optional[Sequence[str]] = None,
keep_original: bool = False,
) -> Union[Dict[str, np.ndarray], List[BeatAnn]]:
"""Load beat annotations of the record.
Beat annotations are stored in the `symbol` attribute
of corresponding annotation files.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto : int, optional
End index of the annotations to be loaded.
beat_format : {"beat", "dict"}, optional
Format of returned annotation, by default "beat",
case insensitive.
beat_types : List[str], optional
Beat types to be loaded, by default `self.beat_types`.
If is not None, only the beat annotations
with the specified types will be returned.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
ann : dict or list
Locations (indices) of the all the
beat types ("A", "N", "Q", "V",).
"""
return self.load_ann(
rec,
sampfrom,
sampto,
beat_format=beat_format,
beat_types=beat_types,
keep_original=keep_original,
)["beat"]
[docs] def load_rpeak_indices(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
keep_original: bool = False,
) -> np.ndarray:
"""Load rpeak indices of the record.
Rpeak indices, or equivalently qrs complex locations,
are stored in the `symbol` attribute of corresponding annotation files,
regardless of their beat types.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the annotations to be loaded.
sampto : int, optional
End index of the annotations to be loaded.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
rpeak_inds : numpy.ndarray
Locations (indices) of the all the rpeaks (qrs complexes).
"""
fp = str(self.get_absolute_path(rec))
wfdb_ann = wfdb.rdann(fp, extension=self.ann_ext)
header = wfdb.rdheader(fp)
sig_len = header.sig_len
sf = sampfrom or 0
st = sampto or sig_len
assert st > sf, "`sampto` should be greater than `sampfrom`!"
rpeak_inds = wfdb_ann.sample
indices = np.where((rpeak_inds >= sf) & (rpeak_inds < st) & (np.isin(wfdb_ann.symbol, self.beat_types)))[0]
rpeak_inds = rpeak_inds[indices]
if not keep_original:
rpeak_inds -= sf
return rpeak_inds
def _get_lead_names(self, rec: Union[str, int]) -> List[str]:
"""Get the names of the leads contained in the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
List[str]
A list of names of the leads contained in the record.
"""
return wfdb.rdheader(str(self.get_absolute_path(rec))).sig_name
@property
def df_stats(self) -> pd.DataFrame:
"""DataFrame of the statistics of the dataset."""
if self._stats.empty:
self._aggregate_stats()
return self._stats
@property
def df_stats_expanded(self) -> pd.DataFrame:
"""Expanded DataFrame of the statistics of the dataset."""
df = self.df_stats.copy(deep=True)
for bt in self.beat_types:
df[f"beat_{bt}"] = df["beat_type_num"].apply(lambda d: d.get(bt, 0))
for rt in self.rhythm_types:
df[f"rhythm_{rt}"] = df["rhythm_len"].apply(lambda d: d.get(rt, 0))
return df.drop(columns=["beat_num", "beat_type_num", "rhythm_len"])
@property
def df_stats_expanded_boolean(self) -> pd.DataFrame:
"""Expanded DataFrame of the statistics of the dataset,
with boolean values.
"""
df = self.df_stats_expanded.copy(deep=True)
for col in df.columns:
if col == "record":
continue
df[col] = df[col].apply(lambda x: int(x > 0))
return df
@property
def db_stats(self) -> Dict[str, Dict[str, int]]:
"""Dictionary of the statistics of the dataset."""
if self._stats.empty:
self._aggregate_stats()
rhythm_len = defaultdict(int)
for rl_dict in self._stats["rhythm_len"]:
for k, v in rl_dict.items():
rhythm_len[k] += v
beat_type_num = defaultdict(int)
for btn_dict in self._stats["beat_type_num"]:
for k, v in btn_dict.items():
beat_type_num[k] += v
return CFG(rhythm_len=dict(rhythm_len), beat_type_num=dict(beat_type_num))
def _categorize_records(self, by: str) -> Dict[str, List[str]]:
"""Categorize records by specific attributes.
Parameters
----------
by : {"beat", "rhythm"}
The attribute to categorize the records,
case insensitive.
Returns
-------
dict
A dict of lists of record names.
"""
assert by.lower() in [
"beat",
"rhythm",
], f"`by` should be one of 'beat' or 'rhythm', but got {by}"
key = dict(beat="beat_type_num", rhythm="rhythm_len")[by.lower()]
return CFG(
{item: [row["record"] for _, row in self.df_stats.iterrows() if item in row[key]] for item in self.db_stats[key]}
)
@property
def beat_types_records(self) -> Dict[str, List[str]]:
"""Dictionary of records with specific beat types."""
return self._categorize_records("beat")
@property
def rhythm_types_records(self) -> Dict[str, List[str]]:
"""Dictionary of records with specific rhythm types."""
return self._categorize_records("rhythm")
[docs] def plot(
self,
rec: Union[str, int],
data: Optional[np.ndarray] = None,
ann: Optional[Dict[str, np.ndarray]] = None,
beat_ann: Optional[Dict[str, np.ndarray]] = None,
rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None,
ticks_granularity: int = 0,
leads: Optional[Union[int, List[int]]] = None,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
same_range: bool = False,
**kwargs: Any,
) -> None:
"""Not implemented."""
raise NotImplementedError
@property
def database_info(self) -> DataBaseInfo:
return _MITDB_INFO