122 lines
5.1 KiB
Python
122 lines
5.1 KiB
Python
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""The implementation of `tf.data.Dataset.sample_from_datasets`."""
|
|
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.data.ops import directed_interleave_op
|
|
from tensorflow.python.data.ops import map_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import gen_stateless_random_ops
|
|
from tensorflow.python.ops import math_ops
|
|
|
|
|
|
def _sample_from_datasets(datasets, # pylint: disable=unused-private-name
|
|
weights=None,
|
|
seed=None,
|
|
stop_on_empty_dataset=False,
|
|
rerandomize_each_iteration=None):
|
|
"""See `Dataset.sample_from_datasets()` for details."""
|
|
|
|
def _skip_datasets_with_zero_weight(datasets, weights):
|
|
datasets_and_weights = [(dataset, weight)
|
|
for (dataset, weight) in zip(datasets, weights)
|
|
if weight > 0]
|
|
return (zip(*datasets_and_weights) if datasets_and_weights else
|
|
([datasets[0].take(0)], [1.]))
|
|
|
|
if not datasets:
|
|
raise ValueError("Invalid `datasets`. `datasets` should not be empty.")
|
|
|
|
if not isinstance(weights, dataset_ops.DatasetV2):
|
|
if weights is None:
|
|
# Select inputs with uniform probability.
|
|
logits = [[1.0] * len(datasets)]
|
|
|
|
else:
|
|
if isinstance(weights, ops.Tensor):
|
|
if not weights.shape.is_compatible_with([len(datasets)]):
|
|
raise ValueError(f"Invalid `weights`. The shape of `weights` "
|
|
f"should be compatible with `[len(datasets)]` "
|
|
f"but is {weights.shape}.")
|
|
else:
|
|
if len(datasets) != len(weights):
|
|
raise ValueError(f"Invalid `weights`. `weights` should have the "
|
|
f"same length as `datasets` but got "
|
|
f"`len(weights)={len(weights)}` vs. "
|
|
f"`len(datasets)={len(datasets)}`.")
|
|
|
|
# Use the given `weights` as the probability of choosing the respective
|
|
# input.
|
|
if not isinstance(weights, ops.Tensor):
|
|
datasets, weights = _skip_datasets_with_zero_weight(datasets, weights)
|
|
weights = ops.convert_to_tensor(weights, name="weights")
|
|
if weights.dtype not in (dtypes.float32, dtypes.float64):
|
|
raise TypeError(f"Invalid `weights`. `weights` type must be either "
|
|
f"`tf.float32` or `tf.float64` but is "
|
|
f"{weights.dtype}.")
|
|
|
|
# The `stateless_multinomial()` op expects log-probabilities, as opposed
|
|
# to weights.
|
|
logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
|
|
|
|
# NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When
|
|
# it is a `Dataset`, it is possible that evaluating it has a side effect
|
|
# the user depends on.
|
|
if len(datasets) == 1:
|
|
return datasets[0]
|
|
|
|
def select_dataset_constant_logits(seed):
|
|
return array_ops.squeeze(
|
|
gen_stateless_random_ops.stateless_multinomial(
|
|
logits, 1, seed=seed),
|
|
axis=[0, 1])
|
|
|
|
selector_input = map_op._MapDataset( # pylint: disable=protected-access
|
|
dataset_ops.Dataset.random(
|
|
seed=seed,
|
|
rerandomize_each_iteration=rerandomize_each_iteration).batch(2),
|
|
select_dataset_constant_logits,
|
|
use_inter_op_parallelism=False)
|
|
|
|
else: # isinstance(weights, DatasetV2)
|
|
# Use each element of the given `weights` dataset as the probability of
|
|
# choosing the respective input.
|
|
#
|
|
# The `stateless_multinomial()` op expects log-probabilities, as opposed
|
|
# to weights.
|
|
logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
|
|
|
|
def select_dataset_varying_logits(logits, seed):
|
|
return array_ops.squeeze(
|
|
gen_stateless_random_ops.stateless_multinomial(
|
|
logits, 1, seed=seed),
|
|
axis=[0, 1])
|
|
|
|
logits_and_seeds = dataset_ops.Dataset.zip(
|
|
(logits_ds,
|
|
dataset_ops.Dataset.random(
|
|
seed=seed,
|
|
rerandomize_each_iteration=rerandomize_each_iteration).batch(2)))
|
|
selector_input = map_op._MapDataset( # pylint: disable=protected-access
|
|
logits_and_seeds,
|
|
select_dataset_varying_logits,
|
|
use_inter_op_parallelism=False)
|
|
|
|
return directed_interleave_op._directed_interleave( # pylint: disable=protected-access
|
|
selector_input, datasets, stop_on_empty_dataset
|
|
)
|