132 lines
4.3 KiB
Python
132 lines
4.3 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`."""
|
|
|
|
import gast
|
|
|
|
from tensorflow.python.autograph.core import converter
|
|
from tensorflow.python.autograph.pyct import parser
|
|
from tensorflow.python.autograph.pyct import templates
|
|
|
|
# TODO(mdan): Properly extract boolean ops according to lazy eval rules.
|
|
# Note that this isn't completely safe either, because tensors may have control
|
|
# dependencies.
|
|
# Note that for loops that should be done after the loop was converted to
|
|
# tf.while_loop so that the expanded conditionals are properly scoped.
|
|
|
|
# Used to signal that an operand is safe for non-lazy evaluation.
|
|
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
|
|
|
|
|
|
LOGICAL_OPERATORS = {
|
|
gast.And: 'ag__.and_',
|
|
gast.Not: 'ag__.not_',
|
|
gast.Or: 'ag__.or_',
|
|
}
|
|
|
|
EQUALITY_OPERATORS = {
|
|
gast.Eq: 'ag__.eq',
|
|
gast.NotEq: 'ag__.not_eq',
|
|
}
|
|
|
|
|
|
class LogicalExpressionTransformer(converter.Base):
|
|
"""Converts logical expressions to corresponding TF calls."""
|
|
|
|
def _overload_of(self, operator):
|
|
op_type = type(operator)
|
|
if op_type in LOGICAL_OPERATORS:
|
|
return LOGICAL_OPERATORS[op_type]
|
|
if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS):
|
|
if op_type in EQUALITY_OPERATORS:
|
|
return EQUALITY_OPERATORS[op_type]
|
|
return None
|
|
|
|
def _as_lambda(self, expr):
|
|
return templates.replace_as_expression('lambda: expr', expr=expr)
|
|
|
|
def _as_binary_function(self, func_name, arg1, arg2):
|
|
return templates.replace_as_expression(
|
|
'func_name(arg1, arg2)',
|
|
func_name=parser.parse_expression(func_name),
|
|
arg1=arg1,
|
|
arg2=arg2)
|
|
|
|
def _as_binary_operation(self, op, arg1, arg2):
|
|
template = templates.replace_as_expression(
|
|
'arg1 is arg2', # Note: `is` will be replaced with `op` below.
|
|
arg1=arg1,
|
|
arg2=arg2)
|
|
template.ops[0] = op
|
|
return template
|
|
|
|
def _as_unary_function(self, func_name, arg):
|
|
return templates.replace_as_expression(
|
|
'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg)
|
|
|
|
def _process_binop(self, op, left, right):
|
|
overload = self._overload_of(op)
|
|
if overload is None:
|
|
return self._as_binary_operation(op, left, right)
|
|
return self._as_binary_function(overload, left, right)
|
|
|
|
def visit_Compare(self, node):
|
|
node = self.generic_visit(node)
|
|
|
|
ops_and_comps = list(zip(node.ops, node.comparators))
|
|
left = node.left
|
|
|
|
# Repeated comparisons are converted to conjunctions:
|
|
# a < b < c -> a < b and b < c
|
|
op_tree = None
|
|
while ops_and_comps:
|
|
op, right = ops_and_comps.pop(0)
|
|
binary_comparison = self._process_binop(op, left, right)
|
|
if op_tree is not None:
|
|
op_tree = self._as_binary_function('ag__.and_',
|
|
self._as_lambda(op_tree),
|
|
self._as_lambda(binary_comparison))
|
|
else:
|
|
op_tree = binary_comparison
|
|
left = right
|
|
|
|
assert op_tree is not None
|
|
return op_tree
|
|
|
|
def visit_UnaryOp(self, node):
|
|
node = self.generic_visit(node)
|
|
|
|
overload = self._overload_of(node.op)
|
|
if overload is None:
|
|
return node
|
|
|
|
return self._as_unary_function(overload, node.operand)
|
|
|
|
def visit_BoolOp(self, node):
|
|
node = self.generic_visit(node)
|
|
node_values = node.values
|
|
right = node.values.pop()
|
|
while node_values:
|
|
left = node_values.pop()
|
|
right = self._as_binary_function(
|
|
self._overload_of(node.op), self._as_lambda(left),
|
|
self._as_lambda(right))
|
|
return right
|
|
|
|
|
|
def transform(node, ctx):
|
|
transformer = LogicalExpressionTransformer(ctx)
|
|
return transformer.visit(node)
|