3RNN/Lib/site-packages/tensorflow/tools/common/public_api.py
2024-05-26 19:49:15 +02:00

148 lines
5.3 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visitor restricting traversal to only the public tensorflow API."""
import re
from tensorflow.python.util import tf_inspect
class PublicAPIVisitor:
"""Visitor to use with `traverse` to visit exactly the public TF API."""
def __init__(self, visitor):
"""Constructor.
`visitor` should be a callable suitable as a visitor for `traverse`. It will
be called only for members of the public TensorFlow API.
Args:
visitor: A visitor to call for the public API.
"""
self._visitor = visitor
self._root_name = 'tf'
# Modules/classes we want to suppress entirely.
self._private_map = {
'tf': [
'compiler',
'core',
# TODO(scottzhu): See b/227410870 for more details. Currently
# dtensor API is exposed under tf.experimental.dtensor, but in the
# meantime, we have tensorflow/dtensor directory which will be treat
# as a python package. We want to avoid step into the
# tensorflow/dtensor directory when visit the API.
# When the tf.dtensor becomes the public API, it will actually pick
# up from tf.compat.v2.dtensor as priority and hide the
# tensorflow/dtensor package.
'security',
'dtensor',
'python',
'tsl', # TODO(tlongeri): Remove after TSL is moved out of TF.
],
# Some implementations have this internal module that we shouldn't
# expose.
'tf.flags': ['cpp_flags'],
}
# Modules/classes we do not want to descend into if we hit them. Usually,
# system modules exposed through platforms for compatibility reasons.
# Each entry maps a module path to a name to ignore in traversal.
self._do_not_descend_map = {
'tf': [
'examples',
'flags', # Don't add flags
# TODO(drpng): This can be removed once sealed off.
'platform',
# TODO(drpng): This can be removed once sealed.
'pywrap_tensorflow',
# TODO(drpng): This can be removed once sealed.
'user_ops',
'tools',
'tensorboard',
],
## Everything below here is legitimate.
# It'll stay, but it's not officially part of the API.
'tf.app': ['flags'],
# Imported for compatibility between py2/3.
'tf.test': ['mock'],
}
@property
def private_map(self):
"""A map from parents to symbols that should not be included at all.
This map can be edited, but it should not be edited once traversal has
begun.
Returns:
The map marking symbols to not include.
"""
return self._private_map
@property
def do_not_descend_map(self):
"""A map from parents to symbols that should not be descended into.
This map can be edited, but it should not be edited once traversal has
begun.
Returns:
The map marking symbols to not explore.
"""
return self._do_not_descend_map
def set_root_name(self, root_name):
"""Override the default root name of 'tf'."""
self._root_name = root_name
def _is_private(self, path, name, obj=None):
"""Return whether a name is private."""
# TODO(wicke): Find out what names to exclude.
del obj # Unused.
return ((path in self._private_map and name in self._private_map[path]) or
(name.startswith('_') and not re.match('__.*__$', name) or
name in ['__base__', '__class__', '__next_in_mro__']))
def _do_not_descend(self, path, name):
"""Safely queries if a specific fully qualified name should be excluded."""
return (path in self._do_not_descend_map and
name in self._do_not_descend_map[path])
def __call__(self, path, parent, children):
"""Visitor interface, see `traverse` for details."""
# Avoid long waits in cases of pretty unambiguous failure.
if tf_inspect.ismodule(parent) and len(path.split('.')) > 10:
raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a '
'problem with an accidental public import.' %
(self._root_name, path))
# Includes self._root_name
full_path = '.'.join([self._root_name, path]) if path else self._root_name
# Remove things that are not visible.
for name, child in list(children):
if self._is_private(full_path, name, child):
children.remove((name, child))
self._visitor(path, parent, children)
# Remove things that are visible, but which should not be descended into.
for name, child in list(children):
if self._do_not_descend(full_path, name):
children.remove((name, child))