297 lines
9.7 KiB
Python
297 lines
9.7 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Container for origin source code information before AutoGraph compilation."""
|
|
import collections
|
|
import difflib
|
|
import io
|
|
import os
|
|
import tokenize
|
|
|
|
import gast
|
|
|
|
from tensorflow.python.autograph.pyct import anno
|
|
from tensorflow.python.autograph.pyct import ast_util
|
|
from tensorflow.python.autograph.pyct import parser
|
|
from tensorflow.python.autograph.pyct import pretty_printer
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
|
|
class LineLocation(
|
|
collections.namedtuple('LineLocation', ('filename', 'lineno'))):
|
|
"""Similar to Location, but without column information.
|
|
|
|
Attributes:
|
|
filename: Text
|
|
lineno: int, 1-based
|
|
"""
|
|
pass
|
|
|
|
|
|
class Location(
|
|
collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))):
|
|
"""Encodes code location information.
|
|
|
|
Attributes:
|
|
filename: Text
|
|
lineno: int, 1-based
|
|
col_offset: int
|
|
line_loc: LineLocation
|
|
"""
|
|
|
|
@property
|
|
def line_loc(self):
|
|
return LineLocation(self.filename, self.lineno)
|
|
|
|
|
|
class OriginInfo(
|
|
collections.namedtuple(
|
|
'OriginInfo',
|
|
('loc', 'function_name', 'source_code_line', 'comment'))):
|
|
"""Container for information about the source code before conversion.
|
|
|
|
Attributes:
|
|
loc: Location
|
|
function_name: Optional[Text]
|
|
source_code_line: Text
|
|
comment: Optional[Text]
|
|
"""
|
|
|
|
def as_frame(self):
|
|
"""Returns a 4-tuple consistent with the return of traceback.extract_tb."""
|
|
return (self.loc.filename, self.loc.lineno, self.function_name,
|
|
self.source_code_line)
|
|
|
|
def __repr__(self):
|
|
if self.loc.filename:
|
|
return '{}:{}:{}'.format(
|
|
os.path.split(self.loc.filename)[1], self.loc.lineno,
|
|
self.loc.col_offset)
|
|
return '<no file>:{}:{}'.format(self.loc.lineno, self.loc.col_offset)
|
|
|
|
|
|
# TODO(mdan): This source map should be a class - easier to refer to.
|
|
def create_source_map(nodes, code, filepath):
|
|
"""Creates a source map between an annotated AST and the code it compiles to.
|
|
|
|
Note: this function assumes nodes nodes, code and filepath correspond to the
|
|
same code.
|
|
|
|
Args:
|
|
nodes: Iterable[ast.AST, ...], one or more AST modes.
|
|
code: Text, the source code in which nodes are found.
|
|
filepath: Text
|
|
|
|
Returns:
|
|
Dict[LineLocation, OriginInfo], mapping locations in code to locations
|
|
indicated by origin annotations in node.
|
|
"""
|
|
reparsed_nodes = parser.parse(code, preamble_len=0, single_node=False)
|
|
for node in reparsed_nodes:
|
|
resolve(node, code, filepath, node.lineno, node.col_offset)
|
|
|
|
source_map = {}
|
|
|
|
try:
|
|
for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
|
|
# Note: generated code might not be mapped back to its origin.
|
|
# TODO(mdan): Generated code should always be mapped to something.
|
|
origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
|
|
final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
|
|
if origin_info is None or final_info is None:
|
|
continue
|
|
|
|
# Note: the keys are by line only, excluding the column offset.
|
|
line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno)
|
|
|
|
existing_origin = source_map.get(line_loc)
|
|
if existing_origin is not None:
|
|
# Overlaps may exist because of child nodes, but almost never to
|
|
# different line locations. Exception make decorated functions, where
|
|
# both lines are mapped to the same line in the AST.
|
|
|
|
# Line overlaps: keep bottom node.
|
|
if existing_origin.loc.line_loc == origin_info.loc.line_loc:
|
|
if existing_origin.loc.lineno >= origin_info.loc.lineno:
|
|
continue
|
|
|
|
# In case of column overlaps, keep the leftmost node.
|
|
if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
|
|
continue
|
|
|
|
source_map[line_loc] = origin_info
|
|
|
|
except ValueError as err:
|
|
new_msg = 'Inconsistent ASTs detected. This is a bug. Cause: \n'
|
|
new_msg += str(err)
|
|
new_msg += 'Diff:\n'
|
|
|
|
for n, rn in zip(nodes, reparsed_nodes):
|
|
nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
|
|
reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
|
|
diff = difflib.context_diff(
|
|
nodes_str.split('\n'),
|
|
reparsed_nodes_str.split('\n'),
|
|
fromfile='Original nodes',
|
|
tofile='Reparsed nodes',
|
|
n=7)
|
|
diff = '\n'.join(diff)
|
|
new_msg += diff + '\n'
|
|
raise ValueError(new_msg)
|
|
|
|
return source_map
|
|
|
|
|
|
class _Function:
|
|
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
|
|
class OriginResolver(gast.NodeVisitor):
|
|
"""Annotates an AST with additional source information like file name."""
|
|
|
|
def __init__(self, root_node, source_lines, comments_map,
|
|
context_lineno, context_col_offset,
|
|
filepath):
|
|
self._source_lines = source_lines
|
|
self._comments_map = comments_map
|
|
|
|
if (hasattr(root_node, 'decorator_list') and root_node.decorator_list and
|
|
hasattr(root_node.decorator_list[0], 'lineno')):
|
|
# Typical case: functions. The line number of the first decorator
|
|
# is more accurate than the line number of the function itself in
|
|
# 3.8+. In earier versions they coincide.
|
|
self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno
|
|
else:
|
|
# Fall back to the line number of the root node.
|
|
self._lineno_offset = context_lineno - root_node.lineno
|
|
|
|
self._col_offset = context_col_offset - root_node.col_offset
|
|
|
|
self._filepath = filepath
|
|
|
|
self._function_stack = []
|
|
|
|
def _absolute_lineno(self, lineno):
|
|
return lineno + self._lineno_offset
|
|
|
|
def _absolute_col_offset(self, col_offset):
|
|
if col_offset is None:
|
|
return 0
|
|
return col_offset + self._col_offset
|
|
|
|
def _attach_origin_info(self, node):
|
|
lineno = getattr(node, 'lineno', None)
|
|
col_offset = getattr(node, 'col_offset', None)
|
|
|
|
if lineno is None:
|
|
return
|
|
|
|
if self._function_stack:
|
|
function_name = self._function_stack[-1].name
|
|
else:
|
|
function_name = None
|
|
|
|
source_code_line = self._source_lines[lineno - 1]
|
|
comment = self._comments_map.get(lineno)
|
|
|
|
loc = Location(self._filepath, self._absolute_lineno(lineno),
|
|
self._absolute_col_offset(col_offset))
|
|
origin = OriginInfo(loc, function_name, source_code_line, comment)
|
|
anno.setanno(node, 'lineno', lineno)
|
|
anno.setanno(node, anno.Basic.ORIGIN, origin)
|
|
|
|
def visit(self, node):
|
|
entered_function = False
|
|
if isinstance(node, gast.FunctionDef):
|
|
entered_function = True
|
|
self._function_stack.append(_Function(node.name))
|
|
|
|
self._attach_origin_info(node)
|
|
self.generic_visit(node)
|
|
|
|
if entered_function:
|
|
self._function_stack.pop()
|
|
|
|
|
|
def resolve(node, source, context_filepath, context_lineno, context_col_offset):
|
|
"""Adds origin information to an AST, based on the source it was loaded from.
|
|
|
|
This allows us to map the original source code line numbers to generated
|
|
source code.
|
|
|
|
Note: the AST may be a part of a larger context (e.g. a function is part of
|
|
a module that may contain other things). However, this function does not
|
|
assume the source argument contains the entire context, nor that it contains
|
|
only code corresponding to node itself. However, it assumes that node was
|
|
parsed from the given source code.
|
|
For this reason, two extra arguments are required, and they indicate the
|
|
location of the node in the original context.
|
|
|
|
Args:
|
|
node: gast.AST, the AST to annotate.
|
|
source: Text, the source code representing node.
|
|
context_filepath: Text
|
|
context_lineno: int
|
|
context_col_offset: int
|
|
"""
|
|
# TODO(mdan): Pull this to a separate utility.
|
|
code_reader = io.StringIO(source)
|
|
comments_map = {}
|
|
try:
|
|
for token in tokenize.generate_tokens(code_reader.readline):
|
|
tok_type, tok_string, loc, _, _ = token
|
|
srow, _ = loc
|
|
if tok_type == tokenize.COMMENT:
|
|
comments_map[srow] = tok_string.strip()[1:].strip()
|
|
except tokenize.TokenError:
|
|
if isinstance(node, gast.Lambda):
|
|
# Source code resolution in older Python versions is brittle for
|
|
# lambda functions, and may contain garbage.
|
|
pass
|
|
else:
|
|
raise
|
|
|
|
source_lines = source.split('\n')
|
|
visitor = OriginResolver(node, source_lines, comments_map,
|
|
context_lineno, context_col_offset,
|
|
context_filepath)
|
|
visitor.visit(node)
|
|
|
|
|
|
def resolve_entity(node, source, entity):
|
|
"""Like resolve, but extracts the context information from an entity."""
|
|
lines, lineno = tf_inspect.getsourcelines(entity)
|
|
filepath = tf_inspect.getsourcefile(entity)
|
|
|
|
# Poor man's attempt at guessing the column offset: count the leading
|
|
# whitespace. This might not work well with tabs.
|
|
definition_line = lines[0]
|
|
col_offset = len(definition_line) - len(definition_line.lstrip())
|
|
|
|
resolve(node, source, filepath, lineno, col_offset)
|
|
|
|
|
|
def copy_origin(from_node, to_node):
|
|
"""Copies the origin info from a node to another, recursively."""
|
|
origin = anno.Basic.ORIGIN.of(from_node, default=None)
|
|
if origin is None:
|
|
return
|
|
if not isinstance(to_node, (list, tuple)):
|
|
to_node = (to_node,)
|
|
for node in to_node:
|
|
for n in gast.walk(node):
|
|
anno.setanno(n, anno.Basic.ORIGIN, origin)
|