StretchCompressOffline

class torch_ecg.augmenters.StretchCompressOffline(ratio: Real = 6, prob: float = 0.5, overlap: float = 0.5, critical_overlap: float = 0.85)[source]

Bases: ReprMixin

Offline stretch-or-compress augmenter.

Stretch-or-compress augmenter on orginal length-varying ECG signals (in the form of numpy arrays), for the purpose of offline data generation.

Parameters:
  • ratio (numbers.Real, default 6) – Mean ratio of the stretch or compress. If it is in the interval [1, 100], then it will be transformed to [0, 1]. The ratio of one batch element is sampled from a normal distribution.

  • prob (float, default 0.5) – Probability of the augmenter to be applied.

  • overlap (float, default 0.5) – Overlap of offline generated data.

  • critical_overlap (float, default 0.85) – Overlap of the critical region of the ECG.

Example

sco = StretchCompressOffline()
seglen = 600
sig = torch.randn((12, 60000)).numpy()
labels = torch.ones((60000, 3)).numpy().astype(int)
masks = torch.ones((60000, 1)).numpy().astype(int)
segments = sco(600, sig, labels, masks, critical_points=[10000,30000])
extra_repr_keys() List[str][source]

Extra keys for __repr__() and __str__().

generate(seglen: int, sig: ndarray, *labels: Sequence[ndarray], critical_points: Sequence[int] | None = None) List[Tuple[ndarray | int, ...]][source]

Generate stretched or compressed segments from the ECGs.

Parameters:
  • seglen (int) – Length of the ECG segments to be generated.

  • sig (numpy.ndarray,) – THe ECGs to generate stretched or compressed segments, of shape (lead, siglen).

  • labels (numpy.ndarray, optional) – Labels of the ECGs, of shape (label_len, channels). For example, when doing segmentation, label_len should be divisible by siglen, channels should be the same as the number of classes.

  • critical_points (Sequence[int], optional) – Indices of the critical points of the ECG, usually have larger overlap by self.critical_overlap.

Returns:

list of generated segments, consisting segments of the form (seg, label1, label2, ..., start_idx, end_idx).

Return type:

list