153 lines
5.8 KiB
Python
153 lines
5.8 KiB
Python
# Extra utilities for working with context managers that should have been
|
|
# in the standard library but are not
|
|
|
|
import functools
|
|
import inspect
|
|
import warnings
|
|
import sys
|
|
from typing import Any, Callable, TypeVar, cast
|
|
|
|
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
|
|
# 'no_grad' and 'enable_grad').
|
|
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
|
|
FuncType = Callable[..., Any]
|
|
F = TypeVar('F', bound=FuncType)
|
|
|
|
|
|
def _wrap_generator(ctx_factory, func):
|
|
"""
|
|
Wrap each generator invocation with the context manager factory.
|
|
|
|
The input should be a function that returns a context manager,
|
|
not a context manager itself, to handle one-shot context managers.
|
|
"""
|
|
@functools.wraps(func)
|
|
def generator_context(*args, **kwargs):
|
|
gen = func(*args, **kwargs)
|
|
|
|
# Generators are suspended and unsuspended at `yield`, hence we
|
|
# make sure the grad mode is properly set every time the execution
|
|
# flow returns into the wrapped generator and restored when it
|
|
# returns through our `yield` to our caller (see PR #49017).
|
|
try:
|
|
# Issuing `None` to a generator fires it up
|
|
with ctx_factory():
|
|
response = gen.send(None)
|
|
|
|
while True:
|
|
try:
|
|
# Forward the response to our caller and get its next request
|
|
request = yield response
|
|
|
|
except GeneratorExit:
|
|
# Inform the still active generator about its imminent closure
|
|
with ctx_factory():
|
|
gen.close()
|
|
raise
|
|
|
|
except BaseException:
|
|
# Propagate the exception thrown at us by the caller
|
|
with ctx_factory():
|
|
response = gen.throw(*sys.exc_info())
|
|
|
|
else:
|
|
# Pass the last request to the generator and get its response
|
|
with ctx_factory():
|
|
response = gen.send(request)
|
|
|
|
# We let the exceptions raised above by the generator's `.throw` or
|
|
# `.send` methods bubble up to our caller, except for StopIteration
|
|
except StopIteration as e:
|
|
# The generator informed us that it is done: take whatever its
|
|
# returned value (if any) was and indicate that we're done too
|
|
# by returning it (see docs for python's return-statement).
|
|
return e.value
|
|
|
|
return generator_context
|
|
|
|
|
|
def context_decorator(ctx, func):
|
|
"""
|
|
Like contextlib.ContextDecorator.
|
|
|
|
But with the following differences:
|
|
1. Is done by wrapping, rather than inheritance, so it works with context
|
|
managers that are implemented from C and thus cannot easily inherit from
|
|
Python classes
|
|
2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
|
|
3. Errors out if you try to wrap a class, because it is ambiguous whether
|
|
or not you intended to wrap only the constructor
|
|
|
|
The input argument can either be a context manager (in which case it must
|
|
be a multi-shot context manager that can be directly invoked multiple times)
|
|
or a callable that produces a context manager.
|
|
"""
|
|
assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
|
|
f"Passed in {ctx} is both callable and also a valid context manager "
|
|
"(has __enter__), making it ambiguous which interface to use. If you "
|
|
"intended to pass a context manager factory, rewrite your call as "
|
|
"context_decorator(lambda: ctx()); if you intended to pass a context "
|
|
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
|
|
)
|
|
|
|
if not callable(ctx):
|
|
def ctx_factory():
|
|
return ctx
|
|
else:
|
|
ctx_factory = ctx
|
|
|
|
if inspect.isclass(func):
|
|
raise RuntimeError(
|
|
"Cannot decorate classes; it is ambiguous whether or not only the "
|
|
"constructor or all methods should have the context manager applied; "
|
|
"additionally, decorating a class at definition-site will prevent "
|
|
"use of the identifier as a conventional type. "
|
|
"To specify which methods to decorate, decorate each of them "
|
|
"individually."
|
|
)
|
|
|
|
if inspect.isgeneratorfunction(func):
|
|
return _wrap_generator(ctx_factory, func)
|
|
|
|
@functools.wraps(func)
|
|
def decorate_context(*args, **kwargs):
|
|
with ctx_factory():
|
|
return func(*args, **kwargs)
|
|
|
|
return decorate_context
|
|
|
|
|
|
class _DecoratorContextManager:
|
|
"""Allow a context manager to be used as a decorator."""
|
|
|
|
def __call__(self, orig_func: F) -> F:
|
|
if inspect.isclass(orig_func):
|
|
warnings.warn("Decorating classes is deprecated and will be disabled in "
|
|
"future versions. You should only decorate functions or methods. "
|
|
"To preserve the current behavior of class decoration, you can "
|
|
"directly decorate the `__init__` method and nothing else.")
|
|
func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
|
|
else:
|
|
func = orig_func
|
|
|
|
return cast(F, context_decorator(self.clone, func))
|
|
|
|
def __enter__(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
raise NotImplementedError
|
|
|
|
def clone(self):
|
|
# override this method if your children class takes __init__ parameters
|
|
return self.__class__()
|
|
|
|
|
|
class _NoParamDecoratorContextManager(_DecoratorContextManager):
|
|
"""Allow a context manager to be used as a decorator without parentheses."""
|
|
|
|
def __new__(cls, orig_func=None):
|
|
if orig_func is None:
|
|
return super().__new__(cls)
|
|
return cls()(orig_func)
|