1017 lines
39 KiB
Python
1017 lines
39 KiB
Python
![]() |
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
"""Utils related to keras metrics."""
|
||
|
|
||
|
import functools
|
||
|
import weakref
|
||
|
from enum import Enum
|
||
|
|
||
|
import numpy as np
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras.utils import losses_utils
|
||
|
from keras.utils import tf_utils
|
||
|
from keras.utils.generic_utils import to_list
|
||
|
|
||
|
NEG_INF = -1e10
|
||
|
|
||
|
|
||
|
class Reduction(Enum):
|
||
|
"""Types of metrics reduction.
|
||
|
|
||
|
Contains the following values:
|
||
|
|
||
|
* `SUM`: Scalar sum of weighted values.
|
||
|
* `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by
|
||
|
number of elements.
|
||
|
* `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights.
|
||
|
"""
|
||
|
|
||
|
SUM = "sum"
|
||
|
SUM_OVER_BATCH_SIZE = "sum_over_batch_size"
|
||
|
WEIGHTED_MEAN = "weighted_mean"
|
||
|
|
||
|
|
||
|
def update_state_wrapper(update_state_fn):
|
||
|
"""Decorator to wrap metric `update_state()` with `add_update()`.
|
||
|
|
||
|
Args:
|
||
|
update_state_fn: function that accumulates metric statistics.
|
||
|
|
||
|
Returns:
|
||
|
Decorated function that wraps `update_state_fn()` with `add_update()`.
|
||
|
"""
|
||
|
|
||
|
def decorated(metric_obj, *args, **kwargs):
|
||
|
"""Decorated function with `add_update()`."""
|
||
|
strategy = tf.distribute.get_strategy()
|
||
|
|
||
|
for weight in metric_obj.weights:
|
||
|
if (
|
||
|
backend.is_tpu_strategy(strategy)
|
||
|
and not strategy.extended.variable_created_in_scope(weight)
|
||
|
and not tf.distribute.in_cross_replica_context()
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"Trying to run metric.update_state in replica context when "
|
||
|
"the metric was not created in TPUStrategy scope. "
|
||
|
"Make sure the keras Metric is created in TPUstrategy "
|
||
|
"scope. "
|
||
|
)
|
||
|
|
||
|
with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
|
||
|
update_op = update_state_fn(*args, **kwargs)
|
||
|
if update_op is not None: # update_op will be None in eager execution.
|
||
|
metric_obj.add_update(update_op)
|
||
|
return update_op
|
||
|
|
||
|
return tf.__internal__.decorator.make_decorator(update_state_fn, decorated)
|
||
|
|
||
|
|
||
|
def result_wrapper(result_fn):
|
||
|
"""Decorator to wrap metric `result()` function in `merge_call()`.
|
||
|
|
||
|
Result computation is an idempotent operation that simply calculates the
|
||
|
metric value using the state variables.
|
||
|
|
||
|
If metric state variables are distributed across replicas/devices and
|
||
|
`result()` is requested from the context of one device - This function wraps
|
||
|
`result()` in a distribution strategy `merge_call()`. With this,
|
||
|
the metric state variables will be aggregated across devices.
|
||
|
|
||
|
Args:
|
||
|
result_fn: function that computes the metric result.
|
||
|
|
||
|
Returns:
|
||
|
Decorated function that wraps `result_fn()` in distribution strategy
|
||
|
`merge_call()`.
|
||
|
"""
|
||
|
|
||
|
def decorated(metric_obj, *args):
|
||
|
"""Decorated function with merge_call."""
|
||
|
replica_context = tf.distribute.get_replica_context()
|
||
|
|
||
|
# The purpose of using `merge_call` to call `result()` is to trigger
|
||
|
# cross replica aggregation of metric state variables
|
||
|
# (SyncOnReadVariable). After we introduced
|
||
|
# `variable_sync_on_read_context`, in principle there is no need to use
|
||
|
# `merge_call` here. However the branch still exists because:
|
||
|
#
|
||
|
# 1. Keras V1 training code sometimes assumes `result_t` is the same
|
||
|
# tensor across replicas (achieved by `merge_call`). With
|
||
|
# `variable_sync_on_read_context` each replica gets their own tensors
|
||
|
# residing on replica's device, thus breaking the assumption.
|
||
|
# 2. Keras c/fit creates a tf.function (a.k.a, train_function) that
|
||
|
# returns the metric values of the first replica. With
|
||
|
# `variable_sync_on_read_context` since each replica gets their own
|
||
|
# tensors, the metric result tensors on the non-first replicas are
|
||
|
# not in the return value of train_function, making TF graph
|
||
|
# optimizer prune the branch that computes and aggregates those
|
||
|
# metric results. As a result, if NCCL is used to do the aggregation,
|
||
|
# the program will hang because NCCL ops are only launched on the
|
||
|
# non-pruned first replica.
|
||
|
#
|
||
|
# We condition on strategy_supports_no_merge_call() since we know if it
|
||
|
# is True, the program uses `jit_compile` to compile replica fn, meaning
|
||
|
# it is not V1 training (hence #1 is okay), and no pruning will happen
|
||
|
# as compiled functions are not inlined (hence #2 is okay).
|
||
|
if (
|
||
|
replica_context is None
|
||
|
or tf.__internal__.distribute.strategy_supports_no_merge_call()
|
||
|
):
|
||
|
with tf.__internal__.distribute.variable_sync_on_read_context():
|
||
|
raw_result = result_fn(*args)
|
||
|
# Results need to be wrapped in a `tf.identity` op to ensure
|
||
|
# correct execution order.
|
||
|
if isinstance(raw_result, (tf.Tensor, tf.Variable, float, int)):
|
||
|
result_t = tf.identity(raw_result)
|
||
|
elif isinstance(raw_result, dict):
|
||
|
result_t = {
|
||
|
key: tf.identity(value)
|
||
|
for key, value in raw_result.items()
|
||
|
}
|
||
|
else:
|
||
|
try:
|
||
|
result_t = tf.identity(raw_result)
|
||
|
except (ValueError, TypeError):
|
||
|
raise RuntimeError(
|
||
|
"The output of `metric.result()` can only be a "
|
||
|
"single Tensor/Variable, or a dict of "
|
||
|
"Tensors/Variables. "
|
||
|
f"For metric {metric_obj.name}, "
|
||
|
f"got result {raw_result}."
|
||
|
)
|
||
|
else:
|
||
|
# TODO(psv): Test distribution of metrics using different
|
||
|
# distribution strategies.
|
||
|
|
||
|
# Creating a wrapper for merge_fn. merge_call invokes the given
|
||
|
# merge_fn with distribution object as the first parameter. We
|
||
|
# create a wrapper here so that the result function need not have
|
||
|
# that parameter.
|
||
|
def merge_fn_wrapper(distribution, merge_fn, *args):
|
||
|
# We will get `PerReplica` merge function. Taking the first one
|
||
|
# as all are identical copies of the function that we had passed
|
||
|
# below.
|
||
|
result = distribution.experimental_local_results(merge_fn)[0](
|
||
|
*args
|
||
|
)
|
||
|
|
||
|
# Wrapping result in identity so that control dependency between
|
||
|
# update_op from `update_state` and result works in case result
|
||
|
# returns a tensor.
|
||
|
return tf.identity(result)
|
||
|
|
||
|
# Wrapping result in merge_call. merge_call is used when we want to
|
||
|
# leave replica mode and compute a value in cross replica mode.
|
||
|
result_t = replica_context.merge_call(
|
||
|
merge_fn_wrapper, args=(result_fn,) + args
|
||
|
)
|
||
|
|
||
|
# We are saving the result op here to be used in train/test execution
|
||
|
# functions. This basically gives the result op that was generated with
|
||
|
# a control dep to the updates for these workflows.
|
||
|
metric_obj._call_result = result_t
|
||
|
return result_t
|
||
|
|
||
|
return tf.__internal__.decorator.make_decorator(result_fn, decorated)
|
||
|
|
||
|
|
||
|
def weakmethod(method):
|
||
|
"""Creates a weak reference to the bound method."""
|
||
|
|
||
|
cls = method.im_class
|
||
|
func = method.im_func
|
||
|
instance_ref = weakref.ref(method.im_self)
|
||
|
|
||
|
@functools.wraps(method)
|
||
|
def inner(*args, **kwargs):
|
||
|
return func.__get__(instance_ref(), cls)(*args, **kwargs)
|
||
|
|
||
|
del method
|
||
|
return inner
|
||
|
|
||
|
|
||
|
def assert_thresholds_range(thresholds):
|
||
|
if thresholds is not None:
|
||
|
invalid_thresholds = [
|
||
|
t for t in thresholds if t is None or t < 0 or t > 1
|
||
|
]
|
||
|
if invalid_thresholds:
|
||
|
raise ValueError(
|
||
|
"Threshold values must be in [0, 1]. "
|
||
|
f"Received: {invalid_thresholds}"
|
||
|
)
|
||
|
|
||
|
|
||
|
def parse_init_thresholds(thresholds, default_threshold=0.5):
|
||
|
if thresholds is not None:
|
||
|
assert_thresholds_range(to_list(thresholds))
|
||
|
thresholds = to_list(
|
||
|
default_threshold if thresholds is None else thresholds
|
||
|
)
|
||
|
return thresholds
|
||
|
|
||
|
|
||
|
class ConfusionMatrix(Enum):
|
||
|
TRUE_POSITIVES = "tp"
|
||
|
FALSE_POSITIVES = "fp"
|
||
|
TRUE_NEGATIVES = "tn"
|
||
|
FALSE_NEGATIVES = "fn"
|
||
|
|
||
|
|
||
|
class AUCCurve(Enum):
|
||
|
"""Type of AUC Curve (ROC or PR)."""
|
||
|
|
||
|
ROC = "ROC"
|
||
|
PR = "PR"
|
||
|
|
||
|
@staticmethod
|
||
|
def from_str(key):
|
||
|
if key in ("pr", "PR"):
|
||
|
return AUCCurve.PR
|
||
|
elif key in ("roc", "ROC"):
|
||
|
return AUCCurve.ROC
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f'Invalid AUC curve value: "{key}". '
|
||
|
'Expected values are ["PR", "ROC"]'
|
||
|
)
|
||
|
|
||
|
|
||
|
class AUCSummationMethod(Enum):
|
||
|
"""Type of AUC summation method.
|
||
|
|
||
|
https://en.wikipedia.org/wiki/Riemann_sum)
|
||
|
|
||
|
Contains the following values:
|
||
|
* 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
|
||
|
`PR` curve, interpolates (true/false) positives but not the ratio that is
|
||
|
precision (see Davis & Goadrich 2006 for details).
|
||
|
* 'minoring': Applies left summation for increasing intervals and right
|
||
|
summation for decreasing intervals.
|
||
|
* 'majoring': Applies right summation for increasing intervals and left
|
||
|
summation for decreasing intervals.
|
||
|
"""
|
||
|
|
||
|
INTERPOLATION = "interpolation"
|
||
|
MAJORING = "majoring"
|
||
|
MINORING = "minoring"
|
||
|
|
||
|
@staticmethod
|
||
|
def from_str(key):
|
||
|
if key in ("interpolation", "Interpolation"):
|
||
|
return AUCSummationMethod.INTERPOLATION
|
||
|
elif key in ("majoring", "Majoring"):
|
||
|
return AUCSummationMethod.MAJORING
|
||
|
elif key in ("minoring", "Minoring"):
|
||
|
return AUCSummationMethod.MINORING
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f'Invalid AUC summation method value: "{key}". '
|
||
|
'Expected values are ["interpolation", "majoring", "minoring"]'
|
||
|
)
|
||
|
|
||
|
|
||
|
def _update_confusion_matrix_variables_optimized(
|
||
|
variables_to_update,
|
||
|
y_true,
|
||
|
y_pred,
|
||
|
thresholds,
|
||
|
multi_label=False,
|
||
|
sample_weights=None,
|
||
|
label_weights=None,
|
||
|
thresholds_with_epsilon=False,
|
||
|
):
|
||
|
"""Update confusion matrix variables with memory efficient alternative.
|
||
|
|
||
|
Note that the thresholds need to be evenly distributed within the list, eg,
|
||
|
the diff between consecutive elements are the same.
|
||
|
|
||
|
To compute TP/FP/TN/FN, we are measuring a binary classifier
|
||
|
C(t) = (predictions >= t)
|
||
|
at each threshold 't'. So we have
|
||
|
TP(t) = sum( C(t) * true_labels )
|
||
|
FP(t) = sum( C(t) * false_labels )
|
||
|
|
||
|
But, computing C(t) requires computation for each t. To make it fast,
|
||
|
observe that C(t) is a cumulative integral, and so if we have
|
||
|
thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
|
||
|
where n = num_thresholds, and if we can compute the bucket function
|
||
|
B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
|
||
|
then we get
|
||
|
C(t_i) = sum( B(j), j >= i )
|
||
|
which is the reversed cumulative sum in tf.cumsum().
|
||
|
|
||
|
We can compute B(i) efficiently by taking advantage of the fact that
|
||
|
our thresholds are evenly distributed, in that
|
||
|
width = 1.0 / (num_thresholds - 1)
|
||
|
thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
|
||
|
Given a prediction value p, we can map it to its bucket by
|
||
|
bucket_index(p) = floor( p * (num_thresholds - 1) )
|
||
|
so we can use tf.math.unsorted_segment_sum() to update the buckets in one
|
||
|
pass.
|
||
|
|
||
|
Consider following example:
|
||
|
y_true = [0, 0, 1, 1]
|
||
|
y_pred = [0.1, 0.5, 0.3, 0.9]
|
||
|
thresholds = [0.0, 0.5, 1.0]
|
||
|
num_buckets = 2 # [0.0, 1.0], (1.0, 2.0]
|
||
|
bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets)
|
||
|
= tf.math.floor([0.2, 1.0, 0.6, 1.8])
|
||
|
= [0, 0, 0, 1]
|
||
|
# The meaning of this bucket is that if any of the label is true,
|
||
|
# then 1 will be added to the corresponding bucket with the index.
|
||
|
# Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
|
||
|
# label for 1.8 is true, then 1 will be added to bucket 1.
|
||
|
#
|
||
|
# Note the second item "1.0" is floored to 0, since the value need to be
|
||
|
# strictly larger than the bucket lower bound.
|
||
|
# In the implementation, we use tf.math.ceil() - 1 to achieve this.
|
||
|
tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices,
|
||
|
num_segments=num_thresholds)
|
||
|
= [1, 1, 0]
|
||
|
# For [1, 1, 0] here, it means there is 1 true value contributed by bucket
|
||
|
# 0, and 1 value contributed by bucket 1. When we aggregate them to
|
||
|
# together, the result become [a + b + c, b + c, c], since large thresholds
|
||
|
# will always contribute to the value for smaller thresholds.
|
||
|
true_positive = tf.math.cumsum(tp_bucket_value, reverse=True)
|
||
|
= [2, 1, 0]
|
||
|
|
||
|
This implementation exhibits a run time and space complexity of O(T + N),
|
||
|
where T is the number of thresholds and N is the size of predictions.
|
||
|
Metrics that rely on standard implementation instead exhibit a complexity of
|
||
|
O(T * N).
|
||
|
|
||
|
Args:
|
||
|
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
|
||
|
and corresponding variables to update as values.
|
||
|
y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
|
||
|
cast to `bool`.
|
||
|
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
|
||
|
in the range `[0, 1]`.
|
||
|
thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
|
||
|
It need to be evenly distributed (the diff between each element need to
|
||
|
be the same).
|
||
|
multi_label: Optional boolean indicating whether multidimensional
|
||
|
prediction/labels should be treated as multilabel responses, or
|
||
|
flattened into a single label. When True, the valus of
|
||
|
`variables_to_update` must have a second dimension equal to the number
|
||
|
of labels in y_true and y_pred, and those tensors must not be
|
||
|
RaggedTensors.
|
||
|
sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
|
||
|
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
|
||
|
must be either `1`, or the same as the corresponding `y_true`
|
||
|
dimension).
|
||
|
label_weights: Optional tensor of non-negative weights for multilabel
|
||
|
data. The weights are applied when calculating TP, FP, FN, and TN
|
||
|
without explicit multilabel handling (i.e. when the data is to be
|
||
|
flattened).
|
||
|
thresholds_with_epsilon: Optional boolean indicating whether the leading
|
||
|
and tailing thresholds has any epsilon added for floating point
|
||
|
imprecisions. It will change how we handle the leading and tailing
|
||
|
bucket.
|
||
|
|
||
|
Returns:
|
||
|
Update op.
|
||
|
"""
|
||
|
num_thresholds = thresholds.shape.as_list()[0]
|
||
|
|
||
|
if sample_weights is None:
|
||
|
sample_weights = 1.0
|
||
|
else:
|
||
|
sample_weights = tf.__internal__.ops.broadcast_weights(
|
||
|
tf.cast(sample_weights, dtype=y_pred.dtype), y_pred
|
||
|
)
|
||
|
if not multi_label:
|
||
|
sample_weights = tf.reshape(sample_weights, [-1])
|
||
|
if label_weights is None:
|
||
|
label_weights = 1.0
|
||
|
else:
|
||
|
label_weights = tf.expand_dims(label_weights, 0)
|
||
|
label_weights = tf.__internal__.ops.broadcast_weights(
|
||
|
label_weights, y_pred
|
||
|
)
|
||
|
if not multi_label:
|
||
|
label_weights = tf.reshape(label_weights, [-1])
|
||
|
weights = tf.cast(tf.multiply(sample_weights, label_weights), y_true.dtype)
|
||
|
|
||
|
# We shouldn't need this, but in case there are predict value that is out of
|
||
|
# the range of [0.0, 1.0]
|
||
|
y_pred = tf.clip_by_value(y_pred, clip_value_min=0.0, clip_value_max=1.0)
|
||
|
|
||
|
y_true = tf.cast(tf.cast(y_true, tf.bool), y_true.dtype)
|
||
|
if not multi_label:
|
||
|
y_true = tf.reshape(y_true, [-1])
|
||
|
y_pred = tf.reshape(y_pred, [-1])
|
||
|
|
||
|
true_labels = tf.multiply(y_true, weights)
|
||
|
false_labels = tf.multiply((1.0 - y_true), weights)
|
||
|
|
||
|
# Compute the bucket indices for each prediction value.
|
||
|
# Since the predict value has to be strictly greater than the thresholds,
|
||
|
# eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
|
||
|
# We have to use math.ceil(val) - 1 for the bucket.
|
||
|
bucket_indices = tf.math.ceil(y_pred * (num_thresholds - 1)) - 1
|
||
|
|
||
|
if thresholds_with_epsilon:
|
||
|
# In this case, the first bucket should actually take into account since
|
||
|
# the any prediction between [0.0, 1.0] should be larger than the first
|
||
|
# threshold. We change the bucket value from -1 to 0.
|
||
|
bucket_indices = tf.nn.relu(bucket_indices)
|
||
|
|
||
|
bucket_indices = tf.cast(bucket_indices, tf.int32)
|
||
|
|
||
|
if multi_label:
|
||
|
# We need to run bucket segment sum for each of the label class. In the
|
||
|
# multi_label case, the rank of the label is 2. We first transpose it so
|
||
|
# that the label dim becomes the first and we can parallel run though
|
||
|
# them.
|
||
|
true_labels = tf.transpose(true_labels)
|
||
|
false_labels = tf.transpose(false_labels)
|
||
|
bucket_indices = tf.transpose(bucket_indices)
|
||
|
|
||
|
def gather_bucket(label_and_bucket_index):
|
||
|
label, bucket_index = (
|
||
|
label_and_bucket_index[0],
|
||
|
label_and_bucket_index[1],
|
||
|
)
|
||
|
return tf.math.unsorted_segment_sum(
|
||
|
data=label,
|
||
|
segment_ids=bucket_index,
|
||
|
num_segments=num_thresholds,
|
||
|
)
|
||
|
|
||
|
tp_bucket_v = tf.vectorized_map(
|
||
|
gather_bucket, (true_labels, bucket_indices), warn=False
|
||
|
)
|
||
|
fp_bucket_v = tf.vectorized_map(
|
||
|
gather_bucket, (false_labels, bucket_indices), warn=False
|
||
|
)
|
||
|
tp = tf.transpose(tf.cumsum(tp_bucket_v, reverse=True, axis=1))
|
||
|
fp = tf.transpose(tf.cumsum(fp_bucket_v, reverse=True, axis=1))
|
||
|
else:
|
||
|
tp_bucket_v = tf.math.unsorted_segment_sum(
|
||
|
data=true_labels,
|
||
|
segment_ids=bucket_indices,
|
||
|
num_segments=num_thresholds,
|
||
|
)
|
||
|
fp_bucket_v = tf.math.unsorted_segment_sum(
|
||
|
data=false_labels,
|
||
|
segment_ids=bucket_indices,
|
||
|
num_segments=num_thresholds,
|
||
|
)
|
||
|
tp = tf.cumsum(tp_bucket_v, reverse=True)
|
||
|
fp = tf.cumsum(fp_bucket_v, reverse=True)
|
||
|
|
||
|
# fn = sum(true_labels) - tp
|
||
|
# tn = sum(false_labels) - fp
|
||
|
if (
|
||
|
ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
|
||
|
or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
|
||
|
):
|
||
|
if multi_label:
|
||
|
total_true_labels = tf.reduce_sum(true_labels, axis=1)
|
||
|
total_false_labels = tf.reduce_sum(false_labels, axis=1)
|
||
|
else:
|
||
|
total_true_labels = tf.reduce_sum(true_labels)
|
||
|
total_false_labels = tf.reduce_sum(false_labels)
|
||
|
|
||
|
update_ops = []
|
||
|
if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
|
||
|
variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
|
||
|
update_ops.append(variable.assign_add(tp))
|
||
|
if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
|
||
|
variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
|
||
|
update_ops.append(variable.assign_add(fp))
|
||
|
if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
|
||
|
variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
|
||
|
tn = total_false_labels - fp
|
||
|
update_ops.append(variable.assign_add(tn))
|
||
|
if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
|
||
|
variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
|
||
|
fn = total_true_labels - tp
|
||
|
update_ops.append(variable.assign_add(fn))
|
||
|
return tf.group(update_ops)
|
||
|
|
||
|
|
||
|
def is_evenly_distributed_thresholds(thresholds):
|
||
|
"""Check if the thresholds list is evenly distributed.
|
||
|
|
||
|
We could leverage evenly distributed thresholds to use less memory when
|
||
|
calculate metrcis like AUC where each individual threshold need to be
|
||
|
evaluated.
|
||
|
|
||
|
Args:
|
||
|
thresholds: A python list or tuple, or 1D numpy array whose value is
|
||
|
ranged in [0, 1].
|
||
|
|
||
|
Returns:
|
||
|
boolean, whether the values in the inputs are evenly distributed.
|
||
|
"""
|
||
|
# Check the list value and see if it is evenly distributed.
|
||
|
num_thresholds = len(thresholds)
|
||
|
if num_thresholds < 3:
|
||
|
return False
|
||
|
even_thresholds = np.arange(num_thresholds, dtype=np.float32) / (
|
||
|
num_thresholds - 1
|
||
|
)
|
||
|
return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())
|
||
|
|
||
|
|
||
|
def update_confusion_matrix_variables(
|
||
|
variables_to_update,
|
||
|
y_true,
|
||
|
y_pred,
|
||
|
thresholds,
|
||
|
top_k=None,
|
||
|
class_id=None,
|
||
|
sample_weight=None,
|
||
|
multi_label=False,
|
||
|
label_weights=None,
|
||
|
thresholds_distributed_evenly=False,
|
||
|
):
|
||
|
"""Returns op to update the given confusion matrix variables.
|
||
|
|
||
|
For every pair of values in y_true and y_pred:
|
||
|
|
||
|
true_positive: y_true == True and y_pred > thresholds
|
||
|
false_negatives: y_true == True and y_pred <= thresholds
|
||
|
true_negatives: y_true == False and y_pred <= thresholds
|
||
|
false_positive: y_true == False and y_pred > thresholds
|
||
|
|
||
|
The results will be weighted and added together. When multiple thresholds
|
||
|
are provided, we will repeat the same for every threshold.
|
||
|
|
||
|
For estimation of these metrics over a stream of data, the function creates
|
||
|
an `update_op` operation that updates the given variables.
|
||
|
|
||
|
If `sample_weight` is `None`, weights default to 1.
|
||
|
Use weights of 0 to mask values.
|
||
|
|
||
|
Args:
|
||
|
variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
|
||
|
and corresponding variables to update as values.
|
||
|
y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
|
||
|
y_pred: A floating point `Tensor` of arbitrary shape and whose values are
|
||
|
in the range `[0, 1]`.
|
||
|
thresholds: A float value, float tensor, python list, or tuple of float
|
||
|
thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
|
||
|
top_k: Optional int, indicates that the positive labels should be limited
|
||
|
to the top k predictions.
|
||
|
class_id: Optional int, limits the prediction and labels to the class
|
||
|
specified by this argument.
|
||
|
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
|
||
|
as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
|
||
|
must be either `1`, or the same as the corresponding `y_true`
|
||
|
dimension).
|
||
|
multi_label: Optional boolean indicating whether multidimensional
|
||
|
prediction/labels should be treated as multilabel responses, or
|
||
|
flattened into a single label. When True, the valus of
|
||
|
`variables_to_update` must have a second dimension equal to the number
|
||
|
of labels in y_true and y_pred, and those tensors must not be
|
||
|
RaggedTensors.
|
||
|
label_weights: (optional) tensor of non-negative weights for multilabel
|
||
|
data. The weights are applied when calculating TP, FP, FN, and TN
|
||
|
without explicit multilabel handling (i.e. when the data is to be
|
||
|
flattened).
|
||
|
thresholds_distributed_evenly: Boolean, whether the thresholds are evenly
|
||
|
distributed within the list. An optimized method will be used if this is
|
||
|
the case. See _update_confusion_matrix_variables_optimized() for more
|
||
|
details.
|
||
|
|
||
|
Returns:
|
||
|
Update op.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
|
||
|
`sample_weight` is not `None` and its shape doesn't match `y_pred`, or
|
||
|
if `variables_to_update` contains invalid keys.
|
||
|
"""
|
||
|
if multi_label and label_weights is not None:
|
||
|
raise ValueError(
|
||
|
"`label_weights` for multilabel data should be handled "
|
||
|
"outside of `update_confusion_matrix_variables` when "
|
||
|
"`multi_label` is True."
|
||
|
)
|
||
|
if variables_to_update is None:
|
||
|
return
|
||
|
if not any(
|
||
|
key for key in variables_to_update if key in list(ConfusionMatrix)
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"Please provide at least one valid confusion matrix "
|
||
|
"variable to update. Valid variable key options are: "
|
||
|
f'"{list(ConfusionMatrix)}". '
|
||
|
f'Received: "{variables_to_update.keys()}"'
|
||
|
)
|
||
|
|
||
|
variable_dtype = list(variables_to_update.values())[0].dtype
|
||
|
|
||
|
y_true = tf.cast(y_true, dtype=variable_dtype)
|
||
|
y_pred = tf.cast(y_pred, dtype=variable_dtype)
|
||
|
|
||
|
if thresholds_distributed_evenly:
|
||
|
# Check whether the thresholds has any leading or tailing epsilon added
|
||
|
# for floating point imprecision. The leading and tailing threshold will
|
||
|
# be handled bit differently as the corner case. At this point,
|
||
|
# thresholds should be a list/array with more than 2 items, and ranged
|
||
|
# between [0, 1]. See is_evenly_distributed_thresholds() for more
|
||
|
# details.
|
||
|
thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0
|
||
|
|
||
|
thresholds = tf.convert_to_tensor(thresholds, dtype=variable_dtype)
|
||
|
num_thresholds = thresholds.shape.as_list()[0]
|
||
|
|
||
|
if multi_label:
|
||
|
one_thresh = tf.equal(
|
||
|
tf.cast(1, dtype=tf.int32),
|
||
|
tf.rank(thresholds),
|
||
|
name="one_set_of_thresholds_cond",
|
||
|
)
|
||
|
else:
|
||
|
[y_pred, y_true], _ = ragged_assert_compatible_and_get_flat_values(
|
||
|
[y_pred, y_true], sample_weight
|
||
|
)
|
||
|
one_thresh = tf.cast(True, dtype=tf.bool)
|
||
|
|
||
|
invalid_keys = [
|
||
|
key for key in variables_to_update if key not in list(ConfusionMatrix)
|
||
|
]
|
||
|
if invalid_keys:
|
||
|
raise ValueError(
|
||
|
f'Invalid keys: "{invalid_keys}". '
|
||
|
f'Valid variable key options are: "{list(ConfusionMatrix)}"'
|
||
|
)
|
||
|
|
||
|
if sample_weight is None:
|
||
|
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
|
||
|
y_pred, y_true
|
||
|
)
|
||
|
else:
|
||
|
sample_weight = tf.cast(sample_weight, dtype=variable_dtype)
|
||
|
(
|
||
|
y_pred,
|
||
|
y_true,
|
||
|
sample_weight,
|
||
|
) = losses_utils.squeeze_or_expand_dimensions(
|
||
|
y_pred, y_true, sample_weight=sample_weight
|
||
|
)
|
||
|
y_pred.shape.assert_is_compatible_with(y_true.shape)
|
||
|
|
||
|
if top_k is not None:
|
||
|
y_pred = _filter_top_k(y_pred, top_k)
|
||
|
if class_id is not None:
|
||
|
# Preserve dimension to match with sample_weight
|
||
|
y_true = y_true[..., class_id, None]
|
||
|
y_pred = y_pred[..., class_id, None]
|
||
|
|
||
|
if thresholds_distributed_evenly:
|
||
|
return _update_confusion_matrix_variables_optimized(
|
||
|
variables_to_update,
|
||
|
y_true,
|
||
|
y_pred,
|
||
|
thresholds,
|
||
|
multi_label=multi_label,
|
||
|
sample_weights=sample_weight,
|
||
|
label_weights=label_weights,
|
||
|
thresholds_with_epsilon=thresholds_with_epsilon,
|
||
|
)
|
||
|
|
||
|
pred_shape = tf.shape(y_pred)
|
||
|
num_predictions = pred_shape[0]
|
||
|
if y_pred.shape.ndims == 1:
|
||
|
num_labels = 1
|
||
|
else:
|
||
|
num_labels = tf.math.reduce_prod(pred_shape[1:], axis=0)
|
||
|
thresh_label_tile = tf.where(
|
||
|
one_thresh, num_labels, tf.ones([], dtype=tf.int32)
|
||
|
)
|
||
|
|
||
|
# Reshape predictions and labels, adding a dim for thresholding.
|
||
|
if multi_label:
|
||
|
predictions_extra_dim = tf.expand_dims(y_pred, 0)
|
||
|
labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype=tf.bool), 0)
|
||
|
else:
|
||
|
# Flatten predictions and labels when not multilabel.
|
||
|
predictions_extra_dim = tf.reshape(y_pred, [1, -1])
|
||
|
labels_extra_dim = tf.reshape(tf.cast(y_true, dtype=tf.bool), [1, -1])
|
||
|
|
||
|
# Tile the thresholds for every prediction.
|
||
|
if multi_label:
|
||
|
thresh_pretile_shape = [num_thresholds, 1, -1]
|
||
|
thresh_tiles = [1, num_predictions, thresh_label_tile]
|
||
|
data_tiles = [num_thresholds, 1, 1]
|
||
|
else:
|
||
|
thresh_pretile_shape = [num_thresholds, -1]
|
||
|
thresh_tiles = [1, num_predictions * num_labels]
|
||
|
data_tiles = [num_thresholds, 1]
|
||
|
|
||
|
thresh_tiled = tf.tile(
|
||
|
tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles)
|
||
|
)
|
||
|
|
||
|
# Tile the predictions for every threshold.
|
||
|
preds_tiled = tf.tile(predictions_extra_dim, data_tiles)
|
||
|
|
||
|
# Compare predictions and threshold.
|
||
|
pred_is_pos = tf.greater(preds_tiled, thresh_tiled)
|
||
|
|
||
|
# Tile labels by number of thresholds
|
||
|
label_is_pos = tf.tile(labels_extra_dim, data_tiles)
|
||
|
|
||
|
if sample_weight is not None:
|
||
|
sample_weight = tf.__internal__.ops.broadcast_weights(
|
||
|
tf.cast(sample_weight, dtype=variable_dtype), y_pred
|
||
|
)
|
||
|
weights_tiled = tf.tile(
|
||
|
tf.reshape(sample_weight, thresh_tiles), data_tiles
|
||
|
)
|
||
|
else:
|
||
|
weights_tiled = None
|
||
|
|
||
|
if label_weights is not None and not multi_label:
|
||
|
label_weights = tf.expand_dims(label_weights, 0)
|
||
|
label_weights = tf.__internal__.ops.broadcast_weights(
|
||
|
label_weights, y_pred
|
||
|
)
|
||
|
label_weights_tiled = tf.tile(
|
||
|
tf.reshape(label_weights, thresh_tiles), data_tiles
|
||
|
)
|
||
|
if weights_tiled is None:
|
||
|
weights_tiled = label_weights_tiled
|
||
|
else:
|
||
|
weights_tiled = tf.multiply(weights_tiled, label_weights_tiled)
|
||
|
|
||
|
update_ops = []
|
||
|
|
||
|
def weighted_assign_add(label, pred, weights, var):
|
||
|
label_and_pred = tf.cast(tf.logical_and(label, pred), dtype=var.dtype)
|
||
|
if weights is not None:
|
||
|
label_and_pred *= tf.cast(weights, dtype=var.dtype)
|
||
|
return var.assign_add(tf.reduce_sum(label_and_pred, 1))
|
||
|
|
||
|
loop_vars = {
|
||
|
ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
|
||
|
}
|
||
|
update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
|
||
|
update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
|
||
|
update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
|
||
|
|
||
|
if update_fn or update_tn:
|
||
|
pred_is_neg = tf.logical_not(pred_is_pos)
|
||
|
loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
|
||
|
|
||
|
if update_fp or update_tn:
|
||
|
label_is_neg = tf.logical_not(label_is_pos)
|
||
|
loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
|
||
|
if update_tn:
|
||
|
loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (
|
||
|
label_is_neg,
|
||
|
pred_is_neg,
|
||
|
)
|
||
|
|
||
|
for matrix_cond, (label, pred) in loop_vars.items():
|
||
|
|
||
|
if matrix_cond in variables_to_update:
|
||
|
update_ops.append(
|
||
|
weighted_assign_add(
|
||
|
label, pred, weights_tiled, variables_to_update[matrix_cond]
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return tf.group(update_ops)
|
||
|
|
||
|
|
||
|
def _filter_top_k(x, k):
|
||
|
"""Filters top-k values in the last dim of x and set the rest to NEG_INF.
|
||
|
|
||
|
Used for computing top-k prediction values in dense labels (which has the
|
||
|
same shape as predictions) for recall and precision top-k metrics.
|
||
|
|
||
|
Args:
|
||
|
x: tensor with any dimensions.
|
||
|
k: the number of values to keep.
|
||
|
|
||
|
Returns:
|
||
|
tensor with same shape and dtype as x.
|
||
|
"""
|
||
|
_, top_k_idx = tf.math.top_k(x, k, sorted=False)
|
||
|
top_k_mask = tf.reduce_sum(
|
||
|
tf.one_hot(top_k_idx, tf.shape(x)[-1], axis=-1), axis=-2
|
||
|
)
|
||
|
return x * top_k_mask + NEG_INF * (1 - top_k_mask)
|
||
|
|
||
|
|
||
|
def ragged_assert_compatible_and_get_flat_values(values, mask=None):
|
||
|
"""If ragged, it checks the compatibility and then returns the flat_values.
|
||
|
|
||
|
Note: If two tensors are dense, it does not check their compatibility.
|
||
|
Note: Although two ragged tensors with different ragged ranks could have
|
||
|
identical overall rank and dimension sizes and hence be compatible,
|
||
|
we do not support those cases.
|
||
|
Args:
|
||
|
values: A list of potentially ragged tensor of the same ragged_rank.
|
||
|
mask: A potentially ragged tensor of the same ragged_rank as elements in
|
||
|
Values.
|
||
|
|
||
|
Returns:
|
||
|
A tuple in which the first element is the list of tensors and the second
|
||
|
is the mask tensor. ([Values], mask). Mask and the element in Values
|
||
|
are equal to the flat_values of the input arguments (if they were
|
||
|
ragged).
|
||
|
"""
|
||
|
if isinstance(values, list):
|
||
|
is_all_ragged = all(isinstance(rt, tf.RaggedTensor) for rt in values)
|
||
|
is_any_ragged = any(isinstance(rt, tf.RaggedTensor) for rt in values)
|
||
|
else:
|
||
|
is_all_ragged = isinstance(values, tf.RaggedTensor)
|
||
|
is_any_ragged = is_all_ragged
|
||
|
if is_all_ragged and ((mask is None) or isinstance(mask, tf.RaggedTensor)):
|
||
|
to_be_stripped = False
|
||
|
if not isinstance(values, list):
|
||
|
values = [values]
|
||
|
to_be_stripped = True
|
||
|
|
||
|
# NOTE: we leave the flat_values compatibility to
|
||
|
# tf.TensorShape `assert_is_compatible_with` check if both dynamic
|
||
|
# dimensions are equal and then use the flat_values.
|
||
|
nested_row_split_list = [rt.nested_row_splits for rt in values]
|
||
|
assertion_list = _assert_splits_match(nested_row_split_list)
|
||
|
|
||
|
# if both are ragged sample_weights also should be ragged with same
|
||
|
# dims.
|
||
|
if isinstance(mask, tf.RaggedTensor):
|
||
|
assertion_list_for_mask = _assert_splits_match(
|
||
|
[nested_row_split_list[0], mask.nested_row_splits]
|
||
|
)
|
||
|
with tf.control_dependencies(assertion_list_for_mask):
|
||
|
mask = tf.expand_dims(mask.flat_values, -1)
|
||
|
|
||
|
# values has at least 1 element.
|
||
|
flat_values = []
|
||
|
for value in values:
|
||
|
with tf.control_dependencies(assertion_list):
|
||
|
flat_values.append(tf.expand_dims(value.flat_values, -1))
|
||
|
|
||
|
values = flat_values[0] if to_be_stripped else flat_values
|
||
|
|
||
|
elif is_any_ragged:
|
||
|
raise TypeError(
|
||
|
"Some of the inputs are not tf.RaggedTensor. "
|
||
|
f"Input received: {values}"
|
||
|
)
|
||
|
# values are empty or value are not ragged and mask is ragged.
|
||
|
elif isinstance(mask, tf.RaggedTensor):
|
||
|
raise TypeError(
|
||
|
"Ragged mask is not allowed with non-ragged inputs. "
|
||
|
f"Input received: {values}, mask received: {mask}"
|
||
|
)
|
||
|
|
||
|
return values, mask
|
||
|
|
||
|
|
||
|
def _assert_splits_match(nested_splits_lists):
|
||
|
"""Checks that the given splits lists are identical.
|
||
|
|
||
|
Performs static tests to ensure that the given splits lists are identical,
|
||
|
and returns a list of control dependency op tensors that check that they are
|
||
|
fully identical.
|
||
|
|
||
|
Args:
|
||
|
nested_splits_lists: A list of nested_splits_lists, where each split_list
|
||
|
is a list of `splits` tensors from a `RaggedTensor`, ordered from
|
||
|
outermost ragged dimension to innermost ragged dimension.
|
||
|
|
||
|
Returns:
|
||
|
A list of control dependency op tensors.
|
||
|
Raises:
|
||
|
ValueError: If the splits are not identical.
|
||
|
"""
|
||
|
error_msg = (
|
||
|
"Inputs must have identical ragged splits. "
|
||
|
f"Input received: {nested_splits_lists}"
|
||
|
)
|
||
|
for splits_list in nested_splits_lists:
|
||
|
if len(splits_list) != len(nested_splits_lists[0]):
|
||
|
raise ValueError(error_msg)
|
||
|
return [
|
||
|
tf.debugging.assert_equal(s1, s2, message=error_msg)
|
||
|
for splits_list in nested_splits_lists[1:]
|
||
|
for (s1, s2) in zip(nested_splits_lists[0], splits_list)
|
||
|
]
|
||
|
|
||
|
|
||
|
def binary_matches(y_true, y_pred, threshold=0.5):
|
||
|
"""Creates int Tensor, 1 for label-prediction match, 0 for mismatch.
|
||
|
|
||
|
Args:
|
||
|
y_true: Ground truth values, of shape (batch_size, d0, .. dN).
|
||
|
y_pred: The predicted values, of shape (batch_size, d0, .. dN).
|
||
|
threshold: (Optional) Float representing the threshold for deciding
|
||
|
whether prediction values are 1 or 0.
|
||
|
|
||
|
Returns:
|
||
|
Binary matches, of shape (batch_size, d0, .. dN).
|
||
|
"""
|
||
|
y_pred = tf.convert_to_tensor(y_pred)
|
||
|
threshold = tf.cast(threshold, y_pred.dtype)
|
||
|
y_pred = tf.cast(y_pred > threshold, y_pred.dtype)
|
||
|
return tf.cast(tf.equal(y_true, y_pred), backend.floatx())
|
||
|
|
||
|
|
||
|
def sparse_categorical_matches(y_true, y_pred):
|
||
|
"""Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch.
|
||
|
|
||
|
You can provide logits of classes as `y_pred`, since argmax of
|
||
|
logits and probabilities are same.
|
||
|
|
||
|
Args:
|
||
|
y_true: Integer ground truth values.
|
||
|
y_pred: The prediction values.
|
||
|
|
||
|
Returns:
|
||
|
Match tensor: 1.0 for label-prediction match, 0.0 for mismatch.
|
||
|
"""
|
||
|
reshape_matches = False
|
||
|
y_pred = tf.convert_to_tensor(y_pred)
|
||
|
y_true = tf.convert_to_tensor(y_true)
|
||
|
y_true_org_shape = tf.shape(y_true)
|
||
|
y_pred_rank = y_pred.shape.ndims
|
||
|
y_true_rank = y_true.shape.ndims
|
||
|
|
||
|
# If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
|
||
|
if (
|
||
|
(y_true_rank is not None)
|
||
|
and (y_pred_rank is not None)
|
||
|
and (len(backend.int_shape(y_true)) == len(backend.int_shape(y_pred)))
|
||
|
):
|
||
|
y_true = tf.squeeze(y_true, [-1])
|
||
|
reshape_matches = True
|
||
|
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||
|
|
||
|
# If the predicted output and actual output types don't match, force cast
|
||
|
# them to match.
|
||
|
if backend.dtype(y_pred) != backend.dtype(y_true):
|
||
|
y_pred = tf.cast(y_pred, backend.dtype(y_true))
|
||
|
matches = tf.cast(tf.equal(y_true, y_pred), backend.floatx())
|
||
|
if reshape_matches:
|
||
|
matches = tf.reshape(matches, shape=y_true_org_shape)
|
||
|
return matches
|
||
|
|
||
|
|
||
|
def sparse_top_k_categorical_matches(y_true, y_pred, k=5):
|
||
|
"""Creates float Tensor, 1.0 for label-TopK_prediction match, 0.0 for
|
||
|
mismatch.
|
||
|
|
||
|
Args:
|
||
|
y_true: tensor of true targets.
|
||
|
y_pred: tensor of predicted targets.
|
||
|
k: (Optional) Number of top elements to look at for computing accuracy.
|
||
|
Defaults to 5.
|
||
|
|
||
|
Returns:
|
||
|
Match tensor: 1.0 for label-prediction match, 0.0 for mismatch.
|
||
|
"""
|
||
|
reshape_matches = False
|
||
|
y_true = tf.convert_to_tensor(y_true)
|
||
|
y_pred = tf.convert_to_tensor(y_pred)
|
||
|
y_true_rank = y_true.shape.ndims
|
||
|
y_pred_rank = y_pred.shape.ndims
|
||
|
y_true_org_shape = tf.shape(y_true)
|
||
|
|
||
|
# Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
|
||
|
if (y_true_rank is not None) and (y_pred_rank is not None):
|
||
|
if y_pred_rank > 2:
|
||
|
y_pred = tf.reshape(y_pred, [-1, y_pred.shape[-1]])
|
||
|
if y_true_rank > 1:
|
||
|
reshape_matches = True
|
||
|
y_true = tf.reshape(y_true, [-1])
|
||
|
|
||
|
matches = tf.cast(
|
||
|
tf.math.in_top_k(
|
||
|
predictions=y_pred, targets=tf.cast(y_true, "int32"), k=k
|
||
|
),
|
||
|
dtype=backend.floatx(),
|
||
|
)
|
||
|
|
||
|
# returned matches is expected to have same shape as y_true input
|
||
|
if reshape_matches:
|
||
|
return tf.reshape(matches, shape=y_true_org_shape)
|
||
|
|
||
|
return matches
|