Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/experimental/ops/shuffle_ops.py

273 lines
11 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental shuffle ops."""
import functools
import numpy as np
from tensorflow.python.data.experimental.ops import random_access
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import random_seed
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that fuses `shuffle` and `repeat`."""
def __init__(self, input_dataset, buffer_size, count=None, seed=None):
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
if count is None:
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
else:
self._count = ops.convert_to_tensor(
count, dtype=dtypes.int64, name="count")
self._seed, self._seed2 = random_seed.get_seed(seed)
variant_tensor = gen_dataset_ops.shuffle_and_repeat_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
buffer_size=self._buffer_size,
count=self._count,
seed=self._seed,
seed2=self._seed2,
**self._flat_structure)
super(_ShuffleAndRepeatDataset, self).__init__(input_dataset,
variant_tensor)
@deprecation.deprecated(
None, "Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by "
"`tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take "
"care of using the fused implementation.")
@tf_export("data.experimental.shuffle_and_repeat")
def shuffle_and_repeat(buffer_size, count=None, seed=None):
"""Shuffles and repeats a Dataset, reshuffling with each repetition.
>>> d = tf.data.Dataset.from_tensor_slices([1, 2, 3])
>>> d = d.apply(tf.data.experimental.shuffle_and_repeat(2, count=2))
>>> [elem.numpy() for elem in d] # doctest: +SKIP
[2, 3, 1, 1, 3, 2]
```python
dataset.apply(
tf.data.experimental.shuffle_and_repeat(buffer_size, count, seed))
```
produces the same output as
```python
dataset.shuffle(
buffer_size, seed=seed, reshuffle_each_iteration=True).repeat(count)
```
In each repetition, this dataset fills a buffer with `buffer_size` elements,
then randomly samples elements from this buffer, replacing the selected
elements with new elements. For perfect shuffling, set the buffer size equal
to the full size of the dataset.
For instance, if your dataset contains 10,000 elements but `buffer_size` is
set to 1,000, then `shuffle` will initially select a random element from
only the first 1,000 elements in the buffer. Once an element is selected,
its space in the buffer is replaced by the next (i.e. 1,001-st) element,
maintaining the 1,000 element buffer.
Args:
buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum
number elements that will be buffered when prefetching.
count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the number
of times the dataset should be repeated. The default behavior (if `count`
is `None` or `-1`) is for the dataset be repeated indefinitely.
seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
seed that will be used to create the distribution. See
`tf.random.set_seed` for behavior.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset): # pylint: disable=missing-docstring
return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
return _apply_fn
def _process_file_infos(file_infos):
"""Computes aggregate information about files to read.
The method collects information about the files to read, the total number of
elements, and arrays that can be used to account for elements to be skipped,
which can be specified via the "skip" and "take" keys.
To account for elements to skip, the range of each file can be divided into
three regions:
- S (elements to skip)
- T (elements to read)
- R (remainder of elements that will also be skipped)
The `thresholds` and `offsets` arrays are initialized as follows:
`thresholds = [0, T_1, T_1 + T_2, ...]` and
`offsets = [S_1, S_1 + R_1 + S_2, S_1 + R_1 + S_2 + R_2 + S_3, ...]`
This makes it possible to map an index from a contiguous range
`(0...num_elements_to_read)` to an index in the range of all elements,
skipping over elements as per the "skip" and "take" keys values. In
particular, for a given input index `X`, we find the greatest `thresholds`
value that is smaller or equal to `X`. Let `t(X)` denotes such index in the
`thresholds` array. The output index is computed as `X + offsets[t(X)]`.
Args:
file_infos: See `file_infos` argument of `index_shuffle` for details.
Returns:
A dictionary containing the following keys:
- `files`, the vector of pathnames of files to read
- `num_elements`, an integer identifying the total number of elements
- `offsets`, the vector of offsets to use for index adjustment (in case
any elements should be skipped)
- `thresholds`, the vector of thresholds to use for index adjustment (in
case any elements should be skipped)
"""
files = []
num_elements = 0
offsets = np.int64([])
offset_sum = 0
thresholds = np.int64([])
threshold_sum = 0
adjustment_needed = False
for file_info in file_infos:
files.append(file_info["path"])
skip = 0
if "skip" in file_info:
if file_info["skip"] < -1:
raise ValueError("`skip` should be greater than `-1` but got {}".format(
file_info["skip"]))
if file_info["skip"] == -1:
skip = file_info["num_elements"]
else:
skip = min(file_info["skip"], file_info["num_elements"])
take = file_info["num_elements"] - skip
if "take" in file_info:
if file_info["take"] < -1:
raise ValueError("`take` should be greater than `-1` but got {}".format(
file_info["take"]))
# `file_info["take"] == -1` is a no-op
if file_info["take"] != -1:
take = min(file_info["take"], take)
remainder = file_info["num_elements"] - skip - take
if take != file_info["num_elements"]:
adjustment_needed = True
num_elements += take
offsets = np.append(offsets, offset_sum + skip)
offset_sum += skip + remainder
thresholds = np.append(thresholds, threshold_sum)
threshold_sum += take
result = {"files": files, "num_elements": num_elements}
if adjustment_needed:
result["offsets"] = offsets
result["thresholds"] = thresholds
return result
def _adjust_index(index, thresholds, offsets):
"""Adjusts index to account for elements to be skipped."""
t_index = array_ops.shape(
array_ops.boolean_mask(
thresholds,
math_ops.less_equal(thresholds, index)))[0] - 1
return index + array_ops.gather(offsets, t_index)
# TODO(jsimsa): Expose this method in the public API. When we do, consider
# defining `FileInfo` as a public API to encapsulate the information provided
# through the `file_infos` argument.
def index_shuffle(file_infos,
reader_factory,
seed=None,
reshuffle_each_iteration=False,
num_parallel_calls=dataset_ops.AUTOTUNE):
"""Creates a (globally) shuffled dataset from the given set of files.
Unlike `tf.data.Dataset.shuffle()`, which uses an in-memory buffer to shuffle
elements of input dataset in a streaming fashion,
`tf.data.experimental.index_shuffle()` performs a global shuffle of element
indices and then reads the data in a shuffled order. The advantage of
`index_shuffle()` is that it can perform global shuffle of datasets that do
not fit into memory (as long as the array of their indices does) and that the
shuffling logic it provides is compatible with symbolic checkpointing. The
disadvantage of `index_shuffle()` is that reading data in a shuffled random
order will in general not be as efficient as reading data sequentially.
Args:
file_infos: A list of dictionaries that describe each file of the input
dataset. Each dictionary is expected to contain the "path" key, which
identifies the path of the file and the "num_elements" key, which
identifies the number of elements in the file. In addition, the "skip"
and "take" keys can be used to identify the number of elements to skip
and take respectively. By default, no elements are skipped and all
elements are taken.
reader_factory: A function that maps a sequence of filenames to an instance
of `tf.data.Dataset` that reads data from the files.
seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
seed that will be used to shuffle the order of elements. Default to
non-deterministic seed.
reshuffle_each_iteration: (Optional.) A `tf.bool` scalar `tf.Tensor`, that
determines whether to change the shuffle order each iteration. Defaults to
`False`.
num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, that
determines the maximum number of random access operations to perform
in parallel. By default, the tf.data runtime uses autotuning to determine
the value dynamically.
Returns:
A `tf.data.Dataset` object, representing a globally shuffled dataset of
the input data.
"""
result = _process_file_infos(file_infos)
def sequential_index_shuffle(seeds):
dataset = dataset_ops.Dataset.range(result["num_elements"])
def read_element(dataset, index):
# 1) Shuffle the index.
shuffled_index = stateless_random_ops.index_shuffle(
index, seeds, result["num_elements"] - 1)
# 2) If needed, adjust the index to the non-contiguous range.
if "thresholds" in result and "offsets" in result:
shuffled_index = _adjust_index(shuffled_index, result["thresholds"],
result["offsets"])
# 3) Perform the read.
return random_access.at(dataset, shuffled_index)
# We evaluate `reader_factory()` eagerly to prevent the dataset from being
# created on every lookup.
map_func = functools.partial(read_element, reader_factory(result["files"]))
return dataset.map(map_func, num_parallel_calls=num_parallel_calls)
rng_ds = dataset_ops.Dataset.random(
seed=seed,
rerandomize_each_iteration=reshuffle_each_iteration)
rng_ds = rng_ds.take(2).batch(2, drop_remainder=True)
return rng_ds.flat_map(sequential_index_shuffle)