3RNN/Lib/site-packages/tensorflow/python/eager/forwardprop.py
2024-05-26 19:49:15 +02:00

488 lines
20 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for forward-mode automatic differentiation."""
import functools
import threading
from tensorflow.core.function.polymorphism import function_cache
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import execute
from tensorflow.python.eager import forwardprop_util
from tensorflow.python.eager.polymorphic_function import tracing_compilation
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.parallel_for import control_flow_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# Dictionary mapping from op names to special-cased jvp functions. Otherwise
# backward functions are transposed on the tape.
_SPECIAL_CASES = {}
def _identity_jvp(attr_tuple, inputs, outputs, tangents):
# Special-cased mostly for resource handles, where creating ones Tensors from
# handle data for transposing the backward function on the tape is error-prone
# (even if we get good handle data, partially defined shapes are an issue).
del attr_tuple, inputs, outputs
return [array_ops.identity(t) for t in tangents]
_SPECIAL_CASES["Identity"] = _identity_jvp
def _read_variable_jvp(attr_tuple, inputs, outputs, tangents):
# Like for Identity, this special case means we don't need to create
# variable-shaped Tensors from resource handles.
del attr_tuple, inputs, outputs
return [array_ops.identity(t) for t in tangents]
_SPECIAL_CASES["ReadVariableOp"] = _read_variable_jvp
_TRACE_COUNT_CONSISTENCY_LOCK = threading.Lock()
# Map from op names to number of traces of _jvp_helper. Used to cap the number
# of traces due to shape differences while still specializing where possible.
_TRACE_COUNT = {}
def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents):
"""Computes a Jacobian-vector product for an op.
Note that this function would be wasteful if executed eagerly. It runs the
backward gradient function and throws away the result just to record its
operations on a GradientTape. These unused ops are pruned away when this
function is traced.
Args:
op_name: A string, the type of operation being executed.
attr_tuple: Attributes of the operation.
inputs: A flat list of input Tensors to the operation.
outputs: A flat list of output Tensors from the operation.
tangents: A flat list of Tensors, same shape as `inputs`.
Returns:
A flat list of tangents corresponding to `outputs`.
"""
with _TRACE_COUNT_CONSISTENCY_LOCK:
# Just make sure writes don't clobber each other's increments; reads in
# _jvp_dispatch do not lock.
_TRACE_COUNT[op_name] = _TRACE_COUNT.get(op_name, 0) + 1
special_case = _SPECIAL_CASES.get(op_name, None)
if special_case is not None:
return special_case(attr_tuple, inputs, outputs, tangents)
if not outputs:
# tape.gradients([], inputs) doesn't make much sense
return []
# Generally inner GradientTapes won't function while outer accumulators are
# recording. We temporarily reset forwardprop state to allow GradientTapes to
# function here.
with forwardprop_util.push_forwardprop_state():
trainable_inputs = []
trainable_indices = []
nontrivial_tangents = []
for input_index, tensor in enumerate(inputs):
if backprop_util.IsTrainable(tensor):
trainable_inputs.append(tensor)
trainable_indices.append(input_index)
nontrivial_tangents.append(tangents[input_index])
with backprop.GradientTape() as transpose_tape:
with backprop.GradientTape() as backfunc_tape:
backfunc_tape.watch(trainable_inputs)
execute.record_gradient(op_name, inputs, attr_tuple, outputs)
forwardprop_aids = []
trainable_outputs = []
nontrivial_output_indices = []
for output_index, output in enumerate(outputs):
if backprop_util.IsTrainable(output):
forwardprop_aids.append(
array_ops.ones_like(output, name="unused_forwardprop_aid"))
trainable_outputs.append(output)
nontrivial_output_indices.append(output_index)
transpose_tape.watch(forwardprop_aids)
grads = backfunc_tape.gradient(
trainable_outputs,
trainable_inputs,
forwardprop_aids,
unconnected_gradients=UnconnectedGradients.ZERO)
nontrivial_output_tangents = transpose_tape.gradient(
grads, forwardprop_aids, output_gradients=nontrivial_tangents)
output_tangents = [None] * len(outputs)
for index, tangent in zip(nontrivial_output_indices,
nontrivial_output_tangents):
output_tangents[index] = tangent
return output_tangents
def _jvp_helper_wrapper(op_name, attr_tuple, inputs, outputs, tangents,
use_batch):
"""Computes a batch of Jacobian-vector product for an op.
Args:
op_name: A string, the type of operation being executed.
attr_tuple: Attributes of the operation.
inputs: A flat list of input Tensors to the operation.
outputs: A flat list of output Tensors from the operation.
tangents: A flat list of Tensors, compatible with shape `[None] +
input_shape`.
use_batch: A bool, True to vetorize over batch of tangents of shape `[None]
+ input_shape`.
Returns:
A flat list of tangents compatible with `outputs`
or `[None] + output_shape`.
Raises:
ValueError: if tangent shapes are not compatible with input shapes.
"""
if use_batch:
for primal, tangent in zip(inputs, tangents):
if not tangent.shape.is_compatible_with([None] + primal.shape):
raise ValueError("Tangent {} was expected to be of shape "
"{} but is instead of shape {}".format(
tangent, [None] + primal.shape, tangent.shape))
return control_flow_ops.vectorized_map(
functools.partial(_jvp_helper, op_name, attr_tuple, inputs, outputs),
tangents,
)
return _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents)
# TODO(allenl): reduce_retracing for gradients which rely on static
# shape information are underspecialized. We may want hand-written forward
# implementations, or a more satisfying story about how we re-specialize
# gradients which were traced with relaxed shapes (e.g. use conds instead of
# trace-time Python logic).
#
# Using function.defun rather than def_function.function avoids
# tf.config.run_functions_eagerly(True). `_jvp_helper` doesn't successfully run
# eagerly (infinite recursion), and even if it did it would use extra memory and
# run unnecessary computation. The function does not create variables, so the
# two symbols are otherwise equivalent.
_jvp_function_cache = function_cache.FunctionCache()
_jvp_relaxed_config = tracing_compilation.TracingOptions(
_jvp_helper_wrapper,
name="_jvp_relaxed_shapes",
reduce_retracing=True,
function_cache=_jvp_function_cache,
)
_jvp_exact_config = tracing_compilation.TracingOptions(
_jvp_helper_wrapper,
name="_jvp_exact_shapes",
reduce_retracing=False,
function_cache=_jvp_function_cache,
)
# The maximum number of exact-shape traces to perform for a single op before
# switching to shape relaxation.
_TRACE_COUNT_LIMIT = 32
def _jvp_dispatch(op_name,
attr_tuple,
inputs,
outputs,
tangents,
use_batch=False):
"""Determine which forwardprop function to call."""
# Note that this _TRACE_COUNT read races with writes. That's fine, it just
# means we may trace a few more exact shapes before moving on to relaxation.
if _TRACE_COUNT.get(op_name, 0) < _TRACE_COUNT_LIMIT:
config = _jvp_exact_config
else:
config = _jvp_relaxed_config
return tracing_compilation.call_function(
(op_name, attr_tuple, inputs, outputs, tangents, use_batch),
tracing_options=config,
)
pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
@tf_export("autodiff.ForwardAccumulator", v1=[])
class ForwardAccumulator():
"""Computes Jacobian-vector products ("JVP"s) using forward-mode autodiff.
Compare to `tf.GradientTape` which computes vector-Jacobian products ("VJP"s)
using reverse-mode autodiff (backprop). Reverse mode is more attractive when
computing gradients of a scalar-valued function with respect to many inputs
(e.g. a neural network with many parameters and a scalar loss). Forward mode
works best on functions with many outputs and few inputs. Since it does not
hold on to intermediate activations, it is much more memory efficient than
backprop where it is applicable.
Consider a simple linear regression:
>>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
>>> targets = tf.constant([[1.], [-1.]])
>>> dense = tf.keras.layers.Dense(1)
>>> dense.build([None, 2])
>>> with tf.autodiff.ForwardAccumulator(
... primals=dense.kernel,
... tangents=tf.constant([[1.], [0.]])) as acc:
... loss = tf.reduce_sum((dense(x) - targets) ** 2.)
>>> acc.jvp(loss)
<tf.Tensor: shape=(), dtype=float32, numpy=...>
The example has two variables containing parameters, `dense.kernel` (2
parameters) and `dense.bias` (1 parameter). Considering the training data `x`
as a constant, this means the Jacobian matrix for the function mapping from
parameters to loss has one row and three columns.
With forwardprop, we specify a length-three vector in advance which multiplies
the Jacobian. The `primals` constructor argument is the parameter (a
`tf.Tensor` or `tf.Variable`) we're specifying a vector for, and the
`tangents` argument is the "vector" in Jacobian-vector product. If our goal is
to compute the entire Jacobian matrix, forwardprop computes one column at a
time while backprop computes one row at a time. Since the Jacobian in the
linear regression example has only one row, backprop requires fewer
invocations:
>>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
>>> targets = tf.constant([[1.], [-1.]])
>>> dense = tf.keras.layers.Dense(1)
>>> dense.build([None, 2])
>>> loss_fn = lambda: tf.reduce_sum((dense(x) - targets) ** 2.)
>>> kernel_fprop = []
>>> with tf.autodiff.ForwardAccumulator(
... dense.kernel, tf.constant([[1.], [0.]])) as acc:
... kernel_fprop.append(acc.jvp(loss_fn()))
>>> with tf.autodiff.ForwardAccumulator(
... dense.kernel, tf.constant([[0.], [1.]])) as acc:
... kernel_fprop.append(acc.jvp(loss_fn()))
>>> with tf.autodiff.ForwardAccumulator(dense.bias, tf.constant([1.])) as acc:
... bias_fprop = acc.jvp(loss_fn())
>>> with tf.GradientTape() as tape:
... loss = loss_fn()
>>> kernel_grad, bias_grad = tape.gradient(loss, (dense.kernel, dense.bias))
>>> np.testing.assert_allclose(
... kernel_grad, tf.stack(kernel_fprop)[:, tf.newaxis])
>>> np.testing.assert_allclose(bias_grad, bias_fprop[tf.newaxis])
Implicit in the `tape.gradient` call is a length-one vector which
left-multiplies the Jacobian, a vector-Jacobian product.
`ForwardAccumulator` maintains JVPs corresponding primal tensors it is
watching, derived from the original `primals` specified in the constructor. As
soon as a primal tensor is deleted, `ForwardAccumulator` deletes the
corresponding JVP.
`acc.jvp(x)` retrieves `acc`'s JVP corresponding to the primal tensor `x`. It
does not perform any computation. `acc.jvp` calls can be repeated as long as
`acc` is accessible, whether the context manager is active or not. New JVPs
are only computed while the context manager is active.
Note that `ForwardAccumulator`s are always applied in the order their context
managers were entered, so inner accumulators will not see JVP computation from
outer accumulators. Take higher-order JVPs from outer accumulators:
>>> primal = tf.constant(1.1)
>>> with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as outer:
... with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as inner:
... primal_out = primal ** tf.constant(3.5)
>>> inner_jvp = inner.jvp(primal_out)
>>> inner_jvp # 3.5 * 1.1 ** 2.5
<tf.Tensor: shape=(), dtype=float32, numpy=4.4417057>
>>> outer.jvp(inner_jvp) # 3.5 * 2.5 * 1.1 ** 1.5
<tf.Tensor: shape=(), dtype=float32, numpy=10.094786>
Reversing the collection in the last line to instead retrieve
`inner.jvp(outer.jvp(primal_out))` will not work.
Strict nesting also applies to combinations of `ForwardAccumulator` and
`tf.GradientTape`. More deeply nested `GradientTape` objects will ignore the
products of outer `ForwardAccumulator` objects. This allows (for example)
memory-efficient forward-over-backward computation of Hessian-vector products,
where the inner `GradientTape` would otherwise hold on to all intermediate
JVPs:
>>> v = tf.Variable([1., 2.])
>>> with tf.autodiff.ForwardAccumulator(
... v,
... # The "vector" in Hessian-vector product.
... tf.constant([1., 0.])) as acc:
... with tf.GradientTape() as tape:
... y = tf.reduce_sum(v ** 3.)
... backward = tape.gradient(y, v)
>>> backward # gradient from backprop
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 3., 12.], dtype=float32)>
>>> acc.jvp(backward) # forward-over-backward Hessian-vector product
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 0.], dtype=float32)>
"""
def __init__(self, primals, tangents):
"""Specify tensors to watch and their Jacobian-vector products.
Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
(a Jacobian-vector product) for the function computed while this accumulator
is active. Since JVPs are computed in forward mode as the computation
happens, this vector must be supplied in advance.
Listing a single tensor multiple times in `primals` raises an
exception. Excluding a tensor from `primals` is equivalent to watching it
with a tangent tensor of zeros.
Args:
primals: A tensor or nested structure of tensors to watch.
tangents: A tensor or nested structure of tensors, with the same nesting
structure as `primals`, with each element being a vector with the same
size as the corresponding primal element.
Raises:
ValueError: If the same tensor or variable is specified multiple times in
`primals`.
"""
self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(False)
self._recording = False
primal_ids = set()
for primal in nest.flatten(primals):
if id(primal) in primal_ids:
raise ValueError(
"Tensor {} was specified as a primal multiple times. This may "
"indicate an error. If it was intended, please sum the "
"corresponding tangents.")
primal_ids.add(id(primal))
self._watch(primals, tangents)
def __enter__(self):
self._push_accumulator()
return self
def __exit__(self, typ, value, traceback):
if self._recording:
self._pop_accumulator()
def _push_accumulator(self):
if self._recording:
raise ValueError("Accumulator is already recording.")
pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
self._recording = True
def _pop_accumulator(self):
if not self._recording:
raise ValueError("Accumulator is not recording.")
pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
self._recording = False
def _watch(self, primals, tangents):
"""Ensures that `primals` are being traced by this accumulator.
Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
(a Jacobian-vector product) for the function computed while this accumulator
is active. Since JVPs are computed in forward mode as the computation
happens, this vector must be supplied in advance.
Watching a single tensor multiple times sums each of its `tangents`. Any
un-watched tensor has zeros for its tangent vector.
Args:
primals: A Tensor or list of Tensors.
tangents: A Tensor or list of Tensors matching `primals`.
"""
def _watch(primal, tangent):
if not primal.dtype.is_floating:
logging.log_first_n(
logging.WARN, "The dtype of the watched primal must be "
"floating (e.g. tf.float32), got %r", 5, primal.dtype)
tangent = ops.convert_to_tensor(tangent, dtype=primal.dtype)
if hasattr(primal, "handle"):
# Run convert_to_tensor to get the captured handle from whichever
# function we're running if necessary.
primal = ops.convert_to_tensor(primal.handle)
pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, primal,
tangent)
nest.map_structure(_watch, primals, tangents)
def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
"""Fetches the Jacobian-vector product computed for `primals`.
Note that this method performs no computation, and simply looks up a JVP
that was already computed (unlike backprop using a `tf.GradientTape`, where
the computation happens on the call to `tape.gradient`).
Args:
primals: A watched Tensor or structure of Tensors to fetch the JVPs for.
unconnected_gradients: A value which can either hold 'none' or 'zero' and
alters the value which will be returned if no JVP was computed for
`primals`. The possible values and effects are detailed in
'tf.UnconnectedGradients' and it defaults to 'none'.
Returns:
Tensors with the same shapes and dtypes as `primals`, or None if no JVP
is available.
"""
unconnected_gradients = UnconnectedGradients(unconnected_gradients)
if self._accumulator is None:
raise ValueError("Called jvp() without first tracing anything.")
def _fetch_jvp(tensor):
if hasattr(tensor, "handle"):
unwrapped_tensor = ops.convert_to_tensor(tensor.handle)
else:
unwrapped_tensor = tensor
result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
unwrapped_tensor)
if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
result = array_ops.zeros_like(tensor)
return result
return nest.map_structure(_fetch_jvp, primals)
@classmethod
def _batch_accumulator(cls, primals, tangents):
"""Factory constructor to test accumulator on batches of tangents.
Args:
primals: A tensor or nested structure of tensors to watch.
tangents: A tensor or nested structure of tensors, with the same nesting
structure as `primals`, with each element being a vector with compatible
shape `[None] + primal.shape` of the corresponding primal element.
Returns:
A batch accumulator object.
"""
acc = super(ForwardAccumulator, cls).__new__(cls, primals, tangents)
acc._recording = False
acc._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(True)
primal_ids = set()
for primal, tangent in zip(nest.flatten(primals), nest.flatten(tangents)):
tangent.shape.assert_is_compatible_with(
tensor_shape.TensorShape([None]) + primal.shape)
if id(primal) in primal_ids:
raise ValueError(
"Tensor {} was specified as a primal multiple times. This may "
"indicate an error. If it was intended, please sum the "
"corresponding tangents.")
primal_ids.add(id(primal))
acc._watch(primals, tangents)
return acc