SequenceLabellingOutput

class torch_ecg.components.SequenceLabellingOutput(*args: Any, **kwargs: Any)

Bases: BaseOutput

Class that maintains the output of a sequence tagging task.

Parameters:
  • classes (Sequence[str]) – Class names.

  • prob (numpy.ndarray) – Probabilities of each class at each time step (each sample point), of shape (batch_size, signal_length, num_classes).

  • pred (numpy.ndarray) – Predicted class indices at each time step (each sample point), of shape (batch_size, signal_length); or binary predictions at each time step (each sample point), of shape (batch_size, signal_length, num_classes).

Note

Known issues:

  • fields of type dict are not well supported due to the limitations of the base class CFG, for example

>>> output = SequenceTaggingOutput(classes=["AF", "N", "SPB"], thr=0.5, pred=np.ones((1,3,3)), prob=np.ones((1,3,3)), d={"d":1})
>>> output
{'classes': ['AF', 'N', 'SPB'],
    'prob': array([[[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]]]),
    'pred': array([[[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]]]),
    'thr': 0.5,
    'd': {'d': 1}}
>>> output.d  # has to access via `output["d"]`
AttributeError: 'SequenceTaggingOutput' object has no attribute 'd'
compute_metrics(macro: bool = True) ClassificationMetrics

Compute metrics from the output.

Parameters:

macro (bool) – Whether to use macro-averaged metrics or not.

Returns:

metrics – Metrics computed from the output.

Return type:

ClassificationMetrics

required_fields() Set[str]

The required fields of the output class.