RR_LSTM#

class torch_ecg.models.RR_LSTM(classes: Sequence[str], config: CFG | None = None, **kwargs: Any)[source]#

Bases: Module, CkptMixin, SizeMixin, CitationMixin

LSTM model for RR time series classification or sequence labeling.

LSTM model using RR time series as input is studied in [Faust et al.[1]] for atrial fibrillation detection. It is further improved in [Wen et al.[2]] via incorporating attention mechanism and conditional random fields.

Parameters:
  • classes (List[str]) – List of the names of the classes.

  • config (dict) – Other hyper-parameters, including kernel sizes, etc. Refer to corresponding config file for details.

References

compute_output_shape(seq_len: int | None = None, batch_size: int | None = None) Sequence[int | None][source]#

Compute the output shape of the model.

Parameters:
  • seq_len (int, optional) – Length of the input series tensor.

  • batch_size (int, optional) – Batch size of the input series tensor.

Returns:

output_shape – Output shape of the model.

Return type:

sequence

forward(input: Tensor) Tensor[source]#

Forward pass of the model.

Parameters:

input (torch.Tensor) – Input RR series tensor of shape (seq_len, batch_size, n_channels), or (batch_size, n_channels, seq_len) if config.batch_first is True.

Returns:

Output tensor, of shape (batch_size, seq_len, n_classes) or (batch_size, n_classes).

Return type:

torch.Tensor

classmethod from_v1(v1_ckpt: str, device: device | None = None, return_config: bool = False) RR_LSTM | Tuple[RR_LSTM, dict][source]#

Restore an instance of the model from a v1 checkpoint.

Parameters:
  • v1_ckpt (str) – Path to the v1 checkpoint file.

  • device (torch.device, optional) – The device to load the model to. Defaults to “cuda” if available, otherwise “cpu”.

  • return_config (bool, default False) – Whether to return the config dict.

Returns:

model – The model instance restored from the v1 checkpoint.

Return type:

RR_LSTM

inference(input: Tensor, bin_pred_thr: float = 0.5) BaseOutput[source]#

Inference method for the model.