Source code for torch_ecg.models.unets.ecg_subtract_unet

"""
2nd place (entry 0433) of CPSC2019

the main differences to a normal Unet are that

1. at the bottom, subtraction (and concatenation) is used
2. uses triple convolutions at each block, instead of double convolutions
3. dropout is used between certain convolutional layers ("cba" layers indeed)

"""

import textwrap
import warnings
from copy import deepcopy
from itertools import repeat
from typing import List, Optional, Sequence, Union

import numpy as np
import torch
from torch import Tensor, nn

from ...cfg import CFG
from ...model_configs import ECG_SUBTRACT_UNET_CONFIG
from ...models._nets import BranchedConv, Conv_Bn_Activation, DownSample, MultiConv
from ...utils.misc import add_docstring
from ...utils.utils_nn import (
    CkptMixin,
    SizeMixin,
    compute_deconv_output_shape,
    compute_sequential_output_shape,
    compute_sequential_output_shape_docstring,
)

__all__ = [
    "ECG_SUBTRACT_UNET",
]


class TripleConv(MultiConv):
    """Triple convolutional layer.

    CBA --> (Dropout) --> CBA --> (Dropout) --> CBA --> (Dropout).

    Parameters
    ----------
    in_channels : int
        Number of channels in the input tensor.
    out_channels : int or Sequence[int]
        Number of channels produced by the (last) convolutional layer(s).
    filter_lengths : int or Sequence[int]
        Length(s) of the filters (kernel size).
    subsample_lengths : int or Sequence[int], default 1
        Subsample length(s) (stride(s)) of the convolutions
    groups: int, default 1,
        connection pattern (of channels) of the inputs and outputs.
    dropouts : float or dict or Sequence[Union[float, dict]], default 0.0
        Dropout ratio after each :class:`Conv_Bn_Activation` block.
    out_activation : bool, default True
        If True, the last mini-block of :class:`Conv_Bn_Activation`
        will have activation as in `config`; otherwise, no activation.
    config : dict
        Other hyper-parameters, including
        activation choices, weight initializer, batch normalization choices, etc.
        for the convolutional layers.

    """

    __name__ = "TripleConv"

    def __init__(
        self,
        in_channels: int,
        out_channels: Union[Sequence[int], int],
        filter_lengths: Union[Sequence[int], int],
        subsample_lengths: Union[Sequence[int], int] = 1,
        groups: int = 1,
        dropouts: Union[Sequence[Union[float, dict]], float, dict] = 0.0,
        out_activation: bool = True,
        **config,
    ) -> None:
        _num_convs = 3
        if isinstance(out_channels, int):
            _out_channels = list(repeat(out_channels, _num_convs))
        else:
            _out_channels = list(out_channels)
            assert _num_convs == len(_out_channels)

        super().__init__(
            in_channels=in_channels,
            out_channels=_out_channels,
            filter_lengths=filter_lengths,
            subsample_lengths=subsample_lengths,
            groups=groups,
            dropouts=dropouts,
            out_activation=out_activation,
            **config,
        )


