Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/layers/rnn/cell_wrappers.py

703 lines
26 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module implementing RNN wrappers."""
# Note that all the APIs under this module are exported as tf.nn.*. This is due
# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported
# here to avoid the cyclic dependency issue for serialization. These APIs will
# probably be deprecated and removed in future since similar API is available in
# existing Keras RNN API.
import hashlib
import numbers
import sys
import types as python_types
import warnings
import tensorflow.compat.v2 as tf
from keras.layers.rnn import lstm
from keras.layers.rnn.abstract_rnn_cell import AbstractRNNCell
from keras.saving import serialization_lib
from keras.saving.legacy import serialization as legacy_serialization
from keras.utils import generic_utils
from keras.utils import tf_inspect
# isort: off
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.util.deprecation import deprecated
class _RNNCellWrapper(AbstractRNNCell):
"""Base class for cells wrappers V2 compatibility.
This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define
wrappers that are compatible with V1 and V2, and defines helper methods for
this purpose.
"""
def __init__(self, cell, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cell = cell
cell_call_spec = tf_inspect.getfullargspec(cell.call)
self._call_spec.expects_training_arg = (
"training" in cell_call_spec.args
) or (cell_call_spec.varkw is not None)
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Calls the wrapped cell and performs the wrapping logic.
This method is called from the wrapper's `call` or `__call__` methods.
Args:
inputs: A tensor with wrapped cell's input.
state: A tensor or tuple of tensors with wrapped cell's state.
cell_call_fn: Wrapped cell's method to use for step computation
(cell's `__call__` or 'call' method).
**kwargs: Additional arguments.
Returns:
A pair containing:
- Output: A tensor with cell's output.
- New state: A tensor or tuple of tensors with new wrapped cell's
state.
"""
raise NotImplementedError
def call(self, inputs, state, **kwargs):
"""Runs the RNN cell step computation.
When `call` is being used, we assume that the wrapper object has been
built, and therefore the wrapped cells has been built via its `build`
method and its `call` method can be used directly.
This allows to use the wrapped cell and the non-wrapped cell
equivalently when using `call` and `build`.
Args:
inputs: A tensor with wrapped cell's input.
state: A tensor or tuple of tensors with wrapped cell's state.
**kwargs: Additional arguments passed to the wrapped cell's `call`.
Returns:
A pair containing:
- Output: A tensor with cell's output.
- New state: A tensor or tuple of tensors with new wrapped cell's
state.
"""
return self._call_wrapped_cell(
inputs, state, cell_call_fn=self.cell.call, **kwargs
)
def build(self, inputs_shape):
"""Builds the wrapped cell."""
self.cell.build(inputs_shape)
self.built = True
@property
def wrapped_cell(self):
return self.cell
@property
def state_size(self):
return self.cell.state_size
@property
def output_size(self):
return self.cell.output_size
def zero_state(self, batch_size, dtype):
with tf.name_scope(type(self).__name__ + "ZeroState"):
return self.cell.zero_state(batch_size, dtype)
def get_config(self):
config = {
"cell": {
"class_name": self.cell.__class__.__name__,
"config": self.cell.get_config(),
},
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
from keras.layers.serialization import deserialize as deserialize_layer
cell = deserialize_layer(
config.pop("cell"), custom_objects=custom_objects
)
return cls(cell, **config)
@deprecated(None, "Please use tf.keras.layers.RNN instead.")
@tf_export("nn.RNNCellDropoutWrapper", v1=[])
class DropoutWrapper(_RNNCellWrapper):
"""Operator adding dropout to inputs and outputs of the given cell."""
def __init__(
self,
cell,
input_keep_prob=1.0,
output_keep_prob=1.0,
state_keep_prob=1.0,
variational_recurrent=False,
input_size=None,
dtype=None,
seed=None,
dropout_state_filter_visitor=None,
**kwargs,
):
"""Create a cell with added input, state, and/or output dropout.
If `variational_recurrent` is set to `True` (**NOT** the default
behavior), then the same dropout mask is applied at every step, as
described in: [A Theoretically Grounded Application of Dropout in
Recurrent Neural Networks. Y. Gal, Z.
Ghahramani](https://arxiv.org/abs/1512.05287).
Otherwise a different dropout mask is applied at every time step.
Note, by default (unless a custom `dropout_state_filter` is provided),
the memory state (`c` component of any `LSTMStateTuple`) passing through
a `DropoutWrapper` is never modified. This behavior is described in the
above article.
Args:
cell: an RNNCell, a projection to output_size is added to it.
input_keep_prob: unit Tensor or float between 0 and 1, input keep
probability; if it is constant and 1, no input dropout will be
added.
output_keep_prob: unit Tensor or float between 0 and 1, output keep
probability; if it is constant and 1, no output dropout will be
added.
state_keep_prob: unit Tensor or float between 0 and 1, output keep
probability; if it is constant and 1, no output dropout will be
added. State dropout is performed on the outgoing states of the
cell. **Note** the state components to which dropout is applied when
`state_keep_prob` is in `(0, 1)` are also determined by the argument
`dropout_state_filter_visitor` (e.g. by default dropout is never
applied to the `c` component of an `LSTMStateTuple`).
variational_recurrent: Python bool. If `True`, then the same dropout
pattern is applied across all time steps per run call. If this
parameter is set, `input_size` **must** be provided.
input_size: (optional) (possibly nested tuple of) `TensorShape`
objects containing the depth(s) of the input tensors expected to be
passed in to the `DropoutWrapper`. Required and used **iff**
`variational_recurrent = True` and `input_keep_prob < 1`.
dtype: (optional) The `dtype` of the input, state, and output tensors.
Required and used **iff** `variational_recurrent = True`.
seed: (optional) integer, the randomness seed.
dropout_state_filter_visitor: (optional), default: (see below).
Function that takes any hierarchical level of the state and returns
a scalar or depth=1 structure of Python booleans describing which
terms in the state should be dropped out. In addition, if the
function returns `True`, dropout is applied across this sublevel.
If the function returns `False`, dropout is not applied across this
entire sublevel. Default behavior: perform dropout on all terms
except the memory (`c`) state of `LSTMCellState` objects, and don't
try to apply dropout to
`TensorArray` objects:
```
def dropout_state_filter_visitor(s):
# Never perform dropout on the c state.
if isinstance(s, LSTMCellState):
return LSTMCellState(c=False, h=True)
elif isinstance(s, TensorArray):
return False
return True
```
**kwargs: dict of keyword arguments for base layer.
Raises:
TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is
provided but not `callable`.
ValueError: if any of the keep_probs are not between 0 and 1.
"""
if isinstance(cell, lstm.LSTMCell):
raise ValueError(
"keras LSTM cell does not work with DropoutWrapper. "
"Please use LSTMCell(dropout=x, recurrent_dropout=y) "
"instead."
)
super().__init__(cell, dtype=dtype, **kwargs)
if dropout_state_filter_visitor is not None and not callable(
dropout_state_filter_visitor
):
raise TypeError(
"dropout_state_filter_visitor must be callable. "
f"Received: {dropout_state_filter_visitor}"
)
self._dropout_state_filter = (
dropout_state_filter_visitor
or _default_dropout_state_filter_visitor
)
with tf.name_scope("DropoutWrapperInit"):
def tensor_and_const_value(v):
tensor_value = tf.convert_to_tensor(v)
const_value = tf.get_static_value(tensor_value)
return (tensor_value, const_value)
for prob, attr in [
(input_keep_prob, "input_keep_prob"),
(state_keep_prob, "state_keep_prob"),
(output_keep_prob, "output_keep_prob"),
]:
tensor_prob, const_prob = tensor_and_const_value(prob)
if const_prob is not None:
if const_prob < 0 or const_prob > 1:
raise ValueError(
f"Parameter {attr} must be between 0 and 1. "
f"Received {const_prob}"
)
setattr(self, f"_{attr}", float(const_prob))
else:
setattr(self, f"_{attr}", tensor_prob)
# Set variational_recurrent, seed before running the code below
self._variational_recurrent = variational_recurrent
self._input_size = input_size
self._seed = seed
self._recurrent_input_noise = None
self._recurrent_state_noise = None
self._recurrent_output_noise = None
if variational_recurrent:
if dtype is None:
raise ValueError(
"When variational_recurrent=True, dtype must be provided"
)
def convert_to_batch_shape(s):
# Prepend a 1 for the batch dimension; for recurrent
# variational dropout we use the same dropout mask for all
# batch elements.
return tf.concat(([1], tf.TensorShape(s).as_list()), 0)
def batch_noise(s, inner_seed):
shape = convert_to_batch_shape(s)
return tf.random.uniform(shape, seed=inner_seed, dtype=dtype)
if (
not isinstance(self._input_keep_prob, numbers.Real)
or self._input_keep_prob < 1.0
):
if input_size is None:
raise ValueError(
"When variational_recurrent=True and input_keep_prob < "
"1.0 or is unknown, input_size must be provided"
)
self._recurrent_input_noise = _enumerated_map_structure_up_to(
input_size,
lambda i, s: batch_noise(
s, inner_seed=self._gen_seed("input", i)
),
input_size,
)
self._recurrent_state_noise = _enumerated_map_structure_up_to(
cell.state_size,
lambda i, s: batch_noise(
s, inner_seed=self._gen_seed("state", i)
),
cell.state_size,
)
self._recurrent_output_noise = _enumerated_map_structure_up_to(
cell.output_size,
lambda i, s: batch_noise(
s, inner_seed=self._gen_seed("output", i)
),
cell.output_size,
)
def _gen_seed(self, salt_prefix, index):
if self._seed is None:
return None
salt = "%s_%d" % (salt_prefix, index)
string = (str(self._seed) + salt).encode("utf-8")
return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
def _variational_recurrent_dropout_value(
self, unused_index, value, noise, keep_prob
):
"""Performs dropout given the pre-calculated noise tensor."""
# uniform [keep_prob, 1.0 + keep_prob)
random_tensor = keep_prob + noise
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
binary_tensor = tf.floor(random_tensor)
ret = tf.divide(value, keep_prob) * binary_tensor
ret.set_shape(value.get_shape())
return ret
def _dropout(
self,
values,
salt_prefix,
recurrent_noise,
keep_prob,
shallow_filtered_substructure=None,
):
"""Decides whether to perform standard dropout or recurrent dropout."""
if shallow_filtered_substructure is None:
# Put something so we traverse the entire structure; inside the
# dropout function we check to see if leafs of this are bool or not.
shallow_filtered_substructure = values
if not self._variational_recurrent:
def dropout(i, do_dropout, v):
if not isinstance(do_dropout, bool) or do_dropout:
return tf.nn.dropout(
v,
rate=1.0 - keep_prob,
seed=self._gen_seed(salt_prefix, i),
)
else:
return v
return _enumerated_map_structure_up_to(
shallow_filtered_substructure,
dropout,
*[shallow_filtered_substructure, values],
)
else:
def dropout(i, do_dropout, v, n):
if not isinstance(do_dropout, bool) or do_dropout:
return self._variational_recurrent_dropout_value(
i, v, n, keep_prob
)
else:
return v
return _enumerated_map_structure_up_to(
shallow_filtered_substructure,
dropout,
*[shallow_filtered_substructure, values, recurrent_noise],
)
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Runs the wrapped cell and applies dropout.
Args:
inputs: A tensor with wrapped cell's input.
state: A tensor or tuple of tensors with wrapped cell's state.
cell_call_fn: Wrapped cell's method to use for step computation
(cell's `__call__` or 'call' method).
**kwargs: Additional arguments.
Returns:
A pair containing:
- Output: A tensor with cell's output.
- New state: A tensor or tuple of tensors with new wrapped cell's
state.
"""
def _should_dropout(p):
return (not isinstance(p, float)) or p < 1
if _should_dropout(self._input_keep_prob):
inputs = self._dropout(
inputs,
"input",
self._recurrent_input_noise,
self._input_keep_prob,
)
output, new_state = cell_call_fn(inputs, state, **kwargs)
if _should_dropout(self._state_keep_prob):
# Identify which subsets of the state to perform dropout on and
# which ones to keep.
shallow_filtered_substructure = (
tf.__internal__.nest.get_traverse_shallow_structure(
self._dropout_state_filter, new_state
)
)
new_state = self._dropout(
new_state,
"state",
self._recurrent_state_noise,
self._state_keep_prob,
shallow_filtered_substructure,
)
if _should_dropout(self._output_keep_prob):
output = self._dropout(
output,
"output",
self._recurrent_output_noise,
self._output_keep_prob,
)
return output, new_state
def get_config(self):
"""Returns the config of the dropout wrapper."""
config = {
"input_keep_prob": self._input_keep_prob,
"output_keep_prob": self._output_keep_prob,
"state_keep_prob": self._state_keep_prob,
"variational_recurrent": self._variational_recurrent,
"input_size": self._input_size,
"seed": self._seed,
}
if self._dropout_state_filter != _default_dropout_state_filter_visitor:
(
function,
function_type,
function_module,
) = _serialize_function_to_config(self._dropout_state_filter)
config.update(
{
"dropout_fn": function,
"dropout_fn_type": function_type,
"dropout_fn_module": function_module,
}
)
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
if "dropout_fn" in config:
config = config.copy()
dropout_state_filter = _parse_config_to_function(
config,
custom_objects,
"dropout_fn",
"dropout_fn_type",
"dropout_fn_module",
)
config.pop("dropout_fn")
config["dropout_state_filter_visitor"] = dropout_state_filter
return super(DropoutWrapper, cls).from_config(
config, custom_objects=custom_objects
)
@deprecated(None, "Please use tf.keras.layers.RNN instead.")
@tf_export("nn.RNNCellResidualWrapper", v1=[])
class ResidualWrapper(_RNNCellWrapper):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
def __init__(self, cell, residual_fn=None, **kwargs):
"""Constructs a `ResidualWrapper` for `cell`.
Args:
cell: An instance of `RNNCell`.
residual_fn: (Optional) The function to map raw cell inputs and raw
cell outputs to the actual cell outputs of the residual network.
Defaults to calling nest.map_structure on (lambda i, o: i + o),
inputs and outputs.
**kwargs: dict of keyword arguments for base layer.
"""
super().__init__(cell, **kwargs)
self._residual_fn = residual_fn
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Run the cell and apply the residual_fn.
Args:
inputs: cell inputs.
state: cell state.
cell_call_fn: Wrapped cell's method to use for step computation
(cell's `__call__` or 'call' method).
**kwargs: Additional arguments passed to the wrapped cell's `call`.
Returns:
Tuple of cell outputs and new state.
Raises:
TypeError: If cell inputs and outputs have different structure (type).
ValueError: If cell inputs and outputs have different structure
(value).
"""
outputs, new_state = cell_call_fn(inputs, state, **kwargs)
# Ensure shapes match
def assert_shape_match(inp, out):
inp.get_shape().assert_is_compatible_with(out.get_shape())
def default_residual_fn(inputs, outputs):
tf.nest.assert_same_structure(inputs, outputs)
tf.nest.map_structure(assert_shape_match, inputs, outputs)
return tf.nest.map_structure(
lambda inp, out: inp + out, inputs, outputs
)
res_outputs = (self._residual_fn or default_residual_fn)(
inputs, outputs
)
return (res_outputs, new_state)
def get_config(self):
"""Returns the config of the residual wrapper."""
if self._residual_fn is not None:
(
function,
function_type,
function_module,
) = _serialize_function_to_config(self._residual_fn)
config = {
"residual_fn": function,
"residual_fn_type": function_type,
"residual_fn_module": function_module,
}
else:
config = {}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
if "residual_fn" in config:
config = config.copy()
residual_function = _parse_config_to_function(
config,
custom_objects,
"residual_fn",
"residual_fn_type",
"residual_fn_module",
)
config["residual_fn"] = residual_function
return super(ResidualWrapper, cls).from_config(
config, custom_objects=custom_objects
)
@deprecated(None, "Please use tf.keras.layers.RNN instead.")
@tf_export("nn.RNNCellDeviceWrapper", v1=[])
class DeviceWrapper(_RNNCellWrapper):
"""Operator that ensures an RNNCell runs on a particular device."""
def __init__(self, cell, device, **kwargs):
"""Construct a `DeviceWrapper` for `cell` with device `device`.
Ensures the wrapped `cell` is called with `tf.device(device)`.
Args:
cell: An instance of `RNNCell`.
device: A device string or function, for passing to `tf.device`.
**kwargs: dict of keyword arguments for base layer.
"""
super().__init__(cell, **kwargs)
self._device = device
def zero_state(self, batch_size, dtype):
with tf.name_scope(type(self).__name__ + "ZeroState"):
with tf.compat.v1.device(self._device):
return self.cell.zero_state(batch_size, dtype)
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Run the cell on specified device."""
with tf.compat.v1.device(self._device):
return cell_call_fn(inputs, state, **kwargs)
def get_config(self):
config = {"device": self._device}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def _serialize_function_to_config(function):
"""Serialize the function for get_config()."""
if isinstance(function, python_types.LambdaType):
output = generic_utils.func_dump(function)
output_type = "lambda"
module = function.__module__
elif callable(function):
output = function.__name__
output_type = "function"
module = function.__module__
else:
raise ValueError(
f"Unrecognized function type for input: {type(function)}"
)
return output, output_type, module
def _parse_config_to_function(
config,
custom_objects,
func_attr_name,
func_type_attr_name,
module_attr_name,
):
"""Reconstruct the function from the config."""
globs = globals()
module = config.pop(module_attr_name, None)
if module in sys.modules:
globs.update(sys.modules[module].__dict__)
elif module is not None:
# Note: we don't know the name of the function if it's a lambda.
warnings.warn(
"{} is not loaded, but a layer uses it. "
"It may cause errors.".format(module),
UserWarning,
stacklevel=2,
)
if custom_objects:
globs.update(custom_objects)
function_type = config.pop(func_type_attr_name)
if function_type == "function":
# Simple lookup in custom objects
function = legacy_serialization.deserialize_keras_object(
config[func_attr_name],
custom_objects=custom_objects,
printable_module_name="function in wrapper",
)
elif function_type == "lambda":
if serialization_lib.in_safe_mode():
raise ValueError(
"Requested the deserialization of a layer with a "
"Python `lambda` inside it. "
"This carries a potential risk of arbitrary code execution "
"and thus it is disallowed by default. If you trust the "
"source of the saved model, you can pass `safe_mode=False` to "
"the loading function in order to allow "
"`lambda` loading."
)
# Unsafe deserialization from bytecode
function = generic_utils.func_load(config[func_attr_name], globs=globs)
else:
raise TypeError(
f"Unknown function type received: {function_type}. "
"Expected types are ['function', 'lambda']"
)
return function
def _default_dropout_state_filter_visitor(substate):
return not isinstance(substate, tf.TensorArray)
def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs):
ix = [0]
def enumerated_fn(*inner_args, **inner_kwargs):
r = map_fn(ix[0], *inner_args, **inner_kwargs)
ix[0] += 1
return r
return tf.__internal__.nest.map_structure_up_to(
shallow_structure, enumerated_fn, *args, **kwargs
)