519 lines
21 KiB
Python
519 lines
21 KiB
Python
![]() |
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Bidirectional wrapper for RNNs."""
|
||
|
|
||
|
|
||
|
import copy
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras.engine.base_layer import Layer
|
||
|
from keras.engine.input_spec import InputSpec
|
||
|
from keras.layers.rnn import rnn_utils
|
||
|
from keras.layers.rnn.base_wrapper import Wrapper
|
||
|
from keras.saving.legacy import serialization
|
||
|
from keras.utils import generic_utils
|
||
|
from keras.utils import tf_inspect
|
||
|
from keras.utils import tf_utils
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
|
||
|
@keras_export("keras.layers.Bidirectional")
|
||
|
class Bidirectional(Wrapper):
|
||
|
"""Bidirectional wrapper for RNNs.
|
||
|
|
||
|
Args:
|
||
|
layer: `keras.layers.RNN` instance, such as `keras.layers.LSTM` or
|
||
|
`keras.layers.GRU`. It could also be a `keras.layers.Layer` instance
|
||
|
that meets the following criteria:
|
||
|
1. Be a sequence-processing layer (accepts 3D+ inputs).
|
||
|
2. Have a `go_backwards`, `return_sequences` and `return_state`
|
||
|
attribute (with the same semantics as for the `RNN` class).
|
||
|
3. Have an `input_spec` attribute.
|
||
|
4. Implement serialization via `get_config()` and `from_config()`.
|
||
|
Note that the recommended way to create new RNN layers is to write a
|
||
|
custom RNN cell and use it with `keras.layers.RNN`, instead of
|
||
|
subclassing `keras.layers.Layer` directly.
|
||
|
- When the `returns_sequences` is true, the output of the masked
|
||
|
timestep will be zero regardless of the layer's original
|
||
|
`zero_output_for_mask` value.
|
||
|
merge_mode: Mode by which outputs of the forward and backward RNNs will be
|
||
|
combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the
|
||
|
outputs will not be combined, they will be returned as a list. Default
|
||
|
value is 'concat'.
|
||
|
backward_layer: Optional `keras.layers.RNN`, or `keras.layers.Layer`
|
||
|
instance to be used to handle backwards input processing.
|
||
|
If `backward_layer` is not provided, the layer instance passed as the
|
||
|
`layer` argument will be used to generate the backward layer
|
||
|
automatically.
|
||
|
Note that the provided `backward_layer` layer should have properties
|
||
|
matching those of the `layer` argument, in particular it should have the
|
||
|
same values for `stateful`, `return_states`, `return_sequences`, etc.
|
||
|
In addition, `backward_layer` and `layer` should have different
|
||
|
`go_backwards` argument values.
|
||
|
A `ValueError` will be raised if these requirements are not met.
|
||
|
|
||
|
Call arguments:
|
||
|
The call arguments for this layer are the same as those of the wrapped RNN
|
||
|
layer.
|
||
|
Beware that when passing the `initial_state` argument during the call of
|
||
|
this layer, the first half in the list of elements in the `initial_state`
|
||
|
list will be passed to the forward RNN call and the last half in the list
|
||
|
of elements will be passed to the backward RNN call.
|
||
|
|
||
|
Raises:
|
||
|
ValueError:
|
||
|
1. If `layer` or `backward_layer` is not a `Layer` instance.
|
||
|
2. In case of invalid `merge_mode` argument.
|
||
|
3. If `backward_layer` has mismatched properties compared to `layer`.
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
model = Sequential()
|
||
|
model.add(Bidirectional(LSTM(10, return_sequences=True),
|
||
|
input_shape=(5, 10)))
|
||
|
model.add(Bidirectional(LSTM(10)))
|
||
|
model.add(Dense(5))
|
||
|
model.add(Activation('softmax'))
|
||
|
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
|
||
|
|
||
|
# With custom backward layer
|
||
|
model = Sequential()
|
||
|
forward_layer = LSTM(10, return_sequences=True)
|
||
|
backward_layer = LSTM(10, activation='relu', return_sequences=True,
|
||
|
go_backwards=True)
|
||
|
model.add(Bidirectional(forward_layer, backward_layer=backward_layer,
|
||
|
input_shape=(5, 10)))
|
||
|
model.add(Dense(5))
|
||
|
model.add(Activation('softmax'))
|
||
|
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
layer,
|
||
|
merge_mode="concat",
|
||
|
weights=None,
|
||
|
backward_layer=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
if not isinstance(layer, Layer):
|
||
|
raise ValueError(
|
||
|
"Please initialize `Bidirectional` layer with a "
|
||
|
f"`tf.keras.layers.Layer` instance. Received: {layer}"
|
||
|
)
|
||
|
if backward_layer is not None and not isinstance(backward_layer, Layer):
|
||
|
raise ValueError(
|
||
|
"`backward_layer` need to be a `tf.keras.layers.Layer` "
|
||
|
f"instance. Received: {backward_layer}"
|
||
|
)
|
||
|
if merge_mode not in ["sum", "mul", "ave", "concat", None]:
|
||
|
raise ValueError(
|
||
|
f"Invalid merge mode. Received: {merge_mode}. "
|
||
|
"Merge mode should be one of "
|
||
|
'{"sum", "mul", "ave", "concat", None}'
|
||
|
)
|
||
|
# We don't want to track `layer` since we're already tracking the two
|
||
|
# copies of it we actually run.
|
||
|
self._setattr_tracking = False
|
||
|
super().__init__(layer, **kwargs)
|
||
|
self._setattr_tracking = True
|
||
|
|
||
|
# Recreate the forward layer from the original layer config, so that it
|
||
|
# will not carry over any state from the layer.
|
||
|
self.forward_layer = self._recreate_layer_from_config(layer)
|
||
|
|
||
|
if backward_layer is None:
|
||
|
self.backward_layer = self._recreate_layer_from_config(
|
||
|
layer, go_backwards=True
|
||
|
)
|
||
|
else:
|
||
|
self.backward_layer = backward_layer
|
||
|
|
||
|
# Keep the custom backward layer config, so that we can save it
|
||
|
# later. The layer's name might be updated below with prefix
|
||
|
# 'backward_', and we want to preserve the original config.
|
||
|
self._backward_layer_config = serialization.serialize_keras_object(
|
||
|
backward_layer
|
||
|
)
|
||
|
|
||
|
self.forward_layer._name = "forward_" + self.forward_layer.name
|
||
|
self.backward_layer._name = "backward_" + self.backward_layer.name
|
||
|
|
||
|
self._verify_layer_config()
|
||
|
|
||
|
def force_zero_output_for_mask(layer):
|
||
|
# Force the zero_output_for_mask to be True if returning sequences.
|
||
|
if getattr(layer, "zero_output_for_mask", None) is not None:
|
||
|
layer.zero_output_for_mask = layer.return_sequences
|
||
|
|
||
|
force_zero_output_for_mask(self.forward_layer)
|
||
|
force_zero_output_for_mask(self.backward_layer)
|
||
|
|
||
|
self.merge_mode = merge_mode
|
||
|
if weights:
|
||
|
nw = len(weights)
|
||
|
self.forward_layer.initial_weights = weights[: nw // 2]
|
||
|
self.backward_layer.initial_weights = weights[nw // 2 :]
|
||
|
self.stateful = layer.stateful
|
||
|
self.return_sequences = layer.return_sequences
|
||
|
self.return_state = layer.return_state
|
||
|
self.supports_masking = True
|
||
|
self._trainable = kwargs.get("trainable", layer.trainable)
|
||
|
self._num_constants = 0
|
||
|
self.input_spec = layer.input_spec
|
||
|
|
||
|
@property
|
||
|
def _use_input_spec_as_call_signature(self):
|
||
|
return self.layer._use_input_spec_as_call_signature
|
||
|
|
||
|
def _verify_layer_config(self):
|
||
|
"""Ensure the forward and backward layers have valid common property."""
|
||
|
if self.forward_layer.go_backwards == self.backward_layer.go_backwards:
|
||
|
raise ValueError(
|
||
|
"Forward layer and backward layer should have different "
|
||
|
"`go_backwards` value."
|
||
|
"forward_layer.go_backwards = "
|
||
|
f"{self.forward_layer.go_backwards},"
|
||
|
"backward_layer.go_backwards = "
|
||
|
f"{self.backward_layer.go_backwards}"
|
||
|
)
|
||
|
|
||
|
common_attributes = ("stateful", "return_sequences", "return_state")
|
||
|
for a in common_attributes:
|
||
|
forward_value = getattr(self.forward_layer, a)
|
||
|
backward_value = getattr(self.backward_layer, a)
|
||
|
if forward_value != backward_value:
|
||
|
raise ValueError(
|
||
|
"Forward layer and backward layer are expected to have "
|
||
|
f'the same value for attribute "{a}", got '
|
||
|
f'"{forward_value}" for forward layer and '
|
||
|
f'"{backward_value}" for backward layer'
|
||
|
)
|
||
|
|
||
|
def _recreate_layer_from_config(self, layer, go_backwards=False):
|
||
|
# When recreating the layer from its config, it is possible that the
|
||
|
# layer is a RNN layer that contains custom cells. In this case we
|
||
|
# inspect the layer and pass the custom cell class as part of the
|
||
|
# `custom_objects` argument when calling `from_config`. See
|
||
|
# https://github.com/tensorflow/tensorflow/issues/26581 for more detail.
|
||
|
config = layer.get_config()
|
||
|
if go_backwards:
|
||
|
config["go_backwards"] = not config["go_backwards"]
|
||
|
if (
|
||
|
"custom_objects"
|
||
|
in tf_inspect.getfullargspec(layer.__class__.from_config).args
|
||
|
):
|
||
|
custom_objects = {}
|
||
|
cell = getattr(layer, "cell", None)
|
||
|
if cell is not None:
|
||
|
custom_objects[cell.__class__.__name__] = cell.__class__
|
||
|
# For StackedRNNCells
|
||
|
stacked_cells = getattr(cell, "cells", [])
|
||
|
for c in stacked_cells:
|
||
|
custom_objects[c.__class__.__name__] = c.__class__
|
||
|
return layer.__class__.from_config(
|
||
|
config, custom_objects=custom_objects
|
||
|
)
|
||
|
else:
|
||
|
return layer.__class__.from_config(config)
|
||
|
|
||
|
@tf_utils.shape_type_conversion
|
||
|
def compute_output_shape(self, input_shape):
|
||
|
output_shape = self.forward_layer.compute_output_shape(input_shape)
|
||
|
if self.return_state:
|
||
|
state_shape = tf_utils.convert_shapes(
|
||
|
output_shape[1:], to_tuples=False
|
||
|
)
|
||
|
output_shape = tf_utils.convert_shapes(
|
||
|
output_shape[0], to_tuples=False
|
||
|
)
|
||
|
else:
|
||
|
output_shape = tf_utils.convert_shapes(
|
||
|
output_shape, to_tuples=False
|
||
|
)
|
||
|
|
||
|
if self.merge_mode == "concat":
|
||
|
output_shape = output_shape.as_list()
|
||
|
output_shape[-1] *= 2
|
||
|
output_shape = tf.TensorShape(output_shape)
|
||
|
elif self.merge_mode is None:
|
||
|
output_shape = [output_shape, copy.copy(output_shape)]
|
||
|
|
||
|
if self.return_state:
|
||
|
if self.merge_mode is None:
|
||
|
return output_shape + state_shape + copy.copy(state_shape)
|
||
|
return [output_shape] + state_shape + copy.copy(state_shape)
|
||
|
return output_shape
|
||
|
|
||
|
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
|
||
|
"""`Bidirectional.__call__` implements the same API as the wrapped
|
||
|
`RNN`."""
|
||
|
inputs, initial_state, constants = rnn_utils.standardize_args(
|
||
|
inputs, initial_state, constants, self._num_constants
|
||
|
)
|
||
|
|
||
|
if isinstance(inputs, list):
|
||
|
if len(inputs) > 1:
|
||
|
initial_state = inputs[1:]
|
||
|
inputs = inputs[0]
|
||
|
|
||
|
if initial_state is None and constants is None:
|
||
|
return super().__call__(inputs, **kwargs)
|
||
|
|
||
|
# Applies the same workaround as in `RNN.__call__`
|
||
|
additional_inputs = []
|
||
|
additional_specs = []
|
||
|
if initial_state is not None:
|
||
|
# Check if `initial_state` can be split into half
|
||
|
num_states = len(initial_state)
|
||
|
if num_states % 2 > 0:
|
||
|
raise ValueError(
|
||
|
"When passing `initial_state` to a Bidirectional RNN, "
|
||
|
"the state should be a list containing the states of "
|
||
|
"the underlying RNNs. "
|
||
|
f"Received: {initial_state}"
|
||
|
)
|
||
|
|
||
|
kwargs["initial_state"] = initial_state
|
||
|
additional_inputs += initial_state
|
||
|
state_specs = tf.nest.map_structure(
|
||
|
lambda state: InputSpec(shape=backend.int_shape(state)),
|
||
|
initial_state,
|
||
|
)
|
||
|
self.forward_layer.state_spec = state_specs[: num_states // 2]
|
||
|
self.backward_layer.state_spec = state_specs[num_states // 2 :]
|
||
|
additional_specs += state_specs
|
||
|
if constants is not None:
|
||
|
kwargs["constants"] = constants
|
||
|
additional_inputs += constants
|
||
|
constants_spec = [
|
||
|
InputSpec(shape=backend.int_shape(constant))
|
||
|
for constant in constants
|
||
|
]
|
||
|
self.forward_layer.constants_spec = constants_spec
|
||
|
self.backward_layer.constants_spec = constants_spec
|
||
|
additional_specs += constants_spec
|
||
|
|
||
|
self._num_constants = len(constants)
|
||
|
self.forward_layer._num_constants = self._num_constants
|
||
|
self.backward_layer._num_constants = self._num_constants
|
||
|
|
||
|
is_keras_tensor = backend.is_keras_tensor(
|
||
|
tf.nest.flatten(additional_inputs)[0]
|
||
|
)
|
||
|
for tensor in tf.nest.flatten(additional_inputs):
|
||
|
if backend.is_keras_tensor(tensor) != is_keras_tensor:
|
||
|
raise ValueError(
|
||
|
"The initial state of a Bidirectional"
|
||
|
" layer cannot be specified with a mix of"
|
||
|
" Keras tensors and non-Keras tensors"
|
||
|
' (a "Keras tensor" is a tensor that was'
|
||
|
" returned by a Keras layer, or by `Input`)"
|
||
|
)
|
||
|
|
||
|
if is_keras_tensor:
|
||
|
# Compute the full input spec, including state
|
||
|
full_input = [inputs] + additional_inputs
|
||
|
# The original input_spec is None since there could be a nested
|
||
|
# tensor input. Update the input_spec to match the inputs.
|
||
|
full_input_spec = [
|
||
|
None for _ in range(len(tf.nest.flatten(inputs)))
|
||
|
] + additional_specs
|
||
|
# Removing kwargs since the value are passed with input list.
|
||
|
kwargs["initial_state"] = None
|
||
|
kwargs["constants"] = None
|
||
|
|
||
|
# Perform the call with temporarily replaced input_spec
|
||
|
original_input_spec = self.input_spec
|
||
|
self.input_spec = full_input_spec
|
||
|
output = super().__call__(full_input, **kwargs)
|
||
|
self.input_spec = original_input_spec
|
||
|
return output
|
||
|
else:
|
||
|
return super().__call__(inputs, **kwargs)
|
||
|
|
||
|
def call(
|
||
|
self,
|
||
|
inputs,
|
||
|
training=None,
|
||
|
mask=None,
|
||
|
initial_state=None,
|
||
|
constants=None,
|
||
|
):
|
||
|
"""`Bidirectional.call` implements the same API as the wrapped `RNN`."""
|
||
|
kwargs = {}
|
||
|
if generic_utils.has_arg(self.layer.call, "training"):
|
||
|
kwargs["training"] = training
|
||
|
if generic_utils.has_arg(self.layer.call, "mask"):
|
||
|
kwargs["mask"] = mask
|
||
|
if generic_utils.has_arg(self.layer.call, "constants"):
|
||
|
kwargs["constants"] = constants
|
||
|
|
||
|
if generic_utils.has_arg(self.layer.call, "initial_state"):
|
||
|
if isinstance(inputs, list) and len(inputs) > 1:
|
||
|
# initial_states are keras tensors, which means they are passed
|
||
|
# in together with inputs as list. The initial_states need to be
|
||
|
# split into forward and backward section, and be feed to layers
|
||
|
# accordingly.
|
||
|
forward_inputs = [inputs[0]]
|
||
|
backward_inputs = [inputs[0]]
|
||
|
pivot = (len(inputs) - self._num_constants) // 2 + 1
|
||
|
# add forward initial state
|
||
|
forward_inputs += inputs[1:pivot]
|
||
|
if not self._num_constants:
|
||
|
# add backward initial state
|
||
|
backward_inputs += inputs[pivot:]
|
||
|
else:
|
||
|
# add backward initial state
|
||
|
backward_inputs += inputs[pivot : -self._num_constants]
|
||
|
# add constants for forward and backward layers
|
||
|
forward_inputs += inputs[-self._num_constants :]
|
||
|
backward_inputs += inputs[-self._num_constants :]
|
||
|
forward_state, backward_state = None, None
|
||
|
if "constants" in kwargs:
|
||
|
kwargs["constants"] = None
|
||
|
elif initial_state is not None:
|
||
|
# initial_states are not keras tensors, eg eager tensor from np
|
||
|
# array. They are only passed in from kwarg initial_state, and
|
||
|
# should be passed to forward/backward layer via kwarg
|
||
|
# initial_state as well.
|
||
|
forward_inputs, backward_inputs = inputs, inputs
|
||
|
half = len(initial_state) // 2
|
||
|
forward_state = initial_state[:half]
|
||
|
backward_state = initial_state[half:]
|
||
|
else:
|
||
|
forward_inputs, backward_inputs = inputs, inputs
|
||
|
forward_state, backward_state = None, None
|
||
|
|
||
|
y = self.forward_layer(
|
||
|
forward_inputs, initial_state=forward_state, **kwargs
|
||
|
)
|
||
|
y_rev = self.backward_layer(
|
||
|
backward_inputs, initial_state=backward_state, **kwargs
|
||
|
)
|
||
|
else:
|
||
|
y = self.forward_layer(inputs, **kwargs)
|
||
|
y_rev = self.backward_layer(inputs, **kwargs)
|
||
|
|
||
|
if self.return_state:
|
||
|
states = y[1:] + y_rev[1:]
|
||
|
y = y[0]
|
||
|
y_rev = y_rev[0]
|
||
|
|
||
|
if self.return_sequences:
|
||
|
time_dim = (
|
||
|
0 if getattr(self.forward_layer, "time_major", False) else 1
|
||
|
)
|
||
|
y_rev = backend.reverse(y_rev, time_dim)
|
||
|
if self.merge_mode == "concat":
|
||
|
output = backend.concatenate([y, y_rev])
|
||
|
elif self.merge_mode == "sum":
|
||
|
output = y + y_rev
|
||
|
elif self.merge_mode == "ave":
|
||
|
output = (y + y_rev) / 2
|
||
|
elif self.merge_mode == "mul":
|
||
|
output = y * y_rev
|
||
|
elif self.merge_mode is None:
|
||
|
output = [y, y_rev]
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Unrecognized value for `merge_mode`. "
|
||
|
f"Received: {self.merge_mode}"
|
||
|
'Expected values are ["concat", "sum", "ave", "mul"]'
|
||
|
)
|
||
|
|
||
|
if self.return_state:
|
||
|
if self.merge_mode is None:
|
||
|
return output + states
|
||
|
return [output] + states
|
||
|
return output
|
||
|
|
||
|
def reset_states(self):
|
||
|
self.forward_layer.reset_states()
|
||
|
self.backward_layer.reset_states()
|
||
|
|
||
|
def build(self, input_shape):
|
||
|
with backend.name_scope(self.forward_layer.name):
|
||
|
self.forward_layer.build(input_shape)
|
||
|
with backend.name_scope(self.backward_layer.name):
|
||
|
self.backward_layer.build(input_shape)
|
||
|
self.built = True
|
||
|
|
||
|
def compute_mask(self, inputs, mask):
|
||
|
if isinstance(mask, list):
|
||
|
mask = mask[0]
|
||
|
if self.return_sequences:
|
||
|
if not self.merge_mode:
|
||
|
output_mask = [mask, mask]
|
||
|
else:
|
||
|
output_mask = mask
|
||
|
else:
|
||
|
output_mask = [None, None] if not self.merge_mode else None
|
||
|
|
||
|
if self.return_state:
|
||
|
states = self.forward_layer.states
|
||
|
state_mask = [None for _ in states]
|
||
|
if isinstance(output_mask, list):
|
||
|
return output_mask + state_mask * 2
|
||
|
return [output_mask] + state_mask * 2
|
||
|
return output_mask
|
||
|
|
||
|
@property
|
||
|
def constraints(self):
|
||
|
constraints = {}
|
||
|
if hasattr(self.forward_layer, "constraints"):
|
||
|
constraints.update(self.forward_layer.constraints)
|
||
|
constraints.update(self.backward_layer.constraints)
|
||
|
return constraints
|
||
|
|
||
|
def get_config(self):
|
||
|
config = {"merge_mode": self.merge_mode}
|
||
|
if self._num_constants:
|
||
|
config["num_constants"] = self._num_constants
|
||
|
|
||
|
if hasattr(self, "_backward_layer_config"):
|
||
|
config["backward_layer"] = self._backward_layer_config
|
||
|
base_config = super().get_config()
|
||
|
return dict(list(base_config.items()) + list(config.items()))
|
||
|
|
||
|
@classmethod
|
||
|
def from_config(cls, config, custom_objects=None):
|
||
|
# Instead of updating the input, create a copy and use that.
|
||
|
config = copy.deepcopy(config)
|
||
|
num_constants = config.pop("num_constants", 0)
|
||
|
# Handle forward layer instantiation (as would parent class).
|
||
|
from keras.layers import deserialize as deserialize_layer
|
||
|
|
||
|
config["layer"] = deserialize_layer(
|
||
|
config["layer"], custom_objects=custom_objects
|
||
|
)
|
||
|
# Handle (optional) backward layer instantiation.
|
||
|
backward_layer_config = config.pop("backward_layer", None)
|
||
|
if backward_layer_config is not None:
|
||
|
backward_layer = deserialize_layer(
|
||
|
backward_layer_config, custom_objects=custom_objects
|
||
|
)
|
||
|
config["backward_layer"] = backward_layer
|
||
|
# Instantiate the wrapper, adjust it and return it.
|
||
|
layer = cls(**config)
|
||
|
layer._num_constants = num_constants
|
||
|
return layer
|