Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/initializers/__init__.py

216 lines
8.4 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras initializer serialization / deserialization."""
import threading
import tensorflow.compat.v2 as tf
from keras.initializers import initializers
from keras.initializers import initializers_v1
from keras.saving.legacy import serialization as legacy_serialization
from keras.utils import generic_utils
from keras.utils import tf_inspect as inspect
# isort: off
from tensorflow.python import tf2
from tensorflow.python.ops import init_ops
from tensorflow.python.util.tf_export import keras_export
# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
# thread-local to avoid concurrent mutations.
LOCAL = threading.local()
def populate_deserializable_objects():
"""Populates dict ALL_OBJECTS with every built-in initializer."""
global LOCAL
if not hasattr(LOCAL, "ALL_OBJECTS"):
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = None
if (
LOCAL.ALL_OBJECTS
and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()
):
# Objects dict is already generated for the proper TF version:
# do nothing.
return
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()
# Compatibility aliases (need to exist in both V1 and V2).
LOCAL.ALL_OBJECTS["ConstantV2"] = initializers.Constant
LOCAL.ALL_OBJECTS["GlorotNormalV2"] = initializers.GlorotNormal
LOCAL.ALL_OBJECTS["GlorotUniformV2"] = initializers.GlorotUniform
LOCAL.ALL_OBJECTS["HeNormalV2"] = initializers.HeNormal
LOCAL.ALL_OBJECTS["HeUniformV2"] = initializers.HeUniform
LOCAL.ALL_OBJECTS["IdentityV2"] = initializers.Identity
LOCAL.ALL_OBJECTS["LecunNormalV2"] = initializers.LecunNormal
LOCAL.ALL_OBJECTS["LecunUniformV2"] = initializers.LecunUniform
LOCAL.ALL_OBJECTS["OnesV2"] = initializers.Ones
LOCAL.ALL_OBJECTS["OrthogonalV2"] = initializers.Orthogonal
LOCAL.ALL_OBJECTS["RandomNormalV2"] = initializers.RandomNormal
LOCAL.ALL_OBJECTS["RandomUniformV2"] = initializers.RandomUniform
LOCAL.ALL_OBJECTS["TruncatedNormalV2"] = initializers.TruncatedNormal
LOCAL.ALL_OBJECTS["VarianceScalingV2"] = initializers.VarianceScaling
LOCAL.ALL_OBJECTS["ZerosV2"] = initializers.Zeros
# Out of an abundance of caution we also include these aliases that have
# a non-zero probability of having been included in saved configs in the
# past.
LOCAL.ALL_OBJECTS["glorot_normalV2"] = initializers.GlorotNormal
LOCAL.ALL_OBJECTS["glorot_uniformV2"] = initializers.GlorotUniform
LOCAL.ALL_OBJECTS["he_normalV2"] = initializers.HeNormal
LOCAL.ALL_OBJECTS["he_uniformV2"] = initializers.HeUniform
LOCAL.ALL_OBJECTS["lecun_normalV2"] = initializers.LecunNormal
LOCAL.ALL_OBJECTS["lecun_uniformV2"] = initializers.LecunUniform
if tf.__internal__.tf2.enabled():
# For V2, entries are generated automatically based on the content of
# initializers.py.
v2_objs = {}
base_cls = initializers.Initializer
generic_utils.populate_dict_with_module_objects(
v2_objs,
[initializers],
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls),
)
for key, value in v2_objs.items():
LOCAL.ALL_OBJECTS[key] = value
# Functional aliases.
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
else:
# V1 initializers.
v1_objs = {
"Constant": tf.compat.v1.constant_initializer,
"GlorotNormal": tf.compat.v1.glorot_normal_initializer,
"GlorotUniform": tf.compat.v1.glorot_uniform_initializer,
"Identity": tf.compat.v1.initializers.identity,
"Ones": tf.compat.v1.ones_initializer,
"Orthogonal": tf.compat.v1.orthogonal_initializer,
"VarianceScaling": tf.compat.v1.variance_scaling_initializer,
"Zeros": tf.compat.v1.zeros_initializer,
"HeNormal": initializers_v1.HeNormal,
"HeUniform": initializers_v1.HeUniform,
"LecunNormal": initializers_v1.LecunNormal,
"LecunUniform": initializers_v1.LecunUniform,
"RandomNormal": initializers_v1.RandomNormal,
"RandomUniform": initializers_v1.RandomUniform,
"TruncatedNormal": initializers_v1.TruncatedNormal,
}
for key, value in v1_objs.items():
LOCAL.ALL_OBJECTS[key] = value
# Functional aliases.
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
# More compatibility aliases.
LOCAL.ALL_OBJECTS["normal"] = LOCAL.ALL_OBJECTS["random_normal"]
LOCAL.ALL_OBJECTS["uniform"] = LOCAL.ALL_OBJECTS["random_uniform"]
LOCAL.ALL_OBJECTS["one"] = LOCAL.ALL_OBJECTS["ones"]
LOCAL.ALL_OBJECTS["zero"] = LOCAL.ALL_OBJECTS["zeros"]
# For backwards compatibility, we populate this file with the objects
# from ALL_OBJECTS. We make no guarantees as to whether these objects will
# using their correct version.
populate_deserializable_objects()
globals().update(LOCAL.ALL_OBJECTS)
# Utility functions
@keras_export("keras.initializers.serialize")
def serialize(initializer, use_legacy_format=False):
if use_legacy_format:
return legacy_serialization.serialize_keras_object(initializer)
# To be replaced by new serialization_lib
return legacy_serialization.serialize_keras_object(initializer)
@keras_export("keras.initializers.deserialize")
def deserialize(config, custom_objects=None, use_legacy_format=False):
"""Return an `Initializer` object from its config."""
populate_deserializable_objects()
if use_legacy_format:
return legacy_serialization.deserialize_keras_object(
config,
module_objects=LOCAL.ALL_OBJECTS,
custom_objects=custom_objects,
printable_module_name="initializer",
)
# To be replaced by new serialization_lib
return legacy_serialization.deserialize_keras_object(
config,
module_objects=LOCAL.ALL_OBJECTS,
custom_objects=custom_objects,
printable_module_name="initializer",
)
@keras_export("keras.initializers.get")
def get(identifier):
"""Retrieve a Keras initializer by the identifier.
The `identifier` may be the string name of a initializers function or class
(case-sensitively).
>>> identifier = 'Ones'
>>> tf.keras.initializers.deserialize(identifier)
<...keras.initializers.initializers.Ones...>
You can also specify `config` of the initializer to this function by passing
dict containing `class_name` and `config` as an identifier. Also note that
the `class_name` must map to a `Initializer` class.
>>> cfg = {'class_name': 'Ones', 'config': {}}
>>> tf.keras.initializers.deserialize(cfg)
<...keras.initializers.initializers.Ones...>
In the case that the `identifier` is a class, this method will return a new
instance of the class by its constructor.
Args:
identifier: String or dict that contains the initializer name or
configurations.
Returns:
Initializer instance base on the input identifier.
Raises:
ValueError: If the input identifier is not a supported type or in a bad
format.
"""
if identifier is None:
return None
if isinstance(identifier, dict):
use_legacy_format = "module" not in identifier
return deserialize(identifier, use_legacy_format=use_legacy_format)
elif isinstance(identifier, str):
identifier = str(identifier)
return deserialize(identifier)
elif callable(identifier):
if inspect.isclass(identifier):
identifier = identifier()
return identifier
else:
raise ValueError(
"Could not interpret initializer identifier: " + str(identifier)
)