148 lines
5.5 KiB
Python
148 lines
5.5 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Python wrappers for Datasets and Iterators."""
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.util import deprecation
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@deprecation.deprecated(None, "Use `tf.data.Dataset.get_single_element()`.")
|
|
@tf_export("data.experimental.get_single_element")
|
|
def get_single_element(dataset):
|
|
"""Returns the single element of the `dataset` as a nested structure of tensors.
|
|
|
|
The function enables you to use a `tf.data.Dataset` in a stateless
|
|
"tensor-in tensor-out" expression, without creating an iterator.
|
|
This facilitates the ease of data transformation on tensors using the
|
|
optimized `tf.data.Dataset` abstraction on top of them.
|
|
|
|
For example, lets consider a `preprocessing_fn` which would take as an
|
|
input the raw features and returns the processed feature along with
|
|
it's label.
|
|
|
|
```python
|
|
def preprocessing_fn(raw_feature):
|
|
# ... the raw_feature is preprocessed as per the use-case
|
|
return feature
|
|
|
|
raw_features = ... # input batch of BATCH_SIZE elements.
|
|
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
|
|
.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
|
|
.batch(BATCH_SIZE))
|
|
|
|
processed_features = tf.data.experimental.get_single_element(dataset)
|
|
```
|
|
|
|
In the above example, the `raw_features` tensor of length=BATCH_SIZE
|
|
was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was
|
|
mapped using the `preprocessing_fn` and the processed features were
|
|
grouped into a single batch. The final `dataset` contains only one element
|
|
which is a batch of all the processed features.
|
|
|
|
NOTE: The `dataset` should contain only one element.
|
|
|
|
Now, instead of creating an iterator for the `dataset` and retrieving the
|
|
batch of features, the `tf.data.experimental.get_single_element()` function
|
|
is used to skip the iterator creation process and directly output the batch
|
|
of features.
|
|
|
|
This can be particularly useful when your tensor transformations are
|
|
expressed as `tf.data.Dataset` operations, and you want to use those
|
|
transformations while serving your model.
|
|
|
|
# Keras
|
|
|
|
```python
|
|
|
|
model = ... # A pre-built or custom model
|
|
|
|
class PreprocessingModel(tf.keras.Model):
|
|
def __init__(self, model):
|
|
super().__init__(self)
|
|
self.model = model
|
|
|
|
@tf.function(input_signature=[...])
|
|
def serving_fn(self, data):
|
|
ds = tf.data.Dataset.from_tensor_slices(data)
|
|
ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
|
|
ds = ds.batch(batch_size=BATCH_SIZE)
|
|
return tf.argmax(
|
|
self.model(tf.data.experimental.get_single_element(ds)),
|
|
axis=-1
|
|
)
|
|
|
|
preprocessing_model = PreprocessingModel(model)
|
|
your_exported_model_dir = ... # save the model to this path.
|
|
tf.saved_model.save(preprocessing_model, your_exported_model_dir,
|
|
signatures={'serving_default': preprocessing_model.serving_fn})
|
|
```
|
|
|
|
# Estimator
|
|
|
|
In the case of estimators, you need to generally define a `serving_input_fn`
|
|
which would require the features to be processed by the model while
|
|
inferencing.
|
|
|
|
```python
|
|
def serving_input_fn():
|
|
|
|
raw_feature_spec = ... # Spec for the raw_features
|
|
input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
|
|
raw_feature_spec, default_batch_size=None)
|
|
)
|
|
serving_input_receiver = input_fn()
|
|
raw_features = serving_input_receiver.features
|
|
|
|
def preprocessing_fn(raw_feature):
|
|
# ... the raw_feature is preprocessed as per the use-case
|
|
return feature
|
|
|
|
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
|
|
.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
|
|
.batch(BATCH_SIZE))
|
|
|
|
processed_features = tf.data.experimental.get_single_element(dataset)
|
|
|
|
# Please note that the value of `BATCH_SIZE` should be equal to
|
|
# the size of the leading dimension of `raw_features`. This ensures
|
|
# that `dataset` has only element, which is a pre-requisite for
|
|
# using `tf.data.experimental.get_single_element(dataset)`.
|
|
|
|
return tf.estimator.export.ServingInputReceiver(
|
|
processed_features, serving_input_receiver.receiver_tensors)
|
|
|
|
estimator = ... # A pre-built or custom estimator
|
|
estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
|
|
```
|
|
|
|
Args:
|
|
dataset: A `tf.data.Dataset` object containing a single element.
|
|
|
|
Returns:
|
|
A nested structure of `tf.Tensor` objects, corresponding to the single
|
|
element of `dataset`.
|
|
|
|
Raises:
|
|
TypeError: if `dataset` is not a `tf.data.Dataset` object.
|
|
InvalidArgumentError: (at runtime) if `dataset` does not contain exactly
|
|
one element.
|
|
"""
|
|
if not isinstance(dataset, dataset_ops.DatasetV2):
|
|
raise TypeError(
|
|
f"Invalid `dataset`. Expected a `tf.data.Dataset` object "
|
|
f"but got {type(dataset)}.")
|
|
|
|
return dataset.get_single_element()
|