414 lines
14 KiB
Python
414 lines
14 KiB
Python
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Handles control flow statements: while, for, if."""
|
||
|
|
||
|
import gast
|
||
|
|
||
|
from tensorflow.python.autograph.core import converter
|
||
|
from tensorflow.python.autograph.lang import directives
|
||
|
from tensorflow.python.autograph.pyct import anno
|
||
|
from tensorflow.python.autograph.pyct import cfg
|
||
|
from tensorflow.python.autograph.pyct import origin_info
|
||
|
from tensorflow.python.autograph.pyct import parser
|
||
|
from tensorflow.python.autograph.pyct import qual_names
|
||
|
from tensorflow.python.autograph.pyct import templates
|
||
|
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||
|
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||
|
from tensorflow.python.autograph.pyct.static_analysis import liveness
|
||
|
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||
|
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
|
||
|
|
||
|
|
||
|
class _Function(object):
|
||
|
|
||
|
scope = None
|
||
|
|
||
|
|
||
|
class ControlFlowTransformer(converter.Base):
|
||
|
"""Transforms control flow structures like loops an conditionals."""
|
||
|
|
||
|
def visit_Lambda(self, node):
|
||
|
with self.state[_Function] as fn:
|
||
|
fn.scope = anno.getanno(node, anno.Static.SCOPE)
|
||
|
return self.generic_visit(node)
|
||
|
|
||
|
def visit_FunctionDef(self, node):
|
||
|
with self.state[_Function] as fn:
|
||
|
fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||
|
return self.generic_visit(node)
|
||
|
|
||
|
def _create_nonlocal_declarations(self, vars_):
|
||
|
vars_ = set(vars_)
|
||
|
results = []
|
||
|
global_vars = self.state[_Function].scope.globals & vars_
|
||
|
|
||
|
if global_vars:
|
||
|
results.append(gast.Global([str(v) for v in global_vars]))
|
||
|
|
||
|
nonlocal_vars = [
|
||
|
v for v in vars_ if not v.is_composite() and v not in global_vars]
|
||
|
if nonlocal_vars:
|
||
|
results.append(gast.Nonlocal([str(v) for v in nonlocal_vars]))
|
||
|
|
||
|
return results
|
||
|
|
||
|
def _create_state_functions(
|
||
|
self, block_vars, nonlocal_declarations, getter_name, setter_name):
|
||
|
if not block_vars:
|
||
|
template = """
|
||
|
def getter_name():
|
||
|
return ()
|
||
|
def setter_name(block_vars):
|
||
|
pass
|
||
|
"""
|
||
|
return templates.replace(
|
||
|
template, getter_name=getter_name, setter_name=setter_name)
|
||
|
|
||
|
guarded_block_vars = []
|
||
|
for v in block_vars:
|
||
|
if v.is_simple():
|
||
|
guarded_block_vars.append(v)
|
||
|
else:
|
||
|
guarded_block_vars.append(
|
||
|
templates.replace_as_expression(
|
||
|
'ag__.ldu(lambda: var_, name)',
|
||
|
var_=v,
|
||
|
name=gast.Constant(str(v), kind=None)))
|
||
|
|
||
|
template = """
|
||
|
def getter_name():
|
||
|
return guarded_state_vars,
|
||
|
def setter_name(vars_):
|
||
|
nonlocal_declarations
|
||
|
state_vars, = vars_
|
||
|
"""
|
||
|
return templates.replace(
|
||
|
template,
|
||
|
nonlocal_declarations=nonlocal_declarations,
|
||
|
getter_name=getter_name,
|
||
|
guarded_state_vars=guarded_block_vars,
|
||
|
setter_name=setter_name,
|
||
|
state_vars=tuple(block_vars))
|
||
|
|
||
|
def _create_loop_options(self, node):
|
||
|
if not anno.hasanno(node, anno.Basic.DIRECTIVES):
|
||
|
return gast.Dict([], [])
|
||
|
|
||
|
loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
|
||
|
if directives.set_loop_options not in loop_directives:
|
||
|
return gast.Dict([], [])
|
||
|
|
||
|
opts_dict = loop_directives[directives.set_loop_options]
|
||
|
str_keys, values = zip(*opts_dict.items())
|
||
|
keys = [gast.Constant(s, kind=None) for s in str_keys]
|
||
|
values = list(values) # ast and gast don't play well with tuples.
|
||
|
return gast.Dict(keys, values)
|
||
|
|
||
|
def _create_undefined_assigns(self, undefined_symbols):
|
||
|
assignments = []
|
||
|
for s in undefined_symbols:
|
||
|
template = '''
|
||
|
var = ag__.Undefined(symbol_name)
|
||
|
'''
|
||
|
assignments += templates.replace(
|
||
|
template,
|
||
|
var=s,
|
||
|
symbol_name=gast.Constant(s.ssf(), kind=None))
|
||
|
return assignments
|
||
|
|
||
|
def _get_block_basic_vars(self, modified, live_in, live_out):
|
||
|
nonlocals = self.state[_Function].scope.nonlocals
|
||
|
basic_scope_vars = []
|
||
|
for s in modified:
|
||
|
if s.is_composite():
|
||
|
# TODO(mdan): Raise an error when this happens for a TF scope.
|
||
|
continue
|
||
|
# Variables not live into or out of the scope are considered local to the
|
||
|
# scope.
|
||
|
if s in live_in or s in live_out or s in nonlocals:
|
||
|
basic_scope_vars.append(s)
|
||
|
continue
|
||
|
return frozenset(basic_scope_vars)
|
||
|
|
||
|
def _get_block_composite_vars(self, modified, live_in):
|
||
|
# The scope variables corresponding to composite symbols (e.g. `self.x`).
|
||
|
composite_scope_vars = []
|
||
|
for s in modified:
|
||
|
if not s.is_composite():
|
||
|
continue
|
||
|
# Mutations made to objects created inside the scope will appear as writes
|
||
|
# to composite symbols. Because these mutations appear as modifications
|
||
|
# made to composite symbols, we check whether the composite's parent is
|
||
|
# actually live into the scope.
|
||
|
# Example:
|
||
|
# while cond:
|
||
|
# x = Foo()
|
||
|
# x.foo = 2 * x.foo # x.foo is live into the scope, but x is not.
|
||
|
#
|
||
|
# Note that some parents might not be symbols - for example, in x['foo'],
|
||
|
# 'foo' is a parent, but it's a literal, not a symbol. We don't check the
|
||
|
# liveness of literals.
|
||
|
support_set_symbols = tuple(
|
||
|
sss for sss in s.support_set if sss.is_symbol())
|
||
|
if not all(sss in live_in for sss in support_set_symbols):
|
||
|
continue
|
||
|
composite_scope_vars.append(s)
|
||
|
return frozenset(composite_scope_vars)
|
||
|
|
||
|
def _get_block_vars(self, node, modified):
|
||
|
"""Determines the variables affected inside a control flow statement."""
|
||
|
defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
|
||
|
live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
|
||
|
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
|
||
|
fn_scope = self.state[_Function].scope
|
||
|
|
||
|
basic_scope_vars = self._get_block_basic_vars(
|
||
|
modified,
|
||
|
live_in,
|
||
|
live_out)
|
||
|
composite_scope_vars = self._get_block_composite_vars(modified, live_in)
|
||
|
scope_vars = tuple(basic_scope_vars | composite_scope_vars)
|
||
|
|
||
|
# Variables that are modified inside the scope, but not defined
|
||
|
# before entering it. Only simple variables must be defined. The
|
||
|
# composite ones will be implicitly checked at runtime.
|
||
|
possibly_undefined = (
|
||
|
modified - defined_in - fn_scope.globals - fn_scope.nonlocals)
|
||
|
undefined = tuple(v for v in possibly_undefined if not v.is_composite())
|
||
|
|
||
|
# Variables that are modified inside the scope, and depend on values outside
|
||
|
# it.
|
||
|
input_only = basic_scope_vars & live_in - live_out
|
||
|
|
||
|
# Place the outputs first, then sort lexicographically.
|
||
|
scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v))
|
||
|
nouts = len(scope_vars) - len(input_only)
|
||
|
|
||
|
return scope_vars, undefined, nouts
|
||
|
|
||
|
def visit_If(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||
|
orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
|
||
|
|
||
|
cond_vars, undefined, nouts = self._get_block_vars(
|
||
|
node, body_scope.bound | orelse_scope.bound)
|
||
|
|
||
|
undefined_assigns = self._create_undefined_assigns(undefined)
|
||
|
|
||
|
nonlocal_declarations = self._create_nonlocal_declarations(cond_vars)
|
||
|
|
||
|
reserved = body_scope.referenced | orelse_scope.referenced
|
||
|
state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
|
||
|
state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
|
||
|
state_functions = self._create_state_functions(
|
||
|
cond_vars, nonlocal_declarations, state_getter_name, state_setter_name)
|
||
|
|
||
|
orelse_body = node.orelse
|
||
|
if not orelse_body:
|
||
|
orelse_body = [gast.Pass()]
|
||
|
|
||
|
template = """
|
||
|
state_functions
|
||
|
def body_name():
|
||
|
nonlocal_declarations
|
||
|
body
|
||
|
def orelse_name():
|
||
|
nonlocal_declarations
|
||
|
orelse
|
||
|
undefined_assigns
|
||
|
ag__.if_stmt(
|
||
|
test,
|
||
|
body_name,
|
||
|
orelse_name,
|
||
|
state_getter_name,
|
||
|
state_setter_name,
|
||
|
(symbol_names,),
|
||
|
nouts)
|
||
|
"""
|
||
|
new_nodes = templates.replace(
|
||
|
template,
|
||
|
body=node.body,
|
||
|
body_name=self.ctx.namer.new_symbol('if_body', reserved),
|
||
|
orelse=orelse_body,
|
||
|
orelse_name=self.ctx.namer.new_symbol('else_body', reserved),
|
||
|
nonlocal_declarations=nonlocal_declarations,
|
||
|
nouts=gast.Constant(nouts, kind=None),
|
||
|
state_functions=state_functions,
|
||
|
state_getter_name=state_getter_name,
|
||
|
state_setter_name=state_setter_name,
|
||
|
symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars),
|
||
|
test=node.test,
|
||
|
undefined_assigns=undefined_assigns)
|
||
|
origin_info.copy_origin(node, new_nodes[-1])
|
||
|
return new_nodes
|
||
|
|
||
|
def visit_While(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||
|
|
||
|
loop_vars, undefined, _ = self._get_block_vars(node, body_scope.bound)
|
||
|
|
||
|
undefined_assigns = self._create_undefined_assigns(undefined)
|
||
|
|
||
|
nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
|
||
|
|
||
|
reserved = body_scope.referenced
|
||
|
state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
|
||
|
state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
|
||
|
state_functions = self._create_state_functions(
|
||
|
loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
|
||
|
|
||
|
opts = self._create_loop_options(node)
|
||
|
|
||
|
template = """
|
||
|
state_functions
|
||
|
def body_name():
|
||
|
nonlocal_declarations
|
||
|
body
|
||
|
def test_name():
|
||
|
return test
|
||
|
undefined_assigns
|
||
|
ag__.while_stmt(
|
||
|
test_name,
|
||
|
body_name,
|
||
|
state_getter_name,
|
||
|
state_setter_name,
|
||
|
(symbol_names,),
|
||
|
opts)
|
||
|
"""
|
||
|
new_nodes = templates.replace(
|
||
|
template,
|
||
|
body=node.body,
|
||
|
body_name=self.ctx.namer.new_symbol('loop_body', reserved),
|
||
|
nonlocal_declarations=nonlocal_declarations,
|
||
|
opts=opts,
|
||
|
state_functions=state_functions,
|
||
|
state_getter_name=state_getter_name,
|
||
|
state_setter_name=state_setter_name,
|
||
|
symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
|
||
|
test=node.test,
|
||
|
test_name=self.ctx.namer.new_symbol('loop_test', reserved),
|
||
|
undefined_assigns=undefined_assigns)
|
||
|
origin_info.copy_origin(node, new_nodes[-1])
|
||
|
return new_nodes
|
||
|
|
||
|
def visit_For(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||
|
iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)
|
||
|
|
||
|
loop_vars, undefined, _ = self._get_block_vars(
|
||
|
node, body_scope.bound | iter_scope.bound)
|
||
|
|
||
|
undefined_assigns = self._create_undefined_assigns(undefined)
|
||
|
|
||
|
nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
|
||
|
|
||
|
reserved = body_scope.referenced | iter_scope.referenced
|
||
|
state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
|
||
|
state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
|
||
|
state_functions = self._create_state_functions(
|
||
|
loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
|
||
|
|
||
|
opts = self._create_loop_options(node)
|
||
|
opts.keys.append(gast.Constant('iterate_names', kind=None))
|
||
|
opts.values.append(gast.Constant(
|
||
|
parser.unparse(node.target, include_encoding_marker=False), kind=None))
|
||
|
|
||
|
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
|
||
|
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
|
||
|
extra_test_name = self.ctx.namer.new_symbol(
|
||
|
'extra_test', reserved)
|
||
|
template = """
|
||
|
def extra_test_name():
|
||
|
nonlocal_declarations
|
||
|
return extra_test_expr
|
||
|
"""
|
||
|
extra_test_function = templates.replace(
|
||
|
template,
|
||
|
extra_test_expr=extra_test,
|
||
|
extra_test_name=extra_test_name,
|
||
|
loop_vars=loop_vars,
|
||
|
nonlocal_declarations=nonlocal_declarations)
|
||
|
else:
|
||
|
extra_test_name = parser.parse_expression('None')
|
||
|
extra_test_function = []
|
||
|
|
||
|
# iterate_arg_name holds a single arg with the iterates, which may be a
|
||
|
# tuple.
|
||
|
iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved)
|
||
|
template = """
|
||
|
iterates = iterate_arg_name
|
||
|
"""
|
||
|
iterate_expansion = templates.replace(
|
||
|
template, iterate_arg_name=iterate_arg_name, iterates=node.target)
|
||
|
origin_info.copy_origin(node, iterate_expansion)
|
||
|
|
||
|
template = """
|
||
|
state_functions
|
||
|
def body_name(iterate_arg_name):
|
||
|
nonlocal_declarations
|
||
|
iterate_expansion
|
||
|
body
|
||
|
extra_test_function
|
||
|
undefined_assigns
|
||
|
ag__.for_stmt(
|
||
|
iterated,
|
||
|
extra_test_name,
|
||
|
body_name,
|
||
|
state_getter_name,
|
||
|
state_setter_name,
|
||
|
(symbol_names,),
|
||
|
opts)
|
||
|
"""
|
||
|
new_nodes = templates.replace(
|
||
|
template,
|
||
|
body=node.body,
|
||
|
body_name=self.ctx.namer.new_symbol('loop_body', reserved),
|
||
|
extra_test_function=extra_test_function,
|
||
|
extra_test_name=extra_test_name,
|
||
|
iterate_arg_name=iterate_arg_name,
|
||
|
iterate_expansion=iterate_expansion,
|
||
|
iterated=node.iter,
|
||
|
nonlocal_declarations=nonlocal_declarations,
|
||
|
opts=opts,
|
||
|
symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
|
||
|
state_functions=state_functions,
|
||
|
state_getter_name=state_getter_name,
|
||
|
state_setter_name=state_setter_name,
|
||
|
undefined_assigns=undefined_assigns)
|
||
|
origin_info.copy_origin(node, new_nodes[-1])
|
||
|
return new_nodes
|
||
|
|
||
|
|
||
|
class AnnotatedDef(reaching_definitions.Definition):
|
||
|
|
||
|
def __init__(self):
|
||
|
super(AnnotatedDef, self).__init__()
|
||
|
self.directives = {}
|
||
|
|
||
|
|
||
|
def transform(node, ctx):
|
||
|
graphs = cfg.build(node)
|
||
|
node = qual_names.resolve(node)
|
||
|
node = activity.resolve(node, ctx, None)
|
||
|
node = reaching_definitions.resolve(node, ctx, graphs)
|
||
|
node = reaching_fndefs.resolve(node, ctx, graphs)
|
||
|
node = liveness.resolve(node, ctx, graphs)
|
||
|
|
||
|
node = ControlFlowTransformer(ctx).visit(node)
|
||
|
return node
|