# Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """`jax.experimental.rnn`: GPU accelerated RNN ---------------------------------------------- This module provides experimental support to CUDNN-backed LSTM. Currrently, the only supported RNN flavor is LSTM with double-bias. We use notations and variable names similar to https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM and CUDNN_LSTM entry in https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t. Note that a bidirectional LSTM is treated as having twice the number of layers, where a forward layer i is followed by a reverse layer i. Each direction has its own associated weights. We use pseudo-layer to denote such layers following CUDNN documentation https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnGetRNNWeightParams. CUDNN takes an opaque 1D weight array that densely packs all the weight arrays in a sparsely documented layout. Through trial-and-error and testing, we believe the layout is the following. Assume 2-layer bi-LSTM with double-bias, so 4 pseudo-layers in total (forward-0, reverse-0, forward-1, reverse-1). There are 4 kinds of weights: W_ih, W_hh, b_ih and b_hh, where W_ih = (W_ii, W_if, W_ig, W_io) concatenated on leading axis, W_hh = (W_hi, W_hf, W_hg, W_ho) concatenated on leading axis, b_ih = (b_ii, b_if, b_ig, b_io) concatenated on leading axis, b_hh = (b_hi, b_hf, b_hg, b_ho) concatenated on leading axis. Say W_ih^0 denotates W_ih from pseudo-layer 0. The linear weights are packed together from all pseudo-layers followed by bias weights from all pseudo-layers. In particular, for each layer, W_ih is followed by W_hh and b_ih by b_hh. (W_ih^0, W_hh^0, W_ih^1, W_hh^1, W_ih^2, W_hh^2, W_ih^3, W_hh^3, b_ih^0, b_hh^0, b_ih^1, b_hh^1, b_ih^2, b_hh^2, b_ih^3, b_hh^3) See `get_params_shapes_in_lstm`. Example usage: ``` x = jax.random.normal( k1, (batch_size, seq_len, input_size), dtype=jnp.float32) h_0 = jax.random.normal( k2, (num_directions * num_layers, batch_size, hidden_size), dtype=jnp.float32) c_0 = jax.random.normal( k3, (num_directions * num_layers, batch_size, hidden_size), dtype=jnp.float32) seq_lengths = jnp.ones((batch_size,), dtype=jnp.int32) * seq_len weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers, bidirectional) y, h_n, c_n = rnn.lstm( x, h_0, c_0, weights, seq_lengths=seq_lengths, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=False, bidirectional=bidirectional) ``` TODO: - Add support for input and weight dtypes other than float32. - Support ragged inputs. - Support RNNs other than LSTM. """ from functools import partial import math from typing import Any, Dict, List, Tuple import jax import numpy as np from jax._src import core from jax.interpreters import mlir from jax.interpreters import xla from jax._src.custom_derivatives import custom_vjp from jax._src.typing import Array, Shape import jax.numpy as jnp try: from jax._src.lib import gpu_rnn except ImportError: gpu_rnn = None # type: ignore[assignment] PRNGKeyArray = Any sigmoid = jax.nn.sigmoid tanh = jax.nn.tanh def _W_ih_l(layer_i: int, input_size: int, hidden_size: int, bidirectional: bool) -> Shape: """Shape of W_ii|W_if|W_ig|W_io. Note that layer_i is an index of pseudo-layers. """ if layer_i == 0 or (layer_i == 1 and bidirectional): return (4 * hidden_size, input_size) else: num_directions = 2 if bidirectional else 1 return (4 * hidden_size, num_directions * hidden_size) def _W_hh_l(layer_i: int, input_size: int, hidden_size: int, bidirectional: bool) -> Shape: """Shape of W_hi|W_hf|W_hg|W_ho.""" return (4 * hidden_size, hidden_size) def _b_ih_l(layer_i: int, input_size: int, hidden_size: int, bidirectional: bool) -> Shape: """Shape of b_ii|b_if|b_ig|b_io.""" return (4 * hidden_size,) def _b_hh_l(layer_i: int, input_size: int, hidden_size: int, bidirectional: bool) -> Shape: """Shape of b_hi|b_hf|b_hg|b_ho.""" return (4 * hidden_size,) def _get_params_shapes_in_lstm(input_size: int, hidden_size: int, num_layers: int, bidirectional: bool) -> List[Shape]: """Get flat param shapes in LSTM. See module docstring for layout.""" layer_shapes = [] num_directions = 2 if bidirectional else 1 num_pseudo_layers = num_layers * num_directions linear_weights = [_W_ih_l, _W_hh_l] for i in range(num_pseudo_layers): for w_kind in linear_weights: layer_shape = w_kind(i, input_size, hidden_size, bidirectional) layer_shapes.append(layer_shape) bias_weights = [_b_ih_l, _b_hh_l] for i in range(num_pseudo_layers): for w_kind in bias_weights: layer_shape = w_kind(i, input_size, hidden_size, bidirectional) layer_shapes.append(layer_shape) return layer_shapes def get_num_params_in_lstm(input_size: int, hidden_size: int, num_layers: int, bidirectional: bool) -> int: """Get param count in LSTM.""" layer_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional) param_count = sum([math.prod(shape) for shape in layer_shapes]) return param_count def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int, num_layers: int, bidirectional: bool): """Random initialize LSTM weights from U(-k, k), k=sqrt(1/hidden_size).""" param_count = get_num_params_in_lstm(input_size, hidden_size, num_layers, bidirectional) k = np.sqrt(1.0 / hidden_size) return jax.random.uniform( rng, shape=(param_count,), dtype=jnp.float32, minval=-k, maxval=k) def unpack_lstm_weights( weights: Array, input_size: int, hidden_size: int, num_layers: int, bidirectional: bool ) -> Tuple[Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]]: """Unpack cudnn LSTM weights into individual weights. CUDNN LSTM weight layout: (num_layers, num_directions, W_ih, W_hh, b_ih, b_hh) Returns W_ih, W_hh, b_ih, b_hh. e.g. W_ih[2][1] is the concat weights of 4 weights (W_ii, W_if, W_ig, W_io), each of shape (hidden_size, input_size) at 2nd layer for the reverse direction. See notations from https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM. """ flat_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional) flat_shapes_offset = 0 w_offsets = 0 num_directions = 2 if bidirectional else 1 num_pseudo_layers = num_layers * num_directions W_ih: Dict[int, Array] = {} W_hh: Dict[int, Array] = {} for l in range(num_pseudo_layers): for w_kind in [W_ih, W_hh]: shape = flat_shapes[flat_shapes_offset] flat_shapes_offset += 1 num_elems = math.prod(shape) w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape) w_offsets += num_elems b_ih: Dict[int, Array] = {} b_hh: Dict[int, Array] = {} for l in range(num_pseudo_layers): for w_kind in [b_ih, b_hh]: shape = flat_shapes[flat_shapes_offset] flat_shapes_offset += 1 num_elems = math.prod(shape) w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape) w_offsets += num_elems return W_ih, W_hh, b_ih, b_hh @partial(custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) def lstm(x: Array, h_0: Array, c_0: Array, weights: Array, seq_lengths: Array, input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool) -> Tuple[Array, Array, Array]: """LSTM via CuDNN or HIPDNN (not-yet-supported). Assume batch-first inputs. Arguments: x: (batch_size, max_seq_length, input_size) h_0: (num_directions * num_layers, batch_size, hidden_size) c_0: (num_directions * num_layers, batch_size, hidden_size) weights: (num_params,) where num_params = get_num_params_in_lstm(...) seq_lengths: (batch_size,) Returns: (y, h_n, c_n, reserve_space). y: (batch_size, max_seq_length, hidden_size * num_directions) h_n: (num_directions * num_layers, batch_size, hidden_size) c_n: (num_directions * num_layers, batch_size, hidden_size) """ (y, h_n, c_n), _ = lstm_fwd( x, h_0, c_0, weights, seq_lengths, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) return y, h_n, c_n @partial(jax.jit, static_argnums=(8, 9, 10, 11, 12)) def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array], W_hh: Dict[int, Array], b_ih: Dict[int, Array], b_hh: Dict[int, Array], seq_lengths: Array, input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool) -> Tuple[Array, Array, Array]: """Reference implementation of LSTM. See https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#lstm https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t """ if seq_lengths.dtype != jnp.dtype("int32"): raise NotImplementedError("`seq_lengths` can only be int32.") if dropout != 0.0: raise NotImplementedError( 'Dropout not supported in LSTM reference because we cannot determine CUDNN dropout mask.' ) # TODO(zhangqiaorjc): Handle ragged seq_lengths. # batch_size, max_seq_length = x.shape[0], x.shape[1] # assert seq_lengths.shape == (batch_size,) # for i in range(batch_size): # if int(seq_lengths[i]) != max_seq_length: # raise NotImplementedError('Does not yet support ragged sequences.') def lstm_cell(carry, x, *, W_ih, W_hh, b_ih, b_hh): h, c = carry W_ii, W_if, W_ig, W_io = jnp.split(W_ih, 4, axis=0) W_hi, W_hf, W_hg, W_ho = jnp.split(W_hh, 4, axis=0) b_ii, b_if, b_ig, b_io = jnp.split(b_ih, 4, axis=0) b_hi, b_hf, b_hg, b_ho = jnp.split(b_hh, 4, axis=0) i = sigmoid(x @ W_ii.T + b_ii[None] + h @ W_hi.T + b_hi[None]) f = sigmoid(x @ W_if.T + b_if[None] + h @ W_hf.T + b_hf[None]) g = tanh(x @ W_ig.T + b_ig[None] + h @ W_hg.T + b_hg[None]) o = sigmoid(x @ W_io.T + b_io[None] + h @ W_ho.T + b_ho[None]) c = f * c + i * g h = o * tanh(c) return (h, c), h # here we also output the carry so that we can later slice # the correct carry according to seq_lengths, while this takes more memory # it is faster than using 'jnp.where' inside the scan loop def scan_fn(cell, carry, x): carry, y = cell(carry, x) return carry, (carry, y) seq_first_y = x.transpose(1, 0, 2) if not bidirectional: final_h = [] final_c = [] for l in range(num_layers): cell = partial( lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l]) cell_fn = partial(scan_fn, cell) out = jax.lax.scan(cell_fn, (h_0[l], c_0[l]), seq_first_y) (h_t, c_t), seq_first_y = _extract_output(seq_lengths, out) final_h.append(h_t) final_c.append(c_t) h_n = jnp.stack(final_h) c_n = jnp.stack(final_c) return seq_first_y.transpose(1, 0, 2), h_n, c_n # bidirectional final_h = [] final_c = [] for l in range(num_layers * 2): cell = partial( lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l]) cell_fn = partial(scan_fn, cell) if l % 2 == 0: out = jax.lax.scan(cell_fn, (h_0[l], c_0[l]), seq_first_y) (h_t, c_t), seq_first_y_fwd = _extract_output(seq_lengths, out) else: # reverse sequence while keeping padding at the end seq_first_y_reversed = _flip_sequence(seq_first_y, seq_lengths) out = jax.lax.scan( cell_fn, (h_0[l], c_0[l]), seq_first_y_reversed) (h_t, c_t), seq_first_y_bwd = _extract_output(seq_lengths, out) # align reversed sequence with original sequence seq_first_y_bwd = _flip_sequence(seq_first_y_bwd, seq_lengths) # Inputs to next layer are concat'ed from fwd and bwd. seq_first_y = jnp.concatenate([seq_first_y_fwd, seq_first_y_bwd], axis=-1) # pytype: disable=name-error final_h.append(h_t) final_c.append(c_t) h_n = jnp.stack(final_h) c_n = jnp.stack(final_c) return seq_first_y.transpose(1, 0, 2), h_n, c_n def _extract_output(seq_lengths: Array, out) -> Tuple[Tuple[Array, Array], Array]: _, ((hs, cs), seq_first_y) = out h_t = _select_last_carry(hs, seq_lengths) c_t = _select_last_carry(cs, seq_lengths) # [seq_len, batch] [1, batch] [seq_len, 1] mask = seq_lengths[None] > jnp.arange(seq_first_y.shape[0], dtype=jnp.int32)[:, None] # [batch, seq_len, hidden_size] seq_first_y = jnp.where( mask[..., None], # [seq_len, batch, 1] seq_first_y, # [seq_len, batch, hidden_size] 0) return (h_t, c_t), seq_first_y def _select_last_carry(carry_seq: Array, seq_lengths: Array): return carry_seq[seq_lengths - 1, jnp.arange(carry_seq.shape[1])] def _flip_sequence(sequences: Array, seq_lengths: Array) -> Array: max_steps = sequences.shape[0] roll_amounts = max_steps - seq_lengths # roll initially puts padding at the front so when the sequence is reversed # (via [::-1]) the padding stays at the end return jax.vmap(partial(jnp.roll, axis=0), in_axes=(1, 0), out_axes=1)(sequences, roll_amounts)[::-1] def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array, input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool): if seq_lengths.dtype != jnp.dtype("int32"): raise NotImplementedError("`seq_lengths` can only be int32.") if jax._src.lib.version < (0, 4, 9): y, h_n, c_n, workspace, reserve_space = rnn_fwd_p.bind( x, h_0, c_0, w, seq_lengths, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, workspace, reserve_space) else: y, h_n, c_n, reserve_space = rnn_fwd_p.bind( x, h_0, c_0, w, seq_lengths, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, reserve_space) def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval, input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool): batch_size, max_seq_length = x_aval.shape[0], x_aval.shape[1] num_directions = 2 if bidirectional else 1 output_shape = (batch_size, max_seq_length, num_directions * hidden_size) output_aval = core.ShapedArray(output_shape, x_aval.dtype) if jax._src.lib.version < (0, 4, 9): workspace_size, reserve_space_size = ( gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout, bidirectional)) workspace_aval = core.ShapedArray((workspace_size,), jnp.float32) reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32) return output_aval, h_0_aval, c_0_aval, workspace_aval, reserve_space_aval else: _, reserve_space_size = ( gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout, bidirectional)) reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32) return output_aval, h_0_aval, c_0_aval, reserve_space_aval rnn_fwd_p = core.Primitive('rnn_fwd') rnn_fwd_p.multiple_results = True rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p)) rnn_fwd_p.def_abstract_eval(rnn_abstract_eval) if gpu_rnn: mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda') def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional, residuals, gradients): if jax._src.lib.version < (0, 4, 9): x, h_0, c_0, w, seq_lengths, y, workspace, reserve_space = residuals dy, dh_n, dc_n = gradients dx, dh_0, dc_0, dw = rnn_bwd_p.bind( dy, dh_n, dc_n, x, h_0, c_0, w, y, workspace, reserve_space, seq_lengths, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths)) else: x, h_0, c_0, w, seq_lengths, y, reserve_space = residuals dy, dh_n, dc_n = gradients dx, dh_0, dc_0, dw = rnn_bwd_p.bind( dy, dh_n, dc_n, x, h_0, c_0, w, y, reserve_space, seq_lengths, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths)) if jax._src.lib.version < (0, 4, 9): def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, w_aval, y_aval, workspace_aval, reserve_space_aval, seq_lengths_aval, input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool): return x_aval, h0_aval, c0_aval, w_aval else: def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore w_aval, y_aval, reserve_space_aval, seq_lengths_aval, input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool): return x_aval, h0_aval, c0_aval, w_aval rnn_bwd_p = core.Primitive('rnn_bwd') rnn_bwd_p.multiple_results = True rnn_bwd_p.def_impl(partial(xla.apply_primitive, rnn_bwd_p)) rnn_bwd_p.def_abstract_eval(rnn_bwd_abstract_eval) if gpu_rnn: mlir.register_lowering( rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda') lstm.defvjp(lstm_fwd, lstm_bwd)