1109 lines
41 KiB
Python
1109 lines
41 KiB
Python
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
"""Functions for saving and loading a Keras Model from HDF5 format."""
|
||
|
|
||
|
import json
|
||
|
import os
|
||
|
|
||
|
import numpy as np
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras.optimizers import optimizer as optimizer_base
|
||
|
from keras.optimizers import optimizer_v1
|
||
|
from keras.saving.legacy import model_config as model_config_lib
|
||
|
from keras.saving.legacy import saving_utils
|
||
|
from keras.saving.legacy.saved_model import json_utils
|
||
|
from keras.utils.generic_utils import LazyLoader
|
||
|
from keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
|
||
|
try:
|
||
|
import h5py
|
||
|
|
||
|
HDF5_OBJECT_HEADER_LIMIT = 64512
|
||
|
except ImportError:
|
||
|
h5py = None
|
||
|
|
||
|
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
|
||
|
# once the issue with copybara is fixed.
|
||
|
|
||
|
sequential_lib = LazyLoader(
|
||
|
"sequential_lib", globals(), "keras.engine.sequential"
|
||
|
)
|
||
|
|
||
|
|
||
|
def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
|
||
|
"""Saves a model to a HDF5 file.
|
||
|
|
||
|
The saved model contains:
|
||
|
- the model's configuration (topology)
|
||
|
- the model's weights
|
||
|
- the model's optimizer's state (if any)
|
||
|
|
||
|
Thus the saved model can be reinstantiated in
|
||
|
the exact same state, without any of the code
|
||
|
used for model definition or training.
|
||
|
|
||
|
Args:
|
||
|
model: Keras model instance to be saved.
|
||
|
filepath: One of the following:
|
||
|
- String, path where to save the model
|
||
|
- `h5py.File` object where to save the model
|
||
|
overwrite: Whether we should overwrite any existing
|
||
|
model at the target location, or instead
|
||
|
ask the user with a manual prompt.
|
||
|
include_optimizer: If True, save optimizer's state together.
|
||
|
|
||
|
Raises:
|
||
|
ImportError: if h5py is not available.
|
||
|
"""
|
||
|
|
||
|
if h5py is None:
|
||
|
raise ImportError(
|
||
|
"`save_model()` using h5 format requires h5py. Could not "
|
||
|
"import h5py."
|
||
|
)
|
||
|
|
||
|
# TODO(psv) Add warning when we save models that contain non-serializable
|
||
|
# entities like metrics added using `add_metric` and losses added using
|
||
|
# `add_loss.`
|
||
|
if len(model.weights) != len(model._undeduplicated_weights):
|
||
|
logging.warning(
|
||
|
"Found duplicated `Variable`s in Model's `weights`. "
|
||
|
"This is usually caused by `Variable`s being shared by "
|
||
|
"Layers in the Model. These `Variable`s will be treated "
|
||
|
"as separate `Variable`s when the Model is restored. To "
|
||
|
'avoid this, please save with `save_format="tf"`.'
|
||
|
)
|
||
|
|
||
|
if not isinstance(filepath, h5py.File):
|
||
|
# If file exists and should not be overwritten.
|
||
|
if not overwrite and os.path.isfile(filepath):
|
||
|
proceed = ask_to_proceed_with_overwrite(filepath)
|
||
|
if not proceed:
|
||
|
return
|
||
|
|
||
|
# Try creating dir if not exist
|
||
|
dirpath = os.path.dirname(filepath)
|
||
|
if not os.path.exists(dirpath):
|
||
|
tf.io.gfile.makedirs(dirpath)
|
||
|
|
||
|
f = h5py.File(filepath, mode="w")
|
||
|
opened_new_file = True
|
||
|
else:
|
||
|
f = filepath
|
||
|
opened_new_file = False
|
||
|
|
||
|
try:
|
||
|
model_metadata = saving_utils.model_metadata(model, include_optimizer)
|
||
|
for k, v in model_metadata.items():
|
||
|
if isinstance(v, (dict, list, tuple)):
|
||
|
f.attrs[k] = json.dumps(
|
||
|
v, default=json_utils.get_json_type
|
||
|
).encode("utf8")
|
||
|
else:
|
||
|
f.attrs[k] = v
|
||
|
|
||
|
model_weights_group = f.create_group("model_weights")
|
||
|
save_weights_to_hdf5_group(model_weights_group, model)
|
||
|
|
||
|
# TODO(b/128683857): Add integration tests between tf.keras and external
|
||
|
# Keras, to avoid breaking TF.js users.
|
||
|
if (
|
||
|
include_optimizer
|
||
|
and model.optimizer
|
||
|
and not isinstance(model.optimizer, optimizer_v1.TFOptimizer)
|
||
|
):
|
||
|
save_optimizer_weights_to_hdf5_group(f, model.optimizer)
|
||
|
|
||
|
f.flush()
|
||
|
finally:
|
||
|
if opened_new_file:
|
||
|
f.close()
|
||
|
|
||
|
|
||
|
def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
|
||
|
"""Loads a model saved via `save_model_to_hdf5`.
|
||
|
|
||
|
Args:
|
||
|
filepath: One of the following:
|
||
|
- String, path to the saved model
|
||
|
- `h5py.File` object from which to load the model
|
||
|
custom_objects: Optional dictionary mapping names
|
||
|
(strings) to custom classes or functions to be
|
||
|
considered during deserialization.
|
||
|
compile: Boolean, whether to compile the model
|
||
|
after loading.
|
||
|
|
||
|
Returns:
|
||
|
A Keras model instance. If an optimizer was found
|
||
|
as part of the saved model, the model is already
|
||
|
compiled. Otherwise, the model is uncompiled and
|
||
|
a warning will be displayed. When `compile` is set
|
||
|
to False, the compilation is omitted without any
|
||
|
warning.
|
||
|
|
||
|
Raises:
|
||
|
ImportError: if h5py is not available.
|
||
|
ValueError: In case of an invalid savefile.
|
||
|
"""
|
||
|
if h5py is None:
|
||
|
raise ImportError(
|
||
|
"`load_model()` using h5 format requires h5py. Could not "
|
||
|
"import h5py."
|
||
|
)
|
||
|
|
||
|
if not custom_objects:
|
||
|
custom_objects = {}
|
||
|
|
||
|
opened_new_file = not isinstance(filepath, h5py.File)
|
||
|
if opened_new_file:
|
||
|
f = h5py.File(filepath, mode="r")
|
||
|
else:
|
||
|
f = filepath
|
||
|
|
||
|
model = None
|
||
|
try:
|
||
|
# instantiate model
|
||
|
model_config = f.attrs.get("model_config")
|
||
|
if model_config is None:
|
||
|
raise ValueError(
|
||
|
f"No model config found in the file at {filepath}."
|
||
|
)
|
||
|
if hasattr(model_config, "decode"):
|
||
|
model_config = model_config.decode("utf-8")
|
||
|
model_config = json_utils.decode(model_config)
|
||
|
model = model_config_lib.model_from_config(
|
||
|
model_config, custom_objects=custom_objects
|
||
|
)
|
||
|
|
||
|
# set weights
|
||
|
load_weights_from_hdf5_group(f["model_weights"], model)
|
||
|
|
||
|
if compile:
|
||
|
# instantiate optimizer
|
||
|
training_config = f.attrs.get("training_config")
|
||
|
if hasattr(training_config, "decode"):
|
||
|
training_config = training_config.decode("utf-8")
|
||
|
if training_config is None:
|
||
|
logging.warning(
|
||
|
"No training configuration found in the save file, so "
|
||
|
"the model was *not* compiled. Compile it manually."
|
||
|
)
|
||
|
return model
|
||
|
training_config = json_utils.decode(training_config)
|
||
|
|
||
|
# Compile model.
|
||
|
model.compile(
|
||
|
**saving_utils.compile_args_from_training_config(
|
||
|
training_config, custom_objects
|
||
|
),
|
||
|
from_serialized=True,
|
||
|
)
|
||
|
saving_utils.try_build_compiled_arguments(model)
|
||
|
|
||
|
# Set optimizer weights.
|
||
|
if "optimizer_weights" in f:
|
||
|
try:
|
||
|
if isinstance(model.optimizer, optimizer_base.Optimizer):
|
||
|
model.optimizer.build(model.trainable_variables)
|
||
|
else:
|
||
|
model.optimizer._create_all_weights(
|
||
|
model.trainable_variables
|
||
|
)
|
||
|
except (NotImplementedError, AttributeError):
|
||
|
logging.warning(
|
||
|
"Error when creating the weights of optimizer {}, "
|
||
|
"making it impossible to restore the saved optimizer "
|
||
|
"state. As a result, your model is starting with "
|
||
|
"a freshly initialized optimizer."
|
||
|
)
|
||
|
|
||
|
optimizer_weight_values = (
|
||
|
load_optimizer_weights_from_hdf5_group(f)
|
||
|
)
|
||
|
try:
|
||
|
model.optimizer.set_weights(optimizer_weight_values)
|
||
|
except ValueError:
|
||
|
logging.warning(
|
||
|
"Error in loading the saved optimizer "
|
||
|
"state. As a result, your model is "
|
||
|
"starting with a freshly initialized "
|
||
|
"optimizer."
|
||
|
)
|
||
|
finally:
|
||
|
if opened_new_file:
|
||
|
f.close()
|
||
|
return model
|
||
|
|
||
|
|
||
|
def preprocess_weights_for_loading(
|
||
|
layer, weights, original_keras_version=None, original_backend=None
|
||
|
):
|
||
|
"""Preprocess layer weights between different Keras formats.
|
||
|
|
||
|
Converts layers weights from Keras 1 format to Keras 2 and also weights of
|
||
|
cuDNN layers in Keras 2.
|
||
|
|
||
|
Args:
|
||
|
layer: Layer instance.
|
||
|
weights: List of weights values (Numpy arrays).
|
||
|
original_keras_version: Keras version for the weights, as a string.
|
||
|
original_backend: Keras backend the weights were trained with,
|
||
|
as a string.
|
||
|
|
||
|
Returns:
|
||
|
A list of weights values (Numpy arrays).
|
||
|
"""
|
||
|
|
||
|
def convert_nested_bidirectional(weights):
|
||
|
"""Converts layers nested in `Bidirectional` wrapper.
|
||
|
|
||
|
This function uses `preprocess_weights_for_loading()` for converting
|
||
|
layers.
|
||
|
|
||
|
Args:
|
||
|
weights: List of weights values (Numpy arrays).
|
||
|
|
||
|
Returns:
|
||
|
A list of weights values (Numpy arrays).
|
||
|
"""
|
||
|
num_weights_per_layer = len(weights) // 2
|
||
|
forward_weights = preprocess_weights_for_loading(
|
||
|
layer.forward_layer,
|
||
|
weights[:num_weights_per_layer],
|
||
|
original_keras_version,
|
||
|
original_backend,
|
||
|
)
|
||
|
backward_weights = preprocess_weights_for_loading(
|
||
|
layer.backward_layer,
|
||
|
weights[num_weights_per_layer:],
|
||
|
original_keras_version,
|
||
|
original_backend,
|
||
|
)
|
||
|
return forward_weights + backward_weights
|
||
|
|
||
|
def convert_nested_time_distributed(weights):
|
||
|
"""Converts layers nested in `TimeDistributed` wrapper.
|
||
|
|
||
|
This function uses `preprocess_weights_for_loading()` for converting
|
||
|
nested layers.
|
||
|
|
||
|
Args:
|
||
|
weights: List of weights values (Numpy arrays).
|
||
|
|
||
|
Returns:
|
||
|
A list of weights values (Numpy arrays).
|
||
|
"""
|
||
|
return preprocess_weights_for_loading(
|
||
|
layer.layer, weights, original_keras_version, original_backend
|
||
|
)
|
||
|
|
||
|
def convert_nested_model(weights):
|
||
|
"""Converts layers nested in `Model` or `Sequential`.
|
||
|
|
||
|
This function uses `preprocess_weights_for_loading()` for converting
|
||
|
nested layers.
|
||
|
|
||
|
Args:
|
||
|
weights: List of weights values (Numpy arrays).
|
||
|
|
||
|
Returns:
|
||
|
A list of weights values (Numpy arrays).
|
||
|
"""
|
||
|
trainable_weights = weights[: len(layer.trainable_weights)]
|
||
|
non_trainable_weights = weights[len(layer.trainable_weights) :]
|
||
|
|
||
|
new_trainable_weights = []
|
||
|
new_non_trainable_weights = []
|
||
|
|
||
|
for sublayer in layer.layers:
|
||
|
num_trainable_weights = len(sublayer.trainable_weights)
|
||
|
num_non_trainable_weights = len(sublayer.non_trainable_weights)
|
||
|
if sublayer.weights:
|
||
|
preprocessed = preprocess_weights_for_loading(
|
||
|
layer=sublayer,
|
||
|
weights=(
|
||
|
trainable_weights[:num_trainable_weights]
|
||
|
+ non_trainable_weights[:num_non_trainable_weights]
|
||
|
),
|
||
|
original_keras_version=original_keras_version,
|
||
|
original_backend=original_backend,
|
||
|
)
|
||
|
new_trainable_weights.extend(
|
||
|
preprocessed[:num_trainable_weights]
|
||
|
)
|
||
|
new_non_trainable_weights.extend(
|
||
|
preprocessed[num_trainable_weights:]
|
||
|
)
|
||
|
|
||
|
trainable_weights = trainable_weights[num_trainable_weights:]
|
||
|
non_trainable_weights = non_trainable_weights[
|
||
|
num_non_trainable_weights:
|
||
|
]
|
||
|
new_trainable_weights += layer._trainable_weights
|
||
|
new_non_trainable_weights += layer._non_trainable_weights
|
||
|
return new_trainable_weights + new_non_trainable_weights
|
||
|
|
||
|
# Convert layers nested in Bidirectional/Model/Sequential.
|
||
|
# Both transformation should be ran for both Keras 1->2 conversion
|
||
|
# and for conversion of cuDNN layers.
|
||
|
if layer.__class__.__name__ == "Bidirectional":
|
||
|
weights = convert_nested_bidirectional(weights)
|
||
|
if layer.__class__.__name__ == "TimeDistributed":
|
||
|
weights = convert_nested_time_distributed(weights)
|
||
|
elif layer.__class__.__name__ in ["Model", "Sequential", "Functional"]:
|
||
|
weights = convert_nested_model(weights)
|
||
|
|
||
|
if original_keras_version == "1":
|
||
|
if layer.__class__.__name__ == "TimeDistributed":
|
||
|
weights = preprocess_weights_for_loading(
|
||
|
layer.layer, weights, original_keras_version, original_backend
|
||
|
)
|
||
|
|
||
|
if layer.__class__.__name__ == "Conv1D":
|
||
|
shape = weights[0].shape
|
||
|
# Handle Keras 1.1 format
|
||
|
if (
|
||
|
shape[:2] != (layer.kernel_size[0], 1)
|
||
|
or shape[3] != layer.filters
|
||
|
):
|
||
|
# Legacy shape:
|
||
|
# (filters, input_dim, filter_length, 1)
|
||
|
assert shape[0] == layer.filters and shape[2:] == (
|
||
|
layer.kernel_size[0],
|
||
|
1,
|
||
|
)
|
||
|
weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
|
||
|
weights[0] = weights[0][:, 0, :, :]
|
||
|
|
||
|
if layer.__class__.__name__ == "Conv2D":
|
||
|
if layer.data_format == "channels_first":
|
||
|
# old: (filters, stack_size, kernel_rows, kernel_cols)
|
||
|
# new: (kernel_rows, kernel_cols, stack_size, filters)
|
||
|
weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
|
||
|
|
||
|
if layer.__class__.__name__ == "Conv2DTranspose":
|
||
|
if layer.data_format == "channels_last":
|
||
|
# old: (kernel_rows, kernel_cols, stack_size, filters)
|
||
|
# new: (kernel_rows, kernel_cols, filters, stack_size)
|
||
|
weights[0] = np.transpose(weights[0], (0, 1, 3, 2))
|
||
|
if layer.data_format == "channels_first":
|
||
|
# old: (filters, stack_size, kernel_rows, kernel_cols)
|
||
|
# new: (kernel_rows, kernel_cols, filters, stack_size)
|
||
|
weights[0] = np.transpose(weights[0], (2, 3, 0, 1))
|
||
|
|
||
|
if layer.__class__.__name__ == "Conv3D":
|
||
|
if layer.data_format == "channels_first":
|
||
|
# old: (filters, stack_size, ...)
|
||
|
# new: (..., stack_size, filters)
|
||
|
weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0))
|
||
|
|
||
|
if layer.__class__.__name__ == "GRU":
|
||
|
if len(weights) == 9:
|
||
|
kernel = np.concatenate(
|
||
|
[weights[0], weights[3], weights[6]], axis=-1
|
||
|
)
|
||
|
recurrent_kernel = np.concatenate(
|
||
|
[weights[1], weights[4], weights[7]], axis=-1
|
||
|
)
|
||
|
bias = np.concatenate(
|
||
|
[weights[2], weights[5], weights[8]], axis=-1
|
||
|
)
|
||
|
weights = [kernel, recurrent_kernel, bias]
|
||
|
|
||
|
if layer.__class__.__name__ == "LSTM":
|
||
|
if len(weights) == 12:
|
||
|
# old: i, c, f, o
|
||
|
# new: i, f, c, o
|
||
|
kernel = np.concatenate(
|
||
|
[weights[0], weights[6], weights[3], weights[9]], axis=-1
|
||
|
)
|
||
|
recurrent_kernel = np.concatenate(
|
||
|
[weights[1], weights[7], weights[4], weights[10]], axis=-1
|
||
|
)
|
||
|
bias = np.concatenate(
|
||
|
[weights[2], weights[8], weights[5], weights[11]], axis=-1
|
||
|
)
|
||
|
weights = [kernel, recurrent_kernel, bias]
|
||
|
|
||
|
if layer.__class__.__name__ == "ConvLSTM2D":
|
||
|
if len(weights) == 12:
|
||
|
kernel = np.concatenate(
|
||
|
[weights[0], weights[6], weights[3], weights[9]], axis=-1
|
||
|
)
|
||
|
recurrent_kernel = np.concatenate(
|
||
|
[weights[1], weights[7], weights[4], weights[10]], axis=-1
|
||
|
)
|
||
|
bias = np.concatenate(
|
||
|
[weights[2], weights[8], weights[5], weights[11]], axis=-1
|
||
|
)
|
||
|
if layer.data_format == "channels_first":
|
||
|
# old: (filters, stack_size, kernel_rows, kernel_cols)
|
||
|
# new: (kernel_rows, kernel_cols, stack_size, filters)
|
||
|
kernel = np.transpose(kernel, (2, 3, 1, 0))
|
||
|
recurrent_kernel = np.transpose(
|
||
|
recurrent_kernel, (2, 3, 1, 0)
|
||
|
)
|
||
|
weights = [kernel, recurrent_kernel, bias]
|
||
|
|
||
|
conv_layers = [
|
||
|
"Conv1D",
|
||
|
"Conv2D",
|
||
|
"Conv3D",
|
||
|
"Conv2DTranspose",
|
||
|
"ConvLSTM2D",
|
||
|
]
|
||
|
if layer.__class__.__name__ in conv_layers:
|
||
|
if backend.int_shape(layer.weights[0]) != weights[0].shape:
|
||
|
weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
|
||
|
if layer.__class__.__name__ == "ConvLSTM2D":
|
||
|
weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
|
||
|
|
||
|
# convert cuDNN layers
|
||
|
return _convert_rnn_weights(layer, weights)
|
||
|
|
||
|
|
||
|
def _convert_rnn_weights(layer, weights):
|
||
|
"""Converts weights for RNN layers between native and cuDNN format.
|
||
|
|
||
|
Input kernels for each gate are transposed and converted between Fortran
|
||
|
and C layout, recurrent kernels are transposed. For LSTM biases are summed/
|
||
|
split in half, for GRU biases are reshaped.
|
||
|
|
||
|
Weights can be converted in both directions between `LSTM` and`CuDNNSLTM`
|
||
|
and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not
|
||
|
compatible with `CuDNNGRU`.
|
||
|
|
||
|
For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made.
|
||
|
|
||
|
Args:
|
||
|
layer: Target layer instance.
|
||
|
weights: List of source weights values (input kernels, recurrent
|
||
|
kernels, [biases]) (Numpy arrays).
|
||
|
|
||
|
Returns:
|
||
|
A list of converted weights values (Numpy arrays).
|
||
|
|
||
|
Raises:
|
||
|
ValueError: for incompatible GRU layer/weights or incompatible biases
|
||
|
"""
|
||
|
|
||
|
def transform_kernels(kernels, func, n_gates):
|
||
|
"""Transforms kernel for each gate separately using given function.
|
||
|
|
||
|
Args:
|
||
|
kernels: Stacked array of kernels for individual gates.
|
||
|
func: Function applied to kernel of each gate.
|
||
|
n_gates: Number of gates (4 for LSTM, 3 for GRU).
|
||
|
|
||
|
Returns:
|
||
|
Stacked array of transformed kernels.
|
||
|
"""
|
||
|
return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)])
|
||
|
|
||
|
def transpose_input(from_cudnn):
|
||
|
"""Makes a function that transforms input kernels from/to cuDNN format.
|
||
|
|
||
|
It keeps the shape, but changes between the layout (Fortran/C). Eg.:
|
||
|
|
||
|
```
|
||
|
Keras cuDNN
|
||
|
[[0, 1, 2], <---> [[0, 2, 4],
|
||
|
[3, 4, 5]] [1, 3, 5]]
|
||
|
```
|
||
|
|
||
|
It can be passed to `transform_kernels()`.
|
||
|
|
||
|
Args:
|
||
|
from_cudnn: `True` if source weights are in cuDNN format, `False` if
|
||
|
they're in plain Keras format.
|
||
|
|
||
|
Returns:
|
||
|
Function that converts input kernel to the other format.
|
||
|
"""
|
||
|
order = "F" if from_cudnn else "C"
|
||
|
|
||
|
def transform(kernel):
|
||
|
return kernel.T.reshape(kernel.shape, order=order)
|
||
|
|
||
|
return transform
|
||
|
|
||
|
target_class = layer.__class__.__name__
|
||
|
|
||
|
# convert the weights between CuDNNLSTM and LSTM
|
||
|
if target_class in ["LSTM", "CuDNNLSTM"] and len(weights) == 3:
|
||
|
# determine if we're loading a CuDNNLSTM layer
|
||
|
# from the number of bias weights:
|
||
|
# CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
|
||
|
# if there's no bias weight in the file, skip this conversion
|
||
|
units = weights[1].shape[0]
|
||
|
bias_shape = weights[2].shape
|
||
|
n_gates = 4
|
||
|
|
||
|
if bias_shape == (2 * units * n_gates,):
|
||
|
source = "CuDNNLSTM"
|
||
|
elif bias_shape == (units * n_gates,):
|
||
|
source = "LSTM"
|
||
|
else:
|
||
|
raise ValueError("Invalid bias shape: " + str(bias_shape))
|
||
|
|
||
|
def convert_lstm_weights(weights, from_cudnn=True):
|
||
|
"""Converts the weights between CuDNNLSTM and LSTM.
|
||
|
|
||
|
Args:
|
||
|
weights: Original weights.
|
||
|
from_cudnn: Indicates whether original weights are from cuDNN
|
||
|
layer.
|
||
|
|
||
|
Returns:
|
||
|
Updated weights compatible with LSTM.
|
||
|
"""
|
||
|
|
||
|
# Transpose (and reshape) input and recurrent kernels
|
||
|
kernels = transform_kernels(
|
||
|
weights[0], transpose_input(from_cudnn), n_gates
|
||
|
)
|
||
|
recurrent_kernels = transform_kernels(
|
||
|
weights[1], lambda k: k.T, n_gates
|
||
|
)
|
||
|
if from_cudnn:
|
||
|
# merge input and recurrent biases into a single set
|
||
|
biases = np.sum(np.split(weights[2], 2, axis=0), axis=0)
|
||
|
else:
|
||
|
# Split single set of biases evenly to two sets. The way of
|
||
|
# splitting doesn't matter as long as the two sets sum is kept.
|
||
|
biases = np.tile(0.5 * weights[2], 2)
|
||
|
return [kernels, recurrent_kernels, biases]
|
||
|
|
||
|
if source != target_class:
|
||
|
weights = convert_lstm_weights(
|
||
|
weights, from_cudnn=source == "CuDNNLSTM"
|
||
|
)
|
||
|
|
||
|
# convert the weights between CuDNNGRU and GRU(reset_after=True)
|
||
|
if target_class in ["GRU", "CuDNNGRU"] and len(weights) == 3:
|
||
|
# We can determine the source of the weights from the shape of the bias.
|
||
|
# If there is no bias we skip the conversion since
|
||
|
# CuDNNGRU always has biases.
|
||
|
|
||
|
units = weights[1].shape[0]
|
||
|
bias_shape = weights[2].shape
|
||
|
n_gates = 3
|
||
|
|
||
|
def convert_gru_weights(weights, from_cudnn=True):
|
||
|
"""Converts the weights between CuDNNGRU and GRU.
|
||
|
|
||
|
Args:
|
||
|
weights: Original weights.
|
||
|
from_cudnn: Indicates whether original weights are from cuDNN
|
||
|
layer.
|
||
|
|
||
|
Returns:
|
||
|
Updated weights compatible with GRU.
|
||
|
"""
|
||
|
|
||
|
kernels = transform_kernels(
|
||
|
weights[0], transpose_input(from_cudnn), n_gates
|
||
|
)
|
||
|
recurrent_kernels = transform_kernels(
|
||
|
weights[1], lambda k: k.T, n_gates
|
||
|
)
|
||
|
biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1)
|
||
|
return [kernels, recurrent_kernels, biases]
|
||
|
|
||
|
if bias_shape == (2 * units * n_gates,):
|
||
|
source = "CuDNNGRU"
|
||
|
elif bias_shape == (2, units * n_gates):
|
||
|
source = "GRU(reset_after=True)"
|
||
|
elif bias_shape == (units * n_gates,):
|
||
|
source = "GRU(reset_after=False)"
|
||
|
else:
|
||
|
raise ValueError("Invalid bias shape: " + str(bias_shape))
|
||
|
|
||
|
if target_class == "CuDNNGRU":
|
||
|
target = "CuDNNGRU"
|
||
|
elif layer.reset_after:
|
||
|
target = "GRU(reset_after=True)"
|
||
|
else:
|
||
|
target = "GRU(reset_after=False)"
|
||
|
|
||
|
# only convert between different types
|
||
|
if source != target:
|
||
|
types = (source, target)
|
||
|
if "GRU(reset_after=False)" in types:
|
||
|
raise ValueError("%s is not compatible with %s" % types)
|
||
|
if source == "CuDNNGRU":
|
||
|
weights = convert_gru_weights(weights, from_cudnn=True)
|
||
|
elif source == "GRU(reset_after=True)":
|
||
|
weights = convert_gru_weights(weights, from_cudnn=False)
|
||
|
|
||
|
return weights
|
||
|
|
||
|
|
||
|
def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
|
||
|
"""Saves optimizer weights of a optimizer to a HDF5 group.
|
||
|
|
||
|
Args:
|
||
|
hdf5_group: HDF5 group.
|
||
|
optimizer: optimizer instance.
|
||
|
"""
|
||
|
if isinstance(optimizer, optimizer_base.Optimizer):
|
||
|
symbolic_weights = optimizer.variables
|
||
|
else:
|
||
|
symbolic_weights = getattr(optimizer, "weights")
|
||
|
if symbolic_weights:
|
||
|
weights_group = hdf5_group.create_group("optimizer_weights")
|
||
|
weight_names = [str(w.name).encode("utf8") for w in symbolic_weights]
|
||
|
save_attributes_to_hdf5_group(
|
||
|
weights_group, "weight_names", weight_names
|
||
|
)
|
||
|
weight_values = backend.batch_get_value(symbolic_weights)
|
||
|
for name, val in zip(weight_names, weight_values):
|
||
|
param_dset = weights_group.create_dataset(
|
||
|
name, val.shape, dtype=val.dtype
|
||
|
)
|
||
|
if not val.shape:
|
||
|
# scalar
|
||
|
param_dset[()] = val
|
||
|
else:
|
||
|
param_dset[:] = val
|
||
|
|
||
|
|
||
|
def load_optimizer_weights_from_hdf5_group(hdf5_group):
|
||
|
"""Load optimizer weights from a HDF5 group.
|
||
|
|
||
|
Args:
|
||
|
hdf5_group: A pointer to a HDF5 group.
|
||
|
|
||
|
Returns:
|
||
|
data: List of optimizer weight names.
|
||
|
"""
|
||
|
weights_group = hdf5_group["optimizer_weights"]
|
||
|
optimizer_weight_names = load_attributes_from_hdf5_group(
|
||
|
weights_group, "weight_names"
|
||
|
)
|
||
|
return [
|
||
|
weights_group[weight_name] for weight_name in optimizer_weight_names
|
||
|
]
|
||
|
|
||
|
|
||
|
def save_subset_weights_to_hdf5_group(f, weights):
|
||
|
"""Save top-level weights of a model to a HDF5 group.
|
||
|
|
||
|
Args:
|
||
|
f: HDF5 group.
|
||
|
weights: List of weight variables.
|
||
|
"""
|
||
|
weight_values = backend.batch_get_value(weights)
|
||
|
weight_names = [w.name.encode("utf8") for w in weights]
|
||
|
save_attributes_to_hdf5_group(f, "weight_names", weight_names)
|
||
|
for name, val in zip(weight_names, weight_values):
|
||
|
param_dset = f.create_dataset(name, val.shape, dtype=val.dtype)
|
||
|
if not val.shape:
|
||
|
# scalar
|
||
|
param_dset[()] = val
|
||
|
else:
|
||
|
param_dset[:] = val
|
||
|
|
||
|
|
||
|
def save_weights_to_hdf5_group(f, model):
|
||
|
"""Saves the weights of a list of layers to a HDF5 group.
|
||
|
|
||
|
Args:
|
||
|
f: HDF5 group.
|
||
|
model: Model instance.
|
||
|
"""
|
||
|
from keras import __version__ as keras_version
|
||
|
|
||
|
save_attributes_to_hdf5_group(
|
||
|
f, "layer_names", [layer.name.encode("utf8") for layer in model.layers]
|
||
|
)
|
||
|
f.attrs["backend"] = backend.backend().encode("utf8")
|
||
|
f.attrs["keras_version"] = str(keras_version).encode("utf8")
|
||
|
|
||
|
# Sort model layers by layer name to ensure that group names are strictly
|
||
|
# growing to avoid prefix issues.
|
||
|
for layer in sorted(model.layers, key=lambda x: x.name):
|
||
|
g = f.create_group(layer.name)
|
||
|
weights = _legacy_weights(layer)
|
||
|
save_subset_weights_to_hdf5_group(g, weights)
|
||
|
weights = model._trainable_weights + model._non_trainable_weights
|
||
|
g = f.create_group("top_level_model_weights")
|
||
|
save_subset_weights_to_hdf5_group(g, weights)
|
||
|
|
||
|
|
||
|
def load_subset_weights_from_hdf5_group(f):
|
||
|
"""Load layer weights of a model from hdf5.
|
||
|
|
||
|
Args:
|
||
|
f: A pointer to a HDF5 group.
|
||
|
|
||
|
Returns:
|
||
|
List of NumPy arrays of the weight values.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: in case of mismatch between provided model
|
||
|
and weights file.
|
||
|
"""
|
||
|
weight_names = load_attributes_from_hdf5_group(f, "weight_names")
|
||
|
return [np.asarray(f[weight_name]) for weight_name in weight_names]
|
||
|
|
||
|
|
||
|
def load_weights_from_hdf5_group(f, model):
|
||
|
"""Implements topological (order-based) weight loading.
|
||
|
|
||
|
Args:
|
||
|
f: A pointer to a HDF5 group.
|
||
|
model: Model instance.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: in case of mismatch between provided layers
|
||
|
and weights file.
|
||
|
"""
|
||
|
if "keras_version" in f.attrs:
|
||
|
original_keras_version = f.attrs["keras_version"]
|
||
|
if hasattr(original_keras_version, "decode"):
|
||
|
original_keras_version = original_keras_version.decode("utf8")
|
||
|
else:
|
||
|
original_keras_version = "1"
|
||
|
if "backend" in f.attrs:
|
||
|
original_backend = f.attrs["backend"]
|
||
|
if hasattr(original_backend, "decode"):
|
||
|
original_backend = original_backend.decode("utf8")
|
||
|
else:
|
||
|
original_backend = None
|
||
|
|
||
|
filtered_layers = []
|
||
|
for layer in model.layers:
|
||
|
weights = _legacy_weights(layer)
|
||
|
if weights:
|
||
|
filtered_layers.append(layer)
|
||
|
|
||
|
layer_names = load_attributes_from_hdf5_group(f, "layer_names")
|
||
|
filtered_layer_names = []
|
||
|
for name in layer_names:
|
||
|
g = f[name]
|
||
|
weight_names = load_attributes_from_hdf5_group(g, "weight_names")
|
||
|
if weight_names:
|
||
|
filtered_layer_names.append(name)
|
||
|
layer_names = filtered_layer_names
|
||
|
if len(layer_names) != len(filtered_layers):
|
||
|
raise ValueError(
|
||
|
"Layer count mismatch when loading weights from file. "
|
||
|
f"Model expected {len(filtered_layers)} layers, found "
|
||
|
f"{len(layer_names)} saved layers."
|
||
|
)
|
||
|
|
||
|
# We batch weight value assignments in a single backend call
|
||
|
# which provides a speedup in TensorFlow.
|
||
|
weight_value_tuples = []
|
||
|
for k, name in enumerate(layer_names):
|
||
|
g = f[name]
|
||
|
layer = filtered_layers[k]
|
||
|
symbolic_weights = _legacy_weights(layer)
|
||
|
weight_values = load_subset_weights_from_hdf5_group(g)
|
||
|
weight_values = preprocess_weights_for_loading(
|
||
|
layer, weight_values, original_keras_version, original_backend
|
||
|
)
|
||
|
if len(weight_values) != len(symbolic_weights):
|
||
|
raise ValueError(
|
||
|
f"Weight count mismatch for layer #{k} (named {layer.name} in "
|
||
|
f"the current model, {name} in the save file). "
|
||
|
f"Layer expects {len(symbolic_weights)} weight(s). Received "
|
||
|
f"{len(weight_values)} saved weight(s)"
|
||
|
)
|
||
|
weight_value_tuples += zip(symbolic_weights, weight_values)
|
||
|
|
||
|
if "top_level_model_weights" in f:
|
||
|
symbolic_weights = (
|
||
|
model._trainable_weights + model._non_trainable_weights
|
||
|
)
|
||
|
weight_values = load_subset_weights_from_hdf5_group(
|
||
|
f["top_level_model_weights"]
|
||
|
)
|
||
|
if len(weight_values) != len(symbolic_weights):
|
||
|
raise ValueError(
|
||
|
"Weight count mismatch for top-level weights when loading "
|
||
|
"weights from file. "
|
||
|
f"Model expects {len(symbolic_weights)} top-level weight(s). "
|
||
|
f"Received {len(weight_values)} saved top-level weight(s)"
|
||
|
)
|
||
|
weight_value_tuples += zip(symbolic_weights, weight_values)
|
||
|
backend.batch_set_value(weight_value_tuples)
|
||
|
|
||
|
# Perform any layer defined finalization of the layer state.
|
||
|
for layer in model._flatten_layers():
|
||
|
layer.finalize_state()
|
||
|
|
||
|
|
||
|
def load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False):
|
||
|
"""Implements name-based weight loading (instead of topological loading).
|
||
|
|
||
|
Layers that have no matching name are skipped.
|
||
|
|
||
|
Args:
|
||
|
f: A pointer to a HDF5 group.
|
||
|
model: Model instance.
|
||
|
skip_mismatch: Boolean, whether to skip loading of layers
|
||
|
where there is a mismatch in the number of weights,
|
||
|
or a mismatch in the shape of the weights.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: in case of mismatch between provided layers
|
||
|
and weights file and skip_match=False.
|
||
|
"""
|
||
|
if "keras_version" in f.attrs:
|
||
|
original_keras_version = f.attrs["keras_version"]
|
||
|
if hasattr(original_keras_version, "decode"):
|
||
|
original_keras_version = original_keras_version.decode("utf8")
|
||
|
else:
|
||
|
original_keras_version = "1"
|
||
|
if "backend" in f.attrs:
|
||
|
original_backend = f.attrs["backend"]
|
||
|
if hasattr(original_backend, "decode"):
|
||
|
original_backend = original_backend.decode("utf8")
|
||
|
else:
|
||
|
original_backend = None
|
||
|
|
||
|
# New file format.
|
||
|
layer_names = load_attributes_from_hdf5_group(f, "layer_names")
|
||
|
|
||
|
# Reverse index of layer name to list of layers with name.
|
||
|
index = {}
|
||
|
for layer in model.layers:
|
||
|
if layer.name:
|
||
|
index.setdefault(layer.name, []).append(layer)
|
||
|
|
||
|
# We batch weight value assignments in a single backend call
|
||
|
# which provides a speedup in TensorFlow.
|
||
|
weight_value_tuples = []
|
||
|
for k, name in enumerate(layer_names):
|
||
|
g = f[name]
|
||
|
weight_values = load_subset_weights_from_hdf5_group(g)
|
||
|
for layer in index.get(name, []):
|
||
|
symbolic_weights = _legacy_weights(layer)
|
||
|
weight_values = preprocess_weights_for_loading(
|
||
|
layer, weight_values, original_keras_version, original_backend
|
||
|
)
|
||
|
if len(weight_values) != len(symbolic_weights):
|
||
|
if skip_mismatch:
|
||
|
logging.warning(
|
||
|
f"Skipping loading of weights for layer #{k} (named "
|
||
|
f"{layer.name}) due to mismatch in number of weights. "
|
||
|
f"Layer expects {len(symbolic_weights)} weight(s). "
|
||
|
f"Received {len(weight_values)} saved weight(s)"
|
||
|
)
|
||
|
continue
|
||
|
raise ValueError(
|
||
|
f"Weight count mismatch for layer #{k} "
|
||
|
f"(named {layer.name}). "
|
||
|
f"Layer expects {len(symbolic_weights)} weight(s). "
|
||
|
f"Received {len(weight_values)} saved weight(s)"
|
||
|
)
|
||
|
# Set values.
|
||
|
for i in range(len(weight_values)):
|
||
|
expected_shape = backend.int_shape(symbolic_weights[i])
|
||
|
received_shape = weight_values[i].shape
|
||
|
if expected_shape != received_shape:
|
||
|
if skip_mismatch:
|
||
|
logging.warning(
|
||
|
f"Skipping loading weights for layer #{k} (named "
|
||
|
f"{layer.name}) due to mismatch in shape for "
|
||
|
f"weight {symbolic_weights[i].name}. "
|
||
|
f"Weight expects shape {expected_shape}. "
|
||
|
"Received saved weight "
|
||
|
f"with shape {received_shape}"
|
||
|
)
|
||
|
continue
|
||
|
raise ValueError(
|
||
|
f"Shape mismatch in layer #{k} (named {layer.name}) "
|
||
|
f"for weight {symbolic_weights[i].name}. "
|
||
|
f"Weight expects shape {expected_shape}. "
|
||
|
"Received saved weight "
|
||
|
f"with shape {received_shape}"
|
||
|
)
|
||
|
else:
|
||
|
weight_value_tuples.append(
|
||
|
(symbolic_weights[i], weight_values[i])
|
||
|
)
|
||
|
|
||
|
if "top_level_model_weights" in f:
|
||
|
symbolic_weights = (
|
||
|
model._trainable_weights + model._non_trainable_weights
|
||
|
)
|
||
|
weight_values = load_subset_weights_from_hdf5_group(
|
||
|
f["top_level_model_weights"]
|
||
|
)
|
||
|
|
||
|
if len(weight_values) != len(symbolic_weights):
|
||
|
if skip_mismatch:
|
||
|
logging.warning(
|
||
|
"Skipping loading top-level weights for model due to "
|
||
|
"mismatch in number of weights. "
|
||
|
f"Model expects {len(symbolic_weights)} "
|
||
|
"top-level weight(s). "
|
||
|
f"Received {len(weight_values)} saved top-level weight(s)"
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Weight count mismatch for top-level weights of model. "
|
||
|
f"Model expects {len(symbolic_weights)} "
|
||
|
"top-level weight(s). "
|
||
|
f"Received {len(weight_values)} saved top-level weight(s)"
|
||
|
)
|
||
|
else:
|
||
|
for i in range(len(weight_values)):
|
||
|
expected_shape = backend.int_shape(symbolic_weights[i])
|
||
|
received_shape = weight_values[i].shape
|
||
|
if expected_shape != received_shape:
|
||
|
if skip_mismatch:
|
||
|
logging.warning(
|
||
|
"Skipping loading top-level weight for model due "
|
||
|
"to mismatch in shape for "
|
||
|
f"weight {symbolic_weights[i].name}. "
|
||
|
f"Weight expects shape {expected_shape}. "
|
||
|
"Received saved weight "
|
||
|
f"with shape {received_shape}"
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Shape mismatch in model for top-level weight "
|
||
|
f"{symbolic_weights[i].name}. "
|
||
|
f"Weight expects shape {expected_shape}. "
|
||
|
"Received saved weight "
|
||
|
f"with shape {received_shape}"
|
||
|
)
|
||
|
else:
|
||
|
weight_value_tuples.append(
|
||
|
(symbolic_weights[i], weight_values[i])
|
||
|
)
|
||
|
|
||
|
backend.batch_set_value(weight_value_tuples)
|
||
|
|
||
|
# Perform any layer defined finalization of the layer state.
|
||
|
for layer in model._flatten_layers():
|
||
|
layer.finalize_state()
|
||
|
|
||
|
|
||
|
def save_attributes_to_hdf5_group(group, name, data):
|
||
|
"""Saves attributes (data) of the specified name into the HDF5 group.
|
||
|
|
||
|
This method deals with an inherent problem of HDF5 file which is not
|
||
|
able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
|
||
|
|
||
|
Args:
|
||
|
group: A pointer to a HDF5 group.
|
||
|
name: A name of the attributes to save.
|
||
|
data: Attributes data to store.
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If any single attribute is too large to be saved.
|
||
|
"""
|
||
|
# Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
|
||
|
# because in that case even chunking the array would not make the saving
|
||
|
# possible.
|
||
|
bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
|
||
|
|
||
|
# Expecting this to never be true.
|
||
|
if bad_attributes:
|
||
|
raise RuntimeError(
|
||
|
"The following attributes cannot be saved to HDF5 file because "
|
||
|
f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} "
|
||
|
f"bytes: {bad_attributes}"
|
||
|
)
|
||
|
|
||
|
data_npy = np.asarray(data)
|
||
|
|
||
|
num_chunks = 1
|
||
|
chunked_data = np.array_split(data_npy, num_chunks)
|
||
|
|
||
|
# This will never loop forever thanks to the test above.
|
||
|
while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
|
||
|
num_chunks += 1
|
||
|
chunked_data = np.array_split(data_npy, num_chunks)
|
||
|
|
||
|
if num_chunks > 1:
|
||
|
for chunk_id, chunk_data in enumerate(chunked_data):
|
||
|
group.attrs["%s%d" % (name, chunk_id)] = chunk_data
|
||
|
else:
|
||
|
group.attrs[name] = data
|
||
|
|
||
|
|
||
|
def load_attributes_from_hdf5_group(group, name):
|
||
|
"""Loads attributes of the specified name from the HDF5 group.
|
||
|
|
||
|
This method deals with an inherent problem
|
||
|
of HDF5 file which is not able to store
|
||
|
data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
|
||
|
|
||
|
Args:
|
||
|
group: A pointer to a HDF5 group.
|
||
|
name: A name of the attributes to load.
|
||
|
|
||
|
Returns:
|
||
|
data: Attributes data.
|
||
|
"""
|
||
|
if name in group.attrs:
|
||
|
data = [
|
||
|
n.decode("utf8") if hasattr(n, "decode") else n
|
||
|
for n in group.attrs[name]
|
||
|
]
|
||
|
else:
|
||
|
data = []
|
||
|
chunk_id = 0
|
||
|
while "%s%d" % (name, chunk_id) in group.attrs:
|
||
|
data.extend(
|
||
|
[
|
||
|
n.decode("utf8") if hasattr(n, "decode") else n
|
||
|
for n in group.attrs["%s%d" % (name, chunk_id)]
|
||
|
]
|
||
|
)
|
||
|
chunk_id += 1
|
||
|
return data
|
||
|
|
||
|
|
||
|
def _legacy_weights(layer):
|
||
|
"""DO NOT USE.
|
||
|
|
||
|
For legacy reason, the layer.weights was in the order of
|
||
|
[self.trainable_weights + self.non_trainable_weights], and this order was
|
||
|
used for preserving the weights in h5 format. The new order of layer.weights
|
||
|
are the same as layer.get_weights() which is more intuitive for user. To
|
||
|
keep supporting the existing saved h5 file, this method should be used to
|
||
|
save/load weights. In future version, we will delete this method and
|
||
|
introduce a breaking change for h5 and stay with the new order for weights.
|
||
|
|
||
|
Args:
|
||
|
layer: a `tf.keras.Model` or `tf.keras.layers.Layer` instance.
|
||
|
|
||
|
Returns:
|
||
|
A list of variables with the order of trainable_weights, followed by
|
||
|
non_trainable_weights.
|
||
|
"""
|
||
|
weights = layer.trainable_weights + layer.non_trainable_weights
|
||
|
if any(not isinstance(w, tf.Variable) for w in weights):
|
||
|
raise NotImplementedError(
|
||
|
"Save or restore weights that is not an instance of `tf.Variable` "
|
||
|
"is not supported in h5, use `save_format='tf'` instead. Received "
|
||
|
f"a model or layer {layer.__class__.__name__} "
|
||
|
f"with weights {weights}"
|
||
|
)
|
||
|
return weights
|