ECG_CRNN#
- class torch_ecg.models.ECG_CRNN(classes: Sequence[str], n_leads: int, config: CFG | None = None, **kwargs: Any)[source]#
Bases:
Module
,CkptMixin
,SizeMixin
,CitationMixin
Convolutional (Recurrent) Neural Network for ECG tasks.
This C(R)NN architecture is adapted from [Yao et al.[1], Yao et al.[2]] in the first place,and then modified to be more general, and more flexible. The most famous model is perhaps [Hannun et al.[3]], which is a modified 1D-ResNet34 model. The website of this model is https://stanfordmlgroup.github.io/projects/ecg2/, and the code is hosted on awni/ecg.
The C(R)NN models have long been competitive in various ECG tasks, e.g. CPSC2018 entry 0236, CPSC2019 entry 0416. The models are also used in the PhysioNet/CinC Challenges.
- Parameters:
References
- compute_output_shape(seq_len: int | None = None, batch_size: int | None = None) Sequence[int | None] [source]#
Compute the output shape of the model.
- extract_features(input: Tensor) Tensor [source]#
Extract feature map before the dense (linear) classifying layer(s).
- Parameters:
input (torch.Tensor) – Input signal tensor, of shape
(batch_size, channels, seq_len)
.- Returns:
features – Feature map tensor, of shape
(batch_size, channels, seq_len)
or(batch_size, channels)
.- Return type:
- forward(input: Tensor) Tensor [source]#
Forward pass of the model.
- Parameters:
input (torch.Tensor) – Input signal tensor, of shape
(batch_size, channels, seq_len)
.- Returns:
pred – Predictions tensor, of shape
(batch_size, seq_len, channels)
or(batch_size, channels)
.- Return type:
- classmethod from_v1(v1_ckpt: str, device: device | None = None, return_config: bool = False) ECG_CRNN | Tuple[ECG_CRNN, dict] [source]#
Restore an instance of the model from a v1 checkpoint.
- Parameters:
v1_ckpt (str) – Path to the v1 checkpoint file.
device (torch.device, optional) – The device to load the model to. Defaults to “cuda” if available, otherwise “cpu”.
return_config (bool, default False) – Whether to return the config dict.
- Returns:
model (ECG_CRNN) – The model instance restored from the v1 checkpoint.
config (dict) – The config dict. (if return_config is True)
- inference(input: ndarray | Tensor, class_names: bool = False, bin_pred_thr: float = 0.5) BaseOutput [source]#
Inference method for the model.
- Parameters:
input (numpy.ndarray or torch.Tensor) – Input tensor, of shape
(batch_size, channels, seq_len)
.class_names (bool, default False) – If True, the returned scalar predictions will be a
DataFrame
, with class names for each scalar prediction.bin_pred_thr (float, default 0.5) – Threshold for making binary predictions from scalar predictions.
- Returns:
output –
The output of the inference method, including the following items:
prob: numpy.ndarray or torch.Tensor, scalar predictions, (and binary predictions if class_names is True).
pred: numpy.ndarray or torch.Tensor, the array (with values 0, 1 for each class) of binary prediction.
- Return type:
BaseOutput