Source code for torch_ecg.models.cnn.xception

"""
famous for its use of separable convolutions,
usually the SOTA image classifier,
however seems not have been used in physiological signal processing tasks
"""

import textwrap
from copy import deepcopy
from itertools import repeat
from numbers import Real
from typing import List, Optional, Sequence, Union

from torch import Tensor, nn

from ...cfg import CFG
from ...models._nets import DownSample, MultiConv
from ...utils.misc import CitationMixin, add_docstring
from ...utils.utils_nn import SizeMixin, compute_sequential_output_shape, compute_sequential_output_shape_docstring

__all__ = [
    "Xception",
    "XceptionEntryFlow",
    "XceptionMiddleFlow",
    "XceptionExitFlow",
    "XceptionMultiConv",
]


if not hasattr(nn, "Dropout1d"):
    nn.Dropout1d = nn.Dropout  # added in pytorch 1.12


_DEFAULT_CONV_CONFIGS = CFG(
    ordering="acb",
    conv_type="separable",
    batch_norm=True,
    subsample_mode="max",
    activation="relu",
    kw_activation={"inplace": True},
    kernel_initializer="he_normal",
    kw_initializer={},
)


class XceptionMultiConv(nn.Module, SizeMixin, CitationMixin):
    """Xception multi-convolutional block.

    -> n(2 or 3) x (activation -> norm -> sep_conv) (-> optional sub-sample) ->
    |-------------------------------- shortcut ------------------------------|

    Parameters
    ----------
    in_channels : int
        Number of channels in the input tensor.
    num_filters : Sequence[int]
        Number of channels produced by the main stream convolutions.
        The length of `num_filters` indicates the number of convolutions.
    filter_lengths : int or Sequence[int]
        Length(s) of the filters (kernel size).
    subsample_length : int, default 1
        Stride of the main stream subsample layer.
    subsample_kernel : int, optional
        Kernel size of the main stream subsample layer.
        If not set, defaults to `subsample_length`.
    dilations : int or Sequence[int], default 1
        Dilation(s) of the convolutions.
    groups : int, default 1
        Connection pattern (of channels) of the inputs and outputs.
    dropouts : float or dict or Sequence[float] or Sequence[dict], default 0.0
        If is dictionary, it should contain the keys ``"p"`` and ``"type"``,
        where ``"p"`` is the dropout rate and ``"type"`` is the type of dropout,
        which can be either ``"1d"`` (:class:`torch.nn.Dropout1d`) or
        ``None`` (:class:`torch.nn.Dropout`).
    config : dict, optional
        Other parameters, including
        activation choices, weight initializer, batch normalization choices, etc.,
        for the convolutional layers,
        and subsampling modes for subsampling layers, etc.

    """

    __name__ = "XceptionMultiConv"

    def __init__(
        self,
        in_channels: int,
        num_filters: Sequence[int],
        filter_lengths: Union[Sequence[int], int],
        subsample_length: int = 1,
        subsample_kernel: Optional[int] = None,
        dilations: Union[Sequence[int], int] = 1,
        groups: int = 1,
        dropouts: Union[Sequence[float], float] = 0.0,
        **config,
    ) -> None:
        super().__init__()
        self.__in_channels = in_channels
        self.__num_filters = list(num_filters)
        self.__num_convs = len(self.__num_filters)
        if isinstance(filter_lengths, int):
            self.__filter_lengths = list(repeat(filter_lengths, self.__num_convs))
        else:
            self.__filter_lengths = list(filter_lengths)
        assert self.__num_convs == len(self.__filter_lengths), (
            f"the main stream has {self.__num_convs} convolutions, "
            f"while `filter_lengths` indicates {len(self.__filter_lengths)}"
        )
        self.__subsample_length = subsample_length
        self.__subsample_kernel = subsample_kernel or subsample_length
        self.__groups = groups
        if isinstance(dilations, int):
            self.__dilations = list(repeat(dilations, self.__num_convs))
        else:
            self.__dilations = list(dilations)
        assert self.__num_convs == len(self.__dilations), (
            f"the main stream has {self.__num_convs} convolutions, " f"while `dilations` indicates {len(self.__dilations)}"
        )
        if isinstance(dropouts, (Real, dict)):
            self.__dropouts = list(repeat(dropouts, self.__num_convs))
        else:
            self.__dropouts = list(dropouts)
        assert self.__num_convs == len(self.__dropouts), (
            f"the main stream has {self.__num_convs} convolutions, " f"while `dropouts` indicates {len(self.__dropouts)}"
        )
        self.config = CFG(deepcopy(_DEFAULT_CONV_CONFIGS))
        self.config.update(deepcopy(config))

        self.main_stream_conv = MultiConv(
            in_channels=self.__in_channels,
            out_channels=self.__num_filters,
            filter_lengths=self.__filter_lengths,
            subsample_lengths=1,
            dilations=self.__dilations,
            groups=self.__groups,
            dropouts=self.__dropouts,
            **self.config,
        )
        if self.__subsample_length > 1:
            self.subsample = DownSample(
                down_scale=self.__subsample_length,
                in_channels=self.__num_filters[-1],
                kernel_size=self.__subsample_kernel,
                groups=self.__groups,
                padding=(self.__subsample_kernel - 1) // 2,
                mode=self.config.subsample_mode,
            )
            self.shortcut = DownSample(
                down_scale=self.__subsample_length,
                in_channels=self.__in_channels,
                out_channels=self.__num_filters[-1],
                groups=self.__groups,
                kernel_size=1,
                norm=self.config.get("norm", self.config.get("batch_norm")),
                mode="conv",
            )
        else:
            self.subsample = nn.Identity()
            self.shortcut = nn.Identity()

    def forward(self, input: Tensor) -> Tensor:
        """Forward pass.

        Parameters
        ----------
        input : torch.Tensor
            Input tensor,
            of shape ``(batch_size, n_channels, seq_len)``.

        Returns
        -------
        output : torch.Tensor
            Output tensor,
            of shape ``(batch_size, n_channels, seq_len)``.

        """
        main_out = self.main_stream_conv(input)
        main_out = self.subsample(main_out)
        residue = input
        residue = self.shortcut(residue)
        output = residue + main_out
        return output

    def compute_output_shape(
        self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
    ) -> Sequence[Union[int, None]]:
        """Compute the output shape of the module.

        Parameters
        ----------
        seq_len : int, optional
            Length of the input tensor.
        batch_size : int, optional
            Batch size of the input tensor.

        Returns
        -------
        output_shape : sequence
            The output shape of the module.

        """
        output_shape = self.main_stream_conv.compute_output_shape(seq_len, batch_size)
        if not isinstance(self.subsample, nn.Identity):
            output_shape = self.subsample.compute_output_shape(output_shape[-1], output_shape[0])
        return output_shape


