# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Gradients for (block) GRU/LSTM operators.""" from tensorflow.python.framework import ops from tensorflow.python.ops import gen_rnn_ops def _block_lstm_grad(op, *grads): """Gradient for the BlockLSTM op.""" seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs i, cs, f, o, ci, co, h = op.outputs _, cs_grad, _, _, _, _, h_grad = grads (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, b_grad) = gen_rnn_ops.block_lstm_grad( seq_len_max=seq_len_max, x=x, cs_prev=cs_prev, h_prev=h_prev, w=w, wci=wci, wcf=wcf, wco=wco, b=b, i=i, cs=cs, f=f, o=o, ci=ci, co=co, h=h, cs_grad=cs_grad, h_grad=h_grad, use_peephole=op.get_attr("use_peephole")) return (None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, b_grad) ops.RegisterGradient("BlockLSTM")(_block_lstm_grad) ops.RegisterGradient("BlockLSTMV2")(_block_lstm_grad)