Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/layers/rnn/base_rnn.py

980 lines
41 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for recurrent layers."""
import collections
import numpy as np
import tensorflow.compat.v2 as tf
from keras import backend
from keras.engine import base_layer
from keras.engine.input_spec import InputSpec
from keras.layers.rnn import rnn_utils
from keras.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
from keras.layers.rnn.stacked_rnn_cells import StackedRNNCells
from keras.saving.legacy import serialization
from keras.saving.legacy.saved_model import layer_serialization
from keras.utils import generic_utils
# isort: off
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
@keras_export("keras.layers.RNN")
class RNN(base_layer.Layer):
"""Base class for recurrent layers.
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
Args:
cell: A RNN cell instance or a list of RNN cell instances.
A RNN cell is a class that has:
- A `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- A `state_size` attribute. This can be a single integer
(single state) in which case it is the size of the recurrent
state. This can also be a list/tuple of integers (one size per state).
The `state_size` can also be TensorShape or tuple/list of
TensorShape, to represent high dimension state.
- A `output_size` attribute. This can be a single integer or a
TensorShape, which represent the shape of the output. For backward
compatible reason, if this attribute is not available for the
cell, the value will be inferred by the first element of the
`state_size`.
- A `get_initial_state(inputs=None, batch_size=None, dtype=None)`
method that creates a tensor meant to be fed to `call()` as the
initial state, if the user didn't specify any initial state via other
means. The returned initial state should have a shape of
[batch_size, cell.state_size]. The cell might choose to create a
tensor full of zeros, or full of other values based on the cell's
implementation.
`inputs` is the input tensor to the RNN layer, which should
contain the batch size as its shape[0], and also dtype. Note that
the shape[0] might be `None` during the graph construction. Either
the `inputs` or the pair of `batch_size` and `dtype` are provided.
`batch_size` is a scalar tensor that represents the batch size
of the inputs. `dtype` is `tf.DType` that represents the dtype of
the inputs.
For backward compatibility, if this method is not implemented
by the cell, the RNN layer will create a zero filled tensor with the
size of [batch_size, cell.state_size].
In the case that `cell` is a list of RNN cell instances, the cells
will be stacked on top of each other in the RNN, resulting in an
efficient stacked RNN.
return_sequences: Boolean (default `False`). Whether to return the last
output in the output sequence, or the full sequence.
return_state: Boolean (default `False`). Whether to return the last state
in addition to the output.
go_backwards: Boolean (default `False`).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default `False`). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default `False`).
If True, the network will be unrolled, else a symbolic loop will be
used. Unrolling can speed-up a RNN, although it tends to be more
memory-intensive. Unrolling is only suitable for short sequences.
time_major: The shape format of the `inputs` and `outputs` tensors.
If True, the inputs and outputs will be in shape
`(timesteps, batch, ...)`, whereas in the False case, it will be
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
efficient because it avoids transposes at the beginning and end of the
RNN calculation. However, most TensorFlow data is batch-major, so by
default this function accepts input and emits output in batch-major
form.
zero_output_for_mask: Boolean (default `False`).
Whether the output should use zeros for the masked timesteps. Note that
this field is only used when `return_sequences` is True and mask is
provided. It can useful if you want to reuse the raw output sequence of
the RNN without interference from the masked timesteps, eg, merging
bidirectional RNNs.
Call arguments:
inputs: Input tensor.
mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether
a given timestep should be masked. An individual `True` entry indicates
that the corresponding timestep should be utilized, while a `False`
entry indicates that the corresponding timestep should be ignored.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is for use with cells that use dropout.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
constants: List of constant tensors to be passed to the cell at each
timestep.
Input shape:
N-D tensor with shape `[batch_size, timesteps, ...]` or
`[timesteps, batch_size, ...]` when time_major is True.
Output shape:
- If `return_state`: a list of tensors. The first tensor is
the output. The remaining tensors are the last states,
each with shape `[batch_size, state_size]`, where `state_size` could
be a high dimension tensor shape.
- If `return_sequences`: N-D tensor with shape
`[batch_size, timesteps, output_size]`, where `output_size` could
be a high dimension tensor shape, or
`[timesteps, batch_size, output_size]` when `time_major` is True.
- Else, N-D tensor with shape `[batch_size, output_size]`, where
`output_size` could be a high dimension tensor shape.
Masking:
This layer supports masking for input data with a variable number
of timesteps. To introduce masks to your data,
use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter
set to `True`.
Note on using statefulness in RNNs:
You can set RNN layers to be 'stateful', which means that the states
computed for the samples in one batch will be reused as initial states
for the samples in the next batch. This assumes a one-to-one mapping
between samples in different successive batches.
To enable statefulness:
- Specify `stateful=True` in the layer constructor.
- Specify a fixed batch size for your model, by passing
If sequential model:
`batch_input_shape=(...)` to the first layer in your model.
Else for functional model with 1 or more Input layers:
`batch_shape=(...)` to all the first layers in your model.
This is the expected shape of your inputs
*including the batch size*.
It should be a tuple of integers, e.g. `(32, 10, 100)`.
- Specify `shuffle=False` when calling `fit()`.
To reset the states of your model, call `.reset_states()` on either
a specific layer, or on your entire model.
Note on specifying the initial state of RNNs:
You can specify the initial state of RNN layers symbolically by
calling them with the keyword argument `initial_state`. The value of
`initial_state` should be a tensor or list of tensors representing
the initial state of the RNN layer.
You can specify the initial state of RNN layers numerically by
calling `reset_states` with the keyword argument `states`. The value of
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.
Note on passing external constants to RNNs:
You can pass "external" constants to the cell using the `constants`
keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
requires that the `cell.call` method accepts the same keyword argument
`constants`. Such constants can be used to condition the cell
transformation on additional static inputs (not changing over time),
a.k.a. an attention mechanism.
Examples:
```python
from keras.layers import RNN
from keras import backend
# First, let's define a RNN Cell, as a layer subclass.
class MinimalRNNCell(keras.layers.Layer):
def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(MinimalRNNCell, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = backend.dot(inputs, self.kernel)
output = h + backend.dot(prev_output, self.recurrent_kernel)
return output, [output]
# Let's use this cell in a RNN layer:
cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)
# Here's how to use the cell to build a stacked RNN:
cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
x = keras.Input((None, 5))
layer = RNN(cells)
y = layer(x)
```
"""
def __init__(
self,
cell,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
time_major=False,
**kwargs,
):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
if "call" not in dir(cell):
raise ValueError(
"Argument `cell` should have a `call` method. "
f"The RNN was passed: cell={cell}"
)
if "state_size" not in dir(cell):
raise ValueError(
"The RNN cell should have a `state_size` attribute "
"(tuple of integers, one integer per RNN state). "
f"Received: cell={cell}"
)
# If True, the output for masked timestep will be zeros, whereas in the
# False case, output from previous timestep is returned for masked
# timestep.
self.zero_output_for_mask = kwargs.pop("zero_output_for_mask", False)
if "input_shape" not in kwargs and (
"input_dim" in kwargs or "input_length" in kwargs
):
input_shape = (
kwargs.pop("input_length", None),
kwargs.pop("input_dim", None),
)
kwargs["input_shape"] = input_shape
super().__init__(**kwargs)
self.cell = cell
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
self.time_major = time_major
self.supports_masking = True
# The input shape is unknown yet, it could have nested tensor inputs,
# and the input spec will be the list of specs for nested inputs, the
# structure of the input_spec will be the same as the input.
self.input_spec = None
self.state_spec = None
self._states = None
self.constants_spec = None
self._num_constants = 0
if stateful:
if tf.distribute.has_strategy():
raise ValueError(
"Stateful RNNs (created with `stateful=True`) "
"are not yet supported with tf.distribute.Strategy."
)
@property
def _use_input_spec_as_call_signature(self):
if self.unroll:
# When the RNN layer is unrolled, the time step shape cannot be
# unknown. The input spec does not define the time step (because
# this layer can be called with any time step value, as long as it
# is not None), so it cannot be used as the call function signature
# when saving to SavedModel.
return False
return super()._use_input_spec_as_call_signature
@property
def states(self):
if self._states is None:
state = tf.nest.map_structure(lambda _: None, self.cell.state_size)
return state if tf.nest.is_nested(self.cell.state_size) else [state]
return self._states
@states.setter
# Automatic tracking catches "self._states" which adds an extra weight and
# breaks HDF5 checkpoints.
@tf.__internal__.tracking.no_automatic_dependency_tracking
def states(self, states):
self._states = states
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
# Check whether the input shape contains any nested shapes. It could be
# (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
# numpy inputs.
try:
input_shape = tf.TensorShape(input_shape)
except (ValueError, TypeError):
# A nested tensor input
input_shape = tf.nest.flatten(input_shape)[0]
batch = input_shape[0]
time_step = input_shape[1]
if self.time_major:
batch, time_step = time_step, batch
if rnn_utils.is_multiple_state(self.cell.state_size):
state_size = self.cell.state_size
else:
state_size = [self.cell.state_size]
def _get_output_shape(flat_output_size):
output_dim = tf.TensorShape(flat_output_size).as_list()
if self.return_sequences:
if self.time_major:
output_shape = tf.TensorShape(
[time_step, batch] + output_dim
)
else:
output_shape = tf.TensorShape(
[batch, time_step] + output_dim
)
else:
output_shape = tf.TensorShape([batch] + output_dim)
return output_shape
if getattr(self.cell, "output_size", None) is not None:
# cell.output_size could be nested structure.
output_shape = tf.nest.flatten(
tf.nest.map_structure(_get_output_shape, self.cell.output_size)
)
output_shape = (
output_shape[0] if len(output_shape) == 1 else output_shape
)
else:
# Note that state_size[0] could be a tensor_shape or int.
output_shape = _get_output_shape(state_size[0])
if self.return_state:
def _get_state_shape(flat_state):
state_shape = [batch] + tf.TensorShape(flat_state).as_list()
return tf.TensorShape(state_shape)
state_shape = tf.nest.map_structure(_get_state_shape, state_size)
return generic_utils.to_list(output_shape) + tf.nest.flatten(
state_shape
)
else:
return output_shape
def compute_mask(self, inputs, mask):
# Time step masks must be the same for each input.
# This is because the mask for an RNN is of size [batch, time_steps, 1],
# and specifies which time steps should be skipped, and a time step
# must be skipped for all inputs.
# TODO(scottzhu): Should we accept multiple different masks?
mask = tf.nest.flatten(mask)[0]
output_mask = mask if self.return_sequences else None
if self.return_state:
state_mask = [None for _ in self.states]
return [output_mask] + state_mask
else:
return output_mask
def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
# The input_shape here could be a nest structure.
# do the tensor_shape to shapes here. The input could be single tensor,
# or a nested structure of tensors.
def get_input_spec(shape):
"""Convert input shape to InputSpec."""
if isinstance(shape, tf.TensorShape):
input_spec_shape = shape.as_list()
else:
input_spec_shape = list(shape)
batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
if not self.stateful:
input_spec_shape[batch_index] = None
input_spec_shape[time_step_index] = None
return InputSpec(shape=tuple(input_spec_shape))
def get_step_input_shape(shape):
if isinstance(shape, tf.TensorShape):
shape = tuple(shape.as_list())
# remove the timestep from the input_shape
return shape[1:] if self.time_major else (shape[0],) + shape[2:]
def get_state_spec(shape):
state_spec_shape = tf.TensorShape(shape).as_list()
# append batch dim
state_spec_shape = [None] + state_spec_shape
return InputSpec(shape=tuple(state_spec_shape))
# Check whether the input shape contains any nested shapes. It could be
# (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
# numpy inputs.
try:
input_shape = tf.TensorShape(input_shape)
except (ValueError, TypeError):
# A nested tensor input
pass
if not tf.nest.is_nested(input_shape):
# This indicates the there is only one input.
if self.input_spec is not None:
self.input_spec[0] = get_input_spec(input_shape)
else:
self.input_spec = [get_input_spec(input_shape)]
step_input_shape = get_step_input_shape(input_shape)
else:
if self.input_spec is not None:
self.input_spec[0] = tf.nest.map_structure(
get_input_spec, input_shape
)
else:
self.input_spec = generic_utils.to_list(
tf.nest.map_structure(get_input_spec, input_shape)
)
step_input_shape = tf.nest.map_structure(
get_step_input_shape, input_shape
)
# allow cell (if layer) to build before we set or validate state_spec.
if isinstance(self.cell, base_layer.Layer) and not self.cell.built:
with backend.name_scope(self.cell.name):
self.cell.build(step_input_shape)
self.cell.built = True
# set or validate state_spec
if rnn_utils.is_multiple_state(self.cell.state_size):
state_size = list(self.cell.state_size)
else:
state_size = [self.cell.state_size]
if self.state_spec is not None:
# initial_state was passed in call, check compatibility
self._validate_state_spec(state_size, self.state_spec)
else:
if tf.nest.is_nested(state_size):
self.state_spec = tf.nest.map_structure(
get_state_spec, state_size
)
else:
self.state_spec = [
InputSpec(shape=[None] + tf.TensorShape(dim).as_list())
for dim in state_size
]
# ensure the generated state_spec is correct.
self._validate_state_spec(state_size, self.state_spec)
if self.stateful:
self.reset_states()
super().build(input_shape)
@staticmethod
def _validate_state_spec(cell_state_sizes, init_state_specs):
"""Validate the state spec between the initial_state and the state_size.
Args:
cell_state_sizes: list, the `state_size` attribute from the cell.
init_state_specs: list, the `state_spec` from the initial_state that
is passed in `call()`.
Raises:
ValueError: When initial state spec is not compatible with the state
size.
"""
validation_error = ValueError(
"An `initial_state` was passed that is not compatible with "
"`cell.state_size`. Received `state_spec`={}; "
"however `cell.state_size` is "
"{}".format(init_state_specs, cell_state_sizes)
)
flat_cell_state_sizes = tf.nest.flatten(cell_state_sizes)
flat_state_specs = tf.nest.flatten(init_state_specs)
if len(flat_cell_state_sizes) != len(flat_state_specs):
raise validation_error
for cell_state_spec, cell_state_size in zip(
flat_state_specs, flat_cell_state_sizes
):
if not tf.TensorShape(
# Ignore the first axis for init_state which is for batch
cell_state_spec.shape[1:]
).is_compatible_with(tf.TensorShape(cell_state_size)):
raise validation_error
@doc_controls.do_not_doc_inheritable
def get_initial_state(self, inputs):
get_initial_state_fn = getattr(self.cell, "get_initial_state", None)
if tf.nest.is_nested(inputs):
# The input are nested sequences. Use the first element in the seq
# to get batch size and dtype.
inputs = tf.nest.flatten(inputs)[0]
input_shape = tf.shape(inputs)
batch_size = input_shape[1] if self.time_major else input_shape[0]
dtype = inputs.dtype
if get_initial_state_fn:
init_state = get_initial_state_fn(
inputs=None, batch_size=batch_size, dtype=dtype
)
else:
init_state = rnn_utils.generate_zero_filled_state(
batch_size, self.cell.state_size, dtype
)
# Keras RNN expect the states in a list, even if it's a single state
# tensor.
if not tf.nest.is_nested(init_state):
init_state = [init_state]
# Force the state to be a list in case it is a namedtuple eg
# LSTMStateTuple.
return list(init_state)
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = rnn_utils.standardize_args(
inputs, initial_state, constants, self._num_constants
)
if initial_state is None and constants is None:
return super().__call__(inputs, **kwargs)
# If any of `initial_state` or `constants` are specified and are Keras
# tensors, then add them to the inputs and temporarily modify the
# input_spec to include them.
additional_inputs = []
additional_specs = []
if initial_state is not None:
additional_inputs += initial_state
self.state_spec = tf.nest.map_structure(
lambda s: InputSpec(shape=backend.int_shape(s)), initial_state
)
additional_specs += self.state_spec
if constants is not None:
additional_inputs += constants
self.constants_spec = [
InputSpec(shape=backend.int_shape(constant))
for constant in constants
]
self._num_constants = len(constants)
additional_specs += self.constants_spec
# additional_inputs can be empty if initial_state or constants are
# provided but empty (e.g. the cell is stateless).
flat_additional_inputs = tf.nest.flatten(additional_inputs)
is_keras_tensor = (
backend.is_keras_tensor(flat_additional_inputs[0])
if flat_additional_inputs
else True
)
for tensor in flat_additional_inputs:
if backend.is_keras_tensor(tensor) != is_keras_tensor:
raise ValueError(
"The initial state or constants of an RNN layer cannot be "
"specified via a mix of Keras tensors and non-Keras "
'tensors (a "Keras tensor" is a tensor that was returned '
"by a Keras layer or by `Input` during Functional "
"model construction). Received: "
f"initial_state={initial_state}, constants={constants}"
)
if is_keras_tensor:
# Compute the full input spec, including state and constants
full_input = [inputs] + additional_inputs
if self.built:
# Keep the input_spec since it has been populated in build()
# method.
full_input_spec = self.input_spec + additional_specs
else:
# The original input_spec is None since there could be a nested
# tensor input. Update the input_spec to match the inputs.
full_input_spec = (
generic_utils.to_list(
tf.nest.map_structure(lambda _: None, inputs)
)
+ additional_specs
)
# Perform the call with temporarily replaced input_spec
self.input_spec = full_input_spec
output = super().__call__(full_input, **kwargs)
# Remove the additional_specs from input spec and keep the rest. It
# is important to keep since the input spec was populated by
# build(), and will be reused in the stateful=True.
self.input_spec = self.input_spec[: -len(additional_specs)]
return output
else:
if initial_state is not None:
kwargs["initial_state"] = initial_state
if constants is not None:
kwargs["constants"] = constants
return super().__call__(inputs, **kwargs)
def call(
self,
inputs,
mask=None,
training=None,
initial_state=None,
constants=None,
):
# The input should be dense, padded with zeros. If a ragged input is fed
# into the layer, it is padded and the row lengths are used for masking.
inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
is_ragged_input = row_lengths is not None
self._validate_args_if_ragged(is_ragged_input, mask)
inputs, initial_state, constants = self._process_inputs(
inputs, initial_state, constants
)
self._maybe_reset_cell_dropout_mask(self.cell)
if isinstance(self.cell, StackedRNNCells):
for cell in self.cell.cells:
self._maybe_reset_cell_dropout_mask(cell)
if mask is not None:
# Time step masks must be the same for each input.
# TODO(scottzhu): Should we accept multiple different masks?
mask = tf.nest.flatten(mask)[0]
if tf.nest.is_nested(inputs):
# In the case of nested input, use the first element for shape
# check.
input_shape = backend.int_shape(tf.nest.flatten(inputs)[0])
else:
input_shape = backend.int_shape(inputs)
timesteps = input_shape[0] if self.time_major else input_shape[1]
if self.unroll and timesteps is None:
raise ValueError(
"Cannot unroll a RNN if the "
"time dimension is undefined. \n"
"- If using a Sequential model, "
"specify the time dimension by passing "
"an `input_shape` or `batch_input_shape` "
"argument to your first layer. If your "
"first layer is an Embedding, you can "
"also use the `input_length` argument.\n"
"- If using the functional API, specify "
"the time dimension by passing a `shape` "
"or `batch_shape` argument to your Input layer."
)
kwargs = {}
if generic_utils.has_arg(self.cell.call, "training"):
kwargs["training"] = training
# TF RNN cells expect single tensor as state instead of list wrapped
# tensor.
is_tf_rnn_cell = getattr(self.cell, "_is_tf_rnn_cell", None) is not None
# Use the __call__ function for callable objects, eg layers, so that it
# will have the proper name scopes for the ops, etc.
cell_call_fn = (
self.cell.__call__ if callable(self.cell) else self.cell.call
)
if constants:
if not generic_utils.has_arg(self.cell.call, "constants"):
raise ValueError(
f"RNN cell {self.cell} does not support constants. "
f"Received: constants={constants}"
)
def step(inputs, states):
constants = states[-self._num_constants :]
states = states[: -self._num_constants]
states = (
states[0] if len(states) == 1 and is_tf_rnn_cell else states
)
output, new_states = cell_call_fn(
inputs, states, constants=constants, **kwargs
)
if not tf.nest.is_nested(new_states):
new_states = [new_states]
return output, new_states
else:
def step(inputs, states):
states = (
states[0] if len(states) == 1 and is_tf_rnn_cell else states
)
output, new_states = cell_call_fn(inputs, states, **kwargs)
if not tf.nest.is_nested(new_states):
new_states = [new_states]
return output, new_states
last_output, outputs, states = backend.rnn(
step,
inputs,
initial_state,
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
input_length=row_lengths if row_lengths is not None else timesteps,
time_major=self.time_major,
zero_output_for_mask=self.zero_output_for_mask,
return_all_outputs=self.return_sequences,
)
if self.stateful:
updates = [
tf.compat.v1.assign(
self_state, tf.cast(state, self_state.dtype)
)
for self_state, state in zip(
tf.nest.flatten(self.states), tf.nest.flatten(states)
)
]
self.add_update(updates)
if self.return_sequences:
output = backend.maybe_convert_to_ragged(
is_ragged_input,
outputs,
row_lengths,
go_backwards=self.go_backwards,
)
else:
output = last_output
if self.return_state:
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
return generic_utils.to_list(output) + states
else:
return output
def _process_inputs(self, inputs, initial_state, constants):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, collections.abc.Sequence) and not isinstance(
inputs, tuple
):
# get initial_state from full input spec
# as they could be copied to multiple GPU.
if not self._num_constants:
initial_state = inputs[1:]
else:
initial_state = inputs[1 : -self._num_constants]
constants = inputs[-self._num_constants :]
if len(initial_state) == 0:
initial_state = None
inputs = inputs[0]
if self.stateful:
if initial_state is not None:
# When layer is stateful and initial_state is provided, check if
# the recorded state is same as the default value (zeros). Use
# the recorded state if it is not same as the default.
non_zero_count = tf.add_n(
[
tf.math.count_nonzero(s)
for s in tf.nest.flatten(self.states)
]
)
# Set strict = True to keep the original structure of the state.
initial_state = tf.compat.v1.cond(
non_zero_count > 0,
true_fn=lambda: self.states,
false_fn=lambda: initial_state,
strict=True,
)
else:
initial_state = self.states
initial_state = tf.nest.map_structure(
# When the layer has a inferred dtype, use the dtype from the
# cell.
lambda v: tf.cast(
v, self.compute_dtype or self.cell.compute_dtype
),
initial_state,
)
elif initial_state is None:
initial_state = self.get_initial_state(inputs)
if len(initial_state) != len(self.states):
raise ValueError(
f"Layer has {len(self.states)} "
f"states but was passed {len(initial_state)} initial "
f"states. Received: initial_state={initial_state}"
)
return inputs, initial_state, constants
def _validate_args_if_ragged(self, is_ragged_input, mask):
if not is_ragged_input:
return
if mask is not None:
raise ValueError(
f"The mask that was passed in was {mask}, which "
"cannot be applied to RaggedTensor inputs. Please "
"make sure that there is no mask injected by upstream "
"layers."
)
if self.unroll:
raise ValueError(
"The input received contains RaggedTensors and does "
"not support unrolling. Disable unrolling by passing "
"`unroll=False` in the RNN Layer constructor."
)
def _maybe_reset_cell_dropout_mask(self, cell):
if isinstance(cell, DropoutRNNCellMixin):
cell.reset_dropout_mask()
cell.reset_recurrent_dropout_mask()
def reset_states(self, states=None):
"""Reset the recorded states for the stateful RNN layer.
Can only be used when RNN layer is constructed with `stateful` = `True`.
Args:
states: Numpy arrays that contains the value for the initial state,
which will be feed to cell at the first time step. When the value is
None, zero filled numpy array will be created based on the cell
state size.
Raises:
AttributeError: When the RNN layer is not stateful.
ValueError: When the batch size of the RNN layer is unknown.
ValueError: When the input numpy array is not compatible with the RNN
layer state, either size wise or dtype wise.
"""
if not self.stateful:
raise AttributeError("Layer must be stateful.")
spec_shape = None
if self.input_spec is not None:
spec_shape = tf.nest.flatten(self.input_spec[0])[0].shape
if spec_shape is None:
# It is possible to have spec shape to be None, eg when construct a
# RNN with a custom cell, or standard RNN layers (LSTM/GRU) which we
# only know it has 3 dim input, but not its full shape spec before
# build().
batch_size = None
else:
batch_size = spec_shape[1] if self.time_major else spec_shape[0]
if not batch_size:
raise ValueError(
"If a RNN is stateful, it needs to know "
"its batch size. Specify the batch size "
"of your input tensors: \n"
"- If using a Sequential model, "
"specify the batch size by passing "
"a `batch_input_shape` "
"argument to your first layer.\n"
"- If using the functional API, specify "
"the batch size by passing a "
"`batch_shape` argument to your Input layer."
)
# initialize state if None
if tf.nest.flatten(self.states)[0] is None:
if getattr(self.cell, "get_initial_state", None):
flat_init_state_values = tf.nest.flatten(
self.cell.get_initial_state(
inputs=None,
batch_size=batch_size,
# Use variable_dtype instead of compute_dtype, since the
# state is stored in a variable
dtype=self.variable_dtype or backend.floatx(),
)
)
else:
flat_init_state_values = tf.nest.flatten(
rnn_utils.generate_zero_filled_state(
batch_size,
self.cell.state_size,
self.variable_dtype or backend.floatx(),
)
)
flat_states_variables = tf.nest.map_structure(
backend.variable, flat_init_state_values
)
self.states = tf.nest.pack_sequence_as(
self.cell.state_size, flat_states_variables
)
if not tf.nest.is_nested(self.states):
self.states = [self.states]
elif states is None:
for state, size in zip(
tf.nest.flatten(self.states),
tf.nest.flatten(self.cell.state_size),
):
backend.set_value(
state,
np.zeros([batch_size] + tf.TensorShape(size).as_list()),
)
else:
flat_states = tf.nest.flatten(self.states)
flat_input_states = tf.nest.flatten(states)
if len(flat_input_states) != len(flat_states):
raise ValueError(
f"Layer {self.name} expects {len(flat_states)} "
f"states, but it received {len(flat_input_states)} "
f"state values. States received: {states}"
)
set_value_tuples = []
for i, (value, state) in enumerate(
zip(flat_input_states, flat_states)
):
if value.shape != state.shape:
raise ValueError(
f"State {i} is incompatible with layer {self.name}: "
f"expected shape={(batch_size, state)} "
f"but found shape={value.shape}"
)
set_value_tuples.append((state, value))
backend.batch_set_value(set_value_tuples)
def get_config(self):
config = {
"return_sequences": self.return_sequences,
"return_state": self.return_state,
"go_backwards": self.go_backwards,
"stateful": self.stateful,
"unroll": self.unroll,
"time_major": self.time_major,
}
if self._num_constants:
config["num_constants"] = self._num_constants
if self.zero_output_for_mask:
config["zero_output_for_mask"] = self.zero_output_for_mask
config["cell"] = serialization.serialize_keras_object(self.cell)
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
from keras.layers import deserialize as deserialize_layer
cell = deserialize_layer(
config.pop("cell"), custom_objects=custom_objects
)
num_constants = config.pop("num_constants", 0)
layer = cls(cell, **config)
layer._num_constants = num_constants
return layer
@property
def _trackable_saved_model_saver(self):
return layer_serialization.RNNSavedModelSaver(self)