266 lines
12 KiB
Python
266 lines
12 KiB
Python
import itertools
|
|
import warnings
|
|
from typing import Protocol, Optional, Type, Any
|
|
|
|
import torch
|
|
from ..parameter import is_lazy
|
|
|
|
__all__ = ['LazyModuleMixin']
|
|
|
|
class _LazyProtocol(Protocol):
|
|
"""This class is used to avoid errors with mypy checks for the attributes in a mixin.
|
|
|
|
https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
|
|
"""
|
|
|
|
def _register_load_state_dict_pre_hook(self, hook):
|
|
...
|
|
|
|
def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False):
|
|
...
|
|
|
|
def _lazy_load_hook(
|
|
self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
...
|
|
|
|
def _get_name(self):
|
|
...
|
|
|
|
def _infer_parameters(self, module, input):
|
|
...
|
|
|
|
@property
|
|
def _parameters(self):
|
|
...
|
|
|
|
@property
|
|
def _buffers(self):
|
|
...
|
|
|
|
@property
|
|
def _non_persistent_buffers_set(self):
|
|
...
|
|
|
|
@property
|
|
def _load_hook(self):
|
|
...
|
|
|
|
@property
|
|
def _initialize_hook(self):
|
|
...
|
|
|
|
|
|
class LazyModuleMixin:
|
|
r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules".
|
|
|
|
.. warning:
|
|
Lazy modules are an experimental new feature under active development,
|
|
and their API is likely to change.
|
|
|
|
Modules that lazily initialize parameters, or "lazy modules",
|
|
derive the shapes of their parameters from the first input(s)
|
|
to their forward method. Until that first forward they contain
|
|
:class:`torch.nn.UninitializedParameter` s that should not be accessed
|
|
or used, and afterward they contain regular :class:`torch.nn.Parameter` s.
|
|
Lazy modules are convenient since they don't require computing some
|
|
module arguments, like the :attr:`in_features` argument of a
|
|
typical :class:`torch.nn.Linear`.
|
|
|
|
After construction, networks with lazy modules should first
|
|
be converted to the desired dtype and placed on the expected device.
|
|
This is because lazy modules only perform shape inference so the usual dtype
|
|
and device placement behavior applies.
|
|
The lazy modules should then perform "dry runs" to initialize all the components in the module.
|
|
These "dry runs" send inputs of the correct size, dtype, and device through
|
|
the network and to each one of its lazy modules. After this the network can be used as usual.
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> class LazyMLP(torch.nn.Module):
|
|
... def __init__(self):
|
|
... super().__init__()
|
|
... self.fc1 = torch.nn.LazyLinear(10)
|
|
... self.relu1 = torch.nn.ReLU()
|
|
... self.fc2 = torch.nn.LazyLinear(1)
|
|
... self.relu2 = torch.nn.ReLU()
|
|
...
|
|
... def forward(self, input):
|
|
... x = self.relu1(self.fc1(input))
|
|
... y = self.relu2(self.fc2(x))
|
|
... return y
|
|
>>> # constructs a network with lazy modules
|
|
>>> lazy_mlp = LazyMLP()
|
|
>>> # transforms the network's device and dtype
|
|
>>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs'
|
|
>>> lazy_mlp = lazy_mlp.cuda().double()
|
|
>>> lazy_mlp
|
|
LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True)
|
|
(relu1): ReLU()
|
|
(fc2): LazyLinear(in_features=0, out_features=1, bias=True)
|
|
(relu2): ReLU()
|
|
)
|
|
>>> # performs a dry run to initialize the network's lazy modules
|
|
>>> lazy_mlp(torch.ones(10,10).cuda())
|
|
>>> # after initialization, LazyLinear modules become regular Linear modules
|
|
>>> lazy_mlp
|
|
LazyMLP(
|
|
(fc1): Linear(in_features=10, out_features=10, bias=True)
|
|
(relu1): ReLU()
|
|
(fc2): Linear(in_features=10, out_features=1, bias=True)
|
|
(relu2): ReLU()
|
|
)
|
|
>>> # attaches an optimizer, since parameters can now be used as usual
|
|
>>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01)
|
|
|
|
A final caveat when using lazy modules is that the order of initialization of a network's
|
|
parameters may change, since the lazy modules are always initialized after other modules.
|
|
For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module
|
|
first and then a regular :class:`torch.nn.Linear` second, the second module would be
|
|
initialized on construction and the first module would be initialized during the first dry run.
|
|
This can cause the parameters of a network using lazy modules to be initialized differently
|
|
than the parameters of a network without lazy modules as the order of parameter initializations,
|
|
which often depends on a stateful random number generator, is different.
|
|
Check :doc:`/notes/randomness` for more details.
|
|
|
|
Lazy modules can be serialized with a state dict like other modules. For example:
|
|
|
|
>>> lazy_mlp = LazyMLP()
|
|
>>> # The state dict shows the uninitialized parameters
|
|
>>> lazy_mlp.state_dict()
|
|
OrderedDict([('fc1.weight', Uninitialized parameter),
|
|
('fc1.bias',
|
|
tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
|
|
4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
|
|
('fc2.weight', Uninitialized parameter),
|
|
('fc2.bias', tensor([0.0019]))])
|
|
|
|
|
|
Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize
|
|
initialized LazyModules and they will remain initialized)
|
|
|
|
|
|
>>> full_mlp = LazyMLP()
|
|
>>> # Dry run to initialize another module
|
|
>>> full_mlp.forward(torch.ones(10, 1))
|
|
>>> # Load an initialized state into a lazy module
|
|
>>> lazy_mlp.load_state_dict(full_mlp.state_dict())
|
|
>>> # The state dict now holds valid values
|
|
>>> lazy_mlp.state_dict()
|
|
OrderedDict([('fc1.weight',
|
|
tensor([[-0.3837],
|
|
[ 0.0907],
|
|
[ 0.6708],
|
|
[-0.5223],
|
|
[-0.9028],
|
|
[ 0.2851],
|
|
[-0.4537],
|
|
[ 0.6813],
|
|
[ 0.5766],
|
|
[-0.8678]])),
|
|
('fc1.bias',
|
|
tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
|
|
4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
|
|
('fc2.weight',
|
|
tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807,
|
|
0.2479, 0.1091]])),
|
|
('fc2.bias', tensor([0.0019]))])
|
|
|
|
Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized
|
|
when the state is loaded. This prevents using initialized modules in different contexts.
|
|
"""
|
|
|
|
# modules inheriting from this will change their __class__ to the specified
|
|
# one after they are fully initialized
|
|
cls_to_become: Optional[Type[Any]] = None
|
|
|
|
def __init__(self: _LazyProtocol, *args, **kwargs):
|
|
# Mypy doesnt like this super call in a mixin
|
|
super().__init__(*args, **kwargs) # type: ignore[misc]
|
|
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
|
|
self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters, with_kwargs=True)
|
|
warnings.warn('Lazy modules are a new feature under heavy development '
|
|
'so changes to the API or functionality can happen at any moment.')
|
|
|
|
def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars):
|
|
# This should be ideally implemented as a hook,
|
|
# but we should override `detach` in the UninitializedParameter to return itself
|
|
# which is not clean
|
|
for name, param in self._parameters.items():
|
|
if param is not None:
|
|
if not (is_lazy(param) or keep_vars):
|
|
param = param.detach()
|
|
destination[prefix + name] = param
|
|
for name, buf in self._buffers.items():
|
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
|
if not (is_lazy(buf) or keep_vars):
|
|
buf = buf.detach()
|
|
destination[prefix + name] = buf
|
|
|
|
def _lazy_load_hook(
|
|
self: _LazyProtocol, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
"""load_state_dict pre-hook function for lazy buffers and parameters.
|
|
|
|
The purpose of this hook is to adjust the current state and/or
|
|
``state_dict`` being loaded so that a module instance serialized in
|
|
both un/initialized state can be deserialized onto both un/initialized
|
|
module instance.
|
|
See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
|
|
for the details of the hook specification.
|
|
"""
|
|
for name, param in itertools.chain(self._parameters.items(), self._buffers.items()):
|
|
key = prefix + name
|
|
if key in state_dict and param is not None:
|
|
input_param = state_dict[key]
|
|
if is_lazy(param):
|
|
# The current parameter is not initialized but the one being loaded one is
|
|
# create a new parameter based on the uninitialized one
|
|
if not is_lazy(input_param):
|
|
with torch.no_grad():
|
|
param.materialize(input_param.shape)
|
|
|
|
def initialize_parameters(self: _LazyProtocol, *args, **kwargs):
|
|
r"""Initialize parameters according to the input batch properties.
|
|
|
|
This adds an interface to isolate parameter initialization from the
|
|
forward pass when doing parameter shape inference.
|
|
"""
|
|
raise NotImplementedError(f'initialize_parameters is not implemented for {self.__class__.__name__}')
|
|
|
|
def has_uninitialized_params(self: _LazyProtocol):
|
|
r"""Check if a module has parameters that are not initialized."""
|
|
# This is to avoid the JIT to track this parameter and force
|
|
# custom modules __setstate__ to add it
|
|
params = self._parameters.values()
|
|
buffers = self._buffers.values()
|
|
for param in itertools.chain(params, buffers):
|
|
if is_lazy(param):
|
|
return True
|
|
return False
|
|
|
|
def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None):
|
|
r"""Infers the size and initializes the parameters according to the provided input batch.
|
|
|
|
Given a module that contains parameters that were declared inferrable
|
|
using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass
|
|
in the complete module using the provided input to initialize all the parameters
|
|
as needed.
|
|
The module is set into evaluation mode before running the forward pass in order
|
|
to avoid saving statistics or calculating gradients
|
|
"""
|
|
kwargs = kwargs if kwargs else {}
|
|
module.initialize_parameters(*args, **kwargs)
|
|
if module.has_uninitialized_params():
|
|
raise RuntimeError(f'module {self._get_name()} has not been fully initialized')
|
|
module._initialize_hook.remove()
|
|
module._load_hook.remove()
|
|
delattr(module, '_initialize_hook')
|
|
delattr(module, '_load_hook')
|
|
if module.cls_to_become is not None:
|
|
module.__class__ = module.cls_to_become
|
|
|
|
|
|
def _replicate_for_data_parallel(self: _LazyProtocol):
|
|
raise RuntimeError('Modules with uninitialized parameters can\'t be used with `DataParallel`. '
|
|
'Run a dummy forward pass to correctly initialize the modules')
|