1358 lines
52 KiB
Python
1358 lines
52 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Legacy module implementing RNN Cells.
|
||
|
|
||
|
This module provides a number of basic commonly used RNN cells, such as LSTM
|
||
|
(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
|
||
|
operators that allow adding dropouts, projections, or embeddings for inputs.
|
||
|
Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
|
||
|
calling the `rnn` ops several times.
|
||
|
"""
|
||
|
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import collections
|
||
|
import warnings
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import activations
|
||
|
from keras import backend
|
||
|
from keras import initializers
|
||
|
from keras.engine import base_layer_utils
|
||
|
from keras.engine import input_spec
|
||
|
from keras.legacy_tf_layers import base as base_layer
|
||
|
from keras.utils import tf_utils
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
_BIAS_VARIABLE_NAME = "bias"
|
||
|
_WEIGHTS_VARIABLE_NAME = "kernel"
|
||
|
|
||
|
|
||
|
def _hasattr(obj, attr_name):
|
||
|
try:
|
||
|
getattr(obj, attr_name)
|
||
|
except AttributeError:
|
||
|
return False
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
|
||
|
def _concat(prefix, suffix, static=False):
|
||
|
"""Concat that enables int, Tensor, or TensorShape values.
|
||
|
|
||
|
This function takes a size specification, which can be an integer, a
|
||
|
TensorShape, or a Tensor, and converts it into a concatenated Tensor
|
||
|
(if static = False) or a list of integers (if static = True).
|
||
|
|
||
|
Args:
|
||
|
prefix: The prefix; usually the batch size (and/or time step size).
|
||
|
(TensorShape, int, or Tensor.)
|
||
|
suffix: TensorShape, int, or Tensor.
|
||
|
static: If `True`, return a python list with possibly unknown dimensions.
|
||
|
Otherwise return a `Tensor`.
|
||
|
|
||
|
Returns:
|
||
|
shape: the concatenation of prefix and suffix.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if `suffix` is not a scalar or vector (or TensorShape).
|
||
|
ValueError: if prefix or suffix was `None` and asked for dynamic
|
||
|
Tensors out.
|
||
|
"""
|
||
|
if isinstance(prefix, tf.Tensor):
|
||
|
p = prefix
|
||
|
p_static = tf.get_static_value(prefix)
|
||
|
if p.shape.ndims == 0:
|
||
|
p = tf.compat.v1.expand_dims(p, 0)
|
||
|
elif p.shape.ndims != 1:
|
||
|
raise ValueError(
|
||
|
"Prefix tensor must be either a scalar or vector, "
|
||
|
f"but received tensor: {p}"
|
||
|
)
|
||
|
else:
|
||
|
p = tf.TensorShape(prefix)
|
||
|
p_static = p.as_list() if p.ndims is not None else None
|
||
|
p = (
|
||
|
tf.constant(p.as_list(), dtype=tf.int32)
|
||
|
if p.is_fully_defined()
|
||
|
else None
|
||
|
)
|
||
|
if isinstance(suffix, tf.Tensor):
|
||
|
s = suffix
|
||
|
s_static = tf.get_static_value(suffix)
|
||
|
if s.shape.ndims == 0:
|
||
|
s = tf.compat.v1.expand_dims(s, 0)
|
||
|
elif s.shape.ndims != 1:
|
||
|
raise ValueError(
|
||
|
"suffix tensor must be either a scalar or vector, "
|
||
|
f"but received tensor: {s}"
|
||
|
)
|
||
|
else:
|
||
|
s = tf.TensorShape(suffix)
|
||
|
s_static = s.as_list() if s.ndims is not None else None
|
||
|
s = (
|
||
|
tf.constant(s.as_list(), dtype=tf.int32)
|
||
|
if s.is_fully_defined()
|
||
|
else None
|
||
|
)
|
||
|
|
||
|
if static:
|
||
|
shape = tf.TensorShape(p_static).concatenate(s_static)
|
||
|
shape = shape.as_list() if shape.ndims is not None else None
|
||
|
else:
|
||
|
if p is None or s is None:
|
||
|
raise ValueError(
|
||
|
"Prefix or suffix can't be None. "
|
||
|
f"Received prefix = {prefix} and suffix = {suffix}"
|
||
|
)
|
||
|
shape = tf.concat((p, s), 0)
|
||
|
return shape
|
||
|
|
||
|
|
||
|
def _zero_state_tensors(state_size, batch_size, dtype):
|
||
|
"""Create tensors of zeros based on state_size, batch_size, and dtype."""
|
||
|
|
||
|
def get_state_shape(s):
|
||
|
"""Combine s with batch_size to get a proper tensor shape."""
|
||
|
c = _concat(batch_size, s)
|
||
|
size = tf.zeros(c, dtype=dtype)
|
||
|
if not tf.executing_eagerly():
|
||
|
c_static = _concat(batch_size, s, static=True)
|
||
|
size.set_shape(c_static)
|
||
|
return size
|
||
|
|
||
|
return tf.nest.map_structure(get_state_shape, state_size)
|
||
|
|
||
|
|
||
|
@keras_export(v1=["keras.__internal__.legacy.rnn_cell.RNNCell"])
|
||
|
@tf_export(v1=["nn.rnn_cell.RNNCell"])
|
||
|
class RNNCell(base_layer.Layer):
|
||
|
"""Abstract object representing an RNN cell.
|
||
|
|
||
|
Every `RNNCell` must have the properties below and implement `call` with
|
||
|
the signature `(output, next_state) = call(input, state)`. The optional
|
||
|
third input argument, `scope`, is allowed for backwards compatibility
|
||
|
purposes; but should be left off for new subclasses.
|
||
|
|
||
|
This definition of cell differs from the definition used in the literature.
|
||
|
In the literature, 'cell' refers to an object with a single scalar output.
|
||
|
This definition refers to a horizontal array of such units.
|
||
|
|
||
|
An RNN cell, in the most abstract setting, is anything that has
|
||
|
a state and performs some operation that takes a matrix of inputs.
|
||
|
This operation results in an output matrix with `self.output_size` columns.
|
||
|
If `self.state_size` is an integer, this operation also results in a new
|
||
|
state matrix with `self.state_size` columns. If `self.state_size` is a
|
||
|
(possibly nested tuple of) TensorShape object(s), then it should return a
|
||
|
matching structure of Tensors having shape `[batch_size].concatenate(s)`
|
||
|
for each `s` in `self.batch_size`.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
|
||
|
super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs)
|
||
|
# Attribute that indicates whether the cell is a TF RNN cell, due the
|
||
|
# slight difference between TF and Keras RNN cell. Notably the state is
|
||
|
# not wrapped in a list for TF cell where they are single tensor state,
|
||
|
# whereas keras cell will wrap the state into a list, and call() will
|
||
|
# have to unwrap them.
|
||
|
self._is_tf_rnn_cell = True
|
||
|
|
||
|
def __call__(self, inputs, state, scope=None):
|
||
|
"""Run this RNN cell on inputs, starting from the given state.
|
||
|
|
||
|
Args:
|
||
|
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
|
||
|
state: if `self.state_size` is an integer, this should be a
|
||
|
`2-D Tensor` with shape `[batch_size, self.state_size]`. Otherwise,
|
||
|
if `self.state_size` is a tuple of integers, this should be a tuple
|
||
|
with shapes `[batch_size, s] for s in self.state_size`.
|
||
|
scope: VariableScope for the created subgraph; defaults to class name.
|
||
|
|
||
|
Returns:
|
||
|
A pair containing:
|
||
|
|
||
|
- Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
|
||
|
- New state: Either a single `2-D` tensor, or a tuple of tensors
|
||
|
matching the arity and shapes of `state`.
|
||
|
"""
|
||
|
if scope is not None:
|
||
|
with tf.compat.v1.variable_scope(
|
||
|
scope, custom_getter=self._rnn_get_variable
|
||
|
) as scope:
|
||
|
return super().__call__(inputs, state, scope=scope)
|
||
|
else:
|
||
|
scope_attrname = "rnncell_scope"
|
||
|
scope = getattr(self, scope_attrname, None)
|
||
|
if scope is None:
|
||
|
scope = tf.compat.v1.variable_scope(
|
||
|
tf.compat.v1.get_variable_scope(),
|
||
|
custom_getter=self._rnn_get_variable,
|
||
|
)
|
||
|
setattr(self, scope_attrname, scope)
|
||
|
with scope:
|
||
|
return super().__call__(inputs, state)
|
||
|
|
||
|
def _rnn_get_variable(self, getter, *args, **kwargs):
|
||
|
variable = getter(*args, **kwargs)
|
||
|
if tf.compat.v1.executing_eagerly_outside_functions():
|
||
|
trainable = variable.trainable
|
||
|
else:
|
||
|
trainable = variable in tf.compat.v1.trainable_variables() or (
|
||
|
base_layer_utils.is_split_variable(variable)
|
||
|
and list(variable)[0] in tf.compat.v1.trainable_variables()
|
||
|
)
|
||
|
if trainable and all(
|
||
|
variable is not v for v in self._trainable_weights
|
||
|
):
|
||
|
self._trainable_weights.append(variable)
|
||
|
elif not trainable and all(
|
||
|
variable is not v for v in self._non_trainable_weights
|
||
|
):
|
||
|
self._non_trainable_weights.append(variable)
|
||
|
return variable
|
||
|
|
||
|
@property
|
||
|
def state_size(self):
|
||
|
"""size(s) of state(s) used by this cell.
|
||
|
|
||
|
It can be represented by an Integer, a TensorShape or a tuple of
|
||
|
Integers or TensorShapes.
|
||
|
"""
|
||
|
raise NotImplementedError("Abstract method")
|
||
|
|
||
|
@property
|
||
|
def output_size(self):
|
||
|
"""Integer or TensorShape: size of outputs produced by this cell."""
|
||
|
raise NotImplementedError("Abstract method")
|
||
|
|
||
|
def build(self, _):
|
||
|
# This tells the parent Layer object that it's OK to call
|
||
|
# self.add_weight() inside the call() method.
|
||
|
pass
|
||
|
|
||
|
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
|
||
|
if inputs is not None:
|
||
|
# Validate the given batch_size and dtype against inputs if
|
||
|
# provided.
|
||
|
inputs = tf.convert_to_tensor(inputs, name="inputs")
|
||
|
if batch_size is not None:
|
||
|
if tf.is_tensor(batch_size):
|
||
|
static_batch_size = tf.get_static_value(
|
||
|
batch_size, partial=True
|
||
|
)
|
||
|
else:
|
||
|
static_batch_size = batch_size
|
||
|
if inputs.shape.dims[0].value != static_batch_size:
|
||
|
raise ValueError(
|
||
|
"batch size from input tensor is different from the "
|
||
|
"input param. Input tensor batch: "
|
||
|
f"{inputs.shape.dims[0].value}, "
|
||
|
f"batch_size: {batch_size}"
|
||
|
)
|
||
|
|
||
|
if dtype is not None and inputs.dtype != dtype:
|
||
|
raise ValueError(
|
||
|
"dtype from input tensor is different from the "
|
||
|
f"input param. Input tensor dtype: {inputs.dtype}, "
|
||
|
f"dtype: {dtype}"
|
||
|
)
|
||
|
|
||
|
batch_size = (
|
||
|
inputs.shape.dims[0].value or tf.compat.v1.shape(inputs)[0]
|
||
|
)
|
||
|
dtype = inputs.dtype
|
||
|
if batch_size is None or dtype is None:
|
||
|
raise ValueError(
|
||
|
"batch_size and dtype cannot be None while constructing "
|
||
|
f"initial state: batch_size={batch_size}, dtype={dtype}"
|
||
|
)
|
||
|
return self.zero_state(batch_size, dtype)
|
||
|
|
||
|
def zero_state(self, batch_size, dtype):
|
||
|
"""Return zero-filled state tensor(s).
|
||
|
|
||
|
Args:
|
||
|
batch_size: int, float, or unit Tensor representing the batch size.
|
||
|
dtype: the data type to use for the state.
|
||
|
|
||
|
Returns:
|
||
|
If `state_size` is an int or TensorShape, then the return value is a
|
||
|
`N-D` tensor of shape `[batch_size, state_size]` filled with zeros.
|
||
|
|
||
|
If `state_size` is a nested list or tuple, then the return value is
|
||
|
a nested list or tuple (of the same structure) of `2-D` tensors with
|
||
|
the shapes `[batch_size, s]` for each s in `state_size`.
|
||
|
"""
|
||
|
# Try to use the last cached zero_state. This is done to avoid
|
||
|
# recreating zeros, especially when eager execution is enabled.
|
||
|
state_size = self.state_size
|
||
|
is_eager = tf.executing_eagerly()
|
||
|
if is_eager and _hasattr(self, "_last_zero_state"):
|
||
|
(
|
||
|
last_state_size,
|
||
|
last_batch_size,
|
||
|
last_dtype,
|
||
|
last_output,
|
||
|
) = getattr(self, "_last_zero_state")
|
||
|
if (
|
||
|
last_batch_size == batch_size
|
||
|
and last_dtype == dtype
|
||
|
and last_state_size == state_size
|
||
|
):
|
||
|
return last_output
|
||
|
with backend.name_scope(type(self).__name__ + "ZeroState"):
|
||
|
output = _zero_state_tensors(state_size, batch_size, dtype)
|
||
|
if is_eager:
|
||
|
self._last_zero_state = (state_size, batch_size, dtype, output)
|
||
|
return output
|
||
|
|
||
|
def get_config(self):
|
||
|
return super().get_config()
|
||
|
|
||
|
@property
|
||
|
def _use_input_spec_as_call_signature(self):
|
||
|
# We do not store the shape information for the state argument in the
|
||
|
# call function for legacy RNN cells, so do not generate an input
|
||
|
# signature.
|
||
|
return False
|
||
|
|
||
|
|
||
|
class LayerRNNCell(RNNCell):
|
||
|
"""Subclass of RNNCells that act like proper `tf.Layer` objects.
|
||
|
|
||
|
For backwards compatibility purposes, most `RNNCell` instances allow their
|
||
|
`call` methods to instantiate variables via `tf.compat.v1.get_variable`.
|
||
|
The underlying variable scope thus keeps track of any variables, and
|
||
|
returning cached versions. This is atypical of `tf.layer` objects, which
|
||
|
separate this part of layer building into a `build` method that is only
|
||
|
called once.
|
||
|
|
||
|
Here we provide a subclass for `RNNCell` objects that act exactly as
|
||
|
`Layer` objects do. They must provide a `build` method and their
|
||
|
`call` methods do not access Variables `tf.compat.v1.get_variable`.
|
||
|
"""
|
||
|
|
||
|
def __call__(self, inputs, state, scope=None, *args, **kwargs):
|
||
|
"""Run this RNN cell on inputs, starting from the given state.
|
||
|
|
||
|
Args:
|
||
|
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
|
||
|
state: if `self.state_size` is an integer, this should be a `2-D
|
||
|
Tensor` with shape `[batch_size, self.state_size]`. Otherwise, if
|
||
|
`self.state_size` is a tuple of integers, this should be a tuple
|
||
|
with shapes `[batch_size, s] for s in self.state_size`.
|
||
|
scope: optional cell scope.
|
||
|
*args: Additional positional arguments.
|
||
|
**kwargs: Additional keyword arguments.
|
||
|
|
||
|
Returns:
|
||
|
A pair containing:
|
||
|
|
||
|
- Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
|
||
|
- New state: Either a single `2-D` tensor, or a tuple of tensors
|
||
|
matching the arity and shapes of `state`.
|
||
|
"""
|
||
|
# Bypass RNNCell's variable capturing semantics for LayerRNNCell.
|
||
|
# Instead, it is up to subclasses to provide a proper build
|
||
|
# method. See the class docstring for more details.
|
||
|
return base_layer.Layer.__call__(
|
||
|
self, inputs, state, scope=scope, *args, **kwargs
|
||
|
)
|
||
|
|
||
|
|
||
|
@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicRNNCell"])
|
||
|
@tf_export(v1=["nn.rnn_cell.BasicRNNCell"])
|
||
|
class BasicRNNCell(LayerRNNCell):
|
||
|
"""The most basic RNN cell.
|
||
|
|
||
|
Note that this cell is not optimized for performance.
|
||
|
|
||
|
Args:
|
||
|
num_units: int, The number of units in the RNN cell.
|
||
|
activation: Nonlinearity to use. Default: `tanh`. It could also be string
|
||
|
that is within Keras activation function names.
|
||
|
reuse: (optional) Python boolean describing whether to reuse variables in
|
||
|
an existing scope. If not `True`, and the existing scope already has the
|
||
|
given variables, an error is raised.
|
||
|
name: String, the name of the layer. Layers with the same name will share
|
||
|
weights, but to avoid mistakes we require reuse=True in such cases.
|
||
|
dtype: Default dtype of the layer (default of `None` means use the type of
|
||
|
the first input). Required when `build` is called before `call`.
|
||
|
**kwargs: Dict, keyword named properties for common layer attributes, like
|
||
|
`trainable` etc when constructing the cell from configs of get_config().
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_units,
|
||
|
activation=None,
|
||
|
reuse=None,
|
||
|
name=None,
|
||
|
dtype=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
warnings.warn(
|
||
|
"`tf.nn.rnn_cell.BasicRNNCell` is deprecated and will be "
|
||
|
"removed in a future version. This class "
|
||
|
"is equivalent as `tf.keras.layers.SimpleRNNCell`, "
|
||
|
"and will be replaced by that in Tensorflow 2.0.",
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
super().__init__(_reuse=reuse, name=name, dtype=dtype, **kwargs)
|
||
|
_check_supported_dtypes(self.dtype)
|
||
|
if tf.executing_eagerly() and tf.config.list_logical_devices("GPU"):
|
||
|
logging.warning(
|
||
|
"%s: Note that this cell is not optimized for performance.",
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
# Inputs must be 2-dimensional.
|
||
|
self.input_spec = input_spec.InputSpec(ndim=2)
|
||
|
|
||
|
self._num_units = num_units
|
||
|
if activation:
|
||
|
self._activation = activations.get(activation)
|
||
|
else:
|
||
|
self._activation = tf.tanh
|
||
|
|
||
|
@property
|
||
|
def state_size(self):
|
||
|
return self._num_units
|
||
|
|
||
|
@property
|
||
|
def output_size(self):
|
||
|
return self._num_units
|
||
|
|
||
|
@tf_utils.shape_type_conversion
|
||
|
def build(self, inputs_shape):
|
||
|
if inputs_shape[-1] is None:
|
||
|
raise ValueError(
|
||
|
"Expected inputs.shape[-1] to be known, "
|
||
|
f"received shape: {inputs_shape}"
|
||
|
)
|
||
|
_check_supported_dtypes(self.dtype)
|
||
|
|
||
|
input_depth = inputs_shape[-1]
|
||
|
self._kernel = self.add_weight(
|
||
|
_WEIGHTS_VARIABLE_NAME,
|
||
|
shape=[input_depth + self._num_units, self._num_units],
|
||
|
)
|
||
|
self._bias = self.add_weight(
|
||
|
_BIAS_VARIABLE_NAME,
|
||
|
shape=[self._num_units],
|
||
|
initializer=tf.compat.v1.zeros_initializer(dtype=self.dtype),
|
||
|
)
|
||
|
|
||
|
self.built = True
|
||
|
|
||
|
def call(self, inputs, state):
|
||
|
"""Most basic RNN: output = new_state = act(W * input + U * state +
|
||
|
B)."""
|
||
|
_check_rnn_cell_input_dtypes([inputs, state])
|
||
|
gate_inputs = tf.matmul(tf.concat([inputs, state], 1), self._kernel)
|
||
|
gate_inputs = tf.nn.bias_add(gate_inputs, self._bias)
|
||
|
output = self._activation(gate_inputs)
|
||
|
return output, output
|
||
|
|
||
|
def get_config(self):
|
||
|
config = {
|
||
|
"num_units": self._num_units,
|
||
|
"activation": activations.serialize(self._activation),
|
||
|
"reuse": self._reuse,
|
||
|
}
|
||
|
base_config = super().get_config()
|
||
|
return dict(list(base_config.items()) + list(config.items()))
|
||
|
|
||
|
|
||
|
@keras_export(v1=["keras.__internal__.legacy.rnn_cell.GRUCell"])
|
||
|
@tf_export(v1=["nn.rnn_cell.GRUCell"])
|
||
|
class GRUCell(LayerRNNCell):
|
||
|
"""Gated Recurrent Unit cell.
|
||
|
|
||
|
Note that this cell is not optimized for performance. Please use
|
||
|
`tf.compat.v1.keras.layers.CuDNNGRU` for better performance on GPU, or
|
||
|
`tf.raw_ops.GRUBlockCell` for better performance on CPU.
|
||
|
|
||
|
Args:
|
||
|
num_units: int, The number of units in the GRU cell.
|
||
|
activation: Nonlinearity to use. Default: `tanh`.
|
||
|
reuse: (optional) Python boolean describing whether to reuse variables in
|
||
|
an existing scope. If not `True`, and the existing scope already has
|
||
|
the given variables, an error is raised.
|
||
|
kernel_initializer: (optional) The initializer to use for the weight and
|
||
|
projection matrices.
|
||
|
bias_initializer: (optional) The initializer to use for the bias.
|
||
|
name: String, the name of the layer. Layers with the same name will share
|
||
|
weights, but to avoid mistakes we require reuse=True in such cases.
|
||
|
dtype: Default dtype of the layer (default of `None` means use the type of
|
||
|
the first input). Required when `build` is called before `call`.
|
||
|
**kwargs: Dict, keyword named properties for common layer attributes, like
|
||
|
`trainable` etc when constructing the cell from configs of get_config().
|
||
|
References: Learning Phrase Representations using RNN Encoder Decoder
|
||
|
for Statistical Machine Translation: [Cho et al., 2014]
|
||
|
(https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179)
|
||
|
([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf))
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_units,
|
||
|
activation=None,
|
||
|
reuse=None,
|
||
|
kernel_initializer=None,
|
||
|
bias_initializer=None,
|
||
|
name=None,
|
||
|
dtype=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
warnings.warn(
|
||
|
"`tf.nn.rnn_cell.GRUCell` is deprecated and will be removed "
|
||
|
"in a future version. This class "
|
||
|
"is equivalent as `tf.keras.layers.GRUCell`, "
|
||
|
"and will be replaced by that in Tensorflow 2.0.",
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
super().__init__(_reuse=reuse, name=name, dtype=dtype, **kwargs)
|
||
|
_check_supported_dtypes(self.dtype)
|
||
|
|
||
|
if tf.executing_eagerly() and tf.config.list_logical_devices("GPU"):
|
||
|
logging.warning(
|
||
|
"%s: Note that this cell is not optimized for performance. "
|
||
|
"Please use tf.compat.v1.keras.layers.CuDNNGRU for better "
|
||
|
"performance on GPU.",
|
||
|
self,
|
||
|
)
|
||
|
# Inputs must be 2-dimensional.
|
||
|
self.input_spec = input_spec.InputSpec(ndim=2)
|
||
|
|
||
|
self._num_units = num_units
|
||
|
if activation:
|
||
|
self._activation = activations.get(activation)
|
||
|
else:
|
||
|
self._activation = tf.tanh
|
||
|
self._kernel_initializer = initializers.get(kernel_initializer)
|
||
|
self._bias_initializer = initializers.get(bias_initializer)
|
||
|
|
||
|
@property
|
||
|
def state_size(self):
|
||
|
return self._num_units
|
||
|
|
||
|
@property
|
||
|
def output_size(self):
|
||
|
return self._num_units
|
||
|
|
||
|
@tf_utils.shape_type_conversion
|
||
|
def build(self, inputs_shape):
|
||
|
if inputs_shape[-1] is None:
|
||
|
raise ValueError(
|
||
|
"Expected inputs.shape[-1] to be known, "
|
||
|
f"received shape: {inputs_shape}"
|
||
|
)
|
||
|
_check_supported_dtypes(self.dtype)
|
||
|
input_depth = inputs_shape[-1]
|
||
|
self._gate_kernel = self.add_weight(
|
||
|
f"gates/{_WEIGHTS_VARIABLE_NAME}",
|
||
|
shape=[input_depth + self._num_units, 2 * self._num_units],
|
||
|
initializer=self._kernel_initializer,
|
||
|
)
|
||
|
self._gate_bias = self.add_weight(
|
||
|
f"gates/{_BIAS_VARIABLE_NAME}",
|
||
|
shape=[2 * self._num_units],
|
||
|
initializer=(
|
||
|
self._bias_initializer
|
||
|
if self._bias_initializer is not None
|
||
|
else tf.compat.v1.constant_initializer(1.0, dtype=self.dtype)
|
||
|
),
|
||
|
)
|
||
|
self._candidate_kernel = self.add_weight(
|
||
|
f"candidate/{_WEIGHTS_VARIABLE_NAME}",
|
||
|
shape=[input_depth + self._num_units, self._num_units],
|
||
|
initializer=self._kernel_initializer,
|
||
|
)
|
||
|
self._candidate_bias = self.add_weight(
|
||
|
f"candidate/{_BIAS_VARIABLE_NAME}",
|
||
|
shape=[self._num_units],
|
||
|
initializer=(
|
||
|
self._bias_initializer
|
||
|
if self._bias_initializer is not None
|
||
|
else tf.compat.v1.zeros_initializer(dtype=self.dtype)
|
||
|
),
|
||
|
)
|
||
|
|
||
|
self.built = True
|
||
|
|
||
|
def call(self, inputs, state):
|
||
|
"""Gated recurrent unit (GRU) with nunits cells."""
|
||
|
_check_rnn_cell_input_dtypes([inputs, state])
|
||
|
|
||
|
gate_inputs = tf.matmul(
|
||
|
tf.concat([inputs, state], 1), self._gate_kernel
|
||
|
)
|
||
|
gate_inputs = tf.nn.bias_add(gate_inputs, self._gate_bias)
|
||
|
|
||
|
value = tf.sigmoid(gate_inputs)
|
||
|
r, u = tf.split(value=value, num_or_size_splits=2, axis=1)
|
||
|
|
||
|
r_state = r * state
|
||
|
|
||
|
candidate = tf.matmul(
|
||
|
tf.concat([inputs, r_state], 1), self._candidate_kernel
|
||
|
)
|
||
|
candidate = tf.nn.bias_add(candidate, self._candidate_bias)
|
||
|
|
||
|
c = self._activation(candidate)
|
||
|
new_h = u * state + (1 - u) * c
|
||
|
return new_h, new_h
|
||
|
|
||
|
def get_config(self):
|
||
|
config = {
|
||
|
"num_units": self._num_units,
|
||
|
"kernel_initializer": initializers.serialize(
|
||
|
self._kernel_initializer
|
||
|
),
|
||
|
"bias_initializer": initializers.serialize(self._bias_initializer),
|
||
|
"activation": activations.serialize(self._activation),
|
||
|
"reuse": self._reuse,
|
||
|
}
|
||
|
base_config = super().get_config()
|
||
|
return dict(list(base_config.items()) + list(config.items()))
|
||
|
|
||
|
|
||
|
_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
|
||
|
|
||
|
|
||
|
@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMStateTuple"])
|
||
|
@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"])
|
||
|
class LSTMStateTuple(_LSTMStateTuple):
|
||
|
"""Tuple used by LSTM Cells for `state_size`, `zero_state`, & output state.
|
||
|
|
||
|
Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state
|
||
|
and `h` is the output.
|
||
|
|
||
|
Only used when `state_is_tuple=True`.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ()
|
||
|
|
||
|
@property
|
||
|
def dtype(self):
|
||
|
(c, h) = self
|
||
|
if c.dtype != h.dtype:
|
||
|
raise TypeError(
|
||
|
"Inconsistent dtypes for internal state: "
|
||
|
f"{c.dtype} vs {h.dtype}"
|
||
|
)
|
||
|
return c.dtype
|
||
|
|
||
|
|
||
|
@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicLSTMCell"])
|
||
|
@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"])
|
||
|
class BasicLSTMCell(LayerRNNCell):
|
||
|
"""DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead.
|
||
|
|
||
|
Basic LSTM recurrent network cell.
|
||
|
|
||
|
The implementation is based on
|
||
|
|
||
|
We add forget_bias (default: 1) to the biases of the forget gate in order to
|
||
|
reduce the scale of forgetting in the beginning of the training.
|
||
|
|
||
|
It does not allow cell clipping, a projection layer, and does not
|
||
|
use peep-hole connections: it is the basic baseline.
|
||
|
|
||
|
For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell`
|
||
|
that follows.
|
||
|
|
||
|
Note that this cell is not optimized for performance. Please use
|
||
|
`tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU, or
|
||
|
`tf.raw_ops.LSTMBlockCell` for better performance on CPU.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_units,
|
||
|
forget_bias=1.0,
|
||
|
state_is_tuple=True,
|
||
|
activation=None,
|
||
|
reuse=None,
|
||
|
name=None,
|
||
|
dtype=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
"""Initialize the basic LSTM cell.
|
||
|
|
||
|
Args:
|
||
|
num_units: int, The number of units in the LSTM cell.
|
||
|
forget_bias: float, The bias added to forget gates (see above). Must
|
||
|
set to `0.0` manually when restoring from CudnnLSTM-trained
|
||
|
checkpoints.
|
||
|
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||
|
the `c_state` and `m_state`. If False, they are concatenated along
|
||
|
the column axis. The latter behavior will soon be deprecated.
|
||
|
activation: Activation function of the inner states. Default: `tanh`.
|
||
|
It could also be string that is within Keras activation function
|
||
|
names.
|
||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||
|
in an existing scope. If not `True`, and the existing scope already
|
||
|
has the given variables, an error is raised.
|
||
|
name: String, the name of the layer. Layers with the same name will
|
||
|
share weights, but to avoid mistakes we require reuse=True in such
|
||
|
cases.
|
||
|
dtype: Default dtype of the layer (default of `None` means use the
|
||
|
type of the first input). Required when `build` is called before
|
||
|
`call`.
|
||
|
**kwargs: Dict, keyword named properties for common layer attributes,
|
||
|
like `trainable` etc when constructing the cell from configs of
|
||
|
get_config(). When restoring from CudnnLSTM-trained checkpoints,
|
||
|
must use `CudnnCompatibleLSTMCell` instead.
|
||
|
"""
|
||
|
warnings.warn(
|
||
|
"`tf.nn.rnn_cell.BasicLSTMCell` is deprecated and will be "
|
||
|
"removed in a future version. This class "
|
||
|
"is equivalent as `tf.keras.layers.LSTMCell`, "
|
||
|
"and will be replaced by that in Tensorflow 2.0.",
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
super().__init__(_reuse=reuse, name=name, dtype=dtype, **kwargs)
|
||
|
_check_supported_dtypes(self.dtype)
|
||
|
if not state_is_tuple:
|
||
|
logging.warning(
|
||
|
"%s: Using a concatenated state is slower and will soon be "
|
||
|
"deprecated. Use state_is_tuple=True.",
|
||
|
self,
|
||
|
)
|
||
|
if tf.executing_eagerly() and tf.config.list_logical_devices("GPU"):
|
||
|
logging.warning(
|
||
|
"%s: Note that this cell is not optimized for performance. "
|
||
|
"Please use tf.compat.v1.keras.layers.CuDNNLSTM for better "
|
||
|
"performance on GPU.",
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
# Inputs must be 2-dimensional.
|
||
|
self.input_spec = input_spec.InputSpec(ndim=2)
|
||
|
|
||
|
self._num_units = num_units
|
||
|
self._forget_bias = forget_bias
|
||
|
self._state_is_tuple = state_is_tuple
|
||
|
if activation:
|
||
|
self._activation = activations.get(activation)
|
||
|
else:
|
||
|
self._activation = tf.tanh
|
||
|
|
||
|
@property
|
||
|
def state_size(self):
|
||
|
return (
|
||
|
LSTMStateTuple(self._num_units, self._num_units)
|
||
|
if self._state_is_tuple
|
||
|
else 2 * self._num_units
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def output_size(self):
|
||
|
return self._num_units
|
||
|
|
||
|
@tf_utils.shape_type_conversion
|
||
|
def build(self, inputs_shape):
|
||
|
if inputs_shape[-1] is None:
|
||
|
raise ValueError(
|
||
|
"Expected inputs.shape[-1] to be known, "
|
||
|
f"received shape: {inputs_shape}"
|
||
|
)
|
||
|
_check_supported_dtypes(self.dtype)
|
||
|
input_depth = inputs_shape[-1]
|
||
|
h_depth = self._num_units
|
||
|
self._kernel = self.add_weight(
|
||
|
_WEIGHTS_VARIABLE_NAME,
|
||
|
shape=[input_depth + h_depth, 4 * self._num_units],
|
||
|
)
|
||
|
self._bias = self.add_weight(
|
||
|
_BIAS_VARIABLE_NAME,
|
||
|
shape=[4 * self._num_units],
|
||
|
initializer=tf.compat.v1.zeros_initializer(dtype=self.dtype),
|
||
|
)
|
||
|
|
||
|
self.built = True
|
||
|
|
||
|
def call(self, inputs, state):
|
||
|
"""Long short-term memory cell (LSTM).
|
||
|
|
||
|
Args:
|
||
|
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
|
||
|
state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size,
|
||
|
num_units]`, if `state_is_tuple` has been set to `True`. Otherwise,
|
||
|
a `Tensor` shaped `[batch_size, 2 * num_units]`.
|
||
|
|
||
|
Returns:
|
||
|
A pair containing the new hidden state, and the new state (either a
|
||
|
`LSTMStateTuple` or a concatenated state, depending on
|
||
|
`state_is_tuple`).
|
||
|
"""
|
||
|
_check_rnn_cell_input_dtypes([inputs, state])
|
||
|
|
||
|
sigmoid = tf.sigmoid
|
||
|
one = tf.constant(1, dtype=tf.int32)
|
||
|
# Parameters of gates are concatenated into one multiply for efficiency.
|
||
|
if self._state_is_tuple:
|
||
|
c, h = state
|
||
|
else:
|
||
|
c, h = tf.split(value=state, num_or_size_splits=2, axis=one)
|
||
|
|
||
|
gate_inputs = tf.matmul(tf.concat([inputs, h], 1), self._kernel)
|
||
|
gate_inputs = tf.nn.bias_add(gate_inputs, self._bias)
|
||
|
|
||
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||
|
i, j, f, o = tf.split(value=gate_inputs, num_or_size_splits=4, axis=one)
|
||
|
|
||
|
forget_bias_tensor = tf.constant(self._forget_bias, dtype=f.dtype)
|
||
|
# Note that using `add` and `multiply` instead of `+` and `*` gives a
|
||
|
# performance improvement. So using those at the cost of readability.
|
||
|
add = tf.add
|
||
|
multiply = tf.multiply
|
||
|
new_c = add(
|
||
|
multiply(c, sigmoid(add(f, forget_bias_tensor))),
|
||
|
multiply(sigmoid(i), self._activation(j)),
|
||
|
)
|
||
|
new_h = multiply(self._activation(new_c), sigmoid(o))
|
||
|
|
||
|
if self._state_is_tuple:
|
||
|
new_state = LSTMStateTuple(new_c, new_h)
|
||
|
else:
|
||
|
new_state = tf.concat([new_c, new_h], 1)
|
||
|
return new_h, new_state
|
||
|
|
||
|
def get_config(self):
|
||
|
config = {
|
||
|
"num_units": self._num_units,
|
||
|
"forget_bias": self._forget_bias,
|
||
|
"state_is_tuple": self._state_is_tuple,
|
||
|
"activation": activations.serialize(self._activation),
|
||
|
"reuse": self._reuse,
|
||
|
}
|
||
|
base_config = super().get_config()
|
||
|
return dict(list(base_config.items()) + list(config.items()))
|
||
|
|
||
|
|
||
|
@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMCell"])
|
||
|
@tf_export(v1=["nn.rnn_cell.LSTMCell"])
|
||
|
class LSTMCell(LayerRNNCell):
|
||
|
"""Long short-term memory unit (LSTM) recurrent network cell.
|
||
|
|
||
|
The default non-peephole implementation is based on (Gers et al., 1999).
|
||
|
The peephole implementation is based on (Sak et al., 2014).
|
||
|
|
||
|
The class uses optional peep-hole connections, optional cell clipping, and
|
||
|
an optional projection layer.
|
||
|
|
||
|
Note that this cell is not optimized for performance. Please use
|
||
|
`tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU, or
|
||
|
`tf.raw_ops.LSTMBlockCell` for better performance on CPU.
|
||
|
References:
|
||
|
Long short-term memory recurrent neural network architectures for large
|
||
|
scale acoustic modeling:
|
||
|
[Sak et al., 2014]
|
||
|
(https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html)
|
||
|
([pdf]
|
||
|
(https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf))
|
||
|
Learning to forget:
|
||
|
[Gers et al., 1999]
|
||
|
(http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218)
|
||
|
([pdf](https://arxiv.org/pdf/1409.2329.pdf))
|
||
|
Long Short-Term Memory:
|
||
|
[Hochreiter et al., 1997]
|
||
|
(https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735)
|
||
|
([pdf](http://ml.jku.at/publications/older/3504.pdf))
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_units,
|
||
|
use_peepholes=False,
|
||
|
cell_clip=None,
|
||
|
initializer=None,
|
||
|
num_proj=None,
|
||
|
proj_clip=None,
|
||
|
num_unit_shards=None,
|
||
|
num_proj_shards=None,
|
||
|
forget_bias=1.0,
|
||
|
state_is_tuple=True,
|
||
|
activation=None,
|
||
|
reuse=None,
|
||
|
name=None,
|
||
|
dtype=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
"""Initialize the parameters for an LSTM cell.
|
||
|
|
||
|
Args:
|
||
|
num_units: int, The number of units in the LSTM cell.
|
||
|
use_peepholes: bool, set True to enable diagonal/peephole connections.
|
||
|
cell_clip: (optional) A float value, if provided the cell state is
|
||
|
clipped by this value prior to the cell output activation.
|
||
|
initializer: (optional) The initializer to use for the weight and
|
||
|
projection matrices.
|
||
|
num_proj: (optional) int, The output dimensionality for the projection
|
||
|
matrices. If None, no projection is performed.
|
||
|
proj_clip: (optional) A float value. If `num_proj > 0` and
|
||
|
`proj_clip` is provided, then the projected values are clipped
|
||
|
elementwise to within `[-proj_clip, proj_clip]`.
|
||
|
num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a
|
||
|
variable_scope partitioner instead.
|
||
|
num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a
|
||
|
variable_scope partitioner instead.
|
||
|
forget_bias: Biases of the forget gate are initialized by default to 1
|
||
|
in order to reduce the scale of forgetting at the beginning of the
|
||
|
training. Must set it manually to `0.0` when restoring from
|
||
|
CudnnLSTM trained checkpoints.
|
||
|
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||
|
the `c_state` and `m_state`. If False, they are concatenated along
|
||
|
the column axis. This latter behavior will soon be deprecated.
|
||
|
activation: Activation function of the inner states. Default: `tanh`.
|
||
|
It could also be string that is within Keras activation function
|
||
|
names.
|
||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||
|
in an existing scope. If not `True`, and the existing scope already
|
||
|
has the given variables, an error is raised.
|
||
|
name: String, the name of the layer. Layers with the same name will
|
||
|
share weights, but to avoid mistakes we require reuse=True in such
|
||
|
cases.
|
||
|
dtype: Default dtype of the layer (default of `None` means use the
|
||
|
type of the first input). Required when `build` is called before
|
||
|
`call`.
|
||
|
**kwargs: Dict, keyword named properties for common layer attributes,
|
||
|
like `trainable` etc when constructing the cell from configs of
|
||
|
get_config(). When restoring from CudnnLSTM-trained checkpoints,
|
||
|
use `CudnnCompatibleLSTMCell` instead.
|
||
|
"""
|
||
|
warnings.warn(
|
||
|
"`tf.nn.rnn_cell.LSTMCell` is deprecated and will be "
|
||
|
"removed in a future version. This class "
|
||
|
"is equivalent as `tf.keras.layers.LSTMCell`, "
|
||
|
"and will be replaced by that in Tensorflow 2.0.",
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
super().__init__(_reuse=reuse, name=name, dtype=dtype, **kwargs)
|
||
|
_check_supported_dtypes(self.dtype)
|
||
|
if not state_is_tuple:
|
||
|
logging.warning(
|
||
|
"%s: Using a concatenated state is slower and will soon be "
|
||
|
"deprecated. Use state_is_tuple=True.",
|
||
|
self,
|
||
|
)
|
||
|
if num_unit_shards is not None or num_proj_shards is not None:
|
||
|
logging.warning(
|
||
|
"%s: The num_unit_shards and proj_unit_shards parameters are "
|
||
|
"deprecated and will be removed in Jan 2017. "
|
||
|
"Use a variable scope with a partitioner instead.",
|
||
|
self,
|
||
|
)
|
||
|
if tf.executing_eagerly() and tf.config.list_logical_devices("GPU"):
|
||
|
logging.warning(
|
||
|
"%s: Note that this cell is not optimized for performance. "
|
||
|
"Please use tf.compat.v1.keras.layers.CuDNNLSTM for better "
|
||
|
"performance on GPU.",
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
# Inputs must be 2-dimensional.
|
||
|
self.input_spec = input_spec.InputSpec(ndim=2)
|
||
|
|
||
|
self._num_units = num_units
|
||
|
self._use_peepholes = use_peepholes
|
||
|
self._cell_clip = cell_clip
|
||
|
self._initializer = initializers.get(initializer)
|
||
|
self._num_proj = num_proj
|
||
|
self._proj_clip = proj_clip
|
||
|
self._num_unit_shards = num_unit_shards
|
||
|
self._num_proj_shards = num_proj_shards
|
||
|
self._forget_bias = forget_bias
|
||
|
self._state_is_tuple = state_is_tuple
|
||
|
if activation:
|
||
|
self._activation = activations.get(activation)
|
||
|
else:
|
||
|
self._activation = tf.tanh
|
||
|
|
||
|
if num_proj:
|
||
|
self._state_size = (
|
||
|
LSTMStateTuple(num_units, num_proj)
|
||
|
if state_is_tuple
|
||
|
else num_units + num_proj
|
||
|
)
|
||
|
self._output_size = num_proj
|
||
|
else:
|
||
|
self._state_size = (
|
||
|
LSTMStateTuple(num_units, num_units)
|
||
|
if state_is_tuple
|
||
|
else 2 * num_units
|
||
|
)
|
||
|
self._output_size = num_units
|
||
|
|
||
|
@property
|
||
|
def state_size(self):
|
||
|
return self._state_size
|
||
|
|
||
|
@property
|
||
|
def output_size(self):
|
||
|
return self._output_size
|
||
|
|
||
|
@tf_utils.shape_type_conversion
|
||
|
def build(self, inputs_shape):
|
||
|
if inputs_shape[-1] is None:
|
||
|
raise ValueError(
|
||
|
"Expected inputs.shape[-1] to be known, "
|
||
|
f"received shape: {inputs_shape}"
|
||
|
)
|
||
|
_check_supported_dtypes(self.dtype)
|
||
|
input_depth = inputs_shape[-1]
|
||
|
h_depth = self._num_units if self._num_proj is None else self._num_proj
|
||
|
maybe_partitioner = (
|
||
|
tf.compat.v1.fixed_size_partitioner(self._num_unit_shards)
|
||
|
if self._num_unit_shards is not None
|
||
|
else None
|
||
|
)
|
||
|
self._kernel = self.add_weight(
|
||
|
_WEIGHTS_VARIABLE_NAME,
|
||
|
shape=[input_depth + h_depth, 4 * self._num_units],
|
||
|
initializer=self._initializer,
|
||
|
partitioner=maybe_partitioner,
|
||
|
)
|
||
|
if self.dtype is None:
|
||
|
initializer = tf.compat.v1.zeros_initializer
|
||
|
else:
|
||
|
initializer = tf.compat.v1.zeros_initializer(dtype=self.dtype)
|
||
|
self._bias = self.add_weight(
|
||
|
_BIAS_VARIABLE_NAME,
|
||
|
shape=[4 * self._num_units],
|
||
|
initializer=initializer,
|
||
|
)
|
||
|
if self._use_peepholes:
|
||
|
self._w_f_diag = self.add_weight(
|
||
|
"w_f_diag",
|
||
|
shape=[self._num_units],
|
||
|
initializer=self._initializer,
|
||
|
)
|
||
|
self._w_i_diag = self.add_weight(
|
||
|
"w_i_diag",
|
||
|
shape=[self._num_units],
|
||
|
initializer=self._initializer,
|
||
|
)
|
||
|
self._w_o_diag = self.add_weight(
|
||
|
"w_o_diag",
|
||
|
shape=[self._num_units],
|
||
|
initializer=self._initializer,
|
||
|
)
|
||
|
|
||
|
if self._num_proj is not None:
|
||
|
maybe_proj_partitioner = (
|
||
|
tf.compat.v1.fixed_size_partitioner(self._num_proj_shards)
|
||
|
if self._num_proj_shards is not None
|
||
|
else None
|
||
|
)
|
||
|
self._proj_kernel = self.add_weight(
|
||
|
f"projection/{_WEIGHTS_VARIABLE_NAME}",
|
||
|
shape=[self._num_units, self._num_proj],
|
||
|
initializer=self._initializer,
|
||
|
partitioner=maybe_proj_partitioner,
|
||
|
)
|
||
|
|
||
|
self.built = True
|
||
|
|
||
|
def call(self, inputs, state):
|
||
|
"""Run one step of LSTM.
|
||
|
|
||
|
Args:
|
||
|
inputs: input Tensor, must be 2-D, `[batch, input_size]`.
|
||
|
state: if `state_is_tuple` is False, this must be a state Tensor,
|
||
|
`2-D, [batch, state_size]`. If `state_is_tuple` is True, this must
|
||
|
be a tuple of state Tensors, both `2-D`, with column sizes `c_state`
|
||
|
and `m_state`.
|
||
|
|
||
|
Returns:
|
||
|
A tuple containing:
|
||
|
|
||
|
- A `2-D, [batch, output_dim]`, Tensor representing the output of the
|
||
|
LSTM after reading `inputs` when previous state was `state`.
|
||
|
Here output_dim is:
|
||
|
num_proj if num_proj was set,
|
||
|
num_units otherwise.
|
||
|
- Tensor(s) representing the new state of LSTM after reading `inputs`
|
||
|
when the previous state was `state`. Same type and shape(s) as
|
||
|
`state`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If input size cannot be inferred from inputs via
|
||
|
static shape inference.
|
||
|
"""
|
||
|
_check_rnn_cell_input_dtypes([inputs, state])
|
||
|
|
||
|
num_proj = self._num_units if self._num_proj is None else self._num_proj
|
||
|
sigmoid = tf.sigmoid
|
||
|
|
||
|
if self._state_is_tuple:
|
||
|
(c_prev, m_prev) = state
|
||
|
else:
|
||
|
c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
|
||
|
m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj])
|
||
|
|
||
|
input_size = inputs.get_shape().with_rank(2).dims[1].value
|
||
|
if input_size is None:
|
||
|
raise ValueError(
|
||
|
"Could not infer input size from inputs.get_shape()[-1]."
|
||
|
f"Received input shape: {inputs.get_shape()}"
|
||
|
)
|
||
|
|
||
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||
|
lstm_matrix = tf.matmul(tf.concat([inputs, m_prev], 1), self._kernel)
|
||
|
lstm_matrix = tf.nn.bias_add(lstm_matrix, self._bias)
|
||
|
|
||
|
i, j, f, o = tf.split(value=lstm_matrix, num_or_size_splits=4, axis=1)
|
||
|
# Diagonal connections
|
||
|
if self._use_peepholes:
|
||
|
c = sigmoid(
|
||
|
f + self._forget_bias + self._w_f_diag * c_prev
|
||
|
) * c_prev + sigmoid(
|
||
|
i + self._w_i_diag * c_prev
|
||
|
) * self._activation(
|
||
|
j
|
||
|
)
|
||
|
else:
|
||
|
c = sigmoid(f + self._forget_bias) * c_prev + sigmoid(
|
||
|
i
|
||
|
) * self._activation(j)
|
||
|
|
||
|
if self._cell_clip is not None:
|
||
|
|
||
|
c = tf.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
||
|
|
||
|
if self._use_peepholes:
|
||
|
m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
|
||
|
else:
|
||
|
m = sigmoid(o) * self._activation(c)
|
||
|
|
||
|
if self._num_proj is not None:
|
||
|
m = tf.matmul(m, self._proj_kernel)
|
||
|
|
||
|
if self._proj_clip is not None:
|
||
|
|
||
|
m = tf.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
||
|
|
||
|
new_state = (
|
||
|
LSTMStateTuple(c, m)
|
||
|
if self._state_is_tuple
|
||
|
else tf.concat([c, m], 1)
|
||
|
)
|
||
|
return m, new_state
|
||
|
|
||
|
def get_config(self):
|
||
|
config = {
|
||
|
"num_units": self._num_units,
|
||
|
"use_peepholes": self._use_peepholes,
|
||
|
"cell_clip": self._cell_clip,
|
||
|
"initializer": initializers.serialize(self._initializer),
|
||
|
"num_proj": self._num_proj,
|
||
|
"proj_clip": self._proj_clip,
|
||
|
"num_unit_shards": self._num_unit_shards,
|
||
|
"num_proj_shards": self._num_proj_shards,
|
||
|
"forget_bias": self._forget_bias,
|
||
|
"state_is_tuple": self._state_is_tuple,
|
||
|
"activation": activations.serialize(self._activation),
|
||
|
"reuse": self._reuse,
|
||
|
}
|
||
|
base_config = super().get_config()
|
||
|
return dict(list(base_config.items()) + list(config.items()))
|
||
|
|
||
|
|
||
|
@keras_export(v1=["keras.__internal__.legacy.rnn_cell.MultiRNNCell"])
|
||
|
@tf_export(v1=["nn.rnn_cell.MultiRNNCell"])
|
||
|
class MultiRNNCell(RNNCell):
|
||
|
"""RNN cell composed sequentially of multiple simple cells.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
num_units = [128, 64]
|
||
|
cells = [BasicLSTMCell(num_units=n) for n in num_units]
|
||
|
stacked_rnn_cell = MultiRNNCell(cells)
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __init__(self, cells, state_is_tuple=True):
|
||
|
"""Create a RNN cell composed sequentially of a number of RNNCells.
|
||
|
|
||
|
Args:
|
||
|
cells: list of RNNCells that will be composed in this order.
|
||
|
state_is_tuple: If True, accepted and returned states are n-tuples,
|
||
|
where `n = len(cells)`. If False, the states are all concatenated
|
||
|
along the column axis. This latter behavior will soon be
|
||
|
deprecated.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if cells is empty (not allowed), or at least one of the
|
||
|
cells returns a state tuple but the flag `state_is_tuple` is
|
||
|
`False`.
|
||
|
"""
|
||
|
logging.warning(
|
||
|
"`tf.nn.rnn_cell.MultiRNNCell` is deprecated. This class "
|
||
|
"is equivalent as `tf.keras.layers.StackedRNNCells`, "
|
||
|
"and will be replaced by that in Tensorflow 2.0."
|
||
|
)
|
||
|
super().__init__()
|
||
|
if not cells:
|
||
|
raise ValueError("Must specify at least one cell for MultiRNNCell.")
|
||
|
if not tf.nest.is_nested(cells):
|
||
|
raise TypeError(
|
||
|
f"cells must be a list or tuple, but received: {cells}."
|
||
|
)
|
||
|
|
||
|
if len(set(id(cell) for cell in cells)) < len(cells):
|
||
|
logging.log_first_n(
|
||
|
logging.WARN,
|
||
|
"At least two cells provided to MultiRNNCell "
|
||
|
"are the same object and will share weights.",
|
||
|
1,
|
||
|
)
|
||
|
|
||
|
self._cells = cells
|
||
|
for cell_number, cell in enumerate(self._cells):
|
||
|
# Add Trackable dependencies on these cells so their variables get
|
||
|
# saved with this object when using object-based saving.
|
||
|
if isinstance(cell, tf.__internal__.tracking.Trackable):
|
||
|
# TODO(allenl): Track down non-Trackable callers.
|
||
|
self._track_trackable(cell, name="cell-%d" % (cell_number,))
|
||
|
self._state_is_tuple = state_is_tuple
|
||
|
if not state_is_tuple:
|
||
|
if any(tf.nest.is_nested(c.state_size) for c in self._cells):
|
||
|
raise ValueError(
|
||
|
"Some cells return tuples of states, but the flag "
|
||
|
"state_is_tuple is not set. "
|
||
|
f"State sizes are: {[c.state_size for c in self._cells]}"
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def state_size(self):
|
||
|
if self._state_is_tuple:
|
||
|
return tuple(cell.state_size for cell in self._cells)
|
||
|
else:
|
||
|
return sum(cell.state_size for cell in self._cells)
|
||
|
|
||
|
@property
|
||
|
def output_size(self):
|
||
|
return self._cells[-1].output_size
|
||
|
|
||
|
def zero_state(self, batch_size, dtype):
|
||
|
with backend.name_scope(type(self).__name__ + "ZeroState"):
|
||
|
if self._state_is_tuple:
|
||
|
return tuple(
|
||
|
cell.zero_state(batch_size, dtype) for cell in self._cells
|
||
|
)
|
||
|
else:
|
||
|
# We know here that state_size of each cell is not a tuple and
|
||
|
# presumably does not contain TensorArrays or anything else
|
||
|
# fancy
|
||
|
return super().zero_state(batch_size, dtype)
|
||
|
|
||
|
@property
|
||
|
def trainable_weights(self):
|
||
|
if not self.trainable:
|
||
|
return []
|
||
|
weights = []
|
||
|
for cell in self._cells:
|
||
|
if isinstance(cell, base_layer.Layer):
|
||
|
weights += cell.trainable_weights
|
||
|
return weights
|
||
|
|
||
|
@property
|
||
|
def non_trainable_weights(self):
|
||
|
weights = []
|
||
|
for cell in self._cells:
|
||
|
if isinstance(cell, base_layer.Layer):
|
||
|
weights += cell.non_trainable_weights
|
||
|
if not self.trainable:
|
||
|
trainable_weights = []
|
||
|
for cell in self._cells:
|
||
|
if isinstance(cell, base_layer.Layer):
|
||
|
trainable_weights += cell.trainable_weights
|
||
|
return trainable_weights + weights
|
||
|
return weights
|
||
|
|
||
|
def call(self, inputs, state):
|
||
|
"""Run this multi-layer cell on inputs, starting from state."""
|
||
|
cur_state_pos = 0
|
||
|
cur_inp = inputs
|
||
|
new_states = []
|
||
|
for i, cell in enumerate(self._cells):
|
||
|
with tf.compat.v1.variable_scope("cell_%d" % i):
|
||
|
if self._state_is_tuple:
|
||
|
if not tf.nest.is_nested(state):
|
||
|
raise ValueError(
|
||
|
"Expected state to be a tuple of length "
|
||
|
f"{len(self.state_size)}"
|
||
|
f", but received: {state}"
|
||
|
)
|
||
|
cur_state = state[i]
|
||
|
else:
|
||
|
cur_state = tf.slice(
|
||
|
state, [0, cur_state_pos], [-1, cell.state_size]
|
||
|
)
|
||
|
cur_state_pos += cell.state_size
|
||
|
cur_inp, new_state = cell(cur_inp, cur_state)
|
||
|
new_states.append(new_state)
|
||
|
|
||
|
new_states = (
|
||
|
tuple(new_states)
|
||
|
if self._state_is_tuple
|
||
|
else tf.concat(new_states, 1)
|
||
|
)
|
||
|
|
||
|
return cur_inp, new_states
|
||
|
|
||
|
|
||
|
def _check_rnn_cell_input_dtypes(inputs):
|
||
|
"""Check whether the input tensors are with supported dtypes.
|
||
|
|
||
|
Default RNN cells only support floats and complex as its dtypes since the
|
||
|
activation function (tanh and sigmoid) only allow those types. This function
|
||
|
will throw a proper error message if the inputs is not in a supported type.
|
||
|
|
||
|
Args:
|
||
|
inputs: tensor or nested structure of tensors that are feed to RNN cell as
|
||
|
input or state.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if any of the input tensor are not having dtypes of float or
|
||
|
complex.
|
||
|
"""
|
||
|
for t in tf.nest.flatten(inputs):
|
||
|
_check_supported_dtypes(t.dtype)
|
||
|
|
||
|
|
||
|
def _check_supported_dtypes(dtype):
|
||
|
if dtype is None:
|
||
|
return
|
||
|
dtype = tf.as_dtype(dtype)
|
||
|
if not (dtype.is_floating or dtype.is_complex):
|
||
|
raise ValueError(
|
||
|
"RNN cell only supports floating point inputs, "
|
||
|
f"but received dtype: {dtype}"
|
||
|
)
|