23 lines
587 B
Python
23 lines
587 B
Python
from typing import Optional, Tuple, Union
|
|
|
|
from torch import Tensor
|
|
|
|
from .optimizer import Optimizer, ParamsT
|
|
|
|
class Adam(Optimizer):
|
|
def __init__(
|
|
self,
|
|
params: ParamsT,
|
|
lr: Union[float, Tensor] = 1e-3,
|
|
betas: Tuple[float, float] = (0.9, 0.999),
|
|
eps: float = 1e-8,
|
|
weight_decay: float = 0,
|
|
amsgrad: bool = False,
|
|
*,
|
|
foreach: Optional[bool] = None,
|
|
maximize: bool = False,
|
|
capturable: bool = False,
|
|
differentiable: bool = False,
|
|
fused: Optional[bool] = None,
|
|
) -> None: ...
|