Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/framework/composite_tensor_gradient.py
2023-06-19 00:49:18 +02:00

186 lines
6.4 KiB
Python

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient support for Composite Tensors."""
import abc
import sys
from tensorflow.python.framework import composite_tensor
from tensorflow.python.util import nest
# pylint:disable=g-import-not-at-top
if sys.version_info >= (3, 8):
from typing import Protocol
from typing import runtime_checkable
else:
from typing_extensions import Protocol
from typing_extensions import runtime_checkable
# pylint:enable=g-import-not-at-top
# TODO(xjun): Add CompositeTensorGradient support for SparseTensor,
# StructuredTensor, and MaskedTensor.
class CompositeTensorGradient(object, metaclass=abc.ABCMeta):
"""Class used to help compute gradients for CompositeTensors.
This abstract base class defines two methods: `get_gradient_components`, which
returns the components of a value that should be included in gradients; and
`replace_gradient_components`, which replaces the gradient components in a
value. These methods can be used to compute the gradient of a `y` with
respect to `x` (`grad(y, x)`) as follows:
* If `y` is a `CompositeTensor` with `CompositeTensorGradient` `cg` =
`y.__composite_gradient__`, then `grad(y, x)` =
`grad(cg.get_gradient_components(y), x)`.
* If `x` is a `CompositeTensor` with `CompositeTensorGradient` `cg` =
'x.__composite_gradient__', then `grad(y, x)` =
`cg.replace_gradient_components(x, grad(y, cg.get_gradient_components(x))`.
"""
@abc.abstractmethod
def get_gradient_components(self, value):
"""Returns the components of `value` that should be included in gradients.
This method may not call TensorFlow ops, since any new ops added to the
graph would not be propertly tracked by the gradient mechanisms.
Args:
value: A `CompositeTensor` value.
Returns:
A nested structure of `Tensor` or `IndexedSlices`.
"""
raise NotImplementedError(
f"{type(self).__name__}.get_gradient_components()")
@abc.abstractmethod
def replace_gradient_components(self, value, component_grads):
"""Replaces the gradient components in `value` with `component_grads`.
Args:
value: A value with its gradient components compatible with
`component_grads`.
component_grads: A nested structure of `Tensor` or `IndexedSlices` or
`None` (for unconnected gradients).
Returns:
A copy of `value`, where the components that should be included in
gradients have been replaced by `component_grads`; or `None` (if
`component_grads` includes `None`).
"""
raise NotImplementedError(
f"{type(self).__name__}.replace_gradient_components()")
@runtime_checkable
class CompositeTensorGradientProtocol(Protocol):
"""Protocol for adding gradient support to CompositeTensors."""
__composite_gradient__: CompositeTensorGradient
class WithValuesCompositeTensorGradient(CompositeTensorGradient):
"""CompositeTensorGradient based on `T.values` and `T.with_values`."""
def get_gradient_components(self, value):
return value.values
def replace_gradient_components(self, value, component_grads):
return value.with_values(component_grads)
def _get_tensors_for_gradient(x):
"""Returns the Tensors in `x` that should be differentiated.
Args:
x: A `Tensor` or `CompositeTensor`.
Returns:
A `Tensor` or a nested structure of `Tensor`.
"""
if not isinstance(x, composite_tensor.CompositeTensor):
return x
if not isinstance(x, CompositeTensorGradientProtocol):
raise ValueError(
f"Type {type(x).__name__} is not supported as a gradient source or "
"gradient target.")
composite_gradient = x.__composite_gradient__
gradient_components = composite_gradient.get_gradient_components(x)
if gradient_components is x:
return x
return nest.map_structure(_get_tensors_for_gradient, gradient_components)
def _replace_tensors_for_gradient(x, grad):
"""Replaces the tensors in `x` that should be differentiated with `grad`.
Args:
x: A `Tensor` or `CompositeTensor`.
grad: A nested structure of `Tensor`, with the same structure as the value
returned by `_get_tensors_for_gradient(x)`.
Returns:
A `Tensor` or `CompositeTensor`.
"""
if not isinstance(x, composite_tensor.CompositeTensor):
return grad
if not isinstance(x, CompositeTensorGradientProtocol):
raise ValueError(
f"Type {type(x).__name__} is not supported as a gradient source.")
composite_gradient = x.__composite_gradient__
x_components = composite_gradient.get_gradient_components(x)
if x_components is x:
grad_components = grad
else:
grad_components = nest.map_structure_up_to(x_components,
_replace_tensors_for_gradient,
x_components, grad)
if grad_components is None:
return None
return composite_gradient.replace_gradient_components(x, grad_components)
def get_flat_tensors_for_gradients(xs):
"""Returns a flat list of Tensors that should be differentiated for `xs`.
Args:
xs: A list of `Tensor`s or `CompositeTensor`s.
Returns:
A flat list of `Tensor`s constructed from `xs`, where `Tensor` values are
left as-is, and `CompositeTensor`s are replaced with
`_get_tensors_for_gradient(x)`.
"""
return nest.flatten([_get_tensors_for_gradient(x) for x in xs])
def replace_flat_tensors_for_gradients(xs, flat_grads):
"""Replaces Tensors that should be differentiated in `xs` with `flat_grads`.
Args:
xs: A list of `Tensor`s or `CompositeTensor`s.
flat_grads: A list of `Tensor`.
Returns:
A list of `Tensor` or `CompositeTensor`.
"""
xs_structure = [_get_tensors_for_gradient(x) for x in xs]
grads = nest.pack_sequence_as(xs_structure, flat_grads)
return [_replace_tensors_for_gradient(x, grad) for x, grad in zip(xs, grads)]