289 lines
9.3 KiB
Python
289 lines
9.3 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Layer serialization/deserialization functions."""
|
||
|
|
||
|
import threading
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras.engine import base_layer
|
||
|
from keras.engine import input_layer
|
||
|
from keras.engine import input_spec
|
||
|
from keras.layers import activation
|
||
|
from keras.layers import attention
|
||
|
from keras.layers import convolutional
|
||
|
from keras.layers import core
|
||
|
from keras.layers import locally_connected
|
||
|
from keras.layers import merging
|
||
|
from keras.layers import pooling
|
||
|
from keras.layers import regularization
|
||
|
from keras.layers import reshaping
|
||
|
from keras.layers import rnn
|
||
|
from keras.layers.normalization import batch_normalization
|
||
|
from keras.layers.normalization import batch_normalization_v1
|
||
|
from keras.layers.normalization import group_normalization
|
||
|
from keras.layers.normalization import layer_normalization
|
||
|
from keras.layers.normalization import unit_normalization
|
||
|
from keras.layers.preprocessing import category_encoding
|
||
|
from keras.layers.preprocessing import discretization
|
||
|
from keras.layers.preprocessing import hashed_crossing
|
||
|
from keras.layers.preprocessing import hashing
|
||
|
from keras.layers.preprocessing import image_preprocessing
|
||
|
from keras.layers.preprocessing import integer_lookup
|
||
|
from keras.layers.preprocessing import (
|
||
|
normalization as preprocessing_normalization,
|
||
|
)
|
||
|
from keras.layers.preprocessing import string_lookup
|
||
|
from keras.layers.preprocessing import text_vectorization
|
||
|
from keras.layers.rnn import cell_wrappers
|
||
|
from keras.layers.rnn import gru
|
||
|
from keras.layers.rnn import lstm
|
||
|
from keras.saving.legacy import serialization as legacy_serialization
|
||
|
from keras.saving.legacy.saved_model import json_utils
|
||
|
from keras.utils import generic_utils
|
||
|
from keras.utils import tf_inspect as inspect
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
ALL_MODULES = (
|
||
|
base_layer,
|
||
|
input_layer,
|
||
|
activation,
|
||
|
attention,
|
||
|
convolutional,
|
||
|
core,
|
||
|
locally_connected,
|
||
|
merging,
|
||
|
batch_normalization_v1,
|
||
|
group_normalization,
|
||
|
layer_normalization,
|
||
|
unit_normalization,
|
||
|
pooling,
|
||
|
image_preprocessing,
|
||
|
regularization,
|
||
|
reshaping,
|
||
|
rnn,
|
||
|
hashing,
|
||
|
hashed_crossing,
|
||
|
category_encoding,
|
||
|
discretization,
|
||
|
integer_lookup,
|
||
|
preprocessing_normalization,
|
||
|
string_lookup,
|
||
|
text_vectorization,
|
||
|
)
|
||
|
ALL_V2_MODULES = (
|
||
|
batch_normalization,
|
||
|
layer_normalization,
|
||
|
cell_wrappers,
|
||
|
gru,
|
||
|
lstm,
|
||
|
)
|
||
|
# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
|
||
|
# thread-local to avoid concurrent mutations.
|
||
|
LOCAL = threading.local()
|
||
|
|
||
|
|
||
|
def populate_deserializable_objects():
|
||
|
"""Populates dict ALL_OBJECTS with every built-in layer."""
|
||
|
global LOCAL
|
||
|
if not hasattr(LOCAL, "ALL_OBJECTS"):
|
||
|
LOCAL.ALL_OBJECTS = {}
|
||
|
LOCAL.GENERATED_WITH_V2 = None
|
||
|
|
||
|
if (
|
||
|
LOCAL.ALL_OBJECTS
|
||
|
and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()
|
||
|
):
|
||
|
# Objects dict is already generated for the proper TF version:
|
||
|
# do nothing.
|
||
|
return
|
||
|
|
||
|
LOCAL.ALL_OBJECTS = {}
|
||
|
LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()
|
||
|
|
||
|
base_cls = base_layer.Layer
|
||
|
generic_utils.populate_dict_with_module_objects(
|
||
|
LOCAL.ALL_OBJECTS,
|
||
|
ALL_MODULES,
|
||
|
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls),
|
||
|
)
|
||
|
|
||
|
# Overwrite certain V1 objects with V2 versions
|
||
|
if tf.__internal__.tf2.enabled():
|
||
|
generic_utils.populate_dict_with_module_objects(
|
||
|
LOCAL.ALL_OBJECTS,
|
||
|
ALL_V2_MODULES,
|
||
|
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls),
|
||
|
)
|
||
|
|
||
|
# These deserialization aliases are added for backward compatibility,
|
||
|
# as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
|
||
|
# were used as class name for v1 and v2 version of BatchNormalization,
|
||
|
# respectively. Here we explicitly convert them to their canonical names.
|
||
|
LOCAL.ALL_OBJECTS[
|
||
|
"BatchNormalizationV1"
|
||
|
] = batch_normalization_v1.BatchNormalization
|
||
|
LOCAL.ALL_OBJECTS[
|
||
|
"BatchNormalizationV2"
|
||
|
] = batch_normalization.BatchNormalization
|
||
|
|
||
|
# Prevent circular dependencies.
|
||
|
from keras import models
|
||
|
from keras.feature_column.sequence_feature_column import (
|
||
|
SequenceFeatures,
|
||
|
)
|
||
|
from keras.premade_models.linear import (
|
||
|
LinearModel,
|
||
|
)
|
||
|
from keras.premade_models.wide_deep import (
|
||
|
WideDeepModel,
|
||
|
)
|
||
|
|
||
|
LOCAL.ALL_OBJECTS["Input"] = input_layer.Input
|
||
|
LOCAL.ALL_OBJECTS["InputSpec"] = input_spec.InputSpec
|
||
|
LOCAL.ALL_OBJECTS["Functional"] = models.Functional
|
||
|
LOCAL.ALL_OBJECTS["Model"] = models.Model
|
||
|
LOCAL.ALL_OBJECTS["SequenceFeatures"] = SequenceFeatures
|
||
|
LOCAL.ALL_OBJECTS["Sequential"] = models.Sequential
|
||
|
LOCAL.ALL_OBJECTS["LinearModel"] = LinearModel
|
||
|
LOCAL.ALL_OBJECTS["WideDeepModel"] = WideDeepModel
|
||
|
|
||
|
if tf.__internal__.tf2.enabled():
|
||
|
from keras.feature_column.dense_features_v2 import (
|
||
|
DenseFeatures,
|
||
|
)
|
||
|
|
||
|
LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures
|
||
|
else:
|
||
|
from keras.feature_column.dense_features import (
|
||
|
DenseFeatures,
|
||
|
)
|
||
|
|
||
|
LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures
|
||
|
|
||
|
# Merging layers, function versions.
|
||
|
LOCAL.ALL_OBJECTS["add"] = merging.add
|
||
|
LOCAL.ALL_OBJECTS["subtract"] = merging.subtract
|
||
|
LOCAL.ALL_OBJECTS["multiply"] = merging.multiply
|
||
|
LOCAL.ALL_OBJECTS["average"] = merging.average
|
||
|
LOCAL.ALL_OBJECTS["maximum"] = merging.maximum
|
||
|
LOCAL.ALL_OBJECTS["minimum"] = merging.minimum
|
||
|
LOCAL.ALL_OBJECTS["concatenate"] = merging.concatenate
|
||
|
LOCAL.ALL_OBJECTS["dot"] = merging.dot
|
||
|
|
||
|
|
||
|
@keras_export("keras.layers.serialize")
|
||
|
def serialize(layer, use_legacy_format=False):
|
||
|
"""Serializes a `Layer` object into a JSON-compatible representation.
|
||
|
|
||
|
Args:
|
||
|
layer: The `Layer` object to serialize.
|
||
|
|
||
|
Returns:
|
||
|
A JSON-serializable dict representing the object's config.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
from pprint import pprint
|
||
|
model = tf.keras.models.Sequential()
|
||
|
model.add(tf.keras.Input(shape=(16,)))
|
||
|
model.add(tf.keras.layers.Dense(32, activation='relu'))
|
||
|
|
||
|
pprint(tf.keras.layers.serialize(model))
|
||
|
# prints the configuration of the model, as a dict.
|
||
|
"""
|
||
|
if use_legacy_format:
|
||
|
return legacy_serialization.serialize_keras_object(layer)
|
||
|
|
||
|
# To be replaced by new serialization_lib
|
||
|
return legacy_serialization.serialize_keras_object(layer)
|
||
|
|
||
|
|
||
|
@keras_export("keras.layers.deserialize")
|
||
|
def deserialize(config, custom_objects=None, use_legacy_format=False):
|
||
|
"""Instantiates a layer from a config dictionary.
|
||
|
|
||
|
Args:
|
||
|
config: dict of the form {'class_name': str, 'config': dict}
|
||
|
custom_objects: dict mapping class names (or function names) of custom
|
||
|
(non-Keras) objects to class/functions
|
||
|
|
||
|
Returns:
|
||
|
Layer instance (may be Model, Sequential, Network, Layer...)
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
# Configuration of Dense(32, activation='relu')
|
||
|
config = {
|
||
|
'class_name': 'Dense',
|
||
|
'config': {
|
||
|
'activation': 'relu',
|
||
|
'activity_regularizer': None,
|
||
|
'bias_constraint': None,
|
||
|
'bias_initializer': {'class_name': 'Zeros', 'config': {}},
|
||
|
'bias_regularizer': None,
|
||
|
'dtype': 'float32',
|
||
|
'kernel_constraint': None,
|
||
|
'kernel_initializer': {'class_name': 'GlorotUniform',
|
||
|
'config': {'seed': None}},
|
||
|
'kernel_regularizer': None,
|
||
|
'name': 'dense',
|
||
|
'trainable': True,
|
||
|
'units': 32,
|
||
|
'use_bias': True
|
||
|
}
|
||
|
}
|
||
|
dense_layer = tf.keras.layers.deserialize(config)
|
||
|
```
|
||
|
"""
|
||
|
populate_deserializable_objects()
|
||
|
if use_legacy_format:
|
||
|
return legacy_serialization.deserialize_keras_object(
|
||
|
config,
|
||
|
module_objects=LOCAL.ALL_OBJECTS,
|
||
|
custom_objects=custom_objects,
|
||
|
printable_module_name="layer",
|
||
|
)
|
||
|
|
||
|
# To be replaced by new serialization_lib
|
||
|
return legacy_serialization.deserialize_keras_object(
|
||
|
config,
|
||
|
module_objects=LOCAL.ALL_OBJECTS,
|
||
|
custom_objects=custom_objects,
|
||
|
printable_module_name="layer",
|
||
|
)
|
||
|
|
||
|
|
||
|
def get_builtin_layer(class_name):
|
||
|
"""Returns class if `class_name` is registered, else returns None."""
|
||
|
if not hasattr(LOCAL, "ALL_OBJECTS"):
|
||
|
populate_deserializable_objects()
|
||
|
return LOCAL.ALL_OBJECTS.get(class_name)
|
||
|
|
||
|
|
||
|
def deserialize_from_json(json_string, custom_objects=None):
|
||
|
"""Instantiates a layer from a JSON string."""
|
||
|
populate_deserializable_objects()
|
||
|
config = json_utils.decode_and_deserialize(
|
||
|
json_string,
|
||
|
module_objects=LOCAL.ALL_OBJECTS,
|
||
|
custom_objects=custom_objects,
|
||
|
)
|
||
|
return deserialize(config, custom_objects)
|