399 lines
13 KiB
Python
399 lines
13 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
|
|
"""Constraints: functions that impose constraints on weight values."""
|
|
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras import backend
|
|
from keras.saving.legacy import serialization as legacy_serialization
|
|
from keras.saving.legacy.serialization import deserialize_keras_object
|
|
from keras.saving.legacy.serialization import serialize_keras_object
|
|
|
|
# isort: off
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
from tensorflow.tools.docs import doc_controls
|
|
|
|
|
|
@keras_export("keras.constraints.Constraint")
|
|
class Constraint:
|
|
"""Base class for weight constraints.
|
|
|
|
A `Constraint` instance works like a stateless function.
|
|
Users who subclass this
|
|
class should override the `__call__` method, which takes a single
|
|
weight parameter and return a projected version of that parameter
|
|
(e.g. normalized or clipped). Constraints can be used with various Keras
|
|
layers via the `kernel_constraint` or `bias_constraint` arguments.
|
|
|
|
Here's a simple example of a non-negative weight constraint:
|
|
|
|
>>> class NonNegative(tf.keras.constraints.Constraint):
|
|
...
|
|
... def __call__(self, w):
|
|
... return w * tf.cast(tf.math.greater_equal(w, 0.), w.dtype)
|
|
|
|
>>> weight = tf.constant((-1.0, 1.0))
|
|
>>> NonNegative()(weight)
|
|
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 1.],
|
|
dtype=float32)>
|
|
|
|
>>> tf.keras.layers.Dense(4, kernel_constraint=NonNegative())
|
|
"""
|
|
|
|
def __call__(self, w):
|
|
"""Applies the constraint to the input weight variable.
|
|
|
|
By default, the inputs weight variable is not modified.
|
|
Users should override this method to implement their own projection
|
|
function.
|
|
|
|
Args:
|
|
w: Input weight variable.
|
|
|
|
Returns:
|
|
Projected variable (by default, returns unmodified inputs).
|
|
"""
|
|
return w
|
|
|
|
def get_config(self):
|
|
"""Returns a Python dict of the object config.
|
|
|
|
A constraint config is a Python dictionary (JSON-serializable) that can
|
|
be used to reinstantiate the same object.
|
|
|
|
Returns:
|
|
Python dict containing the configuration of the constraint object.
|
|
"""
|
|
return {}
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
"""Instantiates a weight constraint from a configuration dictionary.
|
|
|
|
Example:
|
|
|
|
```python
|
|
constraint = UnitNorm()
|
|
config = constraint.get_config()
|
|
constraint = UnitNorm.from_config(config)
|
|
```
|
|
|
|
Args:
|
|
config: A Python dictionary, the output of `get_config`.
|
|
|
|
Returns:
|
|
A `tf.keras.constraints.Constraint` instance.
|
|
"""
|
|
return cls(**config)
|
|
|
|
|
|
@keras_export("keras.constraints.MaxNorm", "keras.constraints.max_norm")
|
|
class MaxNorm(Constraint):
|
|
"""MaxNorm weight constraint.
|
|
|
|
Constrains the weights incident to each hidden unit
|
|
to have a norm less than or equal to a desired value.
|
|
|
|
Also available via the shortcut function `tf.keras.constraints.max_norm`.
|
|
|
|
Args:
|
|
max_value: the maximum norm value for the incoming weights.
|
|
axis: integer, axis along which to calculate weight norms.
|
|
For instance, in a `Dense` layer the weight matrix
|
|
has shape `(input_dim, output_dim)`,
|
|
set `axis` to `0` to constrain each weight vector
|
|
of length `(input_dim,)`.
|
|
In a `Conv2D` layer with `data_format="channels_last"`,
|
|
the weight tensor has shape
|
|
`(rows, cols, input_depth, output_depth)`,
|
|
set `axis` to `[0, 1, 2]`
|
|
to constrain the weights of each filter tensor of size
|
|
`(rows, cols, input_depth)`.
|
|
|
|
"""
|
|
|
|
def __init__(self, max_value=2, axis=0):
|
|
self.max_value = max_value
|
|
self.axis = axis
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def __call__(self, w):
|
|
norms = backend.sqrt(
|
|
tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True)
|
|
)
|
|
desired = backend.clip(norms, 0, self.max_value)
|
|
return w * (desired / (backend.epsilon() + norms))
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def get_config(self):
|
|
return {"max_value": self.max_value, "axis": self.axis}
|
|
|
|
|
|
@keras_export("keras.constraints.NonNeg", "keras.constraints.non_neg")
|
|
class NonNeg(Constraint):
|
|
"""Constrains the weights to be non-negative.
|
|
|
|
Also available via the shortcut function `tf.keras.constraints.non_neg`.
|
|
"""
|
|
|
|
def __call__(self, w):
|
|
return w * tf.cast(tf.greater_equal(w, 0.0), backend.floatx())
|
|
|
|
|
|
@keras_export("keras.constraints.UnitNorm", "keras.constraints.unit_norm")
|
|
class UnitNorm(Constraint):
|
|
"""Constrains the weights incident to each hidden unit to have unit norm.
|
|
|
|
Also available via the shortcut function `tf.keras.constraints.unit_norm`.
|
|
|
|
Args:
|
|
axis: integer, axis along which to calculate weight norms.
|
|
For instance, in a `Dense` layer the weight matrix
|
|
has shape `(input_dim, output_dim)`,
|
|
set `axis` to `0` to constrain each weight vector
|
|
of length `(input_dim,)`.
|
|
In a `Conv2D` layer with `data_format="channels_last"`,
|
|
the weight tensor has shape
|
|
`(rows, cols, input_depth, output_depth)`,
|
|
set `axis` to `[0, 1, 2]`
|
|
to constrain the weights of each filter tensor of size
|
|
`(rows, cols, input_depth)`.
|
|
"""
|
|
|
|
def __init__(self, axis=0):
|
|
self.axis = axis
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def __call__(self, w):
|
|
return w / (
|
|
backend.epsilon()
|
|
+ backend.sqrt(
|
|
tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True)
|
|
)
|
|
)
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def get_config(self):
|
|
return {"axis": self.axis}
|
|
|
|
|
|
@keras_export("keras.constraints.MinMaxNorm", "keras.constraints.min_max_norm")
|
|
class MinMaxNorm(Constraint):
|
|
"""MinMaxNorm weight constraint.
|
|
|
|
Constrains the weights incident to each hidden unit
|
|
to have the norm between a lower bound and an upper bound.
|
|
|
|
Also available via the shortcut function
|
|
`tf.keras.constraints.min_max_norm`.
|
|
|
|
Args:
|
|
min_value: the minimum norm for the incoming weights.
|
|
max_value: the maximum norm for the incoming weights.
|
|
rate: rate for enforcing the constraint: weights will be
|
|
rescaled to yield
|
|
`(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
|
|
Effectively, this means that rate=1.0 stands for strict
|
|
enforcement of the constraint, while rate<1.0 means that
|
|
weights will be rescaled at each step to slowly move
|
|
towards a value inside the desired interval.
|
|
axis: integer, axis along which to calculate weight norms.
|
|
For instance, in a `Dense` layer the weight matrix
|
|
has shape `(input_dim, output_dim)`,
|
|
set `axis` to `0` to constrain each weight vector
|
|
of length `(input_dim,)`.
|
|
In a `Conv2D` layer with `data_format="channels_last"`,
|
|
the weight tensor has shape
|
|
`(rows, cols, input_depth, output_depth)`,
|
|
set `axis` to `[0, 1, 2]`
|
|
to constrain the weights of each filter tensor of size
|
|
`(rows, cols, input_depth)`.
|
|
"""
|
|
|
|
def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
|
|
self.min_value = min_value
|
|
self.max_value = max_value
|
|
self.rate = rate
|
|
self.axis = axis
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def __call__(self, w):
|
|
norms = backend.sqrt(
|
|
tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True)
|
|
)
|
|
desired = (
|
|
self.rate * backend.clip(norms, self.min_value, self.max_value)
|
|
+ (1 - self.rate) * norms
|
|
)
|
|
return w * (desired / (backend.epsilon() + norms))
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def get_config(self):
|
|
return {
|
|
"min_value": self.min_value,
|
|
"max_value": self.max_value,
|
|
"rate": self.rate,
|
|
"axis": self.axis,
|
|
}
|
|
|
|
|
|
@keras_export(
|
|
"keras.constraints.RadialConstraint", "keras.constraints.radial_constraint"
|
|
)
|
|
class RadialConstraint(Constraint):
|
|
"""Constrains `Conv2D` kernel weights to be the same for each radius.
|
|
|
|
Also available via the shortcut function
|
|
`tf.keras.constraints.radial_constraint`.
|
|
|
|
For example, the desired output for the following 4-by-4 kernel:
|
|
|
|
```
|
|
kernel = [[v_00, v_01, v_02, v_03],
|
|
[v_10, v_11, v_12, v_13],
|
|
[v_20, v_21, v_22, v_23],
|
|
[v_30, v_31, v_32, v_33]]
|
|
```
|
|
|
|
is this::
|
|
|
|
```
|
|
kernel = [[v_11, v_11, v_11, v_11],
|
|
[v_11, v_33, v_33, v_11],
|
|
[v_11, v_33, v_33, v_11],
|
|
[v_11, v_11, v_11, v_11]]
|
|
```
|
|
|
|
This constraint can be applied to any `Conv2D` layer version, including
|
|
`Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"`
|
|
or `"channels_first"` data format. The method assumes the weight tensor is
|
|
of shape `(rows, cols, input_depth, output_depth)`.
|
|
"""
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def __call__(self, w):
|
|
w_shape = w.shape
|
|
if w_shape.rank is None or w_shape.rank != 4:
|
|
raise ValueError(
|
|
"The weight tensor must have rank 4. "
|
|
f"Received weight tensor with shape: {w_shape}"
|
|
)
|
|
|
|
height, width, channels, kernels = w_shape
|
|
w = backend.reshape(w, (height, width, channels * kernels))
|
|
# TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once
|
|
# backend.switch is supported.
|
|
w = backend.map_fn(
|
|
self._kernel_constraint,
|
|
backend.stack(tf.unstack(w, axis=-1), axis=0),
|
|
)
|
|
return backend.reshape(
|
|
backend.stack(tf.unstack(w, axis=0), axis=-1),
|
|
(height, width, channels, kernels),
|
|
)
|
|
|
|
def _kernel_constraint(self, kernel):
|
|
"""Radially constraints a kernel with shape (height, width,
|
|
channels)."""
|
|
padding = backend.constant([[1, 1], [1, 1]], dtype="int32")
|
|
|
|
kernel_shape = backend.shape(kernel)[0]
|
|
start = backend.cast(kernel_shape / 2, "int32")
|
|
|
|
kernel_new = backend.switch(
|
|
backend.cast(tf.math.floormod(kernel_shape, 2), "bool"),
|
|
lambda: kernel[start - 1 : start, start - 1 : start],
|
|
lambda: kernel[start - 1 : start, start - 1 : start]
|
|
+ backend.zeros((2, 2), dtype=kernel.dtype),
|
|
)
|
|
index = backend.switch(
|
|
backend.cast(tf.math.floormod(kernel_shape, 2), "bool"),
|
|
lambda: backend.constant(0, dtype="int32"),
|
|
lambda: backend.constant(1, dtype="int32"),
|
|
)
|
|
while_condition = lambda index, *args: backend.less(index, start)
|
|
|
|
def body_fn(i, array):
|
|
return i + 1, tf.pad(
|
|
array, padding, constant_values=kernel[start + i, start + i]
|
|
)
|
|
|
|
_, kernel_new = tf.compat.v1.while_loop(
|
|
while_condition,
|
|
body_fn,
|
|
[index, kernel_new],
|
|
shape_invariants=[index.get_shape(), tf.TensorShape([None, None])],
|
|
)
|
|
return kernel_new
|
|
|
|
|
|
# Aliases.
|
|
|
|
max_norm = MaxNorm
|
|
non_neg = NonNeg
|
|
unit_norm = UnitNorm
|
|
min_max_norm = MinMaxNorm
|
|
radial_constraint = RadialConstraint
|
|
|
|
# Legacy aliases.
|
|
maxnorm = max_norm
|
|
nonneg = non_neg
|
|
unitnorm = unit_norm
|
|
|
|
|
|
@keras_export("keras.constraints.serialize")
|
|
def serialize(constraint, use_legacy_format=False):
|
|
if use_legacy_format:
|
|
return legacy_serialization.serialize_keras_object(constraint)
|
|
return serialize_keras_object(constraint)
|
|
|
|
|
|
@keras_export("keras.constraints.deserialize")
|
|
def deserialize(config, custom_objects=None, use_legacy_format=False):
|
|
if use_legacy_format:
|
|
return legacy_serialization.deserialize_keras_object(
|
|
config,
|
|
module_objects=globals(),
|
|
custom_objects=custom_objects,
|
|
printable_module_name="constraint",
|
|
)
|
|
return deserialize_keras_object(
|
|
config,
|
|
module_objects=globals(),
|
|
custom_objects=custom_objects,
|
|
printable_module_name="constraint",
|
|
)
|
|
|
|
|
|
@keras_export("keras.constraints.get")
|
|
def get(identifier):
|
|
"""Retrieves a Keras constraint function."""
|
|
if identifier is None:
|
|
return None
|
|
if isinstance(identifier, dict):
|
|
use_legacy_format = "module" not in identifier
|
|
return deserialize(identifier, use_legacy_format=use_legacy_format)
|
|
elif isinstance(identifier, str):
|
|
config = {"class_name": str(identifier), "config": {}}
|
|
return deserialize(config)
|
|
elif callable(identifier):
|
|
return identifier
|
|
else:
|
|
raise ValueError(
|
|
f"Could not interpret constraint function identifier: {identifier}"
|
|
)
|