# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Value for RaggedTensor.""" import numpy as np from tensorflow.python.ops.ragged.row_partition import RowPartition from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export(v1=["ragged.RaggedTensorValue"]) @dispatch.register_dispatchable_type class RaggedTensorValue: """Represents the value of a `RaggedTensor`. Warning: `RaggedTensorValue` should only be used in graph mode; in eager mode, the `tf.RaggedTensor` class contains its value directly. See `tf.RaggedTensor` for a description of ragged tensors. """ def __init__(self, values, row_splits): """Creates a `RaggedTensorValue`. Args: values: A numpy array of any type and shape; or a RaggedTensorValue. row_splits: A 1-D int32 or int64 numpy array. """ if not (isinstance(row_splits, (np.ndarray, np.generic)) and row_splits.dtype in (np.int64, np.int32) and row_splits.ndim == 1): raise TypeError("row_splits must be a 1D int32 or int64 numpy array") if not isinstance(values, (np.ndarray, np.generic, RaggedTensorValue)): raise TypeError("values must be a numpy array or a RaggedTensorValue") if (isinstance(values, RaggedTensorValue) and row_splits.dtype != values.row_splits.dtype): raise ValueError("row_splits and values.row_splits must have " "the same dtype") self._values = values self._row_splits = row_splits row_splits = property( lambda self: self._row_splits, doc="""The split indices for the ragged tensor value.""") values = property( lambda self: self._values, doc="""The concatenated values for all rows in this tensor.""") dtype = property( lambda self: self._values.dtype, doc="""The numpy dtype of values in this tensor.""") @property def flat_values(self): """The innermost `values` array for this ragged tensor value.""" rt_values = self.values while isinstance(rt_values, RaggedTensorValue): rt_values = rt_values.values return rt_values @property def nested_row_splits(self): """The row_splits for all ragged dimensions in this ragged tensor value.""" rt_nested_splits = [self.row_splits] rt_values = self.values while isinstance(rt_values, RaggedTensorValue): rt_nested_splits.append(rt_values.row_splits) rt_values = rt_values.values return tuple(rt_nested_splits) @property def ragged_rank(self): """The number of ragged dimensions in this ragged tensor value.""" values_is_ragged = isinstance(self._values, RaggedTensorValue) return self._values.ragged_rank + 1 if values_is_ragged else 1 @property def shape(self): """A tuple indicating the shape of this RaggedTensorValue.""" return (self._row_splits.shape[0] - 1,) + (None,) + self._values.shape[1:] @property def _nested_row_partitions(self): """The row_partitions representing this shape.""" return [RowPartition.from_row_splits(rs) for rs in self.nested_row_splits] def __str__(self): return "" % self.to_list() def __repr__(self): return "tf.RaggedTensorValue(values=%r, row_splits=%r)" % (self._values, self._row_splits) def to_list(self): """Returns this ragged tensor value as a nested Python list.""" if isinstance(self._values, RaggedTensorValue): values_as_list = self._values.to_list() else: values_as_list = self._values.tolist() return [ values_as_list[self._row_splits[i]:self._row_splits[i + 1]] for i in range(len(self._row_splits) - 1) ]