Source code for torch_ecg.databases.datasets.mitdb.mitdb_dataset

"""
"""

import json
import os
import warnings
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
from scipy import signal as SS
from scipy.io import loadmat, savemat
from torch.utils.data.dataset import Dataset
from tqdm.auto import tqdm

from ...._preprocessors import PreprocManager
from ....cfg import CFG, DEFAULTS
from ....databases import MITDB as DR
from ....utils.misc import ReprMixin, add_docstring, get_record_list_recursive3, list_sum
from ....utils.utils_data import cls_to_bin, ensure_siglen, generate_weight_mask, mask_to_intervals
from ....utils.utils_signal import remove_spikes_naive
from .mitdb_cfg import MITDBTrainCfg

__all__ = [
    "MITDBDataset",
]


[docs]class MITDBDataset(ReprMixin, Dataset): """Data generator for feeding data into pytorch models using the :class:`~torch_ecg.databases.MITDB` database. Parameters ---------- config : dict Configurations for the dataset, ref. `MITDBDataset`. A simple example is as follows: .. code-block:: python >>> config = deepcopy(MITDBDataset) >>> config.db_dir = "some/path/to/db" >>> dataset = MITDBDataset(config, task="qrs_detection", training=True, lazy=False) training : bool, default True If True, the training set will be loaded, otherwise the test (val) set will be loaded. lazy : bool, default True If True, the data will not be loaded immediately, instead, it will be loaded on demand. **reader_kwargs : dict, optional Keyword arguments for the database reader class. """ __name__ = "MITDBDataset" def __init__( self, config: CFG, task: str, training: bool = True, lazy: bool = True, **reader_kwargs: Any, ) -> None: super().__init__() self.config = deepcopy(config) if reader_kwargs.pop("db_dir", None) is not None: warnings.warn("`db_dir` is specified in both config and reader_kwargs", RuntimeWarning) self.reader = DR(db_dir=self.config.db_dir, **reader_kwargs) # assert self.config.db_dir is not None, "db_dir must be specified" self.config.db_dir = self.reader.db_dir self.dtype = self.config.np_dtype self.allowed_preproc = list( set( [ "bandpass", "baseline_remove", "normalize", ] ).intersection(set(self.config.keys())) ) self.training = training self.lazy = lazy ppm_config = CFG(random=False) ppm_config.update(deepcopy(self.config)) # ppm_config.pop("normalize") seg_ppm_config = CFG(random=False) seg_ppm_config.update(deepcopy(self.config)) seg_ppm_config.pop("bandpass", None) self.ppm = PreprocManager.from_config(ppm_config) # self.ppm.rearrange(["bandpass", "baseline_remove", "normalize"]) self.seg_ppm = PreprocManager.from_config(seg_ppm_config) # create directories if needed # segments_dir for sliced segments of fixed length self.segments_base_dir = self.config.db_dir / "segments" self.segments_base_dir.mkdir(parents=True, exist_ok=True) self.segment_name_pattern = "S_[\\d]{3}_[\\d]{7}" self.segment_ext = "mat" # rr_dir for sequence of rr intervals of fix length self.rr_seq_base_dir = self.config.db_dir / "rr_seq" self.rr_seq_base_dir.mkdir(parents=True, exist_ok=True) self.rr_seq_name_pattern = "R_[\\d]{3}_[\\d]{7}" self.rr_seq_ext = "mat" self._all_data = None self._all_labels = None self._all_masks = None self.__set_task(task, lazy=self.lazy) def _load_all_data(self) -> None: """Load all data into memory.""" self.__set_task(self.task, lazy=False) def __set_task(self, task: str, lazy: bool = True) -> None: """Set the task and load the data if needed. Parameters ---------- task : str Name of the task, as listed in `MITDBTrainCfg.tasks`. lazy : bool, default True If True, the data will not be loaded immediately, instead, it will be loaded on demand. Returns ------- None """ assert task.lower() in MITDBTrainCfg.tasks, f"illegal task \042{task}\042" if hasattr(self, "task") and self.task == task.lower() and self._all_data is not None and len(self._all_data) > 0: return self.task = task.lower() self.all_classes = self.config[task].classes self.n_classes = len(self.config[task].classes) self.lazy = lazy self.seglen = self.config[task].input_len # alias, for simplicity split_res = self._train_test_split(self.task) if self.training: self.records = split_res.train else: self.records = split_res.test if self.task in ["beat_classification"]: self._all_data, self._all_labels = [], [] if self.lazy: return with tqdm( range(len(self.records)), desc="Loading data", unit="record", dynamic_ncols=True, mininterval=1.0, ) as pbar: for idx in pbar: data = self.reader.load_data(self.records[idx]) data, _ = self.ppm(data, self.config.fs) for beat in self.reader.load_beat_ann(self.records[idx]): if beat.symbol not in self.config[self.task].classes: continue beat_data = ensure_siglen( data[ :, max(0, beat.index - self.config[self.task].winL) : min( data.shape[-1], beat.index + self.config[self.task].winR, ), ], self.config[self.task].input_len, ) self._all_data.append(beat_data) self._all_labels.append(self.config[self.task].class_map[beat.symbol]) self._all_data = np.array(self._all_data) self._all_labels = np.array(self._all_labels) if self.config[self.task].loss not in ["CrossEntropyLoss"]: self._all_labels = cls_to_bin(self._all_labels, len(self.config[self.task].classes)) elif self.task in [ "qrs_detection", "rhythm_segmentation", "af_event", ]: # for qrs detection self.segments_dirs = CFG() self.__all_segments = CFG() self.segments_json = self.segments_base_dir / "segments.json" self._ls_segments() self.segments = list_sum([self.__all_segments[rec] for rec in self.records]) # if self.__DEBUG__: # self.segments = DEFAULTS.RNG_sample( # self.segments, int(len(self.segments) * 0.01) # ).tolist() if self.training: DEFAULTS.RNG.shuffle(self.segments) # preload data self.fdr = _FastDataReader( self.config, self.task, self.seg_ppm, self.segments_dirs, self.segments, self.segment_ext, self.reader.rhythm_types_map, ) if self.lazy: return self._all_data, self._all_labels, self._all_masks = [], [], [] if len(self.fdr) == 0: warnings.warn( f"No data found for task {self.task}, slice the data first.", RuntimeWarning, ) with tqdm( range(len(self.fdr)), desc="Loading data", unit="record", dynamic_ncols=True, mininterval=1.0, ) as pbar: for idx in pbar: d, l, m = self.fdr[idx] self._all_data.append(d) self._all_labels.append(l) self._all_masks.append(m) self._all_data = np.array(self._all_data).astype(self.dtype) self._all_labels = np.array(self._all_labels).astype(self.dtype) if self.task == "qrs_detection": self._all_masks = None else: self._all_masks = np.array(self._all_masks).astype(self.dtype) elif self.task in [ "rr_lstm", ]: self.rr_seq_dirs = CFG() self.__all_rr_seq = CFG() self.rr_seq_json = self.rr_seq_base_dir / "rr_seq.json" self._ls_rr_seq() self.rr_seq = list_sum([self.__all_rr_seq[rec] for rec in self.records]) # if self.__DEBUG__: # self.rr_seq = DEFAULTS.RNG_sample( # self.rr_seq, int(len(self.rr_seq) * 0.01) # ).tolist() if self.training: DEFAULTS.RNG.shuffle(self.rr_seq) # preload data self.fdr = _FastDataReader( self.config, self.task, self.seg_ppm, self.rr_seq_dirs, self.rr_seq, self.rr_seq_ext, self.reader.rhythm_types_map, ) if self.lazy: return self._all_data, self._all_labels, self._all_masks = [], [], [] with tqdm( range(len(self.fdr)), desc="Loading data", unit="record", dynamic_ncols=True, mininterval=1.0, ) as pbar: for idx in pbar: d, l, m = self.fdr[idx] self._all_data.append(d) self._all_labels.append(l) self._all_masks.append(m) self._all_data = np.array(self._all_data).astype(self.dtype) self._all_labels = np.array(self._all_labels).astype(self.dtype) self._all_masks = np.array(self._all_masks).astype(self.dtype) else: raise NotImplementedError(f"data generator for task \042{self.task}\042 not implemented")
[docs] def reset_task(self, task: str, lazy: bool = True) -> None: """Reset the task of the data generator. Parameters ---------- task : str The task to be set. lazy : bool, optional Whether to load the data lazily, by default True. Returns ------- None """ self.__set_task(task, lazy)
def _ls_segments(self) -> None: """Find all the segments in the segments base directory, and store them in some private attributes. """ for item in ["data", "ann"]: self.segments_dirs[item] = CFG() for rec in self.reader: self.segments_dirs[item][rec] = self.segments_base_dir / item / rec self.segments_dirs[item][rec].mkdir(parents=True, exist_ok=True) if self.segments_json.is_file(): self.__all_segments = json.loads(self.segments_json.read_text()) # return print("please allow the reader a few minutes to collect " f"the segments from {self.segments_base_dir}...") seg_filename_pattern = f"{self.segment_name_pattern}\\.{self.segment_ext}" self.__all_segments = CFG( {rec: get_record_list_recursive3(str(self.segments_dirs.data[rec]), seg_filename_pattern) for rec in self.reader} ) if all([len(self.__all_segments[rec]) > 0 for rec in self.reader]): self.segments_json.write_text(json.dumps(self.__all_segments, ensure_ascii=False)) def _ls_rr_seq(self) -> None: """Find all the rr sequences in the rr sequences base directory, and store them in some private attributes. """ for rec in self.reader: self.rr_seq_dirs[rec] = self.rr_seq_base_dir / rec self.rr_seq_dirs[rec].mkdir(parents=True, exist_ok=True) if self.rr_seq_json.is_file(): self.__all_rr_seq = json.loads(self.rr_seq_json.read_text()) # return print("please allow the reader a few minutes to collect " f"the rr sequences from {self.rr_seq_base_dir}...") rr_seq_filename_pattern = f"{self.rr_seq_name_pattern}\\.{self.rr_seq_ext}" self.__all_rr_seq = CFG( {rec: get_record_list_recursive3(self.rr_seq_dirs[rec], rr_seq_filename_pattern) for rec in self.reader} ) if all([len(self.__all_rr_seq[rec]) > 0 for rec in self.reader]): self.rr_seq_json.write_text(json.dumps(self.__all_rr_seq, ensure_ascii=False)) @property def all_segments(self) -> CFG: if self.task in [ "qrs_detection", "rhythm_segmentation", "af_event", ]: return self.__all_segments else: return CFG() @property def all_rr_seq(self) -> CFG: if self.task.lower() in [ "rr_lstm", ]: return self.__all_rr_seq else: return CFG() def __len__(self) -> int: if self.task in ["beat_classification"]: return len(self._all_data) return len(self.fdr) def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]: if self.task in ["beat_classification"]: return self._all_data[index], self._all_labels[index] if self.lazy: if self.task in ["qrs_detection"]: return self.fdr[index][:2] else: return self.fdr[index] else: if self.task in ["qrs_detection"]: return self._all_data[index], self._all_labels[index] else: return ( self._all_data[index], self._all_labels[index], self._all_masks[index], ) def _get_seg_data_path(self, seg: str) -> Path: """Get the path of the data file of the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_100_0000193". Returns ------- pathlib.Path Absolute path of the data file of the segment. """ rec = self._get_rec_name(seg) fp = self.segments_dirs.data[rec] / f"{seg}.{self.segment_ext}" return fp @add_docstring(_get_seg_data_path.__doc__.replace("data file", "annotation file")) def _get_seg_ann_path(self, seg: str) -> Path: rec = self._get_rec_name(seg) fp = self.segments_dirs.ann[rec] / f"{seg}.{self.segment_ext}" return fp def _load_seg_data(self, seg: str) -> np.ndarray: """Load data of the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_100_0000193". Returns ------- numpy.ndarray Data of the segment, of shape ``(2, self.seglen)``. """ seg_data_fp = self._get_seg_data_path(seg) seg_data = loadmat(str(seg_data_fp))["ecg"] return seg_data def _load_seg_ann(self, seg: str) -> dict: """Load annotations of the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_100_0000193". Returns ------- dict A dictionay of annotations of the segment, including - rpeaks: indices of rpeaks of the segment - qrs_mask: mask of qrs complexes of the segment - rhythm_mask: mask of rhythms of the segment - interval: interval ([start_idx, end_idx]) in the original ECG record of the segment """ seg_ann_fp = self._get_seg_ann_path(seg) seg_ann = {k: v.flatten() for k, v in loadmat(str(seg_ann_fp)).items() if not k.startswith("__")} return seg_ann def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]: """Load mask(s) of the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_100_0000193". task : str, optional Task for doing the segmentation, by default the current task. If specified, overrides self.task. If is "all", then all masks ("qrs_mask", "rhythm_mask", etc.) will be returned. Returns ------- numpy.ndarray or dict Mask(s) of the segment, of shape ``(self.seglen, self.n_classes)``. """ seg_mask = { k: v.reshape((self.seglen, -1)) for k, v in self._load_seg_ann(seg).items() if k in [ "qrs_mask", "rhythm_mask", ] } _task = (task or self.task).lower() if _task == "all": return seg_mask if _task in [ "qrs_detection", ]: seg_mask = seg_mask["qrs_mask"] elif _task in ["rhythm_segmentation", "af_event"]: seg_mask = seg_mask["rhythm_mask"] return seg_mask def _load_seg_seq_lab(self, seg: str, reduction: int) -> np.ndarray: """Load sequence label of the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_100_0000193". reduction : int Reduction ratio (granularity) of length of the model output, compared to the original signal length. Returns ------- numpy.ndarray Label of the sequence, of shape ``(self.seglen//reduction, self.n_classes)``. """ seg_mask = self._load_seg_mask(seg) seg_len, n_classes = seg_mask.shape seq_lab = np.stack( arrays=[ np.mean( seg_mask[reduction * idx : reduction * (idx + 1)], axis=0, keepdims=True, ).astype(int) for idx in range(seg_len // reduction) ], axis=0, ).squeeze(axis=1) return seq_lab def _get_rr_seq_path(self, seq_name: str) -> Path: """Get path of the data file of the rr_seq. Parameters ---------- seq_name : str Name of the rr_seq, of pattern like "R_100_0000193". Returns ------- pathlib.Path Absolute path of the data file of the rr_seq. """ rec = self._get_rec_name(seq_name) fp = self.rr_seq_dirs[rec] / f"{seq_name}.{self.rr_seq_ext}" return fp def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]: """Load metadata of sequence of rr intervals. Parameters ---------- seq_name : str Name of the rr_seq, of pattern like "R_100_0000193". Returns ------- dict Metadata of sequence of rr intervals, including - rr: the sequence of rr intervals, with units in seconds, of shape ``(self.seglen, 1)`` - label: label of the rr intervals, of shape ``(self.seglen, self.n_classes)`` - interval: interval of the current rr sequence in the whole rr sequence in the original record """ rr_seq_path = self._get_rr_seq_path(seq_name) rr_seq = {k: v for k, v in loadmat(str(rr_seq_path)).items() if not k.startswith("__")} rr_seq["rr"] = rr_seq["rr"].reshape((self.seglen, 1)) rr_seq["label"] = rr_seq["label"].reshape((self.seglen, self.n_classes)) rr_seq["interval"] = rr_seq["interval"].flatten() return rr_seq
[docs] def persistence(self, force_recompute: bool = False, verbose: int = 0) -> None: """Save the preprocessed data to disk. Parameters ---------- force_recompute : bool, default False Whether to force recompute the preprocessed data. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ # TODO: consider whether preprocessing should be added # if verbose >= 1: # print(" preprocessing data ".center(110, "#")) # self._preprocess_data( # force_recompute=force_recompute, # verbose=verbose, # ) original_task = self.task original_lazy = self.lazy self.__set_task("qrs_detection", lazy=True) if verbose >= 1: print("\n" + " slicing data into segments ".center(110, "#")) self._slice_data( force_recompute=force_recompute, verbose=verbose, ) self.__set_task("rr_lstm", lazy=True) if verbose >= 1: print("\n" + " generating rr sequences ".center(110, "#")) self._slice_rr_seq( force_recompute=force_recompute, verbose=verbose, ) self.__set_task(original_task, lazy=original_lazy)
def _slice_data(self, force_recompute: bool = False, verbose: int = 0) -> None: """Slice all records into segments. Slice all records into segments of length `self.seglen`, and perform data augmentations specified in `self.config`. Parameters ---------- force_recompute : bool, default False Whether to force recompute the preprocessed data. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ self.__assert_task( [ "qrs_detection", "rhythm_segmentation", "af_event", # segmentation of AF events ] ) if force_recompute: self._clear_cached_segments() with tqdm( enumerate(self.reader), total=len(self.reader), desc="Slicing data", unit="record", dynamic_ncols=True, mininterval=1.0, ) as pbar: for idx, rec in pbar: self._slice_one_record( rec=rec, force_recompute=False, update_segments_json=False, verbose=verbose, ) # if verbose >= 1: # print(f"{idx+1}/{len(self.reader)} records", end="\r") if force_recompute: self.segments_json.write_text(json.dump(self.__all_segments, ensure_ascii=False)) def _slice_one_record( self, rec: str, force_recompute: bool = False, update_segments_json: bool = False, verbose: int = 0, ) -> None: """Slice one record into segments. Slice one record into segments of length `self.seglen`, and perform data augmentations specified in `self.config`. Parameters ---------- rec : str Name of the record. force_recompute : bool, default False Whether to force recompute the preprocessed data. update_segments_json : bool, default False If both `force_recompute` and `update_segments_json` are True, the file `self.segments_json` will be updated. Useful when slicing not all records. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ self.__assert_task( [ "qrs_detection", "rhythm_segmentation", "af_event", # segmentation of AF events ] ) rec_segs = self.__all_segments[rec] if (not force_recompute) and len(rec_segs) > 0: return elif force_recompute: self._clear_cached_segments([rec]) data, _ = self.ppm(self.reader.load_data(rec), self.config.fs) siglen = data.shape[1] rpeaks = self.reader.load_rpeak_indices(rec) rhythm_mask = self.reader.load_rhythm_ann( rec, rhythm_format="mask", rhythm_types=self.config.rhythm_segmentation.classes, ) forward_len = self.seglen - self.config[self.task].overlap_len critical_forward_len = self.seglen - self.config[self.task].critical_overlap_len critical_forward_len = [critical_forward_len // 4, critical_forward_len] # find critical points critical_points = np.where(np.diff(rhythm_mask) != 0)[0] critical_points = [p for p in critical_points if critical_forward_len[1] <= p < siglen - critical_forward_len[1]] segments = [] # ordinary segments with constant forward_len print("Slicing ordinary segments with constant `forward_len`") with tqdm( range((siglen - self.seglen) // forward_len + 1), desc=f"Slicing segments for record {rec}", unit="segment", dynamic_ncols=True, mininterval=1.0, ) as pbar: for idx in pbar: start_idx = idx * forward_len new_seg = self.__generate_segment( rec=rec, data=data, start_idx=start_idx, ) segments.append(new_seg) # the tail segment new_seg = self.__generate_segment( rec=rec, data=data, end_idx=siglen, ) segments.append(new_seg) if len(critical_points) == 0: # save segments self.__save_segments(rec, segments, update_segments_json) return # special segments around critical_points with random forward_len in critical_forward_len print("Slicing special segments around `critical_points` " "with random `forward_len` in `critical_forward_len`") with tqdm( critical_points, desc=f"Slicing segments for record {rec}", unit="segment", dynamic_ncols=True, mininterval=1.0, ) as pbar: for cp in pbar: start_idx = max( 0, cp - self.seglen + DEFAULTS.RNG_randint(critical_forward_len[0], critical_forward_len[1]), ) while start_idx <= min(cp - critical_forward_len[1], siglen - self.seglen): new_seg = self.__generate_segment( rec=rec, data=data, start_idx=start_idx, ) segments.append(new_seg) start_idx += DEFAULTS.RNG_randint(critical_forward_len[0], critical_forward_len[1]) # save segments self.__save_segments(rec, segments, update_segments_json) def __generate_segment( self, rec: str, data: np.ndarray, start_idx: Optional[int] = None, end_idx: Optional[int] = None, ) -> CFG: """Generate segment, with possible data augmentation. Parameter --------- rec : str Name of the record. data : numpy.ndarray The whole of (preprocessed) ECG record. start_idx : int, optional Start index of the signal for generating the segment. end_idx : int, optional End index of the signal for generating the segment. If `start_idx` is set, then `end_idx` is ignored, since the segment length is fixed to `self.seglen`. At least one of `start_idx` and `end_idx` should be set. Returns ------- dict Segments (meta-)data, containing: - data: values of the segment, with units in mV - rpeaks: indices of rpeaks of the segment - qrs_mask: mask of qrs complexes of the segment - rhythm_mask: mask of rhythms of the segment - interval: interval ([start_idx, end_idx]) in the original ECG record of the segment """ assert not all([start_idx is None, end_idx is None]), "at least one of `start_idx` and `end_idx` should be set" siglen = data.shape[1] # offline augmentations are done, including strech-or-compress, ... if self.config.stretch_compress != 0: stretch_compress_choices = [0, 1, -1] sign = DEFAULTS.RNG_sample(stretch_compress_choices, 1)[0] if sign != 0: sc_ratio = self.config.stretch_compress sc_ratio = 1 + (DEFAULTS.RNG.uniform(sc_ratio / 4, sc_ratio) * sign) / 100 sc_len = int(round(sc_ratio * self.seglen)) if start_idx is not None: end_idx = start_idx + sc_len else: start_idx = end_idx - sc_len if end_idx > siglen: end_idx = siglen start_idx = max(0, end_idx - sc_len) sc_ratio = (end_idx - start_idx) / self.seglen aug_seg = data[..., start_idx:end_idx] aug_seg = SS.resample(x=aug_seg, num=self.seglen, axis=1) else: if start_idx is not None: end_idx = start_idx + self.seglen if end_idx > siglen: end_idx = siglen start_idx = end_idx - self.seglen else: start_idx = end_idx - self.seglen if start_idx < 0: start_idx = 0 end_idx = self.seglen # the segment of original signal, with no augmentation aug_seg = data[..., start_idx:end_idx] sc_ratio = 1 else: if start_idx is not None: end_idx = start_idx + self.seglen if end_idx > siglen: end_idx = siglen start_idx = end_idx - self.seglen else: start_idx = end_idx - self.seglen if start_idx < 0: start_idx = 0 end_idx = self.seglen aug_seg = data[..., start_idx:end_idx] sc_ratio = 1 # adjust rpeaks seg_rpeaks = self.reader.load_rpeak_indices( rec=rec, sampfrom=start_idx, sampto=end_idx, keep_original=False, ) seg_rpeaks = [ int(round(r / sc_ratio)) for r in seg_rpeaks if self.config.rpeaks_dist2border <= r < self.seglen - self.config.rpeaks_dist2border ] # generate qrs_mask from rpeaks seg_qrs_mask = np.zeros((self.seglen,), dtype=int) for r in seg_rpeaks: seg_qrs_mask[r - self.config.qrs_mask_bias : r + self.config.qrs_mask_bias] = 1 # adjust rhythm_intervals seg_rhythm_intervals = self.reader.load_rhythm_ann( rec=rec, sampfrom=start_idx, sampto=end_idx, rhythm_format="intervals", rhythm_types=self.config.rhythm_segmentation.classes, keep_original=False, ) seg_rhythm_intervals = { rt: [[int(round(itv[0] / sc_ratio)), int(round(itv[1] / sc_ratio))] for itv in l_itvs] for rt, l_itvs in seg_rhythm_intervals.items() } # generate rhythm_mask from rhythm_intervals seg_rhythm_mask = np.zeros((self.seglen,), dtype=int) for rt, l_itvs in seg_rhythm_intervals.items(): for itv in l_itvs: seg_rhythm_mask[itv[0] : itv[1]] = self.config.rhythm_segmentation.class_map[rt] new_seg = CFG( data=aug_seg, rpeaks=seg_rpeaks, qrs_mask=seg_qrs_mask, rhythm_mask=seg_rhythm_mask, interval=[start_idx, end_idx], ) return new_seg def __save_segments(self, rec: str, segments: List[CFG], update_segments_json: bool = False) -> None: """Save the segments to disk. Parameters ---------- rec : str Name of the record segments : List[dict] List of the segments (meta-)data to be saved. update_segments_json : bool, default False Whether to update the segments.json file. """ ordering = list(range(len(segments))) DEFAULTS.RNG.shuffle(ordering) for i, idx in enumerate(ordering): seg = segments[idx] filename = f"S_{rec}_{i:07d}.{self.segment_ext}" data_path = self.segments_dirs.data[rec] / filename savemat(str(data_path), {"ecg": seg.data}) self.__all_segments[rec].append(Path(filename).with_suffix("").name) ann_path = self.segments_dirs.ann[rec] / filename savemat( str(ann_path), { k: v for k, v in seg.items() if k not in [ "data", ] }, ) if update_segments_json: self.segments_json.write_text(json.dumps(self.__all_segments, ensure_ascii=False)) def _clear_cached_segments(self, recs: Optional[Sequence[str]] = None) -> None: """Clear the cached segments of the records. Parameters ---------- recs : Sequence[str], optional Sequence of the records whose segments are to be cleared. Defaults to all records. Returns ------- None """ self.__assert_task( [ "qrs_detection", "rhythm_segmentation", "af_event", # segmentation of AF events ] ) if recs is None: recs = self.reader.all_records for rec in recs: for item in [ "data", "ann", ]: path = str(self.segments_dirs[item][rec]) for f in [n for n in os.listdir(path) if n.endswith(self.segment_ext)]: os.remove(os.path.join(path, f)) if os.path.splitext(f)[0] in self.__all_segments[rec]: self.__all_segments[rec].remove(os.path.splitext(f)[0]) self.segments = list_sum([self.__all_segments[rec] for rec in self.records]) def _slice_rr_seq(self, force_recompute: bool = False, verbose: int = 0) -> None: """Slice sequences of rr intervals into fixed length (sub)sequences. Parameters ---------- force_recompute : bool, default False Whether to force recompute the rr sequences. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ self.__assert_task(["rr_lstm"]) if force_recompute: self._clear_cached_rr_seq() with tqdm( enumerate(self.reader), total=len(self.reader), desc="Slicing rr_seq", unit="record", dynamic_ncols=True, mininterval=1.0, ) as pbar: for idx, rec in pbar: self._slice_rr_seq_one_record( rec=rec, force_recompute=False, update_rr_seq_json=False, verbose=verbose, ) # if verbose >= 1: # print(f"{idx+1}/{len(self.reader.all_records)} records", end="\r") if force_recompute: self.rr_seq_json.write_text(json.dumps(self.__all_rr_seq, ensure_ascii=False)) def _slice_rr_seq_one_record( self, rec: str, force_recompute: bool = False, update_rr_seq_json: bool = False, verbose: int = 0, ) -> None: """Slice sequences of rr intervals into fixed length (sub)sequences for one record. Parameters ---------- rec : str Name of the record force_recompute : bool, default False, If True, the rr sequences will be recomputed regardless of whether they have been computed before. update_rr_seq_json : bool, default False Whether to update the rr_seq.json file. verbose : int, default 0 Verbosity level for printing the progress. Returns ------- None """ self.__assert_task(["rr_lstm"]) rec_rr_seq = self.__all_rr_seq[rec] if (not force_recompute) and len(rec_rr_seq) > 0: return elif force_recompute: self._clear_cached_rr_seq([rec]) forward_len = self.seglen - self.config[self.task].overlap_len critical_forward_len = self.seglen - self.config[self.task].critical_overlap_len critical_forward_len = [critical_forward_len - 2, critical_forward_len] rpeaks = self.reader.load_rpeak_indices(rec) rr = np.diff(rpeaks) / self.config.fs if len(rr) < self.seglen: return rhythm_mask = self.reader.load_rhythm_ann( rec, rhythm_format="mask", rhythm_types=self.config.rhythm_segmentation.classes, ) label_seq = rhythm_mask[rpeaks][:-1] # find critical points critical_points = np.where(np.diff(label_seq) != 0)[0] critical_points = [p for p in critical_points if critical_forward_len[1] <= p < len(rr) - critical_forward_len[1]] rr_seq = [] # ordinary rr_seq with constant forward_len print("Slicing ordinary rr_seq with constant `forward_len`") with tqdm( range((len(rr) - self.seglen) // forward_len + 1), desc=f"Slicing rr_seq for record `{rec}`", unit="segment", dynamic_ncols=True, mininterval=1.0, ) as pbar: for idx in pbar: start_idx = idx * forward_len end_idx = start_idx + self.seglen new_rr_seq = CFG( rr=rr[start_idx:end_idx], label=label_seq[start_idx:end_idx], interval=[start_idx, end_idx], ) rr_seq.append(new_rr_seq) # the tail segment if end_idx < len(rr): end_idx = len(rr) start_idx = end_idx - self.seglen new_rr_seq = CFG( rr=rr[start_idx:end_idx], label=label_seq[start_idx:end_idx], interval=[start_idx, end_idx], ) rr_seq.append(new_rr_seq) if len(critical_points) == 0: # save rr sequences self.__save_rr_seq(rec, rr_seq, update_rr_seq_json) return # special rr_seq around critical_points with random forward_len in critical_forward_len print("Slicing special rr_seq around `critical_points` " "with random `forward_len` in `critical_forward_len`") with tqdm( critical_points, desc=f"Slicing rr_seq for record {rec}", unit="segment", dynamic_ncols=True, mininterval=1.0, ) as pbar: for cp in pbar: start_idx = max( 0, cp - self.seglen + DEFAULTS.RNG_randint(critical_forward_len[0], critical_forward_len[1]), ) while start_idx <= min(cp - critical_forward_len[1], len(rr) - self.seglen): end_idx = start_idx + self.seglen new_rr_seq = CFG( rr=rr[start_idx:end_idx], label=label_seq[start_idx:end_idx], interval=[start_idx, end_idx], ) rr_seq.append(new_rr_seq) start_idx += DEFAULTS.RNG_randint(critical_forward_len[0], critical_forward_len[1]) # save rr sequences self.__save_rr_seq(rec, rr_seq, update_rr_seq_json) def __save_rr_seq(self, rec: str, rr_seq: List[CFG], update_rr_seq_json: bool = False) -> None: """Save rr_seq to disk. Parameters ---------- rec : str Name of the record. rr_seq : List[dict] List of the rr_seq (meta-)data to be saved. update_rr_seq_json : bool, default False Whether to update the rr_seq.json file. Returns ------- None """ ordering = list(range(len(rr_seq))) DEFAULTS.RNG.shuffle(ordering) for i, idx in enumerate(ordering): item = rr_seq[idx] filename = f"R_{rec}_{i:07d}.{self.rr_seq_ext}" data_path = self.rr_seq_dirs[rec] / filename savemat(str(data_path), item) self.__all_rr_seq[rec].append(Path(filename).with_suffix("").name) if update_rr_seq_json: self.rr_seq_json.write_text(json.dumps(self.__all_rr_seq, ensure_ascii=False)) def _clear_cached_rr_seq(self, recs: Optional[Sequence[str]] = None) -> None: """Clear the cached rr sequences. Parameters ---------- recs : Sequence[str], optional Sequence of the records whose segments are to be cleared. Defaults to all records. Returns ------- None """ self.__assert_task(["rr_lstm"]) if recs is None: recs = self.reader.all_records for rec in recs: path = str(self.rr_seq_dirs[rec]) for f in [n for n in os.listdir(path) if n.endswith(self.rr_seq_ext)]: os.remove(os.path.join(path, f)) if os.path.splitext(f)[0] in self.__all_rr_seq[rec]: self.__all_rr_seq[rec].remove(os.path.splitext(f)[0]) self.rr_seq = list_sum([self.__all_rr_seq[rec] for rec in self.records]) def _get_rec_name(self, seg_or_rr: str) -> str: """Get the record name from the segment or rr_seq name. Parameters ---------- seg_or_rr : str Name of the segment or rr_seq. Returns ------- rec : str Name of the record that the segment or rr_seq was generated from. """ rec = seg_or_rr.split("_")[1] return rec def _train_test_split(self, task: str) -> Dict[str, List[str]]: """Perform train-test split. Parameters ---------- task : str Task name for which the split is performed. Returns ------- split_res : dict A dictionary of the split results. Keys are "train" and "test", and values are list of the subjects split for training or validation. """ if task in [ "beat_classification", "qrs_detection", ]: full_test_set = ["101", "102", "108", "114", "207", "223"] train_set = [rec for rec in self.reader if rec not in full_test_set] test_set = [rec for rec in self.reader if rec in full_test_set] else: # rhythm segmentation, af event, rr_lstm full_test_set = ["106", "114", "124", "202", "217", "232"] train_set = [rec for rec in self.reader if rec not in full_test_set] test_set = [rec for rec in self.reader if rec in full_test_set] split_res = CFG( { "train": train_set, "test": test_set, } ) return split_res def __assert_task(self, tasks: List[str]) -> None: """Check if the current task is in the given list.""" assert self.task in tasks, ( f"DO NOT call this method when the current task is {self.task}. " "Switch task using `reset_task`" )
[docs] def plot_seg(self, seg: str, ticks_granularity: int = 0) -> None: """Plot the segment. Parameters ---------- seg : str Name of the segment, of pattern like "S_1_1_0000193". ticks_granularity : int, default 0 Granularity to plot axis ticks, the higher the more ticks. 0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks) Returns ------- None """ seg_data = self._load_seg_data(seg) seg_ann = self._load_seg_ann(seg) seg_ann["rhythm_intervals"] = mask_to_intervals(seg_ann["rhythm_mask"], vals=1) rec_name = self._get_rec_name(seg) self.reader.plot( rec=rec_name, # unnecessary indeed data=seg_data, ann=seg_ann, ticks_granularity=ticks_granularity, )
[docs] def extra_repr_keys(self) -> List[str]: return [ "training", "task", "reader", ]
class _FastDataReader(ReprMixin, Dataset): """Fast data reader. Parameters ---------- reader : CR The reader to read the data. records : Sequence[str] The list of records to read. config : CFG The configuration. ppm : PreprocManager, optional The preprocessor manager. """ def __init__( self, config: CFG, task: str, seg_ppm: PreprocManager, file_dirs: dict, files: List[str], file_ext: str, rhythm_types_map: dict, ) -> None: self.config = config self.task = task self.seg_ppm = seg_ppm self.file_dirs = file_dirs self.files = files self.file_ext = file_ext self.rhythm_types_map = rhythm_types_map self.seglen = self.config[self.task].input_len self.n_classes = len(self.config[task].classes) self._seg_keys = { "qrs_detection": "qrs_mask", "rhythm_segmentation": "rhythm_mask", "af_event": "rhythm_mask", # segmentation of AF events } def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]: if self.task in [ "qrs_detection", "rhythm_segmentation", "af_event", # segmentation of AF events ]: seg_name = self.files[index] rec = seg_name.split("_")[1] seg_data_fp = self.file_dirs.data[rec] / f"{seg_name}.{self.file_ext}" seg_data = loadmat(str(seg_data_fp))["ecg"] for idx in range(seg_data.shape[0]): seg_data[idx] = remove_spikes_naive(seg_data[idx]) seg_ann_fp = self.file_dirs.ann[rec] / f"{seg_name}.{self.file_ext}" seg_label = loadmat(str(seg_ann_fp))[self._seg_keys[self.task]].reshape((self.seglen, -1)) if self.config[self.task].reduction > 1: reduction = self.config[self.task].reduction seg_len, n_classes = seg_label.shape seg_label = np.stack( arrays=[ np.mean( seg_data[reduction * idx : reduction * (idx + 1)], axis=0, keepdims=True, ).astype(int) for idx in range(seg_len // reduction) ], axis=0, ).squeeze(axis=1) seg_data, _ = self.seg_ppm(seg_data, self.config.fs) if self.task == [ "rhythm_segmentation", "af_event", # segmentation of AF events ]: weight_mask = generate_weight_mask( target_mask=seg_label.squeeze(-1), fg_weight=2, fs=self.config.fs, reduction=self.config[self.task].reduction, radius=0.8, boundary_weight=5, )[..., np.newaxis] return seg_data, seg_label, weight_mask return seg_data, seg_label, None elif self.task in [ "rr_lstm", ]: seq_name = self.files[index] rec = seq_name.split("_")[1] rr_seq_path = self.file_dirs[rec] / f"{seq_name}.{self.file_ext}" rr_seq = loadmat(str(rr_seq_path)) rr_seq["rr"] = rr_seq["rr"].reshape((self.seglen, 1)) rr_seq["label"] = rr_seq["label"].reshape((self.seglen, self.n_classes)) # map values of `rr_seq["label"]` to 0, 1 according to `self.reader.rhythm_types_map` rr_seq["label"][np.where(rr_seq["label"] == self.rhythm_types_map["AFIB"])] = 1 rr_seq["label"][np.where(rr_seq["label"] != self.rhythm_types_map["AFIB"])] = 0 weight_mask = generate_weight_mask( target_mask=rr_seq["label"].squeeze(-1), fg_weight=2, fs=1 / 0.8, reduction=1, radius=2, boundary_weight=5, )[..., np.newaxis] return rr_seq["rr"], rr_seq["label"], weight_mask else: raise NotImplementedError(f"data generator for task \042{self.task}\042 not implemented") def __len__(self) -> int: return len(self.files) def extra_repr_keys(self) -> List[str]: return [ "task", "reader", "ppm", ]