# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Python wrappers for Datasets and Iterators.""" from tensorflow.python.data.ops import dataset_ops from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @deprecation.deprecated(None, "Use `tf.data.Dataset.get_single_element()`.") @tf_export("data.experimental.get_single_element") def get_single_element(dataset): """Returns the single element of the `dataset` as a nested structure of tensors. The function enables you to use a `tf.data.Dataset` in a stateless "tensor-in tensor-out" expression, without creating an iterator. This facilitates the ease of data transformation on tensors using the optimized `tf.data.Dataset` abstraction on top of them. For example, lets consider a `preprocessing_fn` which would take as an input the raw features and returns the processed feature along with it's label. ```python def preprocessing_fn(raw_feature): # ... the raw_feature is preprocessed as per the use-case return feature raw_features = ... # input batch of BATCH_SIZE elements. dataset = (tf.data.Dataset.from_tensor_slices(raw_features) .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) .batch(BATCH_SIZE)) processed_features = tf.data.experimental.get_single_element(dataset) ``` In the above example, the `raw_features` tensor of length=BATCH_SIZE was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was mapped using the `preprocessing_fn` and the processed features were grouped into a single batch. The final `dataset` contains only one element which is a batch of all the processed features. NOTE: The `dataset` should contain only one element. Now, instead of creating an iterator for the `dataset` and retrieving the batch of features, the `tf.data.experimental.get_single_element()` function is used to skip the iterator creation process and directly output the batch of features. This can be particularly useful when your tensor transformations are expressed as `tf.data.Dataset` operations, and you want to use those transformations while serving your model. # Keras ```python model = ... # A pre-built or custom model class PreprocessingModel(tf.keras.Model): def __init__(self, model): super().__init__(self) self.model = model @tf.function(input_signature=[...]) def serving_fn(self, data): ds = tf.data.Dataset.from_tensor_slices(data) ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) ds = ds.batch(batch_size=BATCH_SIZE) return tf.argmax( self.model(tf.data.experimental.get_single_element(ds)), axis=-1 ) preprocessing_model = PreprocessingModel(model) your_exported_model_dir = ... # save the model to this path. tf.saved_model.save(preprocessing_model, your_exported_model_dir, signatures={'serving_default': preprocessing_model.serving_fn}) ``` # Estimator In the case of estimators, you need to generally define a `serving_input_fn` which would require the features to be processed by the model while inferencing. ```python def serving_input_fn(): raw_feature_spec = ... # Spec for the raw_features input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( raw_feature_spec, default_batch_size=None) ) serving_input_receiver = input_fn() raw_features = serving_input_receiver.features def preprocessing_fn(raw_feature): # ... the raw_feature is preprocessed as per the use-case return feature dataset = (tf.data.Dataset.from_tensor_slices(raw_features) .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) .batch(BATCH_SIZE)) processed_features = tf.data.experimental.get_single_element(dataset) # Please note that the value of `BATCH_SIZE` should be equal to # the size of the leading dimension of `raw_features`. This ensures # that `dataset` has only element, which is a pre-requisite for # using `tf.data.experimental.get_single_element(dataset)`. return tf.estimator.export.ServingInputReceiver( processed_features, serving_input_receiver.receiver_tensors) estimator = ... # A pre-built or custom estimator estimator.export_saved_model(your_exported_model_dir, serving_input_fn) ``` Args: dataset: A `tf.data.Dataset` object containing a single element. Returns: A nested structure of `tf.Tensor` objects, corresponding to the single element of `dataset`. Raises: TypeError: if `dataset` is not a `tf.data.Dataset` object. InvalidArgumentError: (at runtime) if `dataset` does not contain exactly one element. """ if not isinstance(dataset, dataset_ops.DatasetV2): raise TypeError( f"Invalid `dataset`. Expected a `tf.data.Dataset` object " f"but got {type(dataset)}.") return dataset.get_single_element()