Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/ops/choose_from_datasets_op.py
2023-06-19 00:49:18 +02:00

54 lines
2.2 KiB
Python

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The implementation of `tf.data.Dataset.choose_from_datasets`."""
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import directed_interleave_op
from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
def _choose_from_datasets( # pylint: disable=unused-private-name
datasets, choice_dataset, stop_on_empty_dataset=True
):
"""See `Dataset.choose_from_datasets()` for details."""
if not datasets:
raise ValueError("Invalid `datasets`. `datasets` should not be empty.")
if not isinstance(choice_dataset, dataset_ops.DatasetV2):
raise TypeError(
"Invalid `choice_dataset`. `choice_dataset` should be a "
f"`tf.data.Dataset` but is {type(choice_dataset)}."
)
if not structure.are_compatible(
choice_dataset.element_spec, tensor_spec.TensorSpec([], dtypes.int64)
):
raise TypeError(
"Invalid `choice_dataset`. Elements of `choice_dataset` "
"must be scalar `tf.int64` tensors but are "
f"{choice_dataset.element_spec}."
)
# Replicates the `choice_dataset` component so that each split makes choices
# independently. This avoids the need for prohibitively expensive
# cross-split coordination.
# pylint: disable=protected-access
choice_dataset = dataset_ops._apply_rewrite(
choice_dataset, "replicate_on_split"
)
return directed_interleave_op._directed_interleave( # pylint: disable=protected-access
choice_dataset, datasets, stop_on_empty_dataset
)