867 lines
31 KiB
Python
867 lines
31 KiB
Python
![]() |
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
"""Utilities for `Model.compile`."""
|
||
|
|
||
|
|
||
|
import copy
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import losses as losses_mod
|
||
|
from keras import metrics as metrics_mod
|
||
|
from keras.saving import saving_lib
|
||
|
from keras.utils import generic_utils
|
||
|
from keras.utils import losses_utils
|
||
|
from keras.utils import tf_utils
|
||
|
|
||
|
|
||
|
class Container:
|
||
|
"""Base Container class."""
|
||
|
|
||
|
def __init__(self, output_names=None):
|
||
|
self._output_names = output_names
|
||
|
|
||
|
def build(self, y_pred):
|
||
|
if self._output_names is None:
|
||
|
# In Subclass API, output names like 'output_1' are used for
|
||
|
# `Metric` names.
|
||
|
self._output_names = create_pseudo_output_names(y_pred)
|
||
|
|
||
|
def _conform_to_outputs(self, outputs, struct):
|
||
|
"""Convenience method to conform `struct` to `outputs` structure.
|
||
|
|
||
|
Mappings performed:
|
||
|
|
||
|
(1) Map a dict to a list of outputs, using the output names.
|
||
|
(2) Fill missing keys in a dict w/ `None`s.
|
||
|
(3) Map a single item to all outputs.
|
||
|
|
||
|
Args:
|
||
|
outputs: Model predictions.
|
||
|
struct: Arbitrary nested structure (e.g. of labels, sample_weights,
|
||
|
losses, or metrics).
|
||
|
|
||
|
Returns:
|
||
|
Mapping of `struct` to `outputs` structure.
|
||
|
"""
|
||
|
struct = map_to_output_names(outputs, self._output_names, struct)
|
||
|
struct = map_missing_dict_keys(outputs, struct)
|
||
|
# Allow passing one object that applies to all outputs.
|
||
|
if not tf.nest.is_nested(struct) and tf.nest.is_nested(outputs):
|
||
|
struct = tf.nest.map_structure(lambda _: struct, outputs)
|
||
|
return struct
|
||
|
|
||
|
def _maybe_broadcast_to_outputs(self, outputs, objects):
|
||
|
"""Determines if losses / metrics should be applied to all outputs.
|
||
|
|
||
|
NOTE: This method should only be called for Metrics / Losses, not for
|
||
|
y_true / sample_weight.
|
||
|
|
||
|
Args:
|
||
|
outputs: Model predictions.
|
||
|
objects: Arbitrary nested structure (e.g. of losses or metrics)
|
||
|
|
||
|
Returns:
|
||
|
Arbitrary nested structure of objects, maybe copied to each output.
|
||
|
|
||
|
Applies a Loss / Metric to all outputs.
|
||
|
"""
|
||
|
if not self._should_broadcast(objects):
|
||
|
return objects
|
||
|
|
||
|
# When there is more than one Model output, this is needed to keep
|
||
|
# each Metric / Loss separate. When there is only one Model output,
|
||
|
# the user-supplied object should be used.
|
||
|
should_copy_objects = len(tf.nest.flatten(outputs)) > 1
|
||
|
|
||
|
def _broadcast_fn():
|
||
|
if should_copy_objects:
|
||
|
return tf.nest.map_structure(self._copy_object, objects)
|
||
|
return objects
|
||
|
|
||
|
return tf.nest.map_structure(lambda _: _broadcast_fn(), outputs)
|
||
|
|
||
|
def _should_broadcast(self, objects):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def _copy_object(self, obj):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
class LossesContainer(Container):
|
||
|
"""A container class for losses passed to `Model.compile()`.
|
||
|
|
||
|
Args:
|
||
|
losses: Struct of loss function(s). See `Model.compile()` doc for more
|
||
|
information.
|
||
|
loss_weights: Weights of the losses contributions of different model
|
||
|
outputs. See `Model.compile()` doc for more information.
|
||
|
output_names: List of string. Per-output metric names.
|
||
|
total_loss_mean: A `keras.metrics.Mean` instance that is used to track the
|
||
|
mean of all losses (including compiled and regularization losses).
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self, losses, loss_weights=None, output_names=None, total_loss_mean=None
|
||
|
):
|
||
|
super(LossesContainer, self).__init__(output_names=output_names)
|
||
|
|
||
|
# Keep user-supplied values untouched for recompiling and serialization.
|
||
|
self._user_losses = losses
|
||
|
self._user_loss_weights = loss_weights
|
||
|
|
||
|
self._losses = losses
|
||
|
self._loss_weights = loss_weights
|
||
|
self._per_output_metrics = None # Per-output losses become metrics.
|
||
|
|
||
|
# Mean of the total loss.
|
||
|
self._total_loss_mean = total_loss_mean or metrics_mod.Mean(name="loss")
|
||
|
self._built = False
|
||
|
|
||
|
def get_config(self):
|
||
|
# In case `self._losses` is a single string where we convert it to a
|
||
|
# list.
|
||
|
self._losses = tf.nest.flatten(self._losses)
|
||
|
return {
|
||
|
"losses": [
|
||
|
saving_lib.serialize_keras_object(obj)
|
||
|
for obj in self._losses
|
||
|
if obj is not None
|
||
|
],
|
||
|
"total_loss_mean": saving_lib.serialize_keras_object(
|
||
|
self._total_loss_mean
|
||
|
),
|
||
|
}
|
||
|
|
||
|
@classmethod
|
||
|
def from_config(cls, config):
|
||
|
"""Returns the `LossesContainer` instance given the `config`."""
|
||
|
deserialized_config = {}
|
||
|
for key, value in config.items():
|
||
|
if isinstance(value, list):
|
||
|
deserialized_config[key] = [
|
||
|
saving_lib.deserialize_keras_object(item) for item in value
|
||
|
]
|
||
|
else:
|
||
|
deserialized_config[key] = saving_lib.deserialize_keras_object(
|
||
|
value
|
||
|
)
|
||
|
return cls(**deserialized_config)
|
||
|
|
||
|
@property
|
||
|
def metrics(self):
|
||
|
"""Per-output loss metrics."""
|
||
|
if not self._built:
|
||
|
return []
|
||
|
per_output_metrics = [
|
||
|
metric_obj
|
||
|
for metric_obj in tf.nest.flatten(self._per_output_metrics)
|
||
|
if metric_obj is not None
|
||
|
]
|
||
|
return [self._total_loss_mean] + per_output_metrics
|
||
|
|
||
|
def build(self, y_pred):
|
||
|
"""One-time setup of loss objects."""
|
||
|
super(LossesContainer, self).build(y_pred)
|
||
|
|
||
|
self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses)
|
||
|
self._losses = self._conform_to_outputs(y_pred, self._losses)
|
||
|
self._losses = tf.nest.map_structure(
|
||
|
self._get_loss_object, self._losses
|
||
|
)
|
||
|
self._losses = tf.nest.flatten(self._losses)
|
||
|
|
||
|
self._loss_weights = self._maybe_broadcast_to_outputs(
|
||
|
y_pred, self._loss_weights
|
||
|
)
|
||
|
self._loss_weights = self._conform_to_outputs(
|
||
|
y_pred, self._loss_weights
|
||
|
)
|
||
|
self._loss_weights = tf.nest.flatten(self._loss_weights)
|
||
|
|
||
|
self._create_metrics()
|
||
|
self._built = True
|
||
|
|
||
|
@property
|
||
|
def built(self):
|
||
|
return self._built
|
||
|
|
||
|
def _create_metrics(self):
|
||
|
"""Creates per-output loss metrics, but only for multi-output Models."""
|
||
|
if len(self._output_names) == 1:
|
||
|
self._per_output_metrics = [None]
|
||
|
else:
|
||
|
self._per_output_metrics = []
|
||
|
for loss_obj, output_name in zip(self._losses, self._output_names):
|
||
|
if loss_obj is None:
|
||
|
self._per_output_metrics.append(None)
|
||
|
else:
|
||
|
self._per_output_metrics.append(
|
||
|
metrics_mod.Mean(output_name + "_loss")
|
||
|
)
|
||
|
|
||
|
def __call__(
|
||
|
self, y_true, y_pred, sample_weight=None, regularization_losses=None
|
||
|
):
|
||
|
"""Computes the overall loss.
|
||
|
|
||
|
Args:
|
||
|
y_true: An arbitrary structure of Tensors representing the ground
|
||
|
truth.
|
||
|
y_pred: An arbitrary structure of Tensors representing a Model's
|
||
|
outputs.
|
||
|
sample_weight: An arbitrary structure of Tensors representing the
|
||
|
per-sample loss weights. If one Tensor is passed, it is used for all
|
||
|
losses. If multiple Tensors are passed, the structure should match
|
||
|
`y_pred`.
|
||
|
regularization_losses: Additional losses to be added to the total
|
||
|
loss.
|
||
|
|
||
|
Returns:
|
||
|
The total loss as a `tf.Tensor`, or `None` if no loss results.
|
||
|
"""
|
||
|
y_true = self._conform_to_outputs(y_pred, y_true)
|
||
|
sample_weight = self._conform_to_outputs(y_pred, sample_weight)
|
||
|
|
||
|
if not self._built:
|
||
|
self.build(y_pred)
|
||
|
|
||
|
y_pred = tf.nest.flatten(y_pred)
|
||
|
y_true = tf.nest.flatten(y_true)
|
||
|
sample_weight = tf.nest.flatten(sample_weight)
|
||
|
|
||
|
loss_values = [] # Used for gradient calculation.
|
||
|
total_loss_mean_values = [] # Used for loss metric calculation.
|
||
|
batch_dim = None
|
||
|
zip_args = (
|
||
|
y_true,
|
||
|
y_pred,
|
||
|
sample_weight,
|
||
|
self._losses,
|
||
|
self._loss_weights,
|
||
|
self._per_output_metrics,
|
||
|
)
|
||
|
for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
|
||
|
if (
|
||
|
y_t is None or loss_obj is None
|
||
|
): # Ok to have no loss for an output.
|
||
|
continue
|
||
|
|
||
|
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
|
||
|
sw = losses_utils.apply_mask(y_p, sw, losses_utils.get_mask(y_p))
|
||
|
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
|
||
|
|
||
|
total_loss_mean_value = loss_value
|
||
|
# Correct for the `Mean` loss metrics counting each replica as a
|
||
|
# batch.
|
||
|
if loss_obj.reduction == losses_utils.ReductionV2.SUM:
|
||
|
total_loss_mean_value *= (
|
||
|
tf.distribute.get_strategy().num_replicas_in_sync
|
||
|
)
|
||
|
|
||
|
if batch_dim is None:
|
||
|
if tf_utils.is_ragged(y_t):
|
||
|
batch_dim = y_t.nrows()
|
||
|
else:
|
||
|
batch_dim = tf.shape(y_t)[0]
|
||
|
|
||
|
if metric_obj is not None:
|
||
|
metric_obj.update_state(
|
||
|
total_loss_mean_value, sample_weight=batch_dim
|
||
|
)
|
||
|
|
||
|
if loss_weight is not None:
|
||
|
loss_value *= loss_weight
|
||
|
total_loss_mean_value *= loss_weight
|
||
|
|
||
|
if (
|
||
|
loss_obj.reduction
|
||
|
== losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
|
||
|
or loss_obj.reduction == losses_utils.ReductionV2.AUTO
|
||
|
):
|
||
|
loss_value = losses_utils.scale_loss_for_distribution(
|
||
|
loss_value
|
||
|
)
|
||
|
|
||
|
loss_values.append(loss_value)
|
||
|
total_loss_mean_values.append(total_loss_mean_value)
|
||
|
|
||
|
if regularization_losses:
|
||
|
regularization_losses = losses_utils.cast_losses_to_common_dtype(
|
||
|
regularization_losses
|
||
|
)
|
||
|
reg_loss = tf.add_n(regularization_losses)
|
||
|
total_loss_mean_values.append(reg_loss)
|
||
|
loss_values.append(
|
||
|
losses_utils.scale_loss_for_distribution(reg_loss)
|
||
|
)
|
||
|
|
||
|
if loss_values:
|
||
|
total_loss_mean_values = losses_utils.cast_losses_to_common_dtype(
|
||
|
total_loss_mean_values
|
||
|
)
|
||
|
total_total_loss_mean_value = tf.add_n(total_loss_mean_values)
|
||
|
self._total_loss_mean.update_state(
|
||
|
total_total_loss_mean_value, sample_weight=batch_dim
|
||
|
)
|
||
|
|
||
|
loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
|
||
|
total_loss = tf.add_n(loss_values)
|
||
|
return total_loss
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
def reset_state(self):
|
||
|
"""Resets the state of loss metrics."""
|
||
|
if not self._built:
|
||
|
return
|
||
|
metrics = [self._total_loss_mean] + tf.nest.flatten(
|
||
|
self._per_output_metrics
|
||
|
)
|
||
|
for metric_obj in metrics:
|
||
|
if metric_obj is not None:
|
||
|
metric_obj.reset_state()
|
||
|
|
||
|
def _get_loss_object(self, loss):
|
||
|
"""Returns a `Loss` object.
|
||
|
|
||
|
Converts the user-supplied loss to a `Loss` object. Also allows
|
||
|
`SUM_OVER_BATCH_SIZE` reduction to be used for this loss.
|
||
|
|
||
|
Args:
|
||
|
loss: A string, function, or `Loss` object.
|
||
|
|
||
|
Returns:
|
||
|
A `Loss` object.
|
||
|
"""
|
||
|
if loss is None:
|
||
|
return None # Ok to have no loss for an output.
|
||
|
|
||
|
loss = losses_mod.get(loss)
|
||
|
if not isinstance(loss, losses_mod.Loss):
|
||
|
loss_name = get_custom_object_name(loss)
|
||
|
if loss_name is None:
|
||
|
raise ValueError(f"Loss should be a callable, received: {loss}")
|
||
|
loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
|
||
|
loss._allow_sum_over_batch_size = True
|
||
|
return loss
|
||
|
|
||
|
def _should_broadcast(self, obj):
|
||
|
return not tf.nest.is_nested(obj)
|
||
|
|
||
|
def _copy_object(self, obj):
|
||
|
return obj # Losses don't need to be copied.
|
||
|
|
||
|
|
||
|
class MetricsContainer(Container):
|
||
|
"""A container class for metrics passed to `Model.compile`."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
metrics=None,
|
||
|
weighted_metrics=None,
|
||
|
output_names=None,
|
||
|
from_serialized=False,
|
||
|
):
|
||
|
"""Initializes a container for metrics.
|
||
|
|
||
|
Arguments:
|
||
|
metrics: see the `metrics` argument from `tf.keras.Model.compile`.
|
||
|
weighted_metrics: see the `weighted_metrics` argument from
|
||
|
`tf.keras.Model.compile`.
|
||
|
output_names: A list of strings of names of outputs for the model.
|
||
|
from_serialized: Whether the model being compiled is from a serialized
|
||
|
model. Used to avoid redundantly applying pre-processing renaming
|
||
|
steps.
|
||
|
"""
|
||
|
super(MetricsContainer, self).__init__(output_names=output_names)
|
||
|
|
||
|
self._check_duplicated_metrics(metrics, weighted_metrics)
|
||
|
# Keep user-supplied values untouched for recompiling and serialization.
|
||
|
self._user_metrics = metrics
|
||
|
self._user_weighted_metrics = weighted_metrics
|
||
|
|
||
|
self._metrics = metrics
|
||
|
self._weighted_metrics = weighted_metrics
|
||
|
self._built = False
|
||
|
|
||
|
self._from_serialized = from_serialized
|
||
|
|
||
|
def _check_duplicated_metrics(self, metrics, weighted_metrics):
|
||
|
"""Raise error when user provided metrics have any duplications.
|
||
|
|
||
|
Note that metrics are stateful container, a shared metric instance
|
||
|
between model.metric and model.weighted_metric will make the same
|
||
|
intance to be udpated twice, and report wrong value.
|
||
|
|
||
|
Args:
|
||
|
metrics: User provided metrics list.
|
||
|
weighted_metrics: User provided weighted metrics list.
|
||
|
|
||
|
Raises:
|
||
|
ValueError, when duplicated metrics instance discovered in user
|
||
|
provided metrics and weighted metrics.
|
||
|
"""
|
||
|
seen = set()
|
||
|
duplicated = []
|
||
|
for x in tf.nest.flatten(metrics) + tf.nest.flatten(weighted_metrics):
|
||
|
# We only check metrics object. The string and function objects
|
||
|
# will be converted to unique Metric instance.
|
||
|
if not isinstance(x, metrics_mod.Metric):
|
||
|
continue
|
||
|
if x in seen:
|
||
|
duplicated.append(x)
|
||
|
seen.add(x)
|
||
|
|
||
|
if duplicated:
|
||
|
raise ValueError(
|
||
|
"Found duplicated metrics object in the user provided "
|
||
|
"metrics and weighted metrics. This will cause the same "
|
||
|
"metric object to be updated multiple times, and report "
|
||
|
"wrong results. \n"
|
||
|
f"Duplicated items: {duplicated}"
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def metrics(self):
|
||
|
"""All metrics in this container."""
|
||
|
if not self._built:
|
||
|
return []
|
||
|
return self._metrics_in_order
|
||
|
|
||
|
@property
|
||
|
def unweighted_metrics(self):
|
||
|
"""Metrics in the container that should not be passed sample_weight."""
|
||
|
if not self._built:
|
||
|
return None
|
||
|
return tf.nest.flatten(self._metrics)
|
||
|
|
||
|
@property
|
||
|
def weighted_metrics(self):
|
||
|
"""Metrics in this container that should be passed `sample_weight`."""
|
||
|
if not self._built:
|
||
|
return None
|
||
|
return tf.nest.flatten(self._weighted_metrics)
|
||
|
|
||
|
def build(self, y_pred, y_true):
|
||
|
"""One-time setup of metric objects."""
|
||
|
super(MetricsContainer, self).build(y_pred)
|
||
|
|
||
|
self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
|
||
|
self._metrics = self._conform_to_outputs(y_pred, self._metrics)
|
||
|
|
||
|
self._weighted_metrics = self._maybe_broadcast_to_outputs(
|
||
|
y_pred, self._weighted_metrics
|
||
|
)
|
||
|
self._weighted_metrics = self._conform_to_outputs(
|
||
|
y_pred, self._weighted_metrics
|
||
|
)
|
||
|
|
||
|
# Standardize on tuple since `tf.data` turns lists into `Tensor`s.
|
||
|
y_pred = tf.__internal__.nest.list_to_tuple(y_pred)
|
||
|
y_true = tf.__internal__.nest.list_to_tuple(y_true)
|
||
|
self._metrics = tf.__internal__.nest.list_to_tuple(self._metrics)
|
||
|
self._weighted_metrics = tf.__internal__.nest.list_to_tuple(
|
||
|
self._weighted_metrics
|
||
|
)
|
||
|
|
||
|
# Convert to `Metric` objects, potentially disambiguating based on
|
||
|
# output properties.
|
||
|
self._metrics = tf.__internal__.nest.map_structure_up_to(
|
||
|
y_pred, self._get_metric_objects, self._metrics, y_true, y_pred
|
||
|
)
|
||
|
self._weighted_metrics = tf.__internal__.nest.map_structure_up_to(
|
||
|
y_pred,
|
||
|
self._get_metric_objects,
|
||
|
self._weighted_metrics,
|
||
|
y_true,
|
||
|
y_pred,
|
||
|
)
|
||
|
|
||
|
self._metrics = tf.__internal__.nest.flatten_up_to(
|
||
|
y_pred, self._metrics, check_types=False
|
||
|
)
|
||
|
self._weighted_metrics = tf.__internal__.nest.flatten_up_to(
|
||
|
y_pred, self._weighted_metrics, check_types=False
|
||
|
)
|
||
|
|
||
|
# Assumes metrics, weighted_metrics have been flattened up to outputs.
|
||
|
#
|
||
|
# If we are loading a model that has been already serialized, we do not
|
||
|
# want to re-apply any pre-processing metric renaming steps.
|
||
|
if not self._from_serialized:
|
||
|
self._set_metric_names()
|
||
|
self._create_ordered_metrics()
|
||
|
self._built = True
|
||
|
|
||
|
@property
|
||
|
def built(self):
|
||
|
return self._built
|
||
|
|
||
|
def _set_metric_names(self):
|
||
|
"""Sets unique metric names."""
|
||
|
# For multi-output models, prepend the output name to the metric name.
|
||
|
# For weighted metrics, prepend "weighted_" if the name would be
|
||
|
# non-unique.
|
||
|
|
||
|
metric_names = set()
|
||
|
is_multi_output = len(self._output_names) > 1
|
||
|
zip_args = (self._output_names, self._metrics, self._weighted_metrics)
|
||
|
for output_name, output_metrics, weighted_output_metrics in zip(
|
||
|
*zip_args
|
||
|
):
|
||
|
for m in output_metrics:
|
||
|
if m is None:
|
||
|
continue
|
||
|
if is_multi_output:
|
||
|
m._name = output_name + "_" + m._name
|
||
|
if m._name in metric_names:
|
||
|
raise ValueError(
|
||
|
f"Found two metrics with the same name: {m._name}. "
|
||
|
"All the metrics added to the model need to have "
|
||
|
"unique names."
|
||
|
)
|
||
|
metric_names.add(m._name)
|
||
|
|
||
|
for wm in weighted_output_metrics:
|
||
|
if wm is None:
|
||
|
continue
|
||
|
if is_multi_output:
|
||
|
if output_name + "_" + wm._name in metric_names:
|
||
|
wm._name = output_name + "_weighted_" + wm._name
|
||
|
else:
|
||
|
wm._name = output_name + "_" + wm._name
|
||
|
elif wm._name in metric_names:
|
||
|
wm._name = "weighted_" + wm._name
|
||
|
|
||
|
if wm._name in metric_names:
|
||
|
raise ValueError(
|
||
|
"Found two weighted metrics with the same name: "
|
||
|
f"{wm._name}.All the metrics added to the model need "
|
||
|
"to have unique names."
|
||
|
)
|
||
|
metric_names.add(wm._name)
|
||
|
|
||
|
def _create_ordered_metrics(self):
|
||
|
"""Cache the flat order needed when return metrics, for backcompat."""
|
||
|
self._metrics_in_order = []
|
||
|
for output_metrics, output_weighted_metrics in zip(
|
||
|
self._metrics, self._weighted_metrics
|
||
|
):
|
||
|
for m in tf.nest.flatten(output_metrics):
|
||
|
if m is not None:
|
||
|
self._metrics_in_order.append(m)
|
||
|
for wm in tf.nest.flatten(output_weighted_metrics):
|
||
|
if wm is not None:
|
||
|
self._metrics_in_order.append(wm)
|
||
|
|
||
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||
|
"""Updates the state of per-output metrics."""
|
||
|
y_true = self._conform_to_outputs(y_pred, y_true)
|
||
|
sample_weight = self._conform_to_outputs(y_pred, sample_weight)
|
||
|
|
||
|
if not self._built:
|
||
|
self.build(y_pred, y_true)
|
||
|
|
||
|
y_pred = tf.nest.flatten(y_pred)
|
||
|
y_true = tf.nest.flatten(y_true) if y_true is not None else []
|
||
|
sample_weight = tf.nest.flatten(sample_weight)
|
||
|
|
||
|
zip_args = (
|
||
|
y_true,
|
||
|
y_pred,
|
||
|
sample_weight,
|
||
|
self._metrics,
|
||
|
self._weighted_metrics,
|
||
|
)
|
||
|
for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args):
|
||
|
# Ok to have no metrics for an output.
|
||
|
if y_t is None or (
|
||
|
all(m is None for m in metric_objs)
|
||
|
and all(wm is None for wm in weighted_metric_objs)
|
||
|
):
|
||
|
continue
|
||
|
|
||
|
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
|
||
|
mask = losses_utils.get_mask(y_p)
|
||
|
sw = losses_utils.apply_mask(y_p, sw, mask)
|
||
|
|
||
|
for metric_obj in metric_objs:
|
||
|
if metric_obj is None:
|
||
|
continue
|
||
|
metric_obj.update_state(y_t, y_p, sample_weight=mask)
|
||
|
|
||
|
for weighted_metric_obj in weighted_metric_objs:
|
||
|
if weighted_metric_obj is None:
|
||
|
continue
|
||
|
weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
|
||
|
|
||
|
def reset_state(self):
|
||
|
"""Resets the state of all `Metric`s in this container."""
|
||
|
if self._built:
|
||
|
metrics = self._metrics_in_order
|
||
|
else:
|
||
|
# If the user supplied `Metric` objects directly, we should
|
||
|
# reset those. This could also contain `str`s or `function`s
|
||
|
# though.
|
||
|
metrics = tf.nest.flatten(self._user_metrics) + tf.nest.flatten(
|
||
|
self._user_weighted_metrics
|
||
|
)
|
||
|
|
||
|
for metric_obj in metrics:
|
||
|
if isinstance(metric_obj, metrics_mod.Metric):
|
||
|
metric_obj.reset_state()
|
||
|
|
||
|
def _get_metric_objects(self, metrics, y_t, y_p):
|
||
|
"""Convert user-supplied metrics to `Metric` objects."""
|
||
|
metrics = tf.nest.flatten(metrics)
|
||
|
return [self._get_metric_object(m, y_t, y_p) for m in metrics]
|
||
|
|
||
|
def _get_metric_object(self, metric, y_t, y_p):
|
||
|
"""Converts user-supplied metric to a `Metric` object.
|
||
|
|
||
|
Args:
|
||
|
metric: A string, function, or `Metric` object.
|
||
|
y_t: Sample of label.
|
||
|
y_p: Sample of output.
|
||
|
|
||
|
Returns:
|
||
|
A `Metric` object.
|
||
|
"""
|
||
|
if metric is None:
|
||
|
return None # Ok to have no metric for an output.
|
||
|
|
||
|
# Convenience feature for selecting b/t binary, categorical,
|
||
|
# and sparse categorical.
|
||
|
if str(metric).lower() not in ["accuracy", "acc", "crossentropy", "ce"]:
|
||
|
metric_obj = metrics_mod.get(metric)
|
||
|
else:
|
||
|
y_t_rank = len(y_t.shape.as_list())
|
||
|
y_p_rank = len(y_p.shape.as_list())
|
||
|
y_t_last_dim = y_t.shape.as_list()[-1]
|
||
|
y_p_last_dim = y_p.shape.as_list()[-1]
|
||
|
|
||
|
is_binary = y_p_last_dim == 1
|
||
|
is_sparse_categorical = (
|
||
|
y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1
|
||
|
)
|
||
|
|
||
|
if str(metric).lower() in ["accuracy", "acc"]:
|
||
|
if is_binary:
|
||
|
metric_obj = metrics_mod.binary_accuracy
|
||
|
elif is_sparse_categorical:
|
||
|
metric_obj = metrics_mod.sparse_categorical_accuracy
|
||
|
else:
|
||
|
metric_obj = metrics_mod.categorical_accuracy
|
||
|
else:
|
||
|
if is_binary:
|
||
|
metric_obj = metrics_mod.binary_crossentropy
|
||
|
elif is_sparse_categorical:
|
||
|
metric_obj = metrics_mod.sparse_categorical_crossentropy
|
||
|
else:
|
||
|
metric_obj = metrics_mod.categorical_crossentropy
|
||
|
|
||
|
if isinstance(metric_obj, losses_mod.Loss):
|
||
|
metric_obj._allow_sum_over_batch_size = True
|
||
|
|
||
|
if not isinstance(metric_obj, metrics_mod.Metric):
|
||
|
if isinstance(metric, str):
|
||
|
metric_name = metric
|
||
|
else:
|
||
|
metric_name = get_custom_object_name(metric)
|
||
|
if metric_name is None:
|
||
|
raise ValueError(
|
||
|
f"Metric should be a callable, received: {metric}"
|
||
|
)
|
||
|
|
||
|
metric_obj = metrics_mod.MeanMetricWrapper(
|
||
|
metric_obj, name=metric_name
|
||
|
)
|
||
|
|
||
|
return metric_obj
|
||
|
|
||
|
def _should_broadcast(self, obj):
|
||
|
# e.g. 'mse'.
|
||
|
if not tf.nest.is_nested(obj):
|
||
|
return True
|
||
|
# e.g. ['mse'] or ['mse', 'mae'].
|
||
|
return isinstance(obj, (list, tuple)) and not any(
|
||
|
tf.nest.is_nested(o) for o in obj
|
||
|
)
|
||
|
|
||
|
def _copy_object(self, obj):
|
||
|
if isinstance(obj, metrics_mod.Metric):
|
||
|
return obj.__class__.from_config(obj.get_config())
|
||
|
return obj # Can be a function or `None`.
|
||
|
|
||
|
|
||
|
def create_pseudo_output_names(outputs):
|
||
|
"""Create pseudo output names for a subclassed Model."""
|
||
|
return _create_pseudo_names(outputs, prefix="output_")
|
||
|
|
||
|
|
||
|
def create_pseudo_input_names(inputs):
|
||
|
"""Create pseudo input names for a subclassed Model."""
|
||
|
return _create_pseudo_names(inputs, prefix="input_")
|
||
|
|
||
|
|
||
|
def _create_pseudo_names(tensors, prefix):
|
||
|
"""Creates pseudo {input | output} names for subclassed Models.
|
||
|
|
||
|
Warning: this function should only be used to define default
|
||
|
names for `Metics` and `SavedModel`. No other use cases should
|
||
|
rely on a `Model`'s input or output names.
|
||
|
|
||
|
Example with dict:
|
||
|
|
||
|
`{'a': [x1, x2], 'b': x3}` becomes:
|
||
|
`['a_1', 'a_2', 'b']`
|
||
|
|
||
|
Example with list:
|
||
|
|
||
|
`[x, y]` becomes:
|
||
|
`['output_1', 'output_2']`
|
||
|
|
||
|
Args:
|
||
|
tensors: `Model`'s outputs or inputs.
|
||
|
prefix: 'output_' for outputs, 'input_' for inputs.
|
||
|
|
||
|
Returns:
|
||
|
Flattened list of pseudo names.
|
||
|
"""
|
||
|
|
||
|
def one_index(ele):
|
||
|
# Start with "output_1" instead of "output_0".
|
||
|
if isinstance(ele, int):
|
||
|
return ele + 1
|
||
|
return ele
|
||
|
|
||
|
flat_paths = list(tf.__internal__.nest.yield_flat_paths(tensors))
|
||
|
flat_paths = tf.nest.map_structure(one_index, flat_paths)
|
||
|
names = []
|
||
|
for path in flat_paths:
|
||
|
if not path:
|
||
|
name = prefix + "1" # Single output.
|
||
|
else:
|
||
|
name = "_".join(str(p) for p in path)
|
||
|
if isinstance(path[0], int):
|
||
|
name = prefix + name
|
||
|
names.append(name)
|
||
|
return names
|
||
|
|
||
|
|
||
|
def map_to_output_names(y_pred, output_names, struct):
|
||
|
"""Maps a dict to a list using `output_names` as keys.
|
||
|
|
||
|
This is a convenience feature only. When a `Model`'s outputs
|
||
|
are a list, you can specify per-output losses and metrics as
|
||
|
a dict, where the keys are the output names. If you specify
|
||
|
per-output losses and metrics via the same structure as the
|
||
|
`Model`'s outputs (recommended), no mapping is performed.
|
||
|
|
||
|
For the Functional API, the output names are the names of the
|
||
|
last layer of each output. For the Subclass API, the output names
|
||
|
are determined by `create_pseudo_output_names` (For example:
|
||
|
`['output_1', 'output_2']` for a list of outputs).
|
||
|
|
||
|
This mapping preserves backwards compatibility for `compile` and
|
||
|
`fit`.
|
||
|
|
||
|
Args:
|
||
|
y_pred: Sample outputs of the Model, to determine if this convenience
|
||
|
feature should be applied (`struct` is returned unmodified if `y_pred`
|
||
|
isn't a flat list).
|
||
|
output_names: List. The names of the outputs of the Model.
|
||
|
struct: The structure to map.
|
||
|
|
||
|
Returns:
|
||
|
`struct` mapped to a list in same order as `output_names`.
|
||
|
"""
|
||
|
single_output = not tf.nest.is_nested(y_pred)
|
||
|
outputs_are_flat_list = (
|
||
|
not single_output
|
||
|
and isinstance(y_pred, (list, tuple))
|
||
|
and not any(tf.nest.is_nested(y_p) for y_p in y_pred)
|
||
|
)
|
||
|
|
||
|
if (single_output or outputs_are_flat_list) and isinstance(struct, dict):
|
||
|
output_names = output_names or create_pseudo_output_names(y_pred)
|
||
|
struct = copy.copy(struct)
|
||
|
new_struct = [struct.pop(name, None) for name in output_names]
|
||
|
if struct:
|
||
|
raise ValueError(
|
||
|
"Found unexpected losses or metrics that do not correspond "
|
||
|
f"to any Model output: {struct.keys()}. "
|
||
|
f"Valid mode output names: {output_names}. "
|
||
|
f"Received struct is: {struct}."
|
||
|
)
|
||
|
if len(new_struct) == 1:
|
||
|
return new_struct[0]
|
||
|
return new_struct
|
||
|
else:
|
||
|
return struct
|
||
|
|
||
|
|
||
|
def map_missing_dict_keys(y_pred, struct):
|
||
|
"""Replaces missing dict keys in `struct` with `None` placeholders."""
|
||
|
if not isinstance(y_pred, dict) or not isinstance(struct, dict):
|
||
|
return struct
|
||
|
struct = copy.copy(struct)
|
||
|
for k in y_pred.keys():
|
||
|
if k not in struct:
|
||
|
struct[k] = None
|
||
|
return struct
|
||
|
|
||
|
|
||
|
def match_dtype_and_rank(y_t, y_p, sw):
|
||
|
"""Match dtype and rank of predictions."""
|
||
|
if y_t.shape.rank == 1 and y_p.shape.rank == 2:
|
||
|
y_t = tf.expand_dims(y_t, axis=-1)
|
||
|
if sw is not None:
|
||
|
if sw.shape.rank == 1 and y_p.shape.rank == 2:
|
||
|
sw = tf.expand_dims(sw, axis=-1)
|
||
|
|
||
|
# Dtype.
|
||
|
# This is required mainly for custom loss functions which do not take care
|
||
|
# casting dtypes.
|
||
|
if (y_t.dtype.is_floating and y_p.dtype.is_floating) or (
|
||
|
y_t.dtype.is_integer and y_p.dtype.is_integer
|
||
|
):
|
||
|
y_t = tf.cast(y_t, y_p.dtype)
|
||
|
|
||
|
if sw is not None:
|
||
|
sw = tf.cast(sw, y_p.dtype)
|
||
|
return y_t, y_p, sw
|
||
|
|
||
|
|
||
|
def get_custom_object_name(obj):
|
||
|
"""Returns the name to use for a custom loss or metric callable.
|
||
|
|
||
|
Args:
|
||
|
obj: Custom loss of metric callable
|
||
|
|
||
|
Returns:
|
||
|
Name to use, or `None` if the object was not recognized.
|
||
|
"""
|
||
|
if hasattr(obj, "name"): # Accept `Loss` instance as `Metric`.
|
||
|
return obj.name
|
||
|
elif hasattr(obj, "__name__"): # Function.
|
||
|
return obj.__name__
|
||
|
elif hasattr(obj, "__class__"): # Class instance.
|
||
|
return generic_utils.to_snake_case(obj.__class__.__name__)
|
||
|
else: # Unrecognized object.
|
||
|
return None
|