3RNN/Lib/site-packages/tensorflow/python/autograph/converters/variables.py
2024-05-26 19:49:15 +02:00

98 lines
2.7 KiB
Python

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Overloads all variable read operations."""
import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import templates
class VariableAccessTransformer(converter.Base):
"""Rewrites basic symbol reads.
This transformer rewrites variable reads with a "read" operator which allows
tracking activity.
Example:
For a basic statement:
a = b + c
This is translated to:
a = ld(b) + ld(c)
Augmented assignment operations also introduce a `ld` operator:
a += b
The assignment target also receives an operator to properly represent the
read:
a = ld(a)
a += ld(b)
"""
def visit_Name(self, node):
# Only the loads which existed in the original code are overloaded.
if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
return node
if isinstance(node.ctx, gast.Load):
node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
return node
def visit_Delete(self, node):
node = self.generic_visit(node)
rewrite_targets = []
for tgt in node.targets:
# Don't rewrite composites like `del a[0]`.
if isinstance(tgt, gast.Name):
rewrite_targets.append(tgt)
if not rewrite_targets:
return node
results = []
for tgt in rewrite_targets:
template = """
var_ = ag__.Undefined(var_name)
"""
results.extend(templates.replace(
template, var_=tgt, var_name=gast.Constant(tgt.id, kind=None)))
remaining_targets = [n for n in node.targets if n not in rewrite_targets]
if remaining_targets:
results.append(gast.Delete(targets=remaining_targets))
return results
def visit_AugAssign(self, node):
if isinstance(node.target, gast.Name):
template = """
var_ = ag__.ld(var_)
original
"""
node = templates.replace(template, var_=node.target, original=node)
else:
node = self.generic_visit(node)
return node
def transform(node, ctx):
return VariableAccessTransformer(ctx).visit(node)