ClassificationMetrics¶
- class torch_ecg.components.ClassificationMetrics(multi_label: bool = True, macro: bool = True, extra_metrics: Callable | None = None)[source]¶
Bases:
Metrics
Metrics for the task of classification.
- Parameters:
multi_label (bool, default True) – Whether is multi-label classification task.
macro (bool, default True) – Whether to use macro-averaged metrics.
extra_metrics (callable, optional) –
Extra metrics to compute, has to be a function with signature:
def extra_metrics( labels : np.ndarray outputs : np.ndarray num_classes : Optional[int]=None weights : Optional[np.ndarray]=None ) -> dict
- compute(labels: ndarray | Tensor, outputs: ndarray | Tensor, num_classes: int | None = None, weights: ndarray | None = None, thr: float = 0.5) ClassificationMetrics [source]¶
Compute macro metrics, and metrics for each class.
- Parameters:
labels (numpy.ndarray or torch.Tensor) – Binary labels, of shape
(n_samples, n_classes)
, or indices of each label class, of shape(n_samples,)
.outputs (numpy.ndarray or torch.Tensor) – Probability outputs, of shape
(n_samples, n_classes)
, or binary outputs, of shape(n_samples, n_classes)
, or indices of each class predicted, of shape(n_samples,)
.num_classes (int, optional) – Number of classes. If labels and outputs are both of shape
(n_samples,)
, then num_classes must be specified.weights (numpy.ndarray or torch.Tensor, optional) – Weights for each class, of shape
(n_classes,)
, used to compute macro metrics.thr (float, default: 0.5) – Threshold for binary classification, valid only if outputs is of shape
(n_samples, n_classes)
.fillna (bool or float, default: 0.0) – If is False, then NaN will be left in the result. If is True, then NaN will be filled with 0.0. If is a float, then NaN will be filled with the specified value.
- Returns:
self – The metrics object itself with the computed metrics.
- Return type:
Examples
>>> from torch_ecg.cfg import DEFAULTS >>> # binary labels (100 samples, 10 classes, multi-label) >>> labels = DEFAULTS.RNG_randint(0, 1, (100, 10)) >>> # probability outputs (100 samples, 10 classes, multi-label) >>> outputs = DEFAULTS.RNG.random((100, 10)) >>> metrics = ClassificationMetrics() >>> metrics = metrics.compute(labels, outputs) >>> metrics.fl_measure 0.5062821146226457 >>> metrics.set_macro(False) >>> metrics.fl_measure array([0.46938776, 0.4742268 , 0.4375 , 0.52941176, 0.58 , 0.57692308, 0.55769231, 0.48351648, 0.55855856, 0.3956044 ]) >>> # binarize outputs (100 samples, 10 classes, multi-label) >>> outputs = DEFAULTS.RNG_randint(0, 1, (100, 10)) >>> # would raise >>> # RuntimeWarning: `outputs` is probably binary, AUC may be incorrect >>> metrics = ClassificationMetrics() >>> metrics = metrics.compute(labels, outputs) >>> metrics.fl_measure 0.5062821146226457 >>> metrics.set_macro(False) >>> metrics.fl_measure array([0.46938776, 0.4742268 , 0.4375 , 0.52941176, 0.58 , 0.57692308, 0.55769231, 0.48351648, 0.55855856, 0.3956044 ]) >>> # categorical outputs (100 samples, 10 classes) >>> outputs = DEFAULTS.RNG_randint(0, 9, (100,)) >>> # would raise >>> # RuntimeWarning: `outputs` is probably binary, AUC may be incorrect >>> metrics = ClassificationMetrics() >>> metrics = metrics.compute(labels, outputs) >>> metrics.fl_measure 0.5062821146226457 >>> metrics.set_macro(False) >>> metrics.fl_measure array([0.46938776, 0.4742268 , 0.4375 , 0.52941176, 0.58 , 0.57692308, 0.55769231, 0.48351648, 0.55855856, 0.3956044 ])