238 lines
8.0 KiB
Python
238 lines
8.0 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Python utilities required by Keras."""
|
||
|
|
||
|
import inspect
|
||
|
import threading
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
_GLOBAL_CUSTOM_OBJECTS = {}
|
||
|
_GLOBAL_CUSTOM_NAMES = {}
|
||
|
# Thread-local custom objects set by custom_object_scope.
|
||
|
_THREAD_LOCAL_CUSTOM_OBJECTS = threading.local()
|
||
|
|
||
|
|
||
|
@keras_export(
|
||
|
"keras.saving.custom_object_scope",
|
||
|
"keras.utils.custom_object_scope",
|
||
|
"keras.utils.CustomObjectScope",
|
||
|
)
|
||
|
class CustomObjectScope:
|
||
|
"""Exposes custom classes/functions to Keras deserialization internals.
|
||
|
|
||
|
Under a scope `with custom_object_scope(objects_dict)`, Keras methods such
|
||
|
as `tf.keras.models.load_model` or `tf.keras.models.model_from_config`
|
||
|
will be able to deserialize any custom object referenced by a
|
||
|
saved config (e.g. a custom layer or metric).
|
||
|
|
||
|
Example:
|
||
|
|
||
|
Consider a custom regularizer `my_regularizer`:
|
||
|
|
||
|
```python
|
||
|
layer = Dense(3, kernel_regularizer=my_regularizer)
|
||
|
# Config contains a reference to `my_regularizer`
|
||
|
config = layer.get_config()
|
||
|
...
|
||
|
# Later:
|
||
|
with custom_object_scope({'my_regularizer': my_regularizer}):
|
||
|
layer = Dense.from_config(config)
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
*args: Dictionary or dictionaries of `{name: object}` pairs.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, *args):
|
||
|
self.custom_objects = args
|
||
|
self.backup = None
|
||
|
|
||
|
def __enter__(self):
|
||
|
self.backup = _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.copy()
|
||
|
for objects in self.custom_objects:
|
||
|
_THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.update(objects)
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, *args, **kwargs):
|
||
|
_THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.clear()
|
||
|
_THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.update(self.backup)
|
||
|
|
||
|
|
||
|
@keras_export(
|
||
|
"keras.saving.get_custom_objects", "keras.utils.get_custom_objects"
|
||
|
)
|
||
|
def get_custom_objects():
|
||
|
"""Retrieves a live reference to the global dictionary of custom objects.
|
||
|
|
||
|
Custom objects set using using `custom_object_scope` are not added to the
|
||
|
global dictionary of custom objects, and will not appear in the returned
|
||
|
dictionary.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
get_custom_objects().clear()
|
||
|
get_custom_objects()['MyObject'] = MyObject
|
||
|
```
|
||
|
|
||
|
Returns:
|
||
|
Global dictionary mapping registered class names to classes.
|
||
|
"""
|
||
|
return _GLOBAL_CUSTOM_OBJECTS
|
||
|
|
||
|
|
||
|
@keras_export(
|
||
|
"keras.saving.register_keras_serializable",
|
||
|
"keras.utils.register_keras_serializable",
|
||
|
)
|
||
|
def register_keras_serializable(package="Custom", name=None):
|
||
|
"""Registers an object with the Keras serialization framework.
|
||
|
|
||
|
This decorator injects the decorated class or function into the Keras custom
|
||
|
object dictionary, so that it can be serialized and deserialized without
|
||
|
needing an entry in the user-provided custom object dict. It also injects a
|
||
|
function that Keras will call to get the object's serializable string key.
|
||
|
|
||
|
Note that to be serialized and deserialized, classes must implement the
|
||
|
`get_config()` method. Functions do not have this requirement.
|
||
|
|
||
|
The object will be registered under the key 'package>name' where `name`,
|
||
|
defaults to the object name if not passed.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
# Note that `'my_package'` is used as the `package` argument here, and since
|
||
|
# the `name` argument is not provided, `'MyDense'` is used as the `name`.
|
||
|
@keras.saving.register_keras_serializable('my_package')
|
||
|
class MyDense(keras.layers.Dense):
|
||
|
pass
|
||
|
|
||
|
assert keras.saving.get_registered_object('my_package>MyDense') == MyDense
|
||
|
assert keras.saving.get_registered_name(MyDense) == 'my_package>MyDense'
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
package: The package that this class belongs to. This is used for the
|
||
|
`key` (which is `"package>name"`) to idenfify the class. Note that this
|
||
|
is the first argument passed into the decorator.
|
||
|
name: The name to serialize this class under in this package. If not
|
||
|
provided or `None`, the class' name will be used (note that this is the
|
||
|
case when the decorator is used with only one argument, which becomes
|
||
|
the `package`).
|
||
|
|
||
|
Returns:
|
||
|
A decorator that registers the decorated class with the passed names.
|
||
|
"""
|
||
|
|
||
|
def decorator(arg):
|
||
|
"""Registers a class with the Keras serialization framework."""
|
||
|
class_name = name if name is not None else arg.__name__
|
||
|
registered_name = package + ">" + class_name
|
||
|
|
||
|
if inspect.isclass(arg) and not hasattr(arg, "get_config"):
|
||
|
raise ValueError(
|
||
|
"Cannot register a class that does not have a "
|
||
|
"get_config() method."
|
||
|
)
|
||
|
|
||
|
if registered_name in _GLOBAL_CUSTOM_OBJECTS:
|
||
|
raise ValueError(
|
||
|
f"{registered_name} has already been registered to "
|
||
|
f"{_GLOBAL_CUSTOM_OBJECTS[registered_name]}"
|
||
|
)
|
||
|
|
||
|
if arg in _GLOBAL_CUSTOM_NAMES:
|
||
|
raise ValueError(
|
||
|
f"{arg} has already been registered to "
|
||
|
f"{_GLOBAL_CUSTOM_NAMES[arg]}"
|
||
|
)
|
||
|
_GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
|
||
|
_GLOBAL_CUSTOM_NAMES[arg] = registered_name
|
||
|
|
||
|
return arg
|
||
|
|
||
|
return decorator
|
||
|
|
||
|
|
||
|
@keras_export(
|
||
|
"keras.saving.get_registered_name", "keras.utils.get_registered_name"
|
||
|
)
|
||
|
def get_registered_name(obj):
|
||
|
"""Returns the name registered to an object within the Keras framework.
|
||
|
|
||
|
This function is part of the Keras serialization and deserialization
|
||
|
framework. It maps objects to the string names associated with those objects
|
||
|
for serialization/deserialization.
|
||
|
|
||
|
Args:
|
||
|
obj: The object to look up.
|
||
|
|
||
|
Returns:
|
||
|
The name associated with the object, or the default Python name if the
|
||
|
object is not registered.
|
||
|
"""
|
||
|
if obj in _GLOBAL_CUSTOM_NAMES:
|
||
|
return _GLOBAL_CUSTOM_NAMES[obj]
|
||
|
else:
|
||
|
return obj.__name__
|
||
|
|
||
|
|
||
|
@keras_export(
|
||
|
"keras.saving.get_registered_object", "keras.utils.get_registered_object"
|
||
|
)
|
||
|
def get_registered_object(name, custom_objects=None, module_objects=None):
|
||
|
"""Returns the class associated with `name` if it is registered with Keras.
|
||
|
|
||
|
This function is part of the Keras serialization and deserialization
|
||
|
framework. It maps strings to the objects associated with them for
|
||
|
serialization/deserialization.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
def from_config(cls, config, custom_objects=None):
|
||
|
if 'my_custom_object_name' in config:
|
||
|
config['hidden_cls'] = tf.keras.saving.get_registered_object(
|
||
|
config['my_custom_object_name'], custom_objects=custom_objects)
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
name: The name to look up.
|
||
|
custom_objects: A dictionary of custom objects to look the name up in.
|
||
|
Generally, custom_objects is provided by the user.
|
||
|
module_objects: A dictionary of custom objects to look the name up in.
|
||
|
Generally, module_objects is provided by midlevel library implementers.
|
||
|
|
||
|
Returns:
|
||
|
An instantiable class associated with `name`, or `None` if no such class
|
||
|
exists.
|
||
|
"""
|
||
|
if name in _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__:
|
||
|
return _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[name]
|
||
|
elif name in _GLOBAL_CUSTOM_OBJECTS:
|
||
|
return _GLOBAL_CUSTOM_OBJECTS[name]
|
||
|
elif custom_objects and name in custom_objects:
|
||
|
return custom_objects[name]
|
||
|
elif module_objects and name in module_objects:
|
||
|
return module_objects[name]
|
||
|
return None
|
||
|
|
||
|
|
||
|
# Aliases
|
||
|
custom_object_scope = CustomObjectScope
|