WaveDelineationOutput¶
- class torch_ecg.components.WaveDelineationOutput(*args: Any, **kwargs: Any)[source]¶
Bases:
SequenceTaggingOutput
Class that maintains the output of a wave delineation task.
- Parameters:
classes (Sequence[str]) – class names.
prob (numpy.ndarray) – Probabilities of each class at each time step (each sample point), of shape
(batch_size, signal_length, num_classes)
.mask (numpy.ndarray) – Predicted class indices at each time step (each sample point), or binary predictions at each time step (each sample point), of shape
(batch_size, num_channels, signal_length)
.
Note
Known issues:
fields of type dict are not well supported due to the limitations of the base class CFG, for example
>>> output = WaveDelineationOutput(classes=["N", "P", "Q",], thr=0.5, mask=np.ones((1,3,3)), prob=np.ones((1,3,3)), d={"d":1}) >>> output {'classes': ['AF', 'N', 'SPB'], 'prob': array([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]), 'mask': array([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]), 'd': {'d': 1}} >>> output.d # has to access via `output["d"]` AttributeError: 'WaveDelineationOutput' object has no attribute 'd'
- compute_metrics(fs: int, class_map: Dict[str, int], macro: bool = True, tol: float = 0.15) ClassificationMetrics [source]¶
Compute metrics from the output
- Parameters:
fs (numbers.Real) – Sampling frequency of the signal corresponding to the masks, used to compute the duration of each waveform, and thus the error and standard deviations of errors.
class_map (dict) – Class map, mapping names to waves to numbers from 0 to n_classes-1, the keys should contain “pwave”, “qrs”, “twave”.
macro (bool) – Whether to use macro-averaged metrics or not.
tol (float, default 0.15) – Tolerance for the duration of the waveform, with units in seconds.
- Returns:
metrics – Metrics computed from the output
- Return type: