FocalLoss

class torch_ecg.models.loss.FocalLoss(gamma: float = 2.0, weight: Tensor | None = None, class_weight: Tensor | None = None, size_average: bool | None = None, reduce: bool | None = None, reduction: str = 'mean', multi_label: bool = True, **kwargs: Any)[source]

Bases: _WeightedLoss

Focal loss class.

The focal loss is proposed in [1], and this implementation is based on [2], [3], and [4]. The focal loss is computed as follows:

\[\operatorname{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \log(p_t)\]

Where:

  • \(p_t\) is the model’s estimated probability for each class.

Parameters:
  • gamma (float, default 2.0) – The gamma parameter of focal loss.

  • weight (torch.Tensor, optional) – If multi_label is True, is a manual rescaling weight given to the loss of each batch element, of size batch_size; if multi_label is False, is a weight for each class, of size n_classes.

  • class_weight (torch.Tensor, optional) – The class weight, of shape (1, n_classes).

  • size_average (bool, optional) – Not used, to keep in accordance with PyTorch native loss.

  • reduce (bool, optional) – Not used, to keep in accordance with PyTorch native loss.

  • reduction ({"none", "mean", "sum"}, optional) – Specifies the reduction to apply to the output, by default “mean”.

  • multi_label (bool, default True) – If True, the loss is computed for multi-label classification.

References

forward(input: Tensor, target: Tensor) Tensor[source]

Forward pass.

Parameters:
  • input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape (batch_size, n_classes).

  • target (torch.Tensor) – Multi-label binarized vector of shape (batch_size, n_classes), or single label binarized vector of shape (batch_size,).

Returns:

The focal loss.

Return type:

torch.Tensor