"""
Designing Network Design Spaces
"""
import math
import textwrap
import warnings
from collections import Counter
from itertools import repeat
from numbers import Real
from typing import List, Optional, Sequence, Union
import torch
from torch import nn
from ...cfg import CFG
from ...models._nets import Conv_Bn_Activation, DownSample, SpaceToDepth
from ...utils.misc import CitationMixin, add_docstring
from ...utils.utils_nn import SizeMixin, compute_sequential_output_shape, compute_sequential_output_shape_docstring
from .resnet import ResNetBasicBlock, ResNetBottleNeck
class AnyStage(nn.Sequential, SizeMixin):
"""AnyStage of :class:`RegNet`.
Parameters
----------
in_channels : int
Number of features (channels) of the input.
num_filters : Sequence[int]
Number of filters for the neck conv layer.
filter_length : int
Lengths (sizes) of the filter kernels for the neck conv layer.
subsample_length : int
Subsample length, including pool size for short cut,
and stride for the (top or neck) conv layer.
num_blocks : int
Number of blocks in the stage.
group_width : int
Group width for the bottleneck block.
stage_index : int
Index of the stage in the whole :class:`RegNet`.
block_config: dict,
(optional) configs for the blocks, including
- block: str or torch.nn.Module,
the block class, can be one of
"bottleneck", "bottle_neck", :class:`ResNetBottleNeck`, etc.
- expansion: int,
the expansion factor for the bottleneck block.
- increase_channels_method: str,
the method to increase the number of channels,
can be one of {"conv", "zero_padding"}.
- subsample_mode: str,
the mode of subsampling, can be one of
{:class:`DownSample`.__MODES__},
- activation: str or torch.nn.Module,
the activation function, can be one of
{:class:`Activations`}.
- kw_activation: dict,
keyword arguments for the activation function.
- kernel_initializer: str,
the kernel initializer, can be one of
{:class:`Initializers`}.
- kw_initializer: dict,
keyword arguments for the kernel initializer.
- bias: bool,
whether to use bias in the convolution.
- dilation: int,
the dilation factor for the convolution.
- base_width: int,
number of filters per group for the neck conv layer
usually number of filters of the initial conv layer
of the whole :class:`RegNet`.
- base_groups: int,
pattern of connections between inputs and outputs of
conv layers at the two ends, which should divide `groups`.
- base_filter_length: int,
lengths (sizes) of the filter kernels for conv layers at the two ends.
- attn: dict,
attention mechanism for the neck conv layer.
If is None, no attention mechanism is used.
If is not None, it should be a dict with the following items:
- name: str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
- pos: int, position of the attention mechanism.
Other keys are specific to the attention mechanism.
"""
__name__ = "AnyStage"
__DEFAULT_BLOCK_CONFIG__ = {
"block": "bottleneck",
"expansion": 1,
"increase_channels_method": "conv",
"subsample_mode": "conv",
"activation": "relu",
"kw_activation": {"inplace": True},
"kernel_initializer": "he_normal",
"kw_initializer": {},
"bias": False,
}
def __init__(
self,
in_channels: int,
num_filters: int,
filter_length: int,
subsample_length: int,
num_blocks: int,
group_width: int,
stage_index: int,
**block_config,
) -> None:
super().__init__()
self.block_config = CFG(self.__DEFAULT_BLOCK_CONFIG__.copy())
self.block_config.update(block_config)
block_cls = self.get_building_block_cls(self.block_config)
# adjust num_filters based on group_width
if num_filters % group_width != 0:
_num_filters = num_filters // group_width * group_width
if _num_filters < 0.9 * num_filters:
_num_filters += group_width
num_filters = _num_filters
groups = num_filters // group_width
base_width = block_cls.__DEFAULT_BASE_WIDTH__ / groups
block_in_channels = in_channels
for i in range(num_blocks):
block = block_cls(
in_channels=block_in_channels,
num_filters=num_filters,
filter_length=filter_length,
subsample_length=subsample_length if i == 0 else 1,
groups=groups,
base_width=base_width,
**self.block_config,
)
block_in_channels = block.compute_output_shape()[1]
self.add_module(f"block_{stage_index}_{i}", block)
@staticmethod
def get_building_block_cls(config: CFG) -> nn.Module:
"""Get the building block class."""
block_cls = config.get("block")
if isinstance(block_cls, str):
if block_cls.lower() in ["bottleneck", "bottle_neck"]:
block_cls = ResNetBottleNeck
else:
block_cls = ResNetBasicBlock
return block_cls
@add_docstring(
textwrap.indent(compute_sequential_output_shape_docstring, " " * 4),
mode="append",
)
def compute_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
"""Compute the output shape of the stage."""
return compute_sequential_output_shape(self, seq_len, batch_size)
class RegNetStem(nn.Sequential, SizeMixin):
"""The input stem of :class:`RegNet`.
Parameters
----------
in_channels : int
Number of input channels.
out_channels: int or Sequence[int]
Number of output channels.
filter_lengths : int or Sequence[int]
Length of the filter, or equivalently,
the kernel size(s) of the convolutions.
conv_stride : int
Stride of the convolution.
pool_size : int
Size of the pooling window.
pool_stride : int
Stride of the pooling window.
subsample_mode : str
Mode of subsampling, can be one of
{:class:`DownSample`.__MODES__},
or "s2d" (with aliases "space_to_depth", "SpaceToDepth").
groups : int
Number of groups for the convolution.
config : dict
Other configs for convolution and pooling.
"""
__name__ = "ResNetStem"
def __init__(
self,
in_channels: int,
out_channels: Union[int, Sequence[int]],
filter_lengths: Union[int, Sequence[int]],
conv_stride: int,
pool_size: int,
pool_stride: int,
subsample_mode: str = "max",
groups: int = 1,
**config,
) -> None:
super().__init__()
self.__in_channels = in_channels
self.__out_channels = out_channels
self.__filter_lengths = filter_lengths
if subsample_mode.lower() in ["s2d", "space_to_depth", "SpaceToDepth"]:
self.add_module(
"s2d",
SpaceToDepth(self.__in_channels, self.__out_channels, config.get("block_size", 4)),
)
return
if isinstance(self.__filter_lengths, int):
self.__filter_lengths = [self.__filter_lengths]
if isinstance(self.__out_channels, int):
self.__out_channels = [self.__out_channels]
assert len(self.__filter_lengths) == len(self.__out_channels)
conv_in_channels = self.__in_channels
for idx, fl in enumerate(self.__filter_lengths):
self.add_module(
f"conv_{idx}",
Conv_Bn_Activation(
conv_in_channels,
self.__out_channels[idx],
self.__filter_lengths[idx],
stride=conv_stride if idx == 0 else 1,
groups=groups,
**config,
),
)
conv_in_channels = self.__out_channels[idx]
if pool_stride > 1:
self.add_module(
"pool",
DownSample(
pool_stride,
conv_in_channels,
kernel_size=pool_size,
groups=groups,
padding=(pool_stride - 1) // 2,
mode=subsample_mode.lower(),
**config,
),
)
@add_docstring(
textwrap.indent(compute_sequential_output_shape_docstring, " " * 4),
mode="append",
)
def compute_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
"""Compute the output shape of the stem."""
return compute_sequential_output_shape(self, seq_len, batch_size)
[docs]class RegNet(nn.Sequential, SizeMixin, CitationMixin):
"""RegNet model.
RegNet is a family of convolutional neural networks that can be
constructed by efficiently scaling and pruning a single
convolutional "stem" network. This architecture is proposed in
[1]_, and the implementation is adapted from [2]_.
References
----------
.. [1] https://arxiv.org/abs/2003.13678
.. [2] https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py
Parameters
----------
in_channels : int
Number of channels of the input.
config : dict
Hyper-parameters of the Module, ref. corr. config file.
Keyword arguments that must be set:
- filter_lengths: int or sequence of int,
filter length(s) (kernel size(s)) of the convolutions,
with granularity to the whole network, to each stage.
- subsample_lengths: int or sequence of int,
subsampling length(s) (ratio(s)) of all blocks,
with granularity to the whole network, to each stage.
- tot_blocks: int,
the total number of building blocks.
- w_a, w_0, w_m: float,
the parameters for the widths generating function.
- group_widths: int or sequence of int,
the number of channels in each group,
with granularity to the whole network, to each stage.
- num_blocks: sequence of int, optional,
the number of blocks in each stage,
if not given, will be computed from tot_blocks
and `w_a`, `w_0`, `w_m`.
- num_filters: int or sequence of int, optional,
the number of filters in each stage.
If not given, will be computed from tot_blocks
and `w_a`, `w_0`, `w_m`.
- stem: dict,
the config of the input stem.
- block: dict,
other parameters that can be set for the building blocks.
"""
__name__ = "RegNet"
__DEFAULT_CONFIG__ = dict(
activation="relu",
kw_activation={"inplace": True},
kernel_initializer="he_normal",
kw_initializer={},
base_groups=1,
dropouts=0,
)
def __init__(self, in_channels: int, **config) -> None:
super().__init__()
self.__in_channels = in_channels
self.config = CFG(self.__DEFAULT_CONFIG__.copy())
self.config.update(config)
stem_config = CFG(self.config.stem)
stem_config.pop("num_filters", None)
self.add_module(
"input_stem",
RegNetStem(
in_channels=self.__in_channels,
out_channels=self.config.stem.num_filters,
groups=self.config.base_groups,
activation=self.config.activation,
**stem_config,
),
)
stage_configs = self._get_stage_configs()
in_channels = self.input_stem.compute_output_shape()[1]
for idx, stage_config in enumerate(stage_configs):
stage_block = AnyStage(in_channels=in_channels, **stage_config)
self.add_module(f"stage_{idx}", stage_block)
in_channels = stage_block.compute_output_shape()[1]
def _get_stage_configs(self) -> List[CFG]:
"""Get the configs for each stage."""
stage_configs = []
if self.config.get("num_blocks", None) is not None:
if isinstance(self.config.filter_lengths, int):
self.__filter_lengths = list(repeat(self.config.filter_lengths, len(self.config.num_blocks)))
else:
self.__filter_lengths = self.config.filter_lengths
assert len(self.__filter_lengths) == len(self.config.num_blocks), (
f"`config.filter_lengths` indicates {len(self.__filter_lengths)} stages, "
f"while `config.num_blocks` indicates {len(self.config.num_blocks)}"
)
if isinstance(self.config.subsample_lengths, int):
self.__subsample_lengths = list(repeat(self.config.subsample_lengths, len(self.config.num_blocks)))
else:
self.__subsample_lengths = self.config.subsample_lengths
assert len(self.__subsample_lengths) == len(self.config.num_blocks), (
f"`config.subsample_lengths` indicates {len(self.__subsample_lengths)} stages, "
f"while `config.num_blocks` indicates {len(self.config.num_blocks)}"
)
self.__num_filters = self.config.num_filters
assert len(self.__num_filters) == len(self.config.num_blocks), (
f"`config.num_filters` indicates {len(self.__num_filters)} stages, "
f"while `config.num_blocks` indicates {len(self.config.num_blocks)}"
)
if isinstance(self.config.dropouts, Real):
self.__dropouts = list(repeat(self.config.dropouts, len(self.config.num_blocks)))
else:
self.__dropouts = self.config.dropouts
assert len(self.__dropouts) == len(self.config.num_blocks), (
f"`config.dropouts` indicates {len(self.__dropouts)} stages, "
f"while `config.num_blocks` indicates {len(self.config.num_blocks)}"
)
if isinstance(self.config.group_widths, int):
self.__group_widths = list(repeat(self.config.group_widths, len(self.config.num_blocks)))
else:
self.__group_widths = self.config.group_widths
assert len(self.__group_widths) == len(self.config.num_blocks), (
f"`config.group_widths` indicates {len(self.__group_widths)} stages, "
f"while `config.num_blocks` indicates {len(self.config.num_blocks)}"
)
block_config = CFG(self.config.get("block", {}))
block_config.pop("dropout", None)
stage_configs = [
CFG(
dict(
num_blocks=self.config.num_blocks[idx],
num_filters=self.__num_filters[idx],
filter_length=self.__filter_lengths[idx],
subsample_length=self.__subsample_lengths[idx],
dropout=self.__dropouts[idx],
group_width=self.__group_widths[idx],
stage_index=idx,
**block_config,
)
)
for idx in range(len(self.config.num_blocks))
]
return stage_configs
if self.config.get("num_filters", None) is not None:
warnings.warn(
"`num_filters` are computed from `config.w_a`, `config.w_0`, `config.w_m`, "
"if `config.num_blocks` is not provided. "
"This may not be the intended behavior.",
RuntimeWarning,
)
assert {"w_a", "w_0", "w_m", "tot_blocks"}.issubset(set(self.config.keys())), (
"If `config.num_blocks` is not provided, then `config.w_a`, `config.w_0`, `config.w_m`, "
"and `config.tot_blocks` must be provided."
)
QUANT = 8
if self.config.w_a < 0 or self.config.w_0 <= 0 or self.config.w_m <= 1 or self.config.w_0 % QUANT != 0:
raise ValueError("Invalid RegNet settings")
# Compute the block widths. Each stage has one unique block width
widths_cont = torch.arange(self.config.tot_blocks) * self.config.w_a + self.config.w_0
block_capacity = torch.round(torch.log(widths_cont / self.config.w_0) / math.log(self.config.w_m))
block_widths = (
(
torch.round(
torch.divide(
self.config.w_0 * torch.pow(self.config.w_m, block_capacity),
QUANT,
)
)
* QUANT
)
.int()
.tolist()
)
counter = Counter(block_widths)
num_stages = len(counter)
if isinstance(self.config.filter_lengths, int):
self.__filter_lengths = list(repeat(self.config.filter_lengths, num_stages))
else:
self.__filter_lengths = self.config.filter_lengths
assert len(self.__filter_lengths) == num_stages, (
f"`config.filter_lengths` indicates {len(self.__filter_lengths)} stages, "
f"while there are {num_stages} computed from "
"`config.w_a`, `config.w_0`, `config.w_m`, `config.tot_blocks`"
)
if isinstance(self.config.subsample_lengths, int):
self.__subsample_lengths = list(repeat(self.config.subsample_lengths, num_stages))
else:
self.__subsample_lengths = self.config.subsample_lengths
assert len(self.__subsample_lengths) == num_stages, (
f"`config.subsample_lengths` indicates {len(self.__subsample_lengths)} stages, "
f"while there are {num_stages} computed from "
"`config.w_a`, `config.w_0`, `config.w_m`, `config.tot_blocks`"
)
if isinstance(self.config.dropouts, Real):
self.__dropouts = list(repeat(self.config.dropouts, num_stages))
else:
self.__dropouts = self.config.dropouts
assert len(self.__dropouts) == num_stages, (
f"`config.dropouts` indicates {len(self.__dropouts)} stages, "
f"while there are {num_stages} computed from "
"`config.w_a`, `config.w_0`, `config.w_m`, `config.tot_blocks`"
)
if isinstance(self.config.group_widths, int):
self.__group_widths = list(repeat(self.config.group_widths, num_stages))
else:
self.__group_widths = self.config.group_widths
assert len(self.__group_widths) == num_stages, (
f"`config.group_widths` indicates {len(self.__group_widths)} stages, "
f"while there are {num_stages} computed from "
"`config.w_a`, `config.w_0`, `config.w_m`, `config.tot_blocks`"
)
block_config = CFG(self.config.get("block", {}))
block_config.pop("dropout", None)
for idx, num_filters in enumerate(sorted(counter)):
stage_configs.append(
CFG(
dict(
num_blocks=counter[num_filters],
num_filters=num_filters,
filter_length=self.__filter_lengths[idx],
subsample_length=self.__subsample_lengths[idx],
group_width=self.__group_widths[idx],
dropout=self.__dropouts[idx],
stage_index=idx,
**block_config,
)
)
)
return stage_configs
[docs] @add_docstring(
textwrap.indent(compute_sequential_output_shape_docstring, " " * 4),
mode="append",
)
def compute_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
"""Compute the output shape of the network."""
return compute_sequential_output_shape(self, seq_len, batch_size)
@property
def in_channels(self) -> int:
return self.__in_channels
@property
def doi(self) -> List[str]:
return list(set(self.config.get("doi", []) + ["10.48550/ARXIV.2003.13678"]))