# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Converting code to AST. Adapted from Tangent. """ import ast import inspect import io import linecache import re import sys import textwrap import tokenize import astunparse import gast from tensorflow.python.autograph.pyct import errors from tensorflow.python.autograph.pyct import inspect_utils from tensorflow.python.util import tf_inspect PY2_PREAMBLE = textwrap.dedent(""" """) PY3_PREAMBLE = '' MAX_SIZE = 0 if sys.version_info >= (3, 9): astunparse = ast if sys.version_info >= (3,): STANDARD_PREAMBLE = PY3_PREAMBLE MAX_SIZE = sys.maxsize else: STANDARD_PREAMBLE = PY2_PREAMBLE MAX_SIZE = sys.maxint STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__') _LEADING_WHITESPACE = re.compile(r'\s*') def _unfold_continuations(code_string): """Removes any backslash line continuations from the code.""" return code_string.replace('\\\n', '') def dedent_block(code_string): """Dedents a code so that its first line starts at row zero.""" code_string = _unfold_continuations(code_string) token_gen = tokenize.generate_tokens(io.StringIO(code_string).readline) block_indentation = None tokens = [] try: for tok in token_gen: tokens.append(tok) except tokenize.TokenError: # Resolution of lambda functions may yield incomplete code, which can # in turn generate this error. We silently ignore this error because the # parser may still be able to deal with it. pass for tok in tokens: tok_type, tok_string, _, _, _ = tok if tok_type == tokenize.INDENT: block_indentation = tok_string block_level = len(block_indentation) break elif tok_type not in ( tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT): block_indentation = '' break if not block_indentation: return code_string block_level = len(block_indentation) first_indent_uses_tabs = '\t' in block_indentation for i, tok in enumerate(tokens): tok_type, tok_string, _, _, _ = tok if tok_type == tokenize.INDENT: if ((' ' in tok_string and first_indent_uses_tabs) or ('\t' in tok_string and not first_indent_uses_tabs)): # TODO(mdan): We could attempt to convert tabs to spaces by unix rule. # See: # https://docs.python.org/3/reference/lexical_analysis.html#indentation raise errors.UnsupportedLanguageElementError( 'code mixing tabs and spaces for indentation is not allowed') if len(tok_string) >= block_level: tok_string = tok_string[block_level:] tokens[i] = (tok_type, tok_string) new_code = tokenize.untokenize(tokens) # Note: untokenize respects the line structure, but not the whitespace within # lines. For example, `def foo()` may be untokenized as `def foo ()` # So instead of using the output of dedent, we match the leading whitespace # on each line. dedented_code = [] for line, new_line in zip(code_string.split('\n'), new_code.split('\n')): original_indent = re.match(_LEADING_WHITESPACE, line).group() new_indent = re.match(_LEADING_WHITESPACE, new_line).group() if len(original_indent) > len(new_indent): dedented_line = line[len(original_indent) - len(new_indent):] else: dedented_line = line dedented_code.append(dedented_line) new_code = '\n'.join(dedented_code) return new_code def parse_entity(entity, future_features): """Returns the AST and source code of given entity. Args: entity: Any, Python function/method/class future_features: Iterable[Text], future features to use (e.g. 'print_statement'). See https://docs.python.org/2/reference/simple_stmts.html#future Returns: gast.AST, Text: the parsed AST node; the source code that was parsed to generate the AST (including any prefixes that this function may have added). """ if inspect_utils.islambda(entity): return _parse_lambda(entity) try: original_source = inspect_utils.getimmediatesource(entity) except OSError as e: raise errors.InaccessibleSourceCodeError( f'Unable to locate the source code of {entity}. Note that functions' ' defined in certain environments, like the interactive Python shell,' ' do not expose their source code. If that is the case, you should' ' define them in a .py source file. If you are certain the code is' ' graph-compatible, wrap the call using' f' @tf.autograph.experimental.do_not_convert. Original error: {e}') source = dedent_block(original_source) future_statements = tuple( 'from __future__ import {}'.format(name) for name in future_features) source = '\n'.join(future_statements + (source,)) return parse(source, preamble_len=len(future_features)), source def _without_context(node, lines, minl, maxl): """Returns a clean node and source code without indenting and context.""" for n in gast.walk(node): lineno = getattr(n, 'lineno', None) if lineno is not None: n.lineno = lineno - minl end_lineno = getattr(n, 'end_lineno', None) if end_lineno is not None: n.end_lineno = end_lineno - minl code_lines = lines[minl - 1:maxl] # Attempt to clean up surrounding context code. end_col_offset = getattr(node, 'end_col_offset', None) if end_col_offset is not None: # This is only available in 3.8. code_lines[-1] = code_lines[-1][:end_col_offset] col_offset = getattr(node, 'col_offset', None) if col_offset is None: # Older Python: try to find the "lambda" token. This is brittle. match = re.search(r'(?