Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/ops/scan_op.py
2023-06-19 00:49:18 +02:00

161 lines
6.9 KiB
Python

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The implementation of `tf.data.Dataset.shuffle`."""
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import structured_function
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util.compat import collections_abc
def _scan(input_dataset,
initial_state,
scan_func,
use_default_device=None,
name=None):
return _ScanDataset(
input_dataset, initial_state, scan_func, use_default_device, name=name)
class _ScanDataset(dataset_ops.UnaryDataset):
"""A dataset that scans a function across its input."""
def __init__(self,
input_dataset,
initial_state,
scan_func,
use_default_device=None,
name=None):
"""See `scan()` for details."""
self._input_dataset = input_dataset
self._initial_state = structure.normalize_element(initial_state)
# Compute initial values for the state classes, shapes and types based on
# the initial state. The shapes may be refined by running `tf_scan_func` one
# or more times below.
self._state_structure = structure.type_spec_from_value(self._initial_state)
# Iteratively rerun the scan function until reaching a fixed point on
# `self._state_shapes`.
need_to_rerun = True
while need_to_rerun:
wrapped_func = structured_function.StructuredFunctionWrapper(
scan_func,
self._transformation_name(),
input_structure=(self._state_structure, input_dataset.element_spec),
add_to_graph=False)
if not (isinstance(wrapped_func.output_types, collections_abc.Sequence)
and len(wrapped_func.output_types) == 2):
raise TypeError(f"Invalid `scan_func`. `scan_func` should return a "
f"pair consisting of new state and the output value "
f"but its return type is "
f"{wrapped_func.output_structure}.")
new_state_classes, self._output_classes = wrapped_func.output_classes
# Extract and validate class information from the returned values.
new_state_classes, output_classes = wrapped_func.output_classes
old_state_classes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._state_structure)
for new_state_class, old_state_class in zip(
nest.flatten(new_state_classes), nest.flatten(old_state_classes)):
if not issubclass(new_state_class, old_state_class):
raise TypeError(f"Invalid `scan_func`. The element classes for the "
f"new state must match the initial state. Expected "
f"{old_state_classes}, got {new_state_classes}.")
# Extract and validate type information from the returned values.
new_state_types, output_types = wrapped_func.output_types
old_state_types = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._state_structure)
for new_state_type, old_state_type in zip(
nest.flatten(new_state_types), nest.flatten(old_state_types)):
if new_state_type != old_state_type:
raise TypeError(f"Invalid `scan_func`. The element types for the "
f"new state must match the initial state. Expected "
f"{old_state_types}, got {new_state_types}.")
# Extract shape information from the returned values.
new_state_shapes, output_shapes = wrapped_func.output_shapes
old_state_shapes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._state_structure)
self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
flat_state_shapes = nest.flatten(old_state_shapes)
flat_new_state_shapes = nest.flatten(new_state_shapes)
weakened_state_shapes = [
original.most_specific_compatible_shape(new)
for original, new in zip(flat_state_shapes, flat_new_state_shapes)
]
need_to_rerun = False
for original_shape, weakened_shape in zip(flat_state_shapes,
weakened_state_shapes):
if original_shape.ndims is not None and (
weakened_shape.ndims is None or
original_shape.as_list() != weakened_shape.as_list()):
need_to_rerun = True
break
if need_to_rerun:
# TODO(b/110122868): Support a "most specific compatible structure"
# method for combining structures, to avoid using legacy structures
# in this method.
self._state_structure = structure.convert_legacy_structure(
old_state_types,
nest.pack_sequence_as(old_state_shapes, weakened_state_shapes),
old_state_classes)
self._scan_func = wrapped_func
self._scan_func.function.add_to_graph(ops.get_default_graph())
self._name = name
# pylint: disable=protected-access
if use_default_device is not None:
variant_tensor = ged_ops.scan_dataset(
self._input_dataset._variant_tensor,
structure.to_tensor_list(self._state_structure, self._initial_state),
self._scan_func.function.captured_inputs,
f=self._scan_func.function,
preserve_cardinality=True,
use_default_device=use_default_device,
**self._common_args)
else:
variant_tensor = ged_ops.scan_dataset(
self._input_dataset._variant_tensor,
structure.to_tensor_list(self._state_structure, self._initial_state),
self._scan_func.function.captured_inputs,
f=self._scan_func.function,
preserve_cardinality=True,
**self._common_args)
super().__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._scan_func]
@property
def element_spec(self):
return self._element_spec
def _transformation_name(self):
return "Dataset.scan()"