2048 lines
70 KiB
Python
2048 lines
70 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Gradients for operators defined in math_ops.py."""
|
|
import numpy as np
|
|
|
|
from tensorflow.python.compat import compat
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import indexed_slices as indexed_slices_lib
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import gen_array_ops
|
|
from tensorflow.python.ops import gen_math_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import special_math_ops
|
|
|
|
|
|
@ops.RegisterGradient("ArgMax")
|
|
def _ArgMaxGrad(op: ops.Operation, grad):
|
|
del op, grad
|
|
return [None, None]
|
|
|
|
|
|
@ops.RegisterGradient("ArgMin")
|
|
def _ArgMinGrad(op: ops.Operation, grad):
|
|
del op, grad
|
|
return [None, None]
|
|
|
|
|
|
@ops.RegisterGradient("EuclideanNorm")
|
|
def _EuclideanNormGrad(op: ops.Operation, grad):
|
|
"""Gradient for EuclideanNorm."""
|
|
|
|
output = op.outputs[0]
|
|
|
|
if not op.get_attr("keep_dims"):
|
|
output_shape_kept_dims = math_ops.reduced_shape(
|
|
array_ops.shape(op.inputs[0]), op.inputs[1])
|
|
output = array_ops.reshape(output, output_shape_kept_dims)
|
|
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
|
|
|
return math_ops.truediv(op.inputs[0], output / grad), None
|
|
|
|
|
|
def SmartBroadcastGradientArgs(x, y, grad=None):
|
|
"""Version of `BroadcastGradientArgs` optimized for partially-known shapes.
|
|
|
|
Args:
|
|
x: The first argument of a broadcasting binary op.
|
|
y: The second argument of a broadcasting binary op.
|
|
grad: Deprecated.
|
|
|
|
Returns:
|
|
A pair of triples, one per argument with
|
|
* Shape of the argument (tensor);
|
|
* Reduction axes for the argument (list or tensor);
|
|
* Boolean indicating whether the reduction must be applied.
|
|
"""
|
|
del grad
|
|
x_shape = array_ops.shape(x)
|
|
y_shape = array_ops.shape(y)
|
|
|
|
if (not context.executing_eagerly() and
|
|
isinstance(x, tensor.Tensor) and
|
|
isinstance(y, tensor.Tensor)):
|
|
x_axes, y_axes = _InferGradientReductionAxes(x.shape, y.shape)
|
|
else:
|
|
x_axes, y_axes = None, None
|
|
|
|
if x_axes is None or y_axes is None:
|
|
# NOTE: In graph mode, this is never exercised for statically known shapes.
|
|
x_axes, y_axes = gen_array_ops.broadcast_gradient_args(x_shape, y_shape)
|
|
x_must_reduce = True
|
|
y_must_reduce = True
|
|
else:
|
|
x_must_reduce = x_axes or x.shape.rank < y.shape.rank
|
|
y_must_reduce = y_axes or y.shape.rank < x.shape.rank
|
|
|
|
return (x_shape, x_axes, x_must_reduce), (y_shape, y_axes, y_must_reduce)
|
|
|
|
|
|
def _InferGradientReductionAxes(x_shape, y_shape):
|
|
"""Infers the sets of axes that might have been broadcasted."""
|
|
x_rank = x_shape.rank
|
|
y_rank = y_shape.rank
|
|
if x_rank is None or y_rank is None:
|
|
return None, None
|
|
|
|
# Convert shapes for V1 compatibility, can be omitted in V2.
|
|
x_shape = x_shape.as_list()
|
|
y_shape = y_shape.as_list()
|
|
|
|
b_rank = max(x_rank, y_rank)
|
|
x_axes = []
|
|
y_axes = []
|
|
for axis in range(b_rank):
|
|
x_dim = 1 if axis < b_rank - x_rank else x_shape[axis - (b_rank - x_rank)]
|
|
y_dim = 1 if axis < b_rank - y_rank else y_shape[axis - (b_rank - y_rank)]
|
|
if x_dim == 1 and y_dim != 1:
|
|
# It's safe to assume that x_dim was broadcasted.
|
|
x_axes.append(axis)
|
|
elif y_dim == 1 and x_dim != 1:
|
|
# It's safe to assume that y_dim was broadcasted.
|
|
y_axes.append(axis)
|
|
elif x_dim is None or y_dim is None:
|
|
# Broadcasting decision is dynamic (data-dependent).
|
|
return None, None
|
|
|
|
return x_axes, y_axes
|
|
|
|
|
|
def _ReduceGradientArg(grad, shape_axes_must_reduce):
|
|
"""Reduces gradients of one of the arguments of a broadcasting binary op."""
|
|
shape, axes, must_reduce = shape_axes_must_reduce
|
|
if grad is not None and must_reduce:
|
|
# Applying keepdims=True in presence of unknown axes opens up some
|
|
# opportunities for optimizations. For example, _SumGrad below won't have to
|
|
# emit extra ops to recover reduced indices for broadcasting.
|
|
grad = math_ops.reduce_sum(grad, axes, keepdims=True)
|
|
grad = array_ops.reshape(grad, shape)
|
|
return grad
|
|
|
|
|
|
def _ReduceGradientArgs(x, y, gx, gy):
|
|
"""Reduces gradients of both arguments of a broadcasting binary op."""
|
|
if gx is not None or gy is not None:
|
|
bx, by = SmartBroadcastGradientArgs(x, y)
|
|
gx = _ReduceGradientArg(gx, bx)
|
|
gy = _ReduceGradientArg(gy, by)
|
|
return gx, gy
|
|
|
|
|
|
_EMPTY_TUPLE = ()
|
|
|
|
|
|
def _IsScalar(x):
|
|
return x._shape_tuple() is _EMPTY_TUPLE # pylint: disable=protected-access
|
|
|
|
|
|
def _SafeShapeDiv(x, y):
|
|
"""Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
|
|
return x // math_ops.maximum(y, 1)
|
|
|
|
|
|
@ops.RegisterGradient("Sum")
|
|
def _SumGrad(op: ops.Operation, grad):
|
|
"""Gradient for Sum."""
|
|
# Fast path for when reducing to a scalar and ndims is known: adds only
|
|
# Reshape and Tile ops (and possibly a Shape).
|
|
input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
|
|
if input_0_shape is not None:
|
|
axes = tensor_util.constant_value(op.inputs[1])
|
|
if axes is not None:
|
|
rank = len(input_0_shape)
|
|
if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
|
|
if context.executing_eagerly():
|
|
ctx = context.context()
|
|
new_shape = ctx.ones_rank_cache().get(rank)
|
|
if new_shape is None:
|
|
new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
|
|
ctx.ones_rank_cache().put(rank, new_shape)
|
|
else:
|
|
new_shape = [1] * rank
|
|
grad = array_ops.reshape(grad, new_shape)
|
|
# If shape is not fully defined (but rank is), we use Shape.
|
|
if None not in input_0_shape:
|
|
input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32)
|
|
else:
|
|
input_shape = array_ops.shape(op.inputs[0])
|
|
return [array_ops.tile(grad, input_shape), None]
|
|
elif None not in input_0_shape and not context.executing_eagerly():
|
|
# The shape and reduction indices are statically known, so we use a
|
|
# graph-level cache to avoid recomputing `reduced_shape()` for each
|
|
# invocation.
|
|
graph = ops.get_default_graph()
|
|
|
|
# Canonicalize `axes` to be a tuple of indices. The incoming
|
|
# value may be a scalar or a vector, and may include negative indices.
|
|
axes = tuple(axes.reshape(-1))
|
|
|
|
try:
|
|
output_shape_kept_dims, tile_scaling = graph._reduced_shape_cache[ # pylint: disable=protected-access
|
|
(input_0_shape, axes)]
|
|
except KeyError:
|
|
|
|
# Compute and cache `output_shape_kept_dims` and `tile_scaling`.
|
|
def EvaluateAsTuple(t):
|
|
if tensor_util.is_tf_type(t):
|
|
value = tensor_util.try_evaluate_constant(t)
|
|
assert value is not None
|
|
else:
|
|
value = t
|
|
return tuple(value)
|
|
|
|
output_shape_kept_dims = EvaluateAsTuple(
|
|
math_ops.reduced_shape(input_0_shape, axes))
|
|
tile_scaling = EvaluateAsTuple(
|
|
_SafeShapeDiv(input_0_shape, output_shape_kept_dims))
|
|
graph._reduced_shape_cache[(input_0_shape, axes)] = ( # pylint:disable=protected-access
|
|
output_shape_kept_dims, tile_scaling)
|
|
|
|
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
|
return [array_ops.tile(grad, tile_scaling), None]
|
|
|
|
input_shape = array_ops.shape(op.inputs[0])
|
|
|
|
if not op.get_attr("keep_dims"):
|
|
with ops.colocate_with(input_shape):
|
|
# TODO(apassos) remove this once device placement for eager ops makes
|
|
# more sense.
|
|
output_shape_kept_dims = math_ops.reduced_shape(input_shape,
|
|
op.inputs[1])
|
|
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
|
return [array_ops.broadcast_to(grad, input_shape), None]
|
|
|
|
|
|
def _MinOrMaxGrad(op: ops.Operation, grad):
|
|
"""Gradient for Min or Max. Amazingly it's precisely the same code."""
|
|
input_shape = array_ops.shape(op.inputs[0])
|
|
y = op.outputs[0]
|
|
if not op.get_attr("keep_dims"):
|
|
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
|
y = array_ops.reshape(y, output_shape_kept_dims)
|
|
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
|
else:
|
|
output_shape_kept_dims = array_ops.shape(y)
|
|
|
|
# Compute the number of selected (maximum or minimum) elements in each
|
|
# reduction dimension. If there are multiple minimum or maximum elements
|
|
# then the gradient will be divided between them.
|
|
indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype)
|
|
num_selected = array_ops.reshape(
|
|
math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims)
|
|
|
|
return [math_ops.divide(indicators, num_selected) * grad, None]
|
|
|
|
|
|
@ops.RegisterGradient("Max")
|
|
def _MaxGrad(op: ops.Operation, grad):
|
|
"""Gradient for Max."""
|
|
return _MinOrMaxGrad(op, grad)
|
|
|
|
|
|
@ops.RegisterGradient("Min")
|
|
def _MinGrad(op: ops.Operation, grad):
|
|
return _MinOrMaxGrad(op, grad)
|
|
|
|
|
|
@ops.RegisterGradient("Mean")
|
|
def _MeanGrad(op: ops.Operation, grad):
|
|
"""Gradient for Mean."""
|
|
sum_grad = _SumGrad(op, grad)[0]
|
|
input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
|
|
output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access
|
|
if (input_shape is not None and output_shape is not None and
|
|
None not in input_shape and None not in output_shape):
|
|
input_size = np.prod(input_shape)
|
|
output_size = np.prod(output_shape)
|
|
factor = input_size // max(output_size, 1)
|
|
factor = constant_op.constant(factor, dtype=sum_grad.dtype)
|
|
else:
|
|
input_shape = array_ops.shape(op.inputs[0])
|
|
input_rank = array_ops.size(input_shape)
|
|
axes = (op.inputs[1] + input_rank) % input_rank
|
|
factor = math_ops.reduce_prod(array_ops.gather(input_shape, axes))
|
|
return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None
|
|
|
|
|
|
@ops.RegisterGradient("Prod")
|
|
def _ProdGrad(op: ops.Operation, grad):
|
|
"""Gradient for Prod."""
|
|
# The gradient can be expressed by dividing the product by each entry of the
|
|
# input tensor, but this approach can't deal with zeros in the input.
|
|
# Here, we avoid this problem by composing the output as a product of two
|
|
# cumprod operations.
|
|
|
|
input_shape = array_ops.shape(op.inputs[0])
|
|
# Reshape reduction indices for the case where the parameter is a scalar
|
|
reduction_indices = array_ops.reshape(op.inputs[1], [-1])
|
|
|
|
# Expand grad to full input shape
|
|
if not op.get_attr("keep_dims"):
|
|
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
|
grad = array_ops.reshape(grad, output_shape_kept_dims)
|
|
|
|
grad = array_ops.broadcast_to(grad, input_shape)
|
|
|
|
# Pack all reduced dimensions into a single one, so we can perform the
|
|
# cumprod ops. If the reduction dims list is empty, it defaults to float32,
|
|
# so we need to cast here. We put all the shape-related ops on CPU to avoid
|
|
# copying back and forth, and since listdiff is CPU only.
|
|
with ops.device("/cpu:0"):
|
|
rank = array_ops.rank(op.inputs[0])
|
|
reduction_indices = (reduction_indices + rank) % rank
|
|
reduced = math_ops.cast(reduction_indices, dtypes.int32)
|
|
idx = math_ops.range(0, rank)
|
|
other, _ = gen_array_ops.list_diff(idx, reduced, dtypes.int32)
|
|
perm = array_ops.concat([reduced, other], 0)
|
|
reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
|
|
other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
|
|
permuted = array_ops.transpose(op.inputs[0], perm)
|
|
permuted_shape = array_ops.shape(permuted)
|
|
reshaped = array_ops.reshape(permuted, (reduced_num, other_num))
|
|
|
|
# Calculate product, leaving out the current entry
|
|
left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
|
|
right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
|
|
# For complex inputs, the gradient is in the conjugate direction.
|
|
y = array_ops.reshape(
|
|
math_ops.conj(left) * math_ops.conj(right), permuted_shape)
|
|
|
|
# Invert the transpose and reshape operations.
|
|
# Make sure to set the statically known shape information through a reshape.
|
|
out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
|
|
return array_ops.reshape(out, input_shape), None
|
|
|
|
|
|
@ops.RegisterGradient("SegmentSum")
|
|
def _SegmentSumGrad(op: ops.Operation, grad):
|
|
"""Gradient for SegmentSum."""
|
|
return array_ops.gather(grad, op.inputs[1]), None
|
|
|
|
|
|
@ops.RegisterGradient("SegmentMean")
|
|
def _SegmentMeanGrad(op: ops.Operation, grad):
|
|
"""Gradient for SegmentMean."""
|
|
input_rank = array_ops.rank(op.inputs[0])
|
|
ones_shape = array_ops.concat([
|
|
array_ops.shape(op.inputs[1]),
|
|
array_ops.ones(
|
|
array_ops.expand_dims(input_rank - 1, 0), dtype=dtypes.int32)
|
|
], 0)
|
|
ones = array_ops.ones(ones_shape, dtype=grad.dtype)
|
|
scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1]))
|
|
return array_ops.gather(scaled_grad, op.inputs[1]), None
|
|
|
|
|
|
def _SparseSegmentReduceGradV2(op, grad, norm=None):
|
|
"""Sparse gradient for SparseSegment(Sum|Mean|SqrtN)[WithNumSegments]."""
|
|
assert norm is None or norm == "mean" or norm == "sqrtn"
|
|
data = op.inputs[0]
|
|
indices = op.inputs[1]
|
|
segment_ids = op.inputs[2]
|
|
data_shape = array_ops.shape(op.inputs[0])
|
|
dense_output_dim0 = data_shape[0]
|
|
grad_fn = (
|
|
math_ops.sparse_segment_mean_grad_v2
|
|
if norm == "mean"
|
|
else math_ops.sparse_segment_sqrt_n_grad_v2
|
|
if norm == "sqrtn"
|
|
else math_ops.sparse_segment_sum_grad_v2
|
|
)
|
|
grad_values, sorted_unique_indices = grad_fn(
|
|
grad, indices, segment_ids, dense_output_dim0
|
|
)
|
|
return indexed_slices_lib.IndexedSlices(
|
|
grad_values, sorted_unique_indices, data_shape
|
|
)
|
|
|
|
|
|
def _GetOpAttrOrNone(op, name):
|
|
"""Returns the value of the attr of `op` with the given `name`, or None if no
|
|
|
|
such attr exists.
|
|
"""
|
|
try:
|
|
return op.get_attr(name)
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
@ops.RegisterGradient("SparseSegmentSum")
|
|
def _SparseSegmentSumGrad(op: ops.Operation, grad):
|
|
"""Gradient for SparseSegmentSum."""
|
|
if _GetOpAttrOrNone(op, "sparse_gradient"):
|
|
return _SparseSegmentReduceGradV2(op, grad), None, None
|
|
dim0 = array_ops.shape(op.inputs[0])[0]
|
|
if compat.forward_compatible(2021, 6, 10):
|
|
return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
|
|
dim0), None, None)
|
|
else:
|
|
return (math_ops.unsorted_segment_sum(
|
|
array_ops.gather(grad, op.inputs[2]), op.inputs[1], dim0), None, None)
|
|
|
|
|
|
@ops.RegisterGradient("SparseSegmentSumWithNumSegments")
|
|
def _SparseSegmentSumWithNumSegmentsGrad(op: ops.Operation, grad):
|
|
"""Gradient for SparseSegmentSumWithNumSegments."""
|
|
if _GetOpAttrOrNone(op, "sparse_gradient"):
|
|
return _SparseSegmentReduceGradV2(op, grad), None, None, None
|
|
dim0 = array_ops.shape(op.inputs[0])[0]
|
|
if compat.forward_compatible(2021, 6, 10):
|
|
return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
|
|
dim0), None, None, None)
|
|
else:
|
|
return (math_ops.unsorted_segment_sum(
|
|
array_ops.gather(grad, op.inputs[2]), op.inputs[1],
|
|
dim0), None, None, None)
|
|
|
|
|
|
@ops.RegisterGradient("SparseSegmentMean")
|
|
def _SparseSegmentMeanGrad(op: ops.Operation, grad):
|
|
"""Gradient for SparseSegmentMean."""
|
|
if _GetOpAttrOrNone(op, "sparse_gradient"):
|
|
return _SparseSegmentReduceGradV2(op, grad, "mean"), None, None
|
|
dim0 = array_ops.shape(op.inputs[0])[0]
|
|
return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
|
|
dim0), None, None)
|
|
|
|
|
|
@ops.RegisterGradient("SparseSegmentMeanWithNumSegments")
|
|
def _SparseSegmentMeanWithNumSegmentsGrad(op: ops.Operation, grad):
|
|
"""Gradient for SparseSegmentMeanWithNumSegments."""
|
|
if _GetOpAttrOrNone(op, "sparse_gradient"):
|
|
return _SparseSegmentReduceGradV2(op, grad, "mean"), None, None, None
|
|
dim0 = array_ops.shape(op.inputs[0])[0]
|
|
return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
|
|
dim0), None, None, None)
|
|
|
|
|
|
@ops.RegisterGradient("SparseSegmentSqrtN")
|
|
def _SparseSegmentSqrtNGrad(op: ops.Operation, grad):
|
|
"""Gradient for SparseSegmentSqrtN."""
|
|
if _GetOpAttrOrNone(op, "sparse_gradient"):
|
|
return _SparseSegmentReduceGradV2(op, grad, "sqrtn"), None, None
|
|
dim0 = array_ops.shape(op.inputs[0])[0]
|
|
return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
|
|
dim0), None, None)
|
|
|
|
|
|
@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments")
|
|
def _SparseSegmentSqrtNWithNumSegmentsGrad(op: ops.Operation, grad):
|
|
"""Gradient for SparseSegmentSqrtNWithNumSegments."""
|
|
if _GetOpAttrOrNone(op, "sparse_gradient"):
|
|
return _SparseSegmentReduceGradV2(op, grad, "sqrtn"), None, None, None
|
|
dim0 = array_ops.shape(op.inputs[0])[0]
|
|
return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
|
|
dim0), None, None, None)
|
|
|
|
|
|
def _SegmentMinOrMaxGrad(op: ops.Operation, grad):
|
|
""" Gradient for SegmentMin and SegmentMax. """
|
|
zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype)
|
|
# Get the number of selected (minimum or maximum) elements in each segment.
|
|
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
|
|
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
|
|
num_selected = math_ops.segment_sum(
|
|
math_ops.cast(is_selected, grad.dtype), op.inputs[1])
|
|
# Compute the gradient for each segment. The gradient for the ith segment is
|
|
# divided evenly among the selected elements in that segment.
|
|
weighted_grads = math_ops.divide(grad, num_selected)
|
|
gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
|
|
return array_ops.where_v2(is_selected, gathered_grads, zeros), None
|
|
|
|
|
|
@ops.RegisterGradient("SegmentMin")
|
|
def _SegmentMinGrad(op: ops.Operation, grad):
|
|
"""Gradient for SegmentMin."""
|
|
return _SegmentMinOrMaxGrad(op, grad)
|
|
|
|
|
|
@ops.RegisterGradient("SegmentMax")
|
|
def _SegmentMaxGrad(op: ops.Operation, grad):
|
|
"""Gradient for SegmentMax."""
|
|
return _SegmentMinOrMaxGrad(op, grad)
|
|
|
|
|
|
# pylint: disable=g-doc-args
|
|
@ops.RegisterGradient("SegmentProd")
|
|
def _SegmentProdGrad(op: ops.Operation, grad):
|
|
"""Gradient for SegmentProd.
|
|
|
|
The gradient can be expressed for each segment by dividing the segment's
|
|
product by each element of the segment input tensor, but this approach can't
|
|
deal with zeros in the input.
|
|
Unlike reduce_prod we can't use cumsum here as individual segments may have
|
|
a different number of elements. Therefore we consider three cases:
|
|
1) A segment input contains no zeros and we can safely divide by the input
|
|
tensor.
|
|
2) A segment contains exactly one zero. Then the gradient of each input of
|
|
the segment is zero except for the 0-input, there the gradient is
|
|
the product of the remaining segment entries.
|
|
3) A segment contains at least two zeros. The gradient is zero for all
|
|
segment inputs.
|
|
"""
|
|
data = op.inputs[0]
|
|
segment_ids = op.inputs[1]
|
|
is_zero = math_ops.equal(data, 0)
|
|
num_zeros = gen_math_ops.segment_sum(
|
|
math_ops.cast(is_zero, dtype=dtypes.int32), segment_ids)
|
|
# handle case 3 and set the gradient to 0 for segments with more than one
|
|
# 0 as input
|
|
grad = array_ops.where_v2(
|
|
math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
|
|
# replace all zeros with ones and compute the segment_prod
|
|
non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(data), data)
|
|
non_zero_prod = gen_math_ops.segment_prod(non_zero_data, segment_ids)
|
|
gathered_prod = array_ops.gather(op.outputs[0], segment_ids)
|
|
gathered_non_zero_prod = array_ops.gather(non_zero_prod, segment_ids)
|
|
prod_divided_by_el = gathered_prod / non_zero_data
|
|
# Now fetch the individual results for segments containing 0 and those that
|
|
# don't.
|
|
partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod,
|
|
prod_divided_by_el)
|
|
gathered_grad = array_ops.gather(grad, segment_ids)
|
|
return gathered_grad * partial_derivative, None
|
|
|
|
|
|
def _GatherDropNegatives(params,
|
|
ids,
|
|
zero_clipped_indices=None,
|
|
is_positive=None):
|
|
""" Helper function for unsorted segment ops.
|
|
|
|
Gathers params for
|
|
positive segment ids and gathers 0 for inputs with negative segment id.
|
|
Also returns the clipped indices and a boolean mask with the same shape
|
|
as ids where a positive id is masked as true. With this, the latter two
|
|
can be passed as arguments to this function to reuse them.
|
|
"""
|
|
if zero_clipped_indices is None:
|
|
zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids))
|
|
gathered = array_ops.gather(params, zero_clipped_indices)
|
|
if is_positive is None:
|
|
is_positive = math_ops.greater_equal(ids, 0)
|
|
# tf.where(condition, x, y) requires condition to have the same shape as x
|
|
# and y.
|
|
is_positive_shape = array_ops.shape(is_positive)
|
|
broadcastable_shape = array_ops.concat(
|
|
[
|
|
is_positive_shape,
|
|
array_ops.ones(
|
|
[array_ops.rank(gathered) - array_ops.rank(is_positive)],
|
|
dtype=is_positive_shape.dtype,
|
|
),
|
|
],
|
|
axis=0,
|
|
)
|
|
is_positive = array_ops.reshape(is_positive, broadcastable_shape)
|
|
is_positive = is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool)
|
|
# replace gathered params of negative indices with 0
|
|
zero_slice = array_ops.zeros_like(gathered)
|
|
return (
|
|
array_ops.where_v2(is_positive, gathered, zero_slice),
|
|
zero_clipped_indices,
|
|
is_positive,
|
|
)
|
|
|
|
|
|
def _UnsortedSegmentMinOrMaxGrad(op: ops.Operation, grad):
|
|
"""Gradient for UnsortedSegmentMin and UnsortedSegmentMax."""
|
|
# Get the number of selected (minimum or maximum) elements in each segment.
|
|
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(
|
|
op.outputs[0], op.inputs[1]
|
|
)
|
|
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
|
|
is_selected = math_ops.logical_and(is_selected, is_positive)
|
|
num_selected = math_ops.unsorted_segment_sum(
|
|
math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]
|
|
)
|
|
# Compute the gradient for each segment. The gradient for the ith segment is
|
|
# divided evenly among the selected elements in that segment.
|
|
weighted_grads = math_ops.divide(grad, num_selected)
|
|
gathered_grads, _, _ = _GatherDropNegatives(
|
|
weighted_grads, None, zero_clipped_indices, is_positive
|
|
)
|
|
zeros = array_ops.zeros_like(gathered_grads)
|
|
return array_ops.where_v2(is_selected, gathered_grads, zeros), None, None
|
|
|
|
|
|
@ops.RegisterGradient("UnsortedSegmentSum")
|
|
def _UnsortedSegmentSumGrad(op: ops.Operation, grad):
|
|
"""Gradient for UnsortedSegmentSum."""
|
|
return _GatherDropNegatives(grad, op.inputs[1])[0], None, None
|
|
|
|
|
|
@ops.RegisterGradient("UnsortedSegmentMax")
|
|
def _UnsortedSegmentMaxGrad(op: ops.Operation, grad):
|
|
""" Gradient for UnsortedSegmentMax. """
|
|
return _UnsortedSegmentMinOrMaxGrad(op, grad)
|
|
|
|
|
|
@ops.RegisterGradient("UnsortedSegmentMin")
|
|
def _UnsortedSegmentMinGrad(op: ops.Operation, grad):
|
|
""" Gradient for UnsortedSegmentMin. """
|
|
return _UnsortedSegmentMinOrMaxGrad(op, grad)
|
|
|
|
|
|
@ops.RegisterGradient("UnsortedSegmentProd")
|
|
def _UnsortedSegmentProdGrad(op: ops.Operation, grad):
|
|
""" Gradient for UnsortedSegmentProd.
|
|
|
|
The gradient can be expressed for each segment by dividing the segment's
|
|
product by each element of the segment input tensor, but this approach can't
|
|
deal with zeros in the input.
|
|
Unlike reduce_prod we can't use cumsum here as individual segments may have
|
|
a different number of elements. Therefore we consider three cases:
|
|
1) A segment input contains no zeros and we can safely divide by the input
|
|
tensor.
|
|
2) A segment contains exactly one zero. Then the gradient of each input of
|
|
the segment is zero except for the 0-input, there the gradient is
|
|
the product of the remaining segment entries.
|
|
3) A segment contains at least two zeros. The gradient is zero for all
|
|
segment inputs.
|
|
"""
|
|
# Note that unsorted_segment_sum will filter out the negative indices,
|
|
# so we don't need to do a logical_and with is_positive here
|
|
is_zero = math_ops.equal(op.inputs[0], 0)
|
|
num_zeros = gen_math_ops.unsorted_segment_sum(
|
|
math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2])
|
|
# handle case 3 and set the gradient to 0 for segments with more than one
|
|
# 0 as input
|
|
grad = array_ops.where_v2(
|
|
math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
|
|
# replace all zeros with ones and compute the unsorted_segment_prod
|
|
non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(op.inputs[0]),
|
|
op.inputs[0])
|
|
non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data,
|
|
op.inputs[1], op.inputs[2])
|
|
# clip the indices for gather to be positive
|
|
zero_clipped_indices = math_ops.maximum(op.inputs[1],
|
|
array_ops.zeros_like(op.inputs[1]))
|
|
gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices)
|
|
gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices)
|
|
prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf.
|
|
# Now fetch the individual results for segments containing 0 and those that
|
|
# don't. is_zero will also fetch results for entries with negative index
|
|
# but the following gather_drop_negatives sets the corresponding entry in
|
|
# grad to 0 for these
|
|
partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod,
|
|
prod_divided_by_el)
|
|
gathered_grad = _GatherDropNegatives(grad, op.inputs[1],
|
|
zero_clipped_indices)[0]
|
|
return gathered_grad * partial_derivative, None, None
|
|
|
|
|
|
@ops.RegisterGradient("Abs")
|
|
def _AbsGrad(op: ops.Operation, grad):
|
|
x = op.inputs[0]
|
|
return grad * math_ops.sign(x)
|
|
|
|
|
|
@ops.RegisterGradient("Neg")
|
|
def _NegGrad(_, grad):
|
|
"""Returns -grad."""
|
|
return -grad
|
|
|
|
|
|
@ops.RegisterGradient("Inv")
|
|
def _InvGrad(op: ops.Operation, grad):
|
|
"""Returns -grad * (1 / x^2)."""
|
|
y = op.outputs[0] # y = 1 / x
|
|
return gen_math_ops.reciprocal_grad(y, grad)
|
|
|
|
|
|
@ops.RegisterGradient("Reciprocal")
|
|
def _ReciprocalGrad(op: ops.Operation, grad):
|
|
"""Returns -grad * (1 / x^2)."""
|
|
y = op.outputs[0] # y = 1 / x
|
|
return gen_math_ops.reciprocal_grad(y, grad)
|
|
|
|
|
|
@ops.RegisterGradient("InvGrad")
|
|
def _InvGradGrad(op: ops.Operation, grad):
|
|
b = op.inputs[1]
|
|
# op.output[0]: y = -b * conj(a)^2
|
|
with ops.control_dependencies([grad]):
|
|
ca = math_ops.conj(op.inputs[0])
|
|
cg = math_ops.conj(grad)
|
|
return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
|
|
|
|
|
|
@ops.RegisterGradient("ReciprocalGrad")
|
|
def _ReciprocalGradGrad(op: ops.Operation, grad):
|
|
b = op.inputs[1]
|
|
# op.output[0]: y = -b * conj(a)^2
|
|
with ops.control_dependencies([grad]):
|
|
ca = math_ops.conj(op.inputs[0])
|
|
cg = math_ops.conj(grad)
|
|
return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
|
|
|
|
|
|
@ops.RegisterGradient("Square")
|
|
def _SquareGrad(op: ops.Operation, grad):
|
|
x = op.inputs[0]
|
|
# Added control dependencies to prevent 2*x from being computed too early.
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
y = constant_op.constant(2.0, dtype=x.dtype)
|
|
return math_ops.multiply(grad, math_ops.multiply(x, y))
|
|
|
|
|
|
@ops.RegisterGradient("Sqrt")
|
|
def _SqrtGrad(op: ops.Operation, grad):
|
|
y = op.outputs[0] # y = x^(1/2)
|
|
return gen_math_ops.sqrt_grad(y, grad)
|
|
|
|
|
|
@ops.RegisterGradient("SqrtGrad")
|
|
def _SqrtGradGrad(op: ops.Operation, grad):
|
|
a = op.inputs[0]
|
|
y = op.outputs[0] # y = 0.5 * b / conj(a)
|
|
with ops.control_dependencies([grad]):
|
|
ga = grad / a
|
|
return -math_ops.conj(ga) * y, 0.5 * ga # pylint: disable=invalid-unary-operand-type
|
|
|
|
|
|
@ops.RegisterGradient("Rsqrt")
|
|
def _RsqrtGrad(op: ops.Operation, grad):
|
|
"""Returns -0.5 * grad * conj(y)^3."""
|
|
y = op.outputs[0] # y = x^(-1/2)
|
|
return gen_math_ops.rsqrt_grad(y, grad)
|
|
|
|
|
|
@ops.RegisterGradient("RsqrtGrad")
|
|
def _RsqrtGradGrad(op: ops.Operation, grad):
|
|
"""Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3."""
|
|
a = op.inputs[0] # a = x^{-1/2}
|
|
b = op.inputs[1] # backprop gradient for a
|
|
with ops.control_dependencies([grad]):
|
|
ca = math_ops.conj(a)
|
|
cg = math_ops.conj(grad)
|
|
grad_a = -1.5 * cg * b * math_ops.square(ca)
|
|
grad_b = gen_math_ops.rsqrt_grad(ca, grad)
|
|
return grad_a, grad_b
|
|
|
|
|
|
@ops.RegisterGradient("Exp")
|
|
def _ExpGrad(op: ops.Operation, grad):
|
|
"""Returns grad * exp(x)."""
|
|
y = op.outputs[0] # y = e^x
|
|
with ops.control_dependencies([grad]):
|
|
y = math_ops.conj(y)
|
|
return grad * y
|
|
|
|
|
|
@ops.RegisterGradient("Expm1")
|
|
def _Expm1Grad(op: ops.Operation, grad):
|
|
"""Returns grad * exp(x)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
y = math_ops.exp(x)
|
|
return grad * y
|
|
|
|
|
|
@ops.RegisterGradient("Log")
|
|
def _LogGrad(op: ops.Operation, grad):
|
|
"""Returns grad * (1/x)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
return grad * math_ops.reciprocal(x)
|
|
|
|
|
|
@ops.RegisterGradient("Log1p")
|
|
def _Log1pGrad(op: ops.Operation, grad):
|
|
"""Returns grad * (1/(1 + x))."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
return grad * math_ops.reciprocal(1 + x)
|
|
|
|
|
|
@ops.RegisterGradient("Xlogy")
|
|
def _XLogyGrad(op: ops.Operation, grad):
|
|
"""Returns gradient of xlogy(x, y) with respect to x and y."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
sx = array_ops.shape(x)
|
|
sy = array_ops.shape(y)
|
|
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
|
with ops.control_dependencies([grad]):
|
|
not_zero_x = math_ops.cast(
|
|
math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
|
|
partial_x = gen_math_ops.xlogy(not_zero_x, y)
|
|
partial_y = gen_math_ops.xdivy(x, y)
|
|
return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
|
|
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
|
|
|
|
|
|
@ops.RegisterGradient("Xlog1py")
|
|
def _XLog1pyGrad(op: ops.Operation, grad):
|
|
"""Returns gradient of xlog1py(x, y) with respect to x and y."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
sx = array_ops.shape(x)
|
|
sy = array_ops.shape(y)
|
|
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
|
with ops.control_dependencies([grad]):
|
|
not_zero_x = math_ops.cast(
|
|
math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
|
|
partial_x = gen_math_ops.xlog1py(not_zero_x, y)
|
|
partial_y = gen_math_ops.xdivy(x, y + 1.)
|
|
return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
|
|
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
|
|
|
|
|
|
@ops.RegisterGradient("Xdivy")
|
|
def _XDivyGrad(op: ops.Operation, grad):
|
|
"""Returns gradient of xdivy(x, y) with respect to x and y."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
sx = array_ops.shape(x)
|
|
sy = array_ops.shape(y)
|
|
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
|
with ops.control_dependencies([grad]):
|
|
not_zero_x = math_ops.cast(
|
|
math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
|
|
partial_x = gen_math_ops.xdivy(not_zero_x, y)
|
|
partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2)
|
|
return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
|
|
array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
|
|
|
|
|
|
@ops.RegisterGradient("Sinh")
|
|
def _SinhGrad(op: ops.Operation, grad):
|
|
"""Returns grad * cosh(x)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
return grad * math_ops.cosh(x)
|
|
|
|
|
|
@ops.RegisterGradient("Cosh")
|
|
def _CoshGrad(op: ops.Operation, grad):
|
|
"""Returns grad * sinh(x)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
return grad * math_ops.sinh(x)
|
|
|
|
|
|
@ops.RegisterGradient("Tanh")
|
|
def _TanhGrad(op: ops.Operation, grad):
|
|
"""Returns grad * (1 - tanh(x) * tanh(x))."""
|
|
y = op.outputs[0] # y = tanh(x)
|
|
with ops.control_dependencies([grad]):
|
|
y = math_ops.conj(y)
|
|
return gen_math_ops.tanh_grad(y, grad)
|
|
|
|
|
|
@ops.RegisterGradient("Asinh")
|
|
def _AsinhGrad(op: ops.Operation, grad):
|
|
"""Returns grad * 1/cosh(y)."""
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
y = math_ops.conj(y)
|
|
return grad / math_ops.cosh(y)
|
|
|
|
|
|
@ops.RegisterGradient("Acosh")
|
|
def _AcoshGrad(op: ops.Operation, grad):
|
|
"""Returns grad * 1/sinh(y)."""
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
y = math_ops.conj(y)
|
|
return grad / math_ops.sinh(y)
|
|
|
|
|
|
@ops.RegisterGradient("Atanh")
|
|
def _AtanhGrad(op: ops.Operation, grad):
|
|
"""Returns grad * 1/ (1 - x^2)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
x2 = math_ops.square(x)
|
|
one = constant_op.constant(1, dtype=grad.dtype)
|
|
inv = math_ops.reciprocal(math_ops.subtract(one, x2))
|
|
return grad * inv
|
|
|
|
|
|
@ops.RegisterGradient("TanhGrad")
|
|
def _TanhGradGrad(op: ops.Operation, grad):
|
|
with ops.control_dependencies([grad]):
|
|
a = math_ops.conj(op.inputs[0])
|
|
b = math_ops.conj(op.inputs[1])
|
|
return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad)
|
|
|
|
|
|
@ops.RegisterGradient("Erf")
|
|
def _ErfGrad(op: ops.Operation, grad):
|
|
"""Returns grad * 2/sqrt(pi) * exp(-x**2)."""
|
|
x = op.inputs[0]
|
|
two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype)
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x))
|
|
|
|
|
|
@ops.RegisterGradient("Erfc")
|
|
def _ErfcGrad(op: ops.Operation, grad):
|
|
"""Returns -grad * 2/sqrt(pi) * exp(-x**2)."""
|
|
x = op.inputs[0]
|
|
minus_two_over_root_pi = constant_op.constant(
|
|
-2 / np.sqrt(np.pi), dtype=grad.dtype)
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x))
|
|
|
|
|
|
@ops.RegisterGradient("Erfinv")
|
|
def _ErfinvGrad(op: ops.Operation, grad):
|
|
"""Returns grad * sqrt(pi) / 2 * exp(erfinv(x)**2)."""
|
|
root_pi_over_two = constant_op.constant(np.sqrt(np.pi) / 2, dtype=grad.dtype)
|
|
with ops.control_dependencies([grad]):
|
|
return grad * root_pi_over_two * math_ops.exp(
|
|
math_ops.square(op.outputs[0]))
|
|
|
|
|
|
@ops.RegisterGradient("Ndtri")
|
|
def _NdtriGrad(op: ops.Operation, grad):
|
|
"""Returns grad * sqrt(2 * pi) * exp(ndtri(x)**2 / 2)."""
|
|
root_two_pi = constant_op.constant(np.sqrt(2 * np.pi), dtype=grad.dtype)
|
|
with ops.control_dependencies([grad]):
|
|
return grad * root_two_pi * math_ops.exp(
|
|
math_ops.square(op.outputs[0]) / 2.)
|
|
|
|
|
|
@ops.RegisterGradient("Lgamma")
|
|
def _LgammaGrad(op: ops.Operation, grad):
|
|
"""Returns grad * digamma(x)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
return grad * math_ops.digamma(x)
|
|
|
|
|
|
@ops.RegisterGradient("Digamma")
|
|
def _DigammaGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of the digamma function with respect to its argument."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("Dawsn")
|
|
def _DawsnGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of dawsn(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
return grad * (1. - 2 * x * y)
|
|
|
|
|
|
@ops.RegisterGradient("Expint")
|
|
def _ExpintGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of expint(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
return grad * math_ops.exp(x) / x
|
|
|
|
|
|
@ops.RegisterGradient("FresnelCos")
|
|
def _FresnelCosGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of fresnel_cos(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
return grad * math_ops.cos((np.pi / 2.) * math_ops.square(x))
|
|
|
|
|
|
@ops.RegisterGradient("FresnelSin")
|
|
def _FresnelSinGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of fresnel_sin(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
return grad * math_ops.sin((np.pi / 2.) * math_ops.square(x))
|
|
|
|
|
|
@ops.RegisterGradient("Spence")
|
|
def _SpenceGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of spence(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
partial_x = math_ops.log(x) / (1 - x)
|
|
partial_x = array_ops.where(
|
|
math_ops.equal(x, 1.), -array_ops.ones_like(x), partial_x) # pylint: disable=invalid-unary-operand-type
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("BesselI0")
|
|
def _BesselI0Grad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_i0(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
partial_x = special_math_ops.bessel_i1(x)
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("BesselI0e")
|
|
def _BesselI0eGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_i0e(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
partial_x = (special_math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("BesselI1")
|
|
def _BesselI1Grad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_i1(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
# For x = 0, the correct gradient is 1.0.
|
|
# However, the main branch gives NaN because of the division by x, so
|
|
# we impute the gradient manually.
|
|
# An alternative solution is to express the gradient via bessel_i0 and
|
|
# bessel_i2, but the latter is not yet implemented in Eigen.
|
|
dy_dx = array_ops.where_v2(
|
|
math_ops.equal(x, 0.), math_ops.cast(1., x.dtype),
|
|
special_math_ops.bessel_i0(x) - math_ops.div(y, x))
|
|
return grad * dy_dx
|
|
|
|
|
|
@ops.RegisterGradient("BesselI1e")
|
|
def _BesselI1eGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_i1e(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
# For x = 0, the correct gradient is 0.5.
|
|
# However, the main branch gives NaN because of the division by x, so
|
|
# we impute the gradient manually.
|
|
# An alternative solution is to express the gradient via bessel_i0e and
|
|
# bessel_i2e, but the latter is not yet implemented in Eigen.
|
|
dy_dx = array_ops.where_v2(
|
|
math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
|
|
special_math_ops.bessel_i0e(x) - y *
|
|
(math_ops.sign(x) + math_ops.reciprocal(x)))
|
|
return grad * dy_dx
|
|
|
|
|
|
@ops.RegisterGradient("BesselK0")
|
|
def _BesselK0Grad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_k0(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
partial_x = -special_math_ops.bessel_k1(x)
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("BesselK0e")
|
|
def _BesselK0eGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_k0e(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
partial_x = (y - special_math_ops.bessel_k1e(x))
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("BesselK1")
|
|
def _BesselK1Grad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_k1(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
# At 0., this is NaN which is fine since the derivative is undefined
|
|
# at 0.
|
|
partial_x = -special_math_ops.bessel_k0(x) - math_ops.div(y, x)
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("BesselK1e")
|
|
def _BesselK1eGrad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_k1e(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
# At 0., this is NaN which is fine since the derivative is undefined
|
|
# at 0.
|
|
partial_x = (
|
|
y * (1. - math_ops.reciprocal(x)) - special_math_ops.bessel_k0e(x))
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("BesselJ0")
|
|
def _BesselJ0Grad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_j0(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
partial_x = -special_math_ops.bessel_j1(x)
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("BesselJ1")
|
|
def _BesselJ1Grad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_j1(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
# For x = 0, the correct gradient is 0.5.
|
|
# However, the main branch gives NaN because of the division by x, so
|
|
# we impute the gradient manually.
|
|
# An alternative solution is to express the gradient via bessel_i0e and
|
|
# bessel_i2e, but the latter is not yet implemented in Eigen.
|
|
dy_dx = array_ops.where_v2(
|
|
math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
|
|
special_math_ops.bessel_j0(x) - math_ops.div(y, x))
|
|
return grad * dy_dx
|
|
|
|
|
|
@ops.RegisterGradient("BesselY0")
|
|
def _BesselY0Grad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_y0(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
partial_x = -special_math_ops.bessel_y1(x)
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("BesselY1")
|
|
def _BesselY1Grad(op: ops.Operation, grad):
|
|
"""Compute gradient of bessel_y1(x) with respect to its argument."""
|
|
x = op.inputs[0]
|
|
y = op.outputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
# At 0., this is NaN which is fine since the derivative is undefined
|
|
# at 0.
|
|
partial_x = special_math_ops.bessel_y0(x) - math_ops.div(y, x)
|
|
return grad * partial_x
|
|
|
|
|
|
@ops.RegisterGradient("Igamma")
|
|
def _IgammaGrad(op: ops.Operation, grad):
|
|
"""Returns gradient of igamma(a, x) with respect to a and x."""
|
|
a = op.inputs[0]
|
|
x = op.inputs[1]
|
|
sa = array_ops.shape(a)
|
|
sx = array_ops.shape(x)
|
|
ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
|
|
|
|
with ops.control_dependencies([grad]):
|
|
partial_a = gen_math_ops.igamma_grad_a(a, x)
|
|
# Perform operations in log space before summing, because Gamma(a)
|
|
# and Gamma'(a) can grow large.
|
|
partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) -
|
|
math_ops.lgamma(a))
|
|
return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa),
|
|
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
|
|
|
|
|
|
@ops.RegisterGradient("Igammac")
|
|
def _IgammacGrad(op: ops.Operation, grad):
|
|
"""Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x."""
|
|
igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad)
|
|
return (-igamma_grad_a, -igamma_grad_x)
|
|
|
|
|
|
@ops.RegisterGradient("Betainc")
|
|
def _BetaincGrad(op: ops.Operation, grad):
|
|
"""Returns gradient of betainc(a, b, x) with respect to x."""
|
|
# TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b
|
|
a, b, x = op.inputs
|
|
|
|
# two cases: x is a scalar and a/b are same-shaped tensors, or vice
|
|
# versa; so its sufficient to check against shape(a).
|
|
sa = array_ops.shape(a)
|
|
sx = array_ops.shape(x)
|
|
_, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
|
|
|
|
# Perform operations in log space before summing, because terms
|
|
# can grow large.
|
|
log_beta = (
|
|
gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) -
|
|
gen_math_ops.lgamma(a + b))
|
|
# We use xlog1py and xlogy since the derivatives should tend to
|
|
# zero one of the tails when a is 1. or b is 1.
|
|
partial_x = math_ops.exp(math_ops.xlog1py(b - 1, -x) +
|
|
math_ops.xlogy(a - 1, x) - log_beta)
|
|
|
|
return (
|
|
None, # da
|
|
None, # db
|
|
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
|
|
|
|
|
|
@ops.RegisterGradient("Zeta")
|
|
def _ZetaGrad(op: ops.Operation, grad):
|
|
"""Returns gradient of zeta(x, q) with respect to x and q."""
|
|
# TODO(tillahoffmann): Add derivative with respect to x
|
|
x = op.inputs[0]
|
|
q = op.inputs[1]
|
|
# Broadcast gradients
|
|
sx = array_ops.shape(x)
|
|
sq = array_ops.shape(q)
|
|
unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq)
|
|
# Evaluate gradient
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
q = math_ops.conj(q)
|
|
partial_q = -x * math_ops.zeta(x + 1, q) # pylint: disable=invalid-unary-operand-type
|
|
return (None,
|
|
array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq))
|
|
|
|
|
|
@ops.RegisterGradient("Polygamma")
|
|
def _PolygammaGrad(op: ops.Operation, grad):
|
|
"""Returns gradient of psi(n, x) with respect to n and x."""
|
|
# TODO(tillahoffmann): Add derivative with respect to n
|
|
n = op.inputs[0]
|
|
x = op.inputs[1]
|
|
# Broadcast gradients
|
|
sn = array_ops.shape(n)
|
|
sx = array_ops.shape(x)
|
|
unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx)
|
|
# Evaluate gradient
|
|
with ops.control_dependencies([grad]):
|
|
n = math_ops.conj(n)
|
|
x = math_ops.conj(x)
|
|
partial_x = math_ops.polygamma(n + 1, x)
|
|
return (None,
|
|
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
|
|
|
|
|
|
@ops.RegisterGradient("Sigmoid")
|
|
def _SigmoidGrad(op: ops.Operation, grad):
|
|
"""Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
|
|
y = op.outputs[0] # y = sigmoid(x)
|
|
with ops.control_dependencies([grad]):
|
|
y = math_ops.conj(y)
|
|
return gen_math_ops.sigmoid_grad(y, grad)
|
|
|
|
|
|
@ops.RegisterGradient("SigmoidGrad")
|
|
def _SigmoidGradGrad(op: ops.Operation, grad):
|
|
with ops.control_dependencies([grad]):
|
|
a = math_ops.conj(op.inputs[0])
|
|
b = math_ops.conj(op.inputs[1])
|
|
gb = grad * b
|
|
return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad)
|
|
|
|
|
|
@ops.RegisterGradient("Sign")
|
|
def _SignGrad(op: ops.Operation, _):
|
|
"""Returns 0."""
|
|
x = op.inputs[0]
|
|
return array_ops.zeros_like(x)
|
|
|
|
|
|
@ops.RegisterGradient("Sin")
|
|
def _SinGrad(op: ops.Operation, grad):
|
|
"""Returns grad * cos(x)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
return grad * math_ops.cos(x)
|
|
|
|
|
|
@ops.RegisterGradient("Cos")
|
|
def _CosGrad(op: ops.Operation, grad):
|
|
"""Returns grad * -sin(x)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
return -grad * math_ops.sin(x)
|
|
|
|
|
|
@ops.RegisterGradient("Tan")
|
|
def _TanGrad(op: ops.Operation, grad):
|
|
"""Returns grad * 1/sec^2(x)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
secx = math_ops.reciprocal(math_ops.cos(x))
|
|
secx2 = math_ops.square(secx)
|
|
return secx2 * grad
|
|
|
|
|
|
@ops.RegisterGradient("Asin")
|
|
def _AsinGrad(op: ops.Operation, grad):
|
|
"""Returns grad * 1/sqrt(1-x^2)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
x2 = math_ops.square(x)
|
|
one = constant_op.constant(1, dtype=grad.dtype)
|
|
den = math_ops.sqrt(math_ops.subtract(one, x2))
|
|
inv = math_ops.reciprocal(den)
|
|
return grad * inv
|
|
|
|
|
|
@ops.RegisterGradient("Acos")
|
|
def _AcosGrad(op: ops.Operation, grad):
|
|
"""Returns grad * -1/sqrt(1-x^2)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
x2 = math_ops.square(x)
|
|
one = constant_op.constant(1, dtype=grad.dtype)
|
|
den = math_ops.sqrt(math_ops.subtract(one, x2))
|
|
inv = math_ops.reciprocal(den)
|
|
return -grad * inv
|
|
|
|
|
|
@ops.RegisterGradient("Atan")
|
|
def _AtanGrad(op: ops.Operation, grad):
|
|
"""Returns grad * 1/ (1 + x^2)."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
x = math_ops.conj(x)
|
|
x2 = math_ops.square(x)
|
|
one = constant_op.constant(1, dtype=grad.dtype)
|
|
inv = math_ops.reciprocal(math_ops.add(one, x2))
|
|
return grad * inv
|
|
|
|
|
|
@ops.RegisterGradient("Atan2")
|
|
def _Atan2Grad(op: ops.Operation, grad):
|
|
"""Returns grad * x / (y^2 + x^2), grad * -y / (y^2 + x^2)."""
|
|
y = op.inputs[0]
|
|
x = op.inputs[1]
|
|
with ops.control_dependencies([grad]):
|
|
grad_inv = grad / (math_ops.square(y) + math_ops.square(x))
|
|
gy = x * grad_inv
|
|
gx = -y * grad_inv
|
|
# pylint: disable=arguments-out-of-order
|
|
return _ReduceGradientArgs(y, x, gy, gx)
|
|
# pylint: enable=arguments-out-of-order
|
|
|
|
|
|
@ops.RegisterGradient("AddN")
|
|
def _AddNGrad(op: ops.Operation, grad):
|
|
"""Copies the gradient to all inputs."""
|
|
# Not broadcasting.
|
|
return [grad] * len(op.inputs)
|
|
|
|
|
|
def _ShapesFullySpecifiedAndEqual(x, y, grad):
|
|
# pylint: disable=protected-access
|
|
x_shape = x._shape_tuple()
|
|
y_shape = y._shape_tuple()
|
|
grad_shape = grad._shape_tuple()
|
|
# pylint: enable=protected-access
|
|
return (x_shape == y_shape and x_shape == grad_shape and
|
|
x_shape is not None and None not in x_shape)
|
|
|
|
|
|
@ops.RegisterGradient("Add")
|
|
@ops.RegisterGradient("AddV2")
|
|
def _AddGrad(op: ops.Operation, grad):
|
|
"""Gradient for Add."""
|
|
y = op.inputs[1]
|
|
try:
|
|
skip_input_indices = op.skip_input_indices or ()
|
|
if 1 in skip_input_indices and _IsScalar(y):
|
|
return grad, None
|
|
except AttributeError:
|
|
# No gradient skipping, so do the full gradient computation
|
|
skip_input_indices = ()
|
|
|
|
x = op.inputs[0]
|
|
if isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(
|
|
x, y, grad
|
|
):
|
|
return grad, grad
|
|
|
|
gx = None if 0 in skip_input_indices else grad
|
|
gy = None if 1 in skip_input_indices else grad
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("Sub")
|
|
def _SubGrad(op: ops.Operation, grad):
|
|
"""Gradient for Sub."""
|
|
y = op.inputs[1]
|
|
try:
|
|
skip_input_indices = op.skip_input_indices or ()
|
|
if 1 in skip_input_indices and _IsScalar(y):
|
|
return grad, None
|
|
except AttributeError:
|
|
# No gradient skipping, so do the full gradient computation
|
|
skip_input_indices = ()
|
|
|
|
x = op.inputs[0]
|
|
if isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(
|
|
x, y, grad
|
|
):
|
|
return grad, -grad
|
|
|
|
gx = None if 0 in skip_input_indices else grad
|
|
gy = None if 1 in skip_input_indices else -grad
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("Mul")
|
|
def _MulGrad(op: ops.Operation, grad):
|
|
"""The gradient of scalar multiplication."""
|
|
y = op.inputs[1]
|
|
try:
|
|
skip_input_indices = op.skip_input_indices or ()
|
|
if 1 in skip_input_indices and _IsScalar(y):
|
|
return gen_math_ops.mul(grad, math_ops.conj(y)), None
|
|
except AttributeError:
|
|
# No gradient skipping, so do the full gradient computation
|
|
skip_input_indices = ()
|
|
|
|
x = op.inputs[0]
|
|
if (
|
|
isinstance(grad, tensor.Tensor)
|
|
and _ShapesFullySpecifiedAndEqual(x, y, grad)
|
|
and grad.dtype in (dtypes.int32, dtypes.float32)
|
|
):
|
|
return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
|
|
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
|
|
|
|
if 0 in skip_input_indices:
|
|
gx = None
|
|
else:
|
|
gx = gen_math_ops.mul(grad, math_ops.conj(y))
|
|
|
|
if 1 in skip_input_indices:
|
|
gy = None
|
|
else:
|
|
gy = gen_math_ops.mul(math_ops.conj(x), grad)
|
|
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("MulNoNan")
|
|
def _MulNoNanGrad(op: ops.Operation, grad):
|
|
"""The gradient of scalar multiplication with NaN-suppression."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
if isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(
|
|
x, y, grad
|
|
):
|
|
return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad)
|
|
|
|
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
|
|
gx = gen_math_ops.mul_no_nan(grad, y)
|
|
gy = gen_math_ops.mul_no_nan(x, grad)
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("Div")
|
|
def _DivGrad(op: ops.Operation, grad):
|
|
"""The gradient for the Div operator."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
cx = math_ops.conj(x)
|
|
cy = math_ops.conj(y)
|
|
gx = math_ops.divide(grad, cy)
|
|
gy = grad * math_ops.divide(math_ops.divide(-cx, cy), cy)
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("FloorDiv")
|
|
def _FloorDivGrad(_, unused_grad):
|
|
"""The gradient for the FloorDiv operator."""
|
|
return None, None
|
|
|
|
|
|
@ops.RegisterGradient("FloorMod")
|
|
def _FloorModGrad(op: ops.Operation, grad):
|
|
"""Returns grad * (1, -floor(x/y))."""
|
|
x = math_ops.conj(op.inputs[0])
|
|
y = math_ops.conj(op.inputs[1])
|
|
floor_xy = math_ops.floor_div(x, y)
|
|
gx = grad
|
|
gy = grad * math_ops.negative(floor_xy)
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("TruncateDiv")
|
|
def _TruncateDivGrad(_, unused_grad):
|
|
return None, None
|
|
|
|
|
|
@ops.RegisterGradient("RealDiv")
|
|
def _RealDivGrad(op: ops.Operation, grad):
|
|
"""RealDiv op gradient."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
cx = math_ops.conj(op.inputs[0])
|
|
cy = math_ops.conj(op.inputs[1])
|
|
gx = math_ops.realdiv(grad, cy)
|
|
gy = grad * math_ops.realdiv(math_ops.realdiv(-cx, cy), cy)
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("DivNoNan")
|
|
def _DivNoNanGrad(op: ops.Operation, grad):
|
|
"""DivNoNan op gradient."""
|
|
x = math_ops.conj(op.inputs[0])
|
|
y = math_ops.conj(op.inputs[1])
|
|
gx = math_ops.div_no_nan(grad, y)
|
|
gy = grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y)
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("Pow")
|
|
def _PowGrad(op: ops.Operation, grad):
|
|
"""Returns grad * (y*x^(y-1), z*log(x))."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
cx = math_ops.conj(x)
|
|
cy = math_ops.conj(y)
|
|
try:
|
|
skip_input_indices = op.skip_input_indices or ()
|
|
if 1 in skip_input_indices and _IsScalar(y):
|
|
return grad * cy * math_ops.pow(cx, cy - 1), None
|
|
except AttributeError:
|
|
# No gradient skipping, so do the full gradient computation
|
|
skip_input_indices = ()
|
|
|
|
if 0 in skip_input_indices:
|
|
gx = None
|
|
else:
|
|
gx = grad * cy * math_ops.pow(cx, cy - 1)
|
|
|
|
if 1 in skip_input_indices:
|
|
gy = None
|
|
else:
|
|
# Avoid false singularity at x = 0
|
|
if x.dtype.is_complex:
|
|
# real(x) < 0 is fine for the complex case
|
|
mask = math_ops.not_equal(cx, 0)
|
|
else:
|
|
# There's no sensible real value to return if x < 0, so return 0
|
|
mask = cx > 0
|
|
safe_x = array_ops.where(mask, cx, array_ops.ones_like(x))
|
|
log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
|
|
gy = grad * math_ops.conj(op.outputs[0]) * log_x
|
|
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
def _MaximumMinimumGradInputOnly(op: ops.Operation, grad, selector_op):
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
zeros = array_ops.zeros_like(grad)
|
|
xmask = selector_op(x, y)
|
|
xgrad = array_ops.where_v2(xmask, grad, zeros)
|
|
ygrad = None # Return None for ygrad since the config allows that.
|
|
return (xgrad, ygrad)
|
|
|
|
|
|
def _MaximumMinimumGrad(op: ops.Operation, grad, selector_op):
|
|
"""Factor out the code for the gradient of Maximum or Minimum."""
|
|
y = op.inputs[1]
|
|
try:
|
|
skip_input_indices = op.skip_input_indices or ()
|
|
if 1 in skip_input_indices and _IsScalar(y):
|
|
# When we want to get gradients for the first input only, and the second
|
|
# input tensor is a scalar, we can do a much simpler calculation
|
|
return _MaximumMinimumGradInputOnly(op, grad, selector_op)
|
|
except AttributeError:
|
|
# No gradient skipping, so do the full gradient computation
|
|
skip_input_indices = ()
|
|
x = op.inputs[0]
|
|
zeros = array_ops.zeros_like(grad)
|
|
xmask = selector_op(x, y)
|
|
if 0 in skip_input_indices:
|
|
gx = None
|
|
else:
|
|
gx = array_ops.where_v2(xmask, grad, zeros)
|
|
if 1 in skip_input_indices:
|
|
gy = None
|
|
else:
|
|
gy = array_ops.where_v2(xmask, zeros, grad)
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("Maximum")
|
|
def _MaximumGrad(op: ops.Operation, grad):
|
|
"""Returns grad*(x >= y, x < y) with type of grad."""
|
|
return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)
|
|
|
|
|
|
@ops.RegisterGradient("Minimum")
|
|
def _MinimumGrad(op: ops.Operation, grad):
|
|
"""Returns grad*(x <= y, x > y) with type of grad."""
|
|
return _MaximumMinimumGrad(op, grad, math_ops.less_equal)
|
|
|
|
|
|
@ops.RegisterGradient("SquaredDifference")
|
|
def _SquaredDifferenceGrad(op: ops.Operation, grad):
|
|
"""Returns the gradient for (x-y)^2."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
try:
|
|
skip_input_indices = op.skip_input_indices or ()
|
|
except AttributeError:
|
|
# No gradient skipping, so do the full gradient computation
|
|
skip_input_indices = ()
|
|
|
|
with ops.control_dependencies([grad]):
|
|
# The parens ensure that if grad is IndexedSlices, it'll get multiplied by
|
|
# Tensor (not a number like 2.0) which causes it to convert to Tensor.
|
|
x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
|
|
|
|
if isinstance(grad, tensor.Tensor) and _ShapesFullySpecifiedAndEqual(
|
|
x, y, grad
|
|
):
|
|
return x_grad, -x_grad
|
|
|
|
gx = None if 0 in skip_input_indices else x_grad
|
|
gy = None if 1 in skip_input_indices else -x_grad
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
# Logical operations have no gradients.
|
|
ops.NotDifferentiable("Less")
|
|
ops.NotDifferentiable("LessEqual")
|
|
ops.NotDifferentiable("Greater")
|
|
ops.NotDifferentiable("GreaterEqual")
|
|
ops.NotDifferentiable("Equal")
|
|
ops.NotDifferentiable("ApproximateEqual")
|
|
ops.NotDifferentiable("NotEqual")
|
|
ops.NotDifferentiable("LogicalAnd")
|
|
ops.NotDifferentiable("LogicalOr")
|
|
ops.NotDifferentiable("LogicalNot")
|
|
|
|
|
|
@ops.RegisterGradient("Select")
|
|
def _SelectGrad(op: ops.Operation, grad):
|
|
c = op.inputs[0]
|
|
x = op.inputs[1]
|
|
zeros = array_ops.zeros_like(x)
|
|
return (
|
|
None,
|
|
array_ops.where(c, grad, zeros),
|
|
array_ops.where(c, zeros, grad),
|
|
)
|
|
|
|
|
|
@ops.RegisterGradient("SelectV2")
|
|
def _SelectGradV2(op: ops.Operation, grad):
|
|
c = op.inputs[0]
|
|
x = op.inputs[1]
|
|
y = op.inputs[2]
|
|
z = op.outputs[0]
|
|
zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype)
|
|
gx = array_ops.where_v2(c, grad, zeros)
|
|
gy = array_ops.where_v2(c, zeros, grad)
|
|
gx, _ = _ReduceGradientArgs(x, z, gx, None)
|
|
gy, _ = _ReduceGradientArgs(y, z, gy, None)
|
|
return None, gx, gy
|
|
|
|
|
|
def _MatMulGradAgainstFirstOnly(op: ops.Operation, grad):
|
|
"""Gradient for MatMul, only for the first input."""
|
|
t_a = op.get_attr("transpose_a")
|
|
t_b = op.get_attr("transpose_b")
|
|
b = math_ops.conj(op.inputs[1])
|
|
if not t_a and not t_b:
|
|
grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True, grad_a=True)
|
|
elif not t_a and t_b:
|
|
grad_a = gen_math_ops.mat_mul(grad, b, grad_a=True)
|
|
elif t_a and not t_b:
|
|
grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True, grad_a=True)
|
|
elif t_a and t_b:
|
|
grad_a = gen_math_ops.mat_mul(
|
|
b, grad, transpose_a=True, transpose_b=True, grad_a=True
|
|
)
|
|
return grad_a, None
|
|
|
|
|
|
def _MatMulGradAgainstSecondOnly(op: ops.Operation, grad):
|
|
"""Gradient for MatMul, only for the second input."""
|
|
t_a = op.get_attr("transpose_a")
|
|
t_b = op.get_attr("transpose_b")
|
|
a = math_ops.conj(op.inputs[0])
|
|
if not t_a and not t_b:
|
|
grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True, grad_b=True)
|
|
elif not t_a and t_b:
|
|
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, grad_b=True)
|
|
elif t_a and not t_b:
|
|
grad_b = gen_math_ops.mat_mul(a, grad, grad_b=True)
|
|
elif t_a and t_b:
|
|
grad_b = gen_math_ops.mat_mul(
|
|
grad, a, transpose_a=True, transpose_b=True, grad_b=True
|
|
)
|
|
return None, grad_b
|
|
|
|
|
|
@ops.RegisterGradient("MatMul")
|
|
def _MatMulGrad(op: ops.Operation, grad):
|
|
"""Gradient for MatMul."""
|
|
try:
|
|
skip_input_indices = op.skip_input_indices
|
|
if skip_input_indices is not None:
|
|
if 1 in skip_input_indices:
|
|
return _MatMulGradAgainstFirstOnly(op, grad)
|
|
elif 0 in skip_input_indices:
|
|
return _MatMulGradAgainstSecondOnly(op, grad)
|
|
except AttributeError:
|
|
# No gradient skipping, so do the full gradient computation
|
|
pass
|
|
|
|
t_a = op.get_attr("transpose_a")
|
|
t_b = op.get_attr("transpose_b")
|
|
a = math_ops.conj(op.inputs[0])
|
|
b = math_ops.conj(op.inputs[1])
|
|
if not t_a and not t_b:
|
|
grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True, grad_a=True)
|
|
grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True, grad_b=True)
|
|
elif not t_a and t_b:
|
|
grad_a = gen_math_ops.mat_mul(grad, b, grad_a=True)
|
|
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, grad_b=True)
|
|
elif t_a and not t_b:
|
|
grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True, grad_a=True)
|
|
grad_b = gen_math_ops.mat_mul(a, grad, grad_b=True)
|
|
elif t_a and t_b:
|
|
grad_a = gen_math_ops.mat_mul(
|
|
b, grad, transpose_a=True, transpose_b=True, grad_a=True
|
|
)
|
|
grad_b = gen_math_ops.mat_mul(
|
|
grad, a, transpose_a=True, transpose_b=True, grad_b=True
|
|
)
|
|
return grad_a, grad_b
|
|
|
|
|
|
@ops.RegisterGradient("SparseMatMul")
|
|
def _SparseMatMulGrad(op: ops.Operation, grad):
|
|
"""Gradient for SparseMatMul."""
|
|
|
|
t_a = op.get_attr("transpose_a")
|
|
t_b = op.get_attr("transpose_b")
|
|
is_sparse = {}
|
|
is_sparse[op.inputs[0].ref()] = op.get_attr("a_is_sparse")
|
|
is_sparse[op.inputs[1].ref()] = op.get_attr("b_is_sparse")
|
|
# Use heuristic to figure out if grad might be sparse
|
|
is_sparse[grad.ref()] = not context.executing_eagerly() and (
|
|
grad.op.type == "ReluGrad")
|
|
|
|
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
|
|
"""Helper function to create SparseMatMul op."""
|
|
|
|
assert t1.ref() in is_sparse and t2.ref() in is_sparse
|
|
t1_sparse = is_sparse[t1.ref()]
|
|
t2_sparse = is_sparse[t2.ref()]
|
|
if transpose_b:
|
|
t2 = array_ops.transpose(t2)
|
|
transpose_b = False
|
|
prod = math_ops.matmul(
|
|
t1,
|
|
t2,
|
|
transpose_a=transpose_a,
|
|
transpose_b=transpose_b,
|
|
a_is_sparse=t1_sparse,
|
|
b_is_sparse=t2_sparse)
|
|
if prod.dtype != out_dtype:
|
|
prod = math_ops.cast(prod, out_dtype)
|
|
return prod
|
|
|
|
dtype_a = op.inputs[0].dtype
|
|
dtype_b = op.inputs[1].dtype
|
|
if not t_a and not t_b:
|
|
return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True),
|
|
_SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True))
|
|
elif not t_a and t_b:
|
|
return (_SparseMatMul(grad, op.inputs[1], dtype_a),
|
|
_SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True))
|
|
elif t_a and not t_b:
|
|
return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True),
|
|
_SparseMatMul(op.inputs[0], grad, dtype_b))
|
|
elif t_a and t_b:
|
|
return (_SparseMatMul(
|
|
op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True),
|
|
_SparseMatMul(
|
|
grad, op.inputs[0], dtype_b, transpose_a=True,
|
|
transpose_b=True))
|
|
|
|
|
|
@ops.RegisterGradient("Floor")
|
|
def _FloorGrad(_, unused_grad):
|
|
return [None]
|
|
|
|
|
|
@ops.RegisterGradient("Ceil")
|
|
def _CeilGrad(_, unused_grad):
|
|
return [None]
|
|
|
|
|
|
@ops.RegisterGradient("Round")
|
|
def _RoundGrad(_, unused_grad):
|
|
return [None]
|
|
|
|
|
|
@ops.RegisterGradient("Rint")
|
|
def _RintGrad(_, unused_grad):
|
|
# the gradient of Rint is zero
|
|
return [None]
|
|
|
|
|
|
@ops.RegisterGradient("BatchMatMul")
|
|
def _BatchMatMul(op: ops.Operation, grad):
|
|
"""Returns the gradient of x and y given the gradient of x * y."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
adj_x = op.get_attr("adj_x")
|
|
adj_y = op.get_attr("adj_y")
|
|
|
|
if not adj_x:
|
|
if not adj_y:
|
|
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
|
|
grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
|
|
else:
|
|
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
|
|
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
|
|
else:
|
|
if not adj_y:
|
|
grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
|
|
grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
|
|
else:
|
|
grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
|
|
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
|
|
|
|
return grad_x, grad_y
|
|
|
|
|
|
@ops.RegisterGradient("BatchMatMulV2")
|
|
@ops.RegisterGradient("BatchMatMulV3")
|
|
def _BatchMatMulV2(op: ops.Operation, grad):
|
|
"""Returns the gradient of x and y given the gradient of x * y."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
adj_x = op.get_attr("adj_x")
|
|
adj_y = op.get_attr("adj_y")
|
|
|
|
if not adj_x:
|
|
if not adj_y:
|
|
grad_x = math_ops.matmul(
|
|
grad, y, adjoint_a=False, adjoint_b=True, grad_a=True
|
|
)
|
|
grad_y = math_ops.matmul(
|
|
x, grad, adjoint_a=True, adjoint_b=False, grad_b=True
|
|
)
|
|
else:
|
|
grad_x = math_ops.matmul(
|
|
grad, y, adjoint_a=False, adjoint_b=False, grad_a=True
|
|
)
|
|
grad_y = math_ops.matmul(
|
|
grad, x, adjoint_a=True, adjoint_b=False, grad_b=True
|
|
)
|
|
else:
|
|
if not adj_y:
|
|
grad_x = math_ops.matmul(
|
|
y, grad, adjoint_a=False, adjoint_b=True, grad_a=True
|
|
)
|
|
grad_y = math_ops.matmul(
|
|
x, grad, adjoint_a=False, adjoint_b=False, grad_b=True
|
|
)
|
|
else:
|
|
grad_x = math_ops.matmul(
|
|
y, grad, adjoint_a=True, adjoint_b=True, grad_a=True
|
|
)
|
|
grad_y = math_ops.matmul(
|
|
grad, x, adjoint_a=True, adjoint_b=True, grad_b=True
|
|
)
|
|
|
|
# Possibly reduce along the broadcasted batch dimensions, if broadcasting
|
|
# is required.
|
|
shape_x_static = x.get_shape()
|
|
shape_y_static = y.get_shape()
|
|
output_may_have_non_empty_batch_shape = (
|
|
(shape_x_static.rank is None or shape_x_static.rank > 2) or
|
|
(shape_y_static.rank is None or shape_y_static.rank > 2))
|
|
batch_shapes_match = (
|
|
shape_x_static[:-2].is_fully_defined() and
|
|
shape_y_static[:-2].is_fully_defined() and
|
|
shape_x_static[:-2] == shape_y_static[:-2])
|
|
if (not output_may_have_non_empty_batch_shape) or batch_shapes_match:
|
|
return grad_x, grad_y
|
|
|
|
sx = array_ops.shape(x)
|
|
sy = array_ops.shape(y)
|
|
rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
|
|
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
|
|
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
|
|
return grad_x, grad_y
|
|
|
|
|
|
ops.NotDifferentiable("Range")
|
|
ops.NotDifferentiable("LinSpace")
|
|
|
|
|
|
@ops.RegisterGradient("Complex")
|
|
def _ComplexGrad(op: ops.Operation, grad):
|
|
"""Returns the real and imaginary components of 'grad', respectively."""
|
|
x = op.inputs[0]
|
|
y = op.inputs[1]
|
|
gx = math_ops.real(grad)
|
|
gy = math_ops.imag(grad)
|
|
return _ReduceGradientArgs(x, y, gx, gy)
|
|
|
|
|
|
@ops.RegisterGradient("Real")
|
|
def _RealGrad(_, grad):
|
|
"""Returns 'grad' as the real part and set the imaginary part 0."""
|
|
zero = constant_op.constant(0, dtype=grad.dtype)
|
|
return math_ops.complex(grad, zero)
|
|
|
|
|
|
@ops.RegisterGradient("Imag")
|
|
def _ImagGrad(_, grad):
|
|
"""Returns 'grad' as the imaginary part and set the real part 0."""
|
|
zero = constant_op.constant(0, dtype=grad.dtype)
|
|
return math_ops.complex(zero, grad)
|
|
|
|
|
|
@ops.RegisterGradient("Angle")
|
|
def _AngleGrad(op: ops.Operation, grad):
|
|
"""Returns `-grad / (Im(x) + i Re(x))`."""
|
|
x = op.inputs[0]
|
|
with ops.control_dependencies([grad]):
|
|
re = math_ops.real(x)
|
|
im = math_ops.imag(x)
|
|
z = math_ops.reciprocal(math_ops.complex(im, re))
|
|
zero = constant_op.constant(0, dtype=grad.dtype)
|
|
complex_grad = math_ops.complex(grad, zero)
|
|
return -complex_grad * z
|
|
|
|
|
|
@ops.RegisterGradient("Conj")
|
|
def _ConjGrad(_, grad):
|
|
"""Returns the complex conjugate of grad."""
|
|
return math_ops.conj(grad)
|
|
|
|
|
|
@ops.RegisterGradient("ComplexAbs")
|
|
def _ComplexAbsGrad(op: ops.Operation, grad):
|
|
"""Returns the gradient of ComplexAbs."""
|
|
return math_ops.div_no_nan(
|
|
math_ops.complex(
|
|
grad, array_ops.zeros_like(grad)) * op.inputs[0],
|
|
math_ops.complex(
|
|
op.outputs[0], array_ops.zeros_like(op.outputs[0])))
|
|
|
|
|
|
@ops.RegisterGradient("Cast")
|
|
def _CastGrad(op: ops.Operation, grad):
|
|
t = [
|
|
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16,
|
|
dtypes.complex64, dtypes.complex128
|
|
]
|
|
src_type = op.inputs[0].dtype.base_dtype
|
|
dst_type = grad.dtype.base_dtype
|
|
if src_type in t and dst_type in t:
|
|
return math_ops.cast(grad, src_type)
|
|
else:
|
|
return None
|
|
|
|
|
|
@ops.RegisterGradient("Cross")
|
|
def _CrossGrad(op: ops.Operation, grad):
|
|
u = op.inputs[0]
|
|
v = op.inputs[1]
|
|
return (math_ops.cross(v, grad), math_ops.cross(grad, u))
|
|
|
|
|
|
@ops.RegisterGradient("Cumsum")
|
|
def _CumsumGrad(op: ops.Operation, grad):
|
|
axis = op.inputs[1]
|
|
exclusive = op.get_attr("exclusive")
|
|
reverse = op.get_attr("reverse")
|
|
return [
|
|
math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse),
|
|
None
|
|
]
|
|
|
|
|
|
@ops.RegisterGradient("Cumprod")
|
|
def _CumprodGrad(op: ops.Operation, grad):
|
|
x = op.inputs[0]
|
|
axis = op.inputs[1]
|
|
exclusive = op.get_attr("exclusive")
|
|
reverse = op.get_attr("reverse")
|
|
|
|
prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
|
|
out = math_ops.cumsum(
|
|
prod * grad, axis, exclusive=exclusive, reverse=not reverse
|
|
)
|
|
return [math_ops.div_no_nan(out, x), None]
|
|
|
|
|
|
# pylint: disable=missing-function-docstring
|
|
@ops.RegisterGradient("CumulativeLogsumexp")
|
|
def _CumulativeLogsumexpGrad(op: ops.Operation, grad):
|
|
x = op.inputs[0]
|
|
axis = op.inputs[1]
|
|
cumulative_logsumexp = op.outputs[0]
|
|
|
|
exclusive = op.get_attr("exclusive")
|
|
reverse = op.get_attr("reverse")
|
|
|
|
# Split the incoming gradient into positive and negative part
|
|
# in order to take logs. This is required for stable results.
|
|
log_grad_positive = array_ops.where_v2(
|
|
math_ops.greater(grad, 0),
|
|
math_ops.log(grad),
|
|
grad.dtype.min)
|
|
|
|
log_grad_negative = array_ops.where_v2(
|
|
math_ops.less(grad, 0),
|
|
math_ops.log(-grad),
|
|
grad.dtype.min)
|
|
|
|
output_pos = math_ops.exp(
|
|
math_ops.cumulative_logsumexp(
|
|
log_grad_positive - cumulative_logsumexp,
|
|
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
|
|
|
|
output_neg = math_ops.exp(
|
|
math_ops.cumulative_logsumexp(
|
|
log_grad_negative - cumulative_logsumexp,
|
|
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
|
|
|
|
return [output_pos - output_neg, None]
|
|
|
|
|
|
@ops.RegisterGradient("NextAfter")
|
|
def _NextAfterGrad(op: ops.Operation, grad):
|
|
"""Returns gradient of nextafter(x1, x2) with respect to x1 and x2."""
|
|
x1 = op.inputs[0]
|
|
x2 = op.inputs[1]
|
|
s_x1 = array_ops.shape(x1)
|
|
s_x2 = array_ops.shape(x2)
|
|
r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2)
|
|
with ops.control_dependencies([grad]):
|
|
partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype)
|
|
partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype)
|
|
return (array_ops.reshape(
|
|
math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1),
|
|
array_ops.reshape(
|
|
math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2))
|