64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
from typing import Type
|
|
|
|
from torch import optim
|
|
from .functional_adadelta import _FunctionalAdadelta
|
|
from .functional_adagrad import _FunctionalAdagrad
|
|
from .functional_adam import _FunctionalAdam
|
|
from .functional_adamax import _FunctionalAdamax
|
|
from .functional_adamw import _FunctionalAdamW
|
|
from .functional_rmsprop import _FunctionalRMSprop
|
|
from .functional_rprop import _FunctionalRprop
|
|
from .functional_sgd import _FunctionalSGD
|
|
|
|
# dict to map a user passed in optimizer_class to a functional
|
|
# optimizer class if we have already defined inside the
|
|
# distributed.optim package, this is so that we hide the
|
|
# functional optimizer to user and still provide the same API.
|
|
functional_optim_map = {
|
|
optim.Adagrad: _FunctionalAdagrad,
|
|
optim.Adam: _FunctionalAdam,
|
|
optim.AdamW: _FunctionalAdamW,
|
|
optim.SGD: _FunctionalSGD,
|
|
optim.Adadelta: _FunctionalAdadelta,
|
|
optim.RMSprop: _FunctionalRMSprop,
|
|
optim.Rprop: _FunctionalRprop,
|
|
optim.Adamax: _FunctionalAdamax,
|
|
}
|
|
|
|
|
|
def register_functional_optim(key, optim):
|
|
"""
|
|
Interface to insert a new functional optimizer to functional_optim_map
|
|
``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
|
|
need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
|
|
Example::
|
|
>>> # import the new functional optimizer
|
|
>>> # xdoctest: +SKIP
|
|
>>> from xyz import fn_optimizer
|
|
>>> from torch.distributed.optim.utils import register_functional_optim
|
|
>>> fn_optim_key = "XYZ_optim"
|
|
>>> register_functional_optim(fn_optim_key, fn_optimizer)
|
|
"""
|
|
if key not in functional_optim_map:
|
|
functional_optim_map[key] = optim
|
|
|
|
|
|
def as_functional_optim(optim_cls: Type, *args, **kwargs):
|
|
try:
|
|
functional_cls = functional_optim_map[optim_cls]
|
|
except KeyError as e:
|
|
raise ValueError(
|
|
f"Optimizer {optim_cls} does not have a functional " f"counterpart!"
|
|
) from e
|
|
|
|
return _create_functional_optim(functional_cls, *args, **kwargs)
|
|
|
|
|
|
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
|
|
return functional_optim_cls(
|
|
[],
|
|
*args,
|
|
**kwargs,
|
|
_allow_empty_param_list=True,
|
|
)
|