1893 lines
65 KiB
Python
1893 lines
65 KiB
Python
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Adapter module that convert different input data objects into tf.dataset."""
|
||
|
|
||
|
import abc
|
||
|
import contextlib
|
||
|
import functools
|
||
|
import itertools
|
||
|
import math
|
||
|
import random
|
||
|
|
||
|
import numpy as np
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras.distribute import distributed_training_utils
|
||
|
from keras.engine import training_utils
|
||
|
from keras.utils import data_utils
|
||
|
from keras.utils import dataset_creator
|
||
|
from keras.utils import tf_utils
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.distribute.input_lib import (
|
||
|
DistributedDataset,
|
||
|
)
|
||
|
from tensorflow.python.eager import context
|
||
|
from tensorflow.python.framework import type_spec
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
try:
|
||
|
import pandas as pd
|
||
|
except ImportError:
|
||
|
pd = None
|
||
|
|
||
|
keras_data_adapter_gauge = tf.__internal__.monitoring.BoolGauge(
|
||
|
"/tensorflow/api/keras/data_adapters", "keras data adapter usage", "method"
|
||
|
)
|
||
|
|
||
|
|
||
|
class DataAdapter(object, metaclass=abc.ABCMeta):
|
||
|
"""Base class for input data adapter.
|
||
|
|
||
|
In TF 2.0, tf.data is the preferred API for user to feed in data. In order
|
||
|
to simplify the training code path, all the input data object will be
|
||
|
converted to `tf.data.Dataset` if possible.
|
||
|
|
||
|
Note that since this class is mainly targeted for TF 2.0, it might have a
|
||
|
lot of assumptions under the hood, e.g. eager context by default,
|
||
|
distribution strategy, etc. In the meantime, some legacy feature support
|
||
|
might be dropped, eg, Iterator from dataset API in v1, etc.
|
||
|
|
||
|
The sample usage of this class is like:
|
||
|
|
||
|
```
|
||
|
x = tf.data.Dataset.range(100)
|
||
|
adapter_cls = [NumpyArrayDataAdapter, ..., DatasetAdapter]
|
||
|
applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)]
|
||
|
if len(applicable_adapters) != 1:
|
||
|
raise ValueError("Expect only one adapter class to handle the input")
|
||
|
|
||
|
dataset = applicable_adapters[0](x).get_dataset()
|
||
|
for data in dataset:
|
||
|
# training
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def can_handle(x, y=None):
|
||
|
"""Whether the current DataAdapter could handle the input x and y.
|
||
|
|
||
|
Structure wise, x and y can be single object, or list of objects if
|
||
|
there multiple input/output, or dictionary of objects when the
|
||
|
input/output are named.
|
||
|
|
||
|
Args:
|
||
|
x: input features.
|
||
|
y: target labels. Note that y could be None in the case of prediction.
|
||
|
|
||
|
Returns:
|
||
|
boolean
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def __init__(self, x, y=None, **kwargs):
|
||
|
"""Create a DataAdapter based on data inputs.
|
||
|
|
||
|
The caller must make sure to call `can_handle()` first before invoking
|
||
|
this method. Provide unsupported data type will result into unexpected
|
||
|
behavior.
|
||
|
|
||
|
Args:
|
||
|
x: input features.
|
||
|
y: target labels. Note that y could be None in the case of prediction.
|
||
|
**kwargs: Other keyword arguments for DataAdapter during the
|
||
|
construction of the tf.dataset.Dataset. For example:
|
||
|
- Numpy data might have `sample_weights` which will be used for
|
||
|
weighting the loss function during training.
|
||
|
- Numpy data might need to have `batch_size` parameter when
|
||
|
constructing the dataset and iterator.
|
||
|
- Certain input might need to be distribution strategy aware. When
|
||
|
`distribution_strategy` is passed, the created dataset need to
|
||
|
respect the strategy.
|
||
|
DataAdapter might choose to ignore any keyword argument if it
|
||
|
doesn't use it, or raise exception if any required argument is not
|
||
|
provided.
|
||
|
"""
|
||
|
if not self.can_handle(x, y):
|
||
|
raise ValueError(f"{self.__class__} Cannot handle input {x}, {y}")
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def get_dataset(self):
|
||
|
"""Get a dataset instance for the current DataAdapter.
|
||
|
|
||
|
Note that the dataset returned does not repeat for epoch, so caller
|
||
|
might need to create new iterator for the same dataset at the beginning
|
||
|
of the epoch. This behavior might change in the future.
|
||
|
|
||
|
Returns:
|
||
|
A `tf.data.Dataset`. Caller might use the dataset in different
|
||
|
context, e.g. iter(dataset) in eager to get the value directly, or in
|
||
|
graph mode, provide the iterator tensor to Keras model function.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def get_size(self):
|
||
|
"""Return the size (number of batches) for the dataset created.
|
||
|
|
||
|
For certain type of the data input, the number of batches is known, eg
|
||
|
for Numpy data, the size is same as (number_of_element / batch_size).
|
||
|
Whereas for dataset or python generator, the size is unknown since it
|
||
|
may or may not have an end state.
|
||
|
|
||
|
Returns:
|
||
|
int, the number of batches for the dataset, or None if it is unknown.
|
||
|
The caller could use this to control the loop of training, show
|
||
|
progress bar, or handle unexpected StopIteration error.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def batch_size(self):
|
||
|
"""Return the batch size of the dataset created.
|
||
|
|
||
|
For certain type of the data input, the batch size is known, and even
|
||
|
required, like numpy array. Whereas for dataset, the batch is unknown
|
||
|
unless we take a peek.
|
||
|
|
||
|
Returns:
|
||
|
int, the batch size of the dataset, or None if it is unknown.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def representative_batch_size(self):
|
||
|
"""Return a representative size for batches in the dataset.
|
||
|
|
||
|
This is not guaranteed to be the batch size for all batches in the
|
||
|
dataset. It just needs to be a rough approximation for batch sizes in
|
||
|
the dataset.
|
||
|
|
||
|
Returns:
|
||
|
int, a representative size for batches found in the dataset,
|
||
|
or None if it is unknown.
|
||
|
"""
|
||
|
return self.batch_size()
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def has_partial_batch(self):
|
||
|
"""Whether the dataset has partial batch at the end."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def partial_batch_size(self):
|
||
|
"""The size of the final partial batch for dataset.
|
||
|
|
||
|
Will return None if has_partial_batch is False or batch_size is None.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def should_recreate_iterator(self):
|
||
|
"""Returns whether a new iterator should be created every epoch."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def get_samples(self):
|
||
|
"""Returns number of samples in the data, or `None`."""
|
||
|
if not self.get_size() or not self.batch_size():
|
||
|
return None
|
||
|
total_sample = self.get_size() * self.batch_size()
|
||
|
if self.has_partial_batch():
|
||
|
total_sample -= self.batch_size() - self.partial_batch_size()
|
||
|
return total_sample
|
||
|
|
||
|
def on_epoch_end(self):
|
||
|
"""A hook called after each epoch."""
|
||
|
pass
|
||
|
|
||
|
|
||
|
class TensorLikeDataAdapter(DataAdapter):
|
||
|
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
|
||
|
|
||
|
@staticmethod
|
||
|
def can_handle(x, y=None):
|
||
|
# TODO(kaftan): Check performance implications of using a flatten
|
||
|
# here for other types of inputs.
|
||
|
flat_inputs = tf.nest.flatten(x)
|
||
|
if y is not None:
|
||
|
flat_inputs += tf.nest.flatten(y)
|
||
|
|
||
|
tensor_types = _get_tensor_types()
|
||
|
|
||
|
def _is_tensor(v):
|
||
|
if isinstance(v, tensor_types):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
return all(_is_tensor(v) for v in flat_inputs)
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
x,
|
||
|
y=None,
|
||
|
sample_weights=None,
|
||
|
sample_weight_modes=None,
|
||
|
batch_size=None,
|
||
|
epochs=1,
|
||
|
steps=None,
|
||
|
shuffle=False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(x, y, **kwargs)
|
||
|
x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
|
||
|
sample_weight_modes = broadcast_sample_weight_modes(
|
||
|
sample_weights, sample_weight_modes
|
||
|
)
|
||
|
|
||
|
# If sample_weights are not specified for an output use 1.0 as weights.
|
||
|
(sample_weights, _, _) = training_utils.handle_partial_sample_weights(
|
||
|
y, sample_weights, sample_weight_modes, check_all_flat=True
|
||
|
)
|
||
|
|
||
|
inputs = pack_x_y_sample_weight(x, y, sample_weights)
|
||
|
|
||
|
num_samples = set(
|
||
|
int(i.shape[0]) for i in tf.nest.flatten(inputs)
|
||
|
).pop()
|
||
|
_check_data_cardinality(inputs)
|
||
|
|
||
|
# If batch_size is not passed but steps is, calculate from the input
|
||
|
# data. Default to 32 for backwards compat.
|
||
|
if not batch_size:
|
||
|
batch_size = int(math.ceil(num_samples / steps)) if steps else 32
|
||
|
|
||
|
self._size = int(math.ceil(num_samples / batch_size))
|
||
|
self._batch_size = batch_size
|
||
|
|
||
|
num_full_batches = int(num_samples // batch_size)
|
||
|
self._partial_batch_size = num_samples % batch_size
|
||
|
|
||
|
if isinstance(shuffle, str):
|
||
|
shuffle = shuffle.lower()
|
||
|
|
||
|
self._shuffle = shuffle
|
||
|
# Vectorized version of shuffle.
|
||
|
# This is a performance improvement over using `from_tensor_slices`.
|
||
|
# The indices of the data are shuffled and batched, and these indices
|
||
|
# are then zipped with the data and used to extract a batch of the data
|
||
|
# at each step. The performance improvements here come from:
|
||
|
# 1. vectorized batch using gather
|
||
|
# 2. parallelized map
|
||
|
# 3. pipelined permutation generation
|
||
|
# 4. optimized permutation batching
|
||
|
# 5. disabled static optimizations
|
||
|
|
||
|
indices_dataset = tf.data.Dataset.range(1)
|
||
|
if shuffle != "batch":
|
||
|
indices_dataset = indices_dataset.repeat(epochs)
|
||
|
|
||
|
def permutation(_):
|
||
|
# It turns out to be more performant to make a new set of indices
|
||
|
# rather than reusing the same range Tensor. (presumably because of
|
||
|
# buffer forwarding.)
|
||
|
indices = tf.range(num_samples, dtype=tf.int64)
|
||
|
if shuffle and shuffle != "batch":
|
||
|
indices = tf.random.shuffle(indices)
|
||
|
return indices
|
||
|
|
||
|
# We prefetch a single element. Computing large permutations can take
|
||
|
# quite a while so we don't want to wait for prefetching over an epoch
|
||
|
# boundary to trigger the next permutation. On the other hand, too many
|
||
|
# simultaneous shuffles can contend on a hardware level and degrade all
|
||
|
# performance.
|
||
|
indices_dataset = indices_dataset.map(permutation).prefetch(1)
|
||
|
|
||
|
def slice_batch_indices(indices):
|
||
|
"""Convert a Tensor of indices into a dataset of batched indices.
|
||
|
|
||
|
This step can be accomplished in several ways. The most natural is
|
||
|
to slice the Tensor in a Dataset map. (With a condition on the upper
|
||
|
index to handle the partial batch.) However it turns out that
|
||
|
coercing the Tensor into a shape which is divisible by the batch
|
||
|
size (and handling the last partial batch separately) allows for a
|
||
|
much more favorable memory access pattern and improved performance.
|
||
|
|
||
|
Args:
|
||
|
indices: Tensor which determines the data order for an entire
|
||
|
epoch.
|
||
|
|
||
|
Returns:
|
||
|
A Dataset of batched indices.
|
||
|
"""
|
||
|
num_in_full_batch = num_full_batches * batch_size
|
||
|
first_k_indices = tf.slice(indices, [0], [num_in_full_batch])
|
||
|
first_k_indices = tf.reshape(
|
||
|
first_k_indices, [num_full_batches, batch_size]
|
||
|
)
|
||
|
|
||
|
flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices)
|
||
|
if self._partial_batch_size:
|
||
|
index_remainder = tf.data.Dataset.from_tensors(
|
||
|
tf.slice(
|
||
|
indices, [num_in_full_batch], [self._partial_batch_size]
|
||
|
)
|
||
|
)
|
||
|
flat_dataset = flat_dataset.concatenate(index_remainder)
|
||
|
|
||
|
if shuffle == "batch":
|
||
|
# 1024 is a magic constant that has not been properly evaluated
|
||
|
flat_dataset = flat_dataset.shuffle(1024).repeat(epochs)
|
||
|
return flat_dataset
|
||
|
|
||
|
indices_dataset = indices_dataset.flat_map(slice_batch_indices)
|
||
|
|
||
|
dataset = self.slice_inputs(indices_dataset, inputs)
|
||
|
|
||
|
if shuffle == "batch":
|
||
|
|
||
|
def shuffle_batch(*batch):
|
||
|
return tf.nest.map_structure(tf.random.shuffle, batch)
|
||
|
|
||
|
dataset = dataset.map(shuffle_batch)
|
||
|
|
||
|
options = tf.data.Options()
|
||
|
options.experimental_distribute.auto_shard_policy = (
|
||
|
tf.data.experimental.AutoShardPolicy.DATA
|
||
|
)
|
||
|
dataset = dataset.with_options(options)
|
||
|
|
||
|
self._dataset = dataset
|
||
|
|
||
|
def slice_inputs(self, indices_dataset, inputs):
|
||
|
"""Slice inputs into a Dataset of batches.
|
||
|
|
||
|
Given a Dataset of batch indices and the unsliced inputs,
|
||
|
this step slices the inputs in a parallelized fashion
|
||
|
and produces a dataset of input batches.
|
||
|
|
||
|
Args:
|
||
|
indices_dataset: A Dataset of batched indices
|
||
|
inputs: A python data structure that contains the inputs, targets,
|
||
|
and possibly sample weights.
|
||
|
|
||
|
Returns:
|
||
|
A Dataset of input batches matching the batch indices.
|
||
|
"""
|
||
|
dataset = tf.data.Dataset.zip(
|
||
|
(indices_dataset, tf.data.Dataset.from_tensors(inputs).repeat())
|
||
|
)
|
||
|
|
||
|
def grab_batch(i, data):
|
||
|
return tf.nest.map_structure(
|
||
|
lambda d: tf.gather(d, i, axis=0), data
|
||
|
)
|
||
|
|
||
|
dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE)
|
||
|
|
||
|
# Default optimizations are disabled to avoid the overhead of
|
||
|
# (unnecessary) input pipeline graph serialization and deserialization
|
||
|
options = tf.data.Options()
|
||
|
options.experimental_optimization.apply_default_optimizations = False
|
||
|
if self._shuffle:
|
||
|
# See b/141490660 for more details.
|
||
|
options.experimental_external_state_policy = (
|
||
|
tf.data.experimental.ExternalStatePolicy.IGNORE
|
||
|
)
|
||
|
dataset = dataset.with_options(options)
|
||
|
return dataset
|
||
|
|
||
|
def get_dataset(self):
|
||
|
return self._dataset
|
||
|
|
||
|
def get_size(self):
|
||
|
return self._size
|
||
|
|
||
|
def batch_size(self):
|
||
|
return self._batch_size
|
||
|
|
||
|
def has_partial_batch(self):
|
||
|
return self._partial_batch_size > 0
|
||
|
|
||
|
def partial_batch_size(self):
|
||
|
return self._partial_batch_size or None
|
||
|
|
||
|
def should_recreate_iterator(self):
|
||
|
# An infinite dataset is always created here.
|
||
|
return False
|
||
|
|
||
|
|
||
|
class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
|
||
|
"""Adapter that handles array-like data without forcing it into memory.
|
||
|
|
||
|
This adapter handles array-like datasets that may be too big to fully
|
||
|
fit into memory.
|
||
|
|
||
|
Specifically, this adapter handles any Python class which implements:
|
||
|
`__get_item__`, `__len__`, `shape`, and `dtype` with the same meanings
|
||
|
as Numpy, but it ignores any case where all the inputs are Tensors or Numpy
|
||
|
arrays (because that case is handled by the base TensorLikeDataAdapter).
|
||
|
|
||
|
It ignores scipy sparse matrices and Composite Tensors because those are
|
||
|
handled by the CompositeTensorDataAdapter.
|
||
|
|
||
|
It also does not handle lists/tuples of scalars, because those are handled
|
||
|
by the ListsOfScalarsDataAdapter.
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def can_handle(x, y=None):
|
||
|
flat_inputs = tf.nest.flatten(x)
|
||
|
if y is not None:
|
||
|
flat_inputs += tf.nest.flatten(y)
|
||
|
|
||
|
def _is_array_like(v):
|
||
|
"""Return True if v is a Tensor, array, or is array-like."""
|
||
|
return (
|
||
|
hasattr(v, "__getitem__")
|
||
|
and hasattr(v, "shape")
|
||
|
and hasattr(v, "dtype")
|
||
|
and hasattr(v, "__len__")
|
||
|
)
|
||
|
|
||
|
if not TensorLikeDataAdapter.can_handle(
|
||
|
x, y
|
||
|
) and not CompositeTensorDataAdapter.can_handle(x, y):
|
||
|
return all(_is_array_like(v) for v in flat_inputs)
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
logging.warning(
|
||
|
"Keras is training/fitting/evaluating on array-like data. Keras "
|
||
|
"may not be optimized for this format, so if your input data "
|
||
|
"format is supported by TensorFlow I/O "
|
||
|
"(https://github.com/tensorflow/io) we recommend using that to "
|
||
|
"load a Dataset instead."
|
||
|
)
|
||
|
|
||
|
super().__init__(*args, **kwargs)
|
||
|
|
||
|
def slice_inputs(self, indices_dataset, inputs):
|
||
|
"""Slice inputs into a Dataset of batches.
|
||
|
|
||
|
Given a Dataset of batch indices and the unsliced inputs,
|
||
|
this step slices the inputs in a parallelized fashion
|
||
|
and produces a dataset of input batches.
|
||
|
|
||
|
Args:
|
||
|
indices_dataset: A Dataset of batched indices
|
||
|
inputs: A python data structure that contains the inputs, targets,
|
||
|
and possibly sample weights.
|
||
|
|
||
|
Returns:
|
||
|
A Dataset of input batches matching the batch indices.
|
||
|
"""
|
||
|
flat_inputs = tf.nest.flatten(inputs)
|
||
|
|
||
|
def dynamic_shape_like(t):
|
||
|
shape = list(t.shape)
|
||
|
shape[0] = None
|
||
|
return tuple(shape)
|
||
|
|
||
|
flat_dtypes = [inp.dtype for inp in flat_inputs]
|
||
|
contiguous = True
|
||
|
if self._shuffle and self._shuffle != "batch":
|
||
|
contiguous = False
|
||
|
|
||
|
def grab_batch(indices):
|
||
|
"""Grab a batch of data from the inputs."""
|
||
|
# This uses a py_function to avoid converting the array-like
|
||
|
# into a Tensor before slicing it, because converting the array-like
|
||
|
# to a Tensor may force it into memory..
|
||
|
def py_method(ind):
|
||
|
def slice_array(data):
|
||
|
return training_utils.slice_arrays(
|
||
|
data, ind.numpy(), contiguous=contiguous
|
||
|
)
|
||
|
|
||
|
return [slice_array(inp) for inp in flat_inputs]
|
||
|
|
||
|
flat_out = tf.py_function(py_method, [indices], flat_dtypes)
|
||
|
for v, original_inp in zip(flat_out, flat_inputs):
|
||
|
v.set_shape(dynamic_shape_like(original_inp))
|
||
|
return tf.nest.pack_sequence_as(inputs, flat_out)
|
||
|
|
||
|
dataset = indices_dataset.map(
|
||
|
grab_batch, num_parallel_calls=tf.data.AUTOTUNE
|
||
|
)
|
||
|
|
||
|
return dataset
|
||
|
|
||
|
|
||
|
class DatasetCreatorAdapter(DataAdapter):
|
||
|
"""Adapter that handles dataset functions."""
|
||
|
|
||
|
def __init__(self, x, y, steps=None, distribution_strategy=None, **kwargs):
|
||
|
super().__init__(x, **kwargs)
|
||
|
|
||
|
if not isinstance(x, dataset_creator.DatasetCreator):
|
||
|
raise TypeError(
|
||
|
"The input of a `DatasetCreatorAdapter` should be a "
|
||
|
"`DatasetCreator` but it received type {}.".format(type(x))
|
||
|
)
|
||
|
if steps is None:
|
||
|
raise ValueError(
|
||
|
"When using a "
|
||
|
"`tf.keras.utils.experimental.DatasetCreator`, "
|
||
|
"`steps_per_epoch`, `validation_steps` or `steps` "
|
||
|
"argument must be provided in `Model.fit`, "
|
||
|
"`Model.evaluate`, or `Model.predict`."
|
||
|
)
|
||
|
self.dataset_creator = x
|
||
|
self.steps = steps
|
||
|
self.strategy = distribution_strategy
|
||
|
|
||
|
@staticmethod
|
||
|
def can_handle(x, y=None):
|
||
|
if isinstance(x, dataset_creator.DatasetCreator):
|
||
|
assert y is None
|
||
|
return True
|
||
|
|
||
|
def should_recreate_iterator(self):
|
||
|
# We expect users to shuffle the dataset in their `dataset_fn` supplied
|
||
|
# to `DatasetCreator`. Since that is a buffered shuffle, we intend to
|
||
|
# not reset the dataset so the batches that are not shuffled can still
|
||
|
# be pulled.
|
||
|
return False
|
||
|
|
||
|
def get_size(self):
|
||
|
return None # To be inferred by `DataHandler`.
|
||
|
|
||
|
def get_dataset(self):
|
||
|
return self.strategy.distribute_datasets_from_function(
|
||
|
self.dataset_creator, options=self.dataset_creator.input_options
|
||
|
)
|
||
|
|
||
|
def batch_size(self):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def has_partial_batch(self):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def partial_batch_size(self):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
|
||
|
class CompositeTensorDataAdapter(DataAdapter):
|
||
|
"""Adapter that handles composite tensor."""
|
||
|
|
||
|
@staticmethod
|
||
|
def can_handle(x, y=None):
|
||
|
flat_inputs = tf.nest.flatten(x)
|
||
|
if y is not None:
|
||
|
flat_inputs += tf.nest.flatten(y)
|
||
|
|
||
|
def _is_composite(v):
|
||
|
# Dataset/iterator/DistributedDataset inherits from CompositeTensor
|
||
|
# but should be handled by DatasetAdapter and GeneratorAdapter.
|
||
|
if (
|
||
|
tf_utils.is_extension_type(v)
|
||
|
and not isinstance(v, (tf.data.Dataset, tf.data.Iterator))
|
||
|
and not _is_distributed_dataset(v)
|
||
|
):
|
||
|
return True
|
||
|
# Support Scipy sparse tensors if scipy is installed
|
||
|
return _is_scipy_sparse(v)
|
||
|
|
||
|
def _is_tensor_or_composite(v):
|
||
|
if isinstance(v, (tf.Tensor, np.ndarray)):
|
||
|
return True
|
||
|
return _is_composite(v)
|
||
|
|
||
|
return any(_is_composite(v) for v in flat_inputs) and all(
|
||
|
_is_tensor_or_composite(v) for v in flat_inputs
|
||
|
)
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
x,
|
||
|
y=None,
|
||
|
sample_weights=None,
|
||
|
sample_weight_modes=None,
|
||
|
batch_size=None,
|
||
|
steps=None,
|
||
|
shuffle=False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(x, y, **kwargs)
|
||
|
x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
|
||
|
sample_weight_modes = broadcast_sample_weight_modes(
|
||
|
sample_weights, sample_weight_modes
|
||
|
)
|
||
|
|
||
|
# If sample_weights are not specified for an output use 1.0 as weights.
|
||
|
(sample_weights, _, _) = training_utils.handle_partial_sample_weights(
|
||
|
y, sample_weights, sample_weight_modes, check_all_flat=True
|
||
|
)
|
||
|
|
||
|
inputs = pack_x_y_sample_weight(x, y, sample_weights)
|
||
|
|
||
|
dataset = tf.data.Dataset.from_tensor_slices(inputs)
|
||
|
num_samples = int(tf.nest.flatten(x)[0].shape[0])
|
||
|
if shuffle:
|
||
|
dataset = dataset.shuffle(num_samples)
|
||
|
|
||
|
# If batch_size is not passed but steps is, calculate from the input
|
||
|
# data. Default to 32 for backwards compatibility.
|
||
|
if not batch_size:
|
||
|
batch_size = int(math.ceil(num_samples / steps)) if steps else 32
|
||
|
|
||
|
dataset = dataset.batch(batch_size)
|
||
|
self._size = int(math.ceil(num_samples / batch_size))
|
||
|
self._batch_size = batch_size
|
||
|
self._has_partial_batch = self._size != (num_samples // batch_size)
|
||
|
|
||
|
self._partial_batch_size = None
|
||
|
if self._has_partial_batch:
|
||
|
self._partial_batch_size = (
|
||
|
num_samples - (self._size - 1) * self._batch_size
|
||
|
)
|
||
|
|
||
|
self._dataset = dataset
|
||
|
|
||
|
def get_dataset(self):
|
||
|
return self._dataset
|
||
|
|
||
|
def get_size(self):
|
||
|
return self._size
|
||
|
|
||
|
def batch_size(self):
|
||
|
return self._batch_size
|
||
|
|
||
|
def has_partial_batch(self):
|
||
|
return self._has_partial_batch
|
||
|
|
||
|
def partial_batch_size(self):
|
||
|
return self._partial_batch_size
|
||
|
|
||
|
def should_recreate_iterator(self):
|
||
|
return True
|
||
|
|
||
|
|
||
|
class ListsOfScalarsDataAdapter(DataAdapter):
|
||
|
"""Adapter that handles lists of scalars and lists of lists of scalars."""
|
||
|
|
||
|
@staticmethod
|
||
|
def can_handle(x, y=None):
|
||
|
handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x)
|
||
|
handles_y = True
|
||
|
if y is not None:
|
||
|
handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y)
|
||
|
return handles_x and handles_y
|
||
|
|
||
|
@staticmethod
|
||
|
def _is_list_of_scalars(inp):
|
||
|
if isinstance(inp, (float, int, str, bytes, bytearray)):
|
||
|
return True
|
||
|
if isinstance(inp, (list, tuple)) and inp:
|
||
|
return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0])
|
||
|
return False
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
x,
|
||
|
y=None,
|
||
|
sample_weights=None,
|
||
|
sample_weight_modes=None,
|
||
|
batch_size=None,
|
||
|
shuffle=False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(x, y, **kwargs)
|
||
|
x = np.asarray(x)
|
||
|
if y is not None:
|
||
|
y = np.asarray(y)
|
||
|
if sample_weights is not None:
|
||
|
sample_weights = np.asarray(sample_weights)
|
||
|
sample_weight_modes = broadcast_sample_weight_modes(
|
||
|
sample_weights, sample_weight_modes
|
||
|
)
|
||
|
|
||
|
self._internal_adapter = TensorLikeDataAdapter(
|
||
|
x,
|
||
|
y=y,
|
||
|
sample_weights=sample_weights,
|
||
|
sample_weight_modes=sample_weight_modes,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=shuffle,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
def get_dataset(self):
|
||
|
return self._internal_adapter.get_dataset()
|
||
|
|
||
|
def get_size(self):
|
||
|
return self._internal_adapter.get_size()
|
||
|
|
||
|
def batch_size(self):
|
||
|
return self._internal_adapter.batch_size()
|
||
|
|
||
|
def has_partial_batch(self):
|
||
|
return self._internal_adapter.has_partial_batch()
|
||
|
|
||
|
def partial_batch_size(self):
|
||
|
return self._internal_adapter.partial_batch_size()
|
||
|
|
||
|
def should_recreate_iterator(self):
|
||
|
return True
|
||
|
|
||
|
|
||
|
class DatasetAdapter(DataAdapter):
|
||
|
"""Adapter that handles `tf.data.Dataset`."""
|
||
|
|
||
|
@staticmethod
|
||
|
def can_handle(x, y=None):
|
||
|
return isinstance(
|
||
|
x, (tf.compat.v1.data.Dataset, tf.data.Dataset)
|
||
|
) or _is_distributed_dataset(x)
|
||
|
|
||
|
def __init__(self, x, y=None, sample_weights=None, steps=None, **kwargs):
|
||
|
super().__init__(x, y, **kwargs)
|
||
|
# Note that the dataset instance is immutable, its fine to reuse the
|
||
|
# user provided dataset.
|
||
|
self._dataset = x
|
||
|
|
||
|
# The user-provided steps.
|
||
|
self._user_steps = steps
|
||
|
|
||
|
self._validate_args(y, sample_weights, steps)
|
||
|
|
||
|
def get_dataset(self):
|
||
|
return self._dataset
|
||
|
|
||
|
def get_size(self):
|
||
|
return # Inferred in `DataHandler`.
|
||
|
|
||
|
def batch_size(self):
|
||
|
return None
|
||
|
|
||
|
def has_partial_batch(self):
|
||
|
return False
|
||
|
|
||
|
def partial_batch_size(self):
|
||
|
return None
|
||
|
|
||
|
def should_recreate_iterator(self):
|
||
|
# Since DistributedDatasets have no cardinality, the user must provide
|
||
|
# all steps that need to be run, calling `.repeat()` as needed.
|
||
|
if _is_distributed_dataset(self._dataset):
|
||
|
return False
|
||
|
|
||
|
# If user doesn't supply `steps`, or if they supply `steps` that
|
||
|
# exactly equals the size of the `Dataset`, create a new iterator
|
||
|
# each epoch.
|
||
|
return (
|
||
|
self._user_steps is None
|
||
|
or tf.data.experimental.cardinality(self._dataset).numpy()
|
||
|
== self._user_steps
|
||
|
)
|
||
|
|
||
|
def _validate_args(self, y, sample_weights, steps):
|
||
|
"""Validates `__init__` arguments."""
|
||
|
# Arguments that shouldn't be passed.
|
||
|
if not is_none_or_empty(y):
|
||
|
raise ValueError(
|
||
|
"`y` argument is not supported when using dataset as input."
|
||
|
)
|
||
|
if not is_none_or_empty(sample_weights):
|
||
|
raise ValueError(
|
||
|
"`sample_weight` argument is not supported when using "
|
||
|
"dataset as input."
|
||
|
)
|
||
|
|
||
|
if steps is None:
|
||
|
if _is_distributed_dataset(self._dataset):
|
||
|
raise ValueError(
|
||
|
"When providing a distributed dataset, you must "
|
||
|
"specify the number of steps to run."
|
||
|
)
|
||
|
|
||
|
size = tf.data.experimental.cardinality(self._dataset).numpy()
|
||
|
if (
|
||
|
size == tf.data.experimental.INFINITE_CARDINALITY
|
||
|
and steps is None
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"When providing an infinite dataset, you must specify "
|
||
|
"the number of steps to run (if you did not intend to "
|
||
|
"create an infinite dataset, make sure to not call "
|
||
|
"`repeat()` on the dataset)."
|
||
|
)
|
||
|
|
||
|
|
||
|
class GeneratorDataAdapter(DataAdapter):
|
||
|
"""Adapter that handles python generators and iterators."""
|
||
|
|
||
|
@staticmethod
|
||
|
def can_handle(x, y=None):
|
||
|
return (
|
||
|
(hasattr(x, "__next__") or hasattr(x, "next"))
|
||
|
and hasattr(x, "__iter__")
|
||
|
and not isinstance(x, data_utils.Sequence)
|
||
|
)
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
x,
|
||
|
y=None,
|
||
|
sample_weights=None,
|
||
|
workers=1,
|
||
|
use_multiprocessing=False,
|
||
|
max_queue_size=10,
|
||
|
model=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
# Generators should never shuffle as exhausting the generator in order
|
||
|
# to shuffle the batches is inefficient.
|
||
|
kwargs.pop("shuffle", None)
|
||
|
|
||
|
if not is_none_or_empty(y):
|
||
|
raise ValueError(
|
||
|
"`y` argument is not supported when using "
|
||
|
"python generator as input."
|
||
|
)
|
||
|
if not is_none_or_empty(sample_weights):
|
||
|
raise ValueError(
|
||
|
"`sample_weight` argument is not supported when using "
|
||
|
"python generator as input."
|
||
|
)
|
||
|
|
||
|
super().__init__(x, y, **kwargs)
|
||
|
|
||
|
# Since we have to know the dtype of the python generator when we build
|
||
|
# the dataset, we have to look at a batch to infer the structure.
|
||
|
peek, x = self._peek_and_restore(x)
|
||
|
peek = self._standardize_batch(peek)
|
||
|
peek = _process_tensorlike(peek)
|
||
|
|
||
|
# Need to build the Model on concrete input shapes.
|
||
|
if model is not None and not model.built:
|
||
|
concrete_x, _, _ = unpack_x_y_sample_weight(peek)
|
||
|
try:
|
||
|
model.distribute_strategy.run(
|
||
|
lambda x: model(x, training=False), args=(concrete_x,)
|
||
|
)
|
||
|
except NotImplementedError:
|
||
|
# The above call may fail if the model is a container-like class
|
||
|
# that does not implement its own forward pass (e.g. a GAN or
|
||
|
# VAE where the forward pass is handled by subcomponents). Such
|
||
|
# a model does not need to be built.
|
||
|
pass
|
||
|
|
||
|
self._first_batch_size = int(tf.nest.flatten(peek)[0].shape[0])
|
||
|
|
||
|
def _get_tensor_spec(t):
|
||
|
# TODO(b/226395276): Remove _with_tensor_ranks_only usage.
|
||
|
return type_spec.type_spec_from_value(t)._with_tensor_ranks_only()
|
||
|
|
||
|
output_signature = tf.nest.map_structure(_get_tensor_spec, peek)
|
||
|
|
||
|
# Note that dataset API takes a callable that creates a generator
|
||
|
# object, rather than generator itself, which is why we define a
|
||
|
# function here.
|
||
|
generator_fn = self._handle_multiprocessing(
|
||
|
x, workers, use_multiprocessing, max_queue_size
|
||
|
)
|
||
|
|
||
|
def wrapped_generator():
|
||
|
for data in generator_fn():
|
||
|
yield self._standardize_batch(data)
|
||
|
|
||
|
dataset = tf.data.Dataset.from_generator(
|
||
|
wrapped_generator, output_signature=output_signature
|
||
|
)
|
||
|
|
||
|
if workers == 1 and not use_multiprocessing:
|
||
|
dataset = dataset.prefetch(1)
|
||
|
|
||
|
self._dataset = dataset
|
||
|
|
||
|
def _standardize_batch(self, data):
|
||
|
"""Standardizes a batch output by a generator."""
|
||
|
# Removes `None`s.
|
||
|
x, y, sample_weight = unpack_x_y_sample_weight(data)
|
||
|
data = pack_x_y_sample_weight(x, y, sample_weight)
|
||
|
|
||
|
data = tf.__internal__.nest.list_to_tuple(data)
|
||
|
|
||
|
def _convert_dtype(t):
|
||
|
if isinstance(t, np.ndarray) and issubclass(
|
||
|
t.dtype.type, np.floating
|
||
|
):
|
||
|
return np.array(t, dtype=backend.floatx())
|
||
|
return t
|
||
|
|
||
|
data = tf.nest.map_structure(_convert_dtype, data)
|
||
|
return data
|
||
|
|
||
|
@staticmethod
|
||
|
def _peek_and_restore(x):
|
||
|
peek = next(x)
|
||
|
return peek, itertools.chain([peek], x)
|
||
|
|
||
|
def _handle_multiprocessing(
|
||
|
self, x, workers, use_multiprocessing, max_queue_size
|
||
|
):
|
||
|
"""Create a callable, possibly including an Enqueuer."""
|
||
|
if workers > 1 or (workers > 0 and use_multiprocessing):
|
||
|
|
||
|
def generator_fn():
|
||
|
enqueuer = data_utils.GeneratorEnqueuer(
|
||
|
x, use_multiprocessing=use_multiprocessing
|
||
|
)
|
||
|
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
|
||
|
return enqueuer.get()
|
||
|
|
||
|
else:
|
||
|
generator_fn = lambda: x
|
||
|
return generator_fn
|
||
|
|
||
|
def get_dataset(self):
|
||
|
return self._dataset
|
||
|
|
||
|
def get_size(self):
|
||
|
return None
|
||
|
|
||
|
def batch_size(self):
|
||
|
return None
|
||
|
|
||
|
def representative_batch_size(self):
|
||
|
return self._first_batch_size
|
||
|
|
||
|
def has_partial_batch(self):
|
||
|
return False
|
||
|
|
||
|
def partial_batch_size(self):
|
||
|
return
|
||
|
|
||
|
def should_recreate_iterator(self):
|
||
|
return False
|
||
|
|
||
|
|
||
|
class KerasSequenceAdapter(GeneratorDataAdapter):
|
||
|
"""Adapter that handles `keras.utils.Sequence`."""
|
||
|
|
||
|
@staticmethod
|
||
|
def can_handle(x, y=None):
|
||
|
return isinstance(x, data_utils.Sequence)
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
x,
|
||
|
y=None,
|
||
|
sample_weights=None,
|
||
|
shuffle=False,
|
||
|
workers=1,
|
||
|
use_multiprocessing=False,
|
||
|
max_queue_size=10,
|
||
|
model=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
if not is_none_or_empty(y):
|
||
|
raise ValueError(
|
||
|
"`y` argument is not supported when using "
|
||
|
"`keras.utils.Sequence` as input."
|
||
|
)
|
||
|
if not is_none_or_empty(sample_weights):
|
||
|
raise ValueError(
|
||
|
"`sample_weight` argument is not supported when using "
|
||
|
"`keras.utils.Sequence` as input."
|
||
|
)
|
||
|
|
||
|
self._shuffle_sequence = shuffle
|
||
|
self._keras_sequence = x
|
||
|
self._enqueuer = None
|
||
|
super().__init__(
|
||
|
x,
|
||
|
shuffle=False, # Shuffle is handed in the _make_callable override.
|
||
|
workers=workers,
|
||
|
use_multiprocessing=use_multiprocessing,
|
||
|
max_queue_size=max_queue_size,
|
||
|
model=model,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def _peek_and_restore(x):
|
||
|
return x[0], x
|
||
|
|
||
|
def _handle_multiprocessing(
|
||
|
self, x, workers, use_multiprocessing, max_queue_size
|
||
|
):
|
||
|
if workers > 1 or (workers > 0 and use_multiprocessing):
|
||
|
|
||
|
def generator_fn():
|
||
|
self._enqueuer = data_utils.OrderedEnqueuer(
|
||
|
x,
|
||
|
use_multiprocessing=use_multiprocessing,
|
||
|
shuffle=self._shuffle_sequence,
|
||
|
)
|
||
|
self._enqueuer.start(
|
||
|
workers=workers, max_queue_size=max_queue_size
|
||
|
)
|
||
|
return self._enqueuer.get()
|
||
|
|
||
|
else:
|
||
|
|
||
|
def generator_fn():
|
||
|
order = range(len(x))
|
||
|
if self._shuffle_sequence:
|
||
|
# Match the shuffle convention in OrderedEnqueuer.
|
||
|
order = list(order)
|
||
|
random.shuffle(order)
|
||
|
|
||
|
for i in order:
|
||
|
yield x[i]
|
||
|
|
||
|
return generator_fn
|
||
|
|
||
|
def get_size(self):
|
||
|
return len(self._keras_sequence)
|
||
|
|
||
|
def should_recreate_iterator(self):
|
||
|
return True
|
||
|
|
||
|
def on_epoch_end(self):
|
||
|
if self._enqueuer:
|
||
|
self._enqueuer.stop()
|
||
|
self._keras_sequence.on_epoch_end()
|
||
|
|
||
|
|
||
|
ALL_ADAPTER_CLS = [
|
||
|
ListsOfScalarsDataAdapter,
|
||
|
TensorLikeDataAdapter,
|
||
|
GenericArrayLikeDataAdapter,
|
||
|
DatasetAdapter,
|
||
|
GeneratorDataAdapter,
|
||
|
KerasSequenceAdapter,
|
||
|
CompositeTensorDataAdapter,
|
||
|
DatasetCreatorAdapter,
|
||
|
]
|
||
|
|
||
|
|
||
|
def select_data_adapter(x, y):
|
||
|
"""Selects a data adapter that can handle a given x and y."""
|
||
|
adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)]
|
||
|
if not adapter_cls:
|
||
|
# TODO(scottzhu): This should be a less implementation-specific error.
|
||
|
raise ValueError(
|
||
|
"Failed to find data adapter that can handle input: {}, {}".format(
|
||
|
_type_name(x), _type_name(y)
|
||
|
)
|
||
|
)
|
||
|
elif len(adapter_cls) > 1:
|
||
|
raise RuntimeError(
|
||
|
"Data adapters should be mutually exclusive for "
|
||
|
"handling inputs. Found multiple adapters {} to handle "
|
||
|
"input: {}, {}".format(adapter_cls, _type_name(x), _type_name(y))
|
||
|
)
|
||
|
# Instrument the data adapter usage before returning it
|
||
|
keras_data_adapter_gauge.get_cell(adapter_cls[0].__name__).set(True)
|
||
|
return adapter_cls[0]
|
||
|
|
||
|
|
||
|
def _type_name(x):
|
||
|
"""Generates a description of the type of an object."""
|
||
|
if isinstance(x, dict):
|
||
|
key_types = set(_type_name(key) for key in x.keys())
|
||
|
val_types = set(_type_name(key) for key in x.values())
|
||
|
return f"({type(x)} containing {key_types} keys and {val_types} values)"
|
||
|
if isinstance(x, (list, tuple)):
|
||
|
types = set(_type_name(val) for val in x)
|
||
|
return f"({type(x)} containing values of types {types})"
|
||
|
return str(type(x))
|
||
|
|
||
|
|
||
|
def _process_tensorlike(inputs):
|
||
|
"""Process tensor-like inputs.
|
||
|
|
||
|
This function:
|
||
|
|
||
|
(1) Converts `Numpy` arrays to `Tensor`s.
|
||
|
(2) Converts `Scipy` sparse matrices to `SparseTensor`s.
|
||
|
(3) Converts `pandas.Series` to `Tensor`s
|
||
|
(4) Converts `list`s to `tuple`s (for `tf.data` support).
|
||
|
|
||
|
Args:
|
||
|
inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
|
||
|
|
||
|
Returns:
|
||
|
Structure of `Tensor`s or tensor-like.
|
||
|
"""
|
||
|
|
||
|
def _convert_single_tensor(x):
|
||
|
if _is_pandas_series(x):
|
||
|
x = np.expand_dims(x.to_numpy(), axis=-1)
|
||
|
|
||
|
if isinstance(x, np.ndarray):
|
||
|
dtype = None
|
||
|
if issubclass(x.dtype.type, np.floating):
|
||
|
dtype = backend.floatx()
|
||
|
return tf.convert_to_tensor(x, dtype=dtype)
|
||
|
elif _is_scipy_sparse(x):
|
||
|
return _scipy_sparse_to_sparse_tensor(x)
|
||
|
return x
|
||
|
|
||
|
inputs = tf.nest.map_structure(_convert_single_tensor, inputs)
|
||
|
return tf.__internal__.nest.list_to_tuple(inputs)
|
||
|
|
||
|
|
||
|
def is_none_or_empty(inputs):
|
||
|
# util method to check if the input is a None or a empty list.
|
||
|
# the python "not" check will raise an error like below if the input is a
|
||
|
# numpy array
|
||
|
# "The truth value of an array with more than one element is ambiguous.
|
||
|
# Use a.any() or a.all()"
|
||
|
return inputs is None or not tf.nest.flatten(inputs)
|
||
|
|
||
|
|
||
|
def broadcast_sample_weight_modes(target_structure, sample_weight_modes):
|
||
|
"""Match sample_weight_modes structure with output structure."""
|
||
|
if target_structure is None or not tf.nest.flatten(target_structure):
|
||
|
return sample_weight_modes
|
||
|
|
||
|
if isinstance(sample_weight_modes, str):
|
||
|
if isinstance(target_structure, dict):
|
||
|
return {key: sample_weight_modes for key in target_structure.keys()}
|
||
|
return [sample_weight_modes for _ in target_structure]
|
||
|
|
||
|
if sample_weight_modes:
|
||
|
try:
|
||
|
tf.nest.assert_same_structure(
|
||
|
training_utils.list_to_tuple(target_structure),
|
||
|
training_utils.list_to_tuple(sample_weight_modes),
|
||
|
)
|
||
|
except (ValueError, TypeError):
|
||
|
target_str = str(
|
||
|
tf.nest.map_structure(lambda _: "...", target_structure)
|
||
|
)
|
||
|
mode_str = str(
|
||
|
tf.nest.map_structure(lambda _: "...", sample_weight_modes)
|
||
|
)
|
||
|
|
||
|
# Attempt to coerce sample_weight_modes to the target structure.
|
||
|
# This implicitly depends on the fact that Model flattens outputs
|
||
|
# for its internal representation.
|
||
|
try:
|
||
|
sample_weight_modes = tf.nest.pack_sequence_as(
|
||
|
target_structure, tf.nest.flatten(sample_weight_modes)
|
||
|
)
|
||
|
logging.warning(
|
||
|
"sample_weight modes were coerced from\n "
|
||
|
"{}\n to \n {}".format(target_str, mode_str)
|
||
|
)
|
||
|
except (ValueError, TypeError):
|
||
|
raise ValueError(
|
||
|
"Unable to match target structure and sample_weight_modes "
|
||
|
"structure:\n {}\n to \n {}".format(
|
||
|
target_str, mode_str
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return sample_weight_modes
|
||
|
|
||
|
|
||
|
class DataHandler:
|
||
|
"""Handles iterating over epoch-level `tf.data.Iterator` objects."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
x,
|
||
|
y=None,
|
||
|
sample_weight=None,
|
||
|
batch_size=None,
|
||
|
steps_per_epoch=None,
|
||
|
initial_epoch=0,
|
||
|
epochs=1,
|
||
|
shuffle=False,
|
||
|
class_weight=None,
|
||
|
max_queue_size=10,
|
||
|
workers=1,
|
||
|
use_multiprocessing=False,
|
||
|
model=None,
|
||
|
steps_per_execution=None,
|
||
|
distribute=True,
|
||
|
):
|
||
|
"""Initializes a `DataHandler`.
|
||
|
|
||
|
Arguments:
|
||
|
x: See `Model.fit`.
|
||
|
y: See `Model.fit`.
|
||
|
sample_weight: See `Model.fit`.
|
||
|
batch_size: See `Model.fit`.
|
||
|
steps_per_epoch: See `Model.fit`.
|
||
|
initial_epoch: See `Model.fit`.
|
||
|
epochs: See `Model.fit`.
|
||
|
shuffle: See `Model.fit`.
|
||
|
class_weight: See `Model.fit`.
|
||
|
max_queue_size: See `Model.fit`.
|
||
|
workers: See `Model.fit`.
|
||
|
use_multiprocessing: See `Model.fit`.
|
||
|
model: The `Model` instance. Needed in order to correctly `build` the
|
||
|
`Model` using generator-like inputs (see `GeneratorDataAdapter`).
|
||
|
steps_per_execution: See `Model.compile`.
|
||
|
distribute: Whether to distribute the `tf.dataset`.
|
||
|
`PreprocessingLayer.adapt` does not support distributed datasets,
|
||
|
`Model` should always set this to `True`.
|
||
|
"""
|
||
|
|
||
|
self._initial_epoch = initial_epoch
|
||
|
self._initial_step = 0
|
||
|
self._epochs = epochs
|
||
|
self._insufficient_data = False
|
||
|
self._model = model
|
||
|
|
||
|
self._steps_per_epoch = steps_per_epoch
|
||
|
|
||
|
# `steps_per_execution_value` is the cached initial value.
|
||
|
# `steps_per_execution` is mutable and may be changed by the DataAdapter
|
||
|
# to handle partial executions.
|
||
|
if steps_per_execution is None:
|
||
|
self._steps_per_execution = tf.Variable(1)
|
||
|
else:
|
||
|
self._steps_per_execution = steps_per_execution
|
||
|
|
||
|
adapter_cls = select_data_adapter(x, y)
|
||
|
self._adapter = adapter_cls(
|
||
|
x,
|
||
|
y,
|
||
|
batch_size=batch_size,
|
||
|
steps=steps_per_epoch,
|
||
|
epochs=epochs - initial_epoch,
|
||
|
sample_weights=sample_weight,
|
||
|
shuffle=shuffle,
|
||
|
max_queue_size=max_queue_size,
|
||
|
workers=workers,
|
||
|
use_multiprocessing=use_multiprocessing,
|
||
|
distribution_strategy=tf.distribute.get_strategy(),
|
||
|
model=model,
|
||
|
)
|
||
|
|
||
|
strategy = tf.distribute.get_strategy()
|
||
|
|
||
|
self._current_step = 0
|
||
|
self._step_increment = self._steps_per_execution.numpy().item() - 1
|
||
|
self._insufficient_data = False
|
||
|
|
||
|
self._configure_dataset_and_inferred_steps(
|
||
|
strategy, x, steps_per_epoch, class_weight, distribute
|
||
|
)
|
||
|
|
||
|
def _configure_dataset_and_inferred_steps(
|
||
|
self, strategy, x, steps_per_epoch, class_weight, distribute
|
||
|
):
|
||
|
"""Configure the `_dataset` and `_inferred_steps` attributes."""
|
||
|
del x
|
||
|
dataset = self._adapter.get_dataset()
|
||
|
if class_weight:
|
||
|
dataset = dataset.map(_make_class_weight_map_fn(class_weight))
|
||
|
self._inferred_steps = self._infer_steps(steps_per_epoch, dataset)
|
||
|
|
||
|
# `PreprocessingLayer.adapt` does not currently support distributed
|
||
|
# datasets, so we pass `distribute=False` there.
|
||
|
if distribute and not _is_distributed_dataset(dataset):
|
||
|
dataset = strategy.experimental_distribute_dataset(dataset)
|
||
|
self._dataset = dataset
|
||
|
self._validate_data_handler()
|
||
|
|
||
|
def enumerate_epochs(self):
|
||
|
"""Yields `(epoch, tf.data.Iterator)`."""
|
||
|
with self._truncate_execution_to_epoch():
|
||
|
data_iterator = iter(self._dataset)
|
||
|
for epoch in range(self._initial_epoch, self._epochs):
|
||
|
if self._insufficient_data: # Set by `catch_stop_iteration`.
|
||
|
break
|
||
|
if self._adapter.should_recreate_iterator():
|
||
|
data_iterator = iter(self._dataset)
|
||
|
if not isinstance(self._dataset, DistributedDataset):
|
||
|
steps = self._infer_steps(
|
||
|
self._steps_per_epoch, self._dataset
|
||
|
)
|
||
|
if steps is not None:
|
||
|
self._inferred_steps = steps
|
||
|
yield epoch, data_iterator
|
||
|
self._adapter.on_epoch_end()
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _truncate_execution_to_epoch(self):
|
||
|
"""Truncates steps per execution to at most one epoch."""
|
||
|
should_truncate = (
|
||
|
self._inferred_steps is not None
|
||
|
and self._steps_per_execution.numpy().item() > self._inferred_steps
|
||
|
)
|
||
|
original_value = self._steps_per_execution.numpy().item()
|
||
|
try:
|
||
|
if should_truncate:
|
||
|
self._steps_per_execution.assign(self._inferred_steps)
|
||
|
yield
|
||
|
finally:
|
||
|
if should_truncate:
|
||
|
self._steps_per_execution.assign(original_value)
|
||
|
|
||
|
def sync(self):
|
||
|
context.async_wait()
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def catch_stop_iteration(self):
|
||
|
"""Catches errors when an iterator runs out of data."""
|
||
|
with distributed_training_utils.maybe_preemption_handler_scope(
|
||
|
self._model
|
||
|
):
|
||
|
try:
|
||
|
yield
|
||
|
self.sync()
|
||
|
except (StopIteration, tf.errors.OutOfRangeError):
|
||
|
if self._inferred_steps is None:
|
||
|
self._inferred_steps = self._current_step
|
||
|
else:
|
||
|
self._insufficient_data = True
|
||
|
total_epochs = self._epochs - self._initial_epoch
|
||
|
logging.warning(
|
||
|
"Your input ran out of data; interrupting training. "
|
||
|
"Make sure that your dataset or generator can generate "
|
||
|
"at least `steps_per_epoch * epochs` batches (in this "
|
||
|
"case, {} batches). You may need to use the repeat() "
|
||
|
"function when building your dataset.".format(
|
||
|
total_epochs * self._inferred_steps
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def steps(self):
|
||
|
"""Yields steps for the current epoch."""
|
||
|
self._current_step = self._initial_step
|
||
|
self._initial_step = 0
|
||
|
# `self._inferred_steps` can be changed by `catch_stop_iteration`.
|
||
|
while (
|
||
|
self._inferred_steps is None
|
||
|
or self._current_step < self._inferred_steps
|
||
|
):
|
||
|
if self._insufficient_data: # Set by `catch_stop_iteration`.
|
||
|
break
|
||
|
original_spe = self._steps_per_execution.numpy().item()
|
||
|
can_run_full_execution = (
|
||
|
original_spe == 1
|
||
|
or self._inferred_steps is None
|
||
|
or self._inferred_steps - self._current_step >= original_spe
|
||
|
)
|
||
|
|
||
|
if can_run_full_execution:
|
||
|
self._step_increment = original_spe - 1
|
||
|
yield self._current_step
|
||
|
self._current_step += original_spe
|
||
|
else:
|
||
|
# Last partial execution.
|
||
|
steps_remaining = self._inferred_steps - self._current_step
|
||
|
self._steps_per_execution.assign(steps_remaining)
|
||
|
self._step_increment = steps_remaining - 1
|
||
|
yield self._current_step
|
||
|
self._current_step += steps_remaining
|
||
|
self._steps_per_execution.assign(original_spe)
|
||
|
|
||
|
@property
|
||
|
def step_increment(self):
|
||
|
"""The number to increment the step for `on_batch_end` methods."""
|
||
|
return self._step_increment
|
||
|
|
||
|
@property
|
||
|
def inferred_steps(self):
|
||
|
"""The inferred steps per epoch of the created `Dataset`.
|
||
|
|
||
|
This will be `None` in the case where:
|
||
|
|
||
|
(1) A `Dataset` of unknown cardinality was passed to the `DataHandler`,
|
||
|
(2) `steps_per_epoch` was not provided, and
|
||
|
(3) The first epoch of iteration has not yet completed.
|
||
|
|
||
|
Returns:
|
||
|
The inferred steps per epoch of the created `Dataset`.
|
||
|
"""
|
||
|
return self._inferred_steps
|
||
|
|
||
|
@property
|
||
|
def should_sync(self):
|
||
|
# Catch OutOfRangeError for Datasets of unknown size.
|
||
|
# This blocks until the batch has finished executing.
|
||
|
# TODO(b/150292341): Allow multiple async steps here.
|
||
|
return self._inferred_steps is None
|
||
|
|
||
|
def _log_indefinite_training_warning(self):
|
||
|
logging.warning(
|
||
|
"The training loop will run indefinitely since you have "
|
||
|
"set `steps_per_epoch=-1`. Please use batch-level "
|
||
|
"callbacks to save checkpoints or log training progress, "
|
||
|
"etc"
|
||
|
)
|
||
|
|
||
|
def _infer_steps(self, steps, dataset):
|
||
|
"""Infers steps_per_epoch needed to loop through a dataset."""
|
||
|
if steps == -1:
|
||
|
self._log_indefinite_training_warning()
|
||
|
return None
|
||
|
|
||
|
if steps is not None:
|
||
|
return steps
|
||
|
|
||
|
adapter_steps = self._adapter.get_size()
|
||
|
if adapter_steps is not None:
|
||
|
return adapter_steps
|
||
|
|
||
|
# tf.distribute's `PerWorkerDataset` does not inherit from
|
||
|
# `tf.data.Dataset` and in those cases we give up on inferring steps.
|
||
|
if not isinstance(dataset, tf.data.Dataset):
|
||
|
return None
|
||
|
|
||
|
size = tf.data.experimental.cardinality(dataset)
|
||
|
if size == tf.data.experimental.INFINITE_CARDINALITY and steps is None:
|
||
|
raise ValueError(
|
||
|
"When passing an infinitely repeating dataset, please specify "
|
||
|
"a `steps_per_epoch` value so that epoch level "
|
||
|
"callbacks continue to work. The value can be arbitrary, or a "
|
||
|
"number that you think correctly defines the size of an epoch. "
|
||
|
"Epoch-level callbacks will then be called at this interval."
|
||
|
)
|
||
|
if size >= 0:
|
||
|
return size.numpy().item()
|
||
|
return None
|
||
|
|
||
|
@property
|
||
|
def _samples(self):
|
||
|
return self._adapter.get_samples()
|
||
|
|
||
|
def _validate_data_handler(self):
|
||
|
# TODO(b/152094471): Support this with DistIter.get_next_as_optional.
|
||
|
if (
|
||
|
self._steps_per_execution.numpy().item() > 1
|
||
|
and self._inferred_steps is None
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"Could not infer the size of the data. With "
|
||
|
"`steps_per_execution > 1`, you must specify the number of "
|
||
|
"steps to run."
|
||
|
)
|
||
|
|
||
|
|
||
|
class _ClusterCoordinatorDataHandler(DataHandler):
|
||
|
"""A `DataHandler` that is compatible with `ClusterCoordinator`."""
|
||
|
|
||
|
def __init__(self, x, y=None, **kwargs):
|
||
|
if not _is_distributed_dataset(x) and not isinstance(
|
||
|
x, (dataset_creator.DatasetCreator, tf.data.Dataset)
|
||
|
):
|
||
|
x = self._convert_to_dataset_creator(x, y, **kwargs)
|
||
|
|
||
|
super().__init__(x=x, **kwargs)
|
||
|
|
||
|
def _convert_to_dataset_creator(self, x, y, **kwargs):
|
||
|
"""Converts non-tf.data.Dataset to `DatasetCreator` instances."""
|
||
|
|
||
|
def _dataset_fn(input_context):
|
||
|
del input_context
|
||
|
data_adapter_cls = select_data_adapter(x, y)
|
||
|
return data_adapter_cls(x=x, y=y, **kwargs).get_dataset()
|
||
|
|
||
|
# This check is needed because types like `tf.data.Dataset` don't work
|
||
|
# with PSS yet. So only apply this logic to the types we can support.
|
||
|
if isinstance(x, _get_tensor_types()) and isinstance(
|
||
|
y, _get_tensor_types()
|
||
|
):
|
||
|
return dataset_creator.DatasetCreator(_dataset_fn)
|
||
|
else:
|
||
|
raise NotImplementedError(
|
||
|
"Only `tf.keras.utils.experimental.DatasetCreator`, "
|
||
|
"`tf.Tensor`, numpy arrays and pandas dataframes are "
|
||
|
"supported types at this time."
|
||
|
)
|
||
|
|
||
|
def _configure_dataset_and_inferred_steps(
|
||
|
self, strategy, x, steps_per_epoch, class_weight, distribute
|
||
|
):
|
||
|
if isinstance(x, dataset_creator.DatasetCreator):
|
||
|
|
||
|
def per_worker_dataset_fn():
|
||
|
|
||
|
return strategy.distribute_datasets_from_function(
|
||
|
x, options=x.input_options
|
||
|
)
|
||
|
|
||
|
coordinator = self._model._cluster_coordinator
|
||
|
self._dataset = coordinator.create_per_worker_dataset(
|
||
|
per_worker_dataset_fn
|
||
|
)
|
||
|
else:
|
||
|
assert distribute
|
||
|
if not _is_distributed_dataset(x):
|
||
|
x = strategy.experimental_distribute_dataset(x)
|
||
|
|
||
|
coordinator = self._model._cluster_coordinator
|
||
|
self._dataset = coordinator.create_per_worker_dataset(x)
|
||
|
|
||
|
if steps_per_epoch == -1:
|
||
|
self._inferred_steps = None
|
||
|
self._log_indefinite_training_warning()
|
||
|
else:
|
||
|
self._inferred_steps = steps_per_epoch
|
||
|
|
||
|
def sync(self):
|
||
|
self._model._cluster_coordinator.join()
|
||
|
|
||
|
|
||
|
@keras_export("keras.__internal__.utils.get_data_handler", v1=[])
|
||
|
def get_data_handler(*args, **kwargs):
|
||
|
"""Creates a `DataHandler`, providing standardized access to a `Dataset`.
|
||
|
|
||
|
See `DataHandler` for the list and definition of the arguments. See the
|
||
|
implementation of `Model.fit()`, `evaluate()`, or `predict()` methods
|
||
|
for complete usage examples. As a rule of tumb, `get_data_handler()` accepts
|
||
|
the same inputs as the `x` argument of `Model.fit()`.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
def step(iterator):
|
||
|
data = next(iterator)
|
||
|
# result <= Do something with data
|
||
|
return result
|
||
|
tf_step = tf.function(step, reduce_retracing=True)
|
||
|
|
||
|
# Assume x is a tf.data Dataset.
|
||
|
data_handler = data_adapter.get_data_handler(x=x)
|
||
|
# Epoch iteration
|
||
|
for epo_idx, iterator in data_handler.enumerate_epochs():
|
||
|
# Stop on dataset exhaustion.
|
||
|
with data_handler.catch_stop_iteration():
|
||
|
for step in data_handler.steps(): # Step iteration
|
||
|
step_result = step(iterator)
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
*args: Arguments passed to the `DataHandler` constructor.
|
||
|
**kwargs: Arguments passed to the `DataHandler` constructor.
|
||
|
|
||
|
Returns:
|
||
|
A `DataHandler` object. If the model's cluster coordinate is set (e.g. the
|
||
|
model was defined under a parameter-server strategy), returns a
|
||
|
`_ClusterCoordinatorDataHandler`.
|
||
|
|
||
|
"""
|
||
|
if getattr(kwargs["model"], "_cluster_coordinator", None):
|
||
|
return _ClusterCoordinatorDataHandler(*args, **kwargs)
|
||
|
return DataHandler(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def _make_class_weight_map_fn(class_weight):
|
||
|
"""Applies class weighting to a `Dataset`.
|
||
|
|
||
|
The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
|
||
|
`y` must be a single `Tensor`.
|
||
|
|
||
|
Args:
|
||
|
class_weight: A map where the keys are integer class ids and values are
|
||
|
the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
|
||
|
|
||
|
Returns:
|
||
|
A function that can be used with `tf.data.Dataset.map` to apply class
|
||
|
weighting.
|
||
|
"""
|
||
|
class_ids = list(sorted(class_weight.keys()))
|
||
|
expected_class_ids = list(range(len(class_ids)))
|
||
|
if class_ids != expected_class_ids:
|
||
|
error_msg = (
|
||
|
"Expected `class_weight` to be a dict with keys from 0 to one less "
|
||
|
"than the number of classes, found {}"
|
||
|
).format(class_weight)
|
||
|
raise ValueError(error_msg)
|
||
|
|
||
|
class_weight_tensor = tf.convert_to_tensor(
|
||
|
[class_weight[int(c)] for c in class_ids]
|
||
|
)
|
||
|
|
||
|
def _class_weights_map_fn(*data):
|
||
|
"""Convert `class_weight` to `sample_weight`."""
|
||
|
x, y, sw = unpack_x_y_sample_weight(data)
|
||
|
|
||
|
if tf.nest.is_nested(y):
|
||
|
raise ValueError(
|
||
|
"`class_weight` is only supported for Models with a single "
|
||
|
"output."
|
||
|
)
|
||
|
|
||
|
if y.shape.rank > 2:
|
||
|
raise ValueError(
|
||
|
"`class_weight` not supported for 3+ dimensional targets."
|
||
|
)
|
||
|
|
||
|
y_classes = tf.__internal__.smart_cond.smart_cond(
|
||
|
y.shape.rank == 2 and backend.shape(y)[1] > 1,
|
||
|
lambda: backend.argmax(y, axis=1),
|
||
|
lambda: tf.cast(backend.reshape(y, (-1,)), tf.int64),
|
||
|
)
|
||
|
|
||
|
cw = tf.gather(class_weight_tensor, y_classes)
|
||
|
if sw is not None:
|
||
|
cw = tf.cast(cw, sw.dtype)
|
||
|
# `class_weight` and `sample_weight` are multiplicative.
|
||
|
sw = sw * cw
|
||
|
else:
|
||
|
sw = cw
|
||
|
return x, y, sw
|
||
|
|
||
|
return _class_weights_map_fn
|
||
|
|
||
|
|
||
|
def train_validation_split(arrays, validation_split):
|
||
|
"""Split arrays into train and validation subsets in deterministic order.
|
||
|
|
||
|
The last part of data will become validation data.
|
||
|
|
||
|
Args:
|
||
|
arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
|
||
|
of Tensors and NumPy arrays.
|
||
|
validation_split: Float between 0 and 1. The proportion of the dataset to
|
||
|
include in the validation split. The rest of the dataset will be
|
||
|
included in the training split.
|
||
|
Returns:
|
||
|
`(train_arrays, validation_arrays)`
|
||
|
"""
|
||
|
|
||
|
def _can_split(t):
|
||
|
tensor_types = _get_tensor_types()
|
||
|
return isinstance(t, tensor_types) or t is None
|
||
|
|
||
|
flat_arrays = tf.nest.flatten(arrays)
|
||
|
unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
|
||
|
if unsplitable:
|
||
|
raise ValueError(
|
||
|
"`validation_split` is only supported for Tensors or NumPy "
|
||
|
"arrays, found following types in the input: {}".format(unsplitable)
|
||
|
)
|
||
|
|
||
|
if all(t is None for t in flat_arrays):
|
||
|
return arrays, arrays
|
||
|
|
||
|
first_non_none = None
|
||
|
for t in flat_arrays:
|
||
|
if t is not None:
|
||
|
first_non_none = t
|
||
|
break
|
||
|
|
||
|
# Assumes all arrays have the same batch shape or are `None`.
|
||
|
batch_dim = int(first_non_none.shape[0])
|
||
|
split_at = int(math.floor(batch_dim * (1.0 - validation_split)))
|
||
|
|
||
|
if split_at == 0 or split_at == batch_dim:
|
||
|
raise ValueError(
|
||
|
"Training data contains {batch_dim} samples, which is not "
|
||
|
"sufficient to split it into a validation and training set as "
|
||
|
"specified by `validation_split={validation_split}`. Either "
|
||
|
"provide more data, or a different value for the "
|
||
|
"`validation_split` argument.".format(
|
||
|
batch_dim=batch_dim, validation_split=validation_split
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def _split(t, start, end):
|
||
|
if t is None:
|
||
|
return t
|
||
|
return t[start:end]
|
||
|
|
||
|
train_arrays = tf.nest.map_structure(
|
||
|
functools.partial(_split, start=0, end=split_at), arrays
|
||
|
)
|
||
|
val_arrays = tf.nest.map_structure(
|
||
|
functools.partial(_split, start=split_at, end=batch_dim), arrays
|
||
|
)
|
||
|
|
||
|
return train_arrays, val_arrays
|
||
|
|
||
|
|
||
|
@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
|
||
|
def unpack_x_y_sample_weight(data):
|
||
|
"""Unpacks user-provided data tuple.
|
||
|
|
||
|
This is a convenience utility to be used when overriding
|
||
|
`Model.train_step`, `Model.test_step`, or `Model.predict_step`.
|
||
|
This utility makes it easy to support data of the form `(x,)`,
|
||
|
`(x, y)`, or `(x, y, sample_weight)`.
|
||
|
|
||
|
Standalone usage:
|
||
|
|
||
|
>>> features_batch = tf.ones((10, 5))
|
||
|
>>> labels_batch = tf.zeros((10, 5))
|
||
|
>>> data = (features_batch, labels_batch)
|
||
|
>>> # `y` and `sample_weight` will default to `None` if not provided.
|
||
|
>>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||
|
>>> sample_weight is None
|
||
|
True
|
||
|
|
||
|
Example in overridden `Model.train_step`:
|
||
|
|
||
|
```python
|
||
|
class MyModel(tf.keras.Model):
|
||
|
|
||
|
def train_step(self, data):
|
||
|
# If `sample_weight` is not provided, all samples will be weighted
|
||
|
# equally.
|
||
|
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||
|
|
||
|
with tf.GradientTape() as tape:
|
||
|
y_pred = self(x, training=True)
|
||
|
loss = self.compiled_loss(
|
||
|
y, y_pred, sample_weight, regularization_losses=self.losses)
|
||
|
trainable_variables = self.trainable_variables
|
||
|
gradients = tape.gradient(loss, trainable_variables)
|
||
|
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
|
||
|
|
||
|
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||
|
return {m.name: m.result() for m in self.metrics}
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
|
||
|
|
||
|
Returns:
|
||
|
The unpacked tuple, with `None`s for `y` and `sample_weight` if they are
|
||
|
not provided.
|
||
|
"""
|
||
|
if isinstance(data, list):
|
||
|
data = tuple(data)
|
||
|
if not isinstance(data, tuple):
|
||
|
return (data, None, None)
|
||
|
elif len(data) == 1:
|
||
|
return (data[0], None, None)
|
||
|
elif len(data) == 2:
|
||
|
return (data[0], data[1], None)
|
||
|
elif len(data) == 3:
|
||
|
return (data[0], data[1], data[2])
|
||
|
else:
|
||
|
error_msg = (
|
||
|
"Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
|
||
|
"or `(x, y, sample_weight)`, found: {}"
|
||
|
).format(data)
|
||
|
raise ValueError(error_msg)
|
||
|
|
||
|
|
||
|
@keras_export("keras.utils.pack_x_y_sample_weight", v1=[])
|
||
|
def pack_x_y_sample_weight(x, y=None, sample_weight=None):
|
||
|
"""Packs user-provided data into a tuple.
|
||
|
|
||
|
This is a convenience utility for packing data into the tuple formats
|
||
|
that `Model.fit` uses.
|
||
|
|
||
|
Standalone usage:
|
||
|
|
||
|
>>> x = tf.ones((10, 1))
|
||
|
>>> data = tf.keras.utils.pack_x_y_sample_weight(x)
|
||
|
>>> isinstance(data, tf.Tensor)
|
||
|
True
|
||
|
>>> y = tf.ones((10, 1))
|
||
|
>>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
|
||
|
>>> isinstance(data, tuple)
|
||
|
True
|
||
|
>>> x, y = data
|
||
|
|
||
|
Args:
|
||
|
x: Features to pass to `Model`.
|
||
|
y: Ground-truth targets to pass to `Model`.
|
||
|
sample_weight: Sample weight for each element.
|
||
|
|
||
|
Returns:
|
||
|
Tuple in the format used in `Model.fit`.
|
||
|
"""
|
||
|
if y is None:
|
||
|
# For single x-input, we do no tuple wrapping since in this case
|
||
|
# there is no ambiguity. This also makes NumPy and Dataset
|
||
|
# consistent in that the user does not have to wrap their Dataset
|
||
|
# data in an unnecessary tuple.
|
||
|
if not isinstance(x, tuple or list):
|
||
|
return x
|
||
|
else:
|
||
|
return (x,)
|
||
|
elif sample_weight is None:
|
||
|
return (x, y)
|
||
|
else:
|
||
|
return (x, y, sample_weight)
|
||
|
|
||
|
|
||
|
def single_batch_iterator(
|
||
|
strategy, x, y=None, sample_weight=None, class_weight=None
|
||
|
):
|
||
|
"""Creates a single-batch dataset."""
|
||
|
x, y, sample_weight = _process_tensorlike((x, y, sample_weight))
|
||
|
if y is None:
|
||
|
data = (x,)
|
||
|
elif sample_weight is None:
|
||
|
data = (x, y)
|
||
|
else:
|
||
|
data = (x, y, sample_weight)
|
||
|
|
||
|
_check_data_cardinality(data)
|
||
|
dataset = tf.data.Dataset.from_tensors(data)
|
||
|
if class_weight:
|
||
|
dataset = dataset.map(_make_class_weight_map_fn(class_weight))
|
||
|
dataset = strategy.experimental_distribute_dataset(dataset)
|
||
|
return iter(dataset)
|
||
|
|
||
|
|
||
|
def _check_data_cardinality(data):
|
||
|
num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data))
|
||
|
if len(num_samples) > 1:
|
||
|
msg = "Data cardinality is ambiguous:\n"
|
||
|
for label, single_data in zip(["x", "y", "sample_weight"], data):
|
||
|
msg += " {} sizes: {}\n".format(
|
||
|
label,
|
||
|
", ".join(
|
||
|
str(i.shape[0]) for i in tf.nest.flatten(single_data)
|
||
|
),
|
||
|
)
|
||
|
msg += "Make sure all arrays contain the same number of samples."
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
|
||
|
def _get_tensor_types():
|
||
|
if pd is None:
|
||
|
return (tf.Tensor, np.ndarray)
|
||
|
else:
|
||
|
return (tf.Tensor, np.ndarray, pd.Series, pd.DataFrame)
|
||
|
|
||
|
|
||
|
def _is_scipy_sparse(x):
|
||
|
try:
|
||
|
from scipy.sparse import issparse
|
||
|
|
||
|
return issparse(x)
|
||
|
except ImportError:
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _is_pandas_series(x):
|
||
|
if pd is None:
|
||
|
return False
|
||
|
else:
|
||
|
return isinstance(x, pd.Series)
|
||
|
|
||
|
|
||
|
def _scipy_sparse_to_sparse_tensor(t):
|
||
|
"""Converts a SciPy sparse matrix to a SparseTensor."""
|
||
|
sparse_coo = t.tocoo()
|
||
|
row, col = sparse_coo.row, sparse_coo.col
|
||
|
data, shape = sparse_coo.data, sparse_coo.shape
|
||
|
if issubclass(data.dtype.type, np.floating):
|
||
|
data = data.astype(backend.floatx())
|
||
|
indices = np.concatenate(
|
||
|
(np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1
|
||
|
)
|
||
|
return tf.SparseTensor(indices, data, shape)
|
||
|
|
||
|
|
||
|
def _is_distributed_dataset(ds):
|
||
|
return isinstance(ds, tf.distribute.DistributedDataset)
|