# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Operations for embeddings.""" from tensorflow.python.compat import compat from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import clip_ops # Imports gradient definitions. from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.types import core from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export def _clip(params, ids, max_norm): """Helper function for _embedding_lookup_and_transform. This function optionally clips embeddings to an l2-norm of max_norm. Args: params: A `Tensor` of embeddings retrieved by `gather`. ids: The `ids` argument that was passed to `gather`. max_norm: If not `None`, each embedding is clipped if its l2-norm is larger than this value. Returns: A `Tensor` with the same type as `params`. """ def _rank(x): """Helper function to retrieve the rank of a tensor. Args: x: Something convertible to `Tensor`. Returns: Either a pair `(rank, True)` where `rank` is an integer or a pair `(rank, False)` where `rank` is an integer `Tensor`. In either case, `rank` is the rank of `x`. """ rank = ops.convert_to_tensor(x).get_shape().ndims if rank: return rank, True else: return array_ops.rank(x), False if max_norm is None: return params ids_rank, ids_static = _rank(ids) params_rank, params_static = _rank(params) return clip_ops.clip_by_norm( params, max_norm, axes=(list(range(ids_rank, params_rank)) if ids_static and params_static else math_ops.range(ids_rank, params_rank))) def _colocate_with(param): if ops.inside_function() and hasattr(param, "handle"): # The `ops.colocate_with` will hard-code a device string if `param.device` # is known, which will then break serving. We capture it here so that it # produces a tensor without a device. return ops.colocate_with(ops.get_default_graph().capture(param.handle)) else: return ops.colocate_with(param) def _embedding_lookup_and_transform(params, ids, partition_strategy="mod", name=None, max_norm=None, transform_fn=None): """Helper function for embedding_lookup and _compute_sampled_logits. This function is a generalization of embedding_lookup that optionally applies a caller-specified transformation to each embedding. This is done through the `transform_fn` argument. If provided, the function is applied to each partitioned tensor of retrieved embeddings, colocated with the embeddings. This function will be called with a single `Tensor` argument of the same type as the `params` tensor and should return a `Tensor`. The shape of the argument will be the same as `params` except for the size of the first dimension. The first dimension of the result's shape must be the same size as the argument's. Args: params: See embedding_lookup. ids: See embedding_lookup. partition_strategy: See embedding_lookup. name: See embedding_lookup. max_norm: See embedding_lookup. transform_fn: An optional function to apply to each retrieved embedding. If max_norm is provided, transform_fn is applied to the norm-limited embeddings. Returns: See embedding_lookup for details. Raises: ValueError: If `params` is empty. """ if params is None: raise ValueError("params must be specified") if isinstance(params, (list, tuple)) and not params: raise ValueError("Length of params is currently 0. " "Need at least one param.") if isinstance(params, variables.PartitionedVariable): params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): params = [params] with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: np = len(params) # Number of partitions # Preserve the resource variable status to avoid accidental dense reads. if not any( isinstance(p, resource_variable_ops.BaseResourceVariable) for p in params): params = indexed_slices.convert_n_to_tensor_or_indexed_slices( params, name="params") ids = ops.convert_to_tensor(ids, name="ids") if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): with _colocate_with(params[0]): result = _clip( array_ops.gather(params[0], ids, name=name), ids, max_norm) if transform_fn: result = transform_fn(result) # Make sure the final result does not have colocation constraints on the # params. Similar to the case np > 1 where parallel_dynamic_stitch is # outside the scope of all with _colocate_with(params[p]). return array_ops.identity(result) else: # Flatten the ids. There are two cases where we need to do this. # - There is more than one params tensor. # - There is a transform_fn and ids is not statically known to be 1-D. # We must flatten in this case because transform_fn expects a flat # tensor of embeddings. flat_ids = array_ops.reshape(ids, [-1]) original_indices = math_ops.range(array_ops.size(flat_ids)) # Create p_assignments and set new_ids depending on the strategy. if partition_strategy == "mod": p_assignments = flat_ids % np new_ids = flat_ids // np elif partition_strategy == "div": # Compute num_total_ids as the sum of dim-0 of params, then assign to # partitions based on a constant number of ids per partition. Optimize # if we already know the full shape statically. dim_0_size = tensor_shape.Dimension( tensor_shape.dimension_value(params[0].get_shape()[0])) for p in range(1, np): dim_0_size += tensor_shape.Dimension( tensor_shape.dimension_value(params[p].get_shape()[0])) if dim_0_size.value: num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) else: dim_0_sizes = [] for p in range(np): param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0]) if param_p_dim is not None: dim_0_sizes.append(param_p_dim) else: with _colocate_with(params[p]): dim_0_sizes.append(array_ops.shape(params[p])[0]) num_total_ids = math_ops.reduce_sum( math_ops.cast(array_ops_stack.stack(dim_0_sizes), flat_ids.dtype)) ids_per_partition = num_total_ids // np extras = num_total_ids % np p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), (flat_ids - extras) // ids_per_partition) # Emulate a conditional using a boolean indicator tensor new_ids = array_ops.where(p_assignments < extras, flat_ids % (ids_per_partition + 1), (flat_ids - extras) % ids_per_partition) else: raise ValueError( f"Unrecognized partition strategy: {partition_strategy}." "Must be one of either `mod` or `div`.") # Cast partition assignments to int32 for use in dynamic_partition. # There really should not be more than 2^32 partitions. p_assignments = math_ops.cast(p_assignments, dtypes.int32) # Partition list of ids based on assignments into np separate lists gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) # Similarly, partition the original indices. pindices = data_flow_ops.dynamic_partition(original_indices, p_assignments, np) # Do np separate lookups, finding embeddings for plist[p] in params[p] partitioned_result = [] for p in range(np): pids = gather_ids[p] with ops.device_v2(None): with _colocate_with(params[p]): result = array_ops.gather(params[p], pids) if transform_fn: # If transform_fn is provided, the clip_by_norm precedes # the transform and hence must be co-located. See below # for the counterpart if transform_fn is not provided. result = transform_fn(_clip(result, pids, max_norm)) partitioned_result.append(result) # Stitch these back together ret = data_flow_ops.parallel_dynamic_stitch( pindices, partitioned_result, name=name) # Determine the static element shape. if transform_fn is None: element_shape_s = params[0].get_shape()[1:] for p in params[1:]: element_shape_s = element_shape_s.merge_with(p.get_shape()[1:]) else: element_shape_s = ret.get_shape()[1:] # Compute the dynamic element shape. if element_shape_s.is_fully_defined(): element_shape_d = element_shape_s elif transform_fn is None: # It's important that we compute params[0].shape on the right device # to avoid data motion. with _colocate_with(params[0]): params_shape = array_ops.shape(params[0]) element_shape_d = params_shape[1:] else: element_shape_d = array_ops.shape(ret)[1:] # Reshape to reverse the flattening of ids. ret = array_ops.reshape( ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0)) # Normally the reshape is sufficient, but setting shape explicitly # teaches shape inference that params[1:].get_shape() matters # (in the case that transform_fn is None). ret.set_shape(ids.get_shape().concatenate(element_shape_s)) if not transform_fn: # If transform_fn was provided, the clip_by_norm was done above. ret = _clip(ret, ids, max_norm) return ret @tf_export(v1=["nn.embedding_lookup"]) @dispatch.add_dispatch_support def embedding_lookup( params, ids, partition_strategy="mod", name=None, validate_indices=True, # pylint: disable=unused-argument max_norm=None): """Looks up embeddings for the given `ids` from a list of tensors. This function is used to perform parallel lookups on the list of tensors in `params`. It is a generalization of `tf.gather`, where `params` is interpreted as a partitioning of a large embedding tensor. `params` may be a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()` with a partitioner. If `len(params) > 1`, each element `id` of `ids` is partitioned between the elements of `params` according to the `partition_strategy`. In all strategies, if the id space does not evenly divide the number of partitions, each of the first `(max_id + 1) % len(params)` partitions will be assigned one more id. If `partition_strategy` is `"mod"`, we assign each id to partition `p = id % len(params)`. For instance, 13 ids are split across 5 partitions as: `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` If `partition_strategy` is `"div"`, we assign ids to partitions in a contiguous manner. In this case, 13 ids are split across 5 partitions as: `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` If the input ids are ragged tensors, partition variables are not supported and the partition strategy and the max_norm are ignored. The results of the lookup are concatenated into a dense tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. Args: params: A single tensor representing the complete embedding tensor, or a list of P tensors all of same shape except for the first dimension, representing sharded embedding tensors. Alternatively, a `PartitionedVariable`, created by partitioning along dimension 0. Each element must be appropriately sized for the given `partition_strategy`. ids: A `Tensor` or a 'RaggedTensor' with type `int32` or `int64` containing the ids to be looked up in `params`. Caution: Out-of-bounds indices will result in undefined behavior, which will differ between devices and backends. partition_strategy: A string specifying the partitioning strategy, relevant if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default is `"mod"`. name: A name for the operation (optional). validate_indices: DEPRECATED. If this operation is assigned to CPU, values in `indices` are always validated to be within range. If assigned to GPU, out-of-bound indices result in safe but unspecified behavior, which may include raising an error. max_norm: If not `None`, each embedding is clipped if its l2-norm is larger than this value. Returns: A `Tensor` or a 'RaggedTensor', depending on the input, with the same type as the tensors in `params`. Raises: ValueError: If `params` is empty. """ """ **Behavior Difference between CPU and GPU** Please note that when using `tf.nn.embedding_lookup` on a GPU, if an out-of-bound index is encountered, a value of 0 will be stored in the corresponding output value. On the other hand, when using `tf.nn.embedding_lookup` on a CPU, an error will be returned if an out-of-bound index is found. This behavior difference can impact the results of your computation, especially when dealing with indices that may go beyond the bounds of the tensor. Make sure to be mindful of this distinction when using the `tf.nn.embedding_lookup` function in your computations. **Usage Example** Here's an example demonstrating how to use `tf.nn.embedding_lookup`: ```python import tensorflow as tf # Example embedding matrix and indices embedding_matrix = tf.constant([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) indices = tf.constant([1, 0, 2]) # Perform embedding lookup embeddings = tf.nn.embedding_lookup(embedding_matrix, indices) # Print the result print("Embeddings:") print(embeddings.numpy()) ``` """ return _embedding_lookup_and_transform( params=params, ids=ids, partition_strategy=partition_strategy, name=name, max_norm=max_norm, transform_fn=None) @tf_export("nn.embedding_lookup", v1=[]) @dispatch.add_dispatch_support def embedding_lookup_v2(params, ids, max_norm=None, name=None): """Looks up embeddings for the given `ids` from a list of tensors. This function is used to perform parallel lookups on the list of tensors in `params`. It is a generalization of `tf.gather`, where `params` is interpreted as a partitioning of a large embedding tensor. If `len(params) > 1`, each element `id` of `ids` is partitioned between the elements of `params` according to the "div" partition strategy, which means we assign ids to partitions in a contiguous manner. For instance, 13 ids are split across 5 partitions as: `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. If the id space does not evenly divide the number of partitions, each of the first `(max_id + 1) % len(params)` partitions will be assigned one more id. The results of the lookup are concatenated into a dense tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. Args: params: A single tensor representing the complete embedding tensor, or a list of tensors all of same shape except for the first dimension, representing sharded embedding tensors following "div" partition strategy. ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. max_norm: If not `None`, each embedding is clipped if its l2-norm is larger than this value. name: A name for the operation (optional). Returns: A `Tensor` with the same type as the tensors in `params`. For instance, if `params` is a 5x2 matrix: ```python [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] ``` or a list of matrices: ```python params[0]: [[1, 2], [3, 4]] params[1]: [[5, 6], [7, 8]] params[2]: [[9, 10]] ``` and `ids` is: ```python [0, 3, 4] ``` The output will be a 3x2 matrix: ```python [[1, 2], [7, 8], [9, 10]] ``` Raises: ValueError: If `params` is empty. """ return embedding_lookup(params, ids, "div", name, max_norm=max_norm) @tf_export(v1=["nn.embedding_lookup_sparse"]) @dispatch.add_dispatch_support def embedding_lookup_sparse( params, sp_ids, sp_weights, partition_strategy="mod", name=None, combiner=None, max_norm=None, allow_fast_lookup=False, ): """Looks up embeddings for the given ids and weights from a list of tensors. This op assumes that there is at least one id for each row in the dense tensor represented by sp_ids (i.e. there are no rows with empty features), and that all the indices of sp_ids are in canonical row-major order. `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s or `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can yield higher performance. It also assumes that all id values lie in the range [0, p0), where p0 is the sum of the size of params along dimension 0. Args: params: A single tensor representing the complete embedding tensor, or a list tensors all of same shape except for the first dimension, representing sharded embedding tensors. Alternatively, a `PartitionedVariable`, created by partitioning along dimension 0. Each element must be appropriately sized for the given `partition_strategy`. sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size and M is arbitrary or a `RaggedTensor` with rank 2. sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as `sparse_ids`, containing float / double weights corresponding to `sparse_ids`, or `None` if all weights are assumed to be 1.0. partition_strategy: A string specifying the partitioning strategy, relevant if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. name: Optional name for the op. combiner: A string specifying the reduction op. Currently "mean", "sqrtn" and "sum" are supported. "sum" computes the weighted sum of the embedding results for each row. "mean" is the weighted sum divided by the total weight. "sqrtn" is the weighted sum divided by the square root of the sum of the squares of the weights. Defaults to `mean`. max_norm: If not `None`, each embedding is clipped if its l2-norm is larger than this value, before combining. allow_fast_lookup: An optional boolean specifying whether to allow simplified embedding lookups when `params` is a single tensor and `max_norm` is `None`. Setting this flag to `True` during training can cause the use of dense gradients with increased memory footprint. Returns: A dense tensor representing the combined embeddings for the sparse ids. For each row in the dense tensor represented by `sp_ids`, the op looks up the embeddings for all ids in that row, multiplies them by the corresponding weight, and combines these embeddings as specified. In other words, if `shape(combined params) = [p0, p1, ..., pm]` and `shape(sp_ids) = shape(sp_weights) = [d0, d1]` then `shape(output) = [d0, p1, ..., pm]`. For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are ```python [0, 0]: id 1, weight 2.0 [0, 1]: id 3, weight 0.5 [1, 0]: id 0, weight 1.0 [2, 3]: id 1, weight 3.0 ``` with `combiner`="mean", then the output will be a 3x20 matrix where ```python output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) output[1, :] = (params[0, :] * 1.0) / 1.0 output[2, :] = (params[1, :] * 3.0) / 3.0 ``` Raises: TypeError: If `sp_ids` is not a `SparseTensor` or `RaggedTensor`, or if `sp_weights` is neither `None` nor of the same type as `sp_ids`. ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. """ if combiner is None: combiner = "mean" if combiner not in ("mean", "sqrtn", "sum"): raise ValueError( f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}") if isinstance(params, variables.PartitionedVariable): params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): params = [params] if not isinstance(sp_ids, sparse_tensor.SparseTensor): raise TypeError(f"sp_ids must be SparseTensor, got {type(sp_ids)}") ignore_weights = sp_weights is None if not ignore_weights: if not isinstance(sp_weights, sparse_tensor.SparseTensor): raise TypeError(f"sp_weights must be either None or SparseTensor," f"got {type(sp_weights)}") sp_ids.values.get_shape().assert_is_compatible_with( sp_weights.values.get_shape()) sp_ids.indices.get_shape().assert_is_compatible_with( sp_weights.indices.get_shape()) sp_ids.dense_shape.get_shape().assert_is_compatible_with( sp_weights.dense_shape.get_shape()) # TODO(yleon): Add enhanced node assertions to verify that sp_ids and # sp_weights have equal indices and shapes. with ops.name_scope(name, "embedding_lookup_sparse", params + [sp_ids]) as name: segment_ids = sp_ids.indices[:, 0] ids = sp_ids.values return embedding_lookup_sparse_impl( params, segment_ids, sp_weights, ids, combiner, ignore_weights, max_norm, allow_fast_lookup, partition_strategy, name, ) @tf_export("nn.embedding_lookup_sparse", v1=[]) @dispatch.add_dispatch_support def embedding_lookup_sparse_v2( params, sp_ids, sp_weights, combiner=None, max_norm=None, name=None, allow_fast_lookup=False, ): """Looks up embeddings for the given ids and weights from a list of tensors. `params` is a dense tensor or a list of dense tensors, and `sp_ids` is a 2D `tf.SparseTensor` or `tf.RaggedTensor` indicating the indices of `params` to gather. This op is best described with an example. Suppose `params` is an embedding table of size `(4, 2)` and `sp_ids` has 3 rows. Since `sp_ids` is sparse or ragged, not every row has the same number of elements. The output has shape (3, 2). Each row of `sp_ids` is a list of indices, where each index selects a row of `params`. For a given row of `sp_ids`, the rows of `params` are gathered based on the indices in `sp_ids`, then combined by taking their sum or mean. >>> params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=tf.float32) >>> sp_ids = tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [2, 0]], ... values=[0, 1, 3, 2], dense_shape=(3, 2)) >>> tf.nn.embedding_lookup_sparse(params, sp_ids, sp_weights=None, ... combiner='sum').numpy() array([[4., 6.], [7., 8.], [5., 6.]], dtype=float32) In this example, `sp_ids` has 3 rows, so the output has 3 rows. Row 0 of `sp_ids` has values 0 and 1, so it selects rows 0 and 1 from `params`, which are `[1, 2]` and `[3, 4]`. The rows are summed since `combiner='sum'`, resulting in the output row of `[4, 6]`. Since row 1 and 2 of `sp_ids` only have one value each, they simply select the corresponding row from `params` as the output row. Row 1 has value `3` so it selects the `params` elements `[7, 8]` and row 2 has the value 2 so it selects the `params` elements `[5, 6]`. If `sparse_weights` is specified, it must have the same shape as `sp_ids`. `sparse_weights` is used to assign a weight to each slice of `params`. For example: >>> params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=tf.float32) >>> sp_ids = tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [2, 0]], ... values=[0, 1, 3, 2], dense_shape=(3, 2)) >>> sparse_weights = tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [2, 0]], ... values=[0.1, 1.0, 0.5, 2.0], ... dense_shape=(3, 2)) >>> tf.nn.embedding_lookup_sparse(params, sp_ids, sp_weights=sparse_weights, ... combiner='sum').numpy() array([[3.1, 4.2], [3.5, 4.], [10., 12.]], dtype=float32) In general, `params` can have shape `(p0, ..., pn)` and `sp_ids` can have `M` rows, where each row can have any number of elements. The output has shape `(M, p1, ..., pn)`. Each slice of the output `output[i, ...]` is obtained as follows: The `combiner` argument is used to combine the values `params[sp_ids[i, j], ...] * sparse_weights[i, j]` for each `j` in `range(0, len(sp_ids[i]))`, e.g. by taking the sum or mean of the values. This op assumes that there is at least one id for each row in the dense tensor represented by sp_ids (i.e. there are no rows with empty features), and that all the indices of sp_ids are in canonical row-major order. `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s or `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can yield higher performance. This op assumes that all id values lie in the range [0, p0), where p0 is `params.shape[0]`. If you want a version of this op that prunes id values less than 0, see `tf.nn.safe_embedding_lookup_sparse` If `len(params) > 1`, each element of `sp_ids` is partitioned between the elements of `params` according to the "div" partition strategy, which means we assign ids to partitions in a contiguous manner. For instance, 13 ids are split across 5 partitions as: `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. If the id space does not evenly divide the number of partitions, each of the first `(max_id + 1) % len(params)` partitions will be assigned one more id. Args: params: A single tensor representing the complete embedding tensor, or a list of tensors all of same shape except for the first dimension, representing sharded embedding tensors following "div" partition strategy. sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size and M is arbitrary or a `RaggedTensor` with rank 2. sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as `sparse_ids`, containing float / double weights corresponding to `sparse_ids`, or `None` if all weights are assumed to be 1.0. combiner: A string specifying the reduction op. Currently "mean", "sqrtn" and "sum" are supported. "sum" computes the weighted sum of the embedding results for each row. "mean" is the weighted sum divided by the total weight. "sqrtn" is the weighted sum divided by the square root of the sum of the squares of the weights. Defaults to `mean`. max_norm: If not `None`, each embedding is clipped if its l2-norm is larger than this value, before combining. name: Optional name for the op. allow_fast_lookup: An optional boolean specifying whether to allow simplified embedding lookups when `params` is a single tensor and `max_norm` is `None`. Setting this flag to `True` during training can cause the use of dense gradients with increased memory footprint. Returns: A dense tensor representing the combined embeddings for the sparse ids. For each row in the dense tensor represented by `sp_ids`, the op looks up the embeddings for all ids in that row, multiplies them by the corresponding weight, and combines these embeddings as specified. In other words, if `shape(combined params) = [p0, p1, ..., pm]` and `shape(sp_ids) = shape(sp_weights) = [d0, d1]` then `shape(output) = [d0, p1, ..., pm]`. For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are ```python [0, 0]: id 1, weight 2.0 [0, 1]: id 3, weight 0.5 [1, 0]: id 0, weight 1.0 [2, 3]: id 1, weight 3.0 ``` with `combiner`="mean", then the output will be a 3x20 matrix where ```python output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) output[1, :] = (params[0, :] * 1.0) / 1.0 output[2, :] = (params[1, :] * 3.0) / 3.0 ``` Raises: TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is neither `None` nor `SparseTensor`. ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. """ return embedding_lookup_sparse( params, sp_ids, sp_weights, "div", name, combiner, max_norm, allow_fast_lookup, ) @tf_export("nn.safe_embedding_lookup_sparse", v1=[]) @dispatch.add_dispatch_support def safe_embedding_lookup_sparse_v2( embedding_weights, sparse_ids, sparse_weights=None, combiner="mean", default_id=None, max_norm=None, name=None, allow_fast_lookup=False, ): """Lookup embedding results, accounting for invalid IDs and empty features. The partitioned embedding in `embedding_weights` must all be the same shape except for the first dimension. The first dimension is allowed to vary as the vocabulary size is not necessarily a multiple of num of shards. This is similar to `tf.nn.embedding_lookup_sparse`, except invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs with non-positive weight. For an entry with no features, the embedding vector for `default_id` is returned, or the 0-vector if `default_id` is not supplied. See `tf.nn.embedding_lookup_sparse` for more information on how sparse embedding lookups work in general. The ids and weights may be multi-dimensional `SparseTensor`s or `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can yield higher performance. If `len(embedding_weights) > 1`, each element `id` of `ids` is partitioned between the elements of `embedding_weights` according to the "div" partition strategy, which means we assign ids to partitions in a contiguous manner. For instance, 13 ids are split across 5 partitions as: `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. If the id space does not evenly divide the number of partitions, each of the first `(max_id + 1) % len(embedding_weights)` partitions will be assigned one more id. Args: embedding_weights: A single tensor representing the complete embedding tensor, or a list of tensors all of same shape except for the first dimension, representing sharded embedding tensors following "div" partition strategy. sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the ids, where `d_0` is typically batch size, or a `RaggedTensor` with rank 2. sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as `sparse_ids`, containing float weights corresponding to `sparse_ids`, or `None` if all weights are assumed to be 1.0. combiner: A string specifying how to combine embedding results for each entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the default. default_id: The id to use for an entry with no features. Defaults to 0-vector. max_norm: If not `None`, all embeddings are l2-normalized to max_norm before combining. name: A name for this operation (optional). allow_fast_lookup: An optional boolean specifying whether to allow simplified embedding lookups when `params` is a single tensor and `max_norm` is `None`. Setting this flag to `True` during training can cause the use of dense gradients with increased memory footprint. Returns: A dense tensor representing the combined embeddings for the sparse ids. For each row in the dense tensor represented by `sparse_ids`, the op looks up the embeddings for all ids in that row, multiplies them by the corresponding weight, and combines these embeddings as specified. In other words, if `shape(combined embedding_weights) = [p0, p1, ..., pm]` and `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` then `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are ```python [0, 0]: id 1, weight 2.0 [0, 1]: id 3, weight 0.5 [1, 0]: id -1, weight 1.0 [2, 3]: id 1, weight 3.0 ``` `default_id` is 0. with `combiner`="mean", then the output will be a 3x20 matrix where ```python output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) output[1, :] = (params[0, :] * 1.0) / 1.0 output[2, :] = (params[1, :] * 3.0) / 3.0 ``` Raises: ValueError: if `embedding_weights` is empty. """ return safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights=sparse_weights, combiner=combiner, default_id=default_id, name=name, partition_strategy="div", max_norm=max_norm, allow_fast_lookup=allow_fast_lookup, ) @tf_export(v1=["nn.safe_embedding_lookup_sparse"]) @dispatch.add_dispatch_support def safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights=None, combiner="mean", default_id=None, name=None, partition_strategy="div", max_norm=None, allow_fast_lookup=False, ): """Lookup embedding results, accounting for invalid IDs and empty features. The partitioned embedding in `embedding_weights` must all be the same shape except for the first dimension. The first dimension is allowed to vary as the vocabulary size is not necessarily a multiple of `P`. `embedding_weights` may be a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()` with a partitioner. Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs with non-positive weight. For an entry with no features, the embedding vector for `default_id` is returned, or the 0-vector if `default_id` is not supplied. The ids and weights may be multi-dimensional `SparseTensor`s or `RaggedTensor`s with rank of 2. For `SpareTensor`s with left-aligned non-zero entries which can be described as `RaggedTensor`s, use of `RaggedTensor`s can yield higher performance. Embeddings are always aggregated along the last dimension. Args: embedding_weights: A single tensor representing the complete embedding tensor, or a list tensors all of same shape except for the first dimension, representing sharded embedding tensors. Alternatively, a `PartitionedVariable`, created by partitioning along dimension 0. Each element must be appropriately sized for the given `partition_strategy`. sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the ids, where `d_0` is typically batch size, or a `RaggedTensor` with rank 2. sparse_weights: `SparseTensor` or `RaggedTensor` of same type and shape as `sparse_ids`, containing float weights corresponding to `sparse_ids`, or `None` if all weights are assumed to be 1.0. combiner: A string specifying how to combine embedding results for each entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the default. default_id: The id to use for an entry with no features. name: A name for this operation (optional). partition_strategy: A string specifying the partitioning strategy. Currently `"div"` and `"mod"` are supported. Default is `"div"`. max_norm: If not `None`, all embeddings are l2-normalized to max_norm before combining. allow_fast_lookup: An optional boolean specifying whether to allow simplified embedding lookups when `params` is a single tensor and `max_norm` is `None`. Setting this flag to `True` during training can cause the use of dense gradients with increased memory footprint. Returns: A dense tensor representing the combined embeddings for the sparse ids. For each row in the dense tensor represented by `sp_ids`, the op looks up the embeddings for all ids in that row, multiplies them by the corresponding weight, and combines these embeddings as specified. In other words, if `shape(combined embedding_weights) = [p0, p1, ..., pm]` and `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` then `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are ```python [0, 0]: id 1, weight 2.0 [0, 1]: id 3, weight 0.5 [1, 0]: id -1, weight 1.0 [2, 3]: id 1, weight 3.0 ``` `default_id` is 0. with `combiner`="mean", then the output will be a 3x20 matrix where ```python output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) output[1, :] = (params[0, :] * 1.0) / 1.0 output[2, :] = (params[1, :] * 3.0) / 3.0 ``` Raises: ValueError: if `embedding_weights` is empty. """ if embedding_weights is None: raise ValueError(f"Missing embedding_weights {embedding_weights}.") if isinstance(embedding_weights, variables.PartitionedVariable): embedding_weights = list(embedding_weights) # get underlying Variables. if not isinstance(embedding_weights, list): embedding_weights = [embedding_weights] if len(embedding_weights) < 1: raise ValueError(f"Missing embedding_weights {embedding_weights}.") dtype = sparse_weights.dtype if sparse_weights is not None else None embedding_weights = [ w if (resource_variable_ops.is_resource_variable(w) and dtype in (None, w.dtype)) else ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights ] with ops.name_scope(name, "embedding_lookup", embedding_weights + [sparse_ids, sparse_weights]) as scope: # Reshape higher-rank sparse ids and weights to linear segment ids. original_shape = sparse_ids.dense_shape original_rank_dim = tensor_shape.dimension_value( sparse_ids.dense_shape.get_shape()[0]) original_rank = ( array_ops.size(original_shape) if original_rank_dim is None else original_rank_dim) sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ math_ops.reduce_prod( array_ops.slice(original_shape, [0], [original_rank - 1])), array_ops.gather(original_shape, original_rank - 1) ]) if sparse_weights is not None: sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices, sparse_weights.values, sparse_ids.dense_shape) # Prune invalid ids and weights. sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) if combiner != "sum": sparse_ids, sparse_weights = _prune_invalid_weights( sparse_ids, sparse_weights) # Fill in dummy values for empty features, if necessary. sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows( sparse_ids, default_id or 0) if sparse_weights is not None: sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) result = embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights, combiner=combiner, partition_strategy=partition_strategy, name=None if default_id is None else scope, max_norm=max_norm, allow_fast_lookup=allow_fast_lookup, ) if default_id is None: # Broadcast is_row_empty to the same shape as embedding_lookup_result, # for use in Select. is_row_empty = array_ops.tile( array_ops.reshape(is_row_empty, [-1, 1]), array_ops_stack.stack([1, array_ops.shape(result)[1]])) result = array_ops.where( is_row_empty, array_ops.zeros_like(result), result, name=scope) # Reshape back from linear ids back into higher-dimensional dense result. final_result = array_ops.reshape( result, array_ops.concat([ array_ops.slice( math_ops.cast(original_shape, dtypes.int32), [0], [original_rank - 1]), array_ops.slice(array_ops.shape(result), [1], [-1]) ], 0)) final_result.set_shape( tensor_shape.unknown_shape( (tensor_shape.Dimension(original_rank_dim) - 1).value ).concatenate(result.get_shape()[1:]) ) return final_result def embedding_lookup_sparse_impl( params, segment_ids, sp_weights, ids, combiner, ignore_weights, max_norm, allow_fast_lookup, partition_strategy, name, ): """Implementation of sparse embedding aggregation.""" need_sparse_segment_gradient = False # Ensure we can query the devices below. segment_ids = ops.convert_to_tensor(segment_ids, name="segment_ids") if len(params) == 1 and not isinstance( params[0], (core.Tensor, composite_tensor.CompositeTensor) ): params = [ops.convert_to_tensor(params[0], name="params")] # Note that if the params are on a different device (e.g., CPU), we must use # embedding_lookup() so that the gather operation is colocated with them. if ( len(params) == 1 and not isinstance(params[0], composite_tensor.CompositeTensor) and params[0].device == segment_ids.device and max_norm is None and ( allow_fast_lookup or (ignore_weights and compat.forward_compatible(2023, 9, 26)) ) ): idx = ids embeddings = params[0] if isinstance(embeddings, resource_variable_ops.BaseResourceVariable): # Avoid a redundant copy due to copy-on-read semantics for # sparsely-updated variables. embeddings = embeddings.read_value_no_copy() if not allow_fast_lookup: need_sparse_segment_gradient = True else: ids, idx = array_ops.unique(ids) embeddings = embedding_lookup( params, ids, partition_strategy=partition_strategy, max_norm=max_norm ) if not ignore_weights: if segment_ids.dtype != dtypes.int32: segment_ids = math_ops.cast(segment_ids, dtypes.int32) weights = sp_weights.values embeddings = array_ops.gather(embeddings, idx) original_dtype = embeddings.dtype if embeddings.dtype in (dtypes.float16, dtypes.bfloat16): # Cast low-precision embeddings to float32 during the computation to # avoid numerical issues. embeddings = math_ops.cast(embeddings, dtypes.float32) if weights.dtype != embeddings.dtype: weights = math_ops.cast(weights, embeddings.dtype) # Reshape weights to allow broadcast ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0) ones = array_ops.ones(ones_shape, dtype=dtypes.int32) bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0) orig_weights_shape = weights.get_shape() weights = array_ops.reshape(weights, bcast_weights_shape) # Set the weight shape, since after reshaping to bcast_weights_shape, # the shape becomes None. if embeddings.get_shape().ndims is not None: weights.set_shape( orig_weights_shape.concatenate( [1 for _ in range(embeddings.get_shape().ndims - 1)] ) ) embeddings *= weights if combiner == "sum": embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) elif combiner == "mean": embeddings = math_ops.segment_sum(embeddings, segment_ids) weight_sum = math_ops.segment_sum(weights, segment_ids) embeddings = math_ops.div_no_nan(embeddings, weight_sum, name=name) elif combiner == "sqrtn": embeddings = math_ops.segment_sum(embeddings, segment_ids) weights_squared = math_ops.pow(weights, 2) weight_sum = math_ops.segment_sum(weights_squared, segment_ids) weight_sum_sqrt = math_ops.sqrt(weight_sum) embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt, name=name) else: assert False, "Unrecognized combiner" if embeddings.dtype != original_dtype: embeddings = math_ops.cast(embeddings, original_dtype) else: if segment_ids.dtype not in (dtypes.int32, dtypes.int64): segment_ids = math_ops.cast(segment_ids, dtypes.int32) assert idx is not None if combiner == "sum": embeddings = math_ops.sparse_segment_sum( embeddings, idx, segment_ids, name=name, sparse_gradient=need_sparse_segment_gradient, ) elif combiner == "mean": embeddings = math_ops.sparse_segment_mean( embeddings, idx, segment_ids, name=name, sparse_gradient=need_sparse_segment_gradient, ) elif combiner == "sqrtn": embeddings = math_ops.sparse_segment_sqrt_n( embeddings, idx, segment_ids, name=name, sparse_gradient=need_sparse_segment_gradient, ) else: assert False, "Unrecognized combiner" return embeddings def _prune_invalid_ids(sparse_ids, sparse_weights): """Prune invalid IDs (< 0) from the input ids and weights.""" is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) if sparse_weights is not None: is_id_valid = math_ops.logical_and( is_id_valid, array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) if sparse_weights is not None: sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) return sparse_ids, sparse_weights def _prune_invalid_weights(sparse_ids, sparse_weights): """Prune invalid weights (< 0) from the input ids and weights.""" if sparse_weights is not None: is_weights_valid = math_ops.greater(sparse_weights.values, 0) sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) return sparse_ids, sparse_weights