Source code for torch_ecg.components.outputs

"""
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Sequence, Set, Union

import numpy as np
import pandas as pd

from ..cfg import CFG
from ..utils.misc import add_docstring
from ..utils.utils_data import ECGWaveFormNames
from .metrics import ClassificationMetrics, RPeaksDetectionMetrics, WaveDelineationMetrics

__all__ = [
    "BaseOutput",
    "ClassificationOutput",
    "MultiLabelClassificationOutput",
    "SequenceTaggingOutput",
    "SequenceLabellingOutput",
    "WaveDelineationOutput",
    "RPeaksDetectionOutput",
]


_KNOWN_ISSUES = """
    NOTE
    ----
    Known issues:

        - fields of type `dict` are not well supported due to the limitations of the base class `CFG`, for example

        .. code-block:: python

        {}
    """
_ClassificationOutput_ISSUE_EXAMPLE = """
            >>> output = ClassificationOutput(classes=["AF", "N", "SPB"], pred=np.ones((1,3)), prob=np.ones((1,3)), d={"d":1})
            >>> output
            {'classes': ['AF', 'N', 'SPB'],
                'prob': array([[1., 1., 1.]]),
                'pred': array([[1., 1., 1.]]),
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'ClassificationOutput' object has no attribute 'd'
    """
_MultiLabelClassificationOutput_ISSUE_EXAMPLE = """
            >>> output = MultiLabelClassificationOutput(classes=["AF", "N", "SPB"], thr=0.5, pred=np.ones((1,3)), prob=np.ones((1,3)), d={"d":1})
            >>> output
            {'classes': ['AF', 'N', 'SPB'],
                'prob': array([[1., 1., 1.]]),
                'pred': array([[1., 1., 1.]]),
                'thr': 0.5,
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'MultiLabelClassificationOutput' object has no attribute 'd'
    """
_SequenceTaggingOutput_ISSUE_EXAMPLE = """
            >>> output = SequenceTaggingOutput(classes=["AF", "N", "SPB"], thr=0.5, pred=np.ones((1,3,3)), prob=np.ones((1,3,3)), d={"d":1})
            >>> output
            {'classes': ['AF', 'N', 'SPB'],
                'prob': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'pred': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'thr': 0.5,
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'SequenceTaggingOutput' object has no attribute 'd'
    """
_WaveDelineationOutput_ISSUE_EXAMPLE = """
            >>> output = WaveDelineationOutput(classes=["N", "P", "Q",], thr=0.5, mask=np.ones((1,3,3)), prob=np.ones((1,3,3)), d={"d":1})
            >>> output
            {'classes': ['AF', 'N', 'SPB'],
                'prob': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'mask': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'WaveDelineationOutput' object has no attribute 'd'
    """
_RPeaksDetectionOutput_ISSUE_EXAMPLE = """
            >>> output = RPeaksDetectionOutput(rpeak_indices=[[2]], thr=0.5, prob=np.ones((1,3,3)), d={"d":1})
            >>> output
            {'rpeak_indices': [[2]],
                'prob': array([[[1., 1., 1.],
                        [1., 1., 1.],
                        [1., 1., 1.]]]),
                'thr': 0.5,
                'd': {'d': 1}}
            >>> output.d  # has to access via `output["d"]`
            AttributeError: 'RPeaksDetectionOutput' object has no attribute 'd'
    """


class BaseOutput(CFG, ABC):
    """Base class for all outputs.

    Parameters
    ----------
    *args : sequence
        Positional arguments.
    **kwargs : dict
        Keyword arguments.

    """

    __name__ = "BaseOutput"

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        pop_fields = [k for k in self if k in ["required_fields", "append", "compute_metrics"] or k.startswith("_abc")]
        for f in pop_fields:
            self.pop(f, None)
        assert all([field in self.keys() for field in self.required_fields()]), (
            f"{self.__name__} requires {self.required_fields()}, "
            f"but `{', '.join(self.required_fields() - set(self.keys()))}` are missing"
        )
        assert all(
            [self[field] is not None for field in self.required_fields()]
        ), f"Fields `{', '.join([field for field in self.required_fields() if self[field] is None])}` are not set"

    @abstractmethod
    def required_fields(self) -> Set[str]:
        """The required fields of the output class."""
        raise NotImplementedError("Subclass must implement method `required_fields`")

    def append(self, values: Union["BaseOutput", Sequence["BaseOutput"]]) -> None:
        """Append other :class:`Output` to `self`

        Parameters
        ----------
        values : Output or Sequence[Output]
            The values to be appended.

        Returns
        -------
        None

        """
        if not isinstance(values, Sequence):
            values = [values]
        for v in values:
            assert v.__class__ == self.__class__, "`values` must be of the same type as `self`"
            assert set(v.keys()) == set(self.keys()), "`values` must have the same fields as `self`"
            for k, v_ in v.items():
                if k in ["classes"]:
                    assert v_ == self[k], f"the field of ordered sequence `{k}` must be the identical"
                    continue
                if isinstance(v_, np.ndarray):
                    self[k] = np.concatenate((self[k], v_))
                elif isinstance(v_, pd.DataFrame):
                    self[k] = pd.concat([self[k], v_], axis=0, ignore_index=True)
                elif isinstance(v_, Sequence):  # list, tuple, etc.
                    self[k] += v_
                else:
                    raise ValueError(f"field `{k}` of type `{type(v_)}` is not supported")


[docs]@add_docstring(_KNOWN_ISSUES.format(_ClassificationOutput_ISSUE_EXAMPLE), "append") class ClassificationOutput(BaseOutput): """ Class that maintains the output of a (typically single-label) classification task. Parameters ---------- classes : Sequence[str] Class names. prob : numpy.ndarray Probabilities of each class, of shape ``(batch_size, num_classes)``. pred : numpy.ndarray Predicted class indices of shape ``(batch_size,)``, or binary predictions of shape ``(batch_size, num_classes)``. """ __name__ = "ClassificationOutput"
[docs] def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "classes", "prob", "pred", ] )
[docs] def compute_metrics(self) -> ClassificationMetrics: """Compute metrics from the output. Returns ------- metrics : ClassificationMetrics Metrics computed from the output. """ assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" clf_met = ClassificationMetrics(multi_label=False, macro=True) return clf_met.compute(self.get("labels", self.get("label")), self.pred, len(self.classes))
[docs]@add_docstring(_KNOWN_ISSUES.format(_MultiLabelClassificationOutput_ISSUE_EXAMPLE), "append") class MultiLabelClassificationOutput(BaseOutput): """ Class that maintains the output of a multi-label classification task. Parameters ---------- classes : Sequence[str] class names thr : float threshold for making binary predictions prob : numpy.ndarray Probabilities of each class, of shape ``(batch_size, num_classes)`` pred : numpy.ndarray Binary predictions, of shape ``(batch_size, num_classes)``. """ __name__ = "MultiLabelClassificationOutput"
[docs] def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "classes", "thr", "prob", "pred", ] )
[docs] def compute_metrics(self, macro: bool = True) -> ClassificationMetrics: """Compute metrics from the output. Parameters ---------- macro : bool Whether to use macro-averaged metrics or not. Returns ------- metrics : ClassificationMetrics Metrics computed from the output. """ assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" clf_met = ClassificationMetrics(multi_label=True, macro=macro) return clf_met.compute(self.get("labels", self.get("label")), self.pred, len(self.classes))
[docs]@add_docstring(_KNOWN_ISSUES.format(_SequenceTaggingOutput_ISSUE_EXAMPLE), "append") class SequenceTaggingOutput(BaseOutput): """Class that maintains the output of a sequence tagging task. Parameters ---------- classes : Sequence[str] Class names. prob : numpy.ndarray Probabilities of each class at each time step (each sample point), of shape ``(batch_size, signal_length, num_classes)``. pred : numpy.ndarray Predicted class indices at each time step (each sample point), of shape ``(batch_size, signal_length)``; or binary predictions at each time step (each sample point), of shape ``(batch_size, signal_length, num_classes)``. """ __name__ = "SequenceTaggingOutput" def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "classes", "prob", "pred", ] ) def compute_metrics(self, macro: bool = True) -> ClassificationMetrics: """Compute metrics from the output. Parameters ---------- macro : bool Whether to use macro-averaged metrics or not. Returns ------- metrics : ClassificationMetrics Metrics computed from the output. """ assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" clf_met = ClassificationMetrics(multi_label=False, macro=macro) labels = self.get("labels", self.get("label")) return clf_met.compute( labels.reshape((-1, labels.shape[-1])), self.pred.reshape((-1, self.pred.shape[-1])), len(self.classes), )
# alias SequenceLabellingOutput = SequenceTaggingOutput SequenceLabellingOutput.__name__ = "SequenceLabellingOutput"
[docs]@add_docstring(_KNOWN_ISSUES.format(_WaveDelineationOutput_ISSUE_EXAMPLE), "append") class WaveDelineationOutput(SequenceTaggingOutput): """Class that maintains the output of a wave delineation task. Parameters ---------- classes : Sequence[str] class names. prob : numpy.ndarray Probabilities of each class at each time step (each sample point), of shape ``(batch_size, signal_length, num_classes)``. mask : numpy.ndarray Predicted class indices at each time step (each sample point), or binary predictions at each time step (each sample point), of shape ``(batch_size, num_channels, signal_length)``. """ __name__ = "WaveDelineationOutput"
[docs] def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "classes", "prob", "mask", ] )
[docs] @add_docstring( f"""Compute metrics from the output Parameters ---------- fs : numbers.Real Sampling frequency of the signal corresponding to the masks, used to compute the duration of each waveform, and thus the error and standard deviations of errors. class_map : dict Class map, mapping names to waves to numbers from 0 to n_classes-1, the keys should contain {", ".join([f'"{item}"' for item in ECGWaveFormNames])}. macro : bool Whether to use macro-averaged metrics or not. tol : float, default 0.15 Tolerance for the duration of the waveform, with units in seconds. Returns ------- metrics : WaveDelineationMetrics Metrics computed from the output """ ) def compute_metrics( self, fs: int, class_map: Dict[str, int], macro: bool = True, tol: float = 0.15, ) -> ClassificationMetrics: assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" wd_met = WaveDelineationMetrics(macro=macro, tol=tol) labels = self.get("labels", self.get("label")) return wd_met.compute(labels, self.mask, class_map=class_map, fs=fs)
[docs]@add_docstring(_KNOWN_ISSUES.format(_RPeaksDetectionOutput_ISSUE_EXAMPLE), "append") class RPeaksDetectionOutput(BaseOutput): """ Class that maintains the output of an R peaks detection task. Parameters ---------- rpeak_indices : Sequence[Sequence[int]] Rpeak indices for each batch sample. prob : numpy.ndarray Probabilities at each time step (each sample point), of shape ``(batch_size, signal_length)``. """ __name__ = "RPeaksDetectionOutput"
[docs] def required_fields(self) -> Set[str]: """The required fields of the output class.""" return set( [ "rpeak_indices", "prob", ] )
[docs] def compute_metrics(self, fs: int, thr: float = 0.075) -> ClassificationMetrics: """Compute metrics from the output. Parameters ---------- fs : int Sampling frequency of the signal corresponding to the masks. thr : float, default 0.075 Threshold for a prediction to be truth positive, with units in seconds. Returns ------- metrics : RPeaksDetectionMetrics Metrics computed from the output. """ assert hasattr(self, "labels") or hasattr( self, "label" ), "`labels` or `label` must be stored in the output for computing metrics" rpd_met = RPeaksDetectionMetrics(thr=thr) return rpd_met.compute(self.get("labels", self.get("label")), self.rpeak_indices, fs=fs)