3RNN/Lib/site-packages/tensorflow/python/feature_column/serialization.py
2024-05-26 19:49:15 +02:00

354 lines
13 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FeatureColumn serialization, deserialization logic."""
import six
from tensorflow.python.feature_column import feature_column_v2_types as fc_types
from tensorflow.python.ops import init_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
_FEATURE_COLUMN_DEPRECATION_WARNING = """\
Warning: tf.feature_column is not recommended for new code. Instead,
feature preprocessing can be done directly using either [Keras preprocessing
layers](https://www.tensorflow.org/guide/migrate/migrating_feature_columns)
or through the one-stop utility [`tf.keras.utils.FeatureSpace`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/FeatureSpace)
built on top of them. See the [migration guide](https://tensorflow.org/guide/migrate)
for details.
"""
_FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING = (
'Use Keras preprocessing layers instead, either directly or via the '
'`tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has '
'a functional equivalent in `tf.keras.layers` for feature preprocessing '
'when training a Keras model.')
_FEATURE_COLUMNS = [init_ops.TruncatedNormal]
@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
@tf_export(
'__internal__.feature_column.serialize_feature_column',
v1=[])
@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
def serialize_feature_column(fc):
"""Serializes a FeatureColumn or a raw string key.
This method should only be used to serialize parent FeatureColumns when
implementing FeatureColumn.get_config(), else serialize_feature_columns()
is preferable.
This serialization also keeps information of the FeatureColumn class, so
deserialization is possible without knowing the class type. For example:
a = numeric_column('x')
a.get_config() gives:
{
'key': 'price',
'shape': (1,),
'default_value': None,
'dtype': 'float32',
'normalizer_fn': None
}
While serialize_feature_column(a) gives:
{
'class_name': 'NumericColumn',
'config': {
'key': 'price',
'shape': (1,),
'default_value': None,
'dtype': 'float32',
'normalizer_fn': None
}
}
Args:
fc: A FeatureColumn or raw feature key string.
Returns:
Keras serialization for FeatureColumns, leaves string keys unaffected.
Raises:
ValueError if called with input that is not string or FeatureColumn.
"""
if isinstance(fc, six.string_types):
return fc
elif isinstance(fc, fc_types.FeatureColumn):
return {'class_name': fc.__class__.__name__, 'config': fc.get_config()}
else:
raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
@tf_export('__internal__.feature_column.deserialize_feature_column', v1=[])
def deserialize_feature_column(config,
custom_objects=None,
columns_by_name=None):
"""Deserializes a `config` generated with `serialize_feature_column`.
This method should only be used to deserialize parent FeatureColumns when
implementing FeatureColumn.from_config(), else deserialize_feature_columns()
is preferable. Returns a FeatureColumn for this config.
Args:
config: A Dict with the serialization of feature columns acquired by
`serialize_feature_column`, or a string representing a raw column.
custom_objects: A Dict from custom_object name to the associated keras
serializable objects (FeatureColumns, classes or functions).
columns_by_name: A Dict[String, FeatureColumn] of existing columns in order
to avoid duplication.
Raises:
ValueError if `config` has invalid format (e.g: expected keys missing,
or refers to unknown classes).
Returns:
A FeatureColumn corresponding to the input `config`.
"""
# TODO(b/118939620): Simplify code if Keras utils support object deduping.
if isinstance(config, six.string_types):
return config
# A dict from class_name to class for all FeatureColumns in this module.
# FeatureColumns not part of the module can be passed as custom_objects.
module_feature_column_classes = {
cls.__name__: cls for cls in _FEATURE_COLUMNS
}
if columns_by_name is None:
columns_by_name = {}
(cls, cls_config) = _class_and_config_for_serialized_keras_object(
config,
module_objects=module_feature_column_classes,
custom_objects=custom_objects,
printable_module_name='feature_column_v2')
if not issubclass(cls, fc_types.FeatureColumn):
raise ValueError(
'Expected FeatureColumn class, instead found: {}'.format(cls))
# Always deserialize the FeatureColumn, in order to get the name.
new_instance = cls.from_config( # pylint: disable=protected-access
cls_config,
custom_objects=custom_objects,
columns_by_name=columns_by_name)
# If the name already exists, re-use the column from columns_by_name,
# (new_instance remains unused).
return columns_by_name.setdefault(
_column_name_with_class_name(new_instance), new_instance)
def serialize_feature_columns(feature_columns):
"""Serializes a list of FeatureColumns.
Returns a list of Keras-style config dicts that represent the input
FeatureColumns and can be used with `deserialize_feature_columns` for
reconstructing the original columns.
Args:
feature_columns: A list of FeatureColumns.
Returns:
Keras serialization for the list of FeatureColumns.
Raises:
ValueError if called with input that is not a list of FeatureColumns.
"""
return [serialize_feature_column(fc) for fc in feature_columns]
def deserialize_feature_columns(configs, custom_objects=None):
"""Deserializes a list of FeatureColumns configs.
Returns a list of FeatureColumns given a list of config dicts acquired by
`serialize_feature_columns`.
Args:
configs: A list of Dicts with the serialization of feature columns acquired
by `serialize_feature_columns`.
custom_objects: A Dict from custom_object name to the associated keras
serializable objects (FeatureColumns, classes or functions).
Returns:
FeatureColumn objects corresponding to the input configs.
Raises:
ValueError if called with input that is not a list of FeatureColumns.
"""
columns_by_name = {}
return [
deserialize_feature_column(c, custom_objects, columns_by_name)
for c in configs
]
def _column_name_with_class_name(fc):
"""Returns a unique name for the feature column used during deduping.
Without this two FeatureColumns that have the same name and where
one wraps the other, such as an IndicatorColumn wrapping a
SequenceCategoricalColumn, will fail to deserialize because they will have the
same name in columns_by_name, causing the wrong column to be returned.
Args:
fc: A FeatureColumn.
Returns:
A unique name as a string.
"""
return fc.__class__.__name__ + ':' + fc.name
def _serialize_keras_object(instance):
"""Serialize a Keras object into a JSON-compatible representation."""
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
if hasattr(instance, 'get_config'):
name = instance.__class__.__name__
config = instance.get_config()
serialization_config = {}
for key, item in config.items():
if isinstance(item, six.string_types):
serialization_config[key] = item
continue
# Any object of a different type needs to be converted to string or dict
# for serialization (e.g. custom functions, custom classes)
try:
serialized_item = _serialize_keras_object(item)
if isinstance(serialized_item, dict) and not isinstance(item, dict):
serialized_item['__passive_serialization__'] = True
serialization_config[key] = serialized_item
except ValueError:
serialization_config[key] = item
return {'class_name': name, 'config': serialization_config}
if hasattr(instance, '__name__'):
return instance.__name__
raise ValueError('Cannot serialize', instance)
def _deserialize_keras_object(identifier,
module_objects=None,
custom_objects=None,
printable_module_name='object'):
"""Turns the serialized form of a Keras object back into an actual object."""
if identifier is None:
return None
if isinstance(identifier, dict):
# In this case we are dealing with a Keras config dictionary.
config = identifier
(cls, cls_config) = _class_and_config_for_serialized_keras_object(
config, module_objects, custom_objects, printable_module_name)
if hasattr(cls, 'from_config'):
arg_spec = tf_inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}
if 'custom_objects' in arg_spec.args:
return cls.from_config(
cls_config, custom_objects=dict(list(custom_objects.items())))
return cls.from_config(cls_config)
else:
# Then `cls` may be a function returning a class.
# in this case by convention `config` holds
# the kwargs of the function.
custom_objects = custom_objects or {}
return cls(**cls_config)
elif isinstance(identifier, six.string_types):
object_name = identifier
if custom_objects and object_name in custom_objects:
obj = custom_objects.get(object_name)
else:
obj = module_objects.get(object_name)
if obj is None:
raise ValueError('Unknown ' + printable_module_name + ': ' +
object_name)
# Classes passed by name are instantiated with no args, functions are
# returned as-is.
if tf_inspect.isclass(obj):
return obj()
return obj
elif tf_inspect.isfunction(identifier):
# If a function has already been deserialized, return as is.
return identifier
else:
raise ValueError('Could not interpret serialized %s: %s' %
(printable_module_name, identifier))
def _class_and_config_for_serialized_keras_object(
config,
module_objects=None,
custom_objects=None,
printable_module_name='object'):
"""Returns the class name and config for a serialized keras object."""
if (not isinstance(config, dict) or 'class_name' not in config or
'config' not in config):
raise ValueError('Improper config format: ' + str(config))
class_name = config['class_name']
cls = _get_registered_object(
class_name, custom_objects=custom_objects, module_objects=module_objects)
if cls is None:
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
cls_config = config['config']
deserialized_objects = {}
for key, item in cls_config.items():
if isinstance(item, dict) and '__passive_serialization__' in item:
deserialized_objects[key] = _deserialize_keras_object(
item,
module_objects=module_objects,
custom_objects=custom_objects,
printable_module_name='config_item')
elif (isinstance(item, six.string_types) and
tf_inspect.isfunction(_get_registered_object(item, custom_objects))):
# Handle custom functions here. When saving functions, we only save the
# function's name as a string. If we find a matching string in the custom
# objects during deserialization, we convert the string back to the
# original function.
# Note that a potential issue is that a string field could have a naming
# conflict with a custom function name, but this should be a rare case.
# This issue does not occur if a string field has a naming conflict with
# a custom object, since the config of an object will always be a dict.
deserialized_objects[key] = _get_registered_object(item, custom_objects)
for key, item in deserialized_objects.items():
cls_config[key] = deserialized_objects[key]
return (cls, cls_config)
def _get_registered_object(name, custom_objects=None, module_objects=None):
if custom_objects and name in custom_objects:
return custom_objects[name]
elif module_objects and name in module_objects:
return module_objects[name]
return None
def register_feature_column(fc):
"""Decorator that registers a FeatureColumn for serialization."""
_FEATURE_COLUMNS.append(fc)
return fc