MaskedBCEWithLogitsLoss#

class torch_ecg.models.loss.MaskedBCEWithLogitsLoss[source]#

Bases: BCEWithLogitsLoss

Masked Binary Cross Entropy Loss class.

This loss is used mainly for the segmentation task, where there are some regions that are of much higher importance, for example, the onsets and offsets of some particular events (e.g. paroxysmal atrial fibrillation (AF) episodes).

This loss is proposed in [1], with a reference to the loss function used in the U-Net paper [2].

References

forward(input: Tensor, target: Tensor, weight_mask: Tensor) Tensor[source]#

Forward pass.

Parameters:
  • input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape (batch_size, ..., n_classes).

  • target (torch.Tensor) – The target tensor, of shape (batch_size, ..., n_classes).

  • weight_mask (torch.Tensor) – The weight mask tensor, of shape (batch_size, ..., n_classes), or (batch_size, ..., 1), or (batch_size, ...).

Returns:

The masked binary cross entropy loss.

Return type:

torch.Tensor

Note

input and target should be N-D tensors of the same shape, with N >=3; weight_mask should have the same shape, or the last dimension can be 1, or the last dimension can be omitted.

A typical example is when input and target have shape (batch_size, sig_len, n_classes).