3RNN/Lib/site-packages/tensorflow/python/ops/sparse_grad.py
2024-05-26 19:49:15 +02:00

363 lines
13 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in sparse_ops.py."""
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
# TODO(b/31222613): This op may be differentiable, and there may be
# latent bugs here.
ops.NotDifferentiable("SparseAddGrad")
ops.NotDifferentiable("SparseConcat")
@ops.RegisterGradient("SparseReorder")
def _SparseReorderGrad(
op: ops.Operation, unused_output_indices_grad, output_values_grad
):
"""Gradients for the SparseReorder op.
Args:
op: the SparseReorder op
unused_output_indices_grad: the incoming gradients of the output indices
output_values_grad: the incoming gradients of the output values
Returns:
Gradient for each of the 3 input tensors:
(input_indices, input_values, input_shape)
The gradients for input_indices and input_shape is None.
"""
input_indices = op.inputs[0]
input_shape = op.inputs[2]
num_entries = array_ops.shape(input_indices)[0]
entry_indices = math_ops.range(num_entries)
sp_unordered = sparse_tensor.SparseTensor(input_indices, entry_indices,
input_shape)
sp_ordered = sparse_ops.sparse_reorder(sp_unordered)
inverted_permutation = array_ops.invert_permutation(sp_ordered.values)
return (None, array_ops.gather(output_values_grad,
inverted_permutation), None)
@ops.RegisterGradient("SparseAdd")
def _SparseAddGrad(op: ops.Operation, *grads):
"""The backward operator for the SparseAdd op.
The SparseAdd op calculates A + B, where A, B, and the sum are all represented
as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
values of A and B.
Args:
op: the SparseAdd op
*grads: the incoming gradients, one element per output of `op`
Returns:
Gradient for each of the 6 input tensors of SparseAdd:
(a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
The gradients for the indices, shapes, and the threshold are None.
"""
val_grad = grads[1]
a_indices = op.inputs[0]
b_indices = op.inputs[3]
sum_indices = op.outputs[0]
# NOTE: we do not need to take `thresh` into account, since it simply affects
# the non-zero elements of the sum, and we will peek into `sum_indices` in the
# gradient op.
a_val_grad, b_val_grad = gen_sparse_ops.sparse_add_grad(
val_grad, a_indices, b_indices, sum_indices)
a_val_grad.set_shape(op.inputs[1].get_shape())
b_val_grad.set_shape(op.inputs[4].get_shape())
# (a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
return (None, a_val_grad, None, None, b_val_grad, None, None)
@ops.RegisterGradient("SparseTensorDenseAdd")
def _SparseTensorDenseAddGrad(op: ops.Operation, out_grad):
sp_indices = op.inputs[0]
# (sparse_indices, sparse_values, sparse_shape, dense)
return (None, array_ops.gather_nd(out_grad, sp_indices), None, out_grad)
@ops.RegisterGradient("SparseReduceSum")
def _SparseReduceSumGrad(op: ops.Operation, out_grad):
"""Similar to gradient for the Sum Op (i.e. tf.reduce_sum())."""
sp_indices = op.inputs[0]
sp_shape = op.inputs[2]
output_shape_kept_dims = math_ops.reduced_shape(sp_shape, op.inputs[3])
out_grad_reshaped = array_ops.reshape(out_grad, output_shape_kept_dims)
scale = sp_shape // math_ops.cast(output_shape_kept_dims, dtypes.int64)
# (sparse_indices, sparse_values, sparse_shape, reduction_axes)
return (None, array_ops.gather_nd(out_grad_reshaped,
sp_indices // scale), None, None)
@ops.RegisterGradient("SparseSlice")
def _SparseSliceGrad(op: ops.Operation, *grads):
"""The backward operator for the SparseSlice op.
This op takes in the upstream gradient w.r.t. non-empty values of
the sliced `SparseTensor`, and outputs the gradients w.r.t.
the non-empty values of input `SparseTensor`.
Args:
op: the SparseSlice op
*grads: the incoming gradients, one element per output of `op`
Returns:
Gradient for each of the 5 input tensors of SparseSlice:
(indices, values, shape, start, size)
The gradients for the indices, shape, start and the size are None.
"""
backprop_val_grad = grads[1]
input_indices = op.inputs[0]
input_start = op.inputs[3]
output_indices = op.outputs[0]
val_grad = gen_sparse_ops.sparse_slice_grad(backprop_val_grad, input_indices,
input_start, output_indices)
val_grad.set_shape(op.inputs[1].get_shape())
# (indices, values, shape, start, size)
return (None, val_grad, None, None, None)
@ops.RegisterGradient("SparseTensorDenseMatMul")
def _SparseTensorDenseMatMulGrad(op: ops.Operation, grad):
"""Gradients for the dense tensor in the SparseTensorDenseMatMul op.
Args:
op: the SparseTensorDenseMatMul op
grad: the incoming gradient
Returns:
Gradient for each of the 4 input tensors:
(sparse_indices, sparse_values, sparse_shape, dense_tensor)
The gradients for indices and shape are None.
Raises:
TypeError: When the two operands don't have the same type.
"""
a_indices, a_values, a_shape = op.inputs[:3]
b = op.inputs[3]
adj_a = op.get_attr("adjoint_a")
adj_b = op.get_attr("adjoint_b")
a_type = a_values.dtype.base_dtype
b_type = b.dtype.base_dtype
if a_type != b_type:
raise TypeError(
f"SparseTensorDenseMatMul op received operands with different types: "
f"`{a_type}` and `{b_type}`.")
# gradient w.r.t. dense
b_grad = gen_sparse_ops.sparse_tensor_dense_mat_mul(
a_indices, a_values, a_shape, grad, adjoint_a=not adj_a)
if adj_b:
b_grad = array_ops.matrix_transpose(b_grad, conjugate=True)
# gradient w.r.t. sparse values
# TODO(zongheng): these gather calls could potentially duplicate rows/cols in
# memory. If there is a need, we should look into implementing this more
# intelligently to avoid duplicating data.
# With no adjoints, a_grad is matmul(grad, adjoint(b)). Since a is sparse, we
# just want to compute that matmul at the rows/columns of non-zero values. The
# (r, c) value is sum(grad[r, :] * adjoint(b)[:, c]), where the latter term is
# more conveniently written as conj(b)[c, :]. That expression is more
# efficient to calculate as a matmul, after expanding the two terms to be 2D
# (i.e. a row vector and a column vector).
#
# If adj_b then we replace conj(b) by transpose(b); if adj_a we need to
# adjoint the result, which is equivalent to swapping r and c and taking
# conjugates.
# Get grad[r, :] and b[c, :] (or with r and c swapped if adj_a, or with
# transpose(b) if adj_b), as batches of vectors (with the batch dimension
# corresponding to the non-zero indices of a).
rows = a_indices[:, 0]
cols = a_indices[:, 1]
parts_a = array_ops.gather(grad, rows if not adj_a else cols)
parts_b = array_ops.gather(
b if not adj_b else array_ops.transpose(b), cols if not adj_a else rows)
if not adj_a and not adj_b:
# grad[r, :] * conj(b[c, :]) = row(grad[r, :]) @ adjoint(row(b[c, :]))
a_values_grad = math_ops.matmul(
array_ops.expand_dims(parts_a, -2),
array_ops.expand_dims(parts_b, -2),
adjoint_b=True)
elif adj_a and not adj_b:
# conj(grad[c, :] * conj(b[r, :])) = adjoint(col(grad[c, :])) @ col(b[r, :])
a_values_grad = math_ops.matmul(
array_ops.expand_dims(parts_a, -1),
array_ops.expand_dims(parts_b, -1),
adjoint_a=True)
elif not adj_a and adj_b:
# grad[r, :] * transpose(b)[c, :] =
# row(grad[r, :]) @ col(transpose(b)[c, :])
a_values_grad = math_ops.matmul(
array_ops.expand_dims(parts_a, -2), array_ops.expand_dims(parts_b, -1))
elif adj_a and adj_b:
# conj(grad[c, :] * transpose(b)[r, :]) =
# adjoint(col(grad[c, :])) @ adjoint(row(transpose(b)[r, :])
a_values_grad = math_ops.matmul(
array_ops.expand_dims(parts_a, -1),
array_ops.expand_dims(parts_b, -2),
adjoint_a=True,
adjoint_b=True)
# gradients w.r.t. (a_indices, a_values, a_shape, b)
return (None, array_ops.squeeze(a_values_grad, axis=[-2, -1]), None, b_grad)
@ops.RegisterGradient("SparseDenseCwiseAdd")
def _SparseDenseCwiseAddGrad(unused_op, unused_grad):
raise NotImplementedError(
"Gradient for SparseDenseCwiseAdd is not implemented.")
def _SparseDenseCwiseMulOrDivGrad(op: ops.Operation, grad, is_mul):
"""Common code for SparseDenseCwise{Mul,Div} gradients."""
x_indices = op.inputs[0]
x_shape = op.inputs[2]
y = op.inputs[3]
y_shape = math_ops.cast(array_ops.shape(y), dtypes.int64)
num_added_dims = array_ops.expand_dims(
array_ops.size(x_shape) - array_ops.size(y_shape), 0)
augmented_y_shape = array_ops.concat(
[array_ops.ones(num_added_dims, ops.dtypes.int64), y_shape], 0)
scaling = x_shape // augmented_y_shape
scaled_indices = x_indices // scaling
scaled_indices = array_ops.slice(scaled_indices,
array_ops.concat([[0], num_added_dims], 0),
[-1, -1])
dense_vals = array_ops.gather_nd(y, scaled_indices)
if is_mul:
dx = grad * dense_vals
dy_val = grad * op.inputs[1]
else:
dx = grad / dense_vals
dy_val = grad * (-op.inputs[1] / math_ops.square(dense_vals))
# indices can repeat after scaling, so we can't use sparse_to_dense().
dy = sparse_ops.sparse_add(
array_ops.zeros_like(y),
sparse_tensor.SparseTensor(scaled_indices, dy_val, y_shape))
# (sp_indices, sp_vals, sp_shape, dense)
return (None, dx, None, dy)
@ops.RegisterGradient("SparseDenseCwiseMul")
def _SparseDenseCwiseMulGrad(op: ops.Operation, grad):
"""Gradients for SparseDenseCwiseMul."""
return _SparseDenseCwiseMulOrDivGrad(op, grad, True)
@ops.RegisterGradient("SparseDenseCwiseDiv")
def _SparseDenseCwiseDivGrad(op: ops.Operation, grad):
"""Gradients for SparseDenseCwiseDiv."""
return _SparseDenseCwiseMulOrDivGrad(op, grad, False)
@ops.RegisterGradient("SparseSoftmax")
def _SparseSoftmaxGrad(op: ops.Operation, grad):
"""Gradients for SparseSoftmax.
The calculation is the same as SoftmaxGrad:
grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
where we now only operate on the non-zero values present in the SparseTensors.
Args:
op: the SparseSoftmax op.
grad: the upstream gradient w.r.t. the non-zero SparseSoftmax output values.
Returns:
Gradients w.r.t. the input (sp_indices, sp_values, sp_shape).
"""
indices, shape = op.inputs[0], op.inputs[2]
out_vals = op.outputs[0]
sp_output = sparse_tensor.SparseTensor(indices, out_vals, shape)
sp_grad = sparse_tensor.SparseTensor(indices, grad, shape)
sp_product = sparse_tensor.SparseTensor(indices,
sp_output.values * sp_grad.values,
shape)
# [..., B, 1], dense.
sum_reduced = -sparse_ops.sparse_reduce_sum(sp_product, [-1], keepdims=True)
# sparse [..., B, C] + dense [..., B, 1] with broadcast; outputs sparse.
sp_sum = sparse_ops.sparse_dense_cwise_add(sp_grad, sum_reduced)
grad_x = sp_sum.values * sp_output.values
return [None, grad_x, None]
@ops.RegisterGradient("SparseSparseMaximum")
def _SparseSparseMaximumGrad(unused_op: ops.Operation, unused_grad):
raise NotImplementedError(
"Gradient for SparseSparseMaximum is not implemented."
)
@ops.RegisterGradient("SparseSparseMinimum")
def _SparseSparseMinimumGrad(unused_op: ops.Operation, unused_grad):
raise NotImplementedError(
"Gradient for SparseSparseMinimum is not implemented."
)
@ops.RegisterGradient("SparseFillEmptyRows")
def _SparseFillEmptyRowsGrad(
op: ops.Operation,
unused_grad_output_indices,
output_grad_values,
unused_grad_empty_row_indicator,
unused_grad_reverse_index_map,
):
"""Gradients for SparseFillEmptyRows."""
reverse_index_map = op.outputs[3]
d_values, d_default_value = gen_sparse_ops.sparse_fill_empty_rows_grad(
reverse_index_map=reverse_index_map, grad_values=output_grad_values
)
# d_indices, d_values, d_dense_shape, d_default_value.
return [None, d_values, None, d_default_value]
@ops.RegisterGradient("SparseToDense")
def _SparseToDenseGrad(op: ops.Operation, grad):
sparse_indices, output_shape, _, _ = op.inputs
sparse_values_grad = array_ops.gather_nd(grad, sparse_indices)
default_value_grad = math_ops.reduce_sum(grad) - math_ops.reduce_sum(
sparse_values_grad)
return [
array_ops.zeros_like(sparse_indices),
array_ops.zeros_like(output_shape), sparse_values_grad, default_value_grad
]