216 lines
8.4 KiB
Python
216 lines
8.4 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Keras initializer serialization / deserialization."""
|
||
|
|
||
|
import threading
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras.initializers import initializers
|
||
|
from keras.initializers import initializers_v1
|
||
|
from keras.saving.legacy import serialization as legacy_serialization
|
||
|
from keras.utils import generic_utils
|
||
|
from keras.utils import tf_inspect as inspect
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python import tf2
|
||
|
from tensorflow.python.ops import init_ops
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
|
||
|
# thread-local to avoid concurrent mutations.
|
||
|
LOCAL = threading.local()
|
||
|
|
||
|
|
||
|
def populate_deserializable_objects():
|
||
|
"""Populates dict ALL_OBJECTS with every built-in initializer."""
|
||
|
global LOCAL
|
||
|
if not hasattr(LOCAL, "ALL_OBJECTS"):
|
||
|
LOCAL.ALL_OBJECTS = {}
|
||
|
LOCAL.GENERATED_WITH_V2 = None
|
||
|
|
||
|
if (
|
||
|
LOCAL.ALL_OBJECTS
|
||
|
and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()
|
||
|
):
|
||
|
# Objects dict is already generated for the proper TF version:
|
||
|
# do nothing.
|
||
|
return
|
||
|
|
||
|
LOCAL.ALL_OBJECTS = {}
|
||
|
LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()
|
||
|
|
||
|
# Compatibility aliases (need to exist in both V1 and V2).
|
||
|
LOCAL.ALL_OBJECTS["ConstantV2"] = initializers.Constant
|
||
|
LOCAL.ALL_OBJECTS["GlorotNormalV2"] = initializers.GlorotNormal
|
||
|
LOCAL.ALL_OBJECTS["GlorotUniformV2"] = initializers.GlorotUniform
|
||
|
LOCAL.ALL_OBJECTS["HeNormalV2"] = initializers.HeNormal
|
||
|
LOCAL.ALL_OBJECTS["HeUniformV2"] = initializers.HeUniform
|
||
|
LOCAL.ALL_OBJECTS["IdentityV2"] = initializers.Identity
|
||
|
LOCAL.ALL_OBJECTS["LecunNormalV2"] = initializers.LecunNormal
|
||
|
LOCAL.ALL_OBJECTS["LecunUniformV2"] = initializers.LecunUniform
|
||
|
LOCAL.ALL_OBJECTS["OnesV2"] = initializers.Ones
|
||
|
LOCAL.ALL_OBJECTS["OrthogonalV2"] = initializers.Orthogonal
|
||
|
LOCAL.ALL_OBJECTS["RandomNormalV2"] = initializers.RandomNormal
|
||
|
LOCAL.ALL_OBJECTS["RandomUniformV2"] = initializers.RandomUniform
|
||
|
LOCAL.ALL_OBJECTS["TruncatedNormalV2"] = initializers.TruncatedNormal
|
||
|
LOCAL.ALL_OBJECTS["VarianceScalingV2"] = initializers.VarianceScaling
|
||
|
LOCAL.ALL_OBJECTS["ZerosV2"] = initializers.Zeros
|
||
|
|
||
|
# Out of an abundance of caution we also include these aliases that have
|
||
|
# a non-zero probability of having been included in saved configs in the
|
||
|
# past.
|
||
|
LOCAL.ALL_OBJECTS["glorot_normalV2"] = initializers.GlorotNormal
|
||
|
LOCAL.ALL_OBJECTS["glorot_uniformV2"] = initializers.GlorotUniform
|
||
|
LOCAL.ALL_OBJECTS["he_normalV2"] = initializers.HeNormal
|
||
|
LOCAL.ALL_OBJECTS["he_uniformV2"] = initializers.HeUniform
|
||
|
LOCAL.ALL_OBJECTS["lecun_normalV2"] = initializers.LecunNormal
|
||
|
LOCAL.ALL_OBJECTS["lecun_uniformV2"] = initializers.LecunUniform
|
||
|
|
||
|
if tf.__internal__.tf2.enabled():
|
||
|
# For V2, entries are generated automatically based on the content of
|
||
|
# initializers.py.
|
||
|
v2_objs = {}
|
||
|
base_cls = initializers.Initializer
|
||
|
generic_utils.populate_dict_with_module_objects(
|
||
|
v2_objs,
|
||
|
[initializers],
|
||
|
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls),
|
||
|
)
|
||
|
for key, value in v2_objs.items():
|
||
|
LOCAL.ALL_OBJECTS[key] = value
|
||
|
# Functional aliases.
|
||
|
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
|
||
|
else:
|
||
|
# V1 initializers.
|
||
|
v1_objs = {
|
||
|
"Constant": tf.compat.v1.constant_initializer,
|
||
|
"GlorotNormal": tf.compat.v1.glorot_normal_initializer,
|
||
|
"GlorotUniform": tf.compat.v1.glorot_uniform_initializer,
|
||
|
"Identity": tf.compat.v1.initializers.identity,
|
||
|
"Ones": tf.compat.v1.ones_initializer,
|
||
|
"Orthogonal": tf.compat.v1.orthogonal_initializer,
|
||
|
"VarianceScaling": tf.compat.v1.variance_scaling_initializer,
|
||
|
"Zeros": tf.compat.v1.zeros_initializer,
|
||
|
"HeNormal": initializers_v1.HeNormal,
|
||
|
"HeUniform": initializers_v1.HeUniform,
|
||
|
"LecunNormal": initializers_v1.LecunNormal,
|
||
|
"LecunUniform": initializers_v1.LecunUniform,
|
||
|
"RandomNormal": initializers_v1.RandomNormal,
|
||
|
"RandomUniform": initializers_v1.RandomUniform,
|
||
|
"TruncatedNormal": initializers_v1.TruncatedNormal,
|
||
|
}
|
||
|
for key, value in v1_objs.items():
|
||
|
LOCAL.ALL_OBJECTS[key] = value
|
||
|
# Functional aliases.
|
||
|
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
|
||
|
|
||
|
# More compatibility aliases.
|
||
|
LOCAL.ALL_OBJECTS["normal"] = LOCAL.ALL_OBJECTS["random_normal"]
|
||
|
LOCAL.ALL_OBJECTS["uniform"] = LOCAL.ALL_OBJECTS["random_uniform"]
|
||
|
LOCAL.ALL_OBJECTS["one"] = LOCAL.ALL_OBJECTS["ones"]
|
||
|
LOCAL.ALL_OBJECTS["zero"] = LOCAL.ALL_OBJECTS["zeros"]
|
||
|
|
||
|
|
||
|
# For backwards compatibility, we populate this file with the objects
|
||
|
# from ALL_OBJECTS. We make no guarantees as to whether these objects will
|
||
|
# using their correct version.
|
||
|
populate_deserializable_objects()
|
||
|
globals().update(LOCAL.ALL_OBJECTS)
|
||
|
|
||
|
# Utility functions
|
||
|
|
||
|
|
||
|
@keras_export("keras.initializers.serialize")
|
||
|
def serialize(initializer, use_legacy_format=False):
|
||
|
if use_legacy_format:
|
||
|
return legacy_serialization.serialize_keras_object(initializer)
|
||
|
|
||
|
# To be replaced by new serialization_lib
|
||
|
return legacy_serialization.serialize_keras_object(initializer)
|
||
|
|
||
|
|
||
|
@keras_export("keras.initializers.deserialize")
|
||
|
def deserialize(config, custom_objects=None, use_legacy_format=False):
|
||
|
"""Return an `Initializer` object from its config."""
|
||
|
populate_deserializable_objects()
|
||
|
if use_legacy_format:
|
||
|
return legacy_serialization.deserialize_keras_object(
|
||
|
config,
|
||
|
module_objects=LOCAL.ALL_OBJECTS,
|
||
|
custom_objects=custom_objects,
|
||
|
printable_module_name="initializer",
|
||
|
)
|
||
|
|
||
|
# To be replaced by new serialization_lib
|
||
|
return legacy_serialization.deserialize_keras_object(
|
||
|
config,
|
||
|
module_objects=LOCAL.ALL_OBJECTS,
|
||
|
custom_objects=custom_objects,
|
||
|
printable_module_name="initializer",
|
||
|
)
|
||
|
|
||
|
|
||
|
@keras_export("keras.initializers.get")
|
||
|
def get(identifier):
|
||
|
"""Retrieve a Keras initializer by the identifier.
|
||
|
|
||
|
The `identifier` may be the string name of a initializers function or class
|
||
|
(case-sensitively).
|
||
|
|
||
|
>>> identifier = 'Ones'
|
||
|
>>> tf.keras.initializers.deserialize(identifier)
|
||
|
<...keras.initializers.initializers.Ones...>
|
||
|
|
||
|
You can also specify `config` of the initializer to this function by passing
|
||
|
dict containing `class_name` and `config` as an identifier. Also note that
|
||
|
the `class_name` must map to a `Initializer` class.
|
||
|
|
||
|
>>> cfg = {'class_name': 'Ones', 'config': {}}
|
||
|
>>> tf.keras.initializers.deserialize(cfg)
|
||
|
<...keras.initializers.initializers.Ones...>
|
||
|
|
||
|
In the case that the `identifier` is a class, this method will return a new
|
||
|
instance of the class by its constructor.
|
||
|
|
||
|
Args:
|
||
|
identifier: String or dict that contains the initializer name or
|
||
|
configurations.
|
||
|
|
||
|
Returns:
|
||
|
Initializer instance base on the input identifier.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the input identifier is not a supported type or in a bad
|
||
|
format.
|
||
|
"""
|
||
|
|
||
|
if identifier is None:
|
||
|
return None
|
||
|
if isinstance(identifier, dict):
|
||
|
use_legacy_format = "module" not in identifier
|
||
|
return deserialize(identifier, use_legacy_format=use_legacy_format)
|
||
|
elif isinstance(identifier, str):
|
||
|
identifier = str(identifier)
|
||
|
return deserialize(identifier)
|
||
|
elif callable(identifier):
|
||
|
if inspect.isclass(identifier):
|
||
|
identifier = identifier()
|
||
|
return identifier
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Could not interpret initializer identifier: " + str(identifier)
|
||
|
)
|