# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`.""" import gast from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import templates # TODO(mdan): Properly extract boolean ops according to lazy eval rules. # Note that this isn't completely safe either, because tensors may have control # dependencies. # Note that for loops that should be done after the loop was converted to # tf.while_loop so that the expanded conditionals are properly scoped. # Used to signal that an operand is safe for non-lazy evaluation. SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' LOGICAL_OPERATORS = { gast.And: 'ag__.and_', gast.Not: 'ag__.not_', gast.Or: 'ag__.or_', } EQUALITY_OPERATORS = { gast.Eq: 'ag__.eq', gast.NotEq: 'ag__.not_eq', } class LogicalExpressionTransformer(converter.Base): """Converts logical expressions to corresponding TF calls.""" def _overload_of(self, operator): op_type = type(operator) if op_type in LOGICAL_OPERATORS: return LOGICAL_OPERATORS[op_type] if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS): if op_type in EQUALITY_OPERATORS: return EQUALITY_OPERATORS[op_type] return None def _as_lambda(self, expr): return templates.replace_as_expression('lambda: expr', expr=expr) def _as_binary_function(self, func_name, arg1, arg2): return templates.replace_as_expression( 'func_name(arg1, arg2)', func_name=parser.parse_expression(func_name), arg1=arg1, arg2=arg2) def _as_binary_operation(self, op, arg1, arg2): template = templates.replace_as_expression( 'arg1 is arg2', # Note: `is` will be replaced with `op` below. arg1=arg1, arg2=arg2) template.ops[0] = op return template def _as_unary_function(self, func_name, arg): return templates.replace_as_expression( 'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg) def _process_binop(self, op, left, right): overload = self._overload_of(op) if overload is None: return self._as_binary_operation(op, left, right) return self._as_binary_function(overload, left, right) def visit_Compare(self, node): node = self.generic_visit(node) ops_and_comps = list(zip(node.ops, node.comparators)) left = node.left # Repeated comparisons are converted to conjunctions: # a < b < c -> a < b and b < c op_tree = None while ops_and_comps: op, right = ops_and_comps.pop(0) binary_comparison = self._process_binop(op, left, right) if op_tree is not None: op_tree = self._as_binary_function('ag__.and_', self._as_lambda(op_tree), self._as_lambda(binary_comparison)) else: op_tree = binary_comparison left = right assert op_tree is not None return op_tree def visit_UnaryOp(self, node): node = self.generic_visit(node) overload = self._overload_of(node.op) if overload is None: return node return self._as_unary_function(overload, node.operand) def visit_BoolOp(self, node): node = self.generic_visit(node) node_values = node.values right = node.values.pop() while node_values: left = node_values.pop() right = self._as_binary_function( self._overload_of(node.op), self._as_lambda(left), self._as_lambda(right)) return right def transform(node, ctx): transformer = LogicalExpressionTransformer(ctx) return transformer.visit(node)