397 lines
13 KiB
Python
397 lines
13 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Converting code to AST.
|
|
|
|
Adapted from Tangent.
|
|
"""
|
|
|
|
import ast
|
|
import inspect
|
|
import io
|
|
import linecache
|
|
import re
|
|
import sys
|
|
import textwrap
|
|
import tokenize
|
|
|
|
import astunparse
|
|
import gast
|
|
|
|
from tensorflow.python.autograph.pyct import errors
|
|
from tensorflow.python.autograph.pyct import inspect_utils
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
|
|
PY2_PREAMBLE = textwrap.dedent("""
|
|
""")
|
|
PY3_PREAMBLE = ''
|
|
MAX_SIZE = 0
|
|
|
|
if sys.version_info >= (3, 9):
|
|
astunparse = ast
|
|
|
|
if sys.version_info >= (3,):
|
|
STANDARD_PREAMBLE = PY3_PREAMBLE
|
|
MAX_SIZE = sys.maxsize
|
|
else:
|
|
STANDARD_PREAMBLE = PY2_PREAMBLE
|
|
MAX_SIZE = sys.maxint
|
|
|
|
STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__')
|
|
|
|
|
|
_LEADING_WHITESPACE = re.compile(r'\s*')
|
|
|
|
|
|
def _unfold_continuations(code_string):
|
|
"""Removes any backslash line continuations from the code."""
|
|
return code_string.replace('\\\n', '')
|
|
|
|
|
|
def dedent_block(code_string):
|
|
"""Dedents a code so that its first line starts at row zero."""
|
|
|
|
code_string = _unfold_continuations(code_string)
|
|
|
|
token_gen = tokenize.generate_tokens(io.StringIO(code_string).readline)
|
|
|
|
block_indentation = None
|
|
tokens = []
|
|
try:
|
|
for tok in token_gen:
|
|
tokens.append(tok)
|
|
except tokenize.TokenError:
|
|
# Resolution of lambda functions may yield incomplete code, which can
|
|
# in turn generate this error. We silently ignore this error because the
|
|
# parser may still be able to deal with it.
|
|
pass
|
|
|
|
for tok in tokens:
|
|
tok_type, tok_string, _, _, _ = tok
|
|
if tok_type == tokenize.INDENT:
|
|
block_indentation = tok_string
|
|
block_level = len(block_indentation)
|
|
break
|
|
elif tok_type not in (
|
|
tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT):
|
|
block_indentation = ''
|
|
break
|
|
|
|
if not block_indentation:
|
|
return code_string
|
|
|
|
block_level = len(block_indentation)
|
|
first_indent_uses_tabs = '\t' in block_indentation
|
|
for i, tok in enumerate(tokens):
|
|
tok_type, tok_string, _, _, _ = tok
|
|
if tok_type == tokenize.INDENT:
|
|
if ((' ' in tok_string and first_indent_uses_tabs)
|
|
or ('\t' in tok_string and not first_indent_uses_tabs)):
|
|
# TODO(mdan): We could attempt to convert tabs to spaces by unix rule.
|
|
# See:
|
|
# https://docs.python.org/3/reference/lexical_analysis.html#indentation
|
|
raise errors.UnsupportedLanguageElementError(
|
|
'code mixing tabs and spaces for indentation is not allowed')
|
|
if len(tok_string) >= block_level:
|
|
tok_string = tok_string[block_level:]
|
|
tokens[i] = (tok_type, tok_string)
|
|
|
|
new_code = tokenize.untokenize(tokens)
|
|
|
|
# Note: untokenize respects the line structure, but not the whitespace within
|
|
# lines. For example, `def foo()` may be untokenized as `def foo ()`
|
|
# So instead of using the output of dedent, we match the leading whitespace
|
|
# on each line.
|
|
dedented_code = []
|
|
for line, new_line in zip(code_string.split('\n'), new_code.split('\n')):
|
|
original_indent = re.match(_LEADING_WHITESPACE, line).group()
|
|
new_indent = re.match(_LEADING_WHITESPACE, new_line).group()
|
|
if len(original_indent) > len(new_indent):
|
|
dedented_line = line[len(original_indent) - len(new_indent):]
|
|
else:
|
|
dedented_line = line
|
|
dedented_code.append(dedented_line)
|
|
new_code = '\n'.join(dedented_code)
|
|
|
|
return new_code
|
|
|
|
|
|
def parse_entity(entity, future_features):
|
|
"""Returns the AST and source code of given entity.
|
|
|
|
Args:
|
|
entity: Any, Python function/method/class
|
|
future_features: Iterable[Text], future features to use (e.g.
|
|
'print_statement'). See
|
|
https://docs.python.org/2/reference/simple_stmts.html#future
|
|
|
|
Returns:
|
|
gast.AST, Text: the parsed AST node; the source code that was parsed to
|
|
generate the AST (including any prefixes that this function may have added).
|
|
"""
|
|
if inspect_utils.islambda(entity):
|
|
return _parse_lambda(entity)
|
|
|
|
try:
|
|
original_source = inspect_utils.getimmediatesource(entity)
|
|
except OSError as e:
|
|
raise errors.InaccessibleSourceCodeError(
|
|
f'Unable to locate the source code of {entity}. Note that functions'
|
|
' defined in certain environments, like the interactive Python shell,'
|
|
' do not expose their source code. If that is the case, you should'
|
|
' define them in a .py source file. If you are certain the code is'
|
|
' graph-compatible, wrap the call using'
|
|
f' @tf.autograph.experimental.do_not_convert. Original error: {e}')
|
|
|
|
source = dedent_block(original_source)
|
|
|
|
future_statements = tuple(
|
|
'from __future__ import {}'.format(name) for name in future_features)
|
|
source = '\n'.join(future_statements + (source,))
|
|
|
|
return parse(source, preamble_len=len(future_features)), source
|
|
|
|
|
|
def _without_context(node, lines, minl, maxl):
|
|
"""Returns a clean node and source code without indenting and context."""
|
|
for n in gast.walk(node):
|
|
lineno = getattr(n, 'lineno', None)
|
|
if lineno is not None:
|
|
n.lineno = lineno - minl
|
|
end_lineno = getattr(n, 'end_lineno', None)
|
|
if end_lineno is not None:
|
|
n.end_lineno = end_lineno - minl
|
|
|
|
code_lines = lines[minl - 1:maxl]
|
|
|
|
# Attempt to clean up surrounding context code.
|
|
|
|
end_col_offset = getattr(node, 'end_col_offset', None)
|
|
if end_col_offset is not None:
|
|
# This is only available in 3.8.
|
|
code_lines[-1] = code_lines[-1][:end_col_offset]
|
|
|
|
col_offset = getattr(node, 'col_offset', None)
|
|
if col_offset is None:
|
|
# Older Python: try to find the "lambda" token. This is brittle.
|
|
match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0])
|
|
if match is not None:
|
|
col_offset = match.start(0)
|
|
|
|
if col_offset is not None:
|
|
code_lines[0] = code_lines[0][col_offset:]
|
|
|
|
code_block = '\n'.join([c.rstrip() for c in code_lines])
|
|
|
|
return node, code_block
|
|
|
|
|
|
def _arg_name(node):
|
|
if node is None:
|
|
return None
|
|
if isinstance(node, gast.Name):
|
|
return node.id
|
|
assert isinstance(node, str)
|
|
return node
|
|
|
|
|
|
def _node_matches_argspec(node, func):
|
|
"""Returns True is node fits the argspec of func."""
|
|
# TODO(mdan): Use just inspect once support for Python 2 is dropped.
|
|
arg_spec = tf_inspect.getfullargspec(func)
|
|
|
|
node_args = tuple(_arg_name(arg) for arg in node.args.args)
|
|
if node_args != tuple(arg_spec.args):
|
|
return False
|
|
|
|
if arg_spec.varargs != _arg_name(node.args.vararg):
|
|
return False
|
|
|
|
if arg_spec.varkw != _arg_name(node.args.kwarg):
|
|
return False
|
|
|
|
node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
|
|
if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _parse_lambda(lam):
|
|
"""Returns the AST and source code of given lambda function.
|
|
|
|
Args:
|
|
lam: types.LambdaType, Python function/method/class
|
|
|
|
Returns:
|
|
gast.AST, Text: the parsed AST node; the source code that was parsed to
|
|
generate the AST (including any prefixes that this function may have added).
|
|
"""
|
|
# TODO(mdan): Use a fast path if the definition is not multi-line.
|
|
# We could detect that the lambda is in a multi-line expression by looking
|
|
# at the surrounding code - an surrounding set of parentheses indicates a
|
|
# potential multi-line definition.
|
|
|
|
mod = inspect.getmodule(lam)
|
|
f = inspect.getsourcefile(lam)
|
|
def_line = lam.__code__.co_firstlineno
|
|
|
|
# This method is more robust that just calling inspect.getsource(mod), as it
|
|
# works in interactive shells, where getsource would fail. This is the
|
|
# same procedure followed by inspect for non-modules:
|
|
# https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772
|
|
lines = linecache.getlines(f, mod.__dict__)
|
|
source = ''.join(lines)
|
|
|
|
# Narrow down to the last node starting before our definition node.
|
|
all_nodes = parse(source, preamble_len=0, single_node=False)
|
|
search_nodes = []
|
|
for node in all_nodes:
|
|
# Also include nodes without a line number, for safety. This is defensive -
|
|
# we don't know whether such nodes might exist, and if they do, whether
|
|
# they are not safe to skip.
|
|
# TODO(mdan): Replace this check with an assertion or skip such nodes.
|
|
if getattr(node, 'lineno', def_line) <= def_line:
|
|
search_nodes.append(node)
|
|
else:
|
|
# Found a node starting past our lambda - can stop the search.
|
|
break
|
|
|
|
# Extract all lambda nodes from the shortlist.
|
|
lambda_nodes = []
|
|
for node in search_nodes:
|
|
lambda_nodes.extend(
|
|
n for n in gast.walk(node) if isinstance(n, gast.Lambda))
|
|
|
|
# Filter down to lambda nodes which span our actual lambda.
|
|
candidates = []
|
|
for ln in lambda_nodes:
|
|
minl, maxl = MAX_SIZE, 0
|
|
for n in gast.walk(ln):
|
|
minl = min(minl, getattr(n, 'lineno', minl))
|
|
lineno = getattr(n, 'lineno', maxl)
|
|
end_lineno = getattr(n, 'end_lineno', None)
|
|
if end_lineno is not None:
|
|
# end_lineno is more precise, but lineno should almost always work too.
|
|
lineno = end_lineno
|
|
maxl = max(maxl, lineno)
|
|
if minl <= def_line <= maxl:
|
|
candidates.append((ln, minl, maxl))
|
|
|
|
# Happy path: exactly one node found.
|
|
if len(candidates) == 1:
|
|
(node, minl, maxl), = candidates # pylint:disable=unbalanced-tuple-unpacking
|
|
return _without_context(node, lines, minl, maxl)
|
|
|
|
elif not candidates:
|
|
lambda_codes = '\n'.join([unparse(l) for l in lambda_nodes])
|
|
raise errors.UnsupportedLanguageElementError(
|
|
f'could not parse the source code of {lam}:'
|
|
f' no matching AST found among candidates:\n{lambda_codes}')
|
|
|
|
# Attempt to narrow down selection by signature is multiple nodes are found.
|
|
matches = [v for v in candidates if _node_matches_argspec(v[0], lam)]
|
|
if len(matches) == 1:
|
|
(node, minl, maxl), = matches
|
|
return _without_context(node, lines, minl, maxl)
|
|
|
|
# Give up if could not narrow down to a single node.
|
|
matches = '\n'.join(
|
|
'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False))
|
|
for i, (node, _, _) in enumerate(matches))
|
|
raise errors.UnsupportedLanguageElementError(
|
|
f'could not parse the source code of {lam}: found multiple definitions'
|
|
' with identical signatures at the location. This error'
|
|
' may be avoided by defining each lambda on a single line and with'
|
|
f' unique argument names. The matching definitions were:\n{matches}')
|
|
|
|
|
|
# TODO(mdan): This should take futures as input instead.
|
|
def parse(src, preamble_len=0, single_node=True):
|
|
"""Returns the AST of given piece of code.
|
|
|
|
Args:
|
|
src: Text
|
|
preamble_len: Int, indicates leading nodes in the parsed AST which should be
|
|
dropped.
|
|
single_node: Bool, whether `src` is assumed to be represented by exactly one
|
|
AST node.
|
|
|
|
Returns:
|
|
ast.AST
|
|
"""
|
|
module_node = gast.parse(src)
|
|
nodes = module_node.body
|
|
if preamble_len:
|
|
nodes = nodes[preamble_len:]
|
|
if single_node:
|
|
if len(nodes) != 1:
|
|
raise ValueError('expected exactly one node, got {}'.format(nodes))
|
|
return nodes[0]
|
|
return nodes
|
|
|
|
|
|
def parse_expression(src):
|
|
"""Returns the AST of given identifier.
|
|
|
|
Args:
|
|
src: A piece of code that represents a single Python expression
|
|
Returns:
|
|
A gast.AST object.
|
|
Raises:
|
|
ValueError: if src does not consist of a single Expression.
|
|
"""
|
|
src = STANDARD_PREAMBLE + src.strip()
|
|
node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True)
|
|
if __debug__:
|
|
if not isinstance(node, gast.Expr):
|
|
raise ValueError(
|
|
'expected exactly one node of type Expr, got {}'.format(node))
|
|
return node.value
|
|
|
|
|
|
def unparse(node, indentation=None, include_encoding_marker=True):
|
|
"""Returns the source code of given AST.
|
|
|
|
Args:
|
|
node: The code to compile, as an AST object.
|
|
indentation: Unused, deprecated. The returning code will always be indented
|
|
at 4 spaces.
|
|
include_encoding_marker: Bool, whether to include a comment on the first
|
|
line to explicitly specify UTF-8 encoding.
|
|
|
|
Returns:
|
|
code: The source code generated from the AST object
|
|
source_mapping: A mapping between the user and AutoGraph generated code.
|
|
"""
|
|
del indentation # astunparse doesn't allow configuring it.
|
|
if not isinstance(node, (list, tuple)):
|
|
node = (node,)
|
|
|
|
codes = []
|
|
if include_encoding_marker:
|
|
codes.append('# coding=utf-8')
|
|
for n in node:
|
|
if isinstance(n, gast.AST):
|
|
ast_n = gast.gast_to_ast(n)
|
|
else:
|
|
ast_n = n
|
|
|
|
if astunparse is ast:
|
|
ast.fix_missing_locations(ast_n) # Only ast needs to call this.
|
|
codes.append(astunparse.unparse(ast_n).strip())
|
|
|
|
return '\n'.join(codes)
|