# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """FeatureColumn serialization, deserialization logic.""" import six from tensorflow.python.feature_column import feature_column_v2_types as fc_types from tensorflow.python.ops import init_ops from tensorflow.python.util import deprecation from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export from tensorflow.tools.docs import doc_controls _FEATURE_COLUMN_DEPRECATION_WARNING = """\ Warning: tf.feature_column is not recommended for new code. Instead, feature preprocessing can be done directly using either [Keras preprocessing layers](https://www.tensorflow.org/guide/migrate/migrating_feature_columns) or through the one-stop utility [`tf.keras.utils.FeatureSpace`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/FeatureSpace) built on top of them. See the [migration guide](https://tensorflow.org/guide/migrate) for details. """ _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING = ( 'Use Keras preprocessing layers instead, either directly or via the ' '`tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has ' 'a functional equivalent in `tf.keras.layers` for feature preprocessing ' 'when training a Keras model.') _FEATURE_COLUMNS = [init_ops.TruncatedNormal] @doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING) @tf_export( '__internal__.feature_column.serialize_feature_column', v1=[]) @deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING) def serialize_feature_column(fc): """Serializes a FeatureColumn or a raw string key. This method should only be used to serialize parent FeatureColumns when implementing FeatureColumn.get_config(), else serialize_feature_columns() is preferable. This serialization also keeps information of the FeatureColumn class, so deserialization is possible without knowing the class type. For example: a = numeric_column('x') a.get_config() gives: { 'key': 'price', 'shape': (1,), 'default_value': None, 'dtype': 'float32', 'normalizer_fn': None } While serialize_feature_column(a) gives: { 'class_name': 'NumericColumn', 'config': { 'key': 'price', 'shape': (1,), 'default_value': None, 'dtype': 'float32', 'normalizer_fn': None } } Args: fc: A FeatureColumn or raw feature key string. Returns: Keras serialization for FeatureColumns, leaves string keys unaffected. Raises: ValueError if called with input that is not string or FeatureColumn. """ if isinstance(fc, six.string_types): return fc elif isinstance(fc, fc_types.FeatureColumn): return {'class_name': fc.__class__.__name__, 'config': fc.get_config()} else: raise ValueError('Instance: {} is not a FeatureColumn'.format(fc)) @doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING) @tf_export('__internal__.feature_column.deserialize_feature_column', v1=[]) def deserialize_feature_column(config, custom_objects=None, columns_by_name=None): """Deserializes a `config` generated with `serialize_feature_column`. This method should only be used to deserialize parent FeatureColumns when implementing FeatureColumn.from_config(), else deserialize_feature_columns() is preferable. Returns a FeatureColumn for this config. Args: config: A Dict with the serialization of feature columns acquired by `serialize_feature_column`, or a string representing a raw column. custom_objects: A Dict from custom_object name to the associated keras serializable objects (FeatureColumns, classes or functions). columns_by_name: A Dict[String, FeatureColumn] of existing columns in order to avoid duplication. Raises: ValueError if `config` has invalid format (e.g: expected keys missing, or refers to unknown classes). Returns: A FeatureColumn corresponding to the input `config`. """ # TODO(b/118939620): Simplify code if Keras utils support object deduping. if isinstance(config, six.string_types): return config # A dict from class_name to class for all FeatureColumns in this module. # FeatureColumns not part of the module can be passed as custom_objects. module_feature_column_classes = { cls.__name__: cls for cls in _FEATURE_COLUMNS } if columns_by_name is None: columns_by_name = {} (cls, cls_config) = _class_and_config_for_serialized_keras_object( config, module_objects=module_feature_column_classes, custom_objects=custom_objects, printable_module_name='feature_column_v2') if not issubclass(cls, fc_types.FeatureColumn): raise ValueError( 'Expected FeatureColumn class, instead found: {}'.format(cls)) # Always deserialize the FeatureColumn, in order to get the name. new_instance = cls.from_config( # pylint: disable=protected-access cls_config, custom_objects=custom_objects, columns_by_name=columns_by_name) # If the name already exists, re-use the column from columns_by_name, # (new_instance remains unused). return columns_by_name.setdefault( _column_name_with_class_name(new_instance), new_instance) def serialize_feature_columns(feature_columns): """Serializes a list of FeatureColumns. Returns a list of Keras-style config dicts that represent the input FeatureColumns and can be used with `deserialize_feature_columns` for reconstructing the original columns. Args: feature_columns: A list of FeatureColumns. Returns: Keras serialization for the list of FeatureColumns. Raises: ValueError if called with input that is not a list of FeatureColumns. """ return [serialize_feature_column(fc) for fc in feature_columns] def deserialize_feature_columns(configs, custom_objects=None): """Deserializes a list of FeatureColumns configs. Returns a list of FeatureColumns given a list of config dicts acquired by `serialize_feature_columns`. Args: configs: A list of Dicts with the serialization of feature columns acquired by `serialize_feature_columns`. custom_objects: A Dict from custom_object name to the associated keras serializable objects (FeatureColumns, classes or functions). Returns: FeatureColumn objects corresponding to the input configs. Raises: ValueError if called with input that is not a list of FeatureColumns. """ columns_by_name = {} return [ deserialize_feature_column(c, custom_objects, columns_by_name) for c in configs ] def _column_name_with_class_name(fc): """Returns a unique name for the feature column used during deduping. Without this two FeatureColumns that have the same name and where one wraps the other, such as an IndicatorColumn wrapping a SequenceCategoricalColumn, will fail to deserialize because they will have the same name in columns_by_name, causing the wrong column to be returned. Args: fc: A FeatureColumn. Returns: A unique name as a string. """ return fc.__class__.__name__ + ':' + fc.name def _serialize_keras_object(instance): """Serialize a Keras object into a JSON-compatible representation.""" _, instance = tf_decorator.unwrap(instance) if instance is None: return None if hasattr(instance, 'get_config'): name = instance.__class__.__name__ config = instance.get_config() serialization_config = {} for key, item in config.items(): if isinstance(item, six.string_types): serialization_config[key] = item continue # Any object of a different type needs to be converted to string or dict # for serialization (e.g. custom functions, custom classes) try: serialized_item = _serialize_keras_object(item) if isinstance(serialized_item, dict) and not isinstance(item, dict): serialized_item['__passive_serialization__'] = True serialization_config[key] = serialized_item except ValueError: serialization_config[key] = item return {'class_name': name, 'config': serialization_config} if hasattr(instance, '__name__'): return instance.__name__ raise ValueError('Cannot serialize', instance) def _deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): """Turns the serialized form of a Keras object back into an actual object.""" if identifier is None: return None if isinstance(identifier, dict): # In this case we are dealing with a Keras config dictionary. config = identifier (cls, cls_config) = _class_and_config_for_serialized_keras_object( config, module_objects, custom_objects, printable_module_name) if hasattr(cls, 'from_config'): arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: return cls.from_config( cls_config, custom_objects=dict(list(custom_objects.items()))) return cls.from_config(cls_config) else: # Then `cls` may be a function returning a class. # in this case by convention `config` holds # the kwargs of the function. custom_objects = custom_objects or {} return cls(**cls_config) elif isinstance(identifier, six.string_types): object_name = identifier if custom_objects and object_name in custom_objects: obj = custom_objects.get(object_name) else: obj = module_objects.get(object_name) if obj is None: raise ValueError('Unknown ' + printable_module_name + ': ' + object_name) # Classes passed by name are instantiated with no args, functions are # returned as-is. if tf_inspect.isclass(obj): return obj() return obj elif tf_inspect.isfunction(identifier): # If a function has already been deserialized, return as is. return identifier else: raise ValueError('Could not interpret serialized %s: %s' % (printable_module_name, identifier)) def _class_and_config_for_serialized_keras_object( config, module_objects=None, custom_objects=None, printable_module_name='object'): """Returns the class name and config for a serialized keras object.""" if (not isinstance(config, dict) or 'class_name' not in config or 'config' not in config): raise ValueError('Improper config format: ' + str(config)) class_name = config['class_name'] cls = _get_registered_object( class_name, custom_objects=custom_objects, module_objects=module_objects) if cls is None: raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) cls_config = config['config'] deserialized_objects = {} for key, item in cls_config.items(): if isinstance(item, dict) and '__passive_serialization__' in item: deserialized_objects[key] = _deserialize_keras_object( item, module_objects=module_objects, custom_objects=custom_objects, printable_module_name='config_item') elif (isinstance(item, six.string_types) and tf_inspect.isfunction(_get_registered_object(item, custom_objects))): # Handle custom functions here. When saving functions, we only save the # function's name as a string. If we find a matching string in the custom # objects during deserialization, we convert the string back to the # original function. # Note that a potential issue is that a string field could have a naming # conflict with a custom function name, but this should be a rare case. # This issue does not occur if a string field has a naming conflict with # a custom object, since the config of an object will always be a dict. deserialized_objects[key] = _get_registered_object(item, custom_objects) for key, item in deserialized_objects.items(): cls_config[key] = deserialized_objects[key] return (cls, cls_config) def _get_registered_object(name, custom_objects=None, module_objects=None): if custom_objects and name in custom_objects: return custom_objects[name] elif module_objects and name in module_objects: return module_objects[name] return None def register_feature_column(fc): """Decorator that registers a FeatureColumn for serialization.""" _FEATURE_COLUMNS.append(fc) return fc