2227 lines
81 KiB
Python
2227 lines
81 KiB
Python
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Training-related utilities."""
|
||
|
|
||
|
import abc
|
||
|
import atexit
|
||
|
import collections
|
||
|
import functools
|
||
|
import multiprocessing.pool
|
||
|
import threading
|
||
|
import time
|
||
|
|
||
|
import numpy as np
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras import callbacks as cbks
|
||
|
from keras import losses
|
||
|
from keras import metrics as metrics_module
|
||
|
from keras.utils import data_utils
|
||
|
from keras.utils import generic_utils
|
||
|
from keras.utils import losses_utils
|
||
|
from keras.utils import tf_inspect
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
|
||
|
|
||
|
def is_composite_or_composite_value(tensor):
|
||
|
"""Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
|
||
|
# TODO(b/125094323): This should be isinstance(CompositeTensor) or
|
||
|
# isinstance(CompositeTensorValue) once we support that.
|
||
|
return isinstance(
|
||
|
tensor,
|
||
|
(
|
||
|
tf.__internal__.CompositeTensor,
|
||
|
tf.compat.v1.SparseTensorValue,
|
||
|
tf.compat.v1.ragged.RaggedTensorValue,
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
class Aggregator(object, metaclass=abc.ABCMeta):
|
||
|
"""Abstract base class used to aggregate batch-level outputs of a loop.
|
||
|
|
||
|
Attributes:
|
||
|
use_steps: Whether the loop is using `step` or `batch_size`.
|
||
|
num_samples: Total number of samples: `batch_size * num_batches`.
|
||
|
steps: Total number of steps.
|
||
|
batch_size: Batch size. It is used for validation checks between inputs
|
||
|
and outputs.
|
||
|
results: What to return at the end of the aggregation loop.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self, use_steps, num_samples=None, steps=None, batch_size=None
|
||
|
):
|
||
|
self.use_steps = use_steps
|
||
|
self.num_samples = num_samples
|
||
|
self.steps = steps
|
||
|
self.batch_size = batch_size
|
||
|
self.results = []
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def create(self, batch_outs):
|
||
|
"""Creates the initial results from the first batch outputs.
|
||
|
|
||
|
Args:
|
||
|
batch_outs: A list of batch-level outputs.
|
||
|
"""
|
||
|
raise NotImplementedError("Must be implemented in subclasses.")
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def aggregate(self, batch_outs, batch_start=None, batch_end=None):
|
||
|
"""Aggregates batch-level results into total results.
|
||
|
|
||
|
Args:
|
||
|
batch_outs: A list of batch-level outputs.
|
||
|
batch_start: The start index of this batch. Always `None` if
|
||
|
`use_steps` is `True`.
|
||
|
batch_end: The end index of this batch. Always `None` if `use_steps`
|
||
|
is `True`.
|
||
|
"""
|
||
|
raise NotImplementedError("Must be implemented in subclasses.")
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def finalize(self):
|
||
|
"""Prepares the total results to be returned."""
|
||
|
raise NotImplementedError("Must be implemented in subclasses.")
|
||
|
|
||
|
|
||
|
class MetricsAggregator(Aggregator):
|
||
|
"""Aggregator that calculates loss and metrics info.
|
||
|
|
||
|
Attributes:
|
||
|
use_steps: Whether the loop is using `step` or `batch_size`.
|
||
|
num_samples: Total number of samples: `batch_size*num_batches`.
|
||
|
steps: Total number of steps, ie number of times to iterate over a dataset
|
||
|
to cover all samples.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, use_steps, num_samples=None, steps=None):
|
||
|
super().__init__(
|
||
|
use_steps=use_steps,
|
||
|
num_samples=num_samples,
|
||
|
steps=steps,
|
||
|
batch_size=None,
|
||
|
)
|
||
|
|
||
|
def create(self, batch_outs):
|
||
|
self.results = [0.0] * len(batch_outs)
|
||
|
|
||
|
def aggregate(self, batch_outs, batch_start=None, batch_end=None):
|
||
|
# Loss.
|
||
|
if self.use_steps:
|
||
|
self.results[0] += batch_outs[0]
|
||
|
else:
|
||
|
self.results[0] += batch_outs[0] * (batch_end - batch_start)
|
||
|
# Metrics (always stateful, just grab current values.)
|
||
|
self.results[1:] = batch_outs[1:]
|
||
|
|
||
|
def finalize(self):
|
||
|
if not self.results:
|
||
|
raise ValueError("Empty training data.")
|
||
|
self.results[0] /= self.num_samples or self.steps
|
||
|
|
||
|
|
||
|
def _append_sparse_tensor_value(target, to_append):
|
||
|
"""Append sparse tensor value objects."""
|
||
|
# Make sure the sparse tensors are of the same size (except for the 0th
|
||
|
# dim).
|
||
|
if len(target.dense_shape) != len(to_append.dense_shape):
|
||
|
raise RuntimeError(
|
||
|
"Unable to concatenate %s and %s. The inner dense shapes do not "
|
||
|
"have the same number of dimensions (%s vs %s)"
|
||
|
% (target, to_append, target.dense_shape, to_append.dense_shape)
|
||
|
)
|
||
|
|
||
|
if target.dense_shape[1:] != to_append.dense_shape[1:]:
|
||
|
raise RuntimeError(
|
||
|
"Unable to concatenate %s and %s. The inner dense shapes do not "
|
||
|
"match inner dimensions (%s vs %s)"
|
||
|
% (
|
||
|
target,
|
||
|
to_append,
|
||
|
target.dense_shape[1:],
|
||
|
to_append.dense_shape[1:],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# Add the to_append indices to target, updating the 0th value, and keeping
|
||
|
# track of the maximum so we know the final dense_shape of this tensor.
|
||
|
base_dim0_value = target.dense_shape[0]
|
||
|
max_dim0_value = target.dense_shape[0]
|
||
|
new_indices = target.indices
|
||
|
for index in to_append.indices:
|
||
|
# Here, we iterate through the sparse indices of the tensor to append.
|
||
|
# For each index, we update its zeroth value (the batch index) by adding
|
||
|
# the number of batch items in the tensor we are appending to (so an
|
||
|
# index of [0, 0, 1] for a value that is being appended to a tensor with
|
||
|
# 0th dim size 3 would become [3, 0, 1].)
|
||
|
index[0] += base_dim0_value
|
||
|
max_dim0_value = max(max_dim0_value, index[0])
|
||
|
new_indices = np.append(new_indices, [index], axis=0)
|
||
|
|
||
|
# Extend the values array to contain all of the appended values. These will
|
||
|
# be in the same order as the indices added above.
|
||
|
new_values = np.concatenate((target.values, to_append.values), axis=0)
|
||
|
|
||
|
# Create a new dense shape by replacing the value for the 0th dimension
|
||
|
# with the new max dim0 value.
|
||
|
new_dense_shape = list(target.dense_shape)
|
||
|
new_dense_shape[0] = max_dim0_value + 1
|
||
|
new_dense_shape = tuple(new_dense_shape)
|
||
|
|
||
|
return tf.compat.v1.SparseTensorValue(
|
||
|
indices=new_indices, values=new_values, dense_shape=new_dense_shape
|
||
|
)
|
||
|
|
||
|
|
||
|
def _append_ragged_tensor_value(target, to_append):
|
||
|
"""Append ragged tensor value objects."""
|
||
|
# Make sure the ragged tensors are of the same size (save for the 0th dim).
|
||
|
if len(target.shape) != len(to_append.shape):
|
||
|
raise RuntimeError(f"Unable to concatenate {target} and {to_append}")
|
||
|
|
||
|
if target.shape[1:] != to_append.shape[1:]:
|
||
|
raise RuntimeError(f"Unable to concatenate {target} and {to_append}")
|
||
|
|
||
|
adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1]
|
||
|
new_row_splits = np.append(target.row_splits, adjusted_row_splits)
|
||
|
if isinstance(target.values, tf.compat.v1.ragged.RaggedTensorValue):
|
||
|
new_values = _append_ragged_tensor_value(
|
||
|
target.values, to_append.values
|
||
|
)
|
||
|
else:
|
||
|
new_values = np.concatenate((target.values, to_append.values), axis=0)
|
||
|
|
||
|
return tf.compat.v1.ragged.RaggedTensorValue(new_values, new_row_splits)
|
||
|
|
||
|
|
||
|
def _append_composite_tensor(target, to_append):
|
||
|
"""Helper function to append composite tensors to each other in the 0 axis.
|
||
|
|
||
|
In order to support batching within a fit/evaluate/predict call, we need
|
||
|
to be able to aggregate within a CompositeTensor. Unfortunately, the CT
|
||
|
API currently does not make this easy - especially in V1 mode, where we're
|
||
|
working with CompositeTensor Value objects that have no connection with the
|
||
|
CompositeTensors that created them.
|
||
|
|
||
|
Args:
|
||
|
target: CompositeTensor or CompositeTensor value object that will be
|
||
|
appended to.
|
||
|
to_append: CompositeTensor or CompositeTensor value object to append to.
|
||
|
'target'.
|
||
|
|
||
|
Returns:
|
||
|
A CompositeTensor or CompositeTensor value object.
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: if concatenation is not possible.
|
||
|
"""
|
||
|
if type(target) is not type(to_append):
|
||
|
raise RuntimeError(
|
||
|
f"Unable to concatenate {type(target)} and {type(to_append)}"
|
||
|
)
|
||
|
|
||
|
# Perform type-specific concatenation.
|
||
|
# TODO(b/125094323): This should be replaced by a simple call to
|
||
|
# target.append() that should work on all of the below classes.
|
||
|
|
||
|
# If we're seeing a CompositeTensor here, we know it's because we're in
|
||
|
# Eager mode (or else we'd have evaluated the CT to a CT Value object
|
||
|
# already). Therefore, it's safe to call concat() on it without evaluating
|
||
|
# the result any further. If not - that is, if we're seeing a
|
||
|
# SparseTensorValue or a RaggedTensorValue - we need to hand-update it
|
||
|
# since we're outside of the graph anyways.
|
||
|
if isinstance(target, tf.SparseTensor):
|
||
|
# We need to invoke the sparse version of concatenate here - tf.concat
|
||
|
# won't work.
|
||
|
return tf.compat.v1.sparse_concat(sp_inputs=[target, to_append], axis=0)
|
||
|
elif isinstance(target, tf.RaggedTensor):
|
||
|
return tf.concat([target, to_append], axis=0)
|
||
|
elif isinstance(target, tf.compat.v1.SparseTensorValue):
|
||
|
return _append_sparse_tensor_value(target, to_append)
|
||
|
elif isinstance(target, tf.compat.v1.ragged.RaggedTensorValue):
|
||
|
return _append_ragged_tensor_value(target, to_append)
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
f"Attempted to concatenate unsupported object {type(target)}."
|
||
|
)
|
||
|
|
||
|
|
||
|
class ConcatAggregator(Aggregator):
|
||
|
"""Combine tensor-likes which cannot be merged on the fly.
|
||
|
|
||
|
This class expects to aggregate a single tensor-like rather than a nested
|
||
|
structure of tensor-likes.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, batch_size):
|
||
|
self.composite = None
|
||
|
super().__init__(
|
||
|
use_steps=True, num_samples=None, steps=None, batch_size=batch_size
|
||
|
)
|
||
|
|
||
|
def create(self, batch_element):
|
||
|
self.composite = is_composite_or_composite_value(batch_element)
|
||
|
|
||
|
def aggregate(self, batch_element, batch_start=None, batch_end=None):
|
||
|
|
||
|
# TODO(psv): Add num_samples check here to detect when output batch
|
||
|
# #samples is < batch size and != input batch #samples.
|
||
|
if self.batch_size and self.batch_size < batch_element.shape[0]:
|
||
|
raise ValueError(
|
||
|
"Mismatch between expected batch size and model output batch "
|
||
|
"size. Output shape = {}, "
|
||
|
"expected output shape = shape {}".format(
|
||
|
batch_element.shape,
|
||
|
(self.batch_size,) + batch_element.shape[1:],
|
||
|
)
|
||
|
)
|
||
|
self.results.append(batch_element)
|
||
|
|
||
|
def finalize(self):
|
||
|
# Special case of single batch inference which skips a copy.
|
||
|
if len(self.results) == 1:
|
||
|
self.results = self.results[0]
|
||
|
|
||
|
elif self.composite:
|
||
|
# TODO(taylorrobie): efficiently concatenate.
|
||
|
results = self.results[0]
|
||
|
for r in self.results[1:]:
|
||
|
results = _append_composite_tensor(results, r)
|
||
|
self.results = results
|
||
|
|
||
|
else:
|
||
|
self.results = np.concatenate(self.results, axis=0)
|
||
|
|
||
|
|
||
|
_COPY_THREADS = 4
|
||
|
_COPY_POOL = None
|
||
|
|
||
|
|
||
|
def get_copy_pool():
|
||
|
"""Shared threadpool for copying arrays.
|
||
|
|
||
|
Pool instantiation takes ~ 2ms, so a singleton pool is used rather than
|
||
|
creating a pool per SliceAggregator.
|
||
|
|
||
|
Returns:
|
||
|
The global copy threadpool.
|
||
|
"""
|
||
|
global _COPY_POOL
|
||
|
if _COPY_POOL is None:
|
||
|
_COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS)
|
||
|
atexit.register(_COPY_POOL.close)
|
||
|
return _COPY_POOL
|
||
|
|
||
|
|
||
|
class SliceAggregator(Aggregator):
|
||
|
"""Combine arrays where the final size is known.
|
||
|
|
||
|
This class expects to aggregate a single tensor-like rather than a nested
|
||
|
structure of tensor-likes.
|
||
|
|
||
|
NumPy copies are an operation that threads handle quite well because all of
|
||
|
the heavy lifting is in c and does not need the GIL. Moreover, we can
|
||
|
perform lock-free writes to the same buffer in multiple threads because the
|
||
|
nature of result aggregation guarantees that either the indices are disjoint
|
||
|
or the aggregator will throw an exception in finalize. Moreover, because
|
||
|
aggregation is performed on the slowest varying dimension, assignments for a
|
||
|
given batch will write to contiguous blocks of memory, further minimizing
|
||
|
contention.
|
||
|
|
||
|
There is, however, some scheduling and context switching overhead which will
|
||
|
offset the gains from pipelining the slice assignment. Below a given
|
||
|
threshold it is faster to simply assign in the main thread rather than
|
||
|
enqueue the assignment in a side thread. The exact threshold will vary from
|
||
|
system to system, but the time is not very sensitive to the exact transition
|
||
|
so a value of 2 ** 14 was chosen which should be reasonable on most systems.
|
||
|
"""
|
||
|
|
||
|
_BINARY_SIZE_THRESHOLD = 2**14
|
||
|
_MAX_COPY_SECONDS = 300
|
||
|
|
||
|
def __init__(self, num_samples, batch_size):
|
||
|
self._async_copies = []
|
||
|
self._pool = get_copy_pool()
|
||
|
self._errors = []
|
||
|
super().__init__(
|
||
|
use_steps=False,
|
||
|
num_samples=num_samples,
|
||
|
steps=None,
|
||
|
batch_size=batch_size,
|
||
|
)
|
||
|
|
||
|
def create(self, batch_element):
|
||
|
# This step does not need to be pipelined because NumPy empty array
|
||
|
# initialization is effectively instantaneous.
|
||
|
shape = (self.num_samples,) + batch_element.shape[1:]
|
||
|
dtype = batch_element.dtype
|
||
|
|
||
|
self.results = np.empty(shape=shape, dtype=dtype)
|
||
|
|
||
|
def aggregate(self, batch_element, batch_start, batch_end):
|
||
|
# Fail early.
|
||
|
if self._errors:
|
||
|
raise self._errors[0]
|
||
|
|
||
|
# In the special case of single batch inference, no copy is needed.
|
||
|
if batch_end - batch_start == self.num_samples:
|
||
|
if self.num_samples != batch_element.shape[0]:
|
||
|
raise ValueError(
|
||
|
"Mismatch between expected batch size and model "
|
||
|
"output batch size. Output shape = {}, "
|
||
|
"expected output shape = shape {}".format(
|
||
|
batch_element.shape, self.results.shape
|
||
|
)
|
||
|
)
|
||
|
|
||
|
self.results = batch_element
|
||
|
return
|
||
|
|
||
|
# This is an approximate threshold, so we don't need to consider the
|
||
|
# number of bytes per element.
|
||
|
num_elements = np.prod(batch_element.shape)
|
||
|
if num_elements < self._BINARY_SIZE_THRESHOLD:
|
||
|
self.results[batch_start:batch_end] = batch_element
|
||
|
else:
|
||
|
is_finished = threading.Event()
|
||
|
self._pool.apply_async(
|
||
|
self._slice_assign,
|
||
|
args=(batch_element, batch_start, batch_end, is_finished),
|
||
|
)
|
||
|
self._async_copies.append(is_finished)
|
||
|
|
||
|
def _slice_assign(self, batch_element, batch_start, batch_end, is_finished):
|
||
|
"""Legacy utility method to slice input arrays."""
|
||
|
try:
|
||
|
self.results[batch_start:batch_end] = batch_element
|
||
|
|
||
|
except Exception as e:
|
||
|
# `_slice_assign` should only be called in threads and exceptions
|
||
|
# raised in threads do not carry over to the main thread. So instead
|
||
|
# we perform a a broad catch in the thread and then store the
|
||
|
# exception to be re-raised in the main thread.
|
||
|
self._errors.append(e)
|
||
|
|
||
|
finally:
|
||
|
is_finished.set()
|
||
|
|
||
|
def finalize(self):
|
||
|
start_time = time.time()
|
||
|
for is_finished in self._async_copies:
|
||
|
timeout = max(
|
||
|
[0.0, self._MAX_COPY_SECONDS - (time.time() - start_time)]
|
||
|
)
|
||
|
if not is_finished.wait(timeout):
|
||
|
raise ValueError("Timed out waiting for copy to complete.")
|
||
|
|
||
|
if self._errors:
|
||
|
raise self._errors[0]
|
||
|
|
||
|
|
||
|
class OutputsAggregator(Aggregator):
|
||
|
"""Aggregator that concatenates outputs."""
|
||
|
|
||
|
_structure = None
|
||
|
|
||
|
def create(self, batch_outs):
|
||
|
# SparseTensorValue is a named tuple which nest will flatten, so we need
|
||
|
# to guard it to properly handle the structure.
|
||
|
self._structure = tf.__internal__.nest.get_traverse_shallow_structure(
|
||
|
lambda x: not is_composite_or_composite_value(x), batch_outs
|
||
|
)
|
||
|
batch_outs = tf.__internal__.nest.flatten_up_to(
|
||
|
self._structure, batch_outs
|
||
|
)
|
||
|
|
||
|
for batch_element in batch_outs:
|
||
|
if is_composite_or_composite_value(batch_element):
|
||
|
# If the output is not a ndarray, it will be either a composite
|
||
|
# tensor or a composite tensor's Value object. In either case,
|
||
|
# we can't allocate an array to hold the object - we'll handle
|
||
|
# it later.
|
||
|
self.results.append(ConcatAggregator(self.batch_size))
|
||
|
elif isinstance(batch_element, np.ndarray):
|
||
|
self.results.append(
|
||
|
(
|
||
|
ConcatAggregator(self.batch_size)
|
||
|
if self.use_steps
|
||
|
else SliceAggregator(self.num_samples, self.batch_size)
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
# This is not a ndarray, a CompositeTensor, or a
|
||
|
# CompositeTensorValue. Fail fast rather than trying to
|
||
|
# concatenate it.
|
||
|
raise RuntimeError(
|
||
|
"Attempted to aggregate unsupported object {}.".format(
|
||
|
batch_element
|
||
|
)
|
||
|
)
|
||
|
|
||
|
self.results[-1].create(batch_element)
|
||
|
|
||
|
def aggregate(self, batch_outs, batch_start=None, batch_end=None):
|
||
|
batch_outs = tf.__internal__.nest.flatten_up_to(
|
||
|
self._structure, batch_outs
|
||
|
)
|
||
|
for batch_element, result in zip(batch_outs, self.results):
|
||
|
result.aggregate(batch_element, batch_start, batch_end)
|
||
|
|
||
|
def finalize(self):
|
||
|
for result in self.results:
|
||
|
result.finalize()
|
||
|
self.results = [i.results for i in self.results]
|
||
|
self.results = tf.nest.pack_sequence_as(self._structure, self.results)
|
||
|
|
||
|
|
||
|
def get_progbar(model, count_mode, include_metrics=True):
|
||
|
"""Get Progbar."""
|
||
|
if include_metrics:
|
||
|
stateful_metric_names = getattr(model, "metrics_names", None)
|
||
|
if stateful_metric_names:
|
||
|
stateful_metric_names = stateful_metric_names[1:] # Exclude `loss`
|
||
|
else:
|
||
|
stateful_metric_names = None
|
||
|
return cbks.ProgbarLogger(
|
||
|
count_mode, stateful_metrics=stateful_metric_names
|
||
|
)
|
||
|
|
||
|
|
||
|
def check_num_samples(ins, batch_size=None, steps=None, steps_name="steps"):
|
||
|
"""Determine the number of samples provided for training and evaluation.
|
||
|
|
||
|
The number of samples is not defined when running with `steps`,
|
||
|
in which case the number of samples is set to `None`.
|
||
|
|
||
|
Args:
|
||
|
ins: List of tensors to be fed to the Keras function.
|
||
|
batch_size: Integer batch size or `None` if not defined.
|
||
|
steps: Total number of steps (batches of samples) before declaring
|
||
|
`_predict_loop` finished. Ignored with the default value of `None`.
|
||
|
steps_name: The public API's parameter name for `steps`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: when `steps` is `None` and the attribute `ins.shape`
|
||
|
does not exist. Also raises ValueError when `steps` is not `None`
|
||
|
and `batch_size` is not `None` because they are mutually
|
||
|
exclusive.
|
||
|
|
||
|
Returns:
|
||
|
When steps is `None`, returns the number of samples to be
|
||
|
processed based on the size of the first dimension of the
|
||
|
first input numpy array. When steps is not `None` and
|
||
|
`batch_size` is `None`, returns `None`.
|
||
|
"""
|
||
|
if steps is not None and batch_size is not None:
|
||
|
raise ValueError(
|
||
|
"If " + steps_name + " is set, the `batch_size` must be None."
|
||
|
)
|
||
|
if check_steps_argument(ins, steps, steps_name):
|
||
|
return None
|
||
|
|
||
|
if hasattr(ins[0], "shape"):
|
||
|
return int(ins[0].shape[0])
|
||
|
return None # Edge case where ins == [static_learning_phase]
|
||
|
|
||
|
|
||
|
def standardize_single_array(x, expected_shape=None):
|
||
|
"""Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1."""
|
||
|
if x is None:
|
||
|
return None
|
||
|
|
||
|
if is_composite_or_composite_value(x):
|
||
|
return x
|
||
|
|
||
|
if isinstance(x, int):
|
||
|
raise ValueError(
|
||
|
f"Expected an array data type but received an integer: {x}"
|
||
|
)
|
||
|
|
||
|
if (
|
||
|
x.shape is not None
|
||
|
and len(x.shape) == 1
|
||
|
and (expected_shape is None or len(expected_shape) != 1)
|
||
|
):
|
||
|
if tf.is_tensor(x):
|
||
|
x = tf.compat.v1.expand_dims(x, axis=1)
|
||
|
else:
|
||
|
x = np.expand_dims(x, 1)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def get_composite_shape(tensor):
|
||
|
"""Returns the shape of the passed composite tensor."""
|
||
|
if isinstance(tensor, tf.compat.v1.SparseTensorValue):
|
||
|
# SparseTensorValues use a 'dense_shape' attribute
|
||
|
return tensor.dense_shape
|
||
|
else:
|
||
|
return tensor.shape
|
||
|
|
||
|
|
||
|
def standardize_input_data(
|
||
|
data, names, shapes=None, check_batch_axis=True, exception_prefix=""
|
||
|
):
|
||
|
"""Normalizes inputs and targets provided by users.
|
||
|
|
||
|
Users may pass data as a list of arrays, dictionary of arrays,
|
||
|
or as a single array. We normalize this to an ordered list of
|
||
|
arrays (same order as `names`), while checking that the provided
|
||
|
arrays have shapes that match the network's expectations.
|
||
|
|
||
|
Args:
|
||
|
data: User-provided input data (polymorphic).
|
||
|
names: List of expected array names.
|
||
|
shapes: Optional list of expected array shapes.
|
||
|
check_batch_axis: Boolean; whether to check that the batch axis of the
|
||
|
arrays matches the expected value found in `shapes`.
|
||
|
exception_prefix: String prefix used for exception formatting.
|
||
|
|
||
|
Returns:
|
||
|
List of standardized input arrays (one array per model input).
|
||
|
|
||
|
Raises:
|
||
|
ValueError: in case of improperly formatted user-provided data.
|
||
|
"""
|
||
|
try:
|
||
|
data_len = len(data)
|
||
|
except TypeError:
|
||
|
# For instance if data is `None` or a symbolic Tensor.
|
||
|
data_len = None
|
||
|
|
||
|
if not names:
|
||
|
if data_len and not isinstance(data, dict):
|
||
|
raise ValueError(
|
||
|
"Error when checking model "
|
||
|
+ exception_prefix
|
||
|
+ ": expected no data, but got:",
|
||
|
data,
|
||
|
)
|
||
|
return []
|
||
|
if data is None:
|
||
|
return [None for _ in range(len(names))]
|
||
|
|
||
|
if isinstance(data, dict):
|
||
|
try:
|
||
|
data = [
|
||
|
data[x].values
|
||
|
if data[x].__class__.__name__ == "DataFrame"
|
||
|
else data[x]
|
||
|
for x in names
|
||
|
]
|
||
|
except KeyError as e:
|
||
|
raise ValueError(
|
||
|
'No data provided for "'
|
||
|
+ e.args[0]
|
||
|
+ '". Need data for each key in: '
|
||
|
+ str(names)
|
||
|
)
|
||
|
elif isinstance(data, (list, tuple)):
|
||
|
if isinstance(data[0], (list, tuple)):
|
||
|
data = [np.asarray(d) for d in data]
|
||
|
elif len(names) == 1 and isinstance(data[0], (float, int)):
|
||
|
data = [np.asarray(data)]
|
||
|
else:
|
||
|
data = [
|
||
|
x.values if x.__class__.__name__ == "DataFrame" else x
|
||
|
for x in data
|
||
|
]
|
||
|
else:
|
||
|
data = data.values if data.__class__.__name__ == "DataFrame" else data
|
||
|
data = [data]
|
||
|
|
||
|
if shapes is not None:
|
||
|
data = [
|
||
|
standardize_single_array(x, shape)
|
||
|
for (x, shape) in zip(data, shapes)
|
||
|
]
|
||
|
else:
|
||
|
data = [standardize_single_array(x) for x in data]
|
||
|
|
||
|
if len(data) != len(names):
|
||
|
if data and hasattr(data[0], "shape"):
|
||
|
raise ValueError(
|
||
|
"Error when checking model "
|
||
|
+ exception_prefix
|
||
|
+ ": the list of Numpy arrays that you are passing to "
|
||
|
"your model is not the size the model expected. "
|
||
|
"Expected to see "
|
||
|
+ str(len(names))
|
||
|
+ " array(s), "
|
||
|
+ "for inputs "
|
||
|
+ str(names)
|
||
|
+ " but instead got the following list of "
|
||
|
+ str(len(data))
|
||
|
+ " arrays: "
|
||
|
+ str(data)[:200]
|
||
|
+ "..."
|
||
|
)
|
||
|
elif len(names) > 1:
|
||
|
raise ValueError(
|
||
|
"Error when checking model "
|
||
|
+ exception_prefix
|
||
|
+ ": you are passing a list as input to your model, "
|
||
|
"but the model expects a list of "
|
||
|
+ str(len(names))
|
||
|
+ " Numpy arrays instead. The list you passed was: "
|
||
|
+ str(data)[:200]
|
||
|
)
|
||
|
elif len(data) == 1 and not hasattr(data[0], "shape"):
|
||
|
raise TypeError(
|
||
|
"Error when checking model "
|
||
|
+ exception_prefix
|
||
|
+ ": data should be a Numpy array, or list/dict of "
|
||
|
"Numpy arrays. Found: " + str(data)[:200] + "..."
|
||
|
)
|
||
|
elif len(names) == 1:
|
||
|
data = [np.asarray(data)]
|
||
|
|
||
|
# Check shapes compatibility.
|
||
|
if shapes:
|
||
|
for i in range(len(names)):
|
||
|
if shapes[i] is not None:
|
||
|
if tf.is_tensor(data[i]):
|
||
|
tensorshape = data[i].shape
|
||
|
if not tensorshape:
|
||
|
continue
|
||
|
data_shape = tuple(tensorshape.as_list())
|
||
|
elif is_composite_or_composite_value(data[i]):
|
||
|
tensorshape = get_composite_shape(data[i])
|
||
|
data_shape = tuple(tensorshape.as_list())
|
||
|
else:
|
||
|
data_shape = data[i].shape
|
||
|
|
||
|
shape = shapes[i]
|
||
|
if len(data_shape) != len(shape):
|
||
|
raise ValueError(
|
||
|
"Error when checking "
|
||
|
+ exception_prefix
|
||
|
+ ": expected "
|
||
|
+ names[i]
|
||
|
+ " to have "
|
||
|
+ str(len(shape))
|
||
|
+ " dimensions, but got array with shape "
|
||
|
+ str(data_shape)
|
||
|
)
|
||
|
if not check_batch_axis:
|
||
|
data_shape = data_shape[1:]
|
||
|
shape = shape[1:]
|
||
|
for dim, ref_dim in zip(data_shape, shape):
|
||
|
if (
|
||
|
ref_dim != dim
|
||
|
and ref_dim is not None
|
||
|
and dim is not None
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"Error when checking "
|
||
|
+ exception_prefix
|
||
|
+ ": expected "
|
||
|
+ names[i]
|
||
|
+ " to have shape "
|
||
|
+ str(shape)
|
||
|
+ " but got array with shape "
|
||
|
+ str(data_shape)
|
||
|
)
|
||
|
return data
|
||
|
|
||
|
|
||
|
def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
|
||
|
"""Maps `sample_weight` or `class_weight` to model outputs.
|
||
|
|
||
|
Args:
|
||
|
x_weight: User-provided `sample_weight` or `class_weight` argument.
|
||
|
output_names: List of output names (strings) in the model.
|
||
|
weight_type: A string used purely for exception printing.
|
||
|
|
||
|
Returns:
|
||
|
A list of `sample_weight` or `class_weight` where there are exactly
|
||
|
one element per model output.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case of invalid user-provided argument.
|
||
|
"""
|
||
|
if x_weight is None or (
|
||
|
isinstance(x_weight, (list, tuple)) and len(x_weight) == 0
|
||
|
):
|
||
|
return [None for _ in output_names]
|
||
|
if len(output_names) == 1:
|
||
|
if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
|
||
|
return x_weight
|
||
|
if isinstance(x_weight, dict) and output_names[0] in x_weight:
|
||
|
return [x_weight[output_names[0]]]
|
||
|
else:
|
||
|
return [x_weight]
|
||
|
if isinstance(x_weight, (list, tuple)):
|
||
|
if len(x_weight) != len(output_names):
|
||
|
raise ValueError(
|
||
|
"Provided `"
|
||
|
+ weight_type
|
||
|
+ "` was a list of "
|
||
|
+ str(len(x_weight))
|
||
|
+ " elements, but the model has "
|
||
|
+ str(len(output_names))
|
||
|
+ " outputs. You should provide one `"
|
||
|
+ weight_type
|
||
|
+ "`array per model output."
|
||
|
)
|
||
|
return x_weight
|
||
|
if isinstance(x_weight, collections.abc.Mapping):
|
||
|
generic_utils.check_for_unexpected_keys(
|
||
|
weight_type, x_weight, output_names
|
||
|
)
|
||
|
x_weights = []
|
||
|
for name in output_names:
|
||
|
x_weights.append(x_weight.get(name))
|
||
|
return x_weights
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"The model has multiple outputs, so `"
|
||
|
+ weight_type
|
||
|
+ "` should be either a list or a dict. Provided `"
|
||
|
+ weight_type
|
||
|
+ "` type not understood: "
|
||
|
+ str(x_weight)
|
||
|
)
|
||
|
|
||
|
|
||
|
def standardize_class_weights(class_weight, output_names):
|
||
|
return standardize_sample_or_class_weights(
|
||
|
class_weight, output_names, "class_weight"
|
||
|
)
|
||
|
|
||
|
|
||
|
def standardize_sample_weights(sample_weight, output_names):
|
||
|
return standardize_sample_or_class_weights(
|
||
|
sample_weight, output_names, "sample_weight"
|
||
|
)
|
||
|
|
||
|
|
||
|
def check_array_lengths(inputs, targets, weights=None):
|
||
|
"""Does user input validation for numpy arrays.
|
||
|
|
||
|
Args:
|
||
|
inputs: list of Numpy arrays of inputs.
|
||
|
targets: list of Numpy arrays of targets.
|
||
|
weights: list of Numpy arrays of sample weights.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: in case of incorrectly formatted data.
|
||
|
"""
|
||
|
|
||
|
def is_tensor_or_composite_tensor(x):
|
||
|
return tf.is_tensor(x) or is_composite_or_composite_value(x)
|
||
|
|
||
|
def set_of_lengths(x):
|
||
|
# Returns a set with the variation between
|
||
|
# different shapes, with None => 0
|
||
|
if x is None:
|
||
|
return {}
|
||
|
else:
|
||
|
return set(
|
||
|
[
|
||
|
y.shape[0]
|
||
|
for y in x
|
||
|
if y is not None and not is_tensor_or_composite_tensor(y)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
set_x = set_of_lengths(inputs)
|
||
|
set_y = set_of_lengths(targets)
|
||
|
set_w = set_of_lengths(weights)
|
||
|
if len(set_x) > 1:
|
||
|
raise ValueError(
|
||
|
"All input arrays (x) should have "
|
||
|
"the same number of samples. Got array shapes: "
|
||
|
+ str([x.shape for x in inputs])
|
||
|
)
|
||
|
if len(set_y) > 1:
|
||
|
raise ValueError(
|
||
|
"All target arrays (y) should have "
|
||
|
"the same number of samples. Got array shapes: "
|
||
|
+ str([y.shape for y in targets])
|
||
|
)
|
||
|
if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
|
||
|
raise ValueError(
|
||
|
"Input arrays should have "
|
||
|
"the same number of samples as target arrays. "
|
||
|
"Found "
|
||
|
+ str(list(set_x)[0])
|
||
|
+ " input samples and "
|
||
|
+ str(list(set_y)[0])
|
||
|
+ " target samples."
|
||
|
)
|
||
|
if len(set_w) > 1:
|
||
|
raise ValueError(
|
||
|
"All sample_weight arrays should have "
|
||
|
"the same number of samples. Got array shapes: "
|
||
|
+ str([w.shape for w in weights])
|
||
|
)
|
||
|
if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
|
||
|
raise ValueError(
|
||
|
"Sample_weight arrays should have "
|
||
|
"the same number of samples as target arrays. Got "
|
||
|
+ str(list(set_y)[0])
|
||
|
+ " input samples and "
|
||
|
+ str(list(set_w)[0])
|
||
|
+ " target samples."
|
||
|
)
|
||
|
|
||
|
|
||
|
def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
|
||
|
"""Does validation on the compatibility of targets and loss functions.
|
||
|
|
||
|
This helps prevent users from using loss functions incorrectly. This check
|
||
|
is purely for UX purposes.
|
||
|
|
||
|
Args:
|
||
|
targets: list of Numpy arrays of targets.
|
||
|
loss_fns: list of loss functions.
|
||
|
output_shapes: list of shapes of model outputs.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if a loss function or target array
|
||
|
is incompatible with an output.
|
||
|
"""
|
||
|
key_loss_fns = {
|
||
|
losses.mean_squared_error,
|
||
|
losses.binary_crossentropy,
|
||
|
losses.categorical_crossentropy,
|
||
|
}
|
||
|
key_loss_classes = (
|
||
|
losses.MeanSquaredError,
|
||
|
losses.BinaryCrossentropy,
|
||
|
losses.CategoricalCrossentropy,
|
||
|
)
|
||
|
for y, loss, shape in zip(targets, loss_fns, output_shapes):
|
||
|
if y is None or loss is None or tf.is_tensor(y):
|
||
|
continue
|
||
|
if losses.is_categorical_crossentropy(loss):
|
||
|
if y.shape[-1] == 1:
|
||
|
raise ValueError(
|
||
|
"You are passing a target array of shape "
|
||
|
+ str(y.shape)
|
||
|
+ " while using as loss `categorical_crossentropy`. "
|
||
|
"`categorical_crossentropy` expects "
|
||
|
"targets to be binary matrices (1s and 0s) "
|
||
|
"of shape (samples, classes). "
|
||
|
"If your targets are integer classes, "
|
||
|
"you can convert them to the expected format via:\n"
|
||
|
"```\n"
|
||
|
"from keras.utils import to_categorical\n"
|
||
|
"y_binary = to_categorical(y_int)\n"
|
||
|
"```\n"
|
||
|
"\n"
|
||
|
"Alternatively, you can use the loss function "
|
||
|
"`sparse_categorical_crossentropy` instead, "
|
||
|
"which does expect integer targets."
|
||
|
)
|
||
|
|
||
|
is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
|
||
|
if isinstance(loss, key_loss_classes) or (
|
||
|
is_loss_wrapper and (loss.fn in key_loss_fns)
|
||
|
):
|
||
|
for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
|
||
|
if out_dim is not None and target_dim != out_dim:
|
||
|
loss_name = loss.name
|
||
|
if loss_name is None:
|
||
|
loss_type = loss.fn if is_loss_wrapper else type(loss)
|
||
|
loss_name = loss_type.__name__
|
||
|
raise ValueError(
|
||
|
"A target array with shape "
|
||
|
+ str(y.shape)
|
||
|
+ " was passed for an output of shape "
|
||
|
+ str(shape)
|
||
|
+ " while using as loss `"
|
||
|
+ loss_name
|
||
|
+ "`. "
|
||
|
"This loss expects targets to have the same shape "
|
||
|
"as the output."
|
||
|
)
|
||
|
|
||
|
|
||
|
def collect_per_output_metric_info(
|
||
|
metrics,
|
||
|
output_names,
|
||
|
output_shapes,
|
||
|
loss_fns,
|
||
|
from_serialized=False,
|
||
|
is_weighted=False,
|
||
|
):
|
||
|
"""Maps metric names and functions to model outputs.
|
||
|
|
||
|
Args:
|
||
|
metrics: a list or a list of lists or a dict of metric functions.
|
||
|
output_names: a list of the names (strings) of model outputs.
|
||
|
output_shapes: a list of the shapes (strings) of model outputs.
|
||
|
loss_fns: a list of the loss functions corresponding to the model
|
||
|
outputs.
|
||
|
from_serialized: whether the model the metrics are being sourced from is
|
||
|
being initialized from a serialized format.
|
||
|
is_weighted: Boolean indicating whether the given metrics are weighted.
|
||
|
|
||
|
Returns:
|
||
|
A list (one entry per model output) of dicts.
|
||
|
For instance, if the model has 2 outputs, and for the first output
|
||
|
we want to compute "binary_accuracy" and "binary_crossentropy",
|
||
|
and just "binary_accuracy" for the second output,
|
||
|
the list would look like: `[{
|
||
|
'acc': binary_accuracy(),
|
||
|
'ce': binary_crossentropy(),
|
||
|
}, {
|
||
|
'acc': binary_accuracy(),
|
||
|
}]`
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if an incorrect type is passed for the `metrics` argument.
|
||
|
"""
|
||
|
if not metrics:
|
||
|
return [{} for _ in output_names]
|
||
|
|
||
|
if isinstance(metrics, list):
|
||
|
any_sub_list = any(isinstance(m, list) for m in metrics)
|
||
|
if any_sub_list:
|
||
|
if len(metrics) != len(output_names):
|
||
|
raise ValueError(
|
||
|
"When passing a list of lists as `metrics`, "
|
||
|
"it should have one entry per model output. "
|
||
|
"The model has "
|
||
|
+ str(len(output_names))
|
||
|
+ " outputs, but you passed metrics="
|
||
|
+ str(metrics)
|
||
|
)
|
||
|
# User has provided a list of len = len(outputs).
|
||
|
nested_metrics = [generic_utils.to_list(m) for m in metrics]
|
||
|
else:
|
||
|
# If it is a single list we then apply all metrics to all outputs.
|
||
|
if len(output_names) > 1:
|
||
|
nested_metrics = []
|
||
|
for _ in output_names:
|
||
|
nested_metrics.append(
|
||
|
[metrics_module.clone_metric(m) for m in metrics]
|
||
|
)
|
||
|
else:
|
||
|
nested_metrics = [metrics]
|
||
|
elif isinstance(metrics, collections.abc.Mapping):
|
||
|
generic_utils.check_for_unexpected_keys(
|
||
|
"metrics", metrics, output_names
|
||
|
)
|
||
|
nested_metrics = []
|
||
|
for name in output_names:
|
||
|
output_metrics = generic_utils.to_list(metrics.get(name, []))
|
||
|
nested_metrics.append(output_metrics)
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"Type of `metrics` argument not understood. "
|
||
|
"Expected a list or dictionary, found: " + str(metrics)
|
||
|
)
|
||
|
|
||
|
per_output_metrics = []
|
||
|
for i, metrics in enumerate(nested_metrics):
|
||
|
metrics_dict = collections.OrderedDict()
|
||
|
for metric in metrics:
|
||
|
metric_name = get_metric_name(metric, is_weighted)
|
||
|
metric_fn = get_metric_function(
|
||
|
metric, output_shape=output_shapes[i], loss_fn=loss_fns[i]
|
||
|
)
|
||
|
metric_fn._from_serialized = from_serialized
|
||
|
|
||
|
# If the metric function is not stateful, we create a stateful
|
||
|
# version.
|
||
|
if not isinstance(metric_fn, metrics_module.Metric):
|
||
|
metric_fn = metrics_module.MeanMetricWrapper(
|
||
|
metric_fn, name=metric_name
|
||
|
)
|
||
|
# If the metric is being revived from something stateless, such
|
||
|
# as a string (e.g. "accuracy"), we may need to later reapply
|
||
|
# transformations such as renaming.
|
||
|
metric_fn._from_serialized = False
|
||
|
metrics_dict[metric_name] = metric_fn
|
||
|
per_output_metrics.append(metrics_dict)
|
||
|
|
||
|
return per_output_metrics
|
||
|
|
||
|
|
||
|
def batch_shuffle(index_array, batch_size):
|
||
|
"""Shuffles an array in a batch-wise fashion.
|
||
|
|
||
|
Useful for shuffling HDF5 arrays
|
||
|
(where one cannot access arbitrary indices).
|
||
|
|
||
|
Args:
|
||
|
index_array: array of indices to be shuffled.
|
||
|
batch_size: integer.
|
||
|
|
||
|
Returns:
|
||
|
The `index_array` array, shuffled in a batch-wise fashion.
|
||
|
"""
|
||
|
batch_count = int(len(index_array) / batch_size)
|
||
|
# to reshape we need to be cleanly divisible by batch size
|
||
|
# we stash extra items and reappend them after shuffling
|
||
|
last_batch = index_array[batch_count * batch_size :]
|
||
|
index_array = index_array[: batch_count * batch_size]
|
||
|
index_array = index_array.reshape((batch_count, batch_size))
|
||
|
np.random.shuffle(index_array)
|
||
|
index_array = index_array.flatten()
|
||
|
return np.append(index_array, last_batch)
|
||
|
|
||
|
|
||
|
def standardize_weights(
|
||
|
y, sample_weight=None, class_weight=None, sample_weight_mode=None
|
||
|
):
|
||
|
"""Performs sample weight validation and standardization.
|
||
|
|
||
|
Everything gets normalized to a single sample-wise (or timestep-wise)
|
||
|
weight array. If both `sample_weight` and `class_weight` are provided,
|
||
|
the weights are multiplied.
|
||
|
|
||
|
Args:
|
||
|
y: Numpy array or Tensor of model targets to be weighted.
|
||
|
sample_weight: User-provided `sample_weight` argument.
|
||
|
class_weight: User-provided `class_weight` argument.
|
||
|
sample_weight_mode: One of `None` or `"temporal"`. `"temporal"`
|
||
|
indicated that we expect 2D weight data that will be applied to the
|
||
|
last 2 dimensions of the targets (i.e. we are weighting timesteps, not
|
||
|
samples).
|
||
|
|
||
|
Returns:
|
||
|
A numpy array of target weights, one entry per sample to weight.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case of invalid user-provided arguments.
|
||
|
"""
|
||
|
# Iterator may return sample_weight as 1-tuple
|
||
|
if isinstance(sample_weight, tuple):
|
||
|
sample_weight = sample_weight[0]
|
||
|
if sample_weight_mode is not None and sample_weight_mode != "samplewise":
|
||
|
if sample_weight_mode != "temporal":
|
||
|
raise ValueError(
|
||
|
'"sample_weight_mode should be None or "temporal". Found: '
|
||
|
+ str(sample_weight_mode)
|
||
|
)
|
||
|
if len(y.shape) < 3:
|
||
|
raise ValueError(
|
||
|
"Found a sample_weight array for an input with shape "
|
||
|
+ str(y.shape)
|
||
|
+ ". "
|
||
|
"Timestep-wise sample weighting (use of "
|
||
|
'sample_weight_mode="temporal") is restricted to '
|
||
|
"outputs that are at least 3D, i.e. that have "
|
||
|
"a time dimension."
|
||
|
)
|
||
|
if sample_weight is not None and len(sample_weight.shape) != 2:
|
||
|
raise ValueError(
|
||
|
"Found a sample_weight array with shape "
|
||
|
+ str(sample_weight.shape)
|
||
|
+ ". "
|
||
|
"In order to use timestep-wise sample weighting, "
|
||
|
"you should pass a 2D sample_weight array."
|
||
|
)
|
||
|
else:
|
||
|
if sample_weight is not None and len(sample_weight.shape) != 1:
|
||
|
raise ValueError(
|
||
|
"Found a sample_weight array with shape {}. In order to "
|
||
|
"use timestep-wise sample weights, you should specify "
|
||
|
'sample_weight_mode="temporal" in compile(); founssd "{}" '
|
||
|
"instead. If you just mean to use sample-wise weights, "
|
||
|
"make sure your sample_weight array is 1D.".format(
|
||
|
sample_weight.shape, sample_weight_mode
|
||
|
)
|
||
|
)
|
||
|
|
||
|
if sample_weight is not None:
|
||
|
if len(sample_weight.shape) > len(y.shape):
|
||
|
raise ValueError(
|
||
|
"Found a sample_weight with shape"
|
||
|
+ str(sample_weight.shape)
|
||
|
+ ".Expected sample_weight with rank less than or equal to "
|
||
|
+ str(len(y.shape))
|
||
|
)
|
||
|
|
||
|
if (
|
||
|
not tf.is_tensor(sample_weight)
|
||
|
and y.shape[: sample_weight.ndim] != sample_weight.shape
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"Found a sample_weight array with shape "
|
||
|
+ str(sample_weight.shape)
|
||
|
+ " for an input with shape "
|
||
|
+ str(y.shape)
|
||
|
+ ". sample_weight cannot be broadcast."
|
||
|
)
|
||
|
|
||
|
# Class weights applied per-sample.
|
||
|
class_sample_weight = None
|
||
|
if isinstance(class_weight, dict):
|
||
|
if len(y.shape) > 2:
|
||
|
raise ValueError(
|
||
|
"`class_weight` not supported for 3+ dimensional targets."
|
||
|
)
|
||
|
|
||
|
if tf.is_tensor(y):
|
||
|
# Few classes are expected, so densifying is reasonable.
|
||
|
keys = np.array(sorted(class_weight.keys()))
|
||
|
values = np.array([class_weight[i] for i in keys])
|
||
|
weight_vector = np.zeros(np.max(keys) + 1)
|
||
|
weight_vector[:] = np.nan
|
||
|
weight_vector[keys] = values
|
||
|
|
||
|
y_classes = tf.__internal__.smart_cond.smart_cond(
|
||
|
len(y.shape.as_list()) == 2 and backend.shape(y)[1] > 1,
|
||
|
lambda: backend.argmax(y, axis=1),
|
||
|
lambda: tf.cast(backend.reshape(y, (-1,)), tf.int64),
|
||
|
)
|
||
|
class_sample_weight = tf.compat.v1.gather(weight_vector, y_classes)
|
||
|
tf.debugging.check_numerics(
|
||
|
class_sample_weight,
|
||
|
"Invalid classes or class weights detected. NaN values "
|
||
|
"indicate that an appropriate class weight could not be "
|
||
|
"determined.",
|
||
|
)
|
||
|
class_sample_weight = tf.cast(class_sample_weight, backend.floatx())
|
||
|
if sample_weight is not None:
|
||
|
sample_weight = tf.cast(
|
||
|
tf.convert_to_tensor(sample_weight), backend.floatx()
|
||
|
)
|
||
|
else:
|
||
|
y_classes = y
|
||
|
if len(y.shape) == 2:
|
||
|
if y.shape[1] > 1:
|
||
|
y_classes = np.argmax(y, axis=1)
|
||
|
elif y.shape[1] == 1:
|
||
|
y_classes = np.reshape(y, y.shape[0])
|
||
|
|
||
|
class_sample_weight = np.asarray(
|
||
|
[class_weight[cls] for cls in y_classes if cls in class_weight]
|
||
|
)
|
||
|
|
||
|
if len(class_sample_weight) != len(y_classes):
|
||
|
# subtract the sets to pick all missing classes
|
||
|
existing_classes = set(y_classes)
|
||
|
existing_class_weight = set(class_weight.keys())
|
||
|
raise ValueError(
|
||
|
"`class_weight` must contain all classes in the data."
|
||
|
" The classes %s exist in the data but not in "
|
||
|
"`class_weight`."
|
||
|
% (existing_classes - existing_class_weight)
|
||
|
)
|
||
|
|
||
|
if class_sample_weight is not None and sample_weight is not None:
|
||
|
# Multiply weights if both are provided.
|
||
|
return class_sample_weight * sample_weight
|
||
|
if sample_weight is not None:
|
||
|
return sample_weight
|
||
|
if class_sample_weight is not None:
|
||
|
return class_sample_weight
|
||
|
return None
|
||
|
|
||
|
|
||
|
def has_symbolic_tensors(ls):
|
||
|
if tf.executing_eagerly():
|
||
|
return False
|
||
|
return has_tensors(ls)
|
||
|
|
||
|
|
||
|
def has_tensors(ls):
|
||
|
"""Returns true if `ls` contains tensors."""
|
||
|
# Note: at some point in time ragged tensors didn't count as tensors, so
|
||
|
# this returned false for ragged tensors. Making this return true fails some
|
||
|
# tests which would then require a steps_per_epoch argument.
|
||
|
if isinstance(ls, (list, tuple)):
|
||
|
return any(
|
||
|
tf.is_tensor(v) and not isinstance(v, tf.RaggedTensor) for v in ls
|
||
|
)
|
||
|
if isinstance(ls, dict):
|
||
|
return any(
|
||
|
tf.is_tensor(v) and not isinstance(v, tf.RaggedTensor)
|
||
|
for _, v in ls.items()
|
||
|
)
|
||
|
return tf.is_tensor(ls) and not isinstance(ls, tf.RaggedTensor)
|
||
|
|
||
|
|
||
|
def get_metric_name(metric, weighted=False):
|
||
|
"""Returns the name corresponding to the given metric input.
|
||
|
|
||
|
Args:
|
||
|
metric: Metric function name or reference.
|
||
|
weighted: Boolean indicating if the given metric is weighted.
|
||
|
|
||
|
Returns:
|
||
|
The metric name.
|
||
|
"""
|
||
|
if tf.__internal__.tf2.enabled():
|
||
|
# We keep the string that the user has set in compile as the metric
|
||
|
# name.
|
||
|
if isinstance(metric, str):
|
||
|
return metric
|
||
|
|
||
|
metric = metrics_module.get(metric)
|
||
|
return metric.name if hasattr(metric, "name") else metric.__name__
|
||
|
else:
|
||
|
metric_name_prefix = "weighted_" if weighted else ""
|
||
|
if metric in ("accuracy", "acc", "crossentropy", "ce"):
|
||
|
if metric in ("accuracy", "acc"):
|
||
|
suffix = "acc"
|
||
|
elif metric in ("crossentropy", "ce"):
|
||
|
suffix = "ce"
|
||
|
else:
|
||
|
metric_fn = metrics_module.get(metric)
|
||
|
# Get metric name as string
|
||
|
if hasattr(metric_fn, "name"):
|
||
|
suffix = metric_fn.name
|
||
|
else:
|
||
|
suffix = metric_fn.__name__
|
||
|
metric_name = metric_name_prefix + suffix
|
||
|
return metric_name
|
||
|
|
||
|
|
||
|
def get_metric_function(metric, output_shape=None, loss_fn=None):
|
||
|
"""Returns the metric function corresponding to the given metric input.
|
||
|
|
||
|
Args:
|
||
|
metric: Metric function name or reference.
|
||
|
output_shape: The shape of the output that this metric will be
|
||
|
calculated for.
|
||
|
loss_fn: The loss function used.
|
||
|
|
||
|
Returns:
|
||
|
The metric function.
|
||
|
"""
|
||
|
if metric not in ["accuracy", "acc", "crossentropy", "ce"]:
|
||
|
return metrics_module.get(metric)
|
||
|
|
||
|
is_sparse_categorical_crossentropy = isinstance(
|
||
|
loss_fn, losses.SparseCategoricalCrossentropy
|
||
|
) or (
|
||
|
isinstance(loss_fn, losses.LossFunctionWrapper)
|
||
|
and loss_fn.fn == losses.sparse_categorical_crossentropy
|
||
|
)
|
||
|
|
||
|
is_binary_crossentropy = isinstance(loss_fn, losses.BinaryCrossentropy) or (
|
||
|
isinstance(loss_fn, losses.LossFunctionWrapper)
|
||
|
and loss_fn.fn == losses.binary_crossentropy
|
||
|
)
|
||
|
|
||
|
if metric in ["accuracy", "acc"]:
|
||
|
if output_shape[-1] == 1 or is_binary_crossentropy:
|
||
|
return metrics_module.binary_accuracy
|
||
|
elif is_sparse_categorical_crossentropy:
|
||
|
return metrics_module.sparse_categorical_accuracy
|
||
|
# If the output_shape[-1] is not 1, then we know output is
|
||
|
# `categorical`. We assume it is sparse categorical only if loss is
|
||
|
# explicitly given as sparse categorical crossentropy loss.
|
||
|
return metrics_module.categorical_accuracy
|
||
|
else:
|
||
|
if output_shape[-1] == 1 or is_binary_crossentropy:
|
||
|
return metrics_module.binary_crossentropy
|
||
|
elif is_sparse_categorical_crossentropy:
|
||
|
return metrics_module.sparse_categorical_crossentropy
|
||
|
return metrics_module.categorical_crossentropy
|
||
|
|
||
|
|
||
|
def call_metric_function(
|
||
|
metric_fn, y_true, y_pred=None, weights=None, mask=None
|
||
|
):
|
||
|
"""Invokes metric function and returns the metric result tensor."""
|
||
|
if mask is not None:
|
||
|
mask = tf.cast(mask, y_pred.dtype)
|
||
|
if weights is None:
|
||
|
# Use mask as sample weight.
|
||
|
weights = mask
|
||
|
else:
|
||
|
# Update dimensions of weights to match with mask.
|
||
|
weights = tf.cast(weights, dtype=y_pred.dtype)
|
||
|
mask, _, weights = losses_utils.squeeze_or_expand_dimensions(
|
||
|
mask, sample_weight=weights
|
||
|
)
|
||
|
weights *= mask
|
||
|
|
||
|
if y_pred is not None:
|
||
|
return metric_fn(y_true, y_pred, sample_weight=weights)
|
||
|
# `Mean` metric only takes a single value.
|
||
|
return metric_fn(y_true, sample_weight=weights)
|
||
|
|
||
|
|
||
|
def get_loss_function(loss):
|
||
|
"""Returns the loss corresponding to the loss input in `compile` API."""
|
||
|
if loss is None or isinstance(loss, losses.Loss):
|
||
|
return loss
|
||
|
|
||
|
if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss):
|
||
|
# It is not safe to assume that the loss takes no constructor arguments.
|
||
|
raise ValueError(
|
||
|
"Received uninstantiated Loss class: {}\n"
|
||
|
"Please call loss classes "
|
||
|
"before passing them to Model.compile.".format(loss)
|
||
|
)
|
||
|
|
||
|
# Deserialize loss configuration, if needed.
|
||
|
if isinstance(loss, collections.abc.Mapping):
|
||
|
loss = losses.get(loss)
|
||
|
|
||
|
# Custom callable class.
|
||
|
if callable(loss) and not hasattr(loss, "__name__"):
|
||
|
return loss
|
||
|
|
||
|
# Wrap loss function with signature `(y_true, y_pred, **kwargs)`
|
||
|
# in `LossFunctionWrapper` class.
|
||
|
loss_fn = losses.get(loss)
|
||
|
|
||
|
# For losses which are given as strings/functions in the compile API,
|
||
|
# we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`
|
||
|
# (both in distribution strategy context and otherwise).
|
||
|
return losses.LossFunctionWrapper(
|
||
|
loss_fn,
|
||
|
name=loss_fn.__name__,
|
||
|
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||
|
)
|
||
|
|
||
|
|
||
|
def validate_dataset_input(x, y, sample_weight, validation_split=None):
|
||
|
"""Validates user input arguments when a dataset iterator is passed.
|
||
|
|
||
|
Args:
|
||
|
x: Input data. A `tf.data` dataset or iterator.
|
||
|
y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
|
||
|
Expected to be `None` when `x` is a dataset iterator.
|
||
|
sample_weight: An optional sample-weight array passed by the user to
|
||
|
weight the importance of each sample in `x`. Expected to be `None` when
|
||
|
`x` is a dataset iterator
|
||
|
validation_split: Float between 0 and 1. Fraction of the training data to
|
||
|
be used as validation data. Expected to be `None` when `x` is a dataset
|
||
|
iterator.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if argument `y` or `sample_weight` or `validation_split` are
|
||
|
provided by user.
|
||
|
"""
|
||
|
if y is not None:
|
||
|
raise ValueError(
|
||
|
"You passed a dataset or dataset iterator (%s) as "
|
||
|
"input `x` to your model. In that case, you should "
|
||
|
"not specify a target (`y`) argument, since the dataset "
|
||
|
"or dataset iterator generates both input data and "
|
||
|
"target data. "
|
||
|
"Received: %s" % (x, y)
|
||
|
)
|
||
|
if sample_weight is not None:
|
||
|
raise ValueError(
|
||
|
"`sample_weight` argument is not supported when input "
|
||
|
"`x` is a dataset or a dataset iterator. Instead, you"
|
||
|
"can provide sample_weight as the third element of your"
|
||
|
"dataset, i.e. (inputs, targets, sample_weight). "
|
||
|
"Received: x=%s, sample_weight=%s" % (x, sample_weight)
|
||
|
)
|
||
|
if validation_split is not None and validation_split != 0.0:
|
||
|
raise ValueError(
|
||
|
"`validation_split` argument is not supported when "
|
||
|
"input `x` is a dataset or a dataset iterator. "
|
||
|
"Received: x=%s, validation_split=%f" % (x, validation_split)
|
||
|
)
|
||
|
|
||
|
|
||
|
def validate_input_types(inp, orig_inp, allow_dict=True, field_name="inputs"):
|
||
|
"""Helper function to validate either inputs or targets."""
|
||
|
if isinstance(inp, (list, tuple)):
|
||
|
if not all(isinstance(v, np.ndarray) or tf.is_tensor(v) for v in inp):
|
||
|
raise ValueError(
|
||
|
"Please provide as model inputs either a single array or a "
|
||
|
f"list of arrays. You passed: {field_name}={str(orig_inp)}"
|
||
|
)
|
||
|
elif isinstance(inp, dict):
|
||
|
if not allow_dict:
|
||
|
raise ValueError(
|
||
|
f"You cannot pass a dictionary as model {field_name}."
|
||
|
)
|
||
|
elif not isinstance(inp, np.ndarray) and not tf.is_tensor(inp):
|
||
|
raise ValueError(
|
||
|
"Please provide as model inputs either a single array or a list of "
|
||
|
"arrays. You passed: {}={}".format(field_name, orig_inp)
|
||
|
)
|
||
|
|
||
|
|
||
|
def check_generator_arguments(
|
||
|
y=None, sample_weight=None, validation_split=None
|
||
|
):
|
||
|
"""Validates arguments passed when using a generator."""
|
||
|
if y is not None:
|
||
|
raise ValueError(
|
||
|
"`y` argument is not supported when data is"
|
||
|
"a generator or Sequence instance. Instead pass targets"
|
||
|
" as the second element of the generator."
|
||
|
)
|
||
|
if sample_weight is not None:
|
||
|
raise ValueError(
|
||
|
"`sample_weight` argument is not supported when data is"
|
||
|
"a generator or Sequence instance. Instead pass sample"
|
||
|
" weights as the third element of the generator."
|
||
|
)
|
||
|
if validation_split:
|
||
|
raise ValueError(
|
||
|
"If your data is in the form of a Python generator, "
|
||
|
"you cannot use `validation_split`."
|
||
|
)
|
||
|
|
||
|
|
||
|
def check_steps_argument(input_data, steps, steps_name):
|
||
|
"""Validates `steps` argument based on input data's type.
|
||
|
|
||
|
The cases when `steps` value must be provided are when
|
||
|
1. input data passed is an iterator.
|
||
|
2. model was built on top of symbolic tensors, input data is not
|
||
|
required and is `None`.
|
||
|
3. input data passed is a symbolic tensor.
|
||
|
|
||
|
Args:
|
||
|
input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
|
||
|
tf.data.Dataset iterator or `None`.
|
||
|
steps: Integer or `None`. Total number of steps (batches of samples) to
|
||
|
execute.
|
||
|
steps_name: The public API's parameter name for `steps`.
|
||
|
|
||
|
Returns:
|
||
|
boolean, True if `steps` argument is required, else False.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if `steps` argument is required for given input data type
|
||
|
but not provided.
|
||
|
"""
|
||
|
is_x_iterator = isinstance(
|
||
|
input_data, (tf.compat.v1.data.Iterator, tf.data.Iterator)
|
||
|
)
|
||
|
if (
|
||
|
input_data is None
|
||
|
or is_x_iterator
|
||
|
or has_symbolic_tensors(input_data)
|
||
|
or (isinstance(input_data, list) and not input_data)
|
||
|
):
|
||
|
if steps is None:
|
||
|
input_type_str = (
|
||
|
"a Dataset iterator" if is_x_iterator else "data tensors"
|
||
|
)
|
||
|
raise ValueError(
|
||
|
"When using {input_type} as input to a model, you should"
|
||
|
" specify the `{steps_name}` argument.".format(
|
||
|
input_type=input_type_str, steps_name=steps_name
|
||
|
)
|
||
|
)
|
||
|
return True
|
||
|
|
||
|
if isinstance(input_data, (tf.compat.v1.data.Dataset, tf.data.Dataset)):
|
||
|
return True
|
||
|
|
||
|
if steps is not None:
|
||
|
list_types = (np.ndarray, list, tuple)
|
||
|
if isinstance(input_data, list_types) or (
|
||
|
isinstance(input_data, dict)
|
||
|
and any(isinstance(v, list_types) for v in input_data.values())
|
||
|
):
|
||
|
logging.warning(
|
||
|
"When passing input data as arrays, do not specify "
|
||
|
"`steps_per_epoch`/`steps` argument. "
|
||
|
"Please use `batch_size` instead."
|
||
|
)
|
||
|
return False
|
||
|
|
||
|
|
||
|
def cast_single_tensor(x, dtype=None):
|
||
|
if isinstance(x, np.ndarray):
|
||
|
x = tf.convert_to_tensor(x)
|
||
|
dtype = dtype or backend.floatx()
|
||
|
if x.dtype.is_floating:
|
||
|
return tf.cast(x, dtype=dtype)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def cast_if_floating_dtype_and_mismatch(targets, outputs):
|
||
|
"""Returns target data tensors using correct datatype.
|
||
|
|
||
|
Checks that each target and output pair are the same datatype. If not, casts
|
||
|
the target to the output's datatype.
|
||
|
|
||
|
Args:
|
||
|
targets: tensor or list of targets.
|
||
|
outputs: tensor or list of outputs.
|
||
|
|
||
|
Returns:
|
||
|
Targets in appropriate datatype.
|
||
|
"""
|
||
|
if tf.is_tensor(targets):
|
||
|
# There is one target, so output[0] should be the only output.
|
||
|
return cast_single_tensor(targets, dtype=outputs[0].dtype)
|
||
|
new_targets = []
|
||
|
for target, out in zip(targets, outputs):
|
||
|
if isinstance(target, np.ndarray):
|
||
|
target = tf.convert_to_tensor(target)
|
||
|
if target.dtype != out.dtype:
|
||
|
new_targets.append(cast_single_tensor(target, dtype=out.dtype))
|
||
|
else:
|
||
|
new_targets.append(target)
|
||
|
return new_targets
|
||
|
|
||
|
|
||
|
def cast_if_floating_dtype(x, dtype=None):
|
||
|
"""Casts the given data tensors to the default floating point type.
|
||
|
|
||
|
Casts only if the input is already a floating point type.
|
||
|
Args:
|
||
|
x: tensor or list/tuple of tensors.
|
||
|
dtype: The dtype to which Tensors should be cast.
|
||
|
|
||
|
Returns:
|
||
|
Converted input.
|
||
|
"""
|
||
|
return tf.nest.map_structure(
|
||
|
functools.partial(cast_single_tensor, dtype=dtype), x
|
||
|
)
|
||
|
|
||
|
|
||
|
def cast_to_model_input_dtypes(x, model):
|
||
|
"""Casts the given data tensors to the dtypes of the model inputs.
|
||
|
|
||
|
Args:
|
||
|
x: tensor or list/tuple of tensors.
|
||
|
model: The model.
|
||
|
|
||
|
Returns:
|
||
|
Converted input. Each tensor is casted to the corresponding input in
|
||
|
`model.inputs`.
|
||
|
"""
|
||
|
input_dtypes = tf.nest.map_structure(lambda t: t.dtype, model.inputs)
|
||
|
return tf.nest.map_structure(tf.cast, x, input_dtypes)
|
||
|
|
||
|
|
||
|
def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
|
||
|
"""Prepares sample weight modes for the model.
|
||
|
|
||
|
Args:
|
||
|
training_endpoints: List of model _TrainingEndpoints.
|
||
|
sample_weight_mode: sample weight mode user input passed from compile API.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case of invalid `sample_weight_mode` input.
|
||
|
"""
|
||
|
|
||
|
if isinstance(sample_weight_mode, collections.abc.Mapping):
|
||
|
generic_utils.check_for_unexpected_keys(
|
||
|
"sample_weight_mode",
|
||
|
sample_weight_mode,
|
||
|
[e.output_name for e in training_endpoints],
|
||
|
)
|
||
|
|
||
|
for end_point in training_endpoints:
|
||
|
if not end_point.should_skip_target_weights():
|
||
|
if end_point.output_name not in sample_weight_mode:
|
||
|
raise ValueError(
|
||
|
"Output "
|
||
|
+ end_point.output_name
|
||
|
+ "missing from `_sample_weight_modes` dictionary"
|
||
|
)
|
||
|
else:
|
||
|
end_point.sample_weight_mode = sample_weight_mode.get(
|
||
|
end_point.output_name
|
||
|
)
|
||
|
elif isinstance(sample_weight_mode, (list, tuple)):
|
||
|
if len(sample_weight_mode) != len(training_endpoints):
|
||
|
raise ValueError(
|
||
|
"When passing a list as sample_weight_mode, "
|
||
|
"it should have one entry per model output. "
|
||
|
"The model has "
|
||
|
+ str(len(training_endpoints))
|
||
|
+ " outputs, but you passed "
|
||
|
+ str(len(sample_weight_mode))
|
||
|
+ "_sample_weight_modes."
|
||
|
)
|
||
|
for mode, endpoint in zip(sample_weight_mode, training_endpoints):
|
||
|
if not endpoint.should_skip_target_weights():
|
||
|
endpoint.sample_weight_mode = mode
|
||
|
else:
|
||
|
for endpoint in training_endpoints:
|
||
|
if not endpoint.should_skip_target_weights():
|
||
|
endpoint.sample_weight_mode = sample_weight_mode
|
||
|
|
||
|
|
||
|
def prepare_loss_functions(loss, output_names):
|
||
|
"""Converts loss to a list of loss functions.
|
||
|
|
||
|
Args:
|
||
|
loss: String (name of objective function), objective function or
|
||
|
`tf.keras.losses.Loss` instance. See `tf.keras.losses`.
|
||
|
If the model has multiple
|
||
|
outputs, you can use a different loss on each output by passing a
|
||
|
dictionary or a list of losses. The loss value that will be minimized
|
||
|
by the model will then be the sum of all individual losses.
|
||
|
output_names: List of model output names.
|
||
|
|
||
|
Returns:
|
||
|
A list of loss objective functions.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If loss is a dict with keys not in model output names,
|
||
|
or if loss is a list with len not equal to model outputs.
|
||
|
"""
|
||
|
if isinstance(loss, collections.abc.Mapping):
|
||
|
generic_utils.check_for_unexpected_keys("loss", loss, output_names)
|
||
|
loss_functions = []
|
||
|
for name in output_names:
|
||
|
if name not in loss:
|
||
|
logging.warning(
|
||
|
"Output {0} missing from loss dictionary. We assume "
|
||
|
"this was done on purpose. The fit and evaluate APIs will "
|
||
|
f"not be expecting any data to be passed to {name}."
|
||
|
)
|
||
|
loss_functions.append(get_loss_function(loss.get(name, None)))
|
||
|
elif isinstance(loss, str):
|
||
|
loss_functions = [get_loss_function(loss) for _ in output_names]
|
||
|
elif isinstance(loss, collections.abc.Sequence):
|
||
|
if len(loss) != len(output_names):
|
||
|
raise ValueError(
|
||
|
"When passing a list as loss, it should have one entry "
|
||
|
"per model outputs. The model has {} outputs, but you "
|
||
|
"passed loss={}".format(len(output_names), loss)
|
||
|
)
|
||
|
loss_functions = tf.nest.map_structure(get_loss_function, loss)
|
||
|
else:
|
||
|
loss_functions = [
|
||
|
get_loss_function(loss) for _ in range(len(output_names))
|
||
|
]
|
||
|
|
||
|
return loss_functions
|
||
|
|
||
|
|
||
|
def prepare_loss_weights(training_endpoints, loss_weights=None):
|
||
|
"""Converts loss weights to a list of loss weights.
|
||
|
|
||
|
The result loss weights will be populated on the training endpoint.
|
||
|
|
||
|
Args:
|
||
|
training_endpoints: List of model training endpoints.
|
||
|
loss_weights: Optional list or dictionary specifying scalar coefficients
|
||
|
(Python floats) to weight the loss contributions of different model
|
||
|
outputs. The loss value that will be minimized by the model will then
|
||
|
be the *weighted sum* of all individual losses, weighted by the
|
||
|
`loss_weights` coefficients. If a list, it is expected to have a 1:1
|
||
|
mapping to the model's outputs. If a dict, it is expected to map
|
||
|
output names (strings) to scalar coefficients.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If loss weight is a dict with key not in model output names,
|
||
|
or if loss is a list with len not equal to model outputs.
|
||
|
"""
|
||
|
if loss_weights is None:
|
||
|
for e in training_endpoints:
|
||
|
e.loss_weight = 1.0
|
||
|
elif isinstance(loss_weights, collections.abc.Mapping):
|
||
|
generic_utils.check_for_unexpected_keys(
|
||
|
"loss_weights",
|
||
|
loss_weights,
|
||
|
[e.output_name for e in training_endpoints],
|
||
|
)
|
||
|
for e in training_endpoints:
|
||
|
e.loss_weight = loss_weights.get(e.output_name, 1.0)
|
||
|
elif isinstance(loss_weights, list):
|
||
|
if len(loss_weights) != len(training_endpoints):
|
||
|
raise ValueError(
|
||
|
"When passing a list as loss_weights, "
|
||
|
"it should have one entry per model output. "
|
||
|
"The model has "
|
||
|
+ str(len(training_endpoints))
|
||
|
+ " outputs, but you passed loss_weights="
|
||
|
+ str(loss_weights)
|
||
|
)
|
||
|
for w, e in zip(loss_weights, training_endpoints):
|
||
|
e.loss_weight = w
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"Could not interpret loss_weights argument: "
|
||
|
+ str(loss_weights)
|
||
|
+ " - expected a list of dicts."
|
||
|
)
|
||
|
|
||
|
|
||
|
# TODO(rohanj): This is a hack to get around not depending on feature_column and
|
||
|
# create a cyclical dependency. Figure out a cleaner solution
|
||
|
def is_feature_layer(layer):
|
||
|
"""Returns whether `layer` is a FeatureLayer or not."""
|
||
|
return getattr(layer, "_is_feature_layer", False)
|
||
|
|
||
|
|
||
|
def is_eager_dataset_or_iterator(data):
|
||
|
return tf.executing_eagerly() and isinstance(
|
||
|
data, (tf.compat.v1.data.Dataset, tf.data.Dataset, tf.data.Iterator)
|
||
|
)
|
||
|
|
||
|
|
||
|
def get_dataset_graph_def(dataset):
|
||
|
if tf.executing_eagerly():
|
||
|
graph_def_str = dataset._as_serialized_graph().numpy()
|
||
|
else:
|
||
|
graph_def_str = backend.get_value(dataset._as_serialized_graph())
|
||
|
return tf.compat.v1.GraphDef().FromString(graph_def_str)
|
||
|
|
||
|
|
||
|
def verify_dataset_shuffled(x):
|
||
|
"""Verifies that the dataset is shuffled.
|
||
|
|
||
|
Args:
|
||
|
x: Dataset passed as an input to the model.
|
||
|
|
||
|
Returns:
|
||
|
boolean, whether the input dataset is shuffled or not.
|
||
|
"""
|
||
|
assert isinstance(x, tf.data.Dataset)
|
||
|
graph_def = get_dataset_graph_def(x)
|
||
|
for node in graph_def.node:
|
||
|
if node.op.startswith("ShuffleDataset"):
|
||
|
return True
|
||
|
# Also check graph_def.library.function for ds.interleave or ds.flat_map
|
||
|
for function in graph_def.library.function:
|
||
|
for node in function.node_def:
|
||
|
if node.op.startswith("ShuffleDataset"):
|
||
|
return True
|
||
|
logging.warning(
|
||
|
"Expected a shuffled dataset but input dataset `x` is "
|
||
|
"not shuffled. Please invoke `shuffle()` on input dataset."
|
||
|
)
|
||
|
return False
|
||
|
|
||
|
|
||
|
def is_dataset_or_iterator(data):
|
||
|
return isinstance(
|
||
|
data,
|
||
|
(
|
||
|
tf.compat.v1.data.Dataset,
|
||
|
tf.data.Dataset,
|
||
|
tf.compat.v1.data.Iterator,
|
||
|
tf.data.Iterator,
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
def get_iterator(dataset):
|
||
|
"""Create and initialize an iterator from a dataset."""
|
||
|
if tf.executing_eagerly():
|
||
|
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
|
||
|
else:
|
||
|
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
|
||
|
initialize_iterator(iterator)
|
||
|
return iterator
|
||
|
|
||
|
|
||
|
def initialize_iterator(iterator):
|
||
|
if not tf.executing_eagerly():
|
||
|
init_op = iterator.initializer
|
||
|
backend.get_session((init_op,)).run(init_op)
|
||
|
|
||
|
|
||
|
def extract_tensors_from_dataset(dataset):
|
||
|
"""Extract tuple of tensors `inputs, targets, sample_weight` from a dataset.
|
||
|
|
||
|
Args:
|
||
|
dataset: Dataset instance.
|
||
|
|
||
|
Returns:
|
||
|
Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
|
||
|
"""
|
||
|
iterator = get_iterator(dataset)
|
||
|
inputs, targets, sample_weight = unpack_iterator_input(iterator)
|
||
|
return inputs, targets, sample_weight
|
||
|
|
||
|
|
||
|
def unpack_iterator_input(iterator):
|
||
|
"""Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`.
|
||
|
|
||
|
Args:
|
||
|
iterator: Instance of a dataset iterator.
|
||
|
|
||
|
Returns:
|
||
|
Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
|
||
|
"""
|
||
|
try:
|
||
|
next_element = iterator.get_next()
|
||
|
except tf.errors.OutOfRangeError:
|
||
|
raise RuntimeError(
|
||
|
"Your dataset iterator ran out of data; "
|
||
|
"Make sure that your dataset can generate "
|
||
|
"required number of samples."
|
||
|
)
|
||
|
|
||
|
if isinstance(next_element, (list, tuple)):
|
||
|
if len(next_element) not in [2, 3]:
|
||
|
raise ValueError(
|
||
|
"Please provide model inputs as a list or tuple of 2 or 3 "
|
||
|
"elements: (input, target) or (input, target, sample_weights) "
|
||
|
"Received %s" % next_element
|
||
|
)
|
||
|
if len(next_element) == 2:
|
||
|
x, y = next_element
|
||
|
weights = None
|
||
|
else:
|
||
|
x, y, weights = next_element
|
||
|
else:
|
||
|
x = next_element
|
||
|
y = None
|
||
|
weights = None
|
||
|
return x, y, weights
|
||
|
|
||
|
|
||
|
def infer_steps_for_dataset(
|
||
|
model, dataset, steps, epochs=1, steps_name="steps"
|
||
|
):
|
||
|
"""Infers steps_per_epoch needed to loop through a dataset.
|
||
|
|
||
|
Args:
|
||
|
model: Keras model instance.
|
||
|
dataset: Input data of type tf.data.Dataset.
|
||
|
steps: Number of steps to draw from the dataset (may be None if
|
||
|
unknown).
|
||
|
epochs: Number of times to iterate over the dataset.
|
||
|
steps_name: The string name of the steps argument, either `steps`,
|
||
|
`validation_steps`, or `steps_per_epoch`. Only used for error message
|
||
|
formatting.
|
||
|
|
||
|
Returns:
|
||
|
Integer or `None`. Inferred number of steps to loop through the dataset.
|
||
|
`None` is returned if 1) the size of the dataset is unknown and `steps`
|
||
|
was not specified, or 2) this is multi-worker training and auto sharding
|
||
|
is enabled.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case of invalid argument values.
|
||
|
"""
|
||
|
assert isinstance(dataset, tf.data.Dataset)
|
||
|
if model._in_multi_worker_mode() and (
|
||
|
dataset.options().experimental_distribute.auto_shard_policy
|
||
|
!= tf.data.experimental.AutoShardPolicy.OFF
|
||
|
):
|
||
|
# If the dataset would be auto-sharded, we should not infer a local
|
||
|
# steps_per_epoch due to the possible imbalanced sharding between
|
||
|
# workers.
|
||
|
return None
|
||
|
|
||
|
size = backend.get_value(tf.data.experimental.cardinality(dataset))
|
||
|
if size == tf.data.experimental.INFINITE_CARDINALITY and steps is None:
|
||
|
raise ValueError(
|
||
|
"When passing an infinitely repeating dataset, you "
|
||
|
"must specify the `%s` argument." % (steps_name,)
|
||
|
)
|
||
|
if size >= 0:
|
||
|
if steps is not None and steps * epochs > size:
|
||
|
if epochs > 1:
|
||
|
raise ValueError(
|
||
|
"The dataset you passed contains %s batches, but you "
|
||
|
"passed `epochs=%s` and `%s=%s`, which is a total of "
|
||
|
"%s steps. We cannot draw that many steps from this "
|
||
|
"dataset. We suggest to set `%s=%s`."
|
||
|
% (
|
||
|
size,
|
||
|
epochs,
|
||
|
steps_name,
|
||
|
steps,
|
||
|
steps * epochs,
|
||
|
steps_name,
|
||
|
size // epochs,
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"The dataset you passed contains %s batches, but you "
|
||
|
"passed `%s=%s`. We cannot draw that many steps from "
|
||
|
"this dataset. We suggest to set `%s=%s`."
|
||
|
% (size, steps_name, steps, steps_name, size)
|
||
|
)
|
||
|
if steps is None:
|
||
|
if size >= 0:
|
||
|
return size
|
||
|
return None
|
||
|
return steps
|
||
|
|
||
|
|
||
|
class ModelInputs:
|
||
|
"""Encapsulates model inputs.
|
||
|
|
||
|
Allows for transforming model inputs while keeping the same structure.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, inputs):
|
||
|
self._inputs = inputs
|
||
|
self._is_dict = isinstance(self._inputs, dict)
|
||
|
self._is_single_input = not isinstance(
|
||
|
self._inputs, (list, tuple, dict)
|
||
|
)
|
||
|
|
||
|
self._flattened_inputs = []
|
||
|
self._input_names = []
|
||
|
|
||
|
if self._is_dict:
|
||
|
for k in sorted(self._inputs.keys()):
|
||
|
self._flattened_inputs.append(self._inputs[k])
|
||
|
self._input_names.append(k)
|
||
|
else:
|
||
|
self._flattened_inputs = tf.nest.flatten(self._inputs)
|
||
|
self._input_names = [
|
||
|
"input_%d" % (i + 1) for i in range(len(self._flattened_inputs))
|
||
|
]
|
||
|
|
||
|
def get_input_names(self):
|
||
|
"""Returns keys to name inputs by.
|
||
|
|
||
|
In case inputs provided were a list, tuple or single entry, we make up a
|
||
|
key 'input_%d'. For dictionary case, we return a sorted list of keys.
|
||
|
"""
|
||
|
return self._input_names
|
||
|
|
||
|
def get_symbolic_inputs(self, return_single_as_list=False):
|
||
|
"""Returns inputs to be set as self.inputs for a model."""
|
||
|
# TODO(karmel): There is a side-effect here where what you get
|
||
|
# with as_list and as_dict depends on whether you have called this
|
||
|
# method first, since it modifies in place.
|
||
|
for i, (k, v) in enumerate(
|
||
|
zip(self._input_names, self._flattened_inputs)
|
||
|
):
|
||
|
if isinstance(v, (list, float, int)):
|
||
|
v = np.asarray(v)
|
||
|
if v.ndim == 1:
|
||
|
v = np.expand_dims(v, 1)
|
||
|
|
||
|
if isinstance(v, np.ndarray):
|
||
|
# We fix the placeholder shape except the batch size.
|
||
|
# This is suboptimal, but it is the best we can do with the info
|
||
|
# we have. The user should call
|
||
|
# `model._set_inputs(placeholders)` to specify custom
|
||
|
# placeholders if the need arises.
|
||
|
shape = (None,) + tuple(v.shape[1:])
|
||
|
if shape == (None,):
|
||
|
shape = (None, 1)
|
||
|
dtype = tf.as_dtype(v.dtype)
|
||
|
if dtype.is_floating:
|
||
|
dtype = backend.floatx()
|
||
|
v = backend.placeholder(shape=shape, name=k, dtype=dtype)
|
||
|
elif isinstance(v, tf.TensorSpec):
|
||
|
shape = (None,) + tuple(v.shape.as_list()[1:])
|
||
|
if shape == (None,):
|
||
|
shape = (None, 1)
|
||
|
v = backend.placeholder(shape=shape, name=k, dtype=v.dtype)
|
||
|
|
||
|
self._flattened_inputs[i] = v
|
||
|
|
||
|
if self._is_dict:
|
||
|
return dict(zip(self._input_names, self._flattened_inputs))
|
||
|
if self._is_single_input and not return_single_as_list:
|
||
|
return self._flattened_inputs[0]
|
||
|
return self._flattened_inputs
|
||
|
|
||
|
def as_dict(self):
|
||
|
"""An iterable over a dictionary version of inputs."""
|
||
|
for k, v in zip(self._input_names, self._flattened_inputs):
|
||
|
yield k, v
|
||
|
|
||
|
def as_list(self):
|
||
|
"""Returning the inputs as a list."""
|
||
|
return self._flattened_inputs
|
||
|
|
||
|
|
||
|
# Allow use of methods not exposed to the user.
|
||
|
|
||
|
|
||
|
def generic_output_names(outputs_list):
|
||
|
return ["output_%d" % (i + 1) for i in range(len(outputs_list))]
|
||
|
|
||
|
|
||
|
def should_run_validation(validation_freq, epoch):
|
||
|
"""Checks if validation should be run this epoch.
|
||
|
|
||
|
Args:
|
||
|
validation_freq: Integer or list. If an integer, specifies how many
|
||
|
training epochs to run before a new validation run is performed. If a
|
||
|
list, specifies the epochs on which to run validation.
|
||
|
epoch: Integer, the number of the training epoch just completed.
|
||
|
|
||
|
Returns:
|
||
|
Bool, True if validation should be run.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if `validation_freq` is an Integer and less than 1, or if
|
||
|
it is neither an Integer nor a Sequence.
|
||
|
"""
|
||
|
# `epoch` is 0-indexed internally but 1-indexed in the public API.
|
||
|
one_indexed_epoch = epoch + 1
|
||
|
|
||
|
if isinstance(validation_freq, int):
|
||
|
if validation_freq < 1:
|
||
|
raise ValueError("`validation_freq` can not be less than 1.")
|
||
|
return one_indexed_epoch % validation_freq == 0
|
||
|
|
||
|
if not isinstance(validation_freq, collections.abc.Container):
|
||
|
raise ValueError(
|
||
|
"`validation_freq` must be an Integer or "
|
||
|
"`collections.abc.Container` (e.g. list, tuple, etc.)"
|
||
|
)
|
||
|
return one_indexed_epoch in validation_freq
|
||
|
|
||
|
|
||
|
def split_training_and_validation_data(x, y, sample_weights, validation_split):
|
||
|
"""Split input data into train/eval section based on validation_split."""
|
||
|
if has_symbolic_tensors(x):
|
||
|
raise ValueError(
|
||
|
"If your data is in the form of symbolic tensors, "
|
||
|
"you cannot use `validation_split`."
|
||
|
)
|
||
|
if hasattr(x[0], "shape"):
|
||
|
split_at = int(x[0].shape[0] * (1.0 - validation_split))
|
||
|
else:
|
||
|
split_at = int(len(x[0]) * (1.0 - validation_split))
|
||
|
x, val_x = (
|
||
|
generic_utils.slice_arrays(x, 0, split_at),
|
||
|
generic_utils.slice_arrays(x, split_at),
|
||
|
)
|
||
|
y, val_y = (
|
||
|
generic_utils.slice_arrays(y, 0, split_at),
|
||
|
generic_utils.slice_arrays(y, split_at),
|
||
|
)
|
||
|
if sample_weights:
|
||
|
sample_weights, val_sample_weights = (
|
||
|
generic_utils.slice_arrays(sample_weights, 0, split_at),
|
||
|
generic_utils.slice_arrays(sample_weights, split_at),
|
||
|
)
|
||
|
else:
|
||
|
val_sample_weights = None
|
||
|
return x, y, sample_weights, val_x, val_y, val_sample_weights
|
||
|
|
||
|
|
||
|
def unpack_validation_data(validation_data, raise_if_ambiguous=True):
|
||
|
"""Unpack validation data based input type.
|
||
|
|
||
|
The validation data is not touched if its dataset or dataset iterator.
|
||
|
For other type of input (Numpy or tensor), it will be unpacked into tuple of
|
||
|
3 which is x, y and sample weights.
|
||
|
|
||
|
Args:
|
||
|
validation_data: dataset, dataset iterator, or numpy, tensor tuple.
|
||
|
raise_if_ambiguous: boolean on whether to fail if validation_data cannot
|
||
|
be parsed. Otherwise simply return validation_data, None, None and defer
|
||
|
the decision to the caller.
|
||
|
|
||
|
Returns:
|
||
|
tuple of 3, (x, y, sample_weights) for numpy and tensor input.
|
||
|
"""
|
||
|
if isinstance(
|
||
|
validation_data,
|
||
|
(
|
||
|
tf.compat.v1.data.Iterator,
|
||
|
tf.data.Iterator,
|
||
|
tf.data.Dataset,
|
||
|
data_utils.Sequence,
|
||
|
),
|
||
|
) or not hasattr(validation_data, "__len__"):
|
||
|
val_x = validation_data
|
||
|
val_y = None
|
||
|
val_sample_weight = None
|
||
|
elif len(validation_data) == 2:
|
||
|
try:
|
||
|
(
|
||
|
val_x,
|
||
|
val_y,
|
||
|
) = validation_data
|
||
|
val_sample_weight = None
|
||
|
except ValueError:
|
||
|
val_x, val_y, val_sample_weight = validation_data, None, None
|
||
|
elif len(validation_data) == 3:
|
||
|
try:
|
||
|
(
|
||
|
val_x,
|
||
|
val_y,
|
||
|
val_sample_weight,
|
||
|
) = validation_data
|
||
|
except ValueError:
|
||
|
val_x, val_y, val_sample_weight = validation_data, None, None
|
||
|
else:
|
||
|
if raise_if_ambiguous:
|
||
|
raise ValueError(
|
||
|
"When passing a `validation_data` argument, "
|
||
|
"it must contain either 2 items (x_val, y_val), "
|
||
|
"or 3 items (x_val, y_val, val_sample_weights), "
|
||
|
"or alternatively it could be a dataset or a "
|
||
|
"dataset or a dataset iterator. "
|
||
|
"However we received `validation_data=%s`" % validation_data
|
||
|
)
|
||
|
val_x, val_y, val_sample_weight = validation_data, None, None
|
||
|
return val_x, val_y, val_sample_weight
|
||
|
|
||
|
|
||
|
class TrainingLoop:
|
||
|
"""TrainingLoop is a wrapper class around the training logic.
|
||
|
|
||
|
This class is trying to encapsulate the different logic of fit/eval/predict
|
||
|
with regard to different data input and model condition.
|
||
|
|
||
|
Note that TrainingLoop is stateless, which means it doesn't contain any
|
||
|
internal field and can be reused with different model and inputs.
|
||
|
"""
|
||
|
|
||
|
def fit(
|
||
|
self,
|
||
|
model,
|
||
|
x=None,
|
||
|
y=None,
|
||
|
batch_size=None,
|
||
|
epochs=1,
|
||
|
verbose=1,
|
||
|
callbacks=None,
|
||
|
validation_split=0.0,
|
||
|
validation_data=None,
|
||
|
shuffle=True,
|
||
|
class_weight=None,
|
||
|
sample_weight=None,
|
||
|
initial_epoch=0,
|
||
|
steps_per_epoch=None,
|
||
|
validation_steps=None,
|
||
|
validation_freq=1,
|
||
|
**kwargs,
|
||
|
):
|
||
|
"""Train the model with the inputs and targets."""
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def evaluate(
|
||
|
self,
|
||
|
model,
|
||
|
x=None,
|
||
|
y=None,
|
||
|
batch_size=None,
|
||
|
verbose=1,
|
||
|
sample_weight=None,
|
||
|
steps=None,
|
||
|
callbacks=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
"""Returns the loss value & metrics values for the model in test
|
||
|
mode."""
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def predict(
|
||
|
self,
|
||
|
model,
|
||
|
x,
|
||
|
batch_size=None,
|
||
|
verbose=0,
|
||
|
steps=None,
|
||
|
callbacks=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
raise NotImplementedError()
|