# -*- coding: utf-8 -*-
import os
from collections import defaultdict
from numbers import Real
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
import scipy.signal as SS
import wfdb
from tqdm.auto import tqdm
from ...cfg import DEFAULTS
from ...utils import add_docstring, generalized_intervals_intersection, get_record_list_recursive3
from ..base import DataBaseInfo, PhysioNetDataBase, PSGDataBaseMixin
__all__ = [
"CINC2018",
]
_CINC2018_INFO = DataBaseInfo(
title="""
You Snooze You Win - The PhysioNet Computing in Cardiology Challenge 2018
""",
about="""
1. includes 1,985 subjects, partitioned into balanced training (n = 994), and test sets (n = 989)
2. signals include
electrocardiogram (ECG),
electroencephalography (EEG),
electrooculography (EOG),
electromyography (EMG),
electrocardiology (EKG),
oxygen saturation (SaO2),
etc.
3. frequency of all signal channels is 200 Hz
4. units of signals:
mV for ECG, EEG, EOG, EMG, EKG
percentage for SaO2
5. six sleep stages were annotated in 30 second contiguous intervals:
wakefulness,
stage 1,
stage 2,
stage 3,
rapid eye movement (REM),
undefined
6. annotated arousals were classified as either of the following:
spontaneous arousals,
respiratory effort related arousals (RERA),
bruxisms,
hypoventilations,
hypopneas,
apneas (central, obstructive and mixed),
vocalizations,
snores,
periodic leg movements,
Cheyne-Stokes breathing,
partial airway obstructions
7. Webpage of the database on PhysioNet [1]_.
""",
usage=[
"sleep stage",
"sleep apnea",
],
references=[
"https://physionet.org/content/challenge-2018/",
],
doi=[
"10.22489/CinC.2018.049",
"10.13026/6phb-r450",
],
)
[docs]@add_docstring(_CINC2018_INFO.format_database_docstring(), mode="prepend")
class CINC2018(PhysioNetDataBase, PSGDataBaseMixin):
"""
Parameters
----------
db_dir : `path-like`, optional
Storage path of the database.
If not specified, data will be fetched from Physionet.
working_dir : `path-like`, optional
Working directory, to store intermediate files and log files.
verbose : int, default 1
Level of logging verbosity.
kwargs : dict, optional
Auxilliary key word arguments.
"""
__name__ = "CINC2018"
def __init__(
self,
db_dir: Optional[Union[str, bytes, os.PathLike]] = None,
working_dir: Optional[Union[str, bytes, os.PathLike]] = None,
verbose: int = 1,
**kwargs: Any,
) -> None:
super().__init__(
db_name="challenge-2018",
db_dir=db_dir,
working_dir=working_dir,
verbose=verbose,
**kwargs,
)
self.fs = 200
self._subset = kwargs.get("subset", "training")
self.rec_ext = "mat"
self.ann_ext = "arousal"
# fmt: off
self.sleep_stage_names = ["W", "R", "N1", "N2", "N3"]
self.arousal_types = [
"arousal_bruxism", "arousal_noise", "arousal_plm", "arousal_rera", "arousal_snore", "arousal_spontaneous",
"resp_centralapnea", "resp_cheynestokesbreath", "resp_hypopnea", "resp_hypoventilation",
"resp_mixedapnea", "resp_obstructiveapnea", "resp_partialobstructive",
]
# fmt: on
self.training_rec_pattern = "^tr\\d{2}\\-\\d{4}.mat$"
self.test_rec_pattern = "^te\\d{2}\\-\\d{4}.mat$"
self.training_records = []
self.test_records = []
self._all_records = []
self._df_records = None
self._ls_rec()
def _ls_rec(self) -> None:
"""Find all records in the database directory
and store them (path, metadata, etc.) in some private attributes.
"""
self._df_records = pd.DataFrame()
records = get_record_list_recursive3(
self.db_dir,
{"training": self.training_rec_pattern, "test": self.test_rec_pattern},
relative=False,
)
for k in records:
df_tmp = pd.DataFrame(sorted(records[k]), columns=["path"])
df_tmp["subset"] = k
self._df_records = pd.concat([self._df_records, df_tmp], axis=0, ignore_index=True)
self._df_records["record"] = self._df_records["path"].apply(lambda x: Path(x).stem)
self._df_records["subject_id"] = self._df_records["record"].apply(self.get_subject_id)
self._df_records.set_index("record", inplace=True)
self._df_records["fs"] = None
self._df_records["siglen"] = None
self._df_records["available_signals"] = None
with tqdm(
self._df_records.iterrows(),
total=len(self._df_records),
mininterval=1.0,
desc="Loading metadata",
disable=self.verbose < 1,
) as pbar:
for idx, row in pbar:
header = wfdb.rdheader(row["path"])
self._df_records.at[idx, "fs"] = header.fs
self._df_records.at[idx, "siglen"] = header.sig_len
self._df_records.at[idx, "available_signals"] = header.sig_name
if self._subset is not None:
self._df_records = self._df_records[self._df_records["subset"] == self._subset]
if self._subsample is not None:
if self._subset is None:
df_tmp = pd.DataFrame(columns=self._df_records.columns)
for k in records:
size = int(round(self._subsample * len(records[k])))
if size > 0:
df_tmp = pd.concat(
[
df_tmp,
self._df_records[self._df_records["subset"] == k].sample(
size, random_state=DEFAULTS.SEED, replace=False
),
],
axis=0,
ignore_index=True,
)
if len(df_tmp) == 0:
size = min(
len(self._df_records),
max(1, int(round(self._subsample * len(self._df_records)))),
)
df_tmp = self._df_records.sample(size, random_state=DEFAULTS.SEED, replace=False)
del self._df_records
self._df_records = df_tmp.copy()
del df_tmp
else:
size = min(
len(self._df_records),
max(1, int(round(self._subsample * len(self._df_records)))),
)
if size > 0:
self._df_records = self._df_records.sample(size, random_state=DEFAULTS.SEED, replace=False)
self._all_records = self._df_records.index.tolist()
self.training_records = self._df_records[self._df_records["subset"] == "training"].index.tolist()
self.test_records = self._df_records[self._df_records["subset"] == "test"].index.tolist()
[docs] def get_subject_id(self, rec: str) -> int:
"""Attach a unique subject ID for the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
int
Subject ID associated with the record.
"""
head = "2018"
mid = rec[2:4]
tail = rec[-4:]
pid = int(head + mid + tail)
return pid
[docs] def set_subset(self, subset: Union[str, None]) -> None:
"""Set the subset of the database to use."""
assert subset in [
"training",
"test",
None,
], """`subset` must be in ``["training", "test", None]``."""
self._subset = subset
self._ls_rec()
[docs] def get_available_signals(self, rec: Union[str, int]) -> List[str]:
"""Get the available signals of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
signals : List[str]
Names of available signal of the record.
"""
if isinstance(rec, int):
rec = self[rec]
return self._df_records.at[rec, "available_signals"]
[docs] def get_fs(self, rec: Union[str, int]) -> int:
"""Get the sampling frequency of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
fs : int
Sampling frequency of the record.
"""
if isinstance(rec, int):
rec = self[rec]
return self._df_records.at[rec, "fs"]
[docs] def get_siglen(self, rec: Union[str, int]) -> int:
"""Get the length of the signal of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
siglen : int
Length of the signal of the record.
"""
if isinstance(rec, int):
rec = self[rec]
return self._df_records.at[rec, "siglen"]
[docs] def load_psg_data(
self,
rec: Union[str, int],
channel: Optional[Union[str, Sequence[str]]] = None,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
data_format: str = "channel_first",
physical: bool = True,
fs: Optional[Real] = None,
return_fs: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
"""Load PSG data of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
channel : str, optional
Nname of the channel of PSG data.
If is None, all channels will be returned.
sampfrom : int, optional
Start index of the data to be loaded.
sampto : int, optional
End index of the data to be loaded.
data_format: str, default "channel_first".
Format of the ECG data,
"channel_last" (alias "lead_last"), or
"channel_first" (alias "lead_first"), or
"flat" (alias "plain") which is valid
only when only one `channel` is passed.
physical : bool, default True
If True, the data will be converted to physical units,
otherwise, the data will be in digital units.
fs : numbers.Real, optional
Sampling frequency of the output signal.
If not None, the loaded data will be resampled to this frequency,
otherwise, the original sampling frequency will be used.
return_fs : bool, default False
Whether to return the sampling frequency of the output signal.
Returns
-------
data : numpy.ndarray
PSG data corr. to the given `channel` of the record.
data_fs : numbers.Real, optional
Sampling frequency of the output signal.
"""
available_signals = self.get_available_signals(rec)
chn = available_signals if channel is None else channel
if isinstance(chn, str):
chn = [chn]
assert set(chn).issubset(set(available_signals)), f"`channel` should be one of `{available_signals}`, but got `{chn}`"
allowed_data_format = [
"channel_first",
"lead_first",
"channel_last",
"lead_last",
"flat",
"plain",
]
assert (
data_format.lower() in allowed_data_format
), f"`data_format` should be one of `{allowed_data_format}`, but got `{data_format}`"
if len(chn) > 1:
assert data_format.lower() in [
"channel_first",
"lead_first",
"channel_last",
"lead_last",
], (
"`data_format` should be one of "
"`['channel_first', 'lead_first', 'channel_last', 'lead_last']` "
f"when the passed number of `channel` is larger than 1, but got `{data_format}`"
)
frp = str(self.get_absolute_path(rec))
wfdb_header = wfdb.rdheader(frp)
sampfrom = max(0, sampfrom or 0)
sampto = min(sampto or wfdb_header.sig_len, wfdb_header.sig_len)
wfdb_rec = wfdb.rdrecord(frp, sampfrom=sampfrom, sampto=sampto, channel_names=chn, physical=physical)
ret_data = wfdb_rec.p_signal.T if physical else wfdb_rec.d_signal.T
if fs is not None and fs != wfdb_header.fs:
ret_data = SS.resample_poly(ret_data, fs, wfdb_header.fs, axis=-1)
data_fs = fs
else:
data_fs = wfdb_header.fs
if data_format.lower() in ["channel_last", "lead_last"]:
ret_data = ret_data.T
elif data_format.lower() in ["flat", "plain"]:
ret_data = ret_data.flatten()
if return_fs:
return ret_data, data_fs
return ret_data
[docs] def load_data(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
data_format: str = "channel_first",
units: Union[str, type(None)] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
"""Load ECG data of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
leads : str or int or Sequence[str] or Sequence[int], optional
The leads of the ECG data to load.
None or "all" for all leads.
sampfrom : int, optional
Start index of the data to be loaded.
sampto : int, optional
End index of the data to be loaded.
data_format : str, default "channel_first"
Format of the ECG data,
"channel_last" (alias "lead_last"), or
"channel_first" (alias "lead_first"), or
"flat" (alias "plain") which is valid only when `leads` is a single lead
units : str or None, default "mV"
Units of the output signal, can also be "μV" (aliases "uV", "muV").
None for digital data, without digital-to-physical conversion.
fs : numbers.Real, optional
Sampling frequency of the output signal.
If not None, the loaded data will be resampled to this frequency,
otherwise, the original sampling frequency will be used.
return_fs : bool, default False
Whether to return the sampling frequency of the output signal.
Returns
-------
data : numpy.ndarray
The ECG data loaded from the record,
with given `units` and `data_format`.
data_fs : numbers.Real, optional
Sampling frequency of the output signal.
Returned if `return_fs` is True.
"""
available_signals = self.get_available_signals(rec)
assert "ECG" in available_signals, f"the record `{rec}` does not have ECG signal"
allowed_units = ["mv", "uv", "μv", "muv"]
assert (
units is None or units.lower() in allowed_units
), f"`units` should be one of `{allowed_units}` or None, but got `{units}`"
data, data_fs = self.load_psg_data(
rec=rec,
channel="ECG",
sampfrom=sampfrom,
sampto=sampto,
data_format=data_format,
physical=units is not None,
fs=fs,
return_fs=True,
)
if units.lower() in ["μv", "uv", "muv"]:
data = 1000 * data
if return_fs:
return data, data_fs
return data
[docs] @add_docstring(load_data.__doc__)
def load_ecg_data(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
data_format: str = "channel_first",
units: Union[str, type(None)] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
"""alias of `load_data`"""
return self.load_data(
rec=rec,
sampfrom=sampfrom,
sampto=sampto,
data_format=data_format,
units=units,
fs=fs,
return_fs=return_fs,
)
[docs] def load_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
keep_original: bool = False,
) -> Dict[str, Dict[str, List[List[int]]]]:
"""Load sleep stage and arousal annotations of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the corresponding PSG data.
sampto : int, optional
End index of the corresponding PSG data.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
dict
A dictionary with keys "sleep_stages" and "arousals",
each of which is a dictionary with keys of sleep stages and arousals,
and values of lists of lists of start and
end indices of the sleep stages and arousals.
"""
frp = str(self.get_absolute_path(rec))
wfdb_ann = wfdb.rdann(frp, extension=self.ann_ext)
sleep_stages = defaultdict(list)
arousals = defaultdict(list)
current_sleep_stage = None
current_sleep_stage_start = None
for aux_note, sample in zip(wfdb_ann.aux_note, wfdb_ann.sample.tolist()):
if aux_note in self.sleep_stage_names:
if current_sleep_stage is not None:
sleep_stages[current_sleep_stage].append([current_sleep_stage_start, sample])
current_sleep_stage = aux_note
current_sleep_stage_start = sample
else:
if "(" in aux_note:
current_arousal_start = sample
else:
arousals[aux_note.strip(")")].append([current_arousal_start, sample])
siglen = self.get_siglen(rec)
if current_sleep_stage_start < siglen:
sleep_stages[current_sleep_stage].append([current_sleep_stage_start, siglen])
sampfrom = max(0, sampfrom or 0)
sampto = min(sampto or siglen, siglen)
sleep_stages = {
k: generalized_intervals_intersection(v, [[sampfrom, sampto]], drop_degenerate=True)
for k, v in sleep_stages.items()
}
sleep_stages = {k: v for k, v in sleep_stages.items() if len(v) > 0}
arousals = {
k: generalized_intervals_intersection(v, [[sampfrom, sampto]], drop_degenerate=True) for k, v in arousals.items()
}
arousals = {k: v for k, v in arousals.items() if len(v) > 0}
if not keep_original:
sleep_stages = {k: [[s - sampfrom, e - sampfrom] for s, e in v] for k, v in sleep_stages.items()}
arousals = {k: [[s - sampfrom, e - sampfrom] for s, e in v] for k, v in arousals.items()}
return {
"sleep_stages": sleep_stages,
"arousals": arousals,
}
[docs] def load_sleep_stages_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
keep_original: bool = False,
) -> Dict[str, List[List[int]]]:
"""Load sleep stage annotations of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the corresponding PSG data.
sampto : int, optional
End index of the corresponding PSG data.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
dict
A dictionary with keys of sleep stages and
values of lists of lists of start and
end indices of the sleep stages.
"""
return self.load_ann(
rec=rec,
sampfrom=sampfrom,
sampto=sampto,
keep_original=keep_original,
)["sleep_stages"]
[docs] def load_arousals_ann(
self,
rec: Union[str, int],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
keep_original: bool = False,
) -> Dict[str, List[List[int]]]:
"""Load arousal annotations of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
sampfrom : int, optional
Start index of the corresponding PSG data.
sampto : int, optional
End index of the corresponding PSG data.
keep_original : bool, default False
If True, indices will keep the same with the annotation file,
otherwise subtract `sampfrom` if specified.
Returns
-------
dict
A dictionary with keys of arousals and
values of lists of lists of start and
end indices of the arousals.
"""
return self.load_ann(
rec=rec,
sampfrom=sampfrom,
sampto=sampto,
keep_original=keep_original,
)["arousals"]
[docs] def plot(self) -> None:
"""NOT implemented yet."""
raise NotImplementedError
[docs] def plot_ann(self, rec: Union[str, int]) -> tuple:
"""Plot the sleep stage and arousal annotations of the record.
Parameters
----------
rec : str or int
Record name or index of the record in :attr:`all_records`.
Returns
-------
fig : matplotlib.figure.Figure
The figure object.
ax : matplotlib.axes.Axes
The axes object.
TODO
----
Plot arousals events.
"""
ann = self.load_ann(rec)
sleep_stages = ann["sleep_stages"]
arousals = ann["arousals"]
stage_mask = self.sleep_stage_intervals_to_mask(sleep_stages)
fig, ax = self.plot_hypnogram(stage_mask)
# TODO: plot arousals events
return fig, ax
@property
def database_info(self) -> DataBaseInfo:
return _CINC2018_INFO