MaskedBCEWithLogitsLoss

class torch_ecg.models.loss.MaskedBCEWithLogitsLoss[source]

Bases: BCEWithLogitsLoss

Masked Binary Cross Entropy Loss class.

This loss is used mainly for the segmentation task, where there are some regions that are of much higher importance, for example, the onsets and offsets of some particular events (e.g. paroxysmal atrial fibrillation (AF) episodes).

This loss is proposed in [1], with a reference to the loss function used in the U-Net paper [2].

References

forward(input: Tensor, target: Tensor, weight_mask: Tensor) Tensor[source]

Forward pass.

Parameters:
  • input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape (batch_size, sig_len, n_classes).

  • target (torch.Tensor) – The target tensor, of shape (batch_size, sig_len, n_classes).

  • weight_mask (torch.Tensor) – The weight mask tensor, of shape (batch_size, sig_len, n_classes).

Returns:

The masked binary cross entropy loss.

Return type:

torch.Tensor

Note

input, target, and weight_mask should be 3-D tensors of the same shape.