import torch from collections import OrderedDict import weakref import warnings from typing import Any, Tuple __all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"] class RemovableHandle: r""" A handle which provides the capability to remove a hook. Args: hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``. extra_dict (Union[dict, List[dict]]): An additional dictionary or list of dictionaries whose keys will be deleted when the same keys are removed from ``hooks_dict``. """ id: int next_id: int = 0 def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None: self.hooks_dict_ref = weakref.ref(hooks_dict) self.id = RemovableHandle.next_id RemovableHandle.next_id += 1 self.extra_dict_ref: Tuple = () if isinstance(extra_dict, dict): self.extra_dict_ref = (weakref.ref(extra_dict),) elif isinstance(extra_dict, list): self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict) def remove(self) -> None: hooks_dict = self.hooks_dict_ref() if hooks_dict is not None and self.id in hooks_dict: del hooks_dict[self.id] for ref in self.extra_dict_ref: extra_dict = ref() if extra_dict is not None and self.id in extra_dict: del extra_dict[self.id] def __getstate__(self): if self.extra_dict_ref is None: return (self.hooks_dict_ref(), self.id) else: return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref)) def __setstate__(self, state) -> None: if state[0] is None: # create a dead reference self.hooks_dict_ref = weakref.ref(OrderedDict()) else: self.hooks_dict_ref = weakref.ref(state[0]) self.id = state[1] RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1) if len(state) < 3 or state[2] is None: self.extra_dict_ref = () else: self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2]) def __enter__(self) -> "RemovableHandle": return self def __exit__(self, type: Any, value: Any, tb: Any) -> None: self.remove() def unserializable_hook(f): """ Mark a function as an unserializable hook with this decorator. This suppresses warnings that would otherwise arise if you attempt to serialize a tensor that has a hook. """ f.__torch_unserializable__ = True return f def warn_if_has_hooks(tensor): if tensor._backward_hooks: for k in tensor._backward_hooks: hook = tensor._backward_hooks[k] if not hasattr(k, "__torch_unserializable__"): warnings.warn(f"backward hook {repr(hook)} on tensor will not be " "serialized. If this is expected, you can " "decorate the function with @torch.utils.hooks.unserializable_hook " "to suppress this warning") class BackwardHook: """ A wrapper class to implement nn.Module backward hooks. It handles: - Ignoring non-Tensor inputs and replacing them by None before calling the user hook - Generating the proper Node to capture a set of Tensor's gradients - Linking the gradients captures for the outputs with the gradients captured for the input - Calling the user hook once both output and input gradients are available """ def __init__(self, module, user_hooks, user_pre_hooks): self.user_hooks = user_hooks self.user_pre_hooks = user_pre_hooks self.module = module self.grad_outputs = None self.n_outputs = -1 self.output_tensors_index = None self.n_inputs = -1 self.input_tensors_index = None def _pack_with_none(self, indices, values, size): res = [None] * size for idx, val in zip(indices, values): res[idx] = val return tuple(res) def _unpack_none(self, indices, values): res = [] for idx in indices: res.append(values[idx]) return tuple(res) def _set_user_hook(self, grad_fn): def hook(grad_input, _): if self.grad_outputs is None: # This happens because the gradient in your nn.Module flows to # the Module's input without " passing through the Module's # output, e.g. when you're doing double backward. return res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs) for hook in self.user_hooks: out = hook(self.module, res, self.grad_outputs) if out is None: continue if len(out) != len(res): raise RuntimeError("Backward hook returned an invalid number of grad_input, " f"got {len(out)}, but expected {len(res)}") res = out self.grad_outputs = None return self._unpack_none(self.input_tensors_index, res) grad_fn.register_hook(hook) def _apply_on_tensors(self, fn, args): # Can be used to apply the given function to the tensors contained in the # args. Will return updated args and the tensors indices tensors_idx = [] tensors = [] requires_grad = False for i, arg in enumerate(args): if isinstance(arg, torch.Tensor): tensors_idx.append(i) tensors.append(arg) requires_grad |= arg.requires_grad if not (requires_grad and torch.is_grad_enabled()): return args, None new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors) if len(new_tensors) == 0: raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.") grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"] if len(grad_fns) == 0: raise RuntimeError("Error while setting up backward hooks. Please open " "an issue with a code sample to reproduce this.") fn(grad_fns[0]) arg_list = list(args) for idx, val in zip(tensors_idx, new_tensors): arg_list[idx] = val if type(args) is tuple: out = tuple(arg_list) else: out = type(args)(*arg_list) return out, tensors_idx def setup_input_hook(self, args): def fn(grad_fn): self._set_user_hook(grad_fn) res, input_idx = self._apply_on_tensors(fn, args) self.n_inputs = len(args) self.input_tensors_index = input_idx return res def setup_output_hook(self, args): def fn(grad_fn): def hook(_, grad_output): self.grad_outputs = self._pack_with_none(self.output_tensors_index, grad_output, self.n_outputs) if self.user_pre_hooks: expected_len = len(self.grad_outputs) for user_pre_hook in self.user_pre_hooks: hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs) if hook_grad_outputs is None: continue actual_len = len(hook_grad_outputs) if actual_len != expected_len: raise RuntimeError("Backward pre hook returned an invalid number of grad_output, " f"got {actual_len}, but expected {expected_len}") self.grad_outputs = hook_grad_outputs # We need to be able to clear self.grad_outputs but also return it local_grad_outputs = self.grad_outputs # Special case if no input required gradients, this hook should call the user # hook directly if self.input_tensors_index is None: grad_inputs = self._pack_with_none([], [], self.n_inputs) for user_hook in self.user_hooks: res = user_hook(self.module, grad_inputs, self.grad_outputs) if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)): raise RuntimeError("Backward hook for Modules where no input requires " "gradient should always return None or None for all gradients.") self.grad_outputs = None if local_grad_outputs is not None: assert self.output_tensors_index is not None # mypy return tuple(local_grad_outputs[i] for i in self.output_tensors_index) grad_fn.register_hook(hook) is_tuple = True if not isinstance(args, tuple): args = (args,) is_tuple = False res, output_idx = self._apply_on_tensors(fn, args) self.n_outputs = len(args) self.output_tensors_index = output_idx if not is_tuple: res = res[0] return res