Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/kernel_tests/test_base.py

444 lines
18 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test utilities for tf.data functionality."""
import os
import random
import re
from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
from tensorflow.python.data.experimental.ops import random_access
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
def default_test_combinations():
"""Returns the default test combinations for tf.data tests."""
return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"])
def eager_only_combinations():
"""Returns the default test combinations for eager mode only tf.data tests."""
return combinations.combine(tf_api_version=[1, 2], mode="eager")
def graph_only_combinations():
"""Returns the default test combinations for graph mode only tf.data tests."""
return combinations.combine(tf_api_version=[1, 2], mode="graph")
def v1_only_combinations():
"""Returns the default test combinations for v1 only tf.data tests."""
return combinations.combine(tf_api_version=1, mode=["eager", "graph"])
def v2_only_combinations():
"""Returns the default test combinations for v2 only tf.data tests."""
return combinations.combine(tf_api_version=2, mode=["eager", "graph"])
def v2_eager_only_combinations():
"""Returns the default test combinations for v2 eager only tf.data tests."""
return combinations.combine(tf_api_version=2, mode="eager")
class DatasetTestBase(test.TestCase):
"""Base class for dataset tests."""
def assert_op_cancelled(self, op):
with self.assertRaises(errors.CancelledError):
self.evaluate(op)
def assertValuesEqual(self, expected, actual):
"""Asserts that two values are equal."""
if isinstance(expected, dict):
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
for k in expected.keys():
self.assertValuesEqual(expected[k], actual[k])
elif sparse_tensor.is_sparse(expected):
self.assertAllEqual(expected.indices, actual.indices)
self.assertAllEqual(expected.values, actual.values)
self.assertAllEqual(expected.dense_shape, actual.dense_shape)
else:
self.assertAllEqual(expected, actual)
def getNext(self, dataset, requires_initialization=False, shared_name=None):
"""Returns a callable that returns the next element of the dataset.
Example use:
```python
# In both graph and eager modes
dataset = ...
get_next = self.getNext(dataset)
result = self.evaluate(get_next())
```
Args:
dataset: A dataset whose elements will be returned.
requires_initialization: Indicates that when the test is executed in graph
mode, it should use an initializable iterator to iterate through the
dataset (e.g. when it contains stateful nodes). Defaults to False.
shared_name: (Optional.) If non-empty, the returned iterator will be
shared under the given name across multiple sessions that share the same
devices (e.g. when using a remote server).
Returns:
A callable that returns the next element of `dataset`. Any `TensorArray`
objects `dataset` outputs are stacked.
"""
def ta_wrapper(gn):
def _wrapper():
r = gn()
if isinstance(r, tensor_array_ops.TensorArray):
return r.stack()
else:
return r
return _wrapper
# Create an anonymous iterator if we are in eager-mode or are graph inside
# of a tf.function.
if context.executing_eagerly() or ops.inside_function():
iterator = iter(dataset)
return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access
else:
if requires_initialization:
iterator = dataset_ops.make_initializable_iterator(dataset, shared_name)
self.evaluate(iterator.initializer)
else:
iterator = dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next()
return ta_wrapper(lambda: get_next)
def _compareOutputToExpected(self, result_values, expected_values,
assert_items_equal):
if assert_items_equal:
# TODO(shivaniagrawal): add support for nested elements containing sparse
# tensors when needed.
self.assertItemsEqual(result_values, expected_values)
return
for i in range(len(result_values)):
nest.assert_same_structure(result_values[i], expected_values[i])
for result_value, expected_value in zip(
nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
self.assertValuesEqual(expected_value, result_value)
def getDatasetOutput(self, dataset, requires_initialization=False):
get_next = self.getNext(
dataset, requires_initialization=requires_initialization)
return self.getIteratorOutput(get_next)
def getIteratorOutput(self, get_next):
"""Evaluates `get_next` until end of input, returning the results."""
results = []
while True:
try:
results.append(self.evaluate(get_next()))
except errors.OutOfRangeError:
break
return results
def assertDatasetProduces(self,
dataset,
expected_output=None,
expected_shapes=None,
expected_error=None,
requires_initialization=False,
num_test_iterations=1,
assert_items_equal=False,
expected_error_iter=1):
"""Asserts that a dataset produces the expected output / error.
Args:
dataset: A dataset to check for the expected output / error.
expected_output: A list of elements that the dataset is expected to
produce.
expected_shapes: A list of TensorShapes which is expected to match
output_shapes of dataset.
expected_error: A tuple `(type, predicate)` identifying the expected error
`dataset` should raise. The `type` should match the expected exception
type, while `predicate` should either be 1) a unary function that inputs
the raised exception and returns a boolean indicator of success or 2) a
regular expression that is expected to match the error message
partially.
requires_initialization: Indicates that when the test is executed in graph
mode, it should use an initializable iterator to iterate through the
dataset (e.g. when it contains stateful nodes). Defaults to False.
num_test_iterations: Number of times `dataset` will be iterated. Defaults
to 1.
assert_items_equal: Tests expected_output has (only) the same elements
regardless of order.
expected_error_iter: How many times to iterate before expecting an error,
if an error is expected.
"""
self.assertTrue(
expected_error is not None or expected_output is not None,
"Exactly one of expected_output or expected error should be provided.")
if expected_error:
self.assertTrue(
expected_output is None,
"Exactly one of expected_output or expected error should be provided."
)
with self.assertRaisesWithPredicateMatch(expected_error[0],
expected_error[1]):
get_next = self.getNext(
dataset, requires_initialization=requires_initialization)
for _ in range(expected_error_iter):
self.evaluate(get_next())
return
if expected_shapes:
self.assertEqual(expected_shapes,
dataset_ops.get_legacy_output_shapes(dataset))
self.assertGreater(num_test_iterations, 0)
for _ in range(num_test_iterations):
get_next = self.getNext(
dataset, requires_initialization=requires_initialization)
result = []
for _ in range(len(expected_output)):
try:
result.append(self.evaluate(get_next()))
except errors.OutOfRangeError:
raise AssertionError(
"Dataset ended early, producing %d elements out of %d. "
"Dataset output: %s" %
(len(result), len(expected_output), str(result)))
self._compareOutputToExpected(result, expected_output, assert_items_equal)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
def assertDatasetsEqual(self, dataset1, dataset2):
"""Checks that datasets are equal. Supports both graph and eager mode."""
self.assertTrue(
structure.are_compatible(
dataset_ops.get_structure(dataset1),
dataset_ops.get_structure(dataset2)))
flattened_types = nest.flatten(
dataset_ops.get_legacy_output_types(dataset1))
next1 = self.getNext(dataset1)
next2 = self.getNext(dataset2)
while True:
try:
op1 = self.evaluate(next1())
except errors.OutOfRangeError:
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next2())
break
op2 = self.evaluate(next2())
op1 = nest.flatten(op1)
op2 = nest.flatten(op2)
assert len(op1) == len(op2)
for i in range(len(op1)):
if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(op1[i]):
self.assertValuesEqual(op1[i], op2[i])
elif flattened_types[i] == dtypes.string:
self.assertAllEqual(op1[i], op2[i])
else:
self.assertAllClose(op1[i], op2[i])
def assertDatasetsRaiseSameError(self,
dataset1,
dataset2,
exception_class,
replacements=None):
"""Checks that datasets raise the same error on the first get_next call."""
if replacements is None:
replacements = []
next1 = self.getNext(dataset1)
next2 = self.getNext(dataset2)
try:
self.evaluate(next1())
raise ValueError(
"Expected dataset to raise an error of type %s, but it did not." %
repr(exception_class))
except exception_class as e:
expected_message = e.message
for old, new, count in replacements:
expected_message = expected_message.replace(old, new, count)
# Check that the first segment of the error messages are the same.
with self.assertRaisesRegexp(exception_class,
re.escape(expected_message)):
self.evaluate(next2())
def structuredDataset(self, dataset_structure, shape=None,
dtype=dtypes.int64):
"""Returns a singleton dataset with the given structure."""
if shape is None:
shape = []
if dataset_structure is None:
return dataset_ops.Dataset.from_tensors(
array_ops.zeros(shape, dtype=dtype))
else:
return dataset_ops.Dataset.zip(
tuple([
self.structuredDataset(substructure, shape, dtype)
for substructure in dataset_structure
]))
def verifyRandomAccess(self, dataset, expected):
self.verifyRandomAccessInfiniteCardinality(dataset, expected)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(random_access.at(dataset, index=len(expected)))
def verifyRandomAccessInfiniteCardinality(self, dataset, expected):
"""Tests randomly accessing elements of a dataset."""
# Tests accessing the elements in a shuffled order with repeats.
len_expected = len(expected)
indices = list(range(len_expected)) * 2
random.shuffle(indices)
for i in indices:
self.assertAllEqual(expected[i],
self.evaluate(random_access.at(dataset, i)))
# Tests accessing the elements in order.
indices = set(sorted(indices))
for i in indices:
self.assertAllEqual(expected[i],
self.evaluate(random_access.at(dataset, i)))
def textFileInitializer(self, vals):
file = os.path.join(self.get_temp_dir(), "text_file_initializer")
with open(file, "w") as f:
f.write("\n".join(str(v) for v in vals) + "\n")
return lookup_ops.TextFileInitializer(file, dtypes.int64,
lookup_ops.TextFileIndex.LINE_NUMBER,
dtypes.int64,
lookup_ops.TextFileIndex.WHOLE_LINE)
def keyValueTensorInitializer(self, vals):
keys_tensor = constant_op.constant(
list(range(len(vals))), dtype=dtypes.int64)
vals_tensor = constant_op.constant(vals)
return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor)
def datasetInitializer(self, vals):
keys = dataset_ops.Dataset.range(len(vals))
values = dataset_ops.Dataset.from_tensor_slices(vals)
ds = dataset_ops.Dataset.zip((keys, values))
return data_lookup_ops.DatasetInitializer(ds)
def lookupTableInitializer(self, init_source, vals):
"""Returns a lookup table initializer for the given source and values.
Args:
init_source: One of ["textfile", "keyvalue", "dataset"], indicating what
type of initializer to use.
vals: The initializer values. The keys will be `range(len(vals))`.
"""
if init_source == "textfile":
return self.textFileInitializer(vals)
elif init_source == "keyvaluetensor":
return self.keyValueTensorInitializer(vals)
elif init_source == "dataset":
return self.datasetInitializer(vals)
else:
raise ValueError("Unrecognized init_source: " + init_source)
def graphRoundTrip(self, dataset, allow_stateful=False):
"""Converts a dataset to a graph and back."""
graph = gen_dataset_ops.dataset_to_graph(
dataset._variant_tensor, allow_stateful=allow_stateful) # pylint: disable=protected-access
return dataset_ops.from_variant(
gen_experimental_dataset_ops.dataset_from_graph(graph),
dataset.element_spec)
def structuredElement(self, element_structure, shape=None,
dtype=dtypes.int64):
"""Returns an element with the given structure."""
if shape is None:
shape = []
if element_structure is None:
return array_ops.zeros(shape, dtype=dtype)
else:
return tuple([
self.structuredElement(substructure, shape, dtype)
for substructure in element_structure
])
def checkDeterminism(self, dataset_fn, expect_determinism, expected_elements):
"""Tests whether a dataset produces its elements deterministically.
`dataset_fn` takes a delay_ms argument, which tells it how long to delay
production of the first dataset element. This gives us a way to trigger
out-of-order production of dataset elements.
Args:
dataset_fn: A function taking a delay_ms argument.
expect_determinism: Whether to expect deterministic ordering.
expected_elements: The elements expected to be produced by the dataset,
assuming the dataset produces elements in deterministic order.
"""
if expect_determinism:
dataset = dataset_fn(100)
actual = self.getDatasetOutput(dataset)
self.assertAllEqual(expected_elements, actual)
return
# We consider the test a success if it succeeds under any delay_ms. The
# delay_ms needed to observe non-deterministic ordering varies across
# test machines. Usually 10 or 100 milliseconds is enough, but on slow
# machines it could take longer.
for delay_ms in [10, 100, 1000, 20000, 100000]:
dataset = dataset_fn(delay_ms)
actual = self.getDatasetOutput(dataset)
self.assertCountEqual(expected_elements, actual)
for i in range(len(actual)):
if actual[i] != expected_elements[i]:
return
self.fail("Failed to observe nondeterministic ordering")
def configureDevicesForMultiDeviceTest(self, num_devices):
"""Configures number of logical devices for multi-device tests.
It returns a list of device names. If invoked in GPU-enabled runtime, the
last device name will be for a GPU device. Otherwise, all device names will
be for a CPU device.
Args:
num_devices: The number of devices to configure.
Returns:
A list of device names to use for a multi-device test.
"""
cpus = config.list_physical_devices("CPU")
gpus = config.list_physical_devices("GPU")
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration() for _ in range(num_devices)
])
devices = ["/device:CPU:" + str(i) for i in range(num_devices - 1)]
if gpus:
devices.append("/device:GPU:0")
else:
devices.append("/device:CPU:" + str(num_devices - 1))
return devices