class XceptionEntryFlow(nn.Sequential, SizeMixin):
    """Entry flow of the Xception model.

    The entry flow is consisting of
    2 initial convolutions which subsamples at the first one,
    followed by several Xception blocks of 2 convolutions
    and of sub-sampling size 2.

    Parameters
    ----------
    in_channels : int
        Number of channels in the input signal.
    init_num_filters : Sequence[int]
        Number of filters (output channels) of the initial convolutions.
    init_filter_lengths : int or Sequence[int]
        Filter length(s) (kernel size(s)) of the initial convolutions.
    init_subsample_lengths : int or Sequence[int]
        Subsampling length(s) (stride(s)) of the initial convolutions.
    num_filters : Sequence[int] or Sequence[Sequence[int]]
        Number of filters of the convolutions of Xception blocks.
    filter_lengths : int or Sequence[int] or Sequence[Sequence[int]]
        Filter length(s) of the convolutions of Xception blocks.
    subsample_lengths : int or Sequence[int]
        Subsampling length(s) of the Xception blocks.
    subsample_kernels : int or Sequence[int], optional
        Subsampling kernel size(s) of the Xception blocks.
    dilations : int or Sequence[int] or Sequence[Sequence[int]], default 1
        Dilation(s) of the convolutions of Xception blocks.
    groups : int, default 1
        Connection pattern (of channels) of the inputs and outputs.
    dropouts : float or dict or Sequence[Union[float, dict]] or Sequence[Sequence[Union[float, dict]]], default 0.0
        Dropout(s) after each :class:`Conv_Bn_Activation` blocks
        in the Xception blocks.
    block_dropouts : float or dict or Sequence[Union[float, dict]], default 0.0
        Dropout(s) after the Xception blocks.
    config : dict, optional
        Other parameters, Xception blocks and initial convolutions, including
        activation choices, weight initializer, batch normalization choices, etc.
        for the convolutional layers,
        and subsampling modes for subsampling layers, etc.

    """

    __name__ = "XceptionEntryFlow"

    def __init__(
        self,
        in_channels: int,
        init_num_filters: Sequence[int],
        init_filter_lengths: Union[int, Sequence[int]],
        init_subsample_lengths: Union[int, Sequence[int]],
        num_filters: Union[Sequence[int], Sequence[Sequence[int]]],
        filter_lengths: Union[int, Sequence[int], Sequence[Sequence[int]]],
        subsample_lengths: Union[int, Sequence[int]],
        subsample_kernels: Optional[Union[int, Sequence[int]]] = None,
        dilations: Union[int, Sequence[int], Sequence[Sequence[int]]] = 1,
        groups: int = 1,
        dropouts: Union[
            float,
            dict,
            Sequence[Union[float, dict]],
            Sequence[Sequence[Union[float, dict]]],
        ] = 0.0,
        block_dropouts: Union[float, dict, Sequence[Union[float, dict]]] = 0.0,
        **config,
    ) -> None:
        super().__init__()
        self.__in_channels = in_channels
        self.__num_filters = list(num_filters)
        self.__num_blocks = len(self.__num_filters)
        if isinstance(filter_lengths, int):
            self.__filter_lengths = list(repeat(filter_lengths, self.__num_blocks))
        else:
            self.__filter_lengths = list(filter_lengths)
        assert self.__num_blocks == len(self.__filter_lengths), (
            f"the entry flow has {self.__num_blocks} blocks, " f"while `filter_lengths` indicates {len(self.__filter_lengths)}"
        )
        if isinstance(subsample_lengths, int):
            self.__subsample_lengths = list(repeat(subsample_lengths, self.__num_blocks))
        else:
            self.__subsample_lengths = list(subsample_lengths)
        assert self.__num_blocks == len(self.__subsample_lengths), (
            f"the entry flow has {self.__num_blocks} blocks, "
            f"while `subsample_lengths` indicates {len(self.__subsample_lengths)}"
        )
        if subsample_kernels is None:
            self.__subsample_kernels = deepcopy(self.__subsample_lengths)
        elif isinstance(subsample_kernels, int):
            self.__subsample_kernels = list(repeat(subsample_kernels, self.__num_blocks))
        else:
            self.__subsample_kernels = list(subsample_kernels)
        assert self.__num_blocks == len(self.__subsample_kernels), (
            f"the entry flow has {self.__num_blocks} blocks, "
            f"while `subsample_kernels` indicates {len(self.__subsample_kernels)}"
        )
        if isinstance(dilations, int):
            self.__dilations = list(repeat(dilations, self.__num_blocks))
        else:
            self.__dilations = list(dilations)
        assert self.__num_blocks == len(self.__dilations), (
            f"the entry flow has {self.__num_blocks} blocks, " f"while `dilations` indicates {len(self.__dilations)}"
        )
        if isinstance(dropouts, (Real, dict)):
            self.__dropouts = list(repeat(dropouts, self.__num_blocks))
        else:
            self.__dropouts = list(dropouts)
        assert self.__num_blocks == len(self.__dropouts), (
            f"the entry flow has {self.__num_blocks} blocks, " f"while `dropouts` indicates {len(self.__dropouts)}"
        )
        if isinstance(block_dropouts, (Real, dict)):
            self.__block_dropouts = list(repeat(block_dropouts, self.__num_blocks))
        else:
            self.__block_dropouts = list(block_dropouts)
        assert self.__num_blocks == len(self.__block_dropouts), (
            f"the entry flow has {self.__num_blocks} blocks, except the initial convolutions, "
            f"while `block_dropouts` indicates {len(self.__block_dropouts)}"
        )
        self.__groups = groups
        self.config = CFG(deepcopy(_DEFAULT_CONV_CONFIGS))
        self.config.update(deepcopy(config))

        self.add_module(
            "init_convs",
            MultiConv(
                in_channels=self.__in_channels,
                out_channels=init_num_filters,
                filter_lengths=init_filter_lengths,
                subsample_lengths=init_subsample_lengths,
                groups=groups,
                activation=self.config.activation,
            ),
        )

        block_in_channels = init_num_filters[-1]
        for idx, nf in enumerate(self.__num_filters):
            # in the case of ordering of "acb",
            # `out_activation` is indeed `in_activation`
            if idx == 0:
                out_activation = False
            else:
                out_activation = True
            # number of main stream convolution defaults to 2
            if isinstance(nf, int):
                block_out_channels = list(repeat(nf, 2))
            else:
                block_out_channels = list(nf)
            self.add_module(
                f"entry_flow_conv_block_{idx}",
                XceptionMultiConv(
                    in_channels=block_in_channels,
                    num_filters=block_out_channels,
                    filter_lengths=self.__filter_lengths[idx],
                    subsample_length=self.__subsample_lengths[idx],
                    subsample_kernel=self.__subsample_kernels[idx],
                    dilations=self.__dilations[idx],
                    groups=self.__groups,
                    dropouts=self.__dropouts[idx],
                    out_activation=out_activation,
                    **self.config,
                ),
            )
            block_in_channels = block_out_channels[-1]
            if isinstance(self.__block_dropouts[idx], dict):
                if self.__block_dropouts[idx]["type"] == "1d" and self.__block_dropouts[idx]["p"] > 0:
                    self.add_module(
                        f"entry_flow_dropout_{idx}",
                        nn.Dropout1d(self.__block_dropouts[idx]["p"]),
                    )
                elif self.__block_dropouts[idx]["type"] is None and self.__block_dropouts[idx]["p"] > 0:
                    self.add_module(
                        f"entry_flow_dropout_{idx}",
                        nn.Dropout(self.__block_dropouts[idx]["p"]),
                    )
            elif self.__block_dropouts[idx] > 0:
                self.add_module(f"entry_flow_dropout_{idx}", nn.Dropout(self.__block_dropouts[idx]))

    def forward(self, input: Tensor) -> Tensor:
        """Forward pass of the entry flow.

        Parameters
        ----------
        input : torch.Tensor
            Input tensor,
            of shape ``(batch_size, n_channels, seq_len)``.

        Returns
        -------
        output : torch.Tensor
            Output tensor,
            of shape ``(batch_size, n_channels, seq_len)``.

        """
        output = super().forward(input)
        return output

    def compute_output_shape(
        self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
    ) -> Sequence[Union[int, None]]:
        """Compute the output shape of the entry flow.

        Parameters
        ----------
        seq_len : int, optional
            Length of the input tensors.
        batch_size : int, optional
            Batch size of the input tensors.

        Returns
        -------
        output_shape : sequence
            The output shape of the entry flow.

        """
        _seq_len = seq_len
        for module in self:
            if isinstance(module, (nn.Dropout, nn.Dropout1d)):
                continue
            output_shape = module.compute_output_shape(_seq_len, batch_size)
            _, _, _seq_len = output_shape
        return output_shape


