# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Keras-based einsum dense layer.""" import re import tensorflow.compat.v2 as tf from keras import activations from keras import constraints from keras import initializers from keras import regularizers from keras.engine.base_layer import Layer # isort: off from tensorflow.python.util.tf_export import keras_export @keras_export( "keras.layers.EinsumDense", "keras.layers.experimental.EinsumDense" ) class EinsumDense(Layer): """A layer that uses `tf.einsum` as the backing computation. This layer can perform einsum calculations of arbitrary dimensionality. Args: equation: An equation describing the einsum to perform. This equation must be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis expression sequence. output_shape: The expected shape of the output tensor (excluding the batch dimension and any dimensions represented by ellipses). You can specify None for any dimension that is unknown or can be inferred from the input shape. activation: Activation function to use. If you don't specify anything, no activation is applied (that is, a "linear" activation: `a(x) = x`). bias_axes: A string containing the output dimension(s) to apply a bias to. Each character in the `bias_axes` string should correspond to a character in the output portion of the `equation` string. kernel_initializer: Initializer for the `kernel` weights matrix. bias_initializer: Initializer for the bias vector. kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to the output of the layer (its "activation"). kernel_constraint: Constraint function applied to the `kernel` weights matrix. bias_constraint: Constraint function applied to the bias vector. Examples: **Biased dense layer with einsums** This example shows how to instantiate a standard Keras dense layer using einsum operations. This example is equivalent to `tf.keras.layers.Dense(64, use_bias=True)`. >>> layer = tf.keras.layers.EinsumDense("ab,bc->ac", ... output_shape=64, ... bias_axes="c") >>> input_tensor = tf.keras.Input(shape=[32]) >>> output_tensor = layer(input_tensor) >>> output_tensor <... shape=(None, 64) dtype=...> **Applying a dense layer to a sequence** This example shows how to instantiate a layer that applies the same dense operation to every element in a sequence. Here, the `output_shape` has two values (since there are two non-batch dimensions in the output); the first dimension in the `output_shape` is `None`, because the sequence dimension `b` has an unknown shape. >>> layer = tf.keras.layers.EinsumDense("abc,cd->abd", ... output_shape=(None, 64), ... bias_axes="d") >>> input_tensor = tf.keras.Input(shape=[32, 128]) >>> output_tensor = layer(input_tensor) >>> output_tensor <... shape=(None, 32, 64) dtype=...> **Applying a dense layer to a sequence using ellipses** This example shows how to instantiate a layer that applies the same dense operation to every element in a sequence, but uses the ellipsis notation instead of specifying the batch and sequence dimensions. Because we are using ellipsis notation and have specified only one axis, the `output_shape` arg is a single value. When instantiated in this way, the layer can handle any number of sequence dimensions - including the case where no sequence dimension exists. >>> layer = tf.keras.layers.EinsumDense("...x,xy->...y", ... output_shape=64, ... bias_axes="y") >>> input_tensor = tf.keras.Input(shape=[32, 128]) >>> output_tensor = layer(input_tensor) >>> output_tensor <... shape=(None, 32, 64) dtype=...> """ def __init__( self, equation, output_shape, activation=None, bias_axes=None, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs, ): super().__init__(**kwargs) self.equation = equation if isinstance(output_shape, int): self.partial_output_shape = [output_shape] else: self.partial_output_shape = list(output_shape) self.bias_axes = bias_axes self.activation = activations.get(activation) self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) def build(self, input_shape): input_shape = tf.TensorShape(input_shape) shape_data = _analyze_einsum_string( self.equation, self.bias_axes, input_shape, self.partial_output_shape, ) kernel_shape, bias_shape, self.full_output_shape = shape_data self.kernel = self.add_weight( "kernel", shape=kernel_shape, initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True, ) if bias_shape is not None: self.bias = self.add_weight( "bias", shape=bias_shape, initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, dtype=self.dtype, trainable=True, ) else: self.bias = None super().build(input_shape) def compute_output_shape(self, _): return tf.TensorShape(self.full_output_shape) def get_config(self): config = { "output_shape": self.partial_output_shape, "equation": self.equation, "activation": activations.serialize(self.activation), "bias_axes": self.bias_axes, "kernel_initializer": initializers.serialize( self.kernel_initializer ), "bias_initializer": initializers.serialize(self.bias_initializer), "kernel_regularizer": regularizers.serialize( self.kernel_regularizer ), "bias_regularizer": regularizers.serialize(self.bias_regularizer), "activity_regularizer": regularizers.serialize( self.activity_regularizer ), "kernel_constraint": constraints.serialize(self.kernel_constraint), "bias_constraint": constraints.serialize(self.bias_constraint), } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, inputs): ret = tf.einsum(self.equation, inputs, self.kernel) if self.bias is not None: ret += self.bias if self.activation is not None: ret = self.activation(ret) return ret def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape): """Analyzes an einsum string to determine the required weight shape.""" dot_replaced_string = re.sub(r"\.\.\.", "0", equation) # This is the case where no ellipses are present in the string. split_string = re.match( "([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", dot_replaced_string ) if split_string: return _analyze_split_string( split_string, bias_axes, input_shape, output_shape ) # This is the case where ellipses are present on the left. split_string = re.match( "0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", dot_replaced_string ) if split_string: return _analyze_split_string( split_string, bias_axes, input_shape, output_shape, left_elided=True ) # This is the case where ellipses are present on the right. split_string = re.match( "([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", dot_replaced_string ) if split_string: return _analyze_split_string( split_string, bias_axes, input_shape, output_shape ) raise ValueError( f"Invalid einsum equation '{equation}'. Equations must be in the form " "[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." ) def _analyze_split_string( split_string, bias_axes, input_shape, output_shape, left_elided=False ): """Analyze an pre-split einsum string to find the weight shape.""" input_spec = split_string.group(1) weight_spec = split_string.group(2) output_spec = split_string.group(3) elided = len(input_shape) - len(input_spec) if isinstance(output_shape, int): output_shape = [output_shape] else: output_shape = list(output_shape) output_shape.insert(0, input_shape[0]) if elided > 0 and left_elided: for i in range(1, elided): # We already inserted the 0th input dimension at dim 0, so we need # to start at location 1 here. output_shape.insert(1, input_shape[i]) elif elided > 0 and not left_elided: for i in range(len(input_shape) - elided, len(input_shape)): output_shape.append(input_shape[i]) if left_elided: # If we have beginning dimensions elided, we need to use negative # indexing to determine where in the input dimension our values are. input_dim_map = { dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec) } # Because we've constructed the full output shape already, we don't need # to do negative indexing. output_dim_map = { dim: (i + elided) for i, dim in enumerate(output_spec) } else: input_dim_map = {dim: i for i, dim in enumerate(input_spec)} output_dim_map = {dim: i for i, dim in enumerate(output_spec)} for dim in input_spec: input_shape_at_dim = input_shape[input_dim_map[dim]] if dim in output_dim_map: output_shape_at_dim = output_shape[output_dim_map[dim]] if ( output_shape_at_dim is not None and output_shape_at_dim != input_shape_at_dim ): raise ValueError( "Input shape and output shape do not match at shared " f"dimension '{dim}'. Input shape is {input_shape_at_dim}, " "and output shape " f"is {output_shape[output_dim_map[dim]]}." ) for dim in output_spec: if dim not in input_spec and dim not in weight_spec: raise ValueError( f"Dimension '{dim}' was specified in the output " f"'{output_spec}' but has no corresponding dim in the input " f"spec '{input_spec}' or weight spec '{output_spec}'" ) weight_shape = [] for dim in weight_spec: if dim in input_dim_map: weight_shape.append(input_shape[input_dim_map[dim]]) elif dim in output_dim_map: weight_shape.append(output_shape[output_dim_map[dim]]) else: raise ValueError( f"Weight dimension '{dim}' did not have a match in either " f"the input spec '{input_spec}' or the output " f"spec '{output_spec}'. For this layer, the weight must " "be fully specified." ) if bias_axes is not None: num_left_elided = elided if left_elided else 0 idx_map = { char: output_shape[i + num_left_elided] for i, char in enumerate(output_spec) } for char in bias_axes: if char not in output_spec: raise ValueError( f"Bias dimension '{char}' was requested, but is not part " f"of the output spec '{output_spec}'" ) first_bias_location = min( [output_spec.find(char) for char in bias_axes] ) bias_output_spec = output_spec[first_bias_location:] bias_shape = [ idx_map[char] if char in bias_axes else 1 for char in bias_output_spec ] if not left_elided: for _ in range(elided): bias_shape.append(1) else: bias_shape = None return weight_shape, bias_shape, output_shape