# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """TFDecorator-aware replacements for the inspect module.""" import collections import functools import inspect as _inspect import tensorflow.compat.v2 as tf if hasattr(_inspect, "ArgSpec"): ArgSpec = _inspect.ArgSpec else: ArgSpec = collections.namedtuple( "ArgSpec", [ "args", "varargs", "keywords", "defaults", ], ) if hasattr(_inspect, "FullArgSpec"): FullArgSpec = _inspect.FullArgSpec else: FullArgSpec = collections.namedtuple( "FullArgSpec", [ "args", "varargs", "varkw", "defaults", "kwonlyargs", "kwonlydefaults", "annotations", ], ) def _convert_maybe_argspec_to_fullargspec(argspec): if isinstance(argspec, FullArgSpec): return argspec return FullArgSpec( args=argspec.args, varargs=argspec.varargs, varkw=argspec.keywords, defaults=argspec.defaults, kwonlyargs=[], kwonlydefaults=None, annotations={}, ) if hasattr(_inspect, "getfullargspec"): _getfullargspec = _inspect.getfullargspec def _getargspec(target): """A python3 version of getargspec. Calls `getfullargspec` and assigns args, varargs, varkw, and defaults to a python 2/3 compatible `ArgSpec`. The parameter name 'varkw' is changed to 'keywords' to fit the `ArgSpec` struct. Args: target: the target object to inspect. Returns: An ArgSpec with args, varargs, keywords, and defaults parameters from FullArgSpec. """ fullargspecs = getfullargspec(target) argspecs = ArgSpec( args=fullargspecs.args, varargs=fullargspecs.varargs, keywords=fullargspecs.varkw, defaults=fullargspecs.defaults, ) return argspecs else: _getargspec = _inspect.getargspec def _getfullargspec(target): """A python2 version of getfullargspec. Args: target: the target object to inspect. Returns: A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations. """ return _convert_maybe_argspec_to_fullargspec(getargspec(target)) def currentframe(): """TFDecorator-aware replacement for inspect.currentframe.""" return _inspect.stack()[1][0] def getargspec(obj): """TFDecorator-aware replacement for `inspect.getargspec`. Note: `getfullargspec` is recommended as the python 2/3 compatible replacement for this function. Args: obj: A function, partial function, or callable object, possibly decorated. Returns: The `ArgSpec` that describes the signature of the outermost decorator that changes the callable's signature, or the `ArgSpec` that describes the object if not decorated. Raises: ValueError: When callable's signature can not be expressed with ArgSpec. TypeError: For objects of unsupported types. """ if isinstance(obj, functools.partial): return _get_argspec_for_partial(obj) decorators, target = tf.__internal__.decorator.unwrap(obj) spec = next( ( d.decorator_argspec for d in decorators if d.decorator_argspec is not None ), None, ) if spec: return spec try: # Python3 will handle most callables here (not partial). return _getargspec(target) except TypeError: pass if isinstance(target, type): try: return _getargspec(target.__init__) except TypeError: pass try: return _getargspec(target.__new__) except TypeError: pass # The `type(target)` ensures that if a class is received we don't return # the signature of its __call__ method. return _getargspec(type(target).__call__) def _get_argspec_for_partial(obj): """Implements `getargspec` for `functools.partial` objects. Args: obj: The `functools.partial` object Returns: An `inspect.ArgSpec` Raises: ValueError: When callable's signature can not be expressed with ArgSpec. """ # When callable is a functools.partial object, we construct its ArgSpec with # following strategy: # - If callable partial contains default value for positional arguments (ie. # object.args), then final ArgSpec doesn't contain those positional # arguments. # - If callable partial contains default value for keyword arguments (ie. # object.keywords), then we merge them with wrapped target. Default values # from callable partial takes precedence over those from wrapped target. # # However, there is a case where it is impossible to construct a valid # ArgSpec. Python requires arguments that have no default values must be # defined before those with default values. ArgSpec structure is only valid # when this presumption holds true because default values are expressed as a # tuple of values without keywords and they are always assumed to belong to # last K arguments where K is number of default values present. # # Since functools.partial can give default value to any argument, this # presumption may no longer hold in some cases. For example: # # def func(m, n): # return 2 * m + n # partialed = functools.partial(func, m=1) # # This example will result in m having a default value but n doesn't. This # is usually not allowed in Python and can not be expressed in ArgSpec # correctly. # # Thus, we must detect cases like this by finding first argument with # default value and ensures all following arguments also have default # values. When this is not true, a ValueError is raised. n_prune_args = len(obj.args) partial_keywords = obj.keywords or {} args, varargs, keywords, defaults = getargspec(obj.func) # Pruning first n_prune_args arguments. args = args[n_prune_args:] # Partial function may give default value to any argument, therefore length # of default value list must be len(args) to allow each argument to # potentially be given a default value. no_default = object() all_defaults = [no_default] * len(args) if defaults: all_defaults[-len(defaults) :] = defaults # Fill in default values provided by partial function in all_defaults. for kw, default in partial_keywords.items(): if kw in args: idx = args.index(kw) all_defaults[idx] = default elif not keywords: raise ValueError( "Function does not have **kwargs parameter, but " "contains an unknown partial keyword." ) # Find first argument with default value set. first_default = next( (idx for idx, x in enumerate(all_defaults) if x is not no_default), None ) # If no default values are found, return ArgSpec with defaults=None. if first_default is None: return ArgSpec(args, varargs, keywords, None) # Checks if all arguments have default value set after first one. invalid_default_values = [ args[i] for i, j in enumerate(all_defaults) if j is no_default and i > first_default ] if invalid_default_values: raise ValueError( f"Some arguments {invalid_default_values} do not have " "default value, but they are positioned after those with " "default values. This can not be expressed with ArgSpec." ) return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:])) def getfullargspec(obj): """TFDecorator-aware replacement for `inspect.getfullargspec`. This wrapper emulates `inspect.getfullargspec` in[^)]* Python2. Args: obj: A callable, possibly decorated. Returns: The `FullArgSpec` that describes the signature of the outermost decorator that changes the callable's signature. If the callable is not decorated, `inspect.getfullargspec()` will be called directly on the callable. """ decorators, target = tf.__internal__.decorator.unwrap(obj) for d in decorators: if d.decorator_argspec is not None: return _convert_maybe_argspec_to_fullargspec(d.decorator_argspec) return _getfullargspec(target) def getcallargs(*func_and_positional, **named): """TFDecorator-aware replacement for inspect.getcallargs. Args: *func_and_positional: A callable, possibly decorated, followed by any positional arguments that would be passed to `func`. **named: The named argument dictionary that would be passed to `func`. Returns: A dictionary mapping `func`'s named arguments to the values they would receive if `func(*positional, **named)` were called. `getcallargs` will use the argspec from the outermost decorator that provides it. If no attached decorators modify argspec, the final unwrapped target's argspec will be used. """ func = func_and_positional[0] positional = func_and_positional[1:] argspec = getfullargspec(func) call_args = named.copy() this = getattr(func, "im_self", None) or getattr(func, "__self__", None) if ismethod(func) and this: positional = (this,) + positional remaining_positionals = [ arg for arg in argspec.args if arg not in call_args ] call_args.update(dict(zip(remaining_positionals, positional))) default_count = 0 if not argspec.defaults else len(argspec.defaults) if default_count: for arg, value in zip(argspec.args[-default_count:], argspec.defaults): if arg not in call_args: call_args[arg] = value if argspec.kwonlydefaults is not None: for k, v in argspec.kwonlydefaults.items(): if k not in call_args: call_args[k] = v return call_args def getframeinfo(*args, **kwargs): return _inspect.getframeinfo(*args, **kwargs) def getdoc(obj): """TFDecorator-aware replacement for inspect.getdoc. Args: obj: An object, possibly decorated. Returns: The docstring associated with the object. The outermost-decorated object is intended to have the most complete documentation, so the decorated parameter is not unwrapped. """ return _inspect.getdoc(obj) def getfile(obj): """TFDecorator-aware replacement for inspect.getfile.""" unwrapped_object = tf.__internal__.decorator.unwrap(obj)[1] # Work around for the case when object is a stack frame # and only .pyc files are used. In this case, getfile # might return incorrect path. So, we get the path from f_globals # instead. if ( hasattr(unwrapped_object, "f_globals") and "__file__" in unwrapped_object.f_globals ): return unwrapped_object.f_globals["__file__"] return _inspect.getfile(unwrapped_object) def getmembers(obj, predicate=None): """TFDecorator-aware replacement for inspect.getmembers.""" return _inspect.getmembers(obj, predicate) def getmodule(obj): """TFDecorator-aware replacement for inspect.getmodule.""" return _inspect.getmodule(obj) def getmro(cls): """TFDecorator-aware replacement for inspect.getmro.""" return _inspect.getmro(cls) def getsource(obj): """TFDecorator-aware replacement for inspect.getsource.""" return _inspect.getsource(tf.__internal__.decorator.unwrap(obj)[1]) def getsourcefile(obj): """TFDecorator-aware replacement for inspect.getsourcefile.""" return _inspect.getsourcefile(tf.__internal__.decorator.unwrap(obj)[1]) def getsourcelines(obj): """TFDecorator-aware replacement for inspect.getsourcelines.""" return _inspect.getsourcelines(tf.__internal__.decorator.unwrap(obj)[1]) def isbuiltin(obj): """TFDecorator-aware replacement for inspect.isbuiltin.""" return _inspect.isbuiltin(tf.__internal__.decorator.unwrap(obj)[1]) def isclass(obj): """TFDecorator-aware replacement for inspect.isclass.""" return _inspect.isclass(tf.__internal__.decorator.unwrap(obj)[1]) def isfunction(obj): """TFDecorator-aware replacement for inspect.isfunction.""" return _inspect.isfunction(tf.__internal__.decorator.unwrap(obj)[1]) def isframe(obj): """TFDecorator-aware replacement for inspect.ismodule.""" return _inspect.isframe(tf.__internal__.decorator.unwrap(obj)[1]) def isgenerator(obj): """TFDecorator-aware replacement for inspect.isgenerator.""" return _inspect.isgenerator(tf.__internal__.decorator.unwrap(obj)[1]) def isgeneratorfunction(obj): """TFDecorator-aware replacement for inspect.isgeneratorfunction.""" return _inspect.isgeneratorfunction( tf.__internal__.decorator.unwrap(obj)[1] ) def ismethod(obj): """TFDecorator-aware replacement for inspect.ismethod.""" return _inspect.ismethod(tf.__internal__.decorator.unwrap(obj)[1]) def ismodule(obj): """TFDecorator-aware replacement for inspect.ismodule.""" return _inspect.ismodule(tf.__internal__.decorator.unwrap(obj)[1]) def isroutine(obj): """TFDecorator-aware replacement for inspect.isroutine.""" return _inspect.isroutine(tf.__internal__.decorator.unwrap(obj)[1]) def stack(context=1): """TFDecorator-aware replacement for inspect.stack.""" return _inspect.stack(context)[1:]