class XceptionMiddleFlow(nn.Sequential, SizeMixin):
    """Middle flow of the Xception model.

    Middle flow consists of several Xception blocks of 3 convolutions
    and without sub-sampling.

    Parameters
    ----------
    in_channels : int
        Number of channels in the input tensor.
    num_filters: Sequence[int] or Sequence[Sequence[int]]
        Number of filters (output channels) of the convolutions
        of Xception blocks.
    filter_lengths : int or Sequence[int] or Sequence[Sequence[int]]
        Filter length(s) of the convolutions of Xception blocks.
    dilations : int or Sequence[int] or Sequence[Sequence[int]], default 1
        Dilation(s) of the convolutions of Xception blocks.
    groups : int, default 1
        Connection pattern (of channels) of the inputs and outputs.
    dropouts : float or dict or Sequence[Union[float, dict]] or Sequence[Sequence[Union[float, dict]]], default 0.0
        Dropout(s) after each :class:`Conv_Bn_Activation` blocks
        in the Xception blocks.
    block_dropouts : float or dict or Sequence[Union[float, dict]], default 0.0
        Dropout(s) after the Xception blocks
    config : dict, optional
        Other parameters for Xception blocks, including
        activation choices, weight initializer, batch normalization choices, etc.
        for the convolutional layers,
        and subsampling modes for subsampling layers, etc.

    """

    __name__ = "XceptionMiddleFlow"

    def __init__(
        self,
        in_channels: int,
        num_filters: Union[Sequence[int], Sequence[Sequence[int]]],
        filter_lengths: Union[int, Sequence[int], Sequence[Sequence[int]]],
        dilations: Union[int, Sequence[int], Sequence[Sequence[int]]] = 1,
        groups: int = 1,
        dropouts: Union[
            float,
            dict,
            Sequence[Union[float, dict]],
            Sequence[Sequence[Union[float, dict]]],
        ] = 0.0,
        block_dropouts: Union[float, dict, Sequence[Union[float, dict]]] = 0.0,
        **config,
    ) -> None:
        super().__init__()
        self.__in_channels = in_channels
        self.__num_filters = list(num_filters)
        self.__num_blocks = len(self.__num_filters)
        if isinstance(filter_lengths, int):
            self.__filter_lengths = list(repeat(filter_lengths, self.__num_blocks))
        else:
            self.__filter_lengths = list(filter_lengths)
        assert self.__num_blocks == len(self.__filter_lengths), (
            f"the middle flow has {self.__num_blocks} blocks, " f"while `filter_lengths` indicates {len(self.__filter_lengths)}"
        )
        if isinstance(dilations, int):
            self.__dilations = list(repeat(dilations, self.__num_blocks))
        else:
            self.__dilations = list(dilations)
        assert self.__num_blocks == len(self.__dilations), (
            f"the middle flow has {self.__num_blocks} blocks, " f"while `dilations` indicates {len(self.__dilations)}"
        )
        if isinstance(dropouts, (Real, dict)):
            self.__dropouts = list(repeat(dropouts, self.__num_blocks))
        else:
            self.__dropouts = list(dropouts)
        assert self.__num_blocks == len(self.__dropouts), (
            f"the middle flow has {self.__num_blocks} blocks, " f"while `dropouts` indicates {len(self.__dropouts)}"
        )
        if isinstance(block_dropouts, (Real, dict)):
            self.__block_dropouts = list(repeat(block_dropouts, self.__num_blocks))
        else:
            self.__block_dropouts = list(block_dropouts)
        assert self.__num_blocks == len(self.__block_dropouts), (
            f"the middle flow has {self.__num_blocks} blocks, " f"while `block_dropouts` indicates {len(self.__block_dropouts)}"
        )
        self.__groups = groups
        self.config = CFG(deepcopy(_DEFAULT_CONV_CONFIGS))
        self.config.update(deepcopy(config))

        block_in_channels = self.__in_channels
        for idx, nf in enumerate(self.__num_filters):
            # number of main stream convolution defaults to 3
            if isinstance(nf, int):
                block_out_channels = list(repeat(nf, 3))
            else:
                block_out_channels = list(nf)
            self.add_module(
                f"middle_flow_conv_block_{idx}",
                XceptionMultiConv(
                    in_channels=block_in_channels,
                    num_filters=block_out_channels,
                    filter_lengths=self.__filter_lengths[idx],
                    dilations=self.__dilations[idx],
                    groups=self.__groups,
                    dropouts=self.__dropouts[idx],
                    **self.config,
                ),
            )
            block_in_channels = block_out_channels[-1]
            if isinstance(self.__block_dropouts[idx], dict):
                if self.__block_dropouts[idx]["type"] == "1d" and self.__block_dropouts[idx]["p"] > 0:
                    self.add_module(
                        f"middle_flow_dropout_{idx}",
                        nn.Dropout1d(self.__block_dropouts[idx]["p"]),
                    )
                elif self.__block_dropouts[idx]["type"] is None and self.__block_dropouts[idx]["p"] > 0:
                    self.add_module(
                        f"middle_flow_dropout_{idx}",
                        nn.Dropout(self.__block_dropouts[idx]["p"]),
                    )
            elif self.__block_dropouts[idx] > 0:
                self.add_module(f"middle_flow_dropout_{idx}", nn.Dropout(self.__block_dropouts[idx]))

    def forward(self, input: Tensor) -> Tensor:
        """Forward pass of the middle flow.

        Parameters
        ----------
        input : torch.Tensor
            Input tensor,
            of shape ``(batch_size, n_channels, seq_len)``.

        Returns
        -------
        output : torch.Tensor
            Output tensor,
            of shape ``(batch_size, n_channels, seq_len)``.

        """
        output = super().forward(input)
        return output

    def compute_output_shape(
        self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
    ) -> Sequence[Union[int, None]]:
        """Compute the output shape of the middle flow.

        Parameters
        ----------
        seq_len : int, optional
            Length of the input tensors.
        batch_size : int, optional
            Batch size of the input tensors.

        Returns
        -------
        output_shape : sequence
            The output shape of the middle flow.

        """
        _seq_len = seq_len
        for module in self:
            if isinstance(module, (nn.Dropout, nn.Dropout1d)):
                continue
            output_shape = module.compute_output_shape(_seq_len, batch_size)
            _, _, _seq_len = output_shape
        return output_shape


class XceptionExitFlow(nn.Sequential, SizeMixin):
    """Exit flow of the Xception model.

    Exit flow consists of several Xception blocks of 2 convolutions,
    followed by several separable convolutions.

    Parameters
    ----------
    in_channels : int
        Number of channels in the input tensor.
    final_num_filters : Sequence[int]
        Number of filters (output channels) of the final convolutions.
    final_filter_lengths : int or Sequence[int]
        Filter length(s) of the convolutions of the final convolutions.
    final_subsample_lengths : int or Sequence[int]
        Subsampling length(s) (stride(s)) of the final convolutions.
    num_filters : Sequence[int] or Sequence[Sequence[int]]
        Number of filters of the convolutions of Xception blocks.
    filter_lengths : int or Sequence[int] or Sequence[Sequence[int]]
        Filter length(s) of the convolutions of Xception blocks.
    subsample_lengths : int or Sequence[int]
        Subsampling length(s) of the Xception blocks.
    subsample_kernels : int or Sequence[int], optional
        Subsampling kernel size(s) of the Xception blocks.
    dilations: int or Sequence[int] or Sequence[Sequence[int]], default 1
        Dilation(s) of the convolutions of Xception blocks.
    groups : int, default 1
        Connection pattern (of channels) of the inputs and outputs.
    dropouts : float or dict or Sequence[Union[float, dict]] or Sequence[Sequence[Union[float, dict]]], default 0.0
        Dropout(s) after each :class:`Conv_Bn_Activation` blocks
        in the Xception blocks.
    block_dropouts : float or dict or Sequence[Union[float, dict]], default 0.0
        Dropout(s) after each of the Xception blocks
        and each of the final convolutions.
    config : dict, optional
        Other parameters for Xception blocks and final convolutions, including
        activation choices, weight initializer, batch normalization choices, etc.
        for the convolutional layers,
        and subsampling modes for subsampling layers, etc.

    """

    __name__ = "XceptionExitFlow"

    def __init__(
        self,
        in_channels: int,
        final_num_filters: Sequence[int],
        final_filter_lengths: Union[int, Sequence[int]],
        num_filters: Union[Sequence[int], Sequence[Sequence[int]]],
        filter_lengths: Union[int, Sequence[int], Sequence[Sequence[int]]],
        subsample_lengths: Union[int, Sequence[int]],
        subsample_kernels: Optional[Union[int, Sequence[int]]] = None,
        dilations: Union[int, Sequence[int], Sequence[Sequence[int]]] = 1,
        groups: int = 1,
        dropouts: Union[
            float,
            dict,
            Sequence[Union[float, dict]],
            Sequence[Sequence[Union[float, dict]]],
        ] = 0.0,
        block_dropouts: Union[float, dict, Sequence[Union[float, dict]]] = 0.0,
        **config,
    ) -> None:
        super().__init__()
        self.__in_channels = in_channels
        self.__num_filters = list(num_filters)
        self.__num_blocks = len(self.__num_filters)
        if isinstance(filter_lengths, int):
            self.__filter_lengths = list(repeat(filter_lengths, self.__num_blocks))
        else:
            self.__filter_lengths = list(filter_lengths)
        assert self.__num_blocks == len(
            self.__filter_lengths
        ), f"the exit flow has {self.__num_blocks} blocks, while `filter_lengths` indicates {len(self.__filter_lengths)}"
        if isinstance(subsample_lengths, int):
            self.__subsample_lengths = list(repeat(subsample_lengths, self.__num_blocks))
        else:
            self.__subsample_lengths = list(subsample_lengths)
        assert self.__num_blocks == len(self.__subsample_lengths), (
            f"the exit flow has {self.__num_blocks} blocks, "
            f"while `subsample_lengths` indicates {len(self.__subsample_lengths)}"
        )
        if subsample_kernels is None:
            self.__subsample_kernels = deepcopy(self.__subsample_lengths)
        elif isinstance(subsample_kernels, int):
            self.__subsample_kernels = list(repeat(subsample_kernels, self.__num_blocks))
        else:
            self.__subsample_kernels = list(subsample_kernels)
        assert self.__num_blocks == len(self.__subsample_kernels), (
            f"the exit flow has {self.__num_blocks} blocks, "
            f"while `subsample_kernels` indicates {len(self.__subsample_kernels)}"
        )
        if isinstance(dilations, int):
            self.__dilations = list(repeat(dilations, self.__num_blocks))
        else:
            self.__dilations = list(dilations)
        assert self.__num_blocks == len(self.__dilations), (
            f"the exit flow has {self.__num_blocks} blocks, " f"while `dilations` indicates {len(self.__dilations)}"
        )
        if isinstance(dropouts, (Real, dict)):
            self.__dropouts = list(repeat(dropouts, self.__num_blocks))
        else:
            self.__dropouts = list(dropouts)
        assert self.__num_blocks == len(self.__dropouts), (
            f"the exit flow has {self.__num_blocks} blocks, " f"while `dropouts` indicates {len(self.__dropouts)}"
        )
        if isinstance(block_dropouts, (Real, dict)):
            self.__block_dropouts = list(repeat(block_dropouts, self.__num_blocks + len(final_num_filters)))
        else:
            self.__block_dropouts = list(block_dropouts)
        assert self.__num_blocks + len(final_num_filters) == len(self.__block_dropouts), (
            f"the exit flow has {self.__num_blocks + len(final_num_filters)} blocks, "
            f"including the final convolutions, while `block_dropouts` indicates {len(self.__block_dropouts)}"
        )
        self.__groups = groups
        self.config = CFG(deepcopy(_DEFAULT_CONV_CONFIGS))
        self.config.update(deepcopy(config))

        block_in_channels = self.__in_channels
        for idx, nf in enumerate(self.__num_filters):
            # number of main stream convolution defaults to 2
            if isinstance(nf, int):
                block_out_channels = list(repeat(nf, 2))
            else:
                block_out_channels = list(nf)
            self.add_module(
                f"exit_flow_conv_block_{idx}",
                XceptionMultiConv(
                    in_channels=block_in_channels,
                    num_filters=block_out_channels,
                    filter_lengths=self.__filter_lengths[idx],
                    subsample_length=self.__subsample_lengths[idx],
                    subsample_kernel=self.__subsample_kernels[idx],
                    dilations=self.__dilations[idx],
                    groups=self.__groups,
                    dropouts=self.__dropouts[idx],
                    **self.config,
                ),
            )
            block_in_channels = block_out_channels[-1]
            if isinstance(self.__block_dropouts[idx], dict):
                if self.__block_dropouts[idx]["type"] == "1d" and self.__block_dropouts[idx]["p"] > 0:
                    self.add_module(
                        f"exit_flow_dropout_{idx}",
                        nn.Dropout1d(self.__block_dropouts[idx]["p"]),
                    )
                elif self.__block_dropouts[idx]["type"] is None and self.__block_dropouts[idx]["p"] > 0:
                    self.add_module(
                        f"exit_flow_dropout_{idx}",
                        nn.Dropout(self.__block_dropouts[idx]["p"]),
                    )
            elif self.__block_dropouts[idx] > 0:
                self.add_module(f"exit_flow_dropout_{idx}", nn.Dropout(self.__block_dropouts[idx]))

        self.add_module(
            "final_convs",
            MultiConv(
                in_channels=block_in_channels,
                out_channels=final_num_filters,
                filter_lengths=final_filter_lengths,
                groups=groups,
                conv_type="separable",
                activation=self.config.activation,
            ),
        )

    def forward(self, input: Tensor) -> Tensor:
        """Forward pass of the exit flow.

        Parameters
        ----------
        input : torch.Tensor
            Input tensor,
            of shape ``(batch_size, n_channels, seq_len)``.

        Returns
        -------
        output : torch.Tensor
            Output tensor,
            of shape ``(batch_size, n_channels, seq_len)``.

        """
        output = super().forward(input)
        return output

    def compute_output_shape(
        self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
    ) -> Sequence[Union[int, None]]:
        """Compute the output shape of the exit flow.

        Parameters
        ----------
        seq_len : int, optional
            Length of the input tensors.
        batch_size : int, optional
            Batch size of the input tensors.

        Returns
        -------
        output_shape : sequence
            The output shape of the exit flow.

        """
        _seq_len = seq_len
        for module in self:
            if isinstance(module, (nn.Dropout, nn.Dropout1d)):
                continue
            output_shape = module.compute_output_shape(_seq_len, batch_size)
            _, _, _seq_len = output_shape
        return output_shape


