# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Converter for list operations. This includes converting Python lists to TensorArray/TensorList. """ # TODO(mdan): Elaborate the logic here. # TODO(mdan): Does it even make sense to attempt to try to use TAs? # The current rule (always convert to TensorArray) is naive and insufficient. # In general, a better mechanism could look like: # * convert to TensorList by default # * leave as Python list if the user explicitly forbids it # * convert to TensorArray only when complete write once behavior can be # guaranteed (e.g. list comprehensions) import gast from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import templates from tensorflow.python.autograph.pyct.static_analysis import activity from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno class _Statement(object): def __init__(self): self.pop_uses = None class ListTransformer(converter.Base): """Converts lists and related operations to their TF counterpart.""" def visit_List(self, node): node = self.generic_visit(node) template = """ ag__.new_list(elements) """ return templates.replace_as_expression(template, elements=node) def _replace_append_call(self, node): assert len(node.args) == 1 assert isinstance(node.func, gast.Attribute) template = """ target = ag__.list_append(target, element) """ return templates.replace( template, target=node.func.value, element=node.args[0]) def _replace_pop_call(self, node): # Expressions that use pop() are converted to a statement + expression. # # For example: # # print(target.pop()) # # ... is converted to: # # target, target_pop = ag__.list_pop(target) # print(target_pop) # # Here, we just generate the variable name and swap it in, # and _generate_pop_operation will handle the rest. # # Multiple uses of pop() are allowed: # # print(tartget.pop(), target.pop()) # print(tartget.pop().pop()) # assert isinstance(node.func, gast.Attribute) scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) target_node = node.func.value # Attempt to use a related name if one exists. Otherwise use something # generic. if anno.hasanno(target_node, anno.Basic.QN): target_name = anno.getanno(target_node, anno.Basic.QN).ssf() else: target_name = 'list_' pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) stmt = self.state[_Statement] if stmt.pop_uses is None: stmt.pop_uses = [] stmt.pop_uses.append((node, pop_var_name)) return templates.replace_as_expression('var_name', var_name=pop_var_name) def _replace_stack_call(self, node): assert len(node.args) == 1 dtype = self.get_definition_directive( node.args[0], directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) template = """ ag__.list_stack( target, opts=ag__.ListStackOpts( element_dtype=dtype, original_call=orig_call)) """ return templates.replace_as_expression( template, dtype=dtype, target=node.args[0], orig_call=node.func) def visit_Call(self, node): node = self.generic_visit(node) # TODO(mdan): This is insufficient if target is a function argument. # In the case of function arguments, we need to add the list to the # function's return value, because it is being modified. # TODO(mdan): Checking just the name is brittle, can it be improved? if isinstance(node.func, gast.Attribute): func_name = node.func.attr if func_name == 'append' and (len(node.args) == 1): node = self._replace_append_call(node) elif func_name == 'pop' and (len(node.args) <= 1): node = self._replace_pop_call(node) elif (func_name == 'stack' and (len(node.args) == 1) and (not node.keywords or node.keywords[0].arg == 'strict')): # This avoids false positives with keyword args. # TODO(mdan): handle kwargs properly. node = self._replace_stack_call(node) return node def _generate_pop_operation(self, original_call_node, pop_var_name): assert isinstance(original_call_node.func, gast.Attribute) if original_call_node.args: pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. # TODO(mdan): For lists of lists, this won't work. # The reason why it won't work is because it's unclear how to annotate # the list as a "list of lists with a certain element type" when using # operations like `l.pop().pop()`. dtype = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'dtype', default=templates.replace_as_expression('None')) shape = self.get_definition_directive( original_call_node.func.value, directives.set_element_type, 'shape', default=templates.replace_as_expression('None')) template = """ target, pop_var_name = ag__.list_pop( target, element, opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) """ return templates.replace( template, target=original_call_node.func.value, pop_var_name=pop_var_name, element=pop_element, dtype=dtype, shape=shape) def _postprocess_statement(self, node): """Inserts any separate pop() calls that node may use.""" pop_uses = self.state[_Statement].pop_uses if pop_uses: replacements = [] for original_call_node, pop_var_name in pop_uses: replacements.extend( self._generate_pop_operation(original_call_node, pop_var_name)) replacements.append(node) node = replacements self.state[_Statement].exit() return node, None def _visit_and_process_block(self, block): return self.visit_block( block, before_visit=self.state[_Statement].enter, after_visit=self._postprocess_statement) def visit_FunctionDef(self, node): node.args = self.generic_visit(node.args) node.decorator_list = self.visit_block(node.decorator_list) node.body = self._visit_and_process_block(node.body) return node def visit_For(self, node): node.target = self.visit(node.target) node.body = self._visit_and_process_block(node.body) node.orelse = self._visit_and_process_block(node.orelse) return node def visit_While(self, node): node.test = self.visit(node.test) node.body = self._visit_and_process_block(node.body) node.orelse = self._visit_and_process_block(node.orelse) return node def visit_If(self, node): node.test = self.visit(node.test) node.body = self._visit_and_process_block(node.body) node.orelse = self._visit_and_process_block(node.orelse) return node def visit_With(self, node): node.items = self.visit_block(node.items) node.body = self._visit_and_process_block(node.body) return node def transform(node, ctx): node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) return ListTransformer(ctx).visit(node)