Source code for torch_ecg.databases.datasets.cpsc2021.cpsc2021_dataset

"""
data generator for feeding data into pytorch models

NOTE
----
In order to avoid potential error in the methods of slicing signals and rr intervals,
one can check using the following code

.. code-block:: python

    from torch_ecg.databases.datasets.cpsc2021 import CPSC2021Dataset, CPSC2021TrainCfg

    ds_train = CPSC2021Dataset(CPSC2021TrainCfg, task="qrs_detection", training=True)
    ds_val = CPSC2021Dataset(CPSC2021TrainCfg, task="qrs_detection", training=False)
    err_list = []
    for idx, seg in enumerate(ds_train.segments):
        sig, lb = ds_train[idx]
        if sig.shape != (2,6000) or lb.shape != (750, 1):
            print("\n"+f"segment {seg} has sig.shape = {sig.shape}, lb.shape = {lb.shape}"+"\n")
            err_list.append(seg)
        print(f"{idx+1}/{len(ds_train)}", end="\r")
    for idx, seg in enumerate(ds_val.segments):
        sig, lb = ds_val[idx]
        if sig.shape != (2,6000) or lb.shape != (750, 1):
            print("\n"+f"segment {seg} has sig.shape = {sig.shape}, lb.shape = {lb.shape}"+"\n")
            err_list.append(seg)
        print(f"{idx+1}/{len(ds_val)}", end="\r")
    for idx, seg in enumerate(err_list):
        path = ds_train._get_seg_data_path(seg)
        os.remove(path)
        path = ds_train._get_seg_ann_path(seg)
        os.remove(path)
        print(f"{idx+1}/{len(err_list)}", end="\r")

and similarly for the task of `rr_lstm`

"""

import json
import os
import re
import time
import warnings
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
from scipy import signal as SS
from scipy.io import loadmat, savemat
from torch.utils.data.dataset import Dataset
from tqdm.auto import tqdm

from ...._preprocessors import PreprocManager
from ....cfg import CFG, DEFAULTS
from ....databases import CPSC2021 as CR
from ....utils.misc import ReprMixin, add_docstring, get_record_list_recursive3, list_sum, nildent
from ....utils.utils_data import generate_weight_mask, mask_to_intervals
from ....utils.utils_signal import remove_spikes_naive
from .cpsc2021_cfg import CPSC2021TrainCfg

__all__ = [
    "CPSC2021Dataset",
]


