SequenceLabellingOutput¶
- class torch_ecg.components.SequenceLabellingOutput(*args: Any, **kwargs: Any)¶
Bases:
BaseOutput
Class that maintains the output of a sequence tagging task.
- Parameters:
classes (Sequence[str]) – Class names.
prob (numpy.ndarray) – Probabilities of each class at each time step (each sample point), of shape
(batch_size, signal_length, num_classes)
.pred (numpy.ndarray) – Predicted class indices at each time step (each sample point), of shape
(batch_size, signal_length)
; or binary predictions at each time step (each sample point), of shape(batch_size, signal_length, num_classes)
.
Note
Known issues:
fields of type dict are not well supported due to the limitations of the base class CFG, for example
>>> output = SequenceTaggingOutput(classes=["AF", "N", "SPB"], thr=0.5, pred=np.ones((1,3,3)), prob=np.ones((1,3,3)), d={"d":1}) >>> output {'classes': ['AF', 'N', 'SPB'], 'prob': array([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]), 'pred': array([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]), 'thr': 0.5, 'd': {'d': 1}} >>> output.d # has to access via `output["d"]` AttributeError: 'SequenceTaggingOutput' object has no attribute 'd'
- compute_metrics(macro: bool = True) ClassificationMetrics ¶
Compute metrics from the output.
- Parameters:
macro (bool) – Whether to use macro-averaged metrics or not.
- Returns:
metrics – Metrics computed from the output.
- Return type: