Source code for torch_ecg.models.cnn.regnet

"""
Designing Network Design Spaces
"""

import math
import textwrap
import warnings
from collections import Counter
from itertools import repeat
from numbers import Real
from typing import List, Optional, Sequence, Union

import torch
from torch import nn

from ...cfg import CFG
from ...models._nets import Conv_Bn_Activation, DownSample, SpaceToDepth
from ...utils.misc import CitationMixin, add_docstring
from ...utils.utils_nn import SizeMixin, compute_sequential_output_shape, compute_sequential_output_shape_docstring
from .resnet import ResNetBasicBlock, ResNetBottleNeck


class AnyStage(nn.Sequential, SizeMixin):
    """AnyStage of :class:`RegNet`.

    Parameters
    ----------
    in_channels : int
        Number of features (channels) of the input.
    num_filters : Sequence[int]
        Number of filters for the neck conv layer.
    filter_length : int
        Lengths (sizes) of the filter kernels for the neck conv layer.
    subsample_length : int
        Subsample length, including pool size for short cut,
        and stride for the (top or neck) conv layer.
    num_blocks : int
        Number of blocks in the stage.
    group_width : int
        Group width for the bottleneck block.
    stage_index : int
        Index of the stage in the whole :class:`RegNet`.
    block_config: dict,
        (optional) configs for the blocks, including
            - block: str or torch.nn.Module,
              the block class, can be one of
              "bottleneck", "bottle_neck", :class:`ResNetBottleNeck`, etc.
            - expansion: int,
              the expansion factor for the bottleneck block.
            - increase_channels_method: str,
              the method to increase the number of channels,
              can be one of {"conv", "zero_padding"}.
            - subsample_mode: str,
              the mode of subsampling, can be one of
              {:class:`DownSample`.__MODES__},
            - activation: str or torch.nn.Module,
              the activation function, can be one of
              {:class:`Activations`}.
            - kw_activation: dict,
              keyword arguments for the activation function.
            - kernel_initializer: str,
              the kernel initializer, can be one of
              {:class:`Initializers`}.
            - kw_initializer: dict,
              keyword arguments for the kernel initializer.
            - bias: bool,
              whether to use bias in the convolution.
            - dilation: int,
              the dilation factor for the convolution.
            - base_width: int,
              number of filters per group for the neck conv layer
              usually number of filters of the initial conv layer
              of the whole :class:`RegNet`.
            - base_groups: int,
              pattern of connections between inputs and outputs of
              conv layers at the two ends, which should divide `groups`.
            - base_filter_length: int,
              lengths (sizes) of the filter kernels for conv layers at the two ends.
            - attn: dict,
              attention mechanism for the neck conv layer.
              If is None, no attention mechanism is used.
              If is not None, it should be a dict with the following items:

                - name: str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
                - pos: int, position of the attention mechanism.

              Other keys are specific to the attention mechanism.

    """

    __name__ = "AnyStage"
    __DEFAULT_BLOCK_CONFIG__ = {
        "block": "bottleneck",
        "expansion": 1,
        "increase_channels_method": "conv",
        "subsample_mode": "conv",
        "activation": "relu",
        "kw_activation": {"inplace": True},
        "kernel_initializer": "he_normal",
        "kw_initializer": {},
        "bias": False,
    }

    def __init__(
        self,
        in_channels: int,
        num_filters: int,
        filter_length: int,
        subsample_length: int,
        num_blocks: int,
        group_width: int,
        stage_index: int,
        **block_config,
    ) -> None:
        super().__init__()

        self.block_config = CFG(self.__DEFAULT_BLOCK_CONFIG__.copy())
        self.block_config.update(block_config)

        block_cls = self.get_building_block_cls(self.block_config)

        # adjust num_filters based on group_width
        if num_filters % group_width != 0:
            _num_filters = num_filters // group_width * group_width
            if _num_filters < 0.9 * num_filters:
                _num_filters += group_width
            num_filters = _num_filters
        groups = num_filters // group_width
        base_width = block_cls.__DEFAULT_BASE_WIDTH__ / groups

        block_in_channels = in_channels
        for i in range(num_blocks):
            block = block_cls(
                in_channels=block_in_channels,
                num_filters=num_filters,
                filter_length=filter_length,
                subsample_length=subsample_length if i == 0 else 1,
                groups=groups,
                base_width=base_width,
                **self.block_config,
            )
            block_in_channels = block.compute_output_shape()[1]
            self.add_module(f"block_{stage_index}_{i}", block)

    @staticmethod
    def get_building_block_cls(config: CFG) -> nn.Module:
        """Get the building block class."""
        block_cls = config.get("block")
        if isinstance(block_cls, str):
            if block_cls.lower() in ["bottleneck", "bottle_neck"]:
                block_cls = ResNetBottleNeck
            else:
                block_cls = ResNetBasicBlock
        return block_cls

    @add_docstring(
        textwrap.indent(compute_sequential_output_shape_docstring, " " * 4),
        mode="append",
    )
    def compute_output_shape(
        self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
    ) -> Sequence[Union[int, None]]:
        """Compute the output shape of the stage."""
        return compute_sequential_output_shape(self, seq_len, batch_size)


class RegNetStem(nn.Sequential, SizeMixin):
    """The input stem of :class:`RegNet`.

    Parameters
    ----------
    in_channels : int
        Number of input channels.
    out_channels: int or Sequence[int]
        Number of output channels.
    filter_lengths : int or Sequence[int]
        Length of the filter, or equivalently,
        the kernel size(s) of the convolutions.
    conv_stride : int
        Stride of the convolution.
    pool_size : int
        Size of the pooling window.
    pool_stride : int
        Stride of the pooling window.
    subsample_mode : str
        Mode of subsampling, can be one of
        {:class:`DownSample`.__MODES__},
        or "s2d" (with aliases "space_to_depth", "SpaceToDepth").
    groups : int
        Number of groups for the convolution.
    config : dict
        Other configs for convolution and pooling.

    """

    __name__ = "ResNetStem"

    def __init__(
        self,
        in_channels: int,
        out_channels: Union[int, Sequence[int]],
        filter_lengths: Union[int, Sequence[int]],
        conv_stride: int,
        pool_size: int,
        pool_stride: int,
        subsample_mode: str = "max",
        groups: int = 1,
        **config,
    ) -> None:
        super().__init__()
        self.__in_channels = in_channels
        self.__out_channels = out_channels
        self.__filter_lengths = filter_lengths
        if subsample_mode.lower() in ["s2d", "space_to_depth", "SpaceToDepth"]:
            self.add_module(
                "s2d",
                SpaceToDepth(self.__in_channels, self.__out_channels, config.get("block_size", 4)),
            )
            return
        if isinstance(self.__filter_lengths, int):
            self.__filter_lengths = [self.__filter_lengths]
        if isinstance(self.__out_channels, int):
            self.__out_channels = [self.__out_channels]
        assert len(self.__filter_lengths) == len(self.__out_channels)

        conv_in_channels = self.__in_channels
        for idx, fl in enumerate(self.__filter_lengths):
            self.add_module(
                f"conv_{idx}",
                Conv_Bn_Activation(
                    conv_in_channels,
                    self.__out_channels[idx],
                    self.__filter_lengths[idx],
                    stride=conv_stride if idx == 0 else 1,
                    groups=groups,
                    **config,
                ),
            )
            conv_in_channels = self.__out_channels[idx]
        if pool_stride > 1:
            self.add_module(
                "pool",
                DownSample(
                    pool_stride,
                    conv_in_channels,
                    kernel_size=pool_size,
                    groups=groups,
                    padding=(pool_stride - 1) // 2,
                    mode=subsample_mode.lower(),
                    **config,
                ),
            )

    @add_docstring(
        textwrap.indent(compute_sequential_output_shape_docstring, " " * 4),
        mode="append",
    )
    def compute_output_shape(
        self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
    ) -> Sequence[Union[int, None]]:
        """Compute the output shape of the stem."""
        return compute_sequential_output_shape(self, seq_len, batch_size)


[docs]class RegNet(nn.Sequential, SizeMixin, CitationMixin): """RegNet model. RegNet is a family of convolutional neural networks that can be constructed by efficiently scaling and pruning a single convolutional "stem" network. This architecture is proposed in [1]_, and the implementation is adapted from [2]_. References ---------- .. [1] https://arxiv.org/abs/2003.13678 .. [2] https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py Parameters ---------- in_channels : int Number of channels of the input. config : dict Hyper-parameters of the Module, ref. corr. config file. Keyword arguments that must be set: - filter_lengths: int or sequence of int, filter length(s) (kernel size(s)) of the convolutions, with granularity to the whole network, to each stage. - subsample_lengths: int or sequence of int, subsampling length(s) (ratio(s)) of all blocks, with granularity to the whole network, to each stage. - tot_blocks: int, the total number of building blocks. - w_a, w_0, w_m: float, the parameters for the widths generating function. - group_widths: int or sequence of int, the number of channels in each group, with granularity to the whole network, to each stage. - num_blocks: sequence of int, optional, the number of blocks in each stage, if not given, will be computed from tot_blocks and `w_a`, `w_0`, `w_m`. - num_filters: int or sequence of int, optional, the number of filters in each stage. If not given, will be computed from tot_blocks and `w_a`, `w_0`, `w_m`. - stem: dict, the config of the input stem. - block: dict, other parameters that can be set for the building blocks. """ __name__ = "RegNet" __DEFAULT_CONFIG__ = dict( activation="relu", kw_activation={"inplace": True}, kernel_initializer="he_normal", kw_initializer={}, base_groups=1, dropouts=0, ) def __init__(self, in_channels: int, **config) -> None: super().__init__() self.__in_channels = in_channels self.config = CFG(self.__DEFAULT_CONFIG__.copy()) self.config.update(config) stem_config = CFG(self.config.stem) stem_config.pop("num_filters", None) self.add_module( "input_stem", RegNetStem( in_channels=self.__in_channels, out_channels=self.config.stem.num_filters, groups=self.config.base_groups, activation=self.config.activation, **stem_config, ), ) stage_configs = self._get_stage_configs() in_channels = self.input_stem.compute_output_shape()[1] for idx, stage_config in enumerate(stage_configs): stage_block = AnyStage(in_channels=in_channels, **stage_config) self.add_module(f"stage_{idx}", stage_block) in_channels = stage_block.compute_output_shape()[1] def _get_stage_configs(self) -> List[CFG]: """Get the configs for each stage.""" stage_configs = [] if self.config.get("num_blocks", None) is not None: if isinstance(self.config.filter_lengths, int): self.__filter_lengths = list(repeat(self.config.filter_lengths, len(self.config.num_blocks))) else: self.__filter_lengths = self.config.filter_lengths assert len(self.__filter_lengths) == len(self.config.num_blocks), ( f"`config.filter_lengths` indicates {len(self.__filter_lengths)} stages, " f"while `config.num_blocks` indicates {len(self.config.num_blocks)}" ) if isinstance(self.config.subsample_lengths, int): self.__subsample_lengths = list(repeat(self.config.subsample_lengths, len(self.config.num_blocks))) else: self.__subsample_lengths = self.config.subsample_lengths assert len(self.__subsample_lengths) == len(self.config.num_blocks), ( f"`config.subsample_lengths` indicates {len(self.__subsample_lengths)} stages, " f"while `config.num_blocks` indicates {len(self.config.num_blocks)}" ) self.__num_filters = self.config.num_filters assert len(self.__num_filters) == len(self.config.num_blocks), ( f"`config.num_filters` indicates {len(self.__num_filters)} stages, " f"while `config.num_blocks` indicates {len(self.config.num_blocks)}" ) if isinstance(self.config.dropouts, Real): self.__dropouts = list(repeat(self.config.dropouts, len(self.config.num_blocks))) else: self.__dropouts = self.config.dropouts assert len(self.__dropouts) == len(self.config.num_blocks), ( f"`config.dropouts` indicates {len(self.__dropouts)} stages, " f"while `config.num_blocks` indicates {len(self.config.num_blocks)}" ) if isinstance(self.config.group_widths, int): self.__group_widths = list(repeat(self.config.group_widths, len(self.config.num_blocks))) else: self.__group_widths = self.config.group_widths assert len(self.__group_widths) == len(self.config.num_blocks), ( f"`config.group_widths` indicates {len(self.__group_widths)} stages, " f"while `config.num_blocks` indicates {len(self.config.num_blocks)}" ) block_config = CFG(self.config.get("block", {})) block_config.pop("dropout", None) stage_configs = [ CFG( dict( num_blocks=self.config.num_blocks[idx], num_filters=self.__num_filters[idx], filter_length=self.__filter_lengths[idx], subsample_length=self.__subsample_lengths[idx], dropout=self.__dropouts[idx], group_width=self.__group_widths[idx], stage_index=idx, **block_config, ) ) for idx in range(len(self.config.num_blocks)) ] return stage_configs if self.config.get("num_filters", None) is not None: warnings.warn( "`num_filters` are computed from `config.w_a`, `config.w_0`, `config.w_m`, " "if `config.num_blocks` is not provided. " "This may not be the intended behavior.", RuntimeWarning, ) assert {"w_a", "w_0", "w_m", "tot_blocks"}.issubset(set(self.config.keys())), ( "If `config.num_blocks` is not provided, then `config.w_a`, `config.w_0`, `config.w_m`, " "and `config.tot_blocks` must be provided." ) QUANT = 8 if self.config.w_a < 0 or self.config.w_0 <= 0 or self.config.w_m <= 1 or self.config.w_0 % QUANT != 0: raise ValueError("Invalid RegNet settings") # Compute the block widths. Each stage has one unique block width widths_cont = torch.arange(self.config.tot_blocks) * self.config.w_a + self.config.w_0 block_capacity = torch.round(torch.log(widths_cont / self.config.w_0) / math.log(self.config.w_m)) block_widths = ( ( torch.round( torch.divide( self.config.w_0 * torch.pow(self.config.w_m, block_capacity), QUANT, ) ) * QUANT ) .int() .tolist() ) counter = Counter(block_widths) num_stages = len(counter) if isinstance(self.config.filter_lengths, int): self.__filter_lengths = list(repeat(self.config.filter_lengths, num_stages)) else: self.__filter_lengths = self.config.filter_lengths assert len(self.__filter_lengths) == num_stages, ( f"`config.filter_lengths` indicates {len(self.__filter_lengths)} stages, " f"while there are {num_stages} computed from " "`config.w_a`, `config.w_0`, `config.w_m`, `config.tot_blocks`" ) if isinstance(self.config.subsample_lengths, int): self.__subsample_lengths = list(repeat(self.config.subsample_lengths, num_stages)) else: self.__subsample_lengths = self.config.subsample_lengths assert len(self.__subsample_lengths) == num_stages, ( f"`config.subsample_lengths` indicates {len(self.__subsample_lengths)} stages, " f"while there are {num_stages} computed from " "`config.w_a`, `config.w_0`, `config.w_m`, `config.tot_blocks`" ) if isinstance(self.config.dropouts, Real): self.__dropouts = list(repeat(self.config.dropouts, num_stages)) else: self.__dropouts = self.config.dropouts assert len(self.__dropouts) == num_stages, ( f"`config.dropouts` indicates {len(self.__dropouts)} stages, " f"while there are {num_stages} computed from " "`config.w_a`, `config.w_0`, `config.w_m`, `config.tot_blocks`" ) if isinstance(self.config.group_widths, int): self.__group_widths = list(repeat(self.config.group_widths, num_stages)) else: self.__group_widths = self.config.group_widths assert len(self.__group_widths) == num_stages, ( f"`config.group_widths` indicates {len(self.__group_widths)} stages, " f"while there are {num_stages} computed from " "`config.w_a`, `config.w_0`, `config.w_m`, `config.tot_blocks`" ) block_config = CFG(self.config.get("block", {})) block_config.pop("dropout", None) for idx, num_filters in enumerate(sorted(counter)): stage_configs.append( CFG( dict( num_blocks=counter[num_filters], num_filters=num_filters, filter_length=self.__filter_lengths[idx], subsample_length=self.__subsample_lengths[idx], group_width=self.__group_widths[idx], dropout=self.__dropouts[idx], stage_index=idx, **block_config, ) ) ) return stage_configs
[docs] @add_docstring( textwrap.indent(compute_sequential_output_shape_docstring, " " * 4), mode="append", ) def compute_output_shape( self, seq_len: Optional[int] = None, batch_size: Optional[int] = None ) -> Sequence[Union[int, None]]: """Compute the output shape of the network.""" return compute_sequential_output_shape(self, seq_len, batch_size)
@property def in_channels(self) -> int: return self.__in_channels @property def doi(self) -> List[str]: return list(set(self.config.get("doi", []) + ["10.48550/ARXIV.2003.13678"]))