161 lines
6.0 KiB
Python
161 lines
6.0 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Operator dispatch for RaggedTensors."""
|
|
|
|
from tensorflow.python.ops import logging_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import string_ops
|
|
from tensorflow.python.ops.ragged import ragged_tensor
|
|
from tensorflow.python.ops.ragged import ragged_tensor_shape
|
|
from tensorflow.python.util import dispatch
|
|
from tensorflow.python.util import tf_decorator
|
|
from tensorflow.python.util import tf_export
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
|
|
@dispatch.dispatch_for_unary_elementwise_apis(ragged_tensor.Ragged)
|
|
def ragged_unary_elementwise_op(op, x):
|
|
"""Unary elementwise api handler for RaggedTensors."""
|
|
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
|
|
return x.with_values(op(x.values))
|
|
|
|
|
|
# TODO(martinz): This is deprecated. Delete.
|
|
def ragged_binary_elementwise_op(op, x, y):
|
|
"""Binary elementwise api handler for RaggedTensors."""
|
|
x_is_ragged = ragged_tensor.is_ragged(x)
|
|
y_is_ragged = ragged_tensor.is_ragged(y)
|
|
|
|
# Convert args to tensors.
|
|
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
|
x, preferred_dtype=(y.dtype if y_is_ragged else None))
|
|
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
|
y, preferred_dtype=x.dtype)
|
|
|
|
if x_is_ragged and y_is_ragged:
|
|
x, y = ragged_tensor.match_row_splits_dtypes(x, y)
|
|
|
|
# Perform broadcasting, when appropraite
|
|
if ((x_is_ragged and y_is_ragged) or
|
|
(x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
|
|
(y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
|
|
# If both x and y are ragged, they must have the same row_splits_dtype now.
|
|
if x_is_ragged:
|
|
dim_size_dtype = x.row_splits.dtype
|
|
else:
|
|
dim_size_dtype = y.row_splits.dtype
|
|
|
|
shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
|
|
x, dim_size_dtype=dim_size_dtype)
|
|
shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
|
|
y, dim_size_dtype=dim_size_dtype)
|
|
bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y)
|
|
x = ragged_tensor_shape.broadcast_to(
|
|
x, bcast_shape, broadcast_inner_dimensions=False)
|
|
y = ragged_tensor_shape.broadcast_to(
|
|
y, bcast_shape, broadcast_inner_dimensions=False)
|
|
|
|
x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
|
|
y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
|
|
mapped_values = op(x_values, y_values)
|
|
if isinstance(mapped_values, bool):
|
|
return mapped_values # Special case for tensor_equals.
|
|
if ragged_tensor.is_ragged(x):
|
|
return x.with_flat_values(mapped_values)
|
|
else:
|
|
return y.with_flat_values(mapped_values)
|
|
|
|
|
|
# TODO(edloper): Update the documentation generation tools to automatically
|
|
# build lists of which types are supported by which ops (and then delete all
|
|
# the following code).
|
|
|
|
|
|
# We don't need to register a separate delegation handler for these v1 ops,
|
|
# since they delegate to the v2 ops (which already have a handler). But we
|
|
# still want to include them in the ragged_op_list() output.
|
|
_V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS = [
|
|
math_ops.reduce_sum,
|
|
math_ops.reduce_prod,
|
|
math_ops.reduce_min,
|
|
math_ops.reduce_max,
|
|
math_ops.reduce_mean,
|
|
math_ops.reduce_variance,
|
|
math_ops.reduce_std,
|
|
math_ops.reduce_any,
|
|
math_ops.reduce_all,
|
|
string_ops.string_to_number,
|
|
string_ops.string_to_hash_bucket,
|
|
string_ops.reduce_join_v2,
|
|
]
|
|
|
|
|
|
def _ragged_op_signature(op, ragged_args, ragged_varargs=False):
|
|
"""Returns a signature for the given op, marking ragged args in bold."""
|
|
op_name = tf_export.get_canonical_name_for_symbol(op)
|
|
argspec = tf_inspect.getfullargspec(op)
|
|
arg_names = argspec.args
|
|
|
|
# Mark ragged arguments in bold.
|
|
for pos in ragged_args:
|
|
arg_names[pos] = '**' + arg_names[pos] + '**'
|
|
|
|
# Add argument defaults.
|
|
if argspec.defaults is not None:
|
|
for pos in range(-1, -len(argspec.defaults) - 1, -1):
|
|
arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos])
|
|
|
|
# Add varargs and keyword args
|
|
if argspec.varargs:
|
|
if ragged_varargs:
|
|
arg_names.append('***' + argspec.varargs + '**')
|
|
else:
|
|
arg_names.append('*' + argspec.varargs)
|
|
if argspec.varkw:
|
|
arg_names.append('**' + argspec.varkw)
|
|
|
|
return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
|
|
|
|
|
|
def _op_is_in_tf_version(op, version):
|
|
if version == 1:
|
|
return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
|
|
op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS)
|
|
elif version == 2:
|
|
return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
|
|
else:
|
|
raise ValueError('Expected version 1 or 2.')
|
|
|
|
|
|
def ragged_op_list(tf_version=2):
|
|
"""Returns a string listing operations that have dispathers registered."""
|
|
lines = []
|
|
api_signatures = dispatch.type_based_dispatch_signatures_for(
|
|
ragged_tensor.RaggedTensor)
|
|
for api, signatures in api_signatures.items():
|
|
arg_names = tf_inspect.getargspec(api).args
|
|
ragged_args = set()
|
|
for signature in signatures:
|
|
for arg in signature:
|
|
ragged_args.add(arg if isinstance(arg, int) else arg_names.index(arg))
|
|
if _op_is_in_tf_version(api, tf_version):
|
|
lines.append(_ragged_op_signature(api, ragged_args))
|
|
|
|
lines.append(
|
|
_ragged_op_signature(logging_ops.print_v2, [], ragged_varargs=True))
|
|
return ('\n\n### Additional ops that support `RaggedTensor`\n\n'
|
|
'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' +
|
|
'\n'.join(sorted(lines)) + 'n')
|