228 lines
7.9 KiB
Python
228 lines
7.9 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Core conversion logic, serves as main point of access."""
|
|
|
|
import functools
|
|
import inspect
|
|
import sys
|
|
import unittest
|
|
|
|
from tensorflow.python.autograph.core import config
|
|
from tensorflow.python.autograph.pyct import cache
|
|
from tensorflow.python.autograph.pyct import inspect_utils
|
|
from tensorflow.python.autograph.utils import ag_logging as logging
|
|
from tensorflow.python.eager.polymorphic_function import tf_method_target
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
|
|
_ALLOWLIST_CACHE = cache.UnboundInstanceCache()
|
|
|
|
|
|
def _is_of_known_loaded_module(f, module_name):
|
|
mod = sys.modules.get(module_name, None)
|
|
if mod is None:
|
|
return False
|
|
if any(v is not None for v in mod.__dict__.values() if f is v):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _is_known_loaded_type(f, module_name, entity_name):
|
|
"""Tests whether the function or method is an instance of a known type."""
|
|
if (module_name not in sys.modules or
|
|
not hasattr(sys.modules[module_name], entity_name)):
|
|
return False
|
|
type_entity = getattr(sys.modules[module_name], entity_name)
|
|
if isinstance(f, type_entity):
|
|
# The method if of this type. Example:
|
|
#
|
|
# o = ClassType()
|
|
# function(o.method)()
|
|
return True
|
|
# Note: inspect is required here, to avoid unpacking tf.function decorators.
|
|
if inspect.ismethod(f):
|
|
# The unbound method if of this type. Example:
|
|
#
|
|
# class ClassType:
|
|
# @function
|
|
# def method(self):
|
|
# ...
|
|
# o = ClassType()
|
|
# o.method()
|
|
if isinstance(f.__func__, type_entity):
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_unsupported(o):
|
|
"""Checks whether an entity is supported by AutoGraph at all."""
|
|
|
|
# TODO(b/122265385): Remove this bypass.
|
|
if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or
|
|
_is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')):
|
|
logging.warning(
|
|
'{} appears to be decorated by wrapt, which is not yet supported'
|
|
' by AutoGraph. The function will run as-is.'
|
|
' You may still apply AutoGraph before the wrapt decorator.'.format(o))
|
|
logging.log(2, 'Permanently allowed: %s: wrapt decorated', o)
|
|
return True
|
|
|
|
if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
|
|
logging.log(2, 'Permanently allowed: %s: lru_cache', o)
|
|
return True
|
|
|
|
# Constructors are permanently allowed.
|
|
# TODO(mdan): Toggle as experimental feature instead.
|
|
# TODO(b/124016764): Remove this limitation.
|
|
if inspect_utils.isconstructor(o):
|
|
logging.log(2, 'Permanently allowed: %s: constructor', o)
|
|
return True
|
|
|
|
# Other built-in modules are permanently allowed.
|
|
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
|
if any(
|
|
_is_of_known_loaded_module(o, m)
|
|
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
|
|
logging.log(2, 'Permanently allowed: %s: part of builtin module', o)
|
|
return True
|
|
|
|
# Custom ops and kernels are also permanently allowed.
|
|
# See tensorflow.framework.load_library.
|
|
if (hasattr(o, '__module__') and
|
|
hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
|
|
logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o)
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
|
|
def is_allowlisted(
|
|
o, check_call_override=True, allow_namedtuple_subclass=False):
|
|
"""Checks whether an entity is allowed for use in graph mode.
|
|
|
|
Examples of allowed entities include all members of the tensorflow
|
|
package.
|
|
|
|
Args:
|
|
o: A Python entity.
|
|
check_call_override: Reserved for internal use. When set to `False`, it
|
|
disables the rule according to which classes are allowed if their
|
|
__call__ method is allowed.
|
|
allow_namedtuple_subclass: Reserved for internal use. When `True`,
|
|
namedtuple subclasses are not allowed.
|
|
|
|
Returns:
|
|
Boolean
|
|
"""
|
|
# TODO(b/120224672): Fix this.
|
|
if isinstance(o, functools.partial):
|
|
# tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
|
|
# functools.partial objects do not have a __module__ attribute.
|
|
m = functools
|
|
else:
|
|
m = tf_inspect.getmodule(o)
|
|
|
|
# Examples of callables that lack a __module__ property include builtins.
|
|
if hasattr(m, '__name__'):
|
|
for rule in config.CONVERSION_RULES:
|
|
action = rule.get_action(m)
|
|
if action == config.Action.CONVERT:
|
|
logging.log(2, 'Not allowed: %s: %s', o, rule)
|
|
return False
|
|
elif action == config.Action.DO_NOT_CONVERT:
|
|
logging.log(2, 'Allowlisted: %s: %s', o, rule)
|
|
return True
|
|
|
|
# The check for __code__ below is because isgeneratorfunction crashes
|
|
# without one.
|
|
if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
|
|
logging.log(2, 'Allowlisted: %s: generator functions are not converted', o)
|
|
return True
|
|
|
|
if (check_call_override and not tf_inspect.isclass(o) and
|
|
hasattr(o, '__call__')):
|
|
# Callable objects: allowed if their __call__ method is.
|
|
# The type check avoids infinite recursion around the __call__ method
|
|
# of function objects.
|
|
if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__): # pylint: disable=unidiomatic-typecheck
|
|
logging.log(2, 'Allowlisted: %s: object __call__ allowed', o)
|
|
return True
|
|
|
|
owner_class = None
|
|
if tf_inspect.ismethod(o):
|
|
# Methods of allowed classes are also allowed, even if they are
|
|
# bound via user subclasses.
|
|
#
|
|
# For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
|
|
# defined as below. `tf.Foo` is allowed. Then `baz.bar` is also
|
|
# allowed.
|
|
#
|
|
# class Custom(tf.Foo):
|
|
# pass
|
|
#
|
|
# baz = Custom()
|
|
#
|
|
# For the example above, if `Custom` did overload `bar`, then it would no
|
|
# longer be allowed.
|
|
|
|
owner_class = inspect_utils.getmethodclass(o)
|
|
if owner_class is tf_method_target.TfMethodTarget:
|
|
owner_class = o.__self__.target_class
|
|
if owner_class is not None:
|
|
if issubclass(owner_class, unittest.TestCase):
|
|
logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o)
|
|
return True
|
|
|
|
owner_class = inspect_utils.getdefiningclass(o, owner_class)
|
|
if is_allowlisted(
|
|
owner_class,
|
|
check_call_override=False,
|
|
allow_namedtuple_subclass=True):
|
|
logging.log(2, 'Allowlisted: %s: owner is allowed %s', o,
|
|
owner_class)
|
|
return True
|
|
|
|
if inspect_utils.isnamedtuple(o):
|
|
# Due to the way they're constructed, namedtuple types cannot be converted
|
|
# because they don't expose source code. But we assume they are safe for
|
|
# graph mode since they are just containers.
|
|
if allow_namedtuple_subclass:
|
|
if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
|
|
logging.log(2, 'Allowlisted: %s: named tuple', o)
|
|
return True
|
|
else:
|
|
logging.log(2, 'Allowlisted: %s: named tuple or subclass', o)
|
|
return True
|
|
|
|
logging.log(2, 'Not allowed: %s: default rule', o)
|
|
return False
|
|
|
|
|
|
def is_in_allowlist_cache(entity, options):
|
|
try:
|
|
return _ALLOWLIST_CACHE.has(entity, options)
|
|
except TypeError:
|
|
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
|
return False
|
|
|
|
|
|
def cache_allowlisted(entity, options):
|
|
try:
|
|
_ALLOWLIST_CACHE[entity][options] = True
|
|
except TypeError:
|
|
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
|
pass
|