98 lines
2.7 KiB
Python
98 lines
2.7 KiB
Python
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Overloads all variable read operations."""
|
||
|
|
||
|
import gast
|
||
|
|
||
|
from tensorflow.python.autograph.core import converter
|
||
|
from tensorflow.python.autograph.pyct import anno
|
||
|
from tensorflow.python.autograph.pyct import templates
|
||
|
|
||
|
|
||
|
class VariableAccessTransformer(converter.Base):
|
||
|
"""Rewrites basic symbol reads.
|
||
|
|
||
|
This transformer rewrites variable reads with a "read" operator which allows
|
||
|
tracking activity.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
For a basic statement:
|
||
|
|
||
|
a = b + c
|
||
|
|
||
|
This is translated to:
|
||
|
|
||
|
a = ld(b) + ld(c)
|
||
|
|
||
|
Augmented assignment operations also introduce a `ld` operator:
|
||
|
|
||
|
a += b
|
||
|
|
||
|
The assignment target also receives an operator to properly represent the
|
||
|
read:
|
||
|
|
||
|
a = ld(a)
|
||
|
a += ld(b)
|
||
|
"""
|
||
|
|
||
|
def visit_Name(self, node):
|
||
|
# Only the loads which existed in the original code are overloaded.
|
||
|
if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
|
||
|
return node
|
||
|
if isinstance(node.ctx, gast.Load):
|
||
|
node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
|
||
|
return node
|
||
|
|
||
|
def visit_Delete(self, node):
|
||
|
node = self.generic_visit(node)
|
||
|
|
||
|
rewrite_targets = []
|
||
|
for tgt in node.targets:
|
||
|
# Don't rewrite composites like `del a[0]`.
|
||
|
if isinstance(tgt, gast.Name):
|
||
|
rewrite_targets.append(tgt)
|
||
|
|
||
|
if not rewrite_targets:
|
||
|
return node
|
||
|
|
||
|
results = []
|
||
|
for tgt in rewrite_targets:
|
||
|
template = """
|
||
|
var_ = ag__.Undefined(var_name)
|
||
|
"""
|
||
|
results.extend(templates.replace(
|
||
|
template, var_=tgt, var_name=gast.Constant(tgt.id, kind=None)))
|
||
|
remaining_targets = [n for n in node.targets if n not in rewrite_targets]
|
||
|
if remaining_targets:
|
||
|
results.append(gast.Delete(targets=remaining_targets))
|
||
|
|
||
|
return results
|
||
|
|
||
|
def visit_AugAssign(self, node):
|
||
|
if isinstance(node.target, gast.Name):
|
||
|
template = """
|
||
|
var_ = ag__.ld(var_)
|
||
|
original
|
||
|
"""
|
||
|
node = templates.replace(template, var_=node.target, original=node)
|
||
|
else:
|
||
|
node = self.generic_visit(node)
|
||
|
return node
|
||
|
|
||
|
|
||
|
def transform(node, ctx):
|
||
|
return VariableAccessTransformer(ctx).visit(node)
|