Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/framework/auto_control_deps_utils.py
2023-06-19 00:49:18 +02:00

169 lines
5.3 KiB
Python

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for AutomaticControlDependencies."""
from tensorflow.python.framework import dtypes
from tensorflow.python.util import object_identity
READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"
RESOURCE_READ_OPS = set()
COLLECTIVE_MANAGER_IDS = "_collective_manager_ids"
def register_read_only_resource_op(op_type):
"""Declares that `op_type` does not update its touched resource."""
RESOURCE_READ_OPS.add(op_type)
def get_read_only_resource_input_indices_graph(func_graph):
"""Returns sorted list of read-only resource indices in func_graph.inputs."""
result = []
# A cache to store the read only resource inputs of an Op.
# Operation -> ObjectIdentitySet of resource handles.
op_read_only_resource_inputs = {}
for input_index, t in enumerate(func_graph.inputs):
if t.dtype != dtypes.resource:
continue
read_only = True
for op in t.consumers():
if op in op_read_only_resource_inputs:
if t not in op_read_only_resource_inputs[op]:
read_only = False
break
else:
indices = _get_read_only_resource_input_indices_op(op)
op_read_only_resource_inputs[op] = object_identity.ObjectIdentitySet(
[op.inputs[i] for i in indices])
if t not in op_read_only_resource_inputs[op]:
read_only = False
break
if read_only:
result.append(input_index)
return result
def _get_read_only_resource_input_indices_op(op):
"""Returns sorted list of read-only resource indices in op.inputs."""
if op.type in RESOURCE_READ_OPS:
return [i for i, t in enumerate(op.inputs) if t.dtype == dtypes.resource]
try:
read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
except ValueError:
# Attr was not set. Add all resource inputs to `writes` and return.
return []
read_only_index = 0
result = []
for i, t in enumerate(op.inputs):
if read_only_index >= len(read_only_input_indices):
break
if op.inputs[i].dtype != dtypes.resource:
continue
if (read_only_index < len(read_only_input_indices) and
i == read_only_input_indices[read_only_index]):
result.append(i)
read_only_index += 1
return result
def get_read_write_resource_inputs(op):
"""Returns a tuple of resource reads, writes in op.inputs.
Args:
op: Operation
Returns:
A 2-tuple of ObjectIdentitySets, the first entry containing read-only
resource handles and the second containing read-write resource handles in
`op.inputs`.
"""
reads = object_identity.ObjectIdentitySet()
writes = object_identity.ObjectIdentitySet()
if op.type in RESOURCE_READ_OPS:
# Add all resource inputs to `reads` and return.
reads.update(t for t in op.inputs if t.dtype == dtypes.resource)
return (reads, writes)
try:
read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
except ValueError:
# Attr was not set. Add all resource inputs to `writes` and return.
writes.update(t for t in op.inputs if t.dtype == dtypes.resource)
return (reads, writes)
read_only_index = 0
for i, t in enumerate(op.inputs):
if op.inputs[i].dtype != dtypes.resource:
continue
if (read_only_index < len(read_only_input_indices) and
i == read_only_input_indices[read_only_index]):
reads.add(op.inputs[i])
read_only_index += 1
else:
writes.add(op.inputs[i])
return (reads, writes)
def _op_writes_to_resource(handle, op):
"""Returns whether op writes to resource handle.
Args:
handle: Resource handle. Must be an input of `op`.
op: Operation.
Returns:
Returns False if op is a read-only op registered using
`register_read_only_resource_op` or if `handle` is an input at one of
the indices in the `READ_ONLY_RESOURCE_INPUTS_ATTR` attr of the op, True
otherwise.
Raises:
ValueError: if `handle` is not an input of `op`.
"""
if op.type in RESOURCE_READ_OPS:
return False
input_index = _input_index(op, handle)
try:
read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
except ValueError:
# Attr was not set. Conservatively assume that the resource is written to.
return True
return input_index not in read_only_input_indices
def _input_index(op, handle):
"""Returns the index of `handle` in `op.inputs`.
Args:
op: Operation.
handle: Resource handle.
Returns:
Index in `op.inputs` receiving the resource `handle`.
Raises:
ValueError: If handle and its replicated input are both not found in
`op.inputs`.
"""
for i, t in enumerate(op.inputs):
if handle is t:
return i
raise ValueError(f"{handle!s} not in list of inputs for op: {op!r}")