[docs]class CPSC2021Dataset(ReprMixin, Dataset): """Data generator for feeding data into pytorch models using the :class:`~torch_ecg.databases.CPSC2021` database. Strategies for generating data and labels: 1. ECGs are preprocessed and stored in one folder 2. preprocessed ECGs are sliced with overlap to generate data and label for different tasks: - the data files stores segments of fixed length of preprocessed ECGs, - the annotation files contain "qrs_mask", and "af_mask" The returned values (tuple) of :meth:`__getitem__` depends on the task: 1. "qrs_detection": (`data`, `qrs_mask`, None) 2. "rr_lstm": (`rr_seq`, `rr_af_mask`, `rr_weight_mask`) 3. "main": (`data`, `af_mask`, `weight_mask`) where - `data` shape: ``(n_lead, n_sample)`` - `qrs_mask` shape: ``(n_sample, 1)`` - `af_mask` shape: ``(n_sample, 1)`` - `weight_mask` shape: ``(n_sample, 1)`` - `rr_seq` shape: ``(n_rr, 1)`` - `rr_af_mask` shape: ``(n_rr, 1)`` - `rr_weight_mask` shape: ``(n_rr, 1)`` Typical values of ``n_sample`` and ``n_rr`` are 6000 and 30, respectively. ``n_lead`` is typically 2, which is the number of leads in the ECG signal of the :class:`CPSC2021` database. Parameters ---------- config : dict Configurations for the dataset, ref. `CPSC2021TrainCfg`. A simple example is as follows: .. code-block:: python >>> config = deepcopy(CPSC2021TrainCfg) >>> config.db_dir = "some/path/to/db" >>> dataset = CPSC2021Dataset(config, task="main", training=True, lazy=False) training : bool, default True If True, the training set will be loaded, otherwise the test (val) set will be loaded. lazy : bool, default True If True, the data will not be loaded immediately, instead, it will be loaded on demand. **reader_kwargs : dict, optional Keyword arguments for the database reader class. """ __DEBUG__ = False __name__ = "CPSC2021Dataset" def __init__( self, config: CFG, task: str, training: bool = True, lazy: bool = True, **reader_kwargs: Any, ) -> None: super().__init__() self.config = deepcopy(config) if reader_kwargs.pop("db_dir", None) is not None: warnings.warn("`db_dir` is specified in both config and reader_kwargs", RuntimeWarning) self.reader = CR(db_dir=self.config.db_dir, **reader_kwargs) # assert self.config.db_dir is not None, "db_dir must be specified" self.config.db_dir = self.reader.db_dir self.dtype = self.config.np_dtype self.allowed_preproc = list( set( [ "bandpass", "baseline_remove", ] ).intersection(set(self.config.keys())) ) self.training = training self.lazy = lazy ppm_config = CFG(random=False) ppm_config.update(deepcopy(self.config)) ppm_config.pop("normalize") seg_ppm_config = CFG(random=False) seg_ppm_config.update(deepcopy(self.config)) seg_ppm_config.pop("bandpass") self.ppm = PreprocManager.from_config(ppm_config) self.seg_ppm = PreprocManager.from_config(seg_ppm_config) # create directories if needed # preprocess_dir stores pre-processed signals self.preprocess_dir = self.config.db_dir / "preprocessed" self.preprocess_dir.mkdir(parents=True, exist_ok=True) # segments_dir for sliced segments of fixed length self.segments_base_dir = self.config.db_dir / "segments" self.segments_base_dir.mkdir(parents=True, exist_ok=True) self.segment_name_pattern = "S_[\\d]{1,3}_[\\d]{1,2}_[\\d]{7}" self.segment_ext = "mat" # rr_dir for sequence of rr intervals of fix length self.rr_seq_base_dir = self.config.db_dir / "rr_seq" self.rr_seq_base_dir.mkdir(parents=True, exist_ok=True) self.rr_seq_name_pattern = "R_[\\d]{1,3}_[\\d]{1,2}_[\\d]{7}" self.rr_seq_ext = "mat" self._all_data = None self._all_labels = None self._all_masks = None self.__set_task(task, lazy=self.lazy) def _load_all_data(self) -> None: """Load all data into memory.""" self.__set_task(self.task, lazy=False) def __set_task(self, task: str, lazy: bool = True) -> None: """Set the task and load the data if needed. Parameters ---------- task : str Name of the task, can be one of `CPSC2021TrainCfg.tasks` lazy : bool, default True If True, the data will not be loaded immediately, instead, it will be loaded on demand. Returns ------- None """ assert task.lower() in CPSC2021TrainCfg.tasks, f"illegal task \042{task}\042" if hasattr(self, "task") and self.task == task.lower() and self._all_data is not None and len(self._all_data) > 0: return self.task = task.lower() self.all_classes = self.config[task].classes self.n_classes = len(self.config[task].classes) self.lazy = lazy self.seglen = self.config[task].input_len # alias, for simplicity split_res = self._train_test_split( train_ratio=self.config.train_ratio, force_recompute=False, ) if self.training: self.subjects = split_res.train else: self.subjects = split_res.test if self.task in [ "qrs_detection", "main", ]: # for qrs detection, or for the main task self.segments_dirs = CFG() self.__all_segments = CFG() self.segments_json = self.segments_base_dir / "segments.json" self._ls_segments() self.segments = list_sum([self.__all_segments[subject] for subject in self.subjects]) if self.__DEBUG__: self.segments = DEFAULTS.RNG_sample(self.segments, int(len(self.segments) * 0.01)).tolist() if self.training: DEFAULTS.RNG.shuffle(self.segments) # preload data self.fdr = _FastDataReader( self.config, self.task, self.seg_ppm, self.segments_dirs, self.segments, self.segment_ext, ) if self.lazy: return self._all_data, self._all_labels, self._all_masks = [], [], [] with tqdm( range(len(self.fdr)), desc="Loading data", unit="record", dynamic_ncols=True, mininterval=1.0, ) as pbar: for idx in pbar: d, l, m = self.fdr[idx] self._all_data.append(d) self._all_labels.append(l) self._all_masks.append(m) self._all_data = np.array(self._all_data).astype(self.dtype) self._all_labels = np.array(self._all_labels).astype(self.dtype) if self.task == "qrs_detection": self._all_masks = None else: self._all_masks = np.array(self._all_masks).astype(self.dtype) elif self.task in [ "rr_lstm", ]: self.rr_seq_dirs = CFG() self.__all_rr_seq = CFG() self.rr_seq_json = self.rr_seq_base_dir / "rr_seq.json" self._ls_rr_seq() self.rr_seq = list_sum([self.__all_rr_seq[subject] for subject in self.subjects]) if self.__DEBUG__: self.rr_seq = DEFAULTS.RNG_sample(self.rr_seq, int(len(self.rr_seq) * 0.01)).tolist() if self.training: DEFAULTS.RNG.shuffle(self.rr_seq) # preload data self.fdr = _FastDataReader( self.config, self.task, self.seg_ppm, self.rr_seq_dirs, self.rr_seq, self.rr_seq_ext, ) if self.lazy: return self._all_data, self._all_labels, self._all_masks = [], [], [] with tqdm( range(len(self.fdr)), desc="Loading data", unit="record", dynamic_ncols=True, mininterval=1.0, ) as pbar: for idx in pbar: d, l, m = self.fdr[idx] self._all_data.append(d) self._all_labels.append(l) self._all_masks.append(m) self._all_data = np.array(self._all_data).astype(self.dtype) self._all_labels = np.array(self._all_labels).astype(self.dtype) self._all_masks = np.array(self._all_masks).astype(self.dtype) else: raise NotImplementedError(f"data generator for task \042{self.task}\042 not implemented")
[docs] def reset_task(self, task: str, lazy: bool = True) -> None: """Reset the task of the data generator. Parameters ---------- task : str The task to be set. lazy : bool, optional If True, the data will not be loaded immediately, instead, it will be loaded on demand. Returns ------- None """ self.__set_task(task, lazy)
def _ls_segments(self) -> None: """Find all segments in the segments directory, and store them in some private attributes. """ for item in ["data", "ann"]: self.segments_dirs[item] = CFG() for s in self.reader.all_subjects: self.segments_dirs[item][s] = self.segments_base_dir / item / s self.segments_dirs[item][s].mkdir(parents=True, exist_ok=True) if self.segments_json.is_file(): self.__all_segments = json.loads(self.segments_json.read_text()) # return print(f"please allow the reader a few minutes to collect the segments from {self.segments_base_dir}...") seg_filename_pattern = f"{self.segment_name_pattern}\\.{self.segment_ext}" self.__all_segments = CFG( { s: get_record_list_recursive3(str(self.segments_dirs.data[s]), seg_filename_pattern) for s in self.reader.all_subjects } ) if all([len(self.__all_segments[s]) > 0 for s in self.reader.all_subjects]): self.segments_json.write_text(json.dumps(self.__all_segments, ensure_ascii=False)) def _ls_rr_seq(self) -> None: """Find all rr sequences in the rr sequences directory, and store them in some private attributes. """ for s in self.reader.all_subjects: self.rr_seq_dirs[s] = self.rr_seq_base_dir / s self.rr_seq_dirs[s].mkdir(parents=True, exist_ok=True) if self.rr_seq_json.is_file(): self.__all_rr_seq = json.loads(self.rr_seq_json.read_text()) # return print(f"please allow the reader a few minutes to collect the rr sequences from {self.rr_seq_base_dir}...") rr_seq_filename_pattern = f"{self.rr_seq_name_pattern}\\.{self.rr_seq_ext}" self.__all_rr_seq = CFG( {s: get_record_list_recursive3(self.rr_seq_dirs[s], rr_seq_filename_pattern) for s in self.reader.all_subjects} ) if all([len(self.__all_rr_seq[s]) > 0 for s in self.reader.all_subjects]): self.rr_seq_json.write_text(json.dumps(self.__all_rr_seq, ensure_ascii=False)) @property def all_segments(self) -> CFG: if self.task in [ "qrs_detection", "main", ]: return self.__all_segments else: return CFG() @property def all_rr_seq(self) -> CFG: if self.task.lower() in [ "rr_lstm", ]: return self.__all_rr_seq else: return CFG() def __len__(self) -> int: return len(self.fdr) def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]: if self.lazy: if self.task in ["qrs_detection"]: return self.fdr[index][:2] else: return self.fdr[index] else: if self.task in ["qrs_detection"]: return self._all_data[index], self._all_labels[index] else: return ( self._all_data[index], self._all_labels[index], self._all_masks[index], ) def _get_seg_data_path(self, seg: str) -> Path: """Get the path of the data file of segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_1_1_0000193". Returns ------- pathlib.Path Absolute path of the data file of the segment. """ subject = seg.split("_")[1] fp = self.segments_dirs.data[subject] / f"{seg}.{self.segment_ext}" return fp @add_docstring(_get_seg_data_path.__doc__.replace("data file", "annotation file")) def _get_seg_ann_path(self, seg: str) -> Path: subject = seg.split("_")[1] fp = self.segments_dirs.ann[subject] / f"{seg}.{self.segment_ext}" return fp def _load_seg_data(self, seg: str) -> np.ndarray: """Load the data of the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_1_1_0000193". Returns ------- numpy.ndarray Loaded data of the segment, of shape ``(2, self.seglen)``. """ seg_data_fp = self._get_seg_data_path(seg) seg_data = loadmat(str(seg_data_fp))["ecg"] return seg_data def _load_seg_ann(self, seg: str) -> dict: """Load the annotations of the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_1_1_0000193". Returns ------- seg_ann : dict Annotations of the segment, containing: - rpeaks: indices of rpeaks of the segment - qrs_mask: mask of qrs complexes of the segment - af_mask: mask of af episodes of the segment - interval: interval ([start_idx, end_idx]) in the original ECG record of the segment """ seg_ann_fp = self._get_seg_ann_path(seg) seg_ann = {k: v.flatten() for k, v in loadmat(str(seg_ann_fp)).items() if not k.startswith("__")} return seg_ann def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]: """Load the mask(s) of segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_1_1_0000193". task : str, optional Task of the mask(s) to be loaded, by default the current task. If is not None, the current task is overrided; if is "all", then all masks ("qrs_mask", "af_mask", etc.) will be returned. Returns ------- numpy.ndarray or dict Mask(s) of the segment, of shape ``(self.seglen, self.n_classes)``. """ seg_mask = { k: v.reshape((self.seglen, -1)) for k, v in self._load_seg_ann(seg).items() if k in [ "qrs_mask", "af_mask", ] } _task = (task or self.task).lower() if _task == "all": return seg_mask if _task in [ "qrs_detection", ]: seg_mask = seg_mask["qrs_mask"] elif _task in [ "main", ]: seg_mask = seg_mask["af_mask"] return seg_mask def _load_seg_seq_lab(self, seg: str, reduction: int) -> np.ndarray: """Load sequence labeling annotations of the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_1_1_0000193". reduction : int Reduction ratio (granularity) of length of the model output, compared to the original signal length. Returns ------- numpy.ndarray Label of the sequence, of shape ``(self.seglen//reduction, self.n_classes)``. """ seg_mask = self._load_seg_mask(seg) seg_len, n_classes = seg_mask.shape seq_lab = np.stack( arrays=[ np.mean( seg_mask[reduction * idx : reduction * (idx + 1)], axis=0, keepdims=True, ).astype(int) for idx in range(seg_len // reduction) ], axis=0, ).squeeze(axis=1) return seq_lab def _get_rr_seq_path(self, seq_name: str) -> Path: """Get the path of the annotation file of the rr_seq. Parameters ---------- seq_name : str Name of the rr_seq, of pattern like "R_1_1_0000193". Returns ------- pathlib.Path absolute path of the annotation file of the rr_seq. """ subject = seq_name.split("_")[1] fp = self.rr_seq_dirs[subject] / f"{seq_name}.{self.rr_seq_ext}" return fp def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]: """Load the metadata of the rr_seq. Parameters ---------- seq_name : str Name of the rr_seq, of pattern like "R_1_1_0000193". Returns ------- dict metadata of sequence of rr intervals, including - rr: the sequence of rr intervals, with units in seconds, of shape ``(self.seglen, 1)`` - label: label of the rr intervals, 0 for normal, 1 for af, of shape ``(self.seglen, self.n_classes)`` - interval: interval of the current rr sequence in the whole rr sequence in the original record """ rr_seq_path = self._get_rr_seq_path(seq_name) rr_seq = {k: v for k, v in loadmat(str(rr_seq_path)).items() if not k.startswith("__")} rr_seq["rr"] = rr_seq["rr"].reshape((self.seglen, 1)) rr_seq["label"] = rr_seq["label"].reshape((self.seglen, self.n_classes)) rr_seq["interval"] = rr_seq["interval"].flatten() return rr_seq
[docs] def persistence(self, force_recompute: bool = False, verbose: int = 0) -> None: """Save the preprocessed data to disk. Parameters ---------- force_recompute : bool, default False Whether to force recompute the preprocessed data. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ if verbose >= 1: print(" preprocessing data ".center(110, "#")) self._preprocess_data( force_recompute=force_recompute, verbose=verbose, ) original_task = self.task original_lazy = self.lazy self.__set_task("main", lazy=True) if verbose >= 1: print("\n" + " slicing data into segments ".center(110, "#")) self._slice_data( force_recompute=force_recompute, verbose=verbose, ) self.__set_task("rr_lstm", lazy=True) if verbose >= 1: print("\n" + " generating rr sequences ".center(110, "#")) self._slice_rr_seq( force_recompute=force_recompute, verbose=verbose, ) self.__set_task(original_task, lazy=original_lazy)
def _preprocess_data(self, force_recompute: bool = False, verbose: int = 0) -> None: """Preprocesses the ECG data in advance for further use. Parameters ---------- force_recompute : bool, default False Whether to force recompute the preprocessed data. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ for idx, rec in enumerate(self.reader.all_records): self._preprocess_one_record( rec=rec, force_recompute=force_recompute, verbose=verbose, ) if verbose >= 1: print(f"{idx+1}/{len(self.reader.all_records)} records", end="\r") def _preprocess_one_record(self, rec: str, force_recompute: bool = False, verbose: int = 0) -> None: """Preprocesses one ECG record in advance for further use. Parameters ---------- rec : str Name of the record. force_recompute : bool, default False Whether to force recompute the preprocessed data. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ suffix = self._get_rec_suffix(self.allowed_preproc) save_fp = self.preprocess_dir / f"{rec}-{suffix}.{self.segment_ext}" if (not force_recompute) and save_fp.is_file(): return # perform pre-process pps, _ = self.ppm(self.reader.load_data(rec), self.config.fs) savemat(save_fp, {"ecg": pps}, format="5")
[docs] def load_preprocessed_data(self, rec: str) -> np.ndarray: """Load the preprocessed data of the record. Parameters ---------- rec : str Name of the record. Returns ------- numpy.ndarray The pre-computed preprocessed ECG data of the record. """ preproc = self.allowed_preproc suffix = self._get_rec_suffix(preproc) fp = self.preprocess_dir / f"{rec}-{suffix}.{self.segment_ext}" if not fp.is_file(): raise FileNotFoundError(f"preprocess(es) \042{preproc}\042 not done for {rec} yet") p_sig = loadmat(str(fp))["ecg"] if p_sig.shape[0] != 2: p_sig = p_sig.T return p_sig
def _get_rec_suffix(self, operations: List[str]) -> str: """Get the suffix of the filename of the preprocessed ECG signal. Parameters ---------- operations : List[str] Names of operations to perform (or has performed). Should be sublist of `self.allowed_preproc`. Returns ------- str Suffix of the filename of the preprocessed ECG signal. """ suffix = "-".join(sorted([item.lower() for item in operations])) return suffix def _slice_data(self, force_recompute: bool = False, verbose: int = 0) -> None: """Slice the preprocessed data into segments. Slice all records into segments of length `self.seglen`, and perform data augmentations specified in `self.config`. Parameters ---------- force_recompute : bool, default False Whether to force recompute the preprocessed data. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ self.__assert_task( [ "qrs_detection", "main", ] ) if force_recompute: self._clear_cached_segments() for idx, rec in enumerate(self.reader.all_records): self._slice_one_record( rec=rec, force_recompute=False, update_segments_json=False, verbose=verbose, ) if verbose >= 1: print(f"{idx+1}/{len(self.reader.all_records)} records", end="\r") if force_recompute: self.segments_json.write_text(json.dump(self.__all_segments, ensure_ascii=False)) def _slice_one_record( self, rec: str, force_recompute: bool = False, update_segments_json: bool = False, verbose: int = 0, ) -> None: """Slice one preprocessed record into segments. slice one record into segments of length `self.seglen`, and perform data augmentations specified in `self.config`. Parameters ---------- rec : str Name of the record. force_recompute : bool, default False Whether to force recompute the preprocessed data. update_segments_json : bool, default False If both `force_recompute` and `update_segments_json` are True, the file `self.segments_json` will be updated. Useful when slicing not all records. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ self.__assert_task( [ "qrs_detection", "main", ] ) subject = self.reader.get_subject_id(rec) rec_segs = [item for item in self.__all_segments[subject] if item.startswith(rec.replace("data", "S"))] if (not force_recompute) and len(rec_segs) > 0: return elif force_recompute: self._clear_cached_segments([rec]) # data = self.reader.load_data(rec, units="mV") data = self.load_preprocessed_data(rec) siglen = data.shape[1] rpeaks = self.reader.load_rpeaks(rec) af_mask = self.reader.load_af_episodes(rec, fmt="mask") forward_len = self.seglen - self.config[self.task].overlap_len critical_forward_len = self.seglen - self.config[self.task].critical_overlap_len critical_forward_len = [critical_forward_len // 4, critical_forward_len] # skip those records that are too short if siglen < self.seglen: return # find critical points critical_points = np.where(np.diff(af_mask) != 0)[0] critical_points = [p for p in critical_points if critical_forward_len[1] <= p < siglen - critical_forward_len[1]] segments = [] # ordinary segments with constant forward_len for idx in range((siglen - self.seglen) // forward_len + 1): start_idx = idx * forward_len new_seg = self.__generate_segment( rec=rec, data=data, start_idx=start_idx, ) segments.append(new_seg) # the tail segment new_seg = self.__generate_segment( rec=rec, data=data, end_idx=siglen, ) segments.append(new_seg) # special segments around critical_points with random forward_len in critical_forward_len for cp in critical_points: start_idx = max( 0, cp - self.seglen + DEFAULTS.RNG_randint(critical_forward_len[0], critical_forward_len[1]), ) while start_idx <= min(cp - critical_forward_len[1], siglen - self.seglen): new_seg = self.__generate_segment( rec=rec, data=data, start_idx=start_idx, ) segments.append(new_seg) start_idx += DEFAULTS.RNG_randint(critical_forward_len[0], critical_forward_len[1]) # return segments self.__save_segments(rec, segments, update_segments_json) def __generate_segment( self, rec: str, data: np.ndarray, start_idx: Optional[int] = None, end_idx: Optional[int] = None, ) -> CFG: """Generate segment, with possible data augmentation. Parameter --------- rec : str Name of the record. data : numpy.ndarray The whole of (preprocessed) ECG record. start_idx : int, optional Start index of the signal for generating the segment. end_idx : int, optional End index of the signal for generating the segment. If `start_idx` is set, then `end_idx` is ignored, since the segment length is fixed to `self.seglen`. At least one of `start_idx` and `end_idx` should be set. Returns ------- dict Segments (meta-)data, containing: - data: values of the segment, with units in mV - rpeaks: indices of rpeaks of the segment - qrs_mask: mask of qrs complexes of the segment - af_mask: mask of af episodes of the segment - interval: interval ([start_idx, end_idx]) in the original ECG record of the segment """ assert not all([start_idx is None, end_idx is None]), "at least one of `start_idx` and `end_idx` should be set" siglen = data.shape[1] # offline augmentations are done, including strech-or-compress, ... if self.config.stretch_compress != 0: stretch_compress_choices = [0, 1, -1] sign = DEFAULTS.RNG_sample(stretch_compress_choices, 1)[0] if sign != 0: sc_ratio = self.config.stretch_compress sc_ratio = 1 + (DEFAULTS.RNG.uniform(sc_ratio / 4, sc_ratio) * sign) / 100 sc_len = int(round(sc_ratio * self.seglen)) if start_idx is not None: end_idx = start_idx + sc_len else: start_idx = end_idx - sc_len if end_idx > siglen: end_idx = siglen start_idx = max(0, end_idx - sc_len) sc_ratio = (end_idx - start_idx) / self.seglen aug_seg = data[..., start_idx:end_idx] aug_seg = SS.resample(x=aug_seg, num=self.seglen, axis=1) else: if start_idx is not None: end_idx = start_idx + self.seglen if end_idx > siglen: end_idx = siglen start_idx = end_idx - self.seglen else: start_idx = end_idx - self.seglen if start_idx < 0: start_idx = 0 end_idx = self.seglen # the segment of original signal, with no augmentation aug_seg = data[..., start_idx:end_idx] sc_ratio = 1 else: if start_idx is not None: end_idx = start_idx + self.seglen if end_idx > siglen: end_idx = siglen start_idx = end_idx - self.seglen else: start_idx = end_idx - self.seglen if start_idx < 0: start_idx = 0 end_idx = self.seglen aug_seg = data[..., start_idx:end_idx] sc_ratio = 1 # adjust rpeaks seg_rpeaks = self.reader.load_rpeaks( rec=rec, sampfrom=start_idx, sampto=end_idx, keep_original=False, ) seg_rpeaks = [ int(round(r / sc_ratio)) for r in seg_rpeaks if self.config.rpeaks_dist2border <= r < self.seglen - self.config.rpeaks_dist2border ] # generate qrs_mask from rpeaks seg_qrs_mask = np.zeros((self.seglen,), dtype=int) for r in seg_rpeaks: seg_qrs_mask[r - self.config.qrs_mask_bias : r + self.config.qrs_mask_bias] = 1 # adjust af_intervals seg_af_intervals = self.reader.load_af_episodes( rec=rec, sampfrom=start_idx, sampto=end_idx, keep_original=False, fmt="intervals", ) seg_af_intervals = [[int(round(itv[0] / sc_ratio)), int(round(itv[1] / sc_ratio))] for itv in seg_af_intervals] # generate af_mask from af_intervals seg_af_mask = np.zeros((self.seglen,), dtype=int) for itv in seg_af_intervals: seg_af_mask[itv[0] : itv[1]] = 1 new_seg = CFG( data=aug_seg, rpeaks=seg_rpeaks, qrs_mask=seg_qrs_mask, af_mask=seg_af_mask, interval=[start_idx, end_idx], ) return new_seg def __save_segments(self, rec: str, segments: List[CFG], update_segments_json: bool = False) -> None: """Save the segments to the disk. Parameters ---------- rec : str Name of the record. segments : List[dict] List of the segments (meta-)data to be saved. update_segments_json : bool, default False Whether to update the segments.json file. Returns ------- None """ subject = self.reader.get_subject_id(rec) ordering = list(range(len(segments))) DEFAULTS.RNG.shuffle(ordering) for i, idx in enumerate(ordering): seg = segments[idx] filename = f"{rec}_{i:07d}.{self.segment_ext}".replace("data", "S") data_path = self.segments_dirs.data[subject] / filename savemat(str(data_path), {"ecg": seg.data}) self.__all_segments[subject].append(Path(filename).with_suffix("").name) ann_path = self.segments_dirs.ann[subject] / filename savemat( str(ann_path), { k: v for k, v in seg.items() if k not in [ "data", ] }, ) if update_segments_json: self.segments_json.write_text(json.dumps(self.__all_segments, ensure_ascii=False)) def _clear_cached_segments(self, recs: Optional[Sequence[str]] = None) -> None: """Clear the cached segments. Parameters ---------- recs : Sequence[str], optional Sequence of the records whose segments are to be cleared. Defaults to all records. Returns ------- None """ self.__assert_task( [ "qrs_detection", "main", ] ) if recs is not None: for rec in recs: subject = self.reader.get_subject_id(rec) for item in [ "data", "ann", ]: path = str(self.segments_dirs[item][subject]) for f in [n for n in os.listdir(path) if n.endswith(self.segment_ext)]: if self._get_rec_name(f) == rec: os.remove(os.path.join(path, f)) if os.path.splitext(f)[0] in self.__all_segments[subject]: self.__all_segments[subject].remove(os.path.splitext(f)[0]) else: for subject in self.reader.all_subjects: for item in [ "data", "ann", ]: path = str(self.segments_dirs[item][subject]) for f in [n for n in os.listdir(path) if n.endswith(self.segment_ext)]: os.remove(os.path.join(path, f)) if os.path.splitext(f)[0] in self.__all_segments[subject]: self.__all_segments[subject].remove(os.path.splitext(f)[0]) self.segments = list_sum([self.__all_segments[subject] for subject in self.subjects]) def _slice_rr_seq(self, force_recompute: bool = False, verbose: int = 0) -> None: """Slice sequences of rr intervals into fixed length (sub)sequences. Parameters ---------- force_recompute : bool, default False Whether to force recompute the rr sequences. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ self.__assert_task(["rr_lstm"]) if force_recompute: self._clear_cached_rr_seq() for idx, rec in enumerate(self.reader.all_records): self._slice_rr_seq_one_record( rec=rec, force_recompute=False, update_rr_seq_json=False, verbose=verbose, ) if verbose >= 1: print(f"{idx+1}/{len(self.reader.all_records)} records", end="\r") if force_recompute: self.rr_seq_json.write_text(json.dumps(self.__all_rr_seq, ensure_ascii=False)) def _slice_rr_seq_one_record( self, rec: str, force_recompute: bool = False, update_rr_seq_json: bool = False, verbose: int = 0, ) -> None: """Slice sequences of rr intervals from one record into fixed length (sub)sequences. Parameters ---------- rec : str Name of the record. force_recompute : bool, default False Whether to force recompute the rr sequences. update_rr_seq_json : bool, default False Whether to update the rr_seq.json file. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ self.__assert_task(["rr_lstm"]) subject = self.reader.get_subject_id(rec) rec_rr_seq = [item for item in self.__all_rr_seq[subject] if item.startswith(rec.replace("data", "R"))] if (not force_recompute) and len(rec_rr_seq) > 0: return elif force_recompute: self._clear_cached_rr_seq([rec]) forward_len = self.seglen - self.config[self.task].overlap_len critical_forward_len = self.seglen - self.config[self.task].critical_overlap_len critical_forward_len = [critical_forward_len - 2, critical_forward_len] rpeaks = self.reader.load_rpeaks(rec) rr = np.diff(rpeaks) / self.config.fs if len(rr) < self.seglen: return af_mask = self.reader.load_af_episodes(rec, fmt="mask") label_seq = af_mask[rpeaks][:-1] # find critical points critical_points = np.where(np.diff(label_seq) != 0)[0] critical_points = [p for p in critical_points if critical_forward_len[1] <= p < len(rr) - critical_forward_len[1]] rr_seq = [] # ordinary segments with constant forward_len for idx in range((len(rr) - self.seglen) // forward_len + 1): start_idx = idx * forward_len end_idx = start_idx + self.seglen new_rr_seq = CFG( rr=rr[start_idx:end_idx], label=label_seq[start_idx:end_idx], interval=[start_idx, end_idx], ) rr_seq.append(new_rr_seq) # the tail segment if end_idx < len(rr): end_idx = len(rr) start_idx = end_idx - self.seglen new_rr_seq = CFG( rr=rr[start_idx:end_idx], label=label_seq[start_idx:end_idx], interval=[start_idx, end_idx], ) rr_seq.append(new_rr_seq) # special segments around critical_points with random forward_len in critical_forward_len for cp in critical_points: start_idx = max( 0, cp - self.seglen + DEFAULTS.RNG_randint(critical_forward_len[0], critical_forward_len[1]), ) while start_idx <= min(cp - critical_forward_len[1], len(rr) - self.seglen): end_idx = start_idx + self.seglen new_rr_seq = CFG( rr=rr[start_idx:end_idx], label=label_seq[start_idx:end_idx], interval=[start_idx, end_idx], ) rr_seq.append(new_rr_seq) start_idx += DEFAULTS.RNG_randint(critical_forward_len[0], critical_forward_len[1]) # save rr sequences self.__save_rr_seq(rec, rr_seq, update_rr_seq_json) def __save_rr_seq(self, rec: str, rr_seq: List[CFG], update_rr_seq_json: bool = False) -> None: """Save the sliced rr sequences to disk. Parameters ---------- rec : str, Name of the record. rr_seq : List[dict], List of the rr_seq (meta-)data. update_rr_seq_json : bool, default False Whether to update the rr_seq.json file. Returns ------- None """ subject = self.reader.get_subject_id(rec) ordering = list(range(len(rr_seq))) DEFAULTS.RNG.shuffle(ordering) for i, idx in enumerate(ordering): item = rr_seq[idx] filename = f"{rec}_{i:07d}.{self.rr_seq_ext}".replace("data", "R") data_path = self.rr_seq_dirs[subject] / filename savemat(str(data_path), item) self.__all_rr_seq[subject].append(Path(filename).with_suffix("").name) if update_rr_seq_json: self.rr_seq_json.write_text(json.dumps(self.__all_rr_seq, ensure_ascii=False)) def _clear_cached_rr_seq(self, recs: Optional[Sequence[str]] = None) -> None: """Clear the cached rr sequences. Parameters ---------- recs : Sequence[str], optional Sequence of the records whose cached rr sequences are to be cleared. Defaults to all records. Returns ------- None """ self.__assert_task(["rr_lstm"]) if recs is not None: for rec in recs: subject = self.reader.get_subject_id(rec) path = str(self.rr_seq_dirs[subject]) for f in [n for n in os.listdir(path) if n.endswith(self.rr_seq_ext)]: if self._get_rec_name(f) == rec: os.remove(os.path.join(path, f)) if os.path.splitext(f)[0] in self.__all_rr_seq[subject]: self.__all_rr_seq[subject].remove(os.path.splitext(f)[0]) else: for subject in self.reader.all_subjects: path = str(self.rr_seq_dirs[subject]) for f in [n for n in os.listdir(path) if n.endswith(self.rr_seq_ext)]: os.remove(os.path.join(path, f)) if os.path.splitext(f)[0] in self.__all_rr_seq[subject]: self.__all_rr_seq[subject].remove(os.path.splitext(f)[0]) self.rr_seq = list_sum([self.__all_rr_seq[subject] for subject in self.subjects]) def _get_rec_name(self, seg_or_rr: str) -> str: """Get the name of the record that a segment or rr_seq was generated from. Parameters ---------- seg_or_rr : str Name of the segment or rr_seq. Returns ------- rec : str Name of the record that the segment or rr_seq was generated from. """ rec = re.sub("[RS]", "data", os.path.splitext(seg_or_rr)[0])[:-8] return rec def _train_test_split(self, train_ratio: float = 0.8, force_recompute: bool = False) -> Dict[str, List[str]]: """Perform train-test split. Parameters ---------- train_ratio : float, default 0.8 Ratio of the train set in the whole dataset. force_recompute : bool, default False If True, the train-test split will be recomputed, regardless of the existing ones stored in json files. Returns ------- dict Dictionary of the train-test split. Keys are "train" and "test", and values are list of the subjects split for training or validation. """ start = time.time() if len(self.reader) == 0: warnings.warn( "No data found. Please check the data directory. " "The train-test split will be empty.", RuntimeWarning, ) split_res = CFG( { "train": [], "test": [], } ) return split_res print("\nstart performing train test split...\n") _train_ratio = int(train_ratio * 100) _test_ratio = 100 - _train_ratio assert _train_ratio * _test_ratio > 0 train_file = self.reader.db_dir_base / f"train_ratio_{_train_ratio}.json" test_file = self.reader.db_dir_base / f"test_ratio_{_test_ratio}.json" if force_recompute or not all([train_file.is_file(), test_file.is_file()]): all_subjects = set(self.reader.df_stats.subject_id.tolist()) afp_subjects = set(self.reader.df_stats[self.reader.df_stats.label == "AFp"].subject_id.tolist()) aff_subjects = set(self.reader.df_stats[self.reader.df_stats.label == "AFf"].subject_id.tolist()) - afp_subjects normal_subjects = all_subjects - afp_subjects - aff_subjects test_set = ( DEFAULTS.RNG_sample( list(afp_subjects), max(1, int(round(len(afp_subjects) * _test_ratio / 100))), ).tolist() + DEFAULTS.RNG_sample( list(aff_subjects), max(1, int(round(len(aff_subjects) * _test_ratio / 100))), ).tolist() + DEFAULTS.RNG_sample( list(normal_subjects), max(1, int(round(len(normal_subjects) * _test_ratio / 100))), ).tolist() ) train_set = list(all_subjects - set(test_set)) DEFAULTS.RNG.shuffle(test_set) DEFAULTS.RNG.shuffle(train_set) train_file.write_text(json.dumps(train_set, ensure_ascii=False)) test_file.write_text(json.dumps(test_set, ensure_ascii=False)) print( nildent( f""" train set saved to \n\042{str(train_file)}\042 test set saved to \n\042{str(test_file)}\042 """ ) ) else: train_set = json.loads(train_file.read_text()) test_set = json.loads(test_file.read_text()) print(f"train test split finished in {(time.time()-start)/60:.2f} minutes") split_res = CFG( { "train": train_set, "test": test_set, } ) return split_res def __assert_task(self, tasks: List[str]) -> None: """Check if the current task is in the given list of tasks.""" assert self.task in tasks, ( f"DO NOT call this method when the current task is {self.task}. " "Switch task using `reset_task`" )
[docs] def plot_seg(self, seg: str, ticks_granularity: int = 0) -> None: """Plot the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_1_1_0000193". ticks_granularity : int, default 0 Granularity to plot axis ticks, the higher the more ticks. 0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks) Returns ------- None """ seg_data = self._load_seg_data(seg) print(f"seg_data.shape = {seg_data.shape}") seg_ann = self._load_seg_ann(seg) seg_ann["af_episodes"] = mask_to_intervals(seg_ann["af_mask"], vals=1) print(f"seg_ann = {seg_ann}") rec_name = self._get_rec_name(seg) self.reader.plot( rec=rec_name, # unnecessary indeed data=seg_data, ann=seg_ann, ticks_granularity=ticks_granularity, )
[docs] def extra_repr_keys(self) -> List[str]: return [ "training", "task", "reader", ]
class _FastDataReader(ReprMixin, Dataset): """Fast data reader. Parameters ---------- reader : CR The reader to read the data. records : Sequence[str] The list of records to read. config : CFG The configuration. ppm : PreprocManager, optional The preprocessor manager. """ def __init__( self, config: CFG, task: str, seg_ppm: PreprocManager, file_dirs: dict, files: List[str], file_ext: str, ) -> None: self.config = config self.task = task self.seg_ppm = seg_ppm self.file_dirs = file_dirs self.files = files self.file_ext = file_ext self.seglen = self.config[self.task].input_len self.n_classes = len(self.config[task].classes) self._seg_keys = { "qrs_detection": "qrs_mask", "main": "af_mask", } def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]: if self.task in [ "qrs_detection", "main", ]: seg_name = self.files[index] subject = seg_name.split("_")[1] seg_data_fp = self.file_dirs.data[subject] / f"{seg_name}.{self.file_ext}" seg_data = loadmat(str(seg_data_fp))["ecg"] for idx in range(seg_data.shape[0]): seg_data[idx] = remove_spikes_naive(seg_data[idx]) seg_ann_fp = self.file_dirs.ann[subject] / f"{seg_name}.{self.file_ext}" seg_label = loadmat(str(seg_ann_fp))[self._seg_keys[self.task]].reshape((self.seglen, -1)) if self.config[self.task].reduction > 1: reduction = self.config[self.task].reduction seg_len, n_classes = seg_label.shape seg_label = np.stack( arrays=[ np.mean( seg_data[reduction * idx : reduction * (idx + 1)], axis=0, keepdims=True, ).astype(int) for idx in range(seg_len // reduction) ], axis=0, ).squeeze(axis=1) seg_data, _ = self.seg_ppm(seg_data, self.config.fs) if self.task == "main": weight_mask = generate_weight_mask( target_mask=seg_label.squeeze(-1), fg_weight=2, fs=self.config.fs, reduction=self.config[self.task].reduction, radius=0.8, boundary_weight=5, )[..., np.newaxis] return seg_data, seg_label, weight_mask return seg_data, seg_label, None elif self.task in [ "rr_lstm", ]: seq_name = self.files[index] subject = seq_name.split("_")[1] rr_seq_path = self.file_dirs[subject] / f"{seq_name}.{self.file_ext}" rr_seq = loadmat(str(rr_seq_path)) rr_seq["rr"] = rr_seq["rr"].reshape((self.seglen, 1)) rr_seq["label"] = rr_seq["label"].reshape((self.seglen, self.n_classes)) weight_mask = generate_weight_mask( target_mask=rr_seq["label"].squeeze(-1), fg_weight=2, fs=1 / 0.8, reduction=1, radius=2, boundary_weight=5, )[..., np.newaxis] return rr_seq["rr"], rr_seq["label"], weight_mask else: raise NotImplementedError(f"data generator for task \042{self.task}\042 not implemented") def __len__(self) -> int: return len(self.files) def extra_repr_keys(self) -> List[str]: return [ "task", "reader", "ppm", ] # `StandaloneSegmentSlicer` can be found in `benchmarks/train_hybrid_cpsc2021/dataset.py`