239 lines
10 KiB
Python
239 lines
10 KiB
Python
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
#==============================================================================
|
||
|
"""Lookup operations."""
|
||
|
|
||
|
from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.framework import tensor_spec
|
||
|
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||
|
from tensorflow.python.ops import lookup_ops
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
|
||
|
def _check_table_initializer_element_spec(element_spec):
|
||
|
"""Raises an error if the given table initializer element spec is invalid."""
|
||
|
base_error = ("Datasets used to initialize lookup tables must "
|
||
|
"produce elements in the form (key, value), where "
|
||
|
"the keys and values are scalar tensors. ")
|
||
|
specific_error = None
|
||
|
if len(element_spec) != 2:
|
||
|
raise ValueError(base_error + "However, the given dataset produces "
|
||
|
f"{len(element_spec)} components instead of two "
|
||
|
"(key, value) components. Full dataset element spec: "
|
||
|
f"{element_spec}.")
|
||
|
if not isinstance(element_spec[0], tensor_spec.TensorSpec):
|
||
|
raise ValueError(base_error + "However, the given dataset produces "
|
||
|
f"non-Tensor keys of type {type(element_spec[0])}.")
|
||
|
if not isinstance(element_spec[1], tensor_spec.TensorSpec):
|
||
|
raise ValueError(base_error + "However, the given dataset produces "
|
||
|
f"non-Tensor values of type {type(element_spec[1])}.")
|
||
|
if element_spec[0].shape.rank not in (None, 0):
|
||
|
raise ValueError(
|
||
|
base_error + "However, the given dataset produces "
|
||
|
f"non-scalar key Tensors of rank {element_spec[0].shape.rank}.")
|
||
|
if element_spec[1].shape.rank not in (None, 0):
|
||
|
raise ValueError(
|
||
|
base_error + "However, the given dataset produces "
|
||
|
f"non-scalar value Tensors of rank {element_spec[1].shape.rank}.")
|
||
|
|
||
|
|
||
|
@tf_export("data.experimental.DatasetInitializer")
|
||
|
class DatasetInitializer(lookup_ops.TableInitializerBase):
|
||
|
"""Creates a table initializer from a `tf.data.Dataset`.
|
||
|
|
||
|
Sample usage:
|
||
|
|
||
|
>>> keys = tf.data.Dataset.range(100)
|
||
|
>>> values = tf.data.Dataset.range(100).map(
|
||
|
... lambda x: tf.strings.as_string(x * 2))
|
||
|
>>> ds = tf.data.Dataset.zip((keys, values))
|
||
|
>>> init = tf.data.experimental.DatasetInitializer(ds)
|
||
|
>>> table = tf.lookup.StaticHashTable(init, "")
|
||
|
>>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
|
||
|
array([b'0', b'2', b'4'], dtype=object)
|
||
|
|
||
|
Attributes:
|
||
|
dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
|
||
|
first scalar is treated as a key and the second as value.
|
||
|
Raises: ValueError if `dataset` doesn't conform to specifications.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, dataset):
|
||
|
"""Creates a table initializer from a `tf.data.Dataset`.
|
||
|
|
||
|
Args:
|
||
|
dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
|
||
|
first scalar is treated as a key and the second as value.
|
||
|
Raises: ValueError if `dataset` doesn't conform to specifications.
|
||
|
Returns: A `DatasetInitializer` object
|
||
|
"""
|
||
|
# Assert that the dataset element spec is a tuple of TensorSpecs where
|
||
|
# each tensor is a scalar.
|
||
|
self.dataset = dataset
|
||
|
elem_spec = self.dataset.element_spec
|
||
|
_check_table_initializer_element_spec(elem_spec)
|
||
|
|
||
|
key_type = elem_spec[0].dtype
|
||
|
value_type = elem_spec[1].dtype
|
||
|
super(DatasetInitializer, self).__init__(key_type, value_type)
|
||
|
|
||
|
def initialize(self, table):
|
||
|
lookup_ops.check_table_dtypes(table, self._key_dtype, self._value_dtype)
|
||
|
init_op = ged_ops.initialize_table_from_dataset(
|
||
|
table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access
|
||
|
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
|
||
|
return init_op
|
||
|
|
||
|
|
||
|
@tf_export("data.experimental.table_from_dataset")
|
||
|
def table_from_dataset(dataset=None,
|
||
|
num_oov_buckets=0,
|
||
|
vocab_size=None,
|
||
|
default_value=None,
|
||
|
hasher_spec=lookup_ops.FastHashSpec,
|
||
|
key_dtype=dtypes.string,
|
||
|
name=None):
|
||
|
"""Returns a lookup table based on the given dataset.
|
||
|
|
||
|
This operation constructs a lookup table based on the given dataset of pairs
|
||
|
of (key, value).
|
||
|
|
||
|
Any lookup of an out-of-vocabulary token will return a bucket ID based on its
|
||
|
hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
|
||
|
`default_value`.
|
||
|
The bucket ID range is
|
||
|
`[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
|
||
|
|
||
|
Sample Usages:
|
||
|
|
||
|
>>> keys = tf.data.Dataset.range(100)
|
||
|
>>> values = tf.data.Dataset.range(100).map(
|
||
|
... lambda x: tf.strings.as_string(x * 2))
|
||
|
>>> ds = tf.data.Dataset.zip((keys, values))
|
||
|
>>> table = tf.data.experimental.table_from_dataset(
|
||
|
... ds, default_value='n/a', key_dtype=tf.int64)
|
||
|
>>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
|
||
|
array([b'0', b'2', b'4'], dtype=object)
|
||
|
|
||
|
Args:
|
||
|
dataset: A dataset containing (key, value) pairs.
|
||
|
num_oov_buckets: The number of out-of-vocabulary buckets.
|
||
|
vocab_size: Number of the elements in the vocabulary, if known.
|
||
|
default_value: The value to use for out-of-vocabulary feature values.
|
||
|
Defaults to -1.
|
||
|
hasher_spec: A `HasherSpec` to specify the hash function to use for
|
||
|
assignation of out-of-vocabulary buckets.
|
||
|
key_dtype: The `key` data type.
|
||
|
name: A name for this op (optional).
|
||
|
|
||
|
Returns:
|
||
|
The lookup table based on the given dataset.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If
|
||
|
* `dataset` does not contain pairs
|
||
|
* The 2nd item in the `dataset` pairs has a dtype which is incompatible
|
||
|
with `default_value`
|
||
|
* `num_oov_buckets` is negative
|
||
|
* `vocab_size` is not greater than zero
|
||
|
* The `key_dtype` is not integer or string
|
||
|
"""
|
||
|
elem_spec = dataset.element_spec
|
||
|
_check_table_initializer_element_spec(elem_spec)
|
||
|
if default_value is None:
|
||
|
default_value = -1
|
||
|
if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating):
|
||
|
raise ValueError("`default_value` must be specified when creating a "
|
||
|
"table from a dataset that produces values of type "
|
||
|
f"{elem_spec[1].dtype}.")
|
||
|
if num_oov_buckets < 0:
|
||
|
raise ValueError("`num_oov_buckets` must be greater than or equal to 0, "
|
||
|
f"got {num_oov_buckets}.")
|
||
|
if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and
|
||
|
vocab_size < 1):
|
||
|
raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.")
|
||
|
if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
|
||
|
raise TypeError("`key_dtype` must be either an integer or string type, "
|
||
|
f"but got {key_dtype}")
|
||
|
if vocab_size is not None:
|
||
|
if isinstance(vocab_size, ops.Tensor):
|
||
|
vocab_size = math_ops.cast(vocab_size, dtypes.int64)
|
||
|
dataset = dataset.take(vocab_size)
|
||
|
dataset = dataset.apply(assert_cardinality(vocab_size))
|
||
|
with ops.name_scope(name, "string_to_index"):
|
||
|
initializer = DatasetInitializer(dataset)
|
||
|
with ops.name_scope(None, "hash_table"):
|
||
|
table = lookup_ops.StaticHashTableV1(initializer, default_value)
|
||
|
if num_oov_buckets:
|
||
|
table = lookup_ops.IdTableWithHashBuckets(
|
||
|
table,
|
||
|
num_oov_buckets=num_oov_buckets,
|
||
|
hasher_spec=hasher_spec,
|
||
|
key_dtype=key_dtype)
|
||
|
return table
|
||
|
|
||
|
|
||
|
@tf_export("data.experimental.index_table_from_dataset")
|
||
|
def index_table_from_dataset(dataset=None,
|
||
|
num_oov_buckets=0,
|
||
|
vocab_size=None,
|
||
|
default_value=-1,
|
||
|
hasher_spec=lookup_ops.FastHashSpec,
|
||
|
key_dtype=dtypes.string,
|
||
|
name=None):
|
||
|
"""Returns an index lookup table based on the given dataset.
|
||
|
|
||
|
This operation constructs a lookup table based on the given dataset of keys.
|
||
|
|
||
|
Any lookup of an out-of-vocabulary token will return a bucket ID based on its
|
||
|
hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
|
||
|
`default_value`.
|
||
|
The bucket ID range is
|
||
|
`[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
|
||
|
|
||
|
Sample Usages:
|
||
|
|
||
|
>>> ds = tf.data.Dataset.range(100).map(lambda x: tf.strings.as_string(x * 2))
|
||
|
>>> table = tf.data.experimental.index_table_from_dataset(
|
||
|
... ds, key_dtype=dtypes.int64)
|
||
|
>>> table.lookup(tf.constant(['0', '2', '4'], dtype=tf.string)).numpy()
|
||
|
array([0, 1, 2])
|
||
|
|
||
|
Args:
|
||
|
dataset: A dataset of keys.
|
||
|
num_oov_buckets: The number of out-of-vocabulary buckets.
|
||
|
vocab_size: Number of the elements in the vocabulary, if known.
|
||
|
default_value: The value to use for out-of-vocabulary feature values.
|
||
|
Defaults to -1.
|
||
|
hasher_spec: A `HasherSpec` to specify the hash function to use for
|
||
|
assignation of out-of-vocabulary buckets.
|
||
|
key_dtype: The `key` data type.
|
||
|
name: A name for this op (optional).
|
||
|
|
||
|
Returns:
|
||
|
The lookup table based on the given dataset.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If
|
||
|
* `num_oov_buckets` is negative
|
||
|
* `vocab_size` is not greater than zero
|
||
|
* The `key_dtype` is not integer or string
|
||
|
"""
|
||
|
return table_from_dataset(dataset.enumerate().map(lambda v, k: (k, v)),
|
||
|
num_oov_buckets, vocab_size, default_value,
|
||
|
hasher_spec, key_dtype, name)
|