# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Constraints: functions that impose constraints on weight values.""" import tensorflow.compat.v2 as tf from keras import backend from keras.saving.legacy import serialization as legacy_serialization from keras.saving.legacy.serialization import deserialize_keras_object from keras.saving.legacy.serialization import serialize_keras_object # isort: off from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls @keras_export("keras.constraints.Constraint") class Constraint: """Base class for weight constraints. A `Constraint` instance works like a stateless function. Users who subclass this class should override the `__call__` method, which takes a single weight parameter and return a projected version of that parameter (e.g. normalized or clipped). Constraints can be used with various Keras layers via the `kernel_constraint` or `bias_constraint` arguments. Here's a simple example of a non-negative weight constraint: >>> class NonNegative(tf.keras.constraints.Constraint): ... ... def __call__(self, w): ... return w * tf.cast(tf.math.greater_equal(w, 0.), w.dtype) >>> weight = tf.constant((-1.0, 1.0)) >>> NonNegative()(weight) >>> tf.keras.layers.Dense(4, kernel_constraint=NonNegative()) """ def __call__(self, w): """Applies the constraint to the input weight variable. By default, the inputs weight variable is not modified. Users should override this method to implement their own projection function. Args: w: Input weight variable. Returns: Projected variable (by default, returns unmodified inputs). """ return w def get_config(self): """Returns a Python dict of the object config. A constraint config is a Python dictionary (JSON-serializable) that can be used to reinstantiate the same object. Returns: Python dict containing the configuration of the constraint object. """ return {} @classmethod def from_config(cls, config): """Instantiates a weight constraint from a configuration dictionary. Example: ```python constraint = UnitNorm() config = constraint.get_config() constraint = UnitNorm.from_config(config) ``` Args: config: A Python dictionary, the output of `get_config`. Returns: A `tf.keras.constraints.Constraint` instance. """ return cls(**config) @keras_export("keras.constraints.MaxNorm", "keras.constraints.max_norm") class MaxNorm(Constraint): """MaxNorm weight constraint. Constrains the weights incident to each hidden unit to have a norm less than or equal to a desired value. Also available via the shortcut function `tf.keras.constraints.max_norm`. Args: max_value: the maximum norm value for the incoming weights. axis: integer, axis along which to calculate weight norms. For instance, in a `Dense` layer the weight matrix has shape `(input_dim, output_dim)`, set `axis` to `0` to constrain each weight vector of length `(input_dim,)`. In a `Conv2D` layer with `data_format="channels_last"`, the weight tensor has shape `(rows, cols, input_depth, output_depth)`, set `axis` to `[0, 1, 2]` to constrain the weights of each filter tensor of size `(rows, cols, input_depth)`. """ def __init__(self, max_value=2, axis=0): self.max_value = max_value self.axis = axis @doc_controls.do_not_generate_docs def __call__(self, w): norms = backend.sqrt( tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True) ) desired = backend.clip(norms, 0, self.max_value) return w * (desired / (backend.epsilon() + norms)) @doc_controls.do_not_generate_docs def get_config(self): return {"max_value": self.max_value, "axis": self.axis} @keras_export("keras.constraints.NonNeg", "keras.constraints.non_neg") class NonNeg(Constraint): """Constrains the weights to be non-negative. Also available via the shortcut function `tf.keras.constraints.non_neg`. """ def __call__(self, w): return w * tf.cast(tf.greater_equal(w, 0.0), backend.floatx()) @keras_export("keras.constraints.UnitNorm", "keras.constraints.unit_norm") class UnitNorm(Constraint): """Constrains the weights incident to each hidden unit to have unit norm. Also available via the shortcut function `tf.keras.constraints.unit_norm`. Args: axis: integer, axis along which to calculate weight norms. For instance, in a `Dense` layer the weight matrix has shape `(input_dim, output_dim)`, set `axis` to `0` to constrain each weight vector of length `(input_dim,)`. In a `Conv2D` layer with `data_format="channels_last"`, the weight tensor has shape `(rows, cols, input_depth, output_depth)`, set `axis` to `[0, 1, 2]` to constrain the weights of each filter tensor of size `(rows, cols, input_depth)`. """ def __init__(self, axis=0): self.axis = axis @doc_controls.do_not_generate_docs def __call__(self, w): return w / ( backend.epsilon() + backend.sqrt( tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True) ) ) @doc_controls.do_not_generate_docs def get_config(self): return {"axis": self.axis} @keras_export("keras.constraints.MinMaxNorm", "keras.constraints.min_max_norm") class MinMaxNorm(Constraint): """MinMaxNorm weight constraint. Constrains the weights incident to each hidden unit to have the norm between a lower bound and an upper bound. Also available via the shortcut function `tf.keras.constraints.min_max_norm`. Args: min_value: the minimum norm for the incoming weights. max_value: the maximum norm for the incoming weights. rate: rate for enforcing the constraint: weights will be rescaled to yield `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. Effectively, this means that rate=1.0 stands for strict enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step to slowly move towards a value inside the desired interval. axis: integer, axis along which to calculate weight norms. For instance, in a `Dense` layer the weight matrix has shape `(input_dim, output_dim)`, set `axis` to `0` to constrain each weight vector of length `(input_dim,)`. In a `Conv2D` layer with `data_format="channels_last"`, the weight tensor has shape `(rows, cols, input_depth, output_depth)`, set `axis` to `[0, 1, 2]` to constrain the weights of each filter tensor of size `(rows, cols, input_depth)`. """ def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0): self.min_value = min_value self.max_value = max_value self.rate = rate self.axis = axis @doc_controls.do_not_generate_docs def __call__(self, w): norms = backend.sqrt( tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True) ) desired = ( self.rate * backend.clip(norms, self.min_value, self.max_value) + (1 - self.rate) * norms ) return w * (desired / (backend.epsilon() + norms)) @doc_controls.do_not_generate_docs def get_config(self): return { "min_value": self.min_value, "max_value": self.max_value, "rate": self.rate, "axis": self.axis, } @keras_export( "keras.constraints.RadialConstraint", "keras.constraints.radial_constraint" ) class RadialConstraint(Constraint): """Constrains `Conv2D` kernel weights to be the same for each radius. Also available via the shortcut function `tf.keras.constraints.radial_constraint`. For example, the desired output for the following 4-by-4 kernel: ``` kernel = [[v_00, v_01, v_02, v_03], [v_10, v_11, v_12, v_13], [v_20, v_21, v_22, v_23], [v_30, v_31, v_32, v_33]] ``` is this:: ``` kernel = [[v_11, v_11, v_11, v_11], [v_11, v_33, v_33, v_11], [v_11, v_33, v_33, v_11], [v_11, v_11, v_11, v_11]] ``` This constraint can be applied to any `Conv2D` layer version, including `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or `"channels_first"` data format. The method assumes the weight tensor is of shape `(rows, cols, input_depth, output_depth)`. """ @doc_controls.do_not_generate_docs def __call__(self, w): w_shape = w.shape if w_shape.rank is None or w_shape.rank != 4: raise ValueError( "The weight tensor must have rank 4. " f"Received weight tensor with shape: {w_shape}" ) height, width, channels, kernels = w_shape w = backend.reshape(w, (height, width, channels * kernels)) # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once # backend.switch is supported. w = backend.map_fn( self._kernel_constraint, backend.stack(tf.unstack(w, axis=-1), axis=0), ) return backend.reshape( backend.stack(tf.unstack(w, axis=0), axis=-1), (height, width, channels, kernels), ) def _kernel_constraint(self, kernel): """Radially constraints a kernel with shape (height, width, channels).""" padding = backend.constant([[1, 1], [1, 1]], dtype="int32") kernel_shape = backend.shape(kernel)[0] start = backend.cast(kernel_shape / 2, "int32") kernel_new = backend.switch( backend.cast(tf.math.floormod(kernel_shape, 2), "bool"), lambda: kernel[start - 1 : start, start - 1 : start], lambda: kernel[start - 1 : start, start - 1 : start] + backend.zeros((2, 2), dtype=kernel.dtype), ) index = backend.switch( backend.cast(tf.math.floormod(kernel_shape, 2), "bool"), lambda: backend.constant(0, dtype="int32"), lambda: backend.constant(1, dtype="int32"), ) while_condition = lambda index, *args: backend.less(index, start) def body_fn(i, array): return i + 1, tf.pad( array, padding, constant_values=kernel[start + i, start + i] ) _, kernel_new = tf.compat.v1.while_loop( while_condition, body_fn, [index, kernel_new], shape_invariants=[index.get_shape(), tf.TensorShape([None, None])], ) return kernel_new # Aliases. max_norm = MaxNorm non_neg = NonNeg unit_norm = UnitNorm min_max_norm = MinMaxNorm radial_constraint = RadialConstraint # Legacy aliases. maxnorm = max_norm nonneg = non_neg unitnorm = unit_norm @keras_export("keras.constraints.serialize") def serialize(constraint, use_legacy_format=False): if use_legacy_format: return legacy_serialization.serialize_keras_object(constraint) return serialize_keras_object(constraint) @keras_export("keras.constraints.deserialize") def deserialize(config, custom_objects=None, use_legacy_format=False): if use_legacy_format: return legacy_serialization.deserialize_keras_object( config, module_objects=globals(), custom_objects=custom_objects, printable_module_name="constraint", ) return deserialize_keras_object( config, module_objects=globals(), custom_objects=custom_objects, printable_module_name="constraint", ) @keras_export("keras.constraints.get") def get(identifier): """Retrieves a Keras constraint function.""" if identifier is None: return None if isinstance(identifier, dict): use_legacy_format = "module" not in identifier return deserialize(identifier, use_legacy_format=use_legacy_format) elif isinstance(identifier, str): config = {"class_name": str(identifier), "config": {}} return deserialize(config) elif callable(identifier): return identifier else: raise ValueError( f"Could not interpret constraint function identifier: {identifier}" )