class DownTripleConv(nn.Sequential, SizeMixin):
    """Down sampling block of the U-net architecture.

    Composed of a down sampling layer and 3 convolutional layers.

    Parameters
    ----------
    down_scale : int
        Down sampling scale.
    in_channels : int
        Number of channels in the input.
    out_channels : int or Sequence[int]
        Number of channels produced by the (last) convolutional layer(s).
    filter_lengths : int or Sequence[int]
        Length(s) of the filters (kernel size).
    groups : int, default 1
        Connection pattern (of channels) of the inputs and outputs.
    dropouts : float or dict or Sequence[Union[float, dict]], default 0.0
        Dropout ratio after each :class:`Conv_Bn_Activation` block.
    mode : str, default "max"
        Down sampling mode, one of {:class:`DownSample`.__MODES__}.
    config : dict
        Other parameters, including
        activation choices, weight initializer, batch normalization choices, etc.
        for the convolutional layers.

    """

    __name__ = "DownTripleConv"
    __MODES__ = deepcopy(DownSample.__MODES__)

    def __init__(
        self,
        down_scale: int,
        in_channels: int,
        out_channels: Union[Sequence[int], int],
        filter_lengths: Union[Sequence[int], int],
        groups: int = 1,
        dropouts: Union[Sequence[Union[float, dict]], float, dict] = 0.0,
        mode: str = "max",
        **config,
    ) -> None:
        super().__init__()
        self.__mode = mode.lower()
        assert self.__mode in self.__MODES__
        self.__down_scale = down_scale
        self.__in_channels = in_channels
        self.__out_channels = out_channels
        self.config = CFG(deepcopy(config))

        self.add_module(
            "down_sample",
            DownSample(
                down_scale=self.__down_scale,
                in_channels=self.__in_channels,
                norm=False,
                mode=mode,
            ),
        )
        self.add_module(
            "triple_conv",
            TripleConv(
                in_channels=self.__in_channels,
                out_channels=self.__out_channels,
                filter_lengths=filter_lengths,
                subsample_lengths=1,
                groups=groups,
                dropouts=dropouts,
                **(self.config),
            ),
        )

    def forward(self, input: Tensor) -> Tensor:
        """Forward pass.

        Parameters
        ----------
        input : torch.Tensor
            Input tensor,
            of shape ``(batch_size, channels, seq_len)``.

        Returns
        -------
        out : torch.Tensor
            Output tensor,
            of shape ``(batch_size, channels, seq_len)``.

        """
        out = super().forward(input)
        return out

    @add_docstring(
        textwrap.indent(compute_sequential_output_shape_docstring, " " * 4),
        mode="append",
    )
    def compute_output_shape(
        self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
    ) -> Sequence[Union[int, None]]:
        """Compute the output shape of the down sampling block."""
        return compute_sequential_output_shape(self, seq_len, batch_size)


class DownBranchedDoubleConv(nn.Module, SizeMixin):
    """The bottom block of the encoder in the U-Net architecture.

    Parameters
    ----------
    down_scale : int
        Down sampling scale.
    in_channels : int
        Number of channels in the input tensor.
    out_channels : Sequence[Sequence[int]]
        Number of channels produced by the (last) convolutional layer(s).
    filter_lengths : int or Sequence[int] or Sequence[Sequence[int]]
        Length(s) of the filters (kernel size).
    dilations : int or Sequence[int] or Sequence[Sequence[int]], default 1
        Dilation(s) of the convolutions.
    groups : int, default 1
        Connection pattern (of channels) of the inputs and outputs.
    dropouts : float or dict or Sequence[Union[float, dict]], default 0.0
        Dropout ratio after each :class:`Conv_Bn_Activation` block.
    mode : str, default "max"
        Down sampling mode, one of {:class:`DownSample`.__MODES__}.
    config: dict
        Other hyper-parameters, including
        activation choices, weight initializer, batch normalization choices, etc.
        for the convolutional layers.

    """

    __name__ = "DownBranchedDoubleConv"
    __MODES__ = deepcopy(DownSample.__MODES__)

    def __init__(
        self,
        down_scale: int,
        in_channels: int,
        out_channels: Sequence[Sequence[int]],
        filter_lengths: Union[Sequence[Sequence[int]], Sequence[int], int],
        dilations: Union[Sequence[Sequence[int]], Sequence[int], int] = 1,
        groups: int = 1,
        dropouts: Union[Sequence[Union[float, dict]], float, dict] = 0.0,
        mode: str = "max",
        **config,
    ) -> None:
        super().__init__()
        self.__mode = mode.lower()
        assert self.__mode in self.__MODES__
        self.__down_scale = down_scale
        self.__in_channels = in_channels
        self.__out_channels = out_channels
        self.config = CFG(deepcopy(config))

        self.down_sample = DownSample(
            down_scale=self.__down_scale,
            in_channels=self.__in_channels,
            norm=False,
            mode=mode,
        )
        self.branched_conv = BranchedConv(
            in_channels=self.__in_channels,
            out_channels=self.__out_channels,
            filter_lengths=filter_lengths,
            subsample_lengths=1,
            dilations=dilations,
            groups=groups,
            dropouts=dropouts,
            **(self.config),
        )

    def forward(self, input: Tensor) -> Tensor:
        """Forward pass.

        Parameters
        ----------
        input : torch.Tensor
            Input tensor,
            of shape ``(batch_size, channels, seq_len)``.

        Returns
        -------
        out : torch.Tensor
            Output tensor,
            of shape ``(batch_size, channels, seq_len)``.

        """
        out = self.down_sample(input)
        out = self.branched_conv(out)
        # SUBTRACT
        # currently (micro scope) - (macro scope)
        # TODO: consider (macro scope) - (micro scope)
        out.append(out[0] - out[1])
        out = torch.cat(out, dim=1)  # concate along the channel axis
        return out

    def compute_output_shape(
        self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
    ) -> Sequence[Union[int, None]]:
        """Compute the output shape of the block.

        Parameters
        ----------
        seq_len : int, optional
            Length of the input tensor.
        batch_size : int, optional
            Batch size of the input tensor.

        Returns
        -------
        output_shape : sequence
            Output shape of the block.

        """
        _seq_len = seq_len
        output_shape = self.down_sample.compute_output_shape(seq_len=_seq_len)
        _, _, _seq_len = output_shape
        output_shapes = self.branched_conv.compute_output_shape(seq_len=_seq_len)
        # output_shape = output_shapes[0][0], sum([s[1] for s in output_shapes]), output_shapes[0][-1]
        n_branches = len(output_shapes)
        output_shape = (
            output_shapes[0][0],
            (n_branches + 1) * output_shapes[0][1],
            output_shapes[0][-1],
        )
        return output_shape


