307 lines
13 KiB
Python
307 lines
13 KiB
Python
![]() |
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Embedding layer."""
|
||
|
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras import constraints
|
||
|
from keras import initializers
|
||
|
from keras import regularizers
|
||
|
from keras.dtensor import utils
|
||
|
from keras.engine import base_layer_utils
|
||
|
from keras.engine.base_layer import Layer
|
||
|
from keras.utils import tf_utils
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
|
||
|
@keras_export("keras.layers.Embedding")
|
||
|
class Embedding(Layer):
|
||
|
"""Turns positive integers (indexes) into dense vectors of fixed size.
|
||
|
|
||
|
e.g. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]`
|
||
|
|
||
|
This layer can only be used on positive integer inputs of a fixed range. The
|
||
|
`tf.keras.layers.TextVectorization`, `tf.keras.layers.StringLookup`,
|
||
|
and `tf.keras.layers.IntegerLookup` preprocessing layers can help prepare
|
||
|
inputs for an `Embedding` layer.
|
||
|
|
||
|
This layer accepts `tf.Tensor` and `tf.RaggedTensor` inputs. It cannot be
|
||
|
called with `tf.SparseTensor` input.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
>>> model = tf.keras.Sequential()
|
||
|
>>> model.add(tf.keras.layers.Embedding(1000, 64, input_length=10))
|
||
|
>>> # The model will take as input an integer matrix of size (batch,
|
||
|
>>> # input_length), and the largest integer (i.e. word index) in the input
|
||
|
>>> # should be no larger than 999 (vocabulary size).
|
||
|
>>> # Now model.output_shape is (None, 10, 64), where `None` is the batch
|
||
|
>>> # dimension.
|
||
|
>>> input_array = np.random.randint(1000, size=(32, 10))
|
||
|
>>> model.compile('rmsprop', 'mse')
|
||
|
>>> output_array = model.predict(input_array)
|
||
|
>>> print(output_array.shape)
|
||
|
(32, 10, 64)
|
||
|
|
||
|
Args:
|
||
|
input_dim: Integer. Size of the vocabulary,
|
||
|
i.e. maximum integer index + 1.
|
||
|
output_dim: Integer. Dimension of the dense embedding.
|
||
|
embeddings_initializer: Initializer for the `embeddings`
|
||
|
matrix (see `keras.initializers`).
|
||
|
embeddings_regularizer: Regularizer function applied to
|
||
|
the `embeddings` matrix (see `keras.regularizers`).
|
||
|
embeddings_constraint: Constraint function applied to
|
||
|
the `embeddings` matrix (see `keras.constraints`).
|
||
|
mask_zero: Boolean, whether or not the input value 0 is a special
|
||
|
"padding" value that should be masked out. This is useful when using
|
||
|
recurrent layers which may take variable length input. If this is
|
||
|
`True`, then all subsequent layers in the model need to support masking
|
||
|
or an exception will be raised. If mask_zero is set to True, as a
|
||
|
consequence, index 0 cannot be used in the vocabulary (input_dim should
|
||
|
equal size of vocabulary + 1).
|
||
|
input_length: Length of input sequences, when it is constant.
|
||
|
This argument is required if you are going to connect
|
||
|
`Flatten` then `Dense` layers upstream
|
||
|
(without it, the shape of the dense outputs cannot be computed).
|
||
|
sparse: If True, calling this layer returns a `tf.SparseTensor`. If False,
|
||
|
the layer returns a dense `tf.Tensor`. For an entry with no features in
|
||
|
a sparse tensor (entry with value 0), the embedding vector of index 0 is
|
||
|
returned by default.
|
||
|
|
||
|
Input shape:
|
||
|
2D tensor with shape: `(batch_size, input_length)`.
|
||
|
|
||
|
Output shape:
|
||
|
3D tensor with shape: `(batch_size, input_length, output_dim)`.
|
||
|
|
||
|
**Note on variable placement:**
|
||
|
By default, if a GPU is available, the embedding matrix will be placed on
|
||
|
the GPU. This achieves the best performance, but it might cause issues:
|
||
|
|
||
|
- You may be using an optimizer that does not support sparse GPU kernels.
|
||
|
In this case you will see an error upon training your model.
|
||
|
- Your embedding matrix may be too large to fit on your GPU. In this case
|
||
|
you will see an Out Of Memory (OOM) error.
|
||
|
|
||
|
In such cases, you should place the embedding matrix on the CPU memory.
|
||
|
You can do so with a device scope, as such:
|
||
|
|
||
|
```python
|
||
|
with tf.device('cpu:0'):
|
||
|
embedding_layer = Embedding(...)
|
||
|
embedding_layer.build()
|
||
|
```
|
||
|
|
||
|
The pre-built `embedding_layer` instance can then be added to a `Sequential`
|
||
|
model (e.g. `model.add(embedding_layer)`), called in a Functional model
|
||
|
(e.g. `x = embedding_layer(x)`), or used in a subclassed model.
|
||
|
"""
|
||
|
|
||
|
@utils.allow_initializer_layout
|
||
|
def __init__(
|
||
|
self,
|
||
|
input_dim,
|
||
|
output_dim,
|
||
|
embeddings_initializer="uniform",
|
||
|
embeddings_regularizer=None,
|
||
|
activity_regularizer=None,
|
||
|
embeddings_constraint=None,
|
||
|
mask_zero=False,
|
||
|
input_length=None,
|
||
|
sparse=False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
if "input_shape" not in kwargs:
|
||
|
if input_length:
|
||
|
kwargs["input_shape"] = (input_length,)
|
||
|
else:
|
||
|
kwargs["input_shape"] = (None,)
|
||
|
if input_dim <= 0 or output_dim <= 0:
|
||
|
raise ValueError(
|
||
|
"Both `input_dim` and `output_dim` should be positive, "
|
||
|
f"Received input_dim = {input_dim} "
|
||
|
f"and output_dim = {output_dim}"
|
||
|
)
|
||
|
if (
|
||
|
not base_layer_utils.v2_dtype_behavior_enabled()
|
||
|
and "dtype" not in kwargs
|
||
|
):
|
||
|
# In TF1, the dtype defaults to the input dtype which is typically
|
||
|
# int32, so explicitly set it to floatx
|
||
|
kwargs["dtype"] = backend.floatx()
|
||
|
# We set autocast to False, as we do not want to cast floating- point
|
||
|
# inputs to self.dtype. In call(), we cast to int32, and casting to
|
||
|
# self.dtype before casting to int32 might cause the int32 values to be
|
||
|
# different due to a loss of precision.
|
||
|
kwargs["autocast"] = False
|
||
|
use_one_hot_matmul = kwargs.pop("use_one_hot_matmul", False)
|
||
|
super().__init__(**kwargs)
|
||
|
|
||
|
self.input_dim = input_dim
|
||
|
self.output_dim = output_dim
|
||
|
self.embeddings_initializer = initializers.get(embeddings_initializer)
|
||
|
self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
|
||
|
self.activity_regularizer = regularizers.get(activity_regularizer)
|
||
|
self.embeddings_constraint = constraints.get(embeddings_constraint)
|
||
|
self.mask_zero = mask_zero
|
||
|
self.supports_masking = mask_zero
|
||
|
self.input_length = input_length
|
||
|
self.sparse = sparse
|
||
|
if self.sparse and self.mask_zero:
|
||
|
raise ValueError(
|
||
|
"`mask_zero` cannot be enabled when "
|
||
|
"`tf.keras.layers.Embedding` is used with `tf.SparseTensor` "
|
||
|
"input."
|
||
|
)
|
||
|
# Make this flag private and do not serialize it for now.
|
||
|
# It will be part of the public API after further testing.
|
||
|
self._use_one_hot_matmul = use_one_hot_matmul
|
||
|
|
||
|
@tf_utils.shape_type_conversion
|
||
|
def build(self, input_shape=None):
|
||
|
self.embeddings = self.add_weight(
|
||
|
shape=(self.input_dim, self.output_dim),
|
||
|
initializer=self.embeddings_initializer,
|
||
|
name="embeddings",
|
||
|
regularizer=self.embeddings_regularizer,
|
||
|
constraint=self.embeddings_constraint,
|
||
|
experimental_autocast=False,
|
||
|
)
|
||
|
self.built = True
|
||
|
|
||
|
def compute_mask(self, inputs, mask=None):
|
||
|
if not self.mask_zero:
|
||
|
return None
|
||
|
return tf.not_equal(inputs, 0)
|
||
|
|
||
|
@tf_utils.shape_type_conversion
|
||
|
def compute_output_shape(self, input_shape):
|
||
|
if self.input_length is None:
|
||
|
return input_shape + (self.output_dim,)
|
||
|
else:
|
||
|
# input_length can be tuple if input is 3D or higher
|
||
|
if isinstance(self.input_length, (list, tuple)):
|
||
|
in_lens = list(self.input_length)
|
||
|
else:
|
||
|
in_lens = [self.input_length]
|
||
|
if len(in_lens) != len(input_shape) - 1:
|
||
|
raise ValueError(
|
||
|
f'"input_length" is {self.input_length}, but received '
|
||
|
f"input has shape {input_shape}"
|
||
|
)
|
||
|
else:
|
||
|
for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])):
|
||
|
if s1 is not None and s2 is not None and s1 != s2:
|
||
|
raise ValueError(
|
||
|
f'"input_length" is {self.input_length}, but '
|
||
|
f"received input has shape {input_shape}"
|
||
|
)
|
||
|
elif s1 is None:
|
||
|
in_lens[i] = s2
|
||
|
return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
|
||
|
|
||
|
def call(self, inputs):
|
||
|
dtype = backend.dtype(inputs)
|
||
|
if dtype != "int32" and dtype != "int64":
|
||
|
inputs = tf.cast(inputs, "int32")
|
||
|
if isinstance(inputs, tf.sparse.SparseTensor):
|
||
|
if self.sparse:
|
||
|
# get sparse embedding values
|
||
|
embedding_values = tf.nn.embedding_lookup(
|
||
|
params=self.embeddings, ids=inputs.values
|
||
|
)
|
||
|
embedding_values = tf.reshape(embedding_values, [-1])
|
||
|
# get sparse embedding indices
|
||
|
indices_values_embed_axis = tf.range(self.output_dim)
|
||
|
repeat_times = [inputs.indices.shape[0]]
|
||
|
indices_values_embed_axis = tf.expand_dims(
|
||
|
tf.tile(indices_values_embed_axis, repeat_times), -1
|
||
|
)
|
||
|
indices_values_embed_axis = tf.cast(
|
||
|
indices_values_embed_axis, dtype=tf.int64
|
||
|
)
|
||
|
current_indices = tf.repeat(
|
||
|
inputs.indices, [self.output_dim], axis=0
|
||
|
)
|
||
|
new_indices = tf.concat(
|
||
|
[current_indices, indices_values_embed_axis], 1
|
||
|
)
|
||
|
new_shape = tf.concat(
|
||
|
[tf.cast(inputs.shape, dtype=tf.int64), [self.output_dim]],
|
||
|
axis=-1,
|
||
|
)
|
||
|
out = tf.SparseTensor(
|
||
|
indices=new_indices,
|
||
|
values=embedding_values,
|
||
|
dense_shape=new_shape,
|
||
|
)
|
||
|
else:
|
||
|
sparse_inputs_expanded = tf.sparse.expand_dims(inputs, axis=-1)
|
||
|
out = tf.nn.safe_embedding_lookup_sparse(
|
||
|
embedding_weights=self.embeddings,
|
||
|
sparse_ids=sparse_inputs_expanded,
|
||
|
default_id=0,
|
||
|
)
|
||
|
elif self._use_one_hot_matmul:
|
||
|
# Note that we change the dtype of the one_hot to be same as the
|
||
|
# weight tensor, since the input data are usually ints, and weights
|
||
|
# are floats. The nn.embedding_lookup support ids as ints, but
|
||
|
# the one_hot matmul need both inputs and weights to be same dtype.
|
||
|
one_hot_data = tf.one_hot(
|
||
|
inputs, depth=self.input_dim, dtype=self.dtype
|
||
|
)
|
||
|
out = tf.matmul(one_hot_data, self.embeddings)
|
||
|
else:
|
||
|
out = tf.nn.embedding_lookup(self.embeddings, inputs)
|
||
|
|
||
|
if self.sparse and not isinstance(out, tf.SparseTensor):
|
||
|
out = tf.sparse.from_dense(out)
|
||
|
|
||
|
if (
|
||
|
self._dtype_policy.compute_dtype
|
||
|
!= self._dtype_policy.variable_dtype
|
||
|
):
|
||
|
# Instead of casting the variable as in most layers, cast the
|
||
|
# output, as this is mathematically equivalent but is faster.
|
||
|
out = tf.cast(out, self._dtype_policy.compute_dtype)
|
||
|
return out
|
||
|
|
||
|
def get_config(self):
|
||
|
config = {
|
||
|
"input_dim": self.input_dim,
|
||
|
"output_dim": self.output_dim,
|
||
|
"embeddings_initializer": initializers.serialize(
|
||
|
self.embeddings_initializer
|
||
|
),
|
||
|
"embeddings_regularizer": regularizers.serialize(
|
||
|
self.embeddings_regularizer
|
||
|
),
|
||
|
"activity_regularizer": regularizers.serialize(
|
||
|
self.activity_regularizer
|
||
|
),
|
||
|
"embeddings_constraint": constraints.serialize(
|
||
|
self.embeddings_constraint
|
||
|
),
|
||
|
"mask_zero": self.mask_zero,
|
||
|
"input_length": self.input_length,
|
||
|
}
|
||
|
base_config = super().get_config()
|
||
|
return dict(list(base_config.items()) + list(config.items()))
|