397 lines
13 KiB
Python
397 lines
13 KiB
Python
|
from typing import Any
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from torch.utils._contextlib import (
|
||
|
_DecoratorContextManager,
|
||
|
_NoParamDecoratorContextManager,
|
||
|
F,
|
||
|
)
|
||
|
|
||
|
__all__ = [
|
||
|
"no_grad",
|
||
|
"enable_grad",
|
||
|
"set_grad_enabled",
|
||
|
"inference_mode",
|
||
|
"set_multithreading_enabled",
|
||
|
]
|
||
|
|
||
|
|
||
|
class no_grad(_NoParamDecoratorContextManager):
|
||
|
r"""Context-manager that disables gradient calculation.
|
||
|
|
||
|
Disabling gradient calculation is useful for inference, when you are sure
|
||
|
that you will not call :meth:`Tensor.backward()`. It will reduce memory
|
||
|
consumption for computations that would otherwise have `requires_grad=True`.
|
||
|
|
||
|
In this mode, the result of every computation will have
|
||
|
`requires_grad=False`, even when the inputs have `requires_grad=True`.
|
||
|
There is an exception! All factory functions, or functions that create
|
||
|
a new Tensor and take a requires_grad kwarg, will NOT be affected by
|
||
|
this mode.
|
||
|
|
||
|
This context manager is thread local; it will not affect computation
|
||
|
in other threads.
|
||
|
|
||
|
Also functions as a decorator.
|
||
|
|
||
|
.. note::
|
||
|
No-grad is one of several mechanisms that can enable or
|
||
|
disable gradients locally see :ref:`locally-disable-grad-doc` for
|
||
|
more information on how they compare.
|
||
|
|
||
|
.. note::
|
||
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
|
||
|
If you want to disable forward AD for a computation, you can unpack
|
||
|
your dual tensors.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> x = torch.tensor([1.], requires_grad=True)
|
||
|
>>> with torch.no_grad():
|
||
|
... y = x * 2
|
||
|
>>> y.requires_grad
|
||
|
False
|
||
|
>>> @torch.no_grad()
|
||
|
... def doubler(x):
|
||
|
... return x * 2
|
||
|
>>> z = doubler(x)
|
||
|
>>> z.requires_grad
|
||
|
False
|
||
|
>>> @torch.no_grad
|
||
|
... def tripler(x):
|
||
|
... return x * 3
|
||
|
>>> z = tripler(x)
|
||
|
>>> z.requires_grad
|
||
|
False
|
||
|
>>> # factory function exception
|
||
|
>>> with torch.no_grad():
|
||
|
... a = torch.nn.Parameter(torch.rand(10))
|
||
|
>>> a.requires_grad
|
||
|
True
|
||
|
"""
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
if not torch._jit_internal.is_scripting():
|
||
|
super().__init__()
|
||
|
self.prev = False
|
||
|
|
||
|
def __enter__(self) -> None:
|
||
|
self.prev = torch.is_grad_enabled()
|
||
|
torch.set_grad_enabled(False)
|
||
|
|
||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||
|
torch.set_grad_enabled(self.prev)
|
||
|
|
||
|
|
||
|
class enable_grad(_NoParamDecoratorContextManager):
|
||
|
r"""Context-manager that enables gradient calculation.
|
||
|
|
||
|
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
|
||
|
or :class:`~set_grad_enabled`.
|
||
|
|
||
|
This context manager is thread local; it will not affect computation
|
||
|
in other threads.
|
||
|
|
||
|
Also functions as a decorator.
|
||
|
|
||
|
.. note::
|
||
|
enable_grad is one of several mechanisms that can enable or
|
||
|
disable gradients locally see :ref:`locally-disable-grad-doc` for
|
||
|
more information on how they compare.
|
||
|
|
||
|
.. note::
|
||
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> x = torch.tensor([1.], requires_grad=True)
|
||
|
>>> with torch.no_grad():
|
||
|
... with torch.enable_grad():
|
||
|
... y = x * 2
|
||
|
>>> y.requires_grad
|
||
|
True
|
||
|
>>> y.backward()
|
||
|
>>> x.grad
|
||
|
tensor([2.])
|
||
|
>>> @torch.enable_grad()
|
||
|
... def doubler(x):
|
||
|
... return x * 2
|
||
|
>>> with torch.no_grad():
|
||
|
... z = doubler(x)
|
||
|
>>> z.requires_grad
|
||
|
True
|
||
|
>>> @torch.enable_grad
|
||
|
... def tripler(x):
|
||
|
... return x * 3
|
||
|
>>> with torch.no_grad():
|
||
|
... z = tripler(x)
|
||
|
>>> z.requires_grad
|
||
|
True
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __enter__(self) -> None:
|
||
|
self.prev = torch.is_grad_enabled()
|
||
|
torch._C._set_grad_enabled(True)
|
||
|
|
||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||
|
torch._C._set_grad_enabled(self.prev)
|
||
|
|
||
|
|
||
|
class set_grad_enabled(_DecoratorContextManager):
|
||
|
r"""Context-manager that sets gradient calculation on or off.
|
||
|
|
||
|
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
|
||
|
It can be used as a context-manager or as a function.
|
||
|
|
||
|
This context manager is thread local; it will not affect computation
|
||
|
in other threads.
|
||
|
|
||
|
Args:
|
||
|
mode (bool): Flag whether to enable grad (``True``), or disable
|
||
|
(``False``). This can be used to conditionally enable
|
||
|
gradients.
|
||
|
|
||
|
.. note::
|
||
|
set_grad_enabled is one of several mechanisms that can enable or
|
||
|
disable gradients locally see :ref:`locally-disable-grad-doc` for
|
||
|
more information on how they compare.
|
||
|
|
||
|
.. note::
|
||
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> x = torch.tensor([1.], requires_grad=True)
|
||
|
>>> is_train = False
|
||
|
>>> with torch.set_grad_enabled(is_train):
|
||
|
... y = x * 2
|
||
|
>>> y.requires_grad
|
||
|
False
|
||
|
>>> _ = torch.set_grad_enabled(True)
|
||
|
>>> y = x * 2
|
||
|
>>> y.requires_grad
|
||
|
True
|
||
|
>>> _ = torch.set_grad_enabled(False)
|
||
|
>>> y = x * 2
|
||
|
>>> y.requires_grad
|
||
|
False
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, mode: bool) -> None:
|
||
|
self.prev = torch.is_grad_enabled()
|
||
|
self.mode = mode
|
||
|
torch._C._set_grad_enabled(mode)
|
||
|
|
||
|
def __call__(self, orig_func: F) -> F:
|
||
|
torch._C._set_grad_enabled(self.prev)
|
||
|
return super().__call__(orig_func)
|
||
|
|
||
|
def __enter__(self) -> None:
|
||
|
torch._C._set_grad_enabled(self.mode)
|
||
|
|
||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||
|
torch._C._set_grad_enabled(self.prev)
|
||
|
|
||
|
def clone(self) -> "set_grad_enabled":
|
||
|
r"""
|
||
|
Create a copy of this class
|
||
|
"""
|
||
|
return self.__class__(self.mode)
|
||
|
|
||
|
|
||
|
class inference_mode(_DecoratorContextManager):
|
||
|
r"""Context-manager that enables or disables inference mode.
|
||
|
|
||
|
InferenceMode is a new context manager analogous to :class:`~no_grad`
|
||
|
to be used when you are certain your operations will have no interactions
|
||
|
with autograd (e.g., model training). Code run under this mode gets better
|
||
|
performance by disabling view tracking and version counter bumps. Note that
|
||
|
unlike some other mechanisms that locally enable or disable grad,
|
||
|
entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
|
||
|
|
||
|
This context manager is thread local; it will not affect computation
|
||
|
in other threads.
|
||
|
|
||
|
Also functions as a decorator.
|
||
|
|
||
|
.. note::
|
||
|
Inference mode is one of several mechanisms that can enable or
|
||
|
disable gradients locally see :ref:`locally-disable-grad-doc` for
|
||
|
more information on how they compare.
|
||
|
|
||
|
Args:
|
||
|
mode (bool or function): Either a boolean flag whether to enable or
|
||
|
disable inference mode or a Python function to decorate with
|
||
|
inference mode enabled
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
|
||
|
>>> import torch
|
||
|
>>> x = torch.ones(1, 2, 3, requires_grad=True)
|
||
|
>>> with torch.inference_mode():
|
||
|
... y = x * x
|
||
|
>>> y.requires_grad
|
||
|
False
|
||
|
>>> # xdoctest: +SKIP("want string isnt quite right")
|
||
|
>>> y._version
|
||
|
Traceback (most recent call last):
|
||
|
File "<stdin>", line 1, in <module>
|
||
|
RuntimeError: Inference tensors do not track version counter.
|
||
|
>>> @torch.inference_mode()
|
||
|
... def func(x):
|
||
|
... return x * x
|
||
|
>>> out = func(x)
|
||
|
>>> out.requires_grad
|
||
|
False
|
||
|
>>> @torch.inference_mode
|
||
|
... def doubler(x):
|
||
|
... return x * 2
|
||
|
>>> out = doubler(x)
|
||
|
>>> out.requires_grad
|
||
|
False
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, mode: bool = True) -> None:
|
||
|
if not torch._jit_internal.is_scripting():
|
||
|
super().__init__()
|
||
|
self.mode = mode
|
||
|
|
||
|
def __new__(cls, mode=True):
|
||
|
if isinstance(mode, bool):
|
||
|
return super().__new__(cls)
|
||
|
return cls()(mode)
|
||
|
|
||
|
def __enter__(self) -> None:
|
||
|
self._inference_mode_context = torch._C._InferenceMode(self.mode)
|
||
|
self._inference_mode_context.__enter__()
|
||
|
|
||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||
|
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
|
||
|
|
||
|
def clone(self) -> "inference_mode":
|
||
|
r"""
|
||
|
Create a copy of this class
|
||
|
"""
|
||
|
return self.__class__(self.mode)
|
||
|
|
||
|
|
||
|
def _enter_inference_mode(mode):
|
||
|
mode_context = torch._C._InferenceMode(mode)
|
||
|
mode_context.__enter__()
|
||
|
return mode_context
|
||
|
|
||
|
|
||
|
def _exit_inference_mode(mode):
|
||
|
mode.__exit__(None, None, None)
|
||
|
|
||
|
|
||
|
class set_multithreading_enabled(_DecoratorContextManager):
|
||
|
r"""Context-manager that sets multithreaded backwards on or off.
|
||
|
|
||
|
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
|
||
|
It can be used as a context-manager or as a function.
|
||
|
|
||
|
This context manager is thread local; it will not affect computation
|
||
|
in other threads.
|
||
|
|
||
|
Args:
|
||
|
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
|
||
|
(``False``).
|
||
|
|
||
|
.. note::
|
||
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, mode: bool) -> None:
|
||
|
self.prev = torch._C._is_multithreading_enabled()
|
||
|
torch._C._set_multithreading_enabled(mode)
|
||
|
self.mode = mode
|
||
|
|
||
|
def __enter__(self) -> None:
|
||
|
pass
|
||
|
|
||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||
|
torch._C._set_multithreading_enabled(self.prev)
|
||
|
|
||
|
def clone(self) -> "set_multithreading_enabled":
|
||
|
r"""
|
||
|
Create a copy of this class
|
||
|
"""
|
||
|
return self.__class__(self.mode)
|
||
|
|
||
|
|
||
|
class _force_original_view_tracking(_DecoratorContextManager):
|
||
|
r"""Context-manager that sets whether or not to always enable view-replay in autograd.
|
||
|
|
||
|
``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
|
||
|
It can be used as a context-manager or as a function.
|
||
|
|
||
|
This context manager is thread local; it will not affect computation
|
||
|
in other threads.
|
||
|
|
||
|
When a tensor view is mutated, the autograd engine needs to decide whether or not
|
||
|
to regenerate the "updated view" by either replaying the chain of views from the updated base,
|
||
|
or with a single call to as_strided.
|
||
|
|
||
|
If set_view_replay_enabled is set to True, then autograd will always use view replay.
|
||
|
Otherwise, it will fall back to its existing logic.
|
||
|
|
||
|
Args:
|
||
|
mode (bool): Flag whether to enable view-replay (``True``), or disable
|
||
|
(``False``).
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, mode: bool) -> None:
|
||
|
self.prev = torch._C._is_view_replay_enabled()
|
||
|
torch._C._set_view_replay_enabled(mode)
|
||
|
self.mode = mode
|
||
|
|
||
|
def __enter__(self) -> None:
|
||
|
pass
|
||
|
|
||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||
|
torch._C._set_view_replay_enabled(self.prev)
|
||
|
|
||
|
def clone(self):
|
||
|
return self.__class__(self.mode)
|
||
|
|
||
|
|
||
|
class _unsafe_preserve_version_counter(_DecoratorContextManager):
|
||
|
r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
|
||
|
|
||
|
This context manager can lead to arbitrary silent-correctness issues in any other part of your code
|
||
|
(even the ones not touched directly by the context manager)!
|
||
|
|
||
|
Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
|
||
|
This is generally important for correctness, as for example, mutating a tensor that autograd has saved
|
||
|
for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
|
||
|
and error out in this situation.
|
||
|
|
||
|
However, there are rare instances where it might be useful to hide mutations from autograd. For example:
|
||
|
if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
|
||
|
the tensor right before it is needed by autograd.
|
||
|
|
||
|
Args:
|
||
|
tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
|
||
|
|
||
|
.. note::
|
||
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, tensor: torch.Tensor) -> None:
|
||
|
self.tensor = tensor
|
||
|
self.prev_version = tensor._version
|
||
|
|
||
|
def __enter__(self) -> None:
|
||
|
pass
|
||
|
|
||
|
def __exit__(self, *args) -> None:
|
||
|
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
|