363 lines
13 KiB
Python
363 lines
13 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Gradients for operators defined in sparse_ops.py."""
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import gen_sparse_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import sparse_ops
|
|
|
|
# TODO(b/31222613): This op may be differentiable, and there may be
|
|
# latent bugs here.
|
|
ops.NotDifferentiable("SparseAddGrad")
|
|
ops.NotDifferentiable("SparseConcat")
|
|
|
|
|
|
@ops.RegisterGradient("SparseReorder")
|
|
def _SparseReorderGrad(
|
|
op: ops.Operation, unused_output_indices_grad, output_values_grad
|
|
):
|
|
"""Gradients for the SparseReorder op.
|
|
|
|
Args:
|
|
op: the SparseReorder op
|
|
unused_output_indices_grad: the incoming gradients of the output indices
|
|
output_values_grad: the incoming gradients of the output values
|
|
|
|
Returns:
|
|
Gradient for each of the 3 input tensors:
|
|
(input_indices, input_values, input_shape)
|
|
The gradients for input_indices and input_shape is None.
|
|
"""
|
|
input_indices = op.inputs[0]
|
|
input_shape = op.inputs[2]
|
|
|
|
num_entries = array_ops.shape(input_indices)[0]
|
|
entry_indices = math_ops.range(num_entries)
|
|
sp_unordered = sparse_tensor.SparseTensor(input_indices, entry_indices,
|
|
input_shape)
|
|
sp_ordered = sparse_ops.sparse_reorder(sp_unordered)
|
|
inverted_permutation = array_ops.invert_permutation(sp_ordered.values)
|
|
|
|
return (None, array_ops.gather(output_values_grad,
|
|
inverted_permutation), None)
|
|
|
|
|
|
@ops.RegisterGradient("SparseAdd")
|
|
def _SparseAddGrad(op: ops.Operation, *grads):
|
|
"""The backward operator for the SparseAdd op.
|
|
|
|
The SparseAdd op calculates A + B, where A, B, and the sum are all represented
|
|
as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
|
|
non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
|
|
values of A and B.
|
|
|
|
Args:
|
|
op: the SparseAdd op
|
|
*grads: the incoming gradients, one element per output of `op`
|
|
|
|
Returns:
|
|
Gradient for each of the 6 input tensors of SparseAdd:
|
|
(a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
|
|
The gradients for the indices, shapes, and the threshold are None.
|
|
"""
|
|
val_grad = grads[1]
|
|
a_indices = op.inputs[0]
|
|
b_indices = op.inputs[3]
|
|
sum_indices = op.outputs[0]
|
|
# NOTE: we do not need to take `thresh` into account, since it simply affects
|
|
# the non-zero elements of the sum, and we will peek into `sum_indices` in the
|
|
# gradient op.
|
|
|
|
a_val_grad, b_val_grad = gen_sparse_ops.sparse_add_grad(
|
|
val_grad, a_indices, b_indices, sum_indices)
|
|
a_val_grad.set_shape(op.inputs[1].get_shape())
|
|
b_val_grad.set_shape(op.inputs[4].get_shape())
|
|
# (a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
|
|
return (None, a_val_grad, None, None, b_val_grad, None, None)
|
|
|
|
|
|
@ops.RegisterGradient("SparseTensorDenseAdd")
|
|
def _SparseTensorDenseAddGrad(op: ops.Operation, out_grad):
|
|
sp_indices = op.inputs[0]
|
|
# (sparse_indices, sparse_values, sparse_shape, dense)
|
|
return (None, array_ops.gather_nd(out_grad, sp_indices), None, out_grad)
|
|
|
|
|
|
@ops.RegisterGradient("SparseReduceSum")
|
|
def _SparseReduceSumGrad(op: ops.Operation, out_grad):
|
|
"""Similar to gradient for the Sum Op (i.e. tf.reduce_sum())."""
|
|
sp_indices = op.inputs[0]
|
|
sp_shape = op.inputs[2]
|
|
output_shape_kept_dims = math_ops.reduced_shape(sp_shape, op.inputs[3])
|
|
out_grad_reshaped = array_ops.reshape(out_grad, output_shape_kept_dims)
|
|
scale = sp_shape // math_ops.cast(output_shape_kept_dims, dtypes.int64)
|
|
# (sparse_indices, sparse_values, sparse_shape, reduction_axes)
|
|
return (None, array_ops.gather_nd(out_grad_reshaped,
|
|
sp_indices // scale), None, None)
|
|
|
|
|
|
@ops.RegisterGradient("SparseSlice")
|
|
def _SparseSliceGrad(op: ops.Operation, *grads):
|
|
"""The backward operator for the SparseSlice op.
|
|
|
|
This op takes in the upstream gradient w.r.t. non-empty values of
|
|
the sliced `SparseTensor`, and outputs the gradients w.r.t.
|
|
the non-empty values of input `SparseTensor`.
|
|
|
|
Args:
|
|
op: the SparseSlice op
|
|
*grads: the incoming gradients, one element per output of `op`
|
|
|
|
Returns:
|
|
Gradient for each of the 5 input tensors of SparseSlice:
|
|
(indices, values, shape, start, size)
|
|
The gradients for the indices, shape, start and the size are None.
|
|
"""
|
|
backprop_val_grad = grads[1]
|
|
input_indices = op.inputs[0]
|
|
input_start = op.inputs[3]
|
|
output_indices = op.outputs[0]
|
|
|
|
val_grad = gen_sparse_ops.sparse_slice_grad(backprop_val_grad, input_indices,
|
|
input_start, output_indices)
|
|
val_grad.set_shape(op.inputs[1].get_shape())
|
|
# (indices, values, shape, start, size)
|
|
return (None, val_grad, None, None, None)
|
|
|
|
|
|
@ops.RegisterGradient("SparseTensorDenseMatMul")
|
|
def _SparseTensorDenseMatMulGrad(op: ops.Operation, grad):
|
|
"""Gradients for the dense tensor in the SparseTensorDenseMatMul op.
|
|
|
|
Args:
|
|
op: the SparseTensorDenseMatMul op
|
|
grad: the incoming gradient
|
|
|
|
Returns:
|
|
Gradient for each of the 4 input tensors:
|
|
(sparse_indices, sparse_values, sparse_shape, dense_tensor)
|
|
The gradients for indices and shape are None.
|
|
|
|
Raises:
|
|
TypeError: When the two operands don't have the same type.
|
|
"""
|
|
a_indices, a_values, a_shape = op.inputs[:3]
|
|
b = op.inputs[3]
|
|
adj_a = op.get_attr("adjoint_a")
|
|
adj_b = op.get_attr("adjoint_b")
|
|
|
|
a_type = a_values.dtype.base_dtype
|
|
b_type = b.dtype.base_dtype
|
|
if a_type != b_type:
|
|
raise TypeError(
|
|
f"SparseTensorDenseMatMul op received operands with different types: "
|
|
f"`{a_type}` and `{b_type}`.")
|
|
|
|
# gradient w.r.t. dense
|
|
b_grad = gen_sparse_ops.sparse_tensor_dense_mat_mul(
|
|
a_indices, a_values, a_shape, grad, adjoint_a=not adj_a)
|
|
if adj_b:
|
|
b_grad = array_ops.matrix_transpose(b_grad, conjugate=True)
|
|
|
|
# gradient w.r.t. sparse values
|
|
|
|
# TODO(zongheng): these gather calls could potentially duplicate rows/cols in
|
|
# memory. If there is a need, we should look into implementing this more
|
|
# intelligently to avoid duplicating data.
|
|
|
|
# With no adjoints, a_grad is matmul(grad, adjoint(b)). Since a is sparse, we
|
|
# just want to compute that matmul at the rows/columns of non-zero values. The
|
|
# (r, c) value is sum(grad[r, :] * adjoint(b)[:, c]), where the latter term is
|
|
# more conveniently written as conj(b)[c, :]. That expression is more
|
|
# efficient to calculate as a matmul, after expanding the two terms to be 2D
|
|
# (i.e. a row vector and a column vector).
|
|
#
|
|
# If adj_b then we replace conj(b) by transpose(b); if adj_a we need to
|
|
# adjoint the result, which is equivalent to swapping r and c and taking
|
|
# conjugates.
|
|
|
|
# Get grad[r, :] and b[c, :] (or with r and c swapped if adj_a, or with
|
|
# transpose(b) if adj_b), as batches of vectors (with the batch dimension
|
|
# corresponding to the non-zero indices of a).
|
|
rows = a_indices[:, 0]
|
|
cols = a_indices[:, 1]
|
|
parts_a = array_ops.gather(grad, rows if not adj_a else cols)
|
|
parts_b = array_ops.gather(
|
|
b if not adj_b else array_ops.transpose(b), cols if not adj_a else rows)
|
|
|
|
if not adj_a and not adj_b:
|
|
# grad[r, :] * conj(b[c, :]) = row(grad[r, :]) @ adjoint(row(b[c, :]))
|
|
a_values_grad = math_ops.matmul(
|
|
array_ops.expand_dims(parts_a, -2),
|
|
array_ops.expand_dims(parts_b, -2),
|
|
adjoint_b=True)
|
|
elif adj_a and not adj_b:
|
|
# conj(grad[c, :] * conj(b[r, :])) = adjoint(col(grad[c, :])) @ col(b[r, :])
|
|
a_values_grad = math_ops.matmul(
|
|
array_ops.expand_dims(parts_a, -1),
|
|
array_ops.expand_dims(parts_b, -1),
|
|
adjoint_a=True)
|
|
elif not adj_a and adj_b:
|
|
# grad[r, :] * transpose(b)[c, :] =
|
|
# row(grad[r, :]) @ col(transpose(b)[c, :])
|
|
a_values_grad = math_ops.matmul(
|
|
array_ops.expand_dims(parts_a, -2), array_ops.expand_dims(parts_b, -1))
|
|
elif adj_a and adj_b:
|
|
# conj(grad[c, :] * transpose(b)[r, :]) =
|
|
# adjoint(col(grad[c, :])) @ adjoint(row(transpose(b)[r, :])
|
|
a_values_grad = math_ops.matmul(
|
|
array_ops.expand_dims(parts_a, -1),
|
|
array_ops.expand_dims(parts_b, -2),
|
|
adjoint_a=True,
|
|
adjoint_b=True)
|
|
|
|
# gradients w.r.t. (a_indices, a_values, a_shape, b)
|
|
return (None, array_ops.squeeze(a_values_grad, axis=[-2, -1]), None, b_grad)
|
|
|
|
|
|
@ops.RegisterGradient("SparseDenseCwiseAdd")
|
|
def _SparseDenseCwiseAddGrad(unused_op, unused_grad):
|
|
raise NotImplementedError(
|
|
"Gradient for SparseDenseCwiseAdd is not implemented.")
|
|
|
|
|
|
def _SparseDenseCwiseMulOrDivGrad(op: ops.Operation, grad, is_mul):
|
|
"""Common code for SparseDenseCwise{Mul,Div} gradients."""
|
|
x_indices = op.inputs[0]
|
|
x_shape = op.inputs[2]
|
|
y = op.inputs[3]
|
|
|
|
y_shape = math_ops.cast(array_ops.shape(y), dtypes.int64)
|
|
num_added_dims = array_ops.expand_dims(
|
|
array_ops.size(x_shape) - array_ops.size(y_shape), 0)
|
|
augmented_y_shape = array_ops.concat(
|
|
[array_ops.ones(num_added_dims, ops.dtypes.int64), y_shape], 0)
|
|
|
|
scaling = x_shape // augmented_y_shape
|
|
scaled_indices = x_indices // scaling
|
|
scaled_indices = array_ops.slice(scaled_indices,
|
|
array_ops.concat([[0], num_added_dims], 0),
|
|
[-1, -1])
|
|
dense_vals = array_ops.gather_nd(y, scaled_indices)
|
|
|
|
if is_mul:
|
|
dx = grad * dense_vals
|
|
dy_val = grad * op.inputs[1]
|
|
else:
|
|
dx = grad / dense_vals
|
|
dy_val = grad * (-op.inputs[1] / math_ops.square(dense_vals))
|
|
# indices can repeat after scaling, so we can't use sparse_to_dense().
|
|
dy = sparse_ops.sparse_add(
|
|
array_ops.zeros_like(y),
|
|
sparse_tensor.SparseTensor(scaled_indices, dy_val, y_shape))
|
|
|
|
# (sp_indices, sp_vals, sp_shape, dense)
|
|
return (None, dx, None, dy)
|
|
|
|
|
|
@ops.RegisterGradient("SparseDenseCwiseMul")
|
|
def _SparseDenseCwiseMulGrad(op: ops.Operation, grad):
|
|
"""Gradients for SparseDenseCwiseMul."""
|
|
return _SparseDenseCwiseMulOrDivGrad(op, grad, True)
|
|
|
|
|
|
@ops.RegisterGradient("SparseDenseCwiseDiv")
|
|
def _SparseDenseCwiseDivGrad(op: ops.Operation, grad):
|
|
"""Gradients for SparseDenseCwiseDiv."""
|
|
return _SparseDenseCwiseMulOrDivGrad(op, grad, False)
|
|
|
|
|
|
@ops.RegisterGradient("SparseSoftmax")
|
|
def _SparseSoftmaxGrad(op: ops.Operation, grad):
|
|
"""Gradients for SparseSoftmax.
|
|
|
|
The calculation is the same as SoftmaxGrad:
|
|
|
|
grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
|
|
|
|
where we now only operate on the non-zero values present in the SparseTensors.
|
|
|
|
Args:
|
|
op: the SparseSoftmax op.
|
|
grad: the upstream gradient w.r.t. the non-zero SparseSoftmax output values.
|
|
|
|
Returns:
|
|
Gradients w.r.t. the input (sp_indices, sp_values, sp_shape).
|
|
"""
|
|
indices, shape = op.inputs[0], op.inputs[2]
|
|
out_vals = op.outputs[0]
|
|
sp_output = sparse_tensor.SparseTensor(indices, out_vals, shape)
|
|
sp_grad = sparse_tensor.SparseTensor(indices, grad, shape)
|
|
sp_product = sparse_tensor.SparseTensor(indices,
|
|
sp_output.values * sp_grad.values,
|
|
shape)
|
|
|
|
# [..., B, 1], dense.
|
|
sum_reduced = -sparse_ops.sparse_reduce_sum(sp_product, [-1], keepdims=True)
|
|
# sparse [..., B, C] + dense [..., B, 1] with broadcast; outputs sparse.
|
|
sp_sum = sparse_ops.sparse_dense_cwise_add(sp_grad, sum_reduced)
|
|
|
|
grad_x = sp_sum.values * sp_output.values
|
|
return [None, grad_x, None]
|
|
|
|
|
|
@ops.RegisterGradient("SparseSparseMaximum")
|
|
def _SparseSparseMaximumGrad(unused_op: ops.Operation, unused_grad):
|
|
raise NotImplementedError(
|
|
"Gradient for SparseSparseMaximum is not implemented."
|
|
)
|
|
|
|
|
|
@ops.RegisterGradient("SparseSparseMinimum")
|
|
def _SparseSparseMinimumGrad(unused_op: ops.Operation, unused_grad):
|
|
raise NotImplementedError(
|
|
"Gradient for SparseSparseMinimum is not implemented."
|
|
)
|
|
|
|
|
|
@ops.RegisterGradient("SparseFillEmptyRows")
|
|
def _SparseFillEmptyRowsGrad(
|
|
op: ops.Operation,
|
|
unused_grad_output_indices,
|
|
output_grad_values,
|
|
unused_grad_empty_row_indicator,
|
|
unused_grad_reverse_index_map,
|
|
):
|
|
"""Gradients for SparseFillEmptyRows."""
|
|
reverse_index_map = op.outputs[3]
|
|
|
|
d_values, d_default_value = gen_sparse_ops.sparse_fill_empty_rows_grad(
|
|
reverse_index_map=reverse_index_map, grad_values=output_grad_values
|
|
)
|
|
|
|
# d_indices, d_values, d_dense_shape, d_default_value.
|
|
return [None, d_values, None, d_default_value]
|
|
|
|
|
|
@ops.RegisterGradient("SparseToDense")
|
|
def _SparseToDenseGrad(op: ops.Operation, grad):
|
|
sparse_indices, output_shape, _, _ = op.inputs
|
|
|
|
sparse_values_grad = array_ops.gather_nd(grad, sparse_indices)
|
|
default_value_grad = math_ops.reduce_sum(grad) - math_ops.reduce_sum(
|
|
sparse_values_grad)
|
|
return [
|
|
array_ops.zeros_like(sparse_indices),
|
|
array_ops.zeros_like(output_shape), sparse_values_grad, default_value_grad
|
|
]
|