Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/layers/rnn/base_conv_rnn.py
2023-06-19 00:49:18 +02:00

438 lines
18 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for convolutional-recurrent layers."""
import numpy as np
import tensorflow.compat.v2 as tf
from keras import backend
from keras.engine import base_layer
from keras.engine.input_spec import InputSpec
from keras.layers.rnn.base_rnn import RNN
from keras.utils import conv_utils
from keras.utils import generic_utils
from keras.utils import tf_utils
class ConvRNN(RNN):
"""N-Dimensional Base class for convolutional-recurrent layers.
Args:
rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions.
cell: A RNN cell instance. A RNN cell is a class that has: - a
`call(input_at_t, states_at_t)` method, returning `(output_at_t,
states_at_t_plus_1)`. The call method of the cell can also take the
optional argument `constants`, see section "Note on passing external
constants" below. - a `state_size` attribute. This can be a single
integer (single state) in which case it is the number of channels of the
recurrent state (which should be the same as the number of channels of
the cell output). This can also be a list/tuple of integers (one size
per state). In this case, the first entry (`state_size[0]`) should be
the same as the size of the cell output.
return_sequences: Boolean. Whether to return the last output. in the
output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state in addition to the
output.
go_backwards: Boolean (default False). If True, process the input sequence
backwards and return the reversed sequence.
stateful: Boolean (default False). If True, the last state for each sample
at index i in a batch will be used as initial state for the sample of
index i in the following batch.
input_shape: Use this argument to specify the shape of the input when this
layer is the first one in a model.
Call arguments:
inputs: A (2 + `rank`)D tensor.
mask: Binary tensor of shape `(samples, timesteps)` indicating whether a
given timestep should be masked.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is for use with cells that use dropout.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
constants: List of constant tensors to be passed to the cell at each
timestep.
Input shape:
(3 + `rank`)D tensor with shape: `(samples, timesteps, channels,
img_dimensions...)`
if data_format='channels_first' or shape: `(samples, timesteps,
img_dimensions..., channels)` if data_format='channels_last'.
Output shape:
- If `return_state`: a list of tensors. The first tensor is the output.
The remaining tensors are the last states,
each (2 + `rank`)D tensor with shape: `(samples, filters,
new_img_dimensions...)` if data_format='channels_first'
or shape: `(samples, new_img_dimensions..., filters)` if
data_format='channels_last'. img_dimension values might have changed
due to padding.
- If `return_sequences`: (3 + `rank`)D tensor with shape: `(samples,
timesteps, filters, new_img_dimensions...)` if
data_format='channels_first'
or shape: `(samples, timesteps, new_img_dimensions..., filters)` if
data_format='channels_last'.
- Else, (2 + `rank`)D tensor with shape: `(samples, filters,
new_img_dimensions...)` if data_format='channels_first'
or shape: `(samples, new_img_dimensions..., filters)` if
data_format='channels_last'.
Masking: This layer supports masking for input data with a variable number
of timesteps.
Note on using statefulness in RNNs: You can set RNN layers to be 'stateful',
which means that the states computed for the samples in one batch will be
reused as initial states for the samples in the next batch. This assumes a
one-to-one mapping between samples in different successive batches.
To enable statefulness: - Specify `stateful=True` in the layer
constructor.
- Specify a fixed batch size for your model, by passing
- If sequential model: `batch_input_shape=(...)` to the first layer
in your model.
- If functional model with 1 or more Input layers:
`batch_shape=(...)` to all the first layers in your model. This is
the expected shape of your inputs *including the batch size*. It
should be a tuple of integers, e.g. `(32, 10, 100, 100, 32)`. for
rank 2 convolution Note that the image dimensions should be
specified too. - Specify `shuffle=False` when calling fit(). To
reset the states of your model, call `.reset_states()` on either a
specific layer, or on your entire model.
Note on specifying the initial state of RNNs: You can specify the initial
state of RNN layers symbolically by calling them with the keyword argument
`initial_state`. The value of `initial_state` should be a tensor or list
of tensors representing the initial state of the RNN layer. You can
specify the initial state of RNN layers numerically by calling
`reset_states` with the keyword argument `states`. The value of `states`
should be a numpy array or list of numpy arrays representing the initial
state of the RNN layer.
Note on passing external constants to RNNs: You can pass "external"
constants to the cell using the `constants` keyword argument of
`RNN.__call__` (as well as `RNN.call`) method. This requires that the
`cell.call` method accepts the same keyword argument `constants`. Such
constants can be used to condition the cell transformation on additional
static inputs (not changing over time), a.k.a. an attention mechanism.
"""
def __init__(
self,
rank,
cell,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
**kwargs,
):
if unroll:
raise TypeError(
"Unrolling is not possible with convolutional RNNs. "
f"Received: unroll={unroll}"
)
if isinstance(cell, (list, tuple)):
# The StackedConvRNN3DCells isn't implemented yet.
raise TypeError(
"It is not possible at the moment to"
"stack convolutional cells. Only pass a single cell "
"instance as the `cell` argument. Received: "
f"cell={cell}"
)
super().__init__(
cell,
return_sequences,
return_state,
go_backwards,
stateful,
unroll,
**kwargs,
)
self.rank = rank
self.input_spec = [InputSpec(ndim=rank + 3)]
self.states = None
self._num_constants = None
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
cell = self.cell
if cell.data_format == "channels_first":
img_dims = input_shape[3:]
elif cell.data_format == "channels_last":
img_dims = input_shape[2:-1]
norm_img_dims = tuple(
[
conv_utils.conv_output_length(
img_dims[idx],
cell.kernel_size[idx],
padding=cell.padding,
stride=cell.strides[idx],
dilation=cell.dilation_rate[idx],
)
for idx in range(len(img_dims))
]
)
if cell.data_format == "channels_first":
output_shape = input_shape[:2] + (cell.filters,) + norm_img_dims
elif cell.data_format == "channels_last":
output_shape = input_shape[:2] + norm_img_dims + (cell.filters,)
if not self.return_sequences:
output_shape = output_shape[:1] + output_shape[2:]
if self.return_state:
output_shape = [output_shape]
if cell.data_format == "channels_first":
output_shape += [
(input_shape[0], cell.filters) + norm_img_dims
for _ in range(2)
]
elif cell.data_format == "channels_last":
output_shape += [
(input_shape[0],) + norm_img_dims + (cell.filters,)
for _ in range(2)
]
return output_shape
@tf_utils.shape_type_conversion
def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
if self._num_constants is not None:
constants_shape = input_shape[-self._num_constants :]
else:
constants_shape = None
if isinstance(input_shape, list):
input_shape = input_shape[0]
batch_size = input_shape[0] if self.stateful else None
self.input_spec[0] = InputSpec(
shape=(batch_size, None) + input_shape[2 : self.rank + 3]
)
# allow cell (if layer) to build before we set or validate state_spec
if isinstance(self.cell, base_layer.Layer):
step_input_shape = (input_shape[0],) + input_shape[2:]
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
self.cell.build(step_input_shape)
# set or validate state_spec
if hasattr(self.cell.state_size, "__len__"):
state_size = list(self.cell.state_size)
else:
state_size = [self.cell.state_size]
if self.state_spec is not None:
# initial_state was passed in call, check compatibility
if self.cell.data_format == "channels_first":
ch_dim = 1
elif self.cell.data_format == "channels_last":
ch_dim = self.rank + 1
if [spec.shape[ch_dim] for spec in self.state_spec] != state_size:
raise ValueError(
"An `initial_state` was passed that is not compatible with "
"`cell.state_size`. Received state shapes "
f"{[spec.shape for spec in self.state_spec]}. "
f"However `cell.state_size` is {self.cell.state_size}"
)
else:
img_dims = tuple((None for _ in range(self.rank)))
if self.cell.data_format == "channels_first":
self.state_spec = [
InputSpec(shape=(None, dim) + img_dims)
for dim in state_size
]
elif self.cell.data_format == "channels_last":
self.state_spec = [
InputSpec(shape=(None,) + img_dims + (dim,))
for dim in state_size
]
if self.stateful:
self.reset_states()
self.built = True
def get_initial_state(self, inputs):
# (samples, timesteps, img_dims..., filters)
initial_state = backend.zeros_like(inputs)
# (samples, img_dims..., filters)
initial_state = backend.sum(initial_state, axis=1)
shape = list(self.cell.kernel_shape)
shape[-1] = self.cell.filters
initial_state = self.cell.input_conv(
initial_state,
tf.zeros(tuple(shape), initial_state.dtype),
padding=self.cell.padding,
)
if hasattr(self.cell.state_size, "__len__"):
return [initial_state for _ in self.cell.state_size]
else:
return [initial_state]
def call(
self,
inputs,
mask=None,
training=None,
initial_state=None,
constants=None,
):
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
inputs, initial_state, constants = self._process_inputs(
inputs, initial_state, constants
)
if isinstance(mask, list):
mask = mask[0]
timesteps = backend.int_shape(inputs)[1]
kwargs = {}
if generic_utils.has_arg(self.cell.call, "training"):
kwargs["training"] = training
if constants:
if not generic_utils.has_arg(self.cell.call, "constants"):
raise ValueError(
f"RNN cell {self.cell} does not support constants. "
f"Received: constants={constants}"
)
def step(inputs, states):
constants = states[-self._num_constants :]
states = states[: -self._num_constants]
return self.cell.call(
inputs, states, constants=constants, **kwargs
)
else:
def step(inputs, states):
return self.cell.call(inputs, states, **kwargs)
last_output, outputs, states = backend.rnn(
step,
inputs,
initial_state,
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
input_length=timesteps,
return_all_outputs=self.return_sequences,
)
if self.stateful:
updates = [
backend.update(self_state, state)
for self_state, state in zip(self.states, states)
]
self.add_update(updates)
if self.return_sequences:
output = outputs
else:
output = last_output
if self.return_state:
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
return [output] + states
return output
def reset_states(self, states=None):
if not self.stateful:
raise AttributeError("Layer must be stateful.")
input_shape = self.input_spec[0].shape
state_shape = self.compute_output_shape(input_shape)
if self.return_state:
state_shape = state_shape[0]
if self.return_sequences:
state_shape = state_shape[:1].concatenate(state_shape[2:])
if None in state_shape:
raise ValueError(
"If a RNN is stateful, it needs to know "
"its batch size. Specify the batch size "
"of your input tensors: \n"
"- If using a Sequential model, "
"specify the batch size by passing "
"a `batch_input_shape` "
"argument to your first layer.\n"
"- If using the functional API, specify "
"the time dimension by passing a "
"`batch_shape` argument to your Input layer.\n"
"The same thing goes for the number of rows and "
"columns."
)
# helper function
def get_tuple_shape(nb_channels):
result = list(state_shape)
if self.cell.data_format == "channels_first":
result[1] = nb_channels
elif self.cell.data_format == "channels_last":
result[self.rank + 1] = nb_channels
else:
raise KeyError(
"Cell data format must be one of "
'{"channels_first", "channels_last"}. Received: '
f"cell.data_format={self.cell.data_format}"
)
return tuple(result)
# initialize state if None
if self.states[0] is None:
if hasattr(self.cell.state_size, "__len__"):
self.states = [
backend.zeros(get_tuple_shape(dim))
for dim in self.cell.state_size
]
else:
self.states = [
backend.zeros(get_tuple_shape(self.cell.state_size))
]
elif states is None:
if hasattr(self.cell.state_size, "__len__"):
for state, dim in zip(self.states, self.cell.state_size):
backend.set_value(state, np.zeros(get_tuple_shape(dim)))
else:
backend.set_value(
self.states[0],
np.zeros(get_tuple_shape(self.cell.state_size)),
)
else:
if not isinstance(states, (list, tuple)):
states = [states]
if len(states) != len(self.states):
raise ValueError(
f"Layer {self.name} expects {len(self.states)} states, "
f"but it received {len(states)} state values. "
f"States received: {states}"
)
for index, (value, state) in enumerate(zip(states, self.states)):
if hasattr(self.cell.state_size, "__len__"):
dim = self.cell.state_size[index]
else:
dim = self.cell.state_size
if value.shape != get_tuple_shape(dim):
raise ValueError(
"State {index} is incompatible with layer "
f"{self.name}: expected shape={get_tuple_shape(dim)}, "
f"found shape={value.shape}"
)
backend.set_value(state, value)