[docs]class Xception(nn.Sequential, SizeMixin, CitationMixin): """Xception model. Xception is an architecture that uses depthwise separable convolutions to build light-weight deep neural networks, as described in [1]_. Its official implementation is available in [2]_, and a PyTorch implementation is available in [3]_. Xception is currently not widely used in the field of ECG analysis, but has the potential to be highly effective for this task. Parameters ---------- in_channels : int Number of channels in the input. config : dict Other hyper-parameters of the Module, ref. corr. config file. For keyword arguments that must be set in 3 sub-dict, namely in "entry_flow", "middle_flow", and "exit_flow", refer to corr. docstring of each class. References ---------- .. [1] Chollet, François. "Xception: Deep learning with depthwise separable convolutions." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. .. [2] https://github.com/keras-team/keras-applications/blob/master/keras_applications/xception.py .. [3] https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py """ __name__ = "Xception" def __init__(self, in_channels: int, **config) -> None: super().__init__() self.__in_channels = in_channels self.config = CFG(deepcopy(config)) entry_flow_in_channels = self.__in_channels entry_flow = XceptionEntryFlow(in_channels=entry_flow_in_channels, **(self.config.entry_flow)) self.add_module("entry_flow", entry_flow) _, middle_flow_in_channels, _ = entry_flow.compute_output_shape() middle_flow = XceptionMiddleFlow(in_channels=middle_flow_in_channels, **(self.config.middle_flow)) self.add_module("middle_flow", middle_flow) _, exit_flow_in_channels, _ = middle_flow.compute_output_shape() exit_flow = XceptionExitFlow(in_channels=exit_flow_in_channels, **(self.config.exit_flow)) self.add_module( "exit_flow", exit_flow, )
[docs] def forward(self, input: Tensor) -> Tensor: """Forward pass of the model. Parameters ---------- input : torch.Tensor. Input signal tensor, of shape ``(batch_size, n_channels, seq_len)``. Returns ------- output : torch.Tensor Output tensor, of shape ``(batch_size, n_channels, seq_len)``. """ output = super().forward(input) return output
[docs] @add_docstring( textwrap.indent(compute_sequential_output_shape_docstring, " " * 4), mode="append", ) def compute_output_shape( self, seq_len: Optional[int] = None, batch_size: Optional[int] = None ) -> Sequence[Union[int, None]]: """Compute the output shape the model.""" return compute_sequential_output_shape(self, seq_len, batch_size)
@property def in_channels(self) -> int: return self.__in_channels @property def doi(self) -> List[str]: return list(set(self.config.get("doi", []) + ["10.1109/cvpr.2017.195"]))