44 lines
1.0 KiB
Python
44 lines
1.0 KiB
Python
|
from typing import Any, Dict, runtime_checkable, TypeVar
|
||
|
|
||
|
from typing_extensions import Protocol
|
||
|
|
||
|
|
||
|
__all__ = ["Stateful", "StatefulT"]
|
||
|
|
||
|
|
||
|
@runtime_checkable
|
||
|
class Stateful(Protocol):
|
||
|
"""
|
||
|
Stateful protocol for objects that can be checkpointed and restored.
|
||
|
"""
|
||
|
|
||
|
def state_dict(self) -> Dict[str, Any]:
|
||
|
"""
|
||
|
Objects should return their state_dict representation as a dictionary.
|
||
|
The output of this function will be checkpointed, and later restored in
|
||
|
`load_state_dict()`.
|
||
|
|
||
|
.. warning::
|
||
|
Because of the inplace nature of restoring a checkpoint, this function
|
||
|
is also called during `torch.distributed.checkpoint.load`.
|
||
|
|
||
|
|
||
|
Returns:
|
||
|
Dict: The objects state dict
|
||
|
"""
|
||
|
|
||
|
...
|
||
|
|
||
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||
|
"""
|
||
|
Restore the object's state from the provided state_dict.
|
||
|
|
||
|
Args:
|
||
|
state_dict: The state dict to restore from
|
||
|
"""
|
||
|
|
||
|
...
|
||
|
|
||
|
|
||
|
StatefulT = TypeVar("StatefulT", bound=Stateful)
|