Source code for torch_ecg.components.outputs


from abc import ABC, abstractmethod
from typing import Any, Dict, Sequence, Set, Union

import numpy as np
import pandas as pd

from ..cfg import CFG
from ..utils.misc import add_docstring
from ..utils.utils_data import ECGWaveFormNames
from .metrics import ClassificationMetrics, RPeaksDetectionMetrics, WaveDelineationMetrics

__all__ = [

    Known issues:

        - fields of type `dict` are not well supported due to the limitations of the base class `CFG`, for example

        .. code-block:: python

_ClassificationOutput_ISSUE_EXAMPLE = """
            >>> output = ClassificationOutput(classes=["AF", "N", "SPB"], pred=np.ones((1,3)), prob=np.ones((1,3)), d={"d":1})
            >>> output
            {'classes': ['AF', 'N', 'SPB'],
                'prob': array([[1., 1., 1.]]),
                'pred': array([[1., 1., 1.]]),
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'ClassificationOutput' object has no attribute 'd'
_MultiLabelClassificationOutput_ISSUE_EXAMPLE = """
            >>> output = MultiLabelClassificationOutput(classes=["AF", "N", "SPB"], thr=0.5, pred=np.ones((1,3)), prob=np.ones((1,3)), d={"d":1})
            >>> output
            {'classes': ['AF', 'N', 'SPB'],
                'prob': array([[1., 1., 1.]]),
                'pred': array([[1., 1., 1.]]),
                'thr': 0.5,
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'MultiLabelClassificationOutput' object has no attribute 'd'
_SequenceTaggingOutput_ISSUE_EXAMPLE = """
            >>> output = SequenceTaggingOutput(classes=["AF", "N", "SPB"], thr=0.5, pred=np.ones((1,3,3)), prob=np.ones((1,3,3)), d={"d":1})
            >>> output
            {'classes': ['AF', 'N', 'SPB'],
                'prob': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'pred': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'thr': 0.5,
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'SequenceTaggingOutput' object has no attribute 'd'
_WaveDelineationOutput_ISSUE_EXAMPLE = """
            >>> output = WaveDelineationOutput(classes=["N", "P", "Q",], thr=0.5, mask=np.ones((1,3,3)), prob=np.ones((1,3,3)), d={"d":1})
            >>> output
            {'classes': ['AF', 'N', 'SPB'],
                'prob': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'mask': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'WaveDelineationOutput' object has no attribute 'd'
_RPeaksDetectionOutput_ISSUE_EXAMPLE = """
            >>> output = RPeaksDetectionOutput(rpeak_indices=[[2]], thr=0.5, prob=np.ones((1,3,3)), d={"d":1})
            >>> output
            {'rpeak_indices': [[2]],
                'prob': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'thr': 0.5,
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'RPeaksDetectionOutput' object has no attribute 'd'

class BaseOutput(CFG, ABC):
    """Base class for all outputs.

    *args : sequence
        Positional arguments.
    **kwargs : dict
        Keyword arguments.


    __name__ = "BaseOutput"

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        pop_fields = [k for k in self if k in ["required_fields", "append", "compute_metrics"] or k.startswith("_abc")]
        for f in pop_fields:
            self.pop(f, None)
        assert all([field in self.keys() for field in self.required_fields()]), (
            f"{self.__name__} requires {self.required_fields()}, "
            f"but `{', '.join(self.required_fields() - set(self.keys()))}` are missing"
        assert all(
            [self[field] is not None for field in self.required_fields()]
        ), f"Fields `{', '.join([field for field in self.required_fields() if self[field] is None])}` are not set"

    def required_fields(self) -> Set[str]:
        """The required fields of the output class."""
        raise NotImplementedError("Subclass must implement method `required_fields`")

    def append(self, values: Union["BaseOutput", Sequence["BaseOutput"]]) -> None:
        """Append other :class:`Output` to `self`

        values : Output or Sequence[Output]
            The values to be appended.


        if not isinstance(values, Sequence):
            values = [values]
        for v in values:
            assert v.__class__ == self.__class__, "`values` must be of the same type as `self`"
            assert set(v.keys()) == set(self.keys()), "`values` must have the same fields as `self`"
            for k, v_ in v.items():
                if k in ["classes"]:
                    assert v_ == self[k], f"the field of ordered sequence `{k}` must be the identical"
                if isinstance(v_, np.ndarray):
                    self[k] = np.concatenate((self[k], v_))
                elif isinstance(v_, pd.DataFrame):
                    self[k] = pd.concat([self[k], v_], axis=0, ignore_index=True)
                elif isinstance(v_, Sequence):  # list, tuple, etc.
                    self[k] += v_
                    raise ValueError(f"field `{k}` of type `{type(v_)}` is not supported")

[docs]@add_docstring(_KNOWN_ISSUES.format(_ClassificationOutput_ISSUE_EXAMPLE), "append") class ClassificationOutput(BaseOutput): """ Class that maintains the output of a (typically single-label) classification task. Parameters ---------- classes : Sequence[str] Class names. prob : numpy.ndarray Probabilities of each class, of shape ``(batch_size, num_classes)``. pred : numpy.ndarray Predicted class indices of shape ``(batch_size,)``, or binary predictions of shape ``(batch_size, num_classes)``. """ __name__ = "ClassificationOutput"
[docs] def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "classes", "prob", "pred", ] )
[docs] def compute_metrics(self) -> ClassificationMetrics: """Compute metrics from the output. Returns ------- metrics : ClassificationMetrics Metrics computed from the output. """ assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" clf_met = ClassificationMetrics(multi_label=False, macro=True) return clf_met.compute(self.get("labels", self.get("label")), self.pred, len(self.classes))
[docs]@add_docstring(_KNOWN_ISSUES.format(_MultiLabelClassificationOutput_ISSUE_EXAMPLE), "append") class MultiLabelClassificationOutput(BaseOutput): """ Class that maintains the output of a multi-label classification task. Parameters ---------- classes : Sequence[str] class names thr : float threshold for making binary predictions prob : numpy.ndarray Probabilities of each class, of shape ``(batch_size, num_classes)`` pred : numpy.ndarray Binary predictions, of shape ``(batch_size, num_classes)``. """ __name__ = "MultiLabelClassificationOutput"
[docs] def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "classes", "thr", "prob", "pred", ] )
[docs] def compute_metrics(self, macro: bool = True) -> ClassificationMetrics: """Compute metrics from the output. Parameters ---------- macro : bool Whether to use macro-averaged metrics or not. Returns ------- metrics : ClassificationMetrics Metrics computed from the output. """ assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" clf_met = ClassificationMetrics(multi_label=True, macro=macro) return clf_met.compute(self.get("labels", self.get("label")), self.pred, len(self.classes))
[docs]@add_docstring(_KNOWN_ISSUES.format(_SequenceTaggingOutput_ISSUE_EXAMPLE), "append") class SequenceTaggingOutput(BaseOutput): """Class that maintains the output of a sequence tagging task. Parameters ---------- classes : Sequence[str] Class names. prob : numpy.ndarray Probabilities of each class at each time step (each sample point), of shape ``(batch_size, signal_length, num_classes)``. pred : numpy.ndarray Predicted class indices at each time step (each sample point), of shape ``(batch_size, signal_length)``; or binary predictions at each time step (each sample point), of shape ``(batch_size, signal_length, num_classes)``. """ __name__ = "SequenceTaggingOutput" def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "classes", "prob", "pred", ] ) def compute_metrics(self, macro: bool = True) -> ClassificationMetrics: """Compute metrics from the output. Parameters ---------- macro : bool Whether to use macro-averaged metrics or not. Returns ------- metrics : ClassificationMetrics Metrics computed from the output. """ assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" clf_met = ClassificationMetrics(multi_label=False, macro=macro) labels = self.get("labels", self.get("label")) return clf_met.compute( labels.reshape((-1, labels.shape[-1])), self.pred.reshape((-1, self.pred.shape[-1])), len(self.classes), )
# alias SequenceLabellingOutput = SequenceTaggingOutput SequenceLabellingOutput.__name__ = "SequenceLabellingOutput"
[docs]@add_docstring(_KNOWN_ISSUES.format(_WaveDelineationOutput_ISSUE_EXAMPLE), "append") class WaveDelineationOutput(SequenceTaggingOutput): """Class that maintains the output of a wave delineation task. Parameters ---------- classes : Sequence[str] class names. prob : numpy.ndarray Probabilities of each class at each time step (each sample point), of shape ``(batch_size, signal_length, num_classes)``. mask : numpy.ndarray Predicted class indices at each time step (each sample point), or binary predictions at each time step (each sample point), of shape ``(batch_size, num_channels, signal_length)``. """ __name__ = "WaveDelineationOutput"
[docs] def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "classes", "prob", "mask", ] )
[docs] @add_docstring( f"""Compute metrics from the output Parameters ---------- fs : numbers.Real Sampling frequency of the signal corresponding to the masks, used to compute the duration of each waveform, and thus the error and standard deviations of errors. class_map : dict Class map, mapping names to waves to numbers from 0 to n_classes-1, the keys should contain {", ".join([f'"{item}"' for item in ECGWaveFormNames])}. macro : bool Whether to use macro-averaged metrics or not. tol : float, default 0.15 Tolerance for the duration of the waveform, with units in seconds. Returns ------- metrics : WaveDelineationMetrics Metrics computed from the output """ ) def compute_metrics( self, fs: int, class_map: Dict[str, int], macro: bool = True, tol: float = 0.15, ) -> ClassificationMetrics: assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" wd_met = WaveDelineationMetrics(macro=macro, tol=tol) labels = self.get("labels", self.get("label")) return wd_met.compute(labels, self.mask, class_map=class_map, fs=fs)
[docs]@add_docstring(_KNOWN_ISSUES.format(_RPeaksDetectionOutput_ISSUE_EXAMPLE), "append") class RPeaksDetectionOutput(BaseOutput): """ Class that maintains the output of an R peaks detection task. Parameters ---------- rpeak_indices : Sequence[Sequence[int]] Rpeak indices for each batch sample. prob : numpy.ndarray Probabilities at each time step (each sample point), of shape ``(batch_size, signal_length)``. """ __name__ = "RPeaksDetectionOutput"
[docs] def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "rpeak_indices", "prob", ] )
[docs] def compute_metrics(self, fs: int, thr: float = 0.075) -> ClassificationMetrics: """Compute metrics from the output. Parameters ---------- fs : int Sampling frequency of the signal corresponding to the masks. thr : float, default 0.075 Threshold for a prediction to be truth positive, with units in seconds. Returns ------- metrics : RPeaksDetectionMetrics Metrics computed from the output. """ assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" rpd_met = RPeaksDetectionMetrics(thr=thr) return rpd_met.compute(self.get("labels", self.get("label")), self.rpeak_indices, fs=fs)