322 lines
11 KiB
Python
322 lines
11 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Live entity inspection utilities.
|
|
|
|
This module contains whatever inspect doesn't offer out of the box.
|
|
"""
|
|
|
|
import builtins
|
|
import inspect
|
|
import itertools
|
|
import linecache
|
|
import sys
|
|
import threading
|
|
import types
|
|
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
# This lock seems to help avoid linecache concurrency errors.
|
|
_linecache_lock = threading.Lock()
|
|
|
|
# Cache all the builtin elements in a frozen set for faster lookup.
|
|
_BUILTIN_FUNCTION_IDS = frozenset(id(v) for v in builtins.__dict__.values())
|
|
|
|
|
|
def islambda(f):
|
|
if not tf_inspect.isfunction(f):
|
|
return False
|
|
# TODO(mdan): Look into checking the only the code object.
|
|
if not (hasattr(f, '__name__') and hasattr(f, '__code__')):
|
|
return False
|
|
# Some wrappers can rename the function, but changing the name of the
|
|
# code object is harder.
|
|
return ((f.__name__ == '<lambda>') or (f.__code__.co_name == '<lambda>'))
|
|
|
|
|
|
def isnamedtuple(f):
|
|
"""Returns True if the argument is a namedtuple-like."""
|
|
if not (tf_inspect.isclass(f) and issubclass(f, tuple)):
|
|
return False
|
|
if not hasattr(f, '_fields'):
|
|
return False
|
|
fields = getattr(f, '_fields')
|
|
if not isinstance(fields, tuple):
|
|
return False
|
|
if not all(isinstance(f, str) for f in fields):
|
|
return False
|
|
return True
|
|
|
|
|
|
def isbuiltin(f):
|
|
"""Returns True if the argument is a built-in function."""
|
|
if id(f) in _BUILTIN_FUNCTION_IDS:
|
|
return True
|
|
elif isinstance(f, types.BuiltinFunctionType):
|
|
return True
|
|
elif inspect.isbuiltin(f):
|
|
return True
|
|
elif f is eval:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def isconstructor(cls):
|
|
"""Returns True if the argument is an object constructor.
|
|
|
|
In general, any object of type class is a constructor, with the exception
|
|
of classes created using a callable metaclass.
|
|
See below for why a callable metaclass is not a trivial combination:
|
|
https://docs.python.org/2.7/reference/datamodel.html#customizing-class-creation
|
|
|
|
Args:
|
|
cls: Any
|
|
|
|
Returns:
|
|
Bool
|
|
"""
|
|
return (inspect.isclass(cls) and
|
|
not (issubclass(cls.__class__, type) and
|
|
hasattr(cls.__class__, '__call__') and
|
|
cls.__class__.__call__ is not type.__call__))
|
|
|
|
|
|
def _fix_linecache_record(obj):
|
|
"""Fixes potential corruption of linecache in the presence of functools.wraps.
|
|
|
|
functools.wraps modifies the target object's __module__ field, which seems
|
|
to confuse linecache in special instances, for example when the source is
|
|
loaded from a .par file (see https://google.github.io/subpar/subpar.html).
|
|
|
|
This function simply triggers a call to linecache.updatecache when a mismatch
|
|
was detected between the object's __module__ property and the object's source
|
|
file.
|
|
|
|
Args:
|
|
obj: Any
|
|
"""
|
|
if hasattr(obj, '__module__'):
|
|
obj_file = inspect.getfile(obj)
|
|
obj_module = obj.__module__
|
|
|
|
# A snapshot of the loaded modules helps avoid "dict changed size during
|
|
# iteration" errors.
|
|
loaded_modules = tuple(sys.modules.values())
|
|
for m in loaded_modules:
|
|
if hasattr(m, '__file__') and m.__file__ == obj_file:
|
|
if obj_module is not m:
|
|
linecache.updatecache(obj_file, m.__dict__)
|
|
|
|
|
|
def getimmediatesource(obj):
|
|
"""A variant of inspect.getsource that ignores the __wrapped__ property."""
|
|
with _linecache_lock:
|
|
_fix_linecache_record(obj)
|
|
lines, lnum = inspect.findsource(obj)
|
|
return ''.join(inspect.getblock(lines[lnum:]))
|
|
|
|
|
|
def getnamespace(f):
|
|
"""Returns the complete namespace of a function.
|
|
|
|
Namespace is defined here as the mapping of all non-local variables to values.
|
|
This includes the globals and the closure variables. Note that this captures
|
|
the entire globals collection of the function, and may contain extra symbols
|
|
that it does not actually use.
|
|
|
|
Args:
|
|
f: User defined function.
|
|
|
|
Returns:
|
|
A dict mapping symbol names to values.
|
|
"""
|
|
namespace = dict(f.__globals__)
|
|
closure = f.__closure__
|
|
freevars = f.__code__.co_freevars
|
|
if freevars and closure:
|
|
for name, cell in zip(freevars, closure):
|
|
try:
|
|
namespace[name] = cell.cell_contents
|
|
except ValueError:
|
|
# Cell contains undefined variable, omit it from the namespace.
|
|
pass
|
|
return namespace
|
|
|
|
|
|
def getqualifiedname(namespace, object_, max_depth=5, visited=None):
|
|
"""Returns the name by which a value can be referred to in a given namespace.
|
|
|
|
If the object defines a parent module, the function attempts to use it to
|
|
locate the object.
|
|
|
|
This function will recurse inside modules, but it will not search objects for
|
|
attributes. The recursion depth is controlled by max_depth.
|
|
|
|
Args:
|
|
namespace: Dict[str, Any], the namespace to search into.
|
|
object_: Any, the value to search.
|
|
max_depth: Optional[int], a limit to the recursion depth when searching
|
|
inside modules.
|
|
visited: Optional[Set[int]], ID of modules to avoid visiting.
|
|
Returns: Union[str, None], the fully-qualified name that resolves to the value
|
|
o, or None if it couldn't be found.
|
|
"""
|
|
if visited is None:
|
|
visited = set()
|
|
|
|
# Copy the dict to avoid "changed size error" during concurrent invocations.
|
|
# TODO(mdan): This is on the hot path. Can we avoid the copy?
|
|
namespace = dict(namespace)
|
|
|
|
for name in namespace:
|
|
# The value may be referenced by more than one symbol, case in which
|
|
# any symbol will be fine. If the program contains symbol aliases that
|
|
# change over time, this may capture a symbol that will later point to
|
|
# something else.
|
|
# TODO(mdan): Prefer the symbol that matches the value type name.
|
|
if object_ is namespace[name]:
|
|
return name
|
|
|
|
# If an object is not found, try to search its parent modules.
|
|
parent = tf_inspect.getmodule(object_)
|
|
if (parent is not None and parent is not object_ and parent is not namespace):
|
|
# No limit to recursion depth because of the guard above.
|
|
parent_name = getqualifiedname(
|
|
namespace, parent, max_depth=0, visited=visited)
|
|
if parent_name is not None:
|
|
name_in_parent = getqualifiedname(
|
|
parent.__dict__, object_, max_depth=0, visited=visited)
|
|
assert name_in_parent is not None, (
|
|
'An object should always be found in its owner module')
|
|
return '{}.{}'.format(parent_name, name_in_parent)
|
|
|
|
if max_depth:
|
|
# Iterating over a copy prevents "changed size due to iteration" errors.
|
|
# It's unclear why those occur - suspecting new modules may load during
|
|
# iteration.
|
|
for name in namespace.keys():
|
|
value = namespace[name]
|
|
if tf_inspect.ismodule(value) and id(value) not in visited:
|
|
visited.add(id(value))
|
|
name_in_module = getqualifiedname(value.__dict__, object_,
|
|
max_depth - 1, visited)
|
|
if name_in_module is not None:
|
|
return '{}.{}'.format(name, name_in_module)
|
|
return None
|
|
|
|
|
|
def getdefiningclass(m, owner_class):
|
|
"""Resolves the class (e.g. one of the superclasses) that defined a method."""
|
|
method_name = m.__name__
|
|
for super_class in inspect.getmro(owner_class):
|
|
if ((hasattr(super_class, '__dict__') and
|
|
method_name in super_class.__dict__) or
|
|
(hasattr(super_class, '__slots__') and
|
|
method_name in super_class.__slots__)):
|
|
return super_class
|
|
return owner_class
|
|
|
|
|
|
def getmethodclass(m):
|
|
"""Resolves a function's owner, e.g.
|
|
|
|
a method's class.
|
|
|
|
Note that this returns the object that the function was retrieved from, not
|
|
necessarily the class where it was defined.
|
|
|
|
This function relies on Python stack frame support in the interpreter, and
|
|
has the same limitations that inspect.currentframe.
|
|
|
|
Limitations. This function will only work correctly if the owned class is
|
|
visible in the caller's global or local variables.
|
|
|
|
Args:
|
|
m: A user defined function
|
|
|
|
Returns:
|
|
The class that this function was retrieved from, or None if the function
|
|
is not an object or class method, or the class that owns the object or
|
|
method is not visible to m.
|
|
|
|
Raises:
|
|
ValueError: if the class could not be resolved for any unexpected reason.
|
|
"""
|
|
|
|
# Callable objects: return their own class.
|
|
if (not hasattr(m, '__name__') and hasattr(m, '__class__') and
|
|
hasattr(m, '__call__')):
|
|
if isinstance(m.__class__, type):
|
|
return m.__class__
|
|
|
|
# Instance and class: return the class of "self".
|
|
m_self = getattr(m, '__self__', None)
|
|
if m_self is not None:
|
|
if inspect.isclass(m_self):
|
|
return m_self
|
|
return m_self.__class__
|
|
|
|
# Class, static and unbound methods: search all defined classes in any
|
|
# namespace. This is inefficient but more robust a method.
|
|
owners = []
|
|
caller_frame = tf_inspect.currentframe().f_back
|
|
try:
|
|
# TODO(mdan): This doesn't consider cell variables.
|
|
# TODO(mdan): This won't work if the owner is hidden inside a container.
|
|
# Cell variables may be pulled using co_freevars and the closure.
|
|
for v in itertools.chain(caller_frame.f_locals.values(),
|
|
caller_frame.f_globals.values()):
|
|
if hasattr(v, m.__name__):
|
|
candidate = getattr(v, m.__name__)
|
|
# Py2 methods may be bound or unbound, extract im_func to get the
|
|
# underlying function.
|
|
if hasattr(candidate, 'im_func'):
|
|
candidate = candidate.im_func
|
|
if hasattr(m, 'im_func'):
|
|
m = m.im_func
|
|
if candidate is m:
|
|
owners.append(v)
|
|
finally:
|
|
del caller_frame
|
|
|
|
if owners:
|
|
if len(owners) == 1:
|
|
return owners[0]
|
|
|
|
# If multiple owners are found, and are not subclasses, raise an error.
|
|
owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners)
|
|
for o in owner_types:
|
|
if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)):
|
|
return o
|
|
raise ValueError('Found too many owners of %s: %s' % (m, owners))
|
|
|
|
return None
|
|
|
|
|
|
def getfutureimports(entity):
|
|
"""Detects what future imports are necessary to safely execute entity source.
|
|
|
|
Args:
|
|
entity: Any object
|
|
|
|
Returns:
|
|
A tuple of future strings
|
|
"""
|
|
if not (tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity)):
|
|
return tuple()
|
|
return tuple(
|
|
sorted(name for name, value in entity.__globals__.items()
|
|
if getattr(value, '__module__', None) == '__future__'))
|