# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Handles function calls, by generating compiled function names and calls. Note: this transformer does not rename the top level object being converted; that is the caller's responsibility. Requires function_scopes. """ import gast from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import templates from tensorflow.python.autograph.utils import ag_logging # TODO(mdan): Rename to FunctionCallsTransformer. class _Function(object): no_root = True def __init__(self): self.context_name = None set_trace_warned = False class _ArgTemplateBuilder(object): """Constructs a tuple representing the positional arguments in a call. Example (yes, it's legal Python 3): f(*args1, b, *args2, c, d) -> args1 + (b,) + args2 + (c, d) """ def __init__(self): self._arg_accumulator = [] self._argspec = [] self._finalized = False def _consume_args(self): if self._arg_accumulator: self._argspec.append( gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load())) self._arg_accumulator = [] def add_arg(self, a): self._arg_accumulator.append(a) def add_stararg(self, a): self._consume_args() self._argspec.append( gast.Call( gast.Name( 'tuple', ctx=gast.Load(), annotation=None, type_comment=None), args=[a], keywords=())) def finalize(self): self._consume_args() self._finalized = True def to_ast(self): assert self._finalized if self._argspec: result = self._argspec[0] for i in range(1, len(self._argspec)): result = gast.BinOp(result, gast.Add(), self._argspec[i]) return result return gast.Tuple([], gast.Load()) class CallTreeTransformer(converter.Base): """Transforms the call tree by renaming transformed symbols.""" def visit_Lambda(self, node): if not anno.hasanno(node, 'function_context_name'): # Lambda functions created during the conversion process have no # context manager. return self.generic_visit(node) with self.state[_Function] as fn_scope: fn_scope.context_name = anno.getanno(node, 'function_context_name') return self.generic_visit(node) def visit_FunctionDef(self, node): # Decorators and arg defaults are part of the outer scope. node.decorator_list = self.visit_block(node.decorator_list) node.args.defaults = self.visit_block(node.args.defaults) for i, d in enumerate(node.args.kw_defaults): if d is not None: node.args.kw_defaults[i] = self.visit(d) with self.state[_Function] as fn_scope: # Note: if the conversion process ever creates helper functions, this # assumption will no longer hold. assert anno.hasanno(node, 'function_context_name'), ( 'The function_scopes converter always creates a scope for functions.') fn_scope.context_name = anno.getanno(node, 'function_context_name') node.body = self.visit_block(node.body) if node.returns: node.returns = self.visit(node.returns) return node def visit_With(self, node): # Context manager calls (in node.items) are not converted. node.body = self.visit_block(node.body) return node def _args_to_tuple(self, node): """Ties together all positional and *arg arguments in a single tuple.""" # TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better? # For example for # f(a, b, *args) # instead of writing: # (a, b) + args # just write this? # tuple(a, b, *args) builder = _ArgTemplateBuilder() for a in node.args: if isinstance(a, gast.Starred): builder.add_stararg(a.value) else: builder.add_arg(a) builder.finalize() return builder.to_ast() def _kwargs_to_dict(self, node): """Ties together all keyword and **kwarg arguments in a single dict.""" if node.keywords: return gast.Call( gast.Name( 'dict', ctx=gast.Load(), annotation=None, type_comment=None), args=(), keywords=node.keywords) else: return parser.parse_expression('None') def visit_Call(self, node): full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) function_context_name = self.state[_Function].context_name node = self.generic_visit(node) # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). if full_name.startswith('ag__.'): return node # Calls to the function context manager (inserted by function_scopes) are # also safe. if full_name.startswith(function_context_name + '.'): return node # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use # the normal mechanisms to bypass these literals because they are sensitive # to the frame they are being called from. # TODO(mdan): Generalize this to a "static allowlist" config. if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'): global set_trace_warned if not set_trace_warned: # TODO(mdan): Update and shorten once available on tensorflow.org. ag_logging.warning( 'Detected `pdb.set_trace()` in user code. The code' ' generated by AutoGraph is not optimized for step-by-step' ' debugging. See https://github.com/tensorflow/tensorflow/' 'blob/master/tensorflow/python/autograph/g3doc/reference/' 'debugging.md.') set_trace_warned = True return node if (full_name == 'print' and not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): return node template = """ ag__.converted_call(func, args, kwargs, function_ctx) """ new_call = templates.replace_as_expression( template, func=node.func, args=self._args_to_tuple(node), kwargs=self._kwargs_to_dict(node), function_ctx=function_context_name) return new_call def transform(node, ctx): """Transform function call to the compiled counterparts. Args: node: AST ctx: EntityContext Returns: A tuple (node, new_names): node: The transformed AST new_names: set(string), containing any newly-generated names """ node = qual_names.resolve(node) node = CallTreeTransformer(ctx).visit(node) return node