torch_ecg.utils.extend_predictions

torch_ecg.utils.extend_predictions(preds: Sequence, classes: List[str], extended_classes: List[str]) ndarray[source]

Extend the prediction arrays to prediction arrays in larger range of classes

Parameters:
  • preds (array_like) – Array of predictions (scalar or binary) of shape (n_records, n_classes), or categorical predictions of shape (n_classes,), where n_classes = len(classes).

  • classes (List[str]) – Classes of the predictions of preds.

  • extended_classes (List[str]) – A superset of classes. The predictions will be extended to this range of classes.

Returns:

extended_preds – The extended array of predictions, with indices in extended_classes, of shape (n_records, n_classes), or (n_classes,).

Return type:

numpy.ndarray

Examples

n_records, n_classes = 10, 3
classes = ["NSR", "AF", "PVC"]
extended_classes = ["AF", "RBBB", "PVC", "NSR"]
scalar_pred = torch.rand(n_records, n_classes)
extended_pred = extend_predictions(scalar_pred, classes, extended_classes)
bin_pred = torch.randint(0, 2, (n_records, n_classes))
extended_pred = extend_predictions(bin_pred, classes, extended_classes)
cate_pred = torch.randint(0, n_classes, (n_records,))
extended_pred = extend_predictions(cate_pred, classes, extended_classes)