401 lines
17 KiB
Python
401 lines
17 KiB
Python
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""The implementation of `tf.data.Dataset.from_generator`."""
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.data.ops import structured_function
|
|
from tensorflow.python.data.util import nest
|
|
from tensorflow.python.data.util import structure
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.framework import type_spec
|
|
from tensorflow.python.ops import gen_dataset_ops
|
|
from tensorflow.python.ops import script_ops
|
|
|
|
|
|
def _from_generator(generator, output_types, output_shapes, args,
|
|
output_signature, name):
|
|
"""Creates a `Dataset` whose elements are generated by `generator`.
|
|
|
|
Note: The current implementation of `Dataset.from_generator()` uses
|
|
`tf.numpy_function` and inherits the same constraints. In particular, it
|
|
requires the dataset and iterator related operations to be placed
|
|
on a device in the same process as the Python program that called
|
|
`Dataset.from_generator()`. In particular, using `from_generator` will
|
|
preclude the use of tf.data service for scaling out dataset processing.
|
|
The body of `generator` will not be serialized in a `GraphDef`, and you
|
|
should not use this method if you need to serialize your model and restore
|
|
it in a different environment.
|
|
|
|
The `generator` argument must be a callable object that returns
|
|
an object that supports the `iter()` protocol (e.g. a generator function).
|
|
|
|
The elements generated by `generator` must be compatible with either the
|
|
given `output_signature` argument or with the given `output_types` and
|
|
(optionally) `output_shapes` arguments, whichever was specified.
|
|
|
|
The recommended way to call `from_generator` is to use the
|
|
`output_signature` argument. In this case the output will be assumed to
|
|
consist of objects with the classes, shapes and types defined by
|
|
`tf.TypeSpec` objects from `output_signature` argument:
|
|
|
|
>>> def gen():
|
|
... ragged_tensor = tf.ragged.constant([[1, 2], [3]])
|
|
... yield 42, ragged_tensor
|
|
>>>
|
|
>>> dataset = tf.data.Dataset.from_generator(
|
|
... gen,
|
|
... output_signature=(
|
|
... tf.TensorSpec(shape=(), dtype=tf.int32),
|
|
... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))
|
|
>>>
|
|
>>> list(dataset.take(1))
|
|
[(<tf.Tensor: shape=(), dtype=int32, numpy=42>,
|
|
<tf.RaggedTensor [[1, 2], [3]]>)]
|
|
|
|
There is also a deprecated way to call `from_generator` by either with
|
|
`output_types` argument alone or together with `output_shapes` argument.
|
|
In this case the output of the function will be assumed to consist of
|
|
`tf.Tensor` objects with the types defined by `output_types` and with the
|
|
shapes which are either unknown or defined by `output_shapes`.
|
|
|
|
Note: If `generator` depends on mutable global variables or other external
|
|
state, be aware that the runtime may invoke `generator` multiple times
|
|
(in order to support repeating the `Dataset`) and at any time
|
|
between the call to `Dataset.from_generator()` and the production of the
|
|
first element from the generator. Mutating global variables or external
|
|
state can cause undefined behavior, and we recommend that you explicitly
|
|
cache any external state in `generator` before calling
|
|
`Dataset.from_generator()`.
|
|
|
|
Note: While the `output_signature` parameter makes it possible to yield
|
|
`Dataset` elements, the scope of `Dataset.from_generator()` should be
|
|
limited to logic that cannot be expressed through tf.data operations. Using
|
|
tf.data operations within the generator function is an anti-pattern and may
|
|
result in incremental memory growth.
|
|
|
|
Args:
|
|
generator: A callable object that returns an object that supports the
|
|
`iter()` protocol. If `args` is not specified, `generator` must take no
|
|
arguments; otherwise it must take as many arguments as there are values in
|
|
`args`.
|
|
output_types: (Optional.) A (nested) structure of `tf.DType` objects
|
|
corresponding to each component of an element yielded by `generator`.
|
|
output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` objects
|
|
corresponding to each component of an element yielded by `generator`.
|
|
args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated and
|
|
passed to `generator` as NumPy-array arguments.
|
|
output_signature: (Optional.) A (nested) structure of `tf.TypeSpec` objects
|
|
corresponding to each component of an element yielded by `generator`.
|
|
name: (Optional.) A name for the tf.data operations used by
|
|
`from_generator`.
|
|
|
|
Returns:
|
|
Dataset: A `Dataset`.
|
|
"""
|
|
if not callable(generator):
|
|
raise TypeError("`generator` must be a Python callable.")
|
|
|
|
if output_signature is not None:
|
|
if output_types is not None:
|
|
raise TypeError("The `output_types` argument can not be used together "
|
|
"with the `output_signature` argument.")
|
|
if output_shapes is not None:
|
|
raise TypeError("The `output_shapes` argument can not be used together "
|
|
"with the `output_signature` argument.")
|
|
for spec in nest.flatten(output_signature):
|
|
if not isinstance(spec, type_spec.TypeSpec):
|
|
raise TypeError(f"`output_signature` must contain objects that are "
|
|
f"subclass of `tf.TypeSpec` but found {type(spec)} "
|
|
f"which is not.")
|
|
else:
|
|
if output_types is None:
|
|
raise TypeError("To specify the output signature you need to provide "
|
|
"either the `output_signature` argument or the "
|
|
"`output_types` argument.")
|
|
|
|
if output_signature is None:
|
|
if output_shapes is None:
|
|
output_shapes = nest.map_structure(
|
|
lambda _: tensor_shape.TensorShape(None), output_types)
|
|
else:
|
|
output_shapes = nest.map_structure_up_to(output_types,
|
|
tensor_shape.as_shape,
|
|
output_shapes)
|
|
output_signature = nest.map_structure_up_to(output_types,
|
|
tensor_spec.TensorSpec,
|
|
output_shapes, output_types)
|
|
if all(
|
|
isinstance(x, tensor_spec.TensorSpec)
|
|
for x in nest.flatten(output_signature)):
|
|
output_types = nest.pack_sequence_as(
|
|
output_signature, [x.dtype for x in nest.flatten(output_signature)])
|
|
output_shapes = nest.pack_sequence_as(
|
|
output_signature, [x.shape for x in nest.flatten(output_signature)])
|
|
|
|
if args is None:
|
|
args = ()
|
|
else:
|
|
args = tuple(ops.convert_n_to_tensor(args, name="args"))
|
|
|
|
generator_state = dataset_ops.DatasetV2._GeneratorState(generator) # pylint: disable=protected-access
|
|
|
|
def get_iterator_id_fn(unused_dummy):
|
|
"""Creates a unique `iterator_id` for each pass over the dataset.
|
|
|
|
The returned `iterator_id` disambiguates between multiple concurrently
|
|
existing iterators.
|
|
|
|
Args:
|
|
unused_dummy: Ignored value.
|
|
|
|
Returns:
|
|
A `tf.int64` tensor whose value uniquely identifies an iterator in
|
|
`generator_state`.
|
|
"""
|
|
return script_ops.numpy_function(generator_state.get_next_id, args,
|
|
dtypes.int64)
|
|
|
|
def generator_next_fn(iterator_id_t):
|
|
"""Generates the next element from iterator with ID `iterator_id_t`.
|
|
|
|
We map this function across an infinite repetition of the
|
|
`iterator_id_t`, and raise `StopIteration` to terminate the iteration.
|
|
|
|
Args:
|
|
iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the
|
|
iterator in `generator_state` from which to generate an element.
|
|
|
|
Returns:
|
|
The next element to generate from the iterator.
|
|
"""
|
|
if output_types and output_shapes:
|
|
flattened_types = [
|
|
dtypes.as_dtype(dt) for dt in nest.flatten(output_types)
|
|
]
|
|
flattened_shapes = nest.flatten(output_shapes)
|
|
|
|
def generator_py_func(iterator_id):
|
|
"""A `py_func` that will be called to invoke the iterator."""
|
|
# `next()` raises `StopIteration` when there are no more
|
|
# elements remaining to be generated.
|
|
values = next(generator_state.get_iterator(iterator_id))
|
|
|
|
# Use the same _convert function from the py_func() implementation to
|
|
# convert the returned values to arrays early, so that we can inspect
|
|
# their values.
|
|
try:
|
|
flattened_values = nest.flatten_up_to(output_types, values)
|
|
except (TypeError, ValueError) as e:
|
|
raise TypeError(
|
|
f"`generator` yielded an element that did not match the "
|
|
f"expected structure. The expected structure was "
|
|
f"{output_types}, but the yielded element was {values}.") from e
|
|
ret_arrays = []
|
|
for ret, dtype in zip(flattened_values, flattened_types):
|
|
try:
|
|
ret_arrays.append(
|
|
script_ops.FuncRegistry._convert( # pylint: disable=protected-access
|
|
ret,
|
|
dtype=dtype.as_numpy_dtype))
|
|
except (TypeError, ValueError) as e:
|
|
raise TypeError(
|
|
f"`generator` yielded an element that could not be "
|
|
f"converted to the expected type. The expected type was "
|
|
f"{dtype.name}, but the yielded element was {ret}.") from e
|
|
|
|
# Additional type and shape checking to ensure that the components of
|
|
# the generated element match the `output_types` and `output_shapes`
|
|
# arguments.
|
|
for (ret_array, expected_dtype,
|
|
expected_shape) in zip(ret_arrays, flattened_types,
|
|
flattened_shapes):
|
|
if ret_array.dtype != expected_dtype.as_numpy_dtype:
|
|
raise TypeError(
|
|
f"`generator` yielded an element of type {ret_array.dtype} "
|
|
f"where an element of type {expected_dtype.as_numpy_dtype} "
|
|
f"was expected.")
|
|
if not expected_shape.is_compatible_with(ret_array.shape):
|
|
raise TypeError(
|
|
f"`generator` yielded an element of shape {ret_array.shape} "
|
|
f"where an element of shape {expected_shape} was expected.")
|
|
|
|
return ret_arrays
|
|
|
|
flat_values = script_ops.numpy_function(generator_py_func,
|
|
[iterator_id_t], flattened_types)
|
|
|
|
# In debug mode the numpy_function will return a scalar if
|
|
# generator_py_func produces only a single value.
|
|
if not isinstance(flat_values, (list, tuple)):
|
|
flat_values = [flat_values]
|
|
|
|
# The `py_func()` op drops the inferred shapes, so we add them back in
|
|
# here.
|
|
if output_shapes is not None:
|
|
for ret_t, shape in zip(flat_values, flattened_shapes):
|
|
ret_t.set_shape(shape)
|
|
|
|
return nest.pack_sequence_as(output_types, flat_values)
|
|
else:
|
|
flat_output_types = structure.get_flat_tensor_types(output_signature)
|
|
|
|
def generator_py_func(iterator_id):
|
|
"""A `py_func` that will be called to invoke the iterator."""
|
|
# `next()` raises `StopIteration` when there are no more
|
|
# elements remaining to be generated.
|
|
values = next(generator_state.get_iterator(iterator_id.numpy()))
|
|
|
|
try:
|
|
values = structure.normalize_element(values, output_signature)
|
|
except (TypeError, ValueError) as e:
|
|
raise TypeError(
|
|
f"`generator` yielded an element that did not match the "
|
|
f"expected structure. The expected structure was "
|
|
f"{output_signature}, but the yielded element was "
|
|
f"{values}.") from e
|
|
|
|
values_spec = structure.type_spec_from_value(values)
|
|
|
|
if not structure.are_compatible(values_spec, output_signature):
|
|
raise TypeError(
|
|
f"`generator` yielded an element of {values_spec} where an "
|
|
f"element of {output_signature} was expected.")
|
|
|
|
return structure.to_tensor_list(output_signature, values)
|
|
|
|
return script_ops.eager_py_func(
|
|
generator_py_func, inp=[iterator_id_t], Tout=flat_output_types)
|
|
|
|
def finalize_fn(iterator_id_t):
|
|
"""Releases host-side state for the iterator with ID `iterator_id_t`."""
|
|
|
|
def finalize_py_func(iterator_id):
|
|
generator_state.iterator_completed(iterator_id)
|
|
# We return a dummy value so that the `finalize_fn` has a valid
|
|
# signature.
|
|
# NOTE(mrry): Explicitly create an array of `np.int64` because implicit
|
|
# casting in `py_func()` will create an array of `np.int32` on Windows,
|
|
# leading to a runtime error.
|
|
return np.array(0, dtype=np.int64)
|
|
|
|
return script_ops.numpy_function(finalize_py_func, [iterator_id_t],
|
|
dtypes.int64)
|
|
|
|
# This function associates each traversal of `generator` with a unique
|
|
# iterator ID.
|
|
def flat_map_fn(dummy_arg):
|
|
# The `get_iterator_id_fn` gets a unique ID for the current instance of
|
|
# of the generator.
|
|
# The `generator_next_fn` gets the next element from the iterator with the
|
|
# given ID, and raises StopIteration when that iterator contains no
|
|
# more elements.
|
|
return _GeneratorDataset(
|
|
dummy_arg,
|
|
get_iterator_id_fn,
|
|
generator_next_fn,
|
|
finalize_fn,
|
|
output_signature,
|
|
name=name)
|
|
|
|
# A single-element dataset that, each time it is evaluated, contains a
|
|
# freshly-generated and unique (for the returned dataset) int64
|
|
# ID that will be used to identify the appropriate Python state, which
|
|
# is encapsulated in `generator_state`, and captured in
|
|
# `get_iterator_id_map_fn`.
|
|
dummy = 0
|
|
id_dataset = dataset_ops.Dataset.from_tensors(dummy, name=name)
|
|
|
|
# A dataset that contains all of the elements generated by a
|
|
# single iterator created from `generator`, identified by the
|
|
# iterator ID contained in `id_dataset`. Lifting the iteration
|
|
# into a flat_map here enables multiple repetitions and/or nested
|
|
# versions of the returned dataset to be created, because it forces
|
|
# the generation of a new ID for each version.
|
|
return id_dataset.flat_map(flat_map_fn, name=name)
|
|
|
|
|
|
class _GeneratorDataset(dataset_ops.DatasetSource):
|
|
"""A `Dataset` that generates elements by invoking a function."""
|
|
|
|
def __init__(self,
|
|
init_args,
|
|
init_func,
|
|
next_func,
|
|
finalize_func,
|
|
output_signature,
|
|
name=None):
|
|
"""Constructs a `_GeneratorDataset`.
|
|
|
|
Args:
|
|
init_args: A (nested) structure representing the arguments to `init_func`.
|
|
init_func: A TensorFlow function that will be called on `init_args` each
|
|
time a C++ iterator over this dataset is constructed. Returns a (nested)
|
|
structure representing the "state" of the dataset.
|
|
next_func: A TensorFlow function that will be called on the result of
|
|
`init_func` to produce each element, and that raises `OutOfRangeError`
|
|
to terminate iteration.
|
|
finalize_func: A TensorFlow function that will be called on the result of
|
|
`init_func` immediately before a C++ iterator over this dataset is
|
|
destroyed. The return value is ignored.
|
|
output_signature: A (nested) structure of `tf.TypeSpec` objects describing
|
|
the output of `next_func`.
|
|
name: Optional. A name for the tf.data transformation.
|
|
"""
|
|
self._init_args = init_args
|
|
|
|
self._init_structure = structure.type_spec_from_value(init_args)
|
|
|
|
self._init_func = structured_function.StructuredFunctionWrapper(
|
|
init_func,
|
|
self._transformation_name(),
|
|
input_structure=self._init_structure)
|
|
|
|
self._next_func = structured_function.StructuredFunctionWrapper(
|
|
next_func,
|
|
self._transformation_name(),
|
|
input_structure=self._init_func.output_structure)
|
|
|
|
self._finalize_func = structured_function.StructuredFunctionWrapper(
|
|
finalize_func,
|
|
self._transformation_name(),
|
|
input_structure=self._init_func.output_structure)
|
|
|
|
self._output_signature = output_signature
|
|
|
|
self._name = name
|
|
|
|
variant_tensor = gen_dataset_ops.generator_dataset(
|
|
structure.to_tensor_list(self._init_structure, self._init_args) +
|
|
self._init_func.function.captured_inputs,
|
|
self._next_func.function.captured_inputs,
|
|
self._finalize_func.function.captured_inputs,
|
|
init_func=self._init_func.function,
|
|
next_func=self._next_func.function,
|
|
finalize_func=self._finalize_func.function,
|
|
**self._common_args)
|
|
super().__init__(variant_tensor)
|
|
|
|
@property
|
|
def element_spec(self):
|
|
return self._output_signature
|
|
|
|
def _transformation_name(self):
|
|
return "Dataset.from_generator()"
|