Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/experimental/ops/distribute.py
2023-06-19 00:49:18 +02:00

399 lines
15 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Distribution Strategy-related dataset transformations."""
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops.options import ExternalStatePolicy
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util.tf_export import tf_export
SHARD_HINT = -1
tf_export("data.experimental.SHARD_HINT").export_constant(
__name__, "SHARD_HINT")
class _AutoShardDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that shards the `Dataset` automatically.
This dataset takes in an existing dataset and tries to automatically figure
out how to shard the dataset in a multi-worker scenario using graph rewrites.
If the AutoShardPolicy is set to FILE, it walks up the dataset graph until
it finds a reader dataset, then inserts a ShardDataset op before that node
so that each worker only sees some files.
If the AutoShardPolicy is set to DATA, it inserts a ShardDataset op at the
end of the input pipeline, before any terminal PrefetchDataset if there is
one. Additionally, if there is a RebatchDatasetV2 in the input pipeline, it
is written to legacy RebatchDataset for correctness reasons, since
RebatchDatasetV2 is incompatible with data sharding.
If the AutoShardPolicy is set to AUTO, it tries to do file-based sharding.
If it cannot find a reader dataset, it falls back to doing data-based
sharding.
If the AutoShardPolicy is set to OFF, it does nothing.
Attributes:
num_workers: Total number of workers to shard this dataset across.
index: The current worker index (out of the total number of workers) this
dataset is for.
num_replicas: The total number of replicas across all workers. This is used
only when sharding by data (either DATA or AUTO) in order to rewrite
RebatchDatasetV2 to RebatchDataset.
Raises:
NotFoundError: If we cannot find a suitable reader dataset to begin
automatically sharding the dataset.
"""
def __init__(self, input_dataset, num_workers, index, num_replicas=None):
self._input_dataset = input_dataset
self._element_spec = input_dataset.element_spec
variant_tensor = ged_ops.auto_shard_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
index=index,
auto_shard_policy=int(
input_dataset.options().experimental_distribute.auto_shard_policy),
num_replicas=num_replicas,
**self._flat_structure)
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
@property
def element_spec(self):
return self._element_spec
def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name
return dataset_ops.DatasetV1Adapter(
_AutoShardDataset(input_dataset, num_workers, index, num_replicas))
class _LegacyRebatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that divides its input batches into `num_replicas` sub-batches.
For each batch in the input dataset, _LegacyRebatchDataset will produce
`num_replicas` smaller batches whose sizes add up to the original batch size.
For example:
```python
ds = tf.data.Dataset.range(8)
ds = ds.batch(4)
ds = _LegacyRebatchDataset(ds, num_replicas=3)
for elem in ds:
print(elem)
>> [0, 1], [2, 3], [], [4, 5], [6, 7], []
```
"""
def __init__(self, input_dataset, num_replicas):
"""Creates a _LegacyRebatchDataset.
Args:
input_dataset: `Dataset` to rebatch.
num_replicas: A `tf.int64` scalar, representing the number of sub-batches
to split each batch from `input_dataset` into.
"""
def recalculate_batch_size(type_spec):
"""Recalculates the output_shape after dividing it by num_replicas."""
output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access
if not isinstance(output_shape, tensor_shape.TensorShape):
return None
# If the output shape is unknown, we set the batch dimension to unknown.
if output_shape.rank is None:
return None
if len(output_shape) < 1:
raise ValueError(
"Invalid `input_dataset`. Expected a dataset whose elements "
"have rank >= 1 but found a dataset whose elements are scalars. "
"Fix the issue by adding the `batch` transformation to the "
"dataset.")
output_dims = [d.value for d in output_shape.dims]
if output_dims[0] is not None and output_dims[0] % num_replicas == 0:
return output_dims[0] // num_replicas
# Set the batch dimension to unknown. If the global batch size does not
# divide num_replicas evenly, the minibatches may have different sizes.
return None
def rebatch(type_spec):
# pylint: disable=protected-access
batch_size = recalculate_batch_size(type_spec)
return type_spec._unbatch()._batch(batch_size)
# pylint: enable=protected-access
self._element_spec = nest.map_structure(
rebatch, dataset_ops.get_structure(input_dataset))
# auto_shard rewrite assumes that there's normalize_to_dense before
# rebatch_dataset.
# LINT.IfChange
input_dataset = dataset_ops.normalize_to_dense(input_dataset)
variant_tensor = ged_ops.rebatch_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
num_replicas=num_replicas,
**self._flat_structure)
# LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc)
super(_LegacyRebatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def element_spec(self):
return self._element_spec
class _RemoteDataset(dataset_ops.DatasetSource):
"""Creates a dataset on a given `device` given a graph def."""
def __init__(self, graph_def, device, element_spec):
self._elem_spec = element_spec
with ops.device(device):
variant_tensor = ged_ops.dataset_from_graph(graph_def)
super(_RemoteDataset, self).__init__(variant_tensor)
@property
def element_spec(self):
return self._elem_spec
def replicate(dataset, devices):
"""A transformation that replicates `dataset` onto a list of devices.
Args:
dataset: A `tf.data.Dataset` object.
devices: A list of devices to replicate the dataset on.
Returns:
A dictionary mapping device name to a dataset on that device.
"""
if not isinstance(dataset, dataset_ops.DatasetV2):
raise TypeError(
f"Invalid `dataset`. Expected a `tf.data.Dataset` object but "
f"got {type(dataset)}.")
# pylint: disable=protected-access
dataset_device = dataset._variant_tensor.device
datasets = {}
if len(devices) == 1 and devices[0] == dataset_device:
datasets[devices[0]] = dataset
return datasets
with ops.colocate_with(dataset._variant_tensor):
dataset = dataset._apply_debug_options()
graph_def = dataset._as_serialized_graph(
strip_device_assignment=True,
external_state_policy=ExternalStatePolicy.WARN)
for device in devices:
ds = _RemoteDataset(graph_def, device, dataset.element_spec)
datasets[device] = ds
return datasets
def batch_sizes_for_worker(global_batch_size, num_workers,
num_replicas_per_worker, worker_index):
"""Determines how to rebatch a dataset for the given worker.
Given the global batch size, number of workers, number of replicas per worker,
and worker index, returns the correct batch sizes for rebatching a dataset
on worker `worker_index` of `num_workers`, such that each global step (across
all workers and replicas) will consume global_batch_size elements. The
returned value should be passed as the `batch_sizes` input parameter to
`tf.data.experimental.rebatch()`. The returned batch sizes meet the following
constraints:
Let G = global_batch_size, W = num_workers, R = num_replicas_per_worker
(A) for any worker, len(batch_sizes) = W * R
(B) for any worker, sum(batch_sizes) == G
(C) for any global step (i.e. R iterations on each worker), the sum of batches
consumed by replicas across all workers is G.
(D) any two batch sizes of any two replicas differs by at most one.
For example, suppose we have G = 7, W = 2, R = 2, and suppose we have two
files which each contain 7 elements:
```python
# WORKER 0
batch_sizes_0 = batch_sizes_for_worker(global_batch_size=global_batch_size,
num_workers=2,
num_replicas_per_worker=2,
worker_index=0)
print(batch_sizes_0)
>> [2, 2, 2, 1]
dataset_0 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"])
dataset_0 = dataset_0.shard(num_shards, index=0)
dataset_0 = dataset_0.batch(7)
dataset_0 = dataset_0.apply(tf.data.experimental.rebatch(batch_sizes_0))
for elem in dataset_0:
print(elem)
>> [[A0, A1], [A2, A3], [A4, A5], [A6]]
# WORKER 1
batch_sizes_1 = batch_sizes_for_worker(global_batch_size=global_batch_size,
num_workers=2,
num_replicas_per_worker=2,
worker_index=1)
print(batch_sizes_1)
>> [2, 1, 2, 2]
dataset_1 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"])
dataset_1 = dataset_1.shard(num_shards, index=1)
dataset_1 = dataset_1.batch(7)
dataset_1 = dataset_1.apply(tf.data.experimental.rebatch(batch_sizes_1))
for elem in dataset_1:
print(elem)
>> [[B0, B1], [B2], [B3, B4], [B5, B6]]
```
The above example will produce the following elements:
Step 1:
Worker 0 Replica 0: [A0, A1]
Worker 0 Replica 1: [A2, A3]
Worker 1 Replica 0: [B0, B1]
Worker 1 Replica 1: [B2]
Total batch size = 7
Step 2:
Worker 0 Replica 0: [A4, A5]
Worker 0 Replica 1: [A6]
Worker 1 Replica 0: [B3, B4]
Worker 1 Replica 1: [B5, B6]
Total batch size = 7
Args:
global_batch_size: A `tf.int64` scalar, representing the global batch size.
num_workers: An integer representing the number of workers the dataset will
be distributed across.
num_replicas_per_worker: An integer representing the number of replicas per
worker. All workers are assumed to have the same number of replicas.
worker_index: An integer index of the worker to be rebatched.
Returns:
A `tf.int64` vector, representing the batch sizes to rebatch the dataset
into.
"""
# Constraint (A)
num_subbatches = num_workers * num_replicas_per_worker
offset = worker_index * num_replicas_per_worker
const_value = tensor_util.constant_value(global_batch_size)
if const_value is not None:
# Use the constant global batch size for further calculations
global_batch_size = const_value
# Let N = W * R. Constraint (B) and (D) jointly mean that the iterations
# should have batch size either floor(B/N) or ceil(B/N). Namely, of the N
# subbatches a batch is split into, B - N * floor(B/N) of them will have size
# ceil(B/N), and the rest will have size floor(B/N).
floor = global_batch_size // num_subbatches
num_ceil = global_batch_size - (num_subbatches * floor)
# For worker 0, we assign the first num_ceil subbatches to have size
# ceil(B/N), and the remainder to have size floor(B/N). The other workers will
# each be offset by R * worker_index in order to meet constraint (C).
if const_value is not None:
# If the global batch size is a known constant value, we return a constant
# tensor directly instead of manipulating it with TF ops. This allows for
# better downstream shape inference.
worker_0 = [floor + 1] * num_ceil + [floor] * (num_subbatches - num_ceil)
return ops.convert_to_tensor(
worker_0[offset:] + worker_0[:offset],
dtype=dtypes.int64,
name="batch_sizes")
worker_0 = array_ops.ones(num_subbatches, dtype=dtypes.int64)
worker_0 = floor * worker_0 + array_ops.concat([
array_ops.ones(num_ceil, dtype=dtypes.int64),
array_ops.zeros(num_subbatches - num_ceil, dtype=dtypes.int64)
],
axis=0)
return array_ops.concat([worker_0[offset:], worker_0[:offset]], axis=0)
def compute_batch_size(dataset):
"""An operation that returns the batch size of the dataset.
This op tries to infer the batch size statically by walking up the dataset
tree from the final dataset node and returning the batch size of the first
batching dataset (such as from .batch() and .padded_batch()) that it
encounters. This differs from using the `element_spec` of a dataset in that it
does not account for partial batches.
This operation may fail if it encounters contradictory batch sizes (for
example, if the dataset is created by zipping together two datasets with
different batch sizes), if there are no explicit batching transformations, or
if there are operations downstream from the batching transformation that may
modify its batch size. In these cases, it returns a -1.
Args:
dataset: A `tf.data.Dataset` object.
Returns:
A `tf.int64` Tensor representing the batch size of the dataset sans partial
batches. If this cannot be inferred statically, the value of this tensor
will be -1.
"""
def get_static_batch_dim(type_spec):
try:
output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access
except NotImplementedError:
return None
if not isinstance(output_shape, tensor_shape.TensorShape):
return None
if output_shape.rank is None:
return None
return output_shape.dims[0].value
batch_dims = [
get_static_batch_dim(type_spec)
for type_spec in nest.flatten(dataset_ops.get_structure(dataset))
]
if all(d is not None for d in batch_dims):
if all(d == batch_dims[0] for d in batch_dims):
# If all batch dimensions are known and equal, return that directly.
batch_dim = batch_dims[0]
else:
# If all batch dimensions are known but not all equal, return -1.
batch_dim = -1
return constant_op.constant(
batch_dim, dtype=dtypes.int64, name="static_batch_size")
# If any batch dimensions are unknown, use compute_batch_size op.
return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__