WaveDelineationOutput

class torch_ecg.components.WaveDelineationOutput(*args: Any, **kwargs: Any)[source]

Bases: SequenceTaggingOutput

Class that maintains the output of a wave delineation task.

Parameters:
  • classes (Sequence[str]) – class names.

  • prob (numpy.ndarray) – Probabilities of each class at each time step (each sample point), of shape (batch_size, signal_length, num_classes).

  • mask (numpy.ndarray) – Predicted class indices at each time step (each sample point), or binary predictions at each time step (each sample point), of shape (batch_size, num_channels, signal_length).

Note

Known issues:

  • fields of type dict are not well supported due to the limitations of the base class CFG, for example

>>> output = WaveDelineationOutput(classes=["N", "P", "Q",], thr=0.5, mask=np.ones((1,3,3)), prob=np.ones((1,3,3)), d={"d":1})
>>> output
{'classes': ['AF', 'N', 'SPB'],
    'prob': array([[[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]]]),
    'mask': array([[[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]]]),
    'd': {'d': 1}}
>>> output.d  # has to access via `output["d"]`
AttributeError: 'WaveDelineationOutput' object has no attribute 'd'
compute_metrics(fs: int, class_map: Dict[str, int], macro: bool = True, tol: float = 0.15) ClassificationMetrics[source]

Compute metrics from the output

Parameters:
  • fs (numbers.Real) – Sampling frequency of the signal corresponding to the masks, used to compute the duration of each waveform, and thus the error and standard deviations of errors.

  • class_map (dict) – Class map, mapping names to waves to numbers from 0 to n_classes-1, the keys should contain “pwave”, “qrs”, “twave”.

  • macro (bool) – Whether to use macro-averaged metrics or not.

  • tol (float, default 0.15) – Tolerance for the duration of the waveform, with units in seconds.

Returns:

metrics – Metrics computed from the output

Return type:

WaveDelineationMetrics

required_fields() Set[str][source]

The required fields of the output class.