Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/experimental/ops/testing.py
2023-06-19 00:49:18 +02:00

199 lines
7.3 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for testing of tf.data."""
from google.protobuf import text_format
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops
def assert_next(transformations):
"""A transformation that asserts which transformations happen next.
Transformations should be referred to by their base name, not including
version suffix. For example, use "Batch" instead of "BatchV2". "Batch" will
match any of "Batch", "BatchV1", "BatchV2", etc.
Args:
transformations: A `tf.string` vector `tf.Tensor` identifying the
transformations that are expected to happen next.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
return _AssertNextDataset(dataset, transformations)
return _apply_fn
def assert_prev(transformations):
r"""Asserts which transformations, with which attributes, happened previously.
Each transformation is repesented as a tuple in the input.
The first element is the base op name of the transformation, not including
version suffix. For example, use "BatchDataset" instead of
"BatchDatasetV2". "BatchDataset" will match any of "BatchDataset",
"BatchDatasetV1", "BatchDatasetV2", etc.
The second element is a dict of attribute name-value pairs. Attributes
values must be of type bool, int, or string.
Example usage:
>>> dataset_ops.Dataset.from_tensors(0) \
... .map(lambda x: x) \
... .batch(1, deterministic=True, num_parallel_calls=8) \
... .assert_prev([("ParallelBatchDataset", {"deterministic": True}), \
... ("MapDataset", {})])
Args:
transformations: A list of tuples identifying the (required) transformation
name, with (optional) attribute name-value pairs, that are expected to
have happened previously.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
return _AssertPrevDataset(dataset, transformations)
return _apply_fn
def non_serializable():
"""A non-serializable identity transformation.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
return _NonSerializableDataset(dataset)
return _apply_fn
def sleep(sleep_microseconds):
"""Sleeps for `sleep_microseconds` before producing each input element.
Args:
sleep_microseconds: The number of microseconds to sleep before producing an
input element.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
return _SleepDataset(dataset, sleep_microseconds)
return _apply_fn
class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that asserts which transformations happen next."""
def __init__(self, input_dataset, transformations):
"""See `assert_next()` for details."""
self._input_dataset = input_dataset
if transformations is None:
raise ValueError(
"Invalid `transformations`. `transformations` should not be empty.")
self._transformations = ops.convert_to_tensor(
transformations, dtype=dtypes.string, name="transformations")
variant_tensor = (
gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._transformations,
**self._flat_structure))
super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
class _AssertPrevDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that asserts which transformations happened previously."""
def __init__(self, input_dataset, transformations):
"""See `assert_prev()` for details."""
self._input_dataset = input_dataset
if transformations is None:
raise ValueError("`transformations` cannot be empty")
def serialize_transformation(op_name, attributes):
proto = attr_value_pb2.NameAttrList(name=op_name)
if attributes is None or isinstance(attributes, set):
attributes = dict()
for (name, value) in attributes.items():
if isinstance(value, bool):
proto.attr[name].b = value
elif isinstance(value, int):
proto.attr[name].i = value
elif isinstance(value, str):
proto.attr[name].s = value.encode()
else:
raise ValueError(
f"attribute value type ({type(value)}) must be bool, int, or str")
return text_format.MessageToString(proto)
self._transformations = ops.convert_to_tensor(
[serialize_transformation(*x) for x in transformations],
dtype=dtypes.string,
name="transformations")
variant_tensor = (
gen_experimental_dataset_ops.assert_prev_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._transformations,
**self._flat_structure))
super(_AssertPrevDataset, self).__init__(input_dataset, variant_tensor)
class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that performs non-serializable identity transformation."""
def __init__(self, input_dataset):
"""See `non_serializable()` for details."""
self._input_dataset = input_dataset
variant_tensor = (
gen_experimental_dataset_ops.experimental_non_serializable_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure))
super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that sleeps before producing each upstream element."""
def __init__(self, input_dataset, sleep_microseconds):
self._input_dataset = input_dataset
self._sleep_microseconds = sleep_microseconds
variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._sleep_microseconds,
**self._flat_structure)
super(_SleepDataset, self).__init__(input_dataset, variant_tensor)