Source code for fl_sim.optimizers._register

"""
"""

import re
import warnings
from typing import Any, Dict, List, Optional

import torch.optim as optim

_built_in_optimizers = {}


[docs]def register_optimizer(name: Optional[str] = None, override: bool = True) -> Any: """Decorator to register a new optimizer. Parameters ---------- name : str, optional Name of the optimizer. If not specified, the class name with "(?:Optimizer)?" removed will be used. override : bool, default True Whether to override the existing optimizer with the same name. Returns ------- The decorated class. """ def wrapper(cls_: Any) -> Any: if name is None: if hasattr(cls_, "__name__"): _name = cls_.__name__ else: _name = cls_.__class__.__name__ _name = re.sub("(?:Optimizer)?$", "", _name) else: _name = name assert issubclass(cls_, optim.Optimizer), f"{cls_} is not a valid optimizer" if _name in _built_in_optimizers: if override: _built_in_optimizers[_name] = cls_ else: # raise ValueError(f"{_name} has already been registered") warnings.warn(f"{_name} has already been registered", RuntimeWarning) else: _built_in_optimizers[_name] = cls_ return cls_ return wrapper
def list_optimizers() -> List[str]: return list(_built_in_optimizers) def get_optimizer(name: str) -> Dict[str, Any]: if name not in _built_in_optimizers: _name = re.sub("(?:Optimizer)?$", "", name) else: _name = name if _name not in _built_in_optimizers: raise ValueError(f"Optimizer {name} is not registered") return _built_in_optimizers[_name]