"""
"""
import warnings
from copy import deepcopy
from typing import List, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from ..cfg import CFG
from ..model_configs.ecg_seq_lab_net import ECG_SEQ_LAB_NET_CONFIG
from ..utils import add_docstring
from .ecg_crnn import ECG_CRNN, ECG_CRNN_v1
__all__ = [
"ECG_SEQ_LAB_NET",
]
[docs]class ECG_SEQ_LAB_NET(ECG_CRNN):
"""SOTA model from CPSC2019 challenge.
Sequence labeling nets, for wave delineation, QRS complex detection, etc.
Proposed in [:footcite:ct:`cai2020rpeak_seq_lab_net`].
pipeline
--------
(multi-scopic, etc.) cnn --> head ((bidi-lstm -->) "attention" --> seq linear) -> output
Parameters
----------
classes : List[str]
List of the classes for sequence labeling.
n_leads : int
Number of leads (number of input channels).
config : dict, optional
Other hyper-parameters, including kernel sizes, etc.
Refer to corresponding config file.
.. footbibliography::
"""
__name__ = "ECG_SEQ_LAB_NET"
__DEFAULT_CONFIG__ = {"recover_length": False}
__DEFAULT_CONFIG__.update(deepcopy(ECG_SEQ_LAB_NET_CONFIG))
def __init__(self, classes: Sequence[str], n_leads: int, config: Optional[CFG] = None) -> None:
_config = CFG(deepcopy(self.__DEFAULT_CONFIG__))
if not config:
warnings.warn("No config is provided, using default config.", RuntimeWarning)
_config.update(deepcopy(config) or {})
_config.global_pool = "none"
super().__init__(classes, n_leads, _config)
[docs] def forward(self, input: Tensor) -> Tensor:
"""Forward pass.
Parameters
----------
input : torch.Tensor
Input tensor,
of shape ``(batch_size, channels, seq_len)``.
Returns
-------
pred : torch.Tensor
Output tensor,
of shape ``(batch_size, seq_len, n_classes)``
"""
batch_size, channels, seq_len = input.shape
pred = super().forward(input) # (batch_size, len, n_classes)
if self.config.recover_length:
pred = F.interpolate(
pred.permute(0, 2, 1),
size=seq_len,
mode="linear",
align_corners=True,
).permute(0, 2, 1)
return pred
[docs] def compute_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
"""Compute the output shape of the model.
Parameters
----------
seq_len : int, optional
Length of the 1d input signal tensor.
batch_size : int, optional
Batch size of the input signal tensor.
Returns
-------
output_shape : sequence
The output shape of the model.
"""
output_shape = super().compute_output_shape(seq_len, batch_size)
if self.config.recover_length:
output_shape = (batch_size, seq_len, output_shape[-1])
return output_shape
[docs] @classmethod
def from_v1(cls, v1_ckpt: str, device: Optional[torch.device] = None) -> "ECG_SEQ_LAB_NET":
"""Convert the v1 model to the current version.
Parameters
----------
v1_ckpt : str
Path to the v1 checkpoint file.
Returns
-------
model : ECG_SEQ_LAB_NET
The converted model.
"""
v1_model, _ = ECG_SEQ_LAB_NET_v1.from_checkpoint(v1_ckpt, device=device)
model = cls(classes=v1_model.classes, n_leads=v1_model.n_leads, config=v1_model.config)
model = model.to(v1_model.device)
model.cnn.load_state_dict(v1_model.cnn.state_dict())
if model.rnn.__class__.__name__ != "Identity":
model.rnn.load_state_dict(v1_model.rnn.state_dict())
if model.attn.__class__.__name__ != "Identity":
model.attn.load_state_dict(v1_model.attn.state_dict())
model.clf.load_state_dict(v1_model.clf.state_dict())
del v1_model
return model
@property
def doi(self) -> List[str]:
return list(set(super().doi + ["10.1109/access.2020.2997473"]))
@add_docstring(ECG_SEQ_LAB_NET.__doc__)
class ECG_SEQ_LAB_NET_v1(ECG_CRNN_v1):
__name__ = "ECG_SEQ_LAB_NET_v1"
__DEFAULT_CONFIG__ = {"recover_length": False}
__DEFAULT_CONFIG__.update(deepcopy(ECG_SEQ_LAB_NET_CONFIG))
def __init__(self, classes: Sequence[str], n_leads: int, config: Optional[CFG] = None) -> None:
_config = CFG(deepcopy(self.__DEFAULT_CONFIG__))
if not config:
warnings.warn("No config is provided, using default config.", RuntimeWarning)
_config.update(deepcopy(config) or {})
_config.global_pool = "none"
super().__init__(classes, n_leads, _config)
def extract_features(self, input: Tensor) -> Tensor:
"""Extract feature map before the dense (linear) classifying layer(s).
Parameters
----------
input : torch.Tensor
Input tensor,
of shape ``(batch_size, channels, seq_len)``.
Returns
-------
features : torch.Tensor
Feature map tensor,
of shape ``(batch_size, seq_len, channels)``.
"""
# cnn
cnn_output = self.cnn(input) # (batch_size, channels, seq_len)
# rnn or none
if self.rnn:
rnn_output = cnn_output.permute(2, 0, 1) # (seq_len, batch_size, channels)
rnn_output = self.rnn(rnn_output) # (seq_len, batch_size, channels)
rnn_output = rnn_output.permute(1, 2, 0) # (batch_size, channels, seq_len)
else:
rnn_output = cnn_output
# attention
if self.attn:
features = self.attn(rnn_output) # (batch_size, channels, seq_len)
else:
features = rnn_output
# features = features.permute(0, 2, 1) # (batch_size, seq_len, channels)
return features
@add_docstring(ECG_SEQ_LAB_NET.forward.__doc__)
def forward(self, input: Tensor) -> Tensor:
batch_size, channels, seq_len = input.shape
pred = super().forward(input) # (batch_size, len, n_classes)
if self.config.recover_length:
pred = F.interpolate(
pred.permute(0, 2, 1),
size=seq_len,
mode="linear",
align_corners=True,
).permute(0, 2, 1)
return pred
@property
def doi(self) -> List[str]:
return list(set(super().doi + ["10.1109/access.2020.2997473"]))