# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """This API defines FeatureColumn for sequential input. NOTE: This API is a work in progress and will likely be changing frequently. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v2 as tf from keras import backend from keras.feature_column import base_feature_layer as kfc # isort: off from tensorflow.python.util.tf_export import keras_export @keras_export("keras.experimental.SequenceFeatures") class SequenceFeatures(kfc._BaseFeaturesLayer): """A layer for sequence input. All `feature_columns` must be sequence dense columns with the same `sequence_length`. The output of this method can be fed into sequence networks, such as RNN. The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. `T` is the maximum sequence length for this batch, which could differ from batch to batch. If multiple `feature_columns` are given with `Di` `num_elements` each, their outputs are concatenated. So, the final `Tensor` has shape `[batch_size, T, D0 + D1 + ... + Dn]`. Example: ```python import tensorflow as tf # Behavior of some cells or feature columns may depend on whether we are in # training or inference mode, e.g. applying dropout. training = True rating = tf.feature_column.sequence_numeric_column('rating') watches = tf.feature_column.sequence_categorical_column_with_identity( 'watches', num_buckets=1000) watches_embedding = tf.feature_column.embedding_column(watches, dimension=10) columns = [rating, watches_embedding] features = { 'rating': tf.sparse.from_dense([[1.0,1.1, 0, 0, 0], [2.0,2.1,2.2, 2.3, 2.5]]), 'watches': tf.sparse.from_dense([[2, 85, 0, 0, 0],[33,78, 2, 73, 1]]) } sequence_input_layer = tf.keras.experimental.SequenceFeatures(columns) sequence_input, sequence_length = sequence_input_layer( features, training=training) sequence_length_mask = tf.sequence_mask(sequence_length) hidden_size = 32 rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) rnn_layer = tf.keras.layers.RNN(rnn_cell) outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` """ def __init__(self, feature_columns, trainable=True, name=None, **kwargs): """ "Constructs a SequenceFeatures layer. Args: feature_columns: An iterable of dense sequence columns. Valid columns are - `embedding_column` that wraps a `sequence_categorical_column_with_*` - `sequence_numeric_column`. trainable: Boolean, whether the layer's variables will be updated via gradient descent during training. name: Name to give to the SequenceFeatures. **kwargs: Keyword arguments to construct a layer. Raises: ValueError: If any of the `feature_columns` is not a `SequenceDenseColumn`. """ super().__init__( feature_columns=feature_columns, trainable=trainable, name=name, expected_column_type=tf.__internal__.feature_column.SequenceDenseColumn, # noqa: E501 **kwargs ) @property def _is_feature_layer(self): return True def _target_shape(self, input_shape, total_elements): return (input_shape[0], input_shape[1], total_elements) def call(self, features, training=None): """Returns sequence input corresponding to the `feature_columns`. Args: features: A dict mapping keys to tensors. training: Python boolean or None, indicating whether to the layer is being run in training mode. This argument is passed to the call method of any `FeatureColumn` that takes a `training` argument. For example, if a `FeatureColumn` performed dropout, the column could expose a `training` argument to control whether the dropout should be applied. If `None`, defaults to `tf.keras.backend.learning_phase()`. Returns: An `(input_layer, sequence_length)` tuple where: - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. `T` is the maximum sequence length for this batch, which could differ from batch to batch. `D` is the sum of `num_elements` for all `feature_columns`. - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence length for each example. Raises: ValueError: If features are not a dictionary. """ if not isinstance(features, dict): raise ValueError( "We expected a dictionary here. Instead we got: ", features ) if training is None: training = backend.learning_phase() transformation_cache = ( tf.__internal__.feature_column.FeatureTransformationCache(features) ) output_tensors = [] sequence_lengths = [] for column in self._feature_columns: with backend.name_scope(column.name): try: ( dense_tensor, sequence_length, ) = column.get_sequence_dense_tensor( transformation_cache, self._state_manager, training=training, ) except TypeError: ( dense_tensor, sequence_length, ) = column.get_sequence_dense_tensor( transformation_cache, self._state_manager ) # Flattens the final dimension to produce a 3D Tensor. output_tensors.append( self._process_dense_tensor(column, dense_tensor) ) sequence_lengths.append(sequence_length) # Check and process sequence lengths. kfc._verify_static_batch_size_equality( sequence_lengths, self._feature_columns ) sequence_length = _assert_all_equal_and_return(sequence_lengths) return self._verify_and_concat_tensors(output_tensors), sequence_length def _assert_all_equal_and_return(tensors, name=None): """Asserts that all tensors are equal and returns the first one.""" with backend.name_scope(name or "assert_all_equal"): if len(tensors) == 1: return tensors[0] assert_equal_ops = [] for t in tensors[1:]: assert_equal_ops.append(tf.compat.v1.assert_equal(tensors[0], t)) with tf.control_dependencies(assert_equal_ops): return tf.identity(tensors[0])