289 lines
11 KiB
Python
289 lines
11 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Python wrapper for prefetching_ops."""
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.data.ops import iterator_ops
|
|
from tensorflow.python.data.ops import structured_function
|
|
from tensorflow.python.data.util import structure
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import function
|
|
from tensorflow.python.framework import device as framework_device
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import functional_ops
|
|
from tensorflow.python.ops import gen_dataset_ops
|
|
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@tf_export("data.experimental.prefetch_to_device")
|
|
def prefetch_to_device(device, buffer_size=None):
|
|
"""A transformation that prefetches dataset values to the given `device`.
|
|
|
|
NOTE: Although the transformation creates a `tf.data.Dataset`, the
|
|
transformation must be the final `Dataset` in the input pipeline.
|
|
|
|
For example,
|
|
>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
|
|
>>> dataset = dataset.apply(tf.data.experimental.prefetch_to_device("/cpu:0"))
|
|
>>> for element in dataset:
|
|
... print(f'Tensor {element} is on device {element.device}')
|
|
Tensor 1 is on device /job:localhost/replica:0/task:0/device:CPU:0
|
|
Tensor 2 is on device /job:localhost/replica:0/task:0/device:CPU:0
|
|
Tensor 3 is on device /job:localhost/replica:0/task:0/device:CPU:0
|
|
|
|
Args:
|
|
device: A string. The name of a device to which elements will be prefetched.
|
|
buffer_size: (Optional.) The number of elements to buffer on `device`.
|
|
Defaults to an automatically chosen value.
|
|
|
|
Returns:
|
|
A `Dataset` transformation function, which can be passed to
|
|
`tf.data.Dataset.apply`.
|
|
"""
|
|
def _apply_fn(dataset):
|
|
return dataset.apply(
|
|
copy_to_device(target_device=device)).prefetch(buffer_size)
|
|
|
|
return _apply_fn
|
|
|
|
|
|
@tf_export("data.experimental.copy_to_device")
|
|
def copy_to_device(target_device, source_device="/cpu:0"):
|
|
"""A transformation that copies dataset elements to the given `target_device`.
|
|
|
|
Args:
|
|
target_device: The name of a device to which elements will be copied.
|
|
source_device: The original device on which `input_dataset` will be placed.
|
|
|
|
Returns:
|
|
A `Dataset` transformation function, which can be passed to
|
|
`tf.data.Dataset.apply`.
|
|
"""
|
|
|
|
def _apply_fn(dataset):
|
|
return _CopyToDeviceDataset(
|
|
dataset, target_device=target_device, source_device=source_device)
|
|
|
|
return _apply_fn
|
|
|
|
|
|
# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
|
|
# all inputs to the Op are in host memory, thereby avoiding some unnecessary
|
|
# Sends and Recvs.
|
|
class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
|
"""A `Dataset` that copies elements to another device."""
|
|
|
|
def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
|
|
"""Constructs a _CopyToDeviceDataset.
|
|
|
|
Args:
|
|
input_dataset: `Dataset` to be copied
|
|
target_device: The name of the device to which elements would be copied.
|
|
source_device: Device where input_dataset would be placed.
|
|
"""
|
|
self._input_dataset = input_dataset._apply_debug_options() # pylint: disable=protected-access
|
|
self._target_device = target_device
|
|
spec = framework_device.DeviceSpec().from_string(self._target_device)
|
|
self._is_gpu_target = (spec.device_type == "GPU")
|
|
self._source_device_string = source_device
|
|
self._source_device = ops.convert_to_tensor(source_device)
|
|
|
|
wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
|
|
self._input_dataset._variant_tensor) # pylint: disable=protected-access
|
|
|
|
@def_function.function()
|
|
def _init_func():
|
|
"""Creates an iterator for the input dataset.
|
|
|
|
Returns:
|
|
A `string` tensor that encapsulates the iterator created.
|
|
"""
|
|
ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
|
|
resource = gen_dataset_ops.anonymous_iterator(
|
|
**self._input_dataset._flat_structure) # pylint: disable=protected-access
|
|
with ops.control_dependencies(
|
|
[gen_dataset_ops.make_iterator(ds_variant, resource)]):
|
|
return gen_dataset_ops.iterator_to_string_handle(resource)
|
|
|
|
init_func_concrete = _init_func.get_concrete_function() # pylint: disable=protected-access
|
|
|
|
@def_function.function()
|
|
def _remote_init_func():
|
|
return functional_ops.remote_call(
|
|
target=self._source_device,
|
|
args=init_func_concrete.captured_inputs,
|
|
Tout=[dtypes.string],
|
|
f=init_func_concrete)
|
|
|
|
self._init_func = _remote_init_func.get_concrete_function() # pylint: disable=protected-access
|
|
self._init_captured_args = self._init_func.captured_inputs
|
|
|
|
@def_function.function(
|
|
input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
|
|
def _next_func(string_handle):
|
|
"""Calls get_next for created iterator.
|
|
|
|
Args:
|
|
string_handle: An iterator string handle created by _init_func
|
|
Returns:
|
|
The elements generated from `input_dataset`
|
|
"""
|
|
with ops.device(self._source_device_string):
|
|
iterator = iterator_ops.Iterator.from_string_handle(
|
|
string_handle,
|
|
dataset_ops.get_legacy_output_types(self),
|
|
dataset_ops.get_legacy_output_shapes(self),
|
|
dataset_ops.get_legacy_output_classes(self))
|
|
return structure.to_tensor_list(self.element_spec, iterator.get_next())
|
|
|
|
next_func_concrete = _next_func.get_concrete_function() # pylint: disable=protected-access
|
|
|
|
@function.defun_with_attributes(
|
|
input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
|
|
attributes={"experimental_ints_on_device": True})
|
|
def _remote_next_func(string_handle):
|
|
return functional_ops.remote_call(
|
|
target=self._source_device,
|
|
args=[string_handle] + next_func_concrete.captured_inputs,
|
|
Tout=self._input_dataset._flat_types, # pylint: disable=protected-access
|
|
f=next_func_concrete)
|
|
|
|
self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access
|
|
self._next_captured_args = self._next_func.captured_inputs
|
|
|
|
@def_function.function(
|
|
input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
|
|
def _finalize_func(string_handle):
|
|
"""Destroys the iterator resource created.
|
|
|
|
Args:
|
|
string_handle: An iterator string handle created by _init_func
|
|
Returns:
|
|
Tensor constant 0
|
|
"""
|
|
iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
|
|
string_handle,
|
|
**self._input_dataset._flat_structure) # pylint: disable=protected-access
|
|
with ops.control_dependencies([
|
|
resource_variable_ops.destroy_resource_op(
|
|
iterator_resource, ignore_lookup_error=True)]):
|
|
return array_ops.constant(0, dtypes.int64)
|
|
|
|
finalize_func_concrete = _finalize_func.get_concrete_function() # pylint: disable=protected-access
|
|
|
|
@def_function.function(
|
|
input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
|
|
def _remote_finalize_func(string_handle):
|
|
return functional_ops.remote_call(
|
|
target=self._source_device,
|
|
args=[string_handle] + finalize_func_concrete.captured_inputs,
|
|
Tout=[dtypes.int64],
|
|
f=finalize_func_concrete)
|
|
|
|
self._finalize_func = _remote_finalize_func.get_concrete_function( # pylint: disable=protected-access
|
|
)
|
|
self._finalize_captured_args = self._finalize_func.captured_inputs
|
|
|
|
g = ops.get_default_graph()
|
|
self._init_func.add_to_graph(g)
|
|
self._next_func.add_to_graph(g)
|
|
self._finalize_func.add_to_graph(g)
|
|
# pylint: enable=protected-scope
|
|
|
|
with ops.device(self._target_device):
|
|
variant_tensor = gen_dataset_ops.generator_dataset(
|
|
self._init_captured_args,
|
|
self._next_captured_args,
|
|
self._finalize_captured_args,
|
|
init_func=self._init_func,
|
|
next_func=self._next_func,
|
|
finalize_func=self._finalize_func,
|
|
**self._input_dataset._flat_structure) # pylint: disable=protected-access
|
|
super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
|
|
|
|
# The one_shot_iterator implementation needs a 0 arg _make_dataset function
|
|
# that thereby captures all the inputs required to create the dataset. Since
|
|
# there are strings that are inputs to the GeneratorDataset which can't be
|
|
# placed on a GPU, this fails for the GPU case. Therefore, disabling it for
|
|
# GPU
|
|
def make_one_shot_iterator(self):
|
|
if self._is_gpu_target:
|
|
raise ValueError(
|
|
"`make_one_shot_iterator` is not compatible with GPU execution. "
|
|
"Please use `Dataset.make_initializable_iterator()` instead."
|
|
)
|
|
else:
|
|
return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
|
|
|
|
|
|
class _MapOnGpuDataset(dataset_ops.UnaryDataset):
|
|
"""A `Dataset` that maps a function over elements in its using a GPU."""
|
|
|
|
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
|
|
"""See `Dataset.map()` for details."""
|
|
self._input_dataset = input_dataset
|
|
self._use_inter_op_parallelism = use_inter_op_parallelism
|
|
|
|
self._map_func = structured_function.StructuredFunctionWrapper(
|
|
map_func,
|
|
self._transformation_name(),
|
|
dataset=input_dataset,
|
|
defun_kwargs={"experimental_ints_on_device": True})
|
|
variant_tensor = ged_ops.experimental_map_dataset(
|
|
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
|
self._map_func.function.captured_inputs,
|
|
f=self._map_func.function,
|
|
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
|
**self._flat_structure)
|
|
super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
|
|
|
|
def _functions(self):
|
|
return [self._map_func]
|
|
|
|
@property
|
|
def element_spec(self):
|
|
return self._map_func.output_structure
|
|
|
|
def _transformation_name(self):
|
|
return "map_on_gpu()"
|
|
|
|
|
|
def map_on_gpu(map_func):
|
|
"""Maps `map_func` across the elements of this dataset.
|
|
|
|
NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs
|
|
`map_func` on GPU. It must be used after applying the
|
|
`tf.data.experimental.copy_to_device` transformation with a GPU device
|
|
argument.
|
|
|
|
Args:
|
|
map_func: A function mapping a nested structure of tensors (having shapes
|
|
and types defined by `self.output_shapes` and `self.output_types`) to
|
|
another nested structure of tensors.
|
|
|
|
Returns:
|
|
A `Dataset` transformation function, which can be passed to
|
|
`tf.data.Dataset.apply`.
|
|
"""
|
|
|
|
def _apply_fn(dataset):
|
|
return _MapOnGpuDataset(dataset, map_func)
|
|
|
|
return _apply_fn
|