class UpTripleConv(nn.Module, SizeMixin):
    """Up sampling block in the U-Net architecture.

    Upscaling then double conv, with input of corr. down layer concatenated
    up sampling --> conv (conv --> (dropout -->) conv --> (dropout -->) conv)
        ^
        |
    extra input

    Channels are shrinked after up sampling.

    Parameters
    ----------
    up_scale : int
        Scale of up sampling.
    in_channels : int
        Number of channels in the input tensor.
    out_channels : int
        Number of channels produced by the convolutional layers.
    filter_lengths : int or Sequence[int]
        Length(s) of the filters (kernel size) of the convolutional layers.
    deconv_filter_length : int
        Length(s) of the filters (kernel size) of the
        deconvolutional upsampling layer,
        used when `mode` is ``"deconv"``.
    groups : int, default 1
        Connection pattern (of channels) of the inputs and outputs.
        Not used currently.
    dropouts : float or dict or Sequence[Union[float, dict]], default 0.0
        Dropout ratio after each :class:`Conv_Bn_Activation` block.
    mode : str, default "deconv"
        Mode of up sampling, case insensitive. Should be ``"deconv"``
        or methods supported by :class:`torch.nn.Upsample`.
    config : dict
        Other hyper-parameters, including
        activation choices, weight initializer, batch normalization choices, etc.
        for the deconvolutional layers.

    """

    __name__ = "UpTripleConv"
    __MODES__ = [
        "nearest",
        "linear",
        "area",
        "deconv",
    ]

    def __init__(
        self,
        up_scale: int,
        in_channels: int,
        out_channels: int,
        filter_lengths: Union[Sequence[int], int],
        deconv_filter_length: Optional[int] = None,
        groups: int = 1,
        dropouts: Union[Sequence[Union[float, dict]], float, dict] = 0.0,
        mode: str = "deconv",
        **config,
    ) -> None:
        super().__init__()
        self.__up_scale = up_scale
        self.__in_channels = in_channels
        self.__out_channels = out_channels
        self.__deconv_filter_length = deconv_filter_length
        self.__mode = mode.lower()
        assert self.__mode in self.__MODES__
        self.config = CFG(deepcopy(config))

        # the following has to be checked
        # if bilinear, use the normal convolutions to reduce the number of channels
        if self.__mode == "deconv":
            self.__deconv_padding = max(0, (self.__deconv_filter_length - self.__up_scale) // 2)
            self.up = nn.ConvTranspose1d(
                in_channels=self.__in_channels,
                out_channels=self.__in_channels,
                kernel_size=self.__deconv_filter_length,
                stride=self.__up_scale,
                padding=self.__deconv_padding,
            )
        else:
            self.up = nn.Upsample(
                scale_factor=self.__up_scale,
                mode=mode,
            )
        self.conv = TripleConv(
            # `+ self.__out_channels` corr. to the output of the corr. down layer
            in_channels=self.__in_channels + self.__out_channels[-1],
            out_channels=self.__out_channels,
            filter_lengths=filter_lengths,
            subsample_lengths=1,
            groups=groups,
            dropouts=dropouts,
            **(self.config),
        )

    def forward(self, input: Tensor, down_output: Tensor) -> Tensor:
        """Forward pass.

        Parameters
        ----------
        input : torch.Tensor
            Input tensor from the previous up sampling block.
        down_output : torch.Tensor
            Input tensor of the last layer of corr. down sampling block.

        Returns
        -------
        output : torch.Tensor
            Output tensor of this block.

        """
        output = self.up(input)
        output = torch.cat([down_output, output], dim=1)  # concate along the channel axis
        output = self.conv(output)

        return output

    def compute_output_shape(
        self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
    ) -> Sequence[Union[int, None]]:
        """Compute the output shape of the block.

        Parameters
        ----------
        seq_len : int, optional
            Length of the input tensor.
        batch_size : int, optional
            Batch size of the input tensor.

        Returns
        -------
        output_shape : sequence
            Output shape of the block.

        """
        _sep_len = seq_len
        if self.__mode == "deconv":
            output_shape = compute_deconv_output_shape(
                input_shape=[batch_size, self.__in_channels, _sep_len],
                num_filters=self.__in_channels,
                kernel_size=self.__deconv_filter_length,
                stride=self.__up_scale,
                padding=self.__deconv_padding,
            )
        else:
            output_shape = [batch_size, self.__in_channels, self.__up_scale * _sep_len]
        _, _, _seq_len = output_shape
        output_shape = self.conv.compute_output_shape(_seq_len, batch_size)
        return output_shape


[docs]class ECG_SUBTRACT_UNET(nn.Module, CkptMixin, SizeMixin): """U-Net for ECG wave delineation. Entry 0433 of CPSC2019, which is a modification of the U-Net using subtraction instead of addition in branched bottom block. Parameters ---------- classes : Sequence[str] List of names of the classes. n_leads : int Number of input leads (number of input channels). config : CFG, optional Other hyper-parameters, including kernel sizes, etc. Refer to the corresponding config file. """ __name__ = "ECG_SUBTRACT_UNET" def __init__( self, classes: Sequence[str], n_leads: int, config: Optional[CFG] = None, ) -> None: super().__init__() self.classes = list(classes) self.n_classes = len(classes) self.__out_channels = len(classes) self.__in_channels = n_leads self.config = deepcopy(ECG_SUBTRACT_UNET_CONFIG) if not config: warnings.warn("No config is provided, using default config.", RuntimeWarning) self.config.update(deepcopy(config) or {}) # TODO: an init batch normalization? if self.config.init_batch_norm: self.init_bn = nn.BatchNorm1d( num_features=self.__in_channels, eps=1e-5, # default val momentum=0.1, # default val ) self.init_conv = TripleConv( in_channels=self.__in_channels, out_channels=self.config.init_num_filters, filter_lengths=self.config.init_filter_length, subsample_lengths=1, groups=self.config.groups, dropouts=self.config.init_dropouts, batch_norm=self.config.batch_norm, activation=self.config.activation, kw_activation=self.config.kw_activation, kernel_initializer=self.config.kernel_initializer, kw_initializer=self.config.kw_initializer, ) self.down_blocks = nn.ModuleDict() in_channels = self.config.init_num_filters for idx in range(self.config.down_up_block_num - 1): self.down_blocks[f"down_{idx}"] = DownTripleConv( down_scale=self.config.down_scales[idx], in_channels=in_channels, out_channels=self.config.down_num_filters[idx], filter_lengths=self.config.down_filter_lengths[idx], groups=self.config.groups, dropouts=self.config.down_dropouts[idx], mode=self.config.down_mode, **(self.config.down_block), ) in_channels = self.config.down_num_filters[idx][-1] self.bottom_block = DownBranchedDoubleConv( down_scale=self.config.down_scales[-1], in_channels=in_channels, out_channels=self.config.bottom_num_filters, filter_lengths=self.config.bottom_filter_lengths, dilations=self.config.bottom_dilations, groups=self.config.groups, dropouts=self.config.bottom_dropouts, mode=self.config.down_mode, **(self.config.down_block), ) self.up_blocks = nn.ModuleDict() # in_channels = sum([branch[-1] for branch in self.config.bottom_num_filters]) in_channels = self.bottom_block.compute_output_shape(None, None)[1] for idx in range(self.config.down_up_block_num): self.up_blocks[f"up_{idx}"] = UpTripleConv( up_scale=self.config.up_scales[idx], in_channels=in_channels, out_channels=self.config.up_num_filters[idx], filter_lengths=self.config.up_conv_filter_lengths[idx], deconv_filter_length=self.config.up_deconv_filter_lengths[idx], groups=self.config.groups, mode=self.config.up_mode, dropouts=self.config.up_dropouts[idx], **(self.config.up_block), ) in_channels = self.config.up_num_filters[idx][-1] self.out_conv = Conv_Bn_Activation( in_channels=self.config.up_num_filters[-1][-1], out_channels=self.__out_channels, kernel_size=self.config.out_filter_length, stride=1, groups=self.config.groups, norm=self.config.get("out_norm", self.config.get("out_batch_norm")), activation=None, kernel_initializer=self.config.kernel_initializer, kw_initializer=self.config.kw_initializer, ) # for inference # if background counted in `classes`, use softmax # otherwise use sigmoid self.softmax = nn.Softmax(-1) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, input: Tensor) -> Tensor: """Forward pass of the model. Parameters ---------- input : torch.Tensor Input signal tensor, of shape ``(batch_size, n_channels, seq_len)``. Returns ------- output : torch.Tensor Output tensor, of shape ``(batch_size, n_channels, seq_len)``. """ if self.config.init_batch_norm: x = self.init_bn(input) else: x = input # down to_concat = [self.init_conv(x)] for idx in range(self.config.down_up_block_num - 1): to_concat.append(self.down_blocks[f"down_{idx}"](to_concat[-1])) to_concat.append(self.bottom_block(to_concat[-1])) # up up_input = to_concat[-1] to_concat = to_concat[-2::-1] for idx in range(self.config.down_up_block_num): up_output = self.up_blocks[f"up_{idx}"](up_input, to_concat[idx]) up_input = up_output # output output = self.out_conv(up_output) # to keep in accordance with other models # (batch_size, channels, seq_len) --> (batch_size, seq_len, channels) output = output.permute(0, 2, 1) # TODO: consider adding CRF at the tail to make final prediction return output
[docs] @torch.no_grad() def inference(self, input: Union[np.ndarray, Tensor], bin_pred_thr: float = 0.5) -> Tensor: """Method for making inference on a single input.""" raise NotImplementedError("implement a task specific inference method")
[docs] def compute_output_shape( self, seq_len: Optional[int] = None, batch_size: Optional[int] = None ) -> Sequence[Union[int, None]]: """Compute the output shape of the model. Parameters ---------- seq_len : int, optional Length of the input signal tensor. batch_size : int, optional Batch size of the input signal tensor. Returns ------- output_shape : sequence Output shape of the model. """ output_shape = (batch_size, seq_len, self.n_classes) return output_shape
@property def doi(self) -> List[str]: # TODO: add doi return list(set(self.config.get("doi", []) + []))