Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/autograph/converters/return_statements.py

403 lines
13 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Canonicalizes functions with multiple returns to use just one."""
import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
BODY_DEFINITELY_RETURNS = 'BODY_DEFINITELY_RETURNS'
ORELSE_DEFINITELY_RETURNS = 'ORELSE_DEFINITELY_RETURNS'
STMT_DEFINITELY_RETURNS = 'STMT_DEFINITELY_RETURNS'
class _RewriteBlock(object):
def __init__(self):
self.definitely_returns = False
class ConditionalReturnRewriter(converter.Base):
"""Rewrites a pattern where it's unobvious that all paths return a value.
This rewrite allows avoiding intermediate None return values.
The following pattern:
if cond:
<block 1>
return
else:
<block 2>
<block 3>
is converted to:
if cond:
<block 1>
return
else:
<block 2>
<block 3>
and vice-versa (if the else returns, subsequent statements are moved under the
if branch).
"""
def visit_Return(self, node):
self.state[_RewriteBlock].definitely_returns = True
return node
def _postprocess_statement(self, node):
# If the node definitely returns (e.g. it's a with statement with a
# return statement in it), then the current block also definitely returns.
if anno.getanno(node, STMT_DEFINITELY_RETURNS, default=False):
self.state[_RewriteBlock].definitely_returns = True
# The special case: collapse a typical conditional return pattern into
# a single conditional with possibly returns on both branches. This
# reduces the use of None return values, which don't work with TF
# conditionals.
if (isinstance(node, gast.If)
and anno.getanno(node, BODY_DEFINITELY_RETURNS, default=False)):
return node, node.orelse
elif (isinstance(node, gast.If)
and anno.getanno(node, ORELSE_DEFINITELY_RETURNS, default=False)):
return node, node.body
return node, None
def _visit_statement_block(self, node, nodes):
self.state[_RewriteBlock].enter()
new_nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
block_definitely_returns = self.state[_RewriteBlock].definitely_returns
self.state[_RewriteBlock].exit()
return new_nodes, block_definitely_returns
def visit_While(self, node):
node.test = self.visit(node.test)
node.body, _ = self._visit_statement_block(node, node.body)
node.orelse, _ = self._visit_statement_block(node, node.orelse)
return node
def visit_For(self, node):
node.iter = self.visit(node.iter)
node.target = self.visit(node.target)
node.body, _ = self._visit_statement_block(node, node.body)
node.orelse, _ = self._visit_statement_block(node, node.orelse)
return node
def visit_With(self, node):
node.items = self.visit_block(node.items)
node.body, definitely_returns = self._visit_statement_block(node, node.body)
if definitely_returns:
anno.setanno(node, STMT_DEFINITELY_RETURNS, True)
return node
def visit_Try(self, node):
# We could decide whether a 'try' DEFINITELY_RETURNS based on its components
# It is not clear whether we want to do anything with this given
# a 'try' is likely to throw an exception in some circumstances.
node.body, _ = self._visit_statement_block(node, node.body)
node.orelse, _ = self._visit_statement_block(node, node.orelse)
node.finalbody, _ = self._visit_statement_block(node, node.finalbody)
node.handlers = self.visit_block(node.handlers)
return node
def visit_ExceptHandler(self, node):
# To determine whether `try` DEFINITELY_RETURNS we need to revisit this.
node.body, _ = self._visit_statement_block(node, node.body)
return node
def visit_If(self, node):
node.test = self.visit(node.test)
node.body, body_definitely_returns = self._visit_statement_block(
node, node.body)
if body_definitely_returns:
anno.setanno(node, BODY_DEFINITELY_RETURNS, True)
node.orelse, orelse_definitely_returns = self._visit_statement_block(
node, node.orelse)
if orelse_definitely_returns:
anno.setanno(node, ORELSE_DEFINITELY_RETURNS, True)
if body_definitely_returns and orelse_definitely_returns:
self.state[_RewriteBlock].definitely_returns = True
return node
def visit_FunctionDef(self, node):
node.args = self.visit(node.args)
node.body, _ = self._visit_statement_block(node, node.body)
return node
class _Block(object):
def __init__(self):
self.is_function = False
self.return_used = False
self.create_guard_next = False
self.create_guard_now = False
def __repr__(self):
return 'used: {}'.format(
self.return_used)
class _Function(object):
def __init__(self):
self.do_return_var_name = None
self.retval_var_name = None
def __repr__(self):
return 'return control: {}, return value: {}'.format(
self.do_return_var_name, self.retval_var_name)
class ReturnStatementsTransformer(converter.Base):
"""Lowers return statements into variables and conditionals.
Specifically, the following pattern:
<block 1>
return val
<block 2>
is converted to:
do_return = False
retval = None
<block 1>
do_return = True
retval = val
if not do_return:
<block 2>
return retval
The conversion adjusts loops as well:
<block 1>
while cond:
<block 2>
return retval
is converted to:
<block 1>
while not do_return and cond:
<block 2>
do_return = True
retval = val
"""
def __init__(self, ctx, allow_missing_return):
super(ReturnStatementsTransformer, self).__init__(ctx)
self.allow_missing_return = allow_missing_return
def visit_Return(self, node):
for block in reversed(self.state[_Block].stack):
block.return_used = True
block.create_guard_next = True
if block.is_function:
break
retval = node.value if node.value else parser.parse_expression('None')
# Note: If `return <expr> raises, then the return is aborted.
# The try-catch below ensures the variables remain consistent in that case.
template = """
try:
do_return_var_name = True
retval_var_name = retval
except:
do_return_var_name = False
raise
"""
node = templates.replace(
template,
do_return_var_name=self.state[_Function].do_return_var_name,
retval_var_name=self.state[_Function].retval_var_name,
retval=retval)
return node
def _postprocess_statement(self, node):
if not self.state[_Block].return_used:
return node, None
state = self.state[_Block]
if state.create_guard_now:
template = """
if not do_return_var_name:
original_node
"""
cond, = templates.replace(
template,
do_return_var_name=self.state[_Function].do_return_var_name,
original_node=node)
node, block = cond, cond.body
else:
node, block = node, None
state.create_guard_now = state.create_guard_next
state.create_guard_next = False
return node, block
def _visit_statement_block(self, node, nodes):
self.state[_Block].enter()
nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
self.state[_Block].exit()
return nodes
def visit_While(self, node):
node.test = self.visit(node.test)
# Add the check for return to the loop condition.
node.body = self._visit_statement_block(node, node.body)
if self.state[_Block].return_used:
node.test = templates.replace_as_expression(
'not control_var and test',
test=node.test,
control_var=self.state[_Function].do_return_var_name)
node.orelse = self._visit_statement_block(node, node.orelse)
return node
def visit_For(self, node):
node.iter = self.visit(node.iter)
node.target = self.visit(node.target)
# Add the check for return to the loop condition.
node.body = self._visit_statement_block(node, node.body)
if self.state[_Block].return_used:
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
if extra_test is not None:
extra_test = templates.replace_as_expression(
'not control_var and extra_test',
extra_test=extra_test,
control_var=self.state[_Function].do_return_var_name)
else:
extra_test = templates.replace_as_expression(
'not control_var',
control_var=self.state[_Function].do_return_var_name)
anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
node.orelse = self._visit_statement_block(node, node.orelse)
return node
def visit_With(self, node):
node.items = self.visit_block(node.items)
node.body = self._visit_statement_block(node, node.body)
return node
def visit_Try(self, node):
node.body = self._visit_statement_block(node, node.body)
node.orelse = self._visit_statement_block(node, node.orelse)
node.finalbody = self._visit_statement_block(node, node.finalbody)
node.handlers = self.visit_block(node.handlers)
return node
def visit_ExceptHandler(self, node):
node.body = self._visit_statement_block(node, node.body)
return node
def visit_If(self, node):
node.test = self.visit(node.test)
node.body = self._visit_statement_block(node, node.body)
node.orelse = self._visit_statement_block(node, node.orelse)
return node
def visit_FunctionDef(self, node):
with self.state[_Function] as fn:
with self.state[_Block] as block:
block.is_function = True
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
do_return_var_name = self.ctx.namer.new_symbol('do_return',
scope.referenced)
retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced)
fn.do_return_var_name = do_return_var_name
fn.retval_var_name = retval_var_name
node.body = self._visit_statement_block(node, node.body)
if block.return_used:
if self.allow_missing_return:
# The function would have a single `with` node that wraps the
# entire body. If the function had a docstring, the body has two
# nodes, with the `with` as the second node.
wrapper_node = node.body[-1]
assert isinstance(wrapper_node, gast.With), (
'This transformer requires the functions converter.')
template = """
do_return_var_name = False
retval_var_name = ag__.UndefinedReturnValue()
body
return function_context.ret(retval_var_name, do_return_var_name)
"""
wrapper_node.body = templates.replace(
template,
body=wrapper_node.body,
do_return_var_name=do_return_var_name,
function_context=anno.getanno(node, 'function_context_name'),
retval_var_name=retval_var_name)
else:
template = """
body
return retval_var_name
"""
node.body = templates.replace(
template,
body=node.body,
do_return_var_name=do_return_var_name,
retval_var_name=retval_var_name)
return node
def transform(node, ctx, default_to_null_return=True):
"""Ensure a function has only a single return, at the end."""
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
# Note: Technically, these two could be merged into a single walk, but
# keeping them separate helps with readability.
node = ConditionalReturnRewriter(ctx).visit(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
transformer = ReturnStatementsTransformer(
ctx, allow_missing_return=default_to_null_return)
node = transformer.visit(node)
return node