54 lines
2.2 KiB
Python
54 lines
2.2 KiB
Python
![]() |
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""The implementation of `tf.data.Dataset.choose_from_datasets`."""
|
||
|
|
||
|
from tensorflow.python.data.ops import dataset_ops
|
||
|
from tensorflow.python.data.ops import directed_interleave_op
|
||
|
from tensorflow.python.data.util import structure
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import tensor_spec
|
||
|
|
||
|
|
||
|
def _choose_from_datasets( # pylint: disable=unused-private-name
|
||
|
datasets, choice_dataset, stop_on_empty_dataset=True
|
||
|
):
|
||
|
"""See `Dataset.choose_from_datasets()` for details."""
|
||
|
|
||
|
if not datasets:
|
||
|
raise ValueError("Invalid `datasets`. `datasets` should not be empty.")
|
||
|
if not isinstance(choice_dataset, dataset_ops.DatasetV2):
|
||
|
raise TypeError(
|
||
|
"Invalid `choice_dataset`. `choice_dataset` should be a "
|
||
|
f"`tf.data.Dataset` but is {type(choice_dataset)}."
|
||
|
)
|
||
|
if not structure.are_compatible(
|
||
|
choice_dataset.element_spec, tensor_spec.TensorSpec([], dtypes.int64)
|
||
|
):
|
||
|
raise TypeError(
|
||
|
"Invalid `choice_dataset`. Elements of `choice_dataset` "
|
||
|
"must be scalar `tf.int64` tensors but are "
|
||
|
f"{choice_dataset.element_spec}."
|
||
|
)
|
||
|
# Replicates the `choice_dataset` component so that each split makes choices
|
||
|
# independently. This avoids the need for prohibitively expensive
|
||
|
# cross-split coordination.
|
||
|
# pylint: disable=protected-access
|
||
|
choice_dataset = dataset_ops._apply_rewrite(
|
||
|
choice_dataset, "replicate_on_split"
|
||
|
)
|
||
|
return directed_interleave_op._directed_interleave( # pylint: disable=protected-access
|
||
|
choice_dataset, datasets, stop_on_empty_dataset
|
||
|
)
|