137 lines
4.0 KiB
Python
137 lines
4.0 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Hinge metrics."""
|
||
|
|
||
|
from keras.dtensor import utils as dtensor_utils
|
||
|
from keras.losses import categorical_hinge
|
||
|
from keras.losses import hinge
|
||
|
from keras.losses import squared_hinge
|
||
|
from keras.metrics import base_metric
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
|
||
|
@keras_export("keras.metrics.Hinge")
|
||
|
class Hinge(base_metric.MeanMetricWrapper):
|
||
|
"""Computes the hinge metric between `y_true` and `y_pred`.
|
||
|
|
||
|
`y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
|
||
|
provided we will convert them to -1 or 1.
|
||
|
|
||
|
Args:
|
||
|
name: (Optional) string name of the metric instance.
|
||
|
dtype: (Optional) data type of the metric result.
|
||
|
|
||
|
Standalone usage:
|
||
|
|
||
|
>>> m = tf.keras.metrics.Hinge()
|
||
|
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
|
||
|
>>> m.result().numpy()
|
||
|
1.3
|
||
|
|
||
|
>>> m.reset_state()
|
||
|
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
|
||
|
... sample_weight=[1, 0])
|
||
|
>>> m.result().numpy()
|
||
|
1.1
|
||
|
|
||
|
Usage with `compile()` API:
|
||
|
|
||
|
```python
|
||
|
model.compile(
|
||
|
optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()])
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
@dtensor_utils.inject_mesh
|
||
|
def __init__(self, name="hinge", dtype=None):
|
||
|
super().__init__(hinge, name, dtype=dtype)
|
||
|
|
||
|
|
||
|
@keras_export("keras.metrics.SquaredHinge")
|
||
|
class SquaredHinge(base_metric.MeanMetricWrapper):
|
||
|
"""Computes the squared hinge metric between `y_true` and `y_pred`.
|
||
|
|
||
|
`y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
|
||
|
provided we will convert them to -1 or 1.
|
||
|
|
||
|
Args:
|
||
|
name: (Optional) string name of the metric instance.
|
||
|
dtype: (Optional) data type of the metric result.
|
||
|
|
||
|
Standalone usage:
|
||
|
|
||
|
>>> m = tf.keras.metrics.SquaredHinge()
|
||
|
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
|
||
|
>>> m.result().numpy()
|
||
|
1.86
|
||
|
|
||
|
>>> m.reset_state()
|
||
|
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
|
||
|
... sample_weight=[1, 0])
|
||
|
>>> m.result().numpy()
|
||
|
1.46
|
||
|
|
||
|
Usage with `compile()` API:
|
||
|
|
||
|
```python
|
||
|
model.compile(
|
||
|
optimizer='sgd',
|
||
|
loss='mse',
|
||
|
metrics=[tf.keras.metrics.SquaredHinge()])
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
@dtensor_utils.inject_mesh
|
||
|
def __init__(self, name="squared_hinge", dtype=None):
|
||
|
super().__init__(squared_hinge, name, dtype=dtype)
|
||
|
|
||
|
|
||
|
@keras_export("keras.metrics.CategoricalHinge")
|
||
|
class CategoricalHinge(base_metric.MeanMetricWrapper):
|
||
|
"""Computes the categorical hinge metric between `y_true` and `y_pred`.
|
||
|
|
||
|
Args:
|
||
|
name: (Optional) string name of the metric instance.
|
||
|
dtype: (Optional) data type of the metric result.
|
||
|
|
||
|
Standalone usage:
|
||
|
|
||
|
>>> m = tf.keras.metrics.CategoricalHinge()
|
||
|
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
|
||
|
>>> m.result().numpy()
|
||
|
1.4000001
|
||
|
|
||
|
>>> m.reset_state()
|
||
|
>>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
|
||
|
... sample_weight=[1, 0])
|
||
|
>>> m.result().numpy()
|
||
|
1.2
|
||
|
|
||
|
Usage with `compile()` API:
|
||
|
|
||
|
```python
|
||
|
model.compile(
|
||
|
optimizer='sgd',
|
||
|
loss='mse',
|
||
|
metrics=[tf.keras.metrics.CategoricalHinge()])
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
@dtensor_utils.inject_mesh
|
||
|
def __init__(self, name="categorical_hinge", dtype=None):
|
||
|
super().__init__(categorical_hinge, name, dtype=dtype)
|