# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Gradients for operators defined in array_ops.py.""" from tensorflow.compiler.tf2xla.ops import gen_xla_ops from tensorflow.python import pywrap_tfe from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices as indexed_slices_lib from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops @ops.RegisterGradient("Pack") def _PackGrad(op: ops.Operation, grad): """Gradient for pack op.""" return array_ops_stack.unstack( grad, num=op.get_attr("N"), axis=op.get_attr("axis")) @ops.RegisterGradient("Unpack") def _UnpackGrad(op: ops.Operation, *grads): """Gradient for unpack op.""" return array_ops_stack.stack(grads, axis=op.get_attr("axis")) def _ConcatGradHelper( op: ops.Operation, grad, start_value_index, end_value_index, dim_index ): """Gradient for concat op. Args: op: An operation. grad: `Tensor` or `IndexedSlices` representing the gradients with respect to each output of the op. start_value_index: An integer index of the first value in the op.inputs. end_value_index: An integer index of the last value in the op.inputs. dim_index: An integer index of concat_dim or axis parameter in op.inputs. Returns: Tensors representing the partial gradients with respect to each input of the op. Raises: ValueError: if concat_dim/axis is not statically known. """ def _CreateDenseMaskAndBegin(sizes, concat_dim): """Create variables for iteratively slicing a dense gradients tensor.""" # Since shape is 1-D, shape_of_shape = [rank-of-inputs] shape_of_shape = array_ops.shape(sizes[0]) # Make a vector of length equal to the input's dimensions, # with 0's everywhere and 1 in the concat dim position. # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now) mask = array_ops.concat([ array_ops.zeros( array_ops.expand_dims(concat_dim, 0), dtype=dtypes.int32), [1], array_ops.zeros(shape_of_shape - concat_dim - 1, dtype=dtypes.int32) ], 0) begin = array_ops.zeros(shape_of_shape, dtype=dtypes.int32) return mask, begin def _ExtractInputShapes(inputs): """Extract the shapes of a set of input tensors.""" if context.executing_eagerly(): return array_ops.shape_n(inputs) sizes = [] fully_known = True for x in inputs: input_shape = array_ops.shape(x) if not isinstance(input_shape, tensor.Tensor) or input_shape.op.type != "Const": fully_known = False break sizes.append(input_shape) if fully_known: return sizes else: return array_ops.shape_n(inputs) # Degenerate concatenation, just return grad. if len(op.inputs) == 2: return grad + [None] if end_value_index <= dim_index else [None] + grad concat_dim = op.inputs[dim_index] input_values = op.inputs[start_value_index:end_value_index] out_grads = [] if isinstance(grad, tensor.Tensor): if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor): # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = ( concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access # All inputs are guaranteed to be EagerTensors in eager mode sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values, non_neg_concat_dim) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) else: if constant_op.is_constant(concat_dim): # If concat_dim is a constant defined in a different context, # then we duplicate it in the current context to avoid passing it # through an Enter node. # This is a small optimization in general, but it is required when # compiling with XLA, as XLA needs the concat input to be folded into a # constant. grad_context = control_flow_util.GetOutputContext(grad.op) dim_context = control_flow_util.GetOutputContext(concat_dim.op) if dim_context != grad_context: value = tensor_util.constant_value(concat_dim) concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype) # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) # Get the inputs' tensor shapes sizes = _ExtractInputShapes(input_values) # The magic number of 16 was found through benchmarking a range of sizes # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of # cases when switching implementations at N=16, but it is possible that # there will be a small number of performance regressions. if len(sizes) > 16: # extract the size of each input along the concat dimension sizes = array_ops.squeeze( array_ops.slice( array_ops_stack.stack(sizes, axis=1), [non_neg_concat_dim, 0], [1, -1])) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) else: offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes) for (begin, size) in zip(offset, sizes): out_grads.append(array_ops.slice(grad, begin, size)) elif isinstance(grad, indexed_slices_lib.IndexedSlices): # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) concat_dim_static = tensor_util.constant_value(concat_dim) if concat_dim_static is None: raise ValueError("Can only compute IndexedSlices gradient with " "statically-known concat_dim") if concat_dim_static < 0: rank = tensor_util.constant_value(array_ops.rank(input_values[0])) if rank is None: raise ValueError("Can only compute IndexedSlices gradient with " "negative concat_dim when first value rank is " "statically-known.") concat_dim_static %= rank # Get the inputs' tensor shapes sizes = [array_ops.shape(x) for x in input_values] if concat_dim_static > 0: # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices # gradients with all the indices, but with grad.values sliced accordingly. # This is like the Tensor case, except shape(grad.values)[0] is not equal # to shape(sizes[i])[0], since only a subset of the dim-0 values are # stored. mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim) for size in sizes: new_values = array_ops.slice( grad.values, begin, array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0)) out_grads.append( indexed_slices_lib.IndexedSlices(new_values, grad.indices, size)) # Lint complains begin = begin + ... begin = math_ops.add(begin, size * mask) else: # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients # only for the relevant indices. start = constant_op.constant(0, dtype=grad.indices.dtype) for size in sizes: size_concat_dim = array_ops.gather(size, non_neg_concat_dim) if size_concat_dim.dtype != grad.indices.dtype: size_concat_dim = math_ops.cast( size_concat_dim, dtype=grad.indices.dtype) end = start + size_concat_dim # Compute the 1-D Tensor of indices relevant for this input. indices_to_select = array_ops.squeeze( array_ops.where( math_ops.logical_and(grad.indices >= start, grad.indices < end)), axis=[1]) new_indices = array_ops.gather(grad.indices, indices_to_select) - start new_values = array_ops.gather(grad.values, indices_to_select) out_grads.append( indexed_slices_lib.IndexedSlices(new_values, new_indices, size)) start = end else: raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad)) return (out_grads + [None] if end_value_index <= dim_index else [None] + out_grads) @ops.RegisterGradient("Concat") def _ConcatGrad(op: ops.Operation, grad): return _ConcatGradHelper( op, grad, start_value_index=1, end_value_index=len(op.inputs), dim_index=0) @ops.RegisterGradient("ConcatV2") def _ConcatGradV2(op: ops.Operation, grad): return _ConcatGradHelper( op, grad, start_value_index=0, end_value_index=-1, dim_index=-1) ops.NotDifferentiable("ConcatOffset") @ops.RegisterGradient("Slice") def _SliceGrad(op: ops.Operation, grad): """Gradient for Slice op.""" # Create an Nx2 padding where the first column represents how many # zeros are to be prepended for each dimension, and the second # column indicates how many zeros are appended. # # The number of zeros to append is the shape of the input # elementwise-subtracted by both the begin vector and sizes vector. # # Some more reshaping is needed to assemble this tensor with the # right dimensions. input_vec = op.inputs[0] begin_vec = op.inputs[1] input_rank = array_ops.rank(input_vec) index_dtype = begin_vec.dtype slice_size = array_ops.shape(op.outputs[0], out_type=index_dtype) if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec), grad, begin_vec), None, None shape = array_ops_stack.stack([input_rank, 1]) before_pad = array_ops.reshape(begin_vec, shape) after_pad = array_ops.reshape( array_ops.shape(input_vec, out_type=index_dtype) - slice_size - begin_vec, shape) paddings = array_ops.concat([before_pad, after_pad], 1) return array_ops.pad(grad, paddings), None, None @ops.RegisterGradient("StridedSlice") def _StridedSliceGrad(op: ops.Operation, grad): """Gradient for StridedSlice op.""" begin = op.inputs[1] end = op.inputs[2] strides = op.inputs[3] # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the # same dtype so we build a shape of the same type as other args. # Note that the choice of `begin` for specifying `out_type` is arbitrary. # We could choose any of {begin|end|strides}.dtype since they are required to # be the same. x = array_ops.shape(op.inputs[0], out_type=begin.dtype) x_static = tensor_util.constant_value(x) x = x_static if x_static is not None else x begin_static = tensor_util.constant_value(begin) begin = begin_static if begin_static is not None else begin end_static = tensor_util.constant_value(end) end = end_static if end_static is not None else end strides_static = tensor_util.constant_value(strides) strides = strides_static if strides_static is not None else strides return array_ops.strided_slice_grad( x, begin, end, strides, grad, begin_mask=op.get_attr("begin_mask"), end_mask=op.get_attr("end_mask"), ellipsis_mask=op.get_attr("ellipsis_mask"), new_axis_mask=op.get_attr("new_axis_mask"), shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None @ops.RegisterGradient("StridedSliceGrad") def _StridedSliceGradGrad(op: ops.Operation, grad): """Gradient for StridedSliceGrad op.""" begin = op.inputs[1] end = op.inputs[2] strides = op.inputs[3] return None, None, None, None, array_ops.strided_slice( grad, begin, end, strides, begin_mask=op.get_attr("begin_mask"), end_mask=op.get_attr("end_mask"), ellipsis_mask=op.get_attr("ellipsis_mask"), new_axis_mask=op.get_attr("new_axis_mask"), shrink_axis_mask=op.get_attr("shrink_axis_mask")) @ops.RegisterGradient("TensorStridedSliceUpdate") def _TensorStridedSliceUpdateGrad(op: ops.Operation, grad): # pylint:disable=missing-function-docstring begin = op.inputs[1] end = op.inputs[2] strides = op.inputs[3] begin_mask = op.get_attr("begin_mask") end_mask = op.get_attr("end_mask") ellipsis_mask = op.get_attr("ellipsis_mask") new_axis_mask = op.get_attr("new_axis_mask") shrink_axis_mask = op.get_attr("shrink_axis_mask") def Apply(f, *args): return f(*args, begin_mask=begin_mask, end_mask=end_mask, shrink_axis_mask=shrink_axis_mask, new_axis_mask=new_axis_mask, ellipsis_mask=ellipsis_mask) dy = Apply(array_ops.strided_slice, grad, begin, end, strides) dx = Apply(array_ops.tensor_strided_slice_update, grad, begin, end, strides, array_ops.zeros_like(dy)) # The value is potentially broadcast to the shape of the strided slice, so we # may need to adjust dy. slice_shape = array_ops.shape(dy, out_type=begin.dtype) value_shape = array_ops.shape(op.inputs[4], out_type=slice_shape.dtype) _, reduction_axes = gen_array_ops.broadcast_gradient_args( slice_shape, value_shape) dy_reshaped = math_ops.reduce_sum(dy, axis=reduction_axes, keepdims=True) dy = array_ops.reshape(dy_reshaped, value_shape) return dx, None, None, None, dy @ops.RegisterGradient("Split") def _SplitGrad(op: ops.Operation, *grads): return None, array_ops.concat(list(grads), op.inputs[0]) @ops.RegisterGradient("SplitV") def _SplitVGrad(op: ops.Operation, *grads): returnval = array_ops.concat(list(grads), op.inputs[2]) returnval = [returnval] + [ None, ] * ( len(op.inputs) - 1) return returnval ops.NotDifferentiable("Const") @ops.RegisterGradient("Diag") def _DiagGrad(_, grad): return array_ops.diag_part(grad) @ops.RegisterGradient("DiagPart") def _DiagPartGrad(_, grad): return array_ops.diag(grad) @ops.RegisterGradient("MatrixDiag") def _MatrixDiagGrad(_, grad): return array_ops.matrix_diag_part(grad) @ops.RegisterGradient("MatrixDiagV2") def _MatrixDiagV2Grad(op: ops.Operation, grad): return array_ops.matrix_diag_part( grad, k=op.inputs[1]), None, None, None, None @ops.RegisterGradient("MatrixDiagV3") def _MatrixDiagV3Grad(op: ops.Operation, grad): return array_ops.matrix_diag_part( grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None @ops.RegisterGradient("MatrixDiagPart") def _MatrixDiagPartGrad(op: ops.Operation, grad): matrix_shape = op.inputs[0].get_shape()[-2:] if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]: return array_ops.matrix_diag(grad) else: return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad) @ops.RegisterGradient("MatrixDiagPartV2") def _MatrixDiagPartV2Grad(op: ops.Operation, grad): """Gradient for MatrixDiagPartV2.""" matrix_shape = op.inputs[0].get_shape()[-2:] if matrix_shape.is_fully_defined(): return array_ops.matrix_diag( grad, k=op.inputs[1], num_rows=matrix_shape[0], num_cols=matrix_shape[1]), None, None else: return array_ops.matrix_set_diag( array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None @ops.RegisterGradient("MatrixDiagPartV3") def _MatrixDiagPartV3Grad(op: ops.Operation, grad): """Gradient for MatrixDiagPartV3.""" matrix_shape = op.inputs[0].get_shape()[-2:] align = op.get_attr("align") if matrix_shape.is_fully_defined(): return array_ops.matrix_diag( grad, k=op.inputs[1], num_rows=matrix_shape[0], num_cols=matrix_shape[1], align=align), None, None else: return array_ops.matrix_set_diag( array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1], align=align), None, None @ops.RegisterGradient("MatrixSetDiag") def _MatrixSetDiagGrad(op: ops.Operation, grad): """Gradient for MatrixSetDiag.""" input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape()) diag_shape = op.inputs[1].get_shape() batch_shape = input_shape[:-2].merge_with(diag_shape[:-1]) matrix_shape = input_shape[-2:] if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined(): diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())] else: with ops.colocate_with(grad): grad_shape = array_ops.shape(grad) grad_rank = array_ops.rank(grad) batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2]) matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2]) min_dim = math_ops.reduce_min(matrix_shape) diag_shape = array_ops.concat([batch_shape, [min_dim]], 0) grad_input = array_ops.matrix_set_diag( grad, array_ops.zeros(diag_shape, dtype=grad.dtype)) grad_diag = array_ops.matrix_diag_part(grad) return (grad_input, grad_diag) @ops.RegisterGradient("MatrixSetDiagV2") def _MatrixSetDiagGradV2(op: ops.Operation, grad): """Gradient for MatrixSetDiagV2.""" diag_shape = op.inputs[1].get_shape() if not diag_shape.is_fully_defined(): # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. grad_shape = array_ops.shape(grad) batch_shape = grad_shape[:-2] matrix_shape = grad_shape[-2:] diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. d_lower = diag_index[0] d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. y_offset = cond.cond( math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) x_offset = cond.cond( math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, matrix_shape[1] + x_offset) # pylint: disable=g-long-lambda # pyformat: disable postfix = cond.cond( math_ops.equal(d_lower, d_upper), lambda: ops.convert_to_tensor([max_diag_len]), lambda: ops.convert_to_tensor([d_upper - d_lower + 1, max_diag_len])) # pyformat: enable # pylint: enable=g-long-lambda diag_shape = array_ops.concat([batch_shape, postfix], 0) grad_input = array_ops.matrix_set_diag( grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2]) grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2]) return (grad_input, grad_diag, None) @ops.RegisterGradient("MatrixSetDiagV3") def _MatrixSetDiagGradV3(op: ops.Operation, grad): """Gradient for MatrixSetDiagV3.""" diag_shape = op.inputs[1].get_shape() align = op.get_attr("align") if not diag_shape.is_fully_defined(): # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. grad_shape = array_ops.shape(grad) batch_shape = grad_shape[:-2] matrix_shape = grad_shape[-2:] diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. d_lower = diag_index[0] d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. y_offset = cond.cond( math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) x_offset = cond.cond( math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, matrix_shape[1] + x_offset) # pylint: disable=g-long-lambda # pyformat: disable postfix = cond.cond( math_ops.equal(d_lower, d_upper), lambda: ops.convert_to_tensor([max_diag_len]), lambda: ops.convert_to_tensor([d_upper - d_lower + 1, max_diag_len])) # pyformat: enable # pylint: enable=g-long-lambda diag_shape = array_ops.concat([batch_shape, postfix], 0) grad_input = array_ops.matrix_set_diag( grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2], align=align) grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align) return (grad_input, grad_diag, None) @ops.RegisterGradient("MatrixBandPart") def _MatrixBandPartGrad(op: ops.Operation, grad): num_lower = op.inputs[1] num_upper = op.inputs[2] return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None) # Edit Distance has no gradient (but can be used to eval seq2seq or CTC). ops.NotDifferentiable("EditDistance") @ops.RegisterGradient("Fill") def _FillGrad(_, grad): return None, math_ops.reduce_sum(grad) ops.NotDifferentiable("ZerosLike") ops.NotDifferentiable("OnesLike") @ops.RegisterGradient("PreventGradient") def _PreventGradientGrad(op: ops.Operation, _): raise LookupError("Gradient explicitly disabled. Reason: %s" % op.get_attr("message")) def _IndexedSlicesToTensorNoWarning(indexed_slices): """Converts an IndexedSlices to a Tensor without sparse->dense warnings.""" if not isinstance(indexed_slices, indexed_slices_lib.IndexedSlices): # If it is not IndexedSlices, it's better be a tensor. return indexed_slices if indexed_slices.dense_shape is None: raise ValueError( "Tensor conversion requested for IndexedSlices without dense_shape: %s" % str(indexed_slices)) return math_ops.unsorted_segment_sum(indexed_slices.values, indexed_slices.indices, indexed_slices.dense_shape[0]) @ops.RegisterGradient("Gather") def _GatherGrad(op: ops.Operation, grad): """Gradient for Gather op.""" # params can be large, so colocate the shape calculation with it. params = op.inputs[0] with ops.colocate_with(params): params_shape = array_ops.shape(params) # Build appropriately shaped IndexedSlices indices = op.inputs[1] size = array_ops.expand_dims(array_ops.size(indices), 0) values_shape = array_ops.concat([size, params_shape[1:]], 0) values = array_ops.reshape( _IndexedSlicesToTensorNoWarning(grad), values_shape) indices = array_ops.reshape(indices, size) return [indexed_slices_lib.IndexedSlices(values, indices, params_shape), None] def _GetBatchIndices(params_shape, indices, batch_dims): """Addds the batch offsets to the given indices and returns the results.""" batch_indices = indices indices_dtype = indices.dtype.base_dtype casted_params_shape = math_ops.cast(params_shape, indices_dtype) accum_dim_value = array_ops.ones((), dtype=indices_dtype) for dim in range(batch_dims, 0, -1): dim_value = casted_params_shape[dim - 1] accum_dim_value *= casted_params_shape[dim] start = array_ops.zeros((), dtype=indices_dtype) step = array_ops.ones((), dtype=indices_dtype) dim_indices = math_ops.range(start, dim_value, step) dim_indices *= accum_dim_value dim_shape = array_ops.concat([ array_ops.tile([1], [dim - 1]), [dim_value], array_ops.tile([1], [array_ops.rank(indices) - dim]) ], axis=0) batch_indices += array_ops.reshape(dim_indices, dim_shape) return batch_indices def _BatchGatherGrad(params_shape, values, indices, batch_dims, gather_dim_size): """Returns the gradient of GatherV2 with batch dimensions.""" # Axis is the first non-batch dimension. indices_size = array_ops.expand_dims(array_ops.size(indices), 0) if batch_dims: values_shape = array_ops.shape(values) # Add the batch offsets to indices and flatten the batch dimensions. outer_shape = values_shape[:batch_dims] inner_shape = values_shape[batch_dims:][1:] batch_size = gen_math_ops.prod(outer_shape, [0], False) flat_values_shape = array_ops.concat([[-1], inner_shape], 0) gather_dim_size *= batch_size indices = _GetBatchIndices(params_shape, indices, batch_dims) values = array_ops.reshape( _IndexedSlicesToTensorNoWarning(values), flat_values_shape) indices = array_ops.reshape(indices, indices_size) params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size) if batch_dims: # Put back the batch dimensions. params_grad = array_ops.reshape( params_grad, array_ops.concat([outer_shape, flat_values_shape], 0)) return params_grad @ops.RegisterGradient("GatherV2") def _GatherV2Grad(op: ops.Operation, grad): """Gradient for GatherV2 op.""" # params can be large, so colocate the shape calculation with it. # # params can be very large for sparse model, array_ops.shape raises # exception on the Windows platform when any dimension is larger than # int32. params_shape is not used in optimizer apply_sparse gradients, # so it's fine to convert it back to int32 regardless of truncation. params = op.inputs[0] with ops.colocate_with(params): params_shape = array_ops.shape(params, out_type=ops.dtypes.int64) params_shape = math_ops.cast(params_shape, dtypes.int32) indices = op.inputs[1] indices_size = array_ops.expand_dims(array_ops.size(indices), 0) axis = op.inputs[2] axis_static = tensor_util.constant_value(axis) batch_dims = int(op.get_attr("batch_dims")) if batch_dims < 0: if indices.shape.ndims is None: raise ValueError( f"Currently, it is unsupported to take the gradient of tf.gather " f"when batch_dims < 0 and the rank of the indices is unknown. Please " f"pass a positive batch_dims or use tf.ensure_shape to update the " f"shape of indices when calling tf.gather. Got " f"batch_dims={batch_dims} and indices={indices}") batch_dims += indices.shape.ndims # For axis 0 gathers, build an appropriately shaped IndexedSlices. if axis_static == 0: if context.executing_eagerly(): with ops.device(indices_size.device): params_tail_shape = array_ops.identity(params_shape)[1:] else: params_tail_shape = params_shape[1:] values_shape = array_ops.concat([indices_size, params_tail_shape], 0) values = array_ops.reshape( _IndexedSlicesToTensorNoWarning(grad), values_shape) indices = array_ops.reshape(indices, indices_size) params_grad = indexed_slices_lib.IndexedSlices(values, indices, params_shape) else: # Handle axis by transposing the axis dimension to be the first non-batch # dimension, compute the gradient and transpose the result back. outer_shape = params_shape[:axis] inner_shape = params_shape[axis:][1:] values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0) values_dims = array_ops.size(values_shape) axis_dims = array_ops.size(outer_shape) outer_batches_indices = math_ops.range(batch_dims) batch_axis_indices = math_ops.range(batch_dims, axis_dims) inner_axes_indices = math_ops.range(axis_dims + 1, values_dims) values = array_ops.reshape( _IndexedSlicesToTensorNoWarning(grad), values_shape) # Move values[axis] up to values[batch_dims] transpose_dims = array_ops.concat([ outer_batches_indices, [axis_dims], batch_axis_indices, inner_axes_indices ], 0) values_transpose = array_ops.transpose(values, transpose_dims) params_shape_transpose = array_ops.gather(params_shape, transpose_dims) params_grad = _BatchGatherGrad(params_shape_transpose, values_transpose, indices, batch_dims, params_shape[axis]) # Inverts the above transpose by moving dimension batch_dims back to its # original position. invert_transpose_dims = array_ops.concat([ outer_batches_indices, batch_axis_indices + 1, [batch_dims], inner_axes_indices ], 0) params_grad = array_ops.transpose(params_grad, invert_transpose_dims) if not isinstance(params_grad, indexed_slices_lib.IndexedSlices): # Prevents mismatches in shapes when some tensor dimensions are zero. params_grad = array_ops.reshape( params_grad, array_ops.shape(params) ) return [params_grad, None, None] @ops.RegisterGradient("GatherNd") def _GatherNdGrad(op: ops.Operation, grad): ref = op.inputs[0] indices = op.inputs[1] ref_shape = array_ops.shape(ref, out_type=indices.dtype) if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: ref_grad = indexed_slices_lib.IndexedSlices( grad, array_ops.squeeze(indices, axis=-1), ref_shape) else: ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) return [ref_grad, None] @ops.RegisterGradient("ResourceGatherNd") def _ResourceGatherNdGrad(op: ops.Operation, grad): # pylint: disable=missing-docstring ref = op.inputs[0] indices = op.inputs[1] ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype) if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: ref_grad = indexed_slices_lib.IndexedSlices( grad, array_ops.squeeze(indices, axis=-1), ref_shape) else: ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) return [ref_grad, None] @ops.RegisterGradient("CheckNumerics") def _CheckNumericsGrad(op: ops.Operation, grad): """Gradient for check_numerics op.""" return array_ops.check_numerics( grad, "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % op.get_attr("message")) @ops.RegisterGradient("CheckNumericsV2") def _CheckNumericsV2Grad(op: ops.Operation, grad): """Gradient for check_numerics op.""" return array_ops.check_numerics_v2( grad, "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % op.get_attr("message")) @ops.RegisterGradient("PlaceholderWithDefault") @ops.RegisterGradient("Identity") def _IdGrad(_, grad): return grad @ops.RegisterGradient("_EagerConst") def _EagerConstGrad(_, grad): raise AssertionError( "This op should never interact with gradient APIs. Please file a bug.") @ops.RegisterGradient("RefIdentity") def _RefIdGrad(_, grad): return grad @ops.RegisterGradient("IdentityN") def _IdNGrad(_, *grad): return grad ops.NotDifferentiable("StopGradient") @ops.RegisterGradient("Reshape") def _ReshapeGrad(op: ops.Operation, grad): return [ array_ops.reshape( _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])), None ] ops.NotDifferentiable("InvertPermutation") def _ReshapeToInput(op: ops.Operation, grad): """Reshapes the gradient to the shape of the original input.""" return array_ops.reshape( _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])) @ops.RegisterGradient("ExpandDims") def _ExpandDimsGrad(op: ops.Operation, grad): return [_ReshapeToInput(op, grad), None] @ops.RegisterGradient("Squeeze") def _SqueezeGrad(op: ops.Operation, grad): return _ReshapeToInput(op, grad) @ops.RegisterGradient("Transpose") def _TransposeGrad(op: ops.Operation, grad): """Returns unshuffle(grad).""" p = op.inputs[1] return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None] @ops.RegisterGradient("ConjugateTranspose") def _ConjugateTransposeGrad(op: ops.Operation, grad): """Returns conj(unshuffle(grad)).""" p = op.inputs[1] return [ array_ops.transpose( grad, array_ops.invert_permutation(p), conjugate=True), None ] ops.NotDifferentiable("Shape") ops.NotDifferentiable("ShapeN") ops.NotDifferentiable("Rank") ops.NotDifferentiable("Size") @ops.RegisterGradient("Tile") def _TileGrad(op: ops.Operation, grad): """Sum reduces grad along the tiled dimensions.""" input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype) # We interleave multiples and input_shape to get split_shape, # reshape grad to split_shape, and reduce along all even # dimensions (the tiled dimensions) to get the result # with shape input_shape. For example # input_shape = [20, 30, 40] # multiples = [2, 3, 4] # split_shape = [2, 20, 3, 30, 4, 40] # axes = [0, 2, 4] split_shape = array_ops.reshape( array_ops.transpose(array_ops_stack.stack([op.inputs[1], input_shape])), [-1]) axes = math_ops.range(0, array_ops.size(split_shape), 2) # Sum reduces grad along the first dimension for IndexedSlices if isinstance(grad, indexed_slices_lib.IndexedSlices): input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype) grad = math_ops.unsorted_segment_sum( grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0) split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) # Fix shape inference if not context.executing_eagerly(): input_grad.set_shape(op.inputs[0].get_shape()) return [input_grad, None] ops.NotDifferentiable("BroadcastGradientArgs") def _PadGrad(op: ops.Operation, grad): """Gradient for Pad.""" # Pad introduces values around the original tensor, so the gradient function # slices the original shape out of the gradient.""" x = op.inputs[0] a = op.inputs[1] # [Rank(x), 2] # Takes a slice of a. The 1st column. [Rank(x), 1]. pad_before = array_ops.slice(a, [0, 0], array_ops_stack.stack([array_ops.rank(x), 1])) # Make it a 1-D tensor. begin = array_ops.reshape(pad_before, [-1]) sizes = array_ops.shape(x, out_type=begin.dtype) x_grad = array_ops.slice(grad, begin, sizes) if len(op.inputs) == 3: return x_grad, None, None else: return x_grad, None ops.RegisterGradient("Pad")(_PadGrad) ops.RegisterGradient("PadV2")(_PadGrad) # ReverseSequence is just a permutation. The gradient permutes back. @ops.RegisterGradient("ReverseSequence") def _ReverseSequenceGrad(op: ops.Operation, grad): seq_lengths = op.inputs[1] return [ array_ops.reverse_sequence( grad, batch_axis=op.get_attr("batch_dim"), seq_axis=op.get_attr("seq_dim"), seq_lengths=seq_lengths), None ] @ops.RegisterGradient("Reverse") def _ReverseGrad(op: ops.Operation, grad): reverse_dims = op.inputs[1] return gen_array_ops.reverse(grad, reverse_dims), None @ops.RegisterGradient("ReverseV2") def _ReverseV2Grad(op: ops.Operation, grad): axis = op.inputs[1] return array_ops.reverse_v2(grad, axis), None @ops.RegisterGradient("SpaceToBatch") def _SpaceToBatchGrad(op: ops.Operation, grad): # Its gradient is the opposite op: BatchToSpace. block_size = op.get_attr("block_size") return [ array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None ] @ops.RegisterGradient("SpaceToBatchND") def _SpaceToBatchNDGrad(op: ops.Operation, grad): # Its gradient is the opposite op: BatchToSpaceND. return [ array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None ] @ops.RegisterGradient("BatchToSpace") def _BatchToSpaceGrad(op: ops.Operation, grad): # Its gradient is the opposite op: SpaceToBatch. block_size = op.get_attr("block_size") return [ array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None ] @ops.RegisterGradient("BatchToSpaceND") def _BatchToSpaceNDGrad(op: ops.Operation, grad): # Its gradient is the opposite op: SpaceToBatchND. return [ array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None ] @ops.RegisterGradient("SpaceToDepth") def _SpaceToDepthGrad(op: ops.Operation, grad): # Its gradient is the opposite op: DepthToSpace. block_size = op.get_attr("block_size") data_format = op.get_attr("data_format") if data_format == "NCHW_VECT_C": raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. " "NCHW_VECT_C requires qint8 data type.") return array_ops.depth_to_space(grad, block_size, data_format=data_format) @ops.RegisterGradient("DepthToSpace") def _DepthToSpaceGrad(op: ops.Operation, grad): # Its gradient is the opposite op: SpaceToDepth. block_size = op.get_attr("block_size") data_format = op.get_attr("data_format") if data_format == "NCHW_VECT_C": raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. " "NCHW_VECT_C requires qint8 data type.") return array_ops.space_to_depth(grad, block_size, data_format=data_format) ops.NotDifferentiable("OneHot") @ops.RegisterGradient("MirrorPad") def _MirrorPadGrad(op: ops.Operation, grad): mode = op.get_attr("mode") return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None] @ops.RegisterGradient("MirrorPadGrad") def _MirrorPadGradGrad(op: ops.Operation, grad): mode = op.get_attr("mode") return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None] @ops.RegisterGradient("QuantizeAndDequantize") def _QuantizeAndDequantizeGrad(_, grad): return grad @ops.RegisterGradient("QuantizeAndDequantizeV2") def _QuantizeAndDequantizeV2Grad(_, grad): return [grad, None, None] @ops.RegisterGradient("QuantizeAndDequantizeV3") def _QuantizeAndDequantizeV3Grad(_, grad): # Only propagate the gradient for the unquantized input. return [grad, None, None, None] @ops.RegisterGradient("ExtractImagePatches") def _ExtractImagePatchesGrad(op: ops.Operation, grad): # pylint:disable=missing-function-docstring input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64) batch_size, rows_in, cols_in, channels = array_ops_stack.unstack(input_bhwc) output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64) rows_out, cols_out = array_ops_stack.unstack(output_bhwc[1:3]) _, ksize_r, ksize_c, _ = op.get_attr("ksizes") # Create indices matrix for input tensor. # Note that 0 is preserved for padding location, # so indices for input start from 1 to 1 + rows_in * cols_in. input_indices_num = rows_in * cols_in # XLA version of extract_image_patches does not support int64, # using float32 instead. input_idx = array_ops.reshape( math_ops.range(1, input_indices_num + 1, dtype=ops.dtypes.float32), (1, rows_in, cols_in, 1), ) input_idx_patched = gen_array_ops.extract_image_patches( input_idx, op.get_attr("ksizes"), op.get_attr("strides"), op.get_attr("rates"), op.get_attr("padding")) input_idx_patched = math_ops.cast(input_idx_patched, dtypes.int64) grad_expanded = array_ops.transpose( array_ops.reshape( _IndexedSlicesToTensorNoWarning(grad), (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), (1, 2, 3, 4, 0, 5)) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) # Shift all input indices back. Padding locations will have "-1" value # which is fortunately ignored by segmented sum. segment_ids = array_ops.reshape(input_idx_patched, [-1]) - 1 grad_out = math_ops.unsorted_segment_sum( grad_flat, segment_ids, num_segments=input_indices_num ) grad_out = array_ops.reshape( grad_out, (rows_in, cols_in, batch_size, channels) ) grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) return [grad_out] @ops.RegisterGradient("ExtractVolumePatches") def _ExtractVolumePatchesGrad(op: ops.Operation, grad): # pylint:disable=missing-function-docstring batch_size, planes_in, rows_in, cols_in, channels = [ dim.value for dim in op.inputs[0].shape.dims ] input_bphwc = array_ops.shape(op.inputs[0]) batch_size = input_bphwc[0] channels = input_bphwc[4] # Create indices matrix for input tensor. # Note that 0 is preserved for padding location, # so indices for input start from 1 to 1 + rows_in * cols_in. input_indices_num = 1 + planes_in * rows_in * cols_in input_idx = array_ops.reshape( math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), (1, planes_in, rows_in, cols_in, 1)) input_idx_patched = gen_array_ops.extract_volume_patches( input_idx, op.get_attr("ksizes"), op.get_attr("strides"), op.get_attr("padding")) # Create indices matrix for output tensor. _, planes_out, rows_out, cols_out, _ = [ dim.value for dim in op.outputs[0].shape.dims ] _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes") # Indices for output start from 0. prc_indices_num = planes_out * rows_out * cols_out output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c output_idx = array_ops.reshape( math_ops.range(output_indices_num, dtype=ops.dtypes.int64), (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c)) # Construct mapping table for indices: (input -> output). idx_matrix = array_ops.concat([ array_ops.expand_dims(input_idx_patched, axis=-1), array_ops.expand_dims(output_idx, axis=-1) ], axis=-1) idx_map = array_ops.reshape(idx_matrix, (-1, 2)) sp_shape = (input_indices_num, output_indices_num) sp_mat_full = sparse_tensor.SparseTensor( idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) # Remove all padding locations [0, :]. sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), (input_indices_num - 1, output_indices_num)) grad_expanded = array_ops.transpose( array_ops.reshape( _IndexedSlicesToTensorNoWarning(grad), (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r, ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7)) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) grad_out = array_ops.reshape( jac, (planes_in, rows_in, cols_in, batch_size, channels)) grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4)) return [grad_out] @ops.RegisterGradient("ScatterNd") def _ScatterNdGrad(op: ops.Operation, grad): indices = op.inputs[0] updates_grad = array_ops.gather_nd(grad, indices) return [None, updates_grad, None] @ops.RegisterGradient("TensorScatterUpdate") def _TensorScatterUpdateGrad(op: ops.Operation, grad): indices = op.inputs[1] updates_grad = array_ops.gather_nd(grad, indices) tensor_grad = array_ops.tensor_scatter_update( array_ops.identity(grad), indices, array_ops.zeros_like(op.inputs[2], dtype=grad.dtype)) return [tensor_grad, None, updates_grad] @ops.RegisterGradient("TensorScatterAdd") def _TensorScatterAddGrad(op: ops.Operation, grad): indices = op.inputs[1] updates_grad = array_ops.gather_nd(grad, indices) tensor_grad = array_ops.identity(grad) return [tensor_grad, None, updates_grad] def _TensorScatterMinOrMaxGrad(op: ops.Operation, grad): """Gradient for TensorScatterMin and TensorScatterMax.""" indices = op.inputs[1] x = op.inputs[0] y = op.inputs[2] output = op.outputs[0] x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype) y_output = array_ops.gather_nd(output, indices) y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype) ys_indicators = array_ops.scatter_nd( indices, y_indicators, array_ops.shape(x, out_type=indices.dtype)) indicators = x_indicators + ys_indicators # All elements are >= 1. # If there are multiple minimum or maximum elements then the gradient will be # divided between them. x_grad = grad * x_indicators / indicators y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators return [x_grad, None, y_grad] @ops.RegisterGradient("TensorScatterMax") def _TensorScatterMaxGrad(op: ops.Operation, grad): """Gradient for TensorScatterMax op.""" return _TensorScatterMinOrMaxGrad(op, grad) @ops.RegisterGradient("TensorScatterMin") def _TensorScatterMinGrad(op: ops.Operation, grad): """Gradient for TensorScatterMin op.""" return _TensorScatterMinOrMaxGrad(op, grad) @ops.RegisterGradient("TensorScatterSub") def _TensorScatterSubGrad(op: ops.Operation, grad): indices = op.inputs[1] updates_grad = array_ops.gather_nd(grad, indices) tensor_grad = array_ops.identity(grad) return [tensor_grad, None, -updates_grad] @ops.RegisterGradient("ScatterNdNonAliasingAdd") def _ScatterNdNonAliasingAddGrad(op: ops.Operation, grad): indices = op.inputs[1] updates_grad = array_ops.gather_nd(grad, indices) return [grad, None, updates_grad] @ops.RegisterGradient("BroadcastTo") def _BroadcastToGrad(op: ops.Operation, grad): # pylint:disable=missing-function-docstring input_value = op.inputs[0] broadcast_shape = op.inputs[1] shape_dtype = dtypes.int32 if isinstance(broadcast_shape, tensor.Tensor): shape_dtype = broadcast_shape.dtype input_value_shape = array_ops.shape(input_value, out_type=shape_dtype) if not isinstance(broadcast_shape, ops.EagerTensor): broadcast_shape_static = tensor_shape.TensorShape( tensor_util.try_evaluate_constant(broadcast_shape)) if broadcast_shape_static.is_fully_defined(): broadcast_shape = constant_op.constant( broadcast_shape_static.as_list(), dtype=shape_dtype) _, reduction_axes = gen_array_ops.broadcast_gradient_args( broadcast_shape, input_value_shape) updates_grad_reshaped = math_ops.reduce_sum( grad, axis=reduction_axes, keepdims=True) updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) return [updates_grad, None]