Source code for fl_sim.optimizers._register
"""
"""
import re
import warnings
from typing import Any, Dict, List, Optional
import torch.optim as optim
_built_in_optimizers = {}
[docs]def register_optimizer(name: Optional[str] = None, override: bool = True) -> Any:
"""Decorator to register a new optimizer.
Parameters
----------
name : str, optional
Name of the optimizer.
If not specified, the class name with "(?:Optimizer)?"
removed will be used.
override : bool, default True
Whether to override the existing optimizer with the same name.
Returns
-------
The decorated class.
"""
def wrapper(cls_: Any) -> Any:
if name is None:
if hasattr(cls_, "__name__"):
_name = cls_.__name__
else:
_name = cls_.__class__.__name__
_name = re.sub("(?:Optimizer)?$", "", _name)
else:
_name = name
assert issubclass(cls_, optim.Optimizer), f"{cls_} is not a valid optimizer"
if _name in _built_in_optimizers:
if override:
_built_in_optimizers[_name] = cls_
else:
# raise ValueError(f"{_name} has already been registered")
warnings.warn(f"{_name} has already been registered", RuntimeWarning)
else:
_built_in_optimizers[_name] = cls_
return cls_
return wrapper
def list_optimizers() -> List[str]:
return list(_built_in_optimizers)
def get_optimizer(name: str) -> Dict[str, Any]:
if name not in _built_in_optimizers:
_name = re.sub("(?:Optimizer)?$", "", name)
else:
_name = name
if _name not in _built_in_optimizers:
raise ValueError(f"Optimizer {name} is not registered")
return _built_in_optimizers[_name]