"""
this file is forked from `FedDR <https://github.com/unc-optimization/FedDR/tree/main/FedDR/flearn/regularizers>`_.
"""
import re
from abc import ABC, abstractmethod
from math import sqrt
from typing import Iterable, List, Optional
from torch.nn.parameter import Parameter
from torch_ecg.utils import ReprMixin, add_docstring
__all__ = [
"get_regularizer",
"Regularizer",
"L1Norm",
"L2Norm",
"L2NormSquared",
"LInfNorm",
"NullRegularizer",
]
[docs]class Regularizer(ReprMixin, ABC):
"""Regularizer base class.
Parameters
----------
coeff : float, default 1.0
The coefficient of the regularizer.
"""
__name__ = "Regularizer"
def __init__(self, coeff: float = 1.0) -> None:
self.coeff = coeff
[docs] @abstractmethod
def eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> float:
"""Evaluate the regularizer on the given parameters.
Parameters
----------
params : Iterable[torch.nn.parameter.Parameter]
The parameters to be evaluated on.
coeff : float, optional
The coefficient of the regularizer.
If None, use the default value.
"""
raise NotImplementedError
[docs] @abstractmethod
def prox_eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> Iterable[Parameter]:
"""Evaluate the proximity operator of the regularizer
on the given parameters.
i.e. evaluate the following function:
.. math::
\\mathrm{prox}_{\\lambda R}(\\mathbf{w})
= \\arg\\min_{\\mathbf{u}} \\frac{1}{2s} \\lVert \\mathbf{u} - \\mathbf{w} \\rVert_2^2
+ \\lambda R(\\mathbf{u})
where :math:`R` is the regularizer.
Parameters
----------
params : Iterable[torch.nn.parameter.Parameter]
The parameters to be evaluated on.
coeff : float, optional
The coefficient of the regularizer.
If None, use the default value.
Returns
-------
Iterable[torch.nn.parameter.Parameter]
The proximity operator of the regularizer
evaluated on the given parameters.
"""
raise NotImplementedError
[docs]def get_regularizer(reg_type: str, reg_coeff: float = 1.0) -> Regularizer:
"""Get the regularizer by name.
Parameters
----------
reg_type : str
The name of the regularizer.
reg_coeff : float, default 1.0
The coefficient of the regularizer.
Returns
-------
Regularizer
The regularizer instance.
"""
reg_type = re.sub("regularizer|norm|[\\s\\_\\-]+", "", reg_type.lower())
if reg_type in [
"l1",
]:
return L1Norm(reg_coeff)
elif reg_type in [
"l2",
]:
return L2Norm(reg_coeff)
elif reg_type in [
"l2squared",
]:
return L2NormSquared(reg_coeff)
elif reg_type in [
"no",
"empty",
"zero",
"none",
"null",
]:
return NullRegularizer(reg_coeff)
elif reg_type in [
"linf",
"inf",
"linfinity",
"infinity",
"linfty",
"infty",
]:
return LInfNorm(reg_coeff)
else:
raise ValueError(f"Unknown regularizer type: {reg_type}")
[docs]@add_docstring(
Regularizer.__doc__.replace(
"Regularizer base class.",
"Null regularizer, or equivalently the zero function.",
)
)
class NullRegularizer(Regularizer):
__name__ = "NullRegularizer"
[docs] def eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> float:
return 0.0
[docs] def prox_eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> Iterable[Parameter]:
return list(params)
[docs]@add_docstring(
Regularizer.__doc__.replace(
"Regularizer base class.",
"L1 norm regularizer.",
)
)
class L1Norm(Regularizer):
__name__ = "L1Norm"
[docs] def eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> float:
if coeff is None:
coeff = self.coeff
return coeff * sum([p.data.abs().sum().item() for p in params])
[docs] def prox_eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> Iterable[Parameter]:
if coeff is None:
coeff = self.coeff
ret_params = [p.data.sign() * (p.data.abs() - coeff).clamp(min=0) for p in params]
return ret_params
[docs]@add_docstring(
Regularizer.__doc__.replace(
"Regularizer base class.",
"L2 norm regularizer.",
)
)
class L2Norm(Regularizer):
__name__ = "L2Norm"
[docs] def eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> float:
if coeff is None:
coeff = self.coeff
return coeff * sqrt(sum([p.data.pow(2).sum().item() for p in params]))
[docs] def prox_eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> Iterable[Parameter]:
if coeff is None:
coeff = self.coeff
_params = list(params) # to avoid the case that params is a generator
norm = self.eval(_params, coeff=coeff)
coeff = max(0, 1 - coeff / norm)
ret_params = [coeff * p.data for p in _params]
del _params
return ret_params
[docs]@add_docstring(
Regularizer.__doc__.replace(
"Regularizer base class.",
"L2 norm squared regularizer.",
)
)
class L2NormSquared(Regularizer):
__name__ = "L2NormSquared"
[docs] def eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> float:
if coeff is None:
coeff = self.coeff
return coeff * sum([p.data.pow(2).sum().item() for p in params])
[docs] def prox_eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> Iterable[Parameter]:
if coeff is None:
coeff = self.coeff
coeff = 1 / (1 + 2 * coeff)
_params = list(params) # to avoid the case that params is a generator
ret_params = [coeff * p.data for p in _params]
del _params
return ret_params
[docs]@add_docstring(
Regularizer.__doc__.replace(
"Regularizer base class.",
"L-infinity norm regularizer.",
)
)
class LInfNorm(Regularizer):
__name__ = "LInfNorm"
[docs] def eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> float:
if coeff is None:
coeff = self.coeff
return coeff * max([p.data.abs().max().item() for p in params])
[docs] def prox_eval(self, params: Iterable[Parameter], coeff: Optional[float] = None) -> Iterable[Parameter]:
if coeff is None:
coeff = self.coeff
_params = list(params) # to avoid the case that params is a generator
raise NotImplementedError("L-infinity norm is not implemented yet")