3RNN/Lib/site-packages/tensorflow/python/ops/check_ops.py

2367 lines
83 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""Asserts and Boolean Checks."""
import collections
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_assert
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
NUMERIC_TYPES = frozenset([
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16,
dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32,
dtypes.uint64, dtypes.qint8, dtypes.qint16, dtypes.qint32, dtypes.quint8,
dtypes.quint16, dtypes.complex64, dtypes.complex128, dtypes.bfloat16
])
__all__ = [
'assert_negative',
'assert_positive',
'assert_proper_iterable',
'assert_non_negative',
'assert_non_positive',
'assert_equal',
'assert_none_equal',
'assert_near',
'assert_integer',
'assert_less',
'assert_less_equal',
'assert_greater',
'assert_greater_equal',
'assert_rank',
'assert_rank_at_least',
'assert_rank_in',
'assert_same_float_dtype',
'assert_scalar',
'assert_type',
'assert_shapes',
'is_non_decreasing',
'is_numeric_tensor',
'is_strictly_increasing',
]
def _maybe_constant_value_string(t):
if not isinstance(t, tensor_lib.Tensor):
return str(t)
const_t = tensor_util.constant_value(t)
if const_t is not None:
return str(const_t)
return t
def _assert_static(condition, data):
"""Raises a InvalidArgumentError with as much information as possible."""
if not condition:
data_static = [_maybe_constant_value_string(x) for x in data]
raise errors.InvalidArgumentError(node_def=None, op=None,
message='\n'.join(data_static))
def _shape_and_dtype_str(tensor):
"""Returns a string containing tensor's shape and dtype."""
return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
def _unary_assert_doc(sym, sym_name):
"""Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor.
Args:
sym: Mathematical symbol for the check performed on each element, i.e. "> 0"
sym_name: English-language name for the op described by sym
Returns:
Decorator that adds the appropriate docstring to the function for symbol
`sym`.
"""
def _decorator(func):
"""Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
Args:
func: Function for a TensorFlow op
Returns:
Version of `func` with documentation attached.
"""
opname = func.__name__
cap_sym_name = sym_name.capitalize()
func.__doc__ = """
Assert the condition `x {sym}` holds element-wise.
When running in graph mode, you should add a dependency on this operation
to ensure that it runs. Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.debugging.{opname}(x, y)]):
output = tf.reduce_sum(x)
```
{sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`.
If `x` is empty this is trivially satisfied.
Args:
x: Numeric `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "{opname}".
Returns:
Op that raises `InvalidArgumentError` if `x {sym}` is False.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x {sym}` is False. The check can be performed immediately during
eager execution or if `x` is statically known.
""".format(
sym=sym, sym_name=cap_sym_name, opname=opname)
return func
return _decorator
def _binary_assert_doc(sym, test_var):
"""Common docstring for most of the v1 assert_* ops that compare two tensors element-wise.
Args:
sym: Binary operation symbol, i.e. "=="
test_var: a string that represents the variable in the right-hand side of
binary operator of the test case
Returns:
Decorator that adds the appropriate docstring to the function for
symbol `sym`.
"""
def _decorator(func):
"""Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
Args:
func: Function for a TensorFlow op
Returns:
A version of `func` with documentation attached.
"""
opname = func.__name__
func.__doc__ = """
Assert the condition `x {sym} y` holds element-wise.
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have `x[i] {sym} y[i]`.
If both `x` and `y` are empty, this is trivially satisfied.
When running in graph mode, you should add a dependency on this operation
to ensure that it runs. Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]):
output = tf.reduce_sum(x)
```
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "{opname}".
Returns:
Op that raises `InvalidArgumentError` if `x {sym} y` is False.
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x {sym} y` is False. The check can be performed immediately during
eager execution or if `x` and `y` are statically known.
@compatibility(TF2)
`tf.compat.v1.{opname}` is compatible with eager execution and
`tf.function`.
Please use `tf.debugging.{opname}` instead when migrating to TF2. Apart
from `data`, all arguments are supported with the same argument name.
If you want to ensure the assert statements run before the
potentially-invalid computation, please use `tf.control_dependencies`,
as tf.function auto-control dependencies are insufficient for assert
statements.
#### Structural Mapping to Native TF2
Before:
```python
tf.compat.v1.{opname}(
x=x, y=y, data=data, summarize=summarize,
message=message, name=name)
```
After:
```python
tf.debugging.{opname}(
x=x, y=y, message=message,
summarize=summarize, name=name)
```
#### TF1 & TF2 Usage Example
TF1:
>>> g = tf.Graph()
>>> with g.as_default():
... a = tf.compat.v1.placeholder(tf.float32, [2])
... b = tf.compat.v1.placeholder(tf.float32, [2])
... result = tf.compat.v1.{opname}(a, b,
... message='"a {sym} b" does not hold for the given inputs')
... with tf.compat.v1.control_dependencies([result]):
... sum_node = a + b
>>> sess = tf.compat.v1.Session(graph=g)
>>> val = sess.run(sum_node, feed_dict={{a: [1, 2], b:{test_var}}})
TF2:
>>> a = tf.Variable([1, 2], dtype=tf.float32)
>>> b = tf.Variable({test_var}, dtype=tf.float32)
>>> assert_op = tf.debugging.{opname}(a, b, message=
... '"a {sym} b" does not hold for the given inputs')
>>> # When working with tf.control_dependencies
>>> with tf.control_dependencies([assert_op]):
... val = a + b
@end_compatibility
""".format(
sym=sym, opname=opname, test_var=test_var)
return func
return _decorator
def _binary_assert_doc_v2(sym, opname, test_var):
"""Common docstring for v2 assert_* ops that compare two tensors element-wise.
Args:
sym: Binary operation symbol, i.e. "=="
opname: Name for the symbol, i.e. "assert_equal"
test_var: A number used in the docstring example
Returns:
Decorator that adds the appropriate docstring to the function for
symbol `sym`.
"""
def _decorator(func):
"""Decorator that adds docstring to the function for symbol `sym`.
Args:
func: Function for a TensorFlow op
Returns:
A version of `func` with documentation attached.
"""
func.__doc__ = """
Assert the condition `x {sym} y` holds element-wise.
This Op checks that `x[i] {sym} y[i]` holds for every pair of (possibly
broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
trivially satisfied.
If `x` {sym} `y` does not hold, `message`, as well as the first `summarize`
entries of `x` and `y` are printed, and `InvalidArgumentError` is raised.
When using inside `tf.function`, this API takes effects during execution.
It's recommended to use this API with `tf.control_dependencies` to
ensure the correct execution order.
In the following example, without `tf.control_dependencies`, errors may
not be raised at all.
Check `tf.control_dependencies` for more details.
>>> def check_size(x):
... with tf.control_dependencies([
... tf.debugging.{opname}(tf.size(x), {test_var},
... message='Bad tensor size')]):
... return x
>>> check_size(tf.ones([2, 3], tf.float32))
Traceback (most recent call last):
...
InvalidArgumentError: ...
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
message: A string to prefix to the default message. (optional)
summarize: Print this many entries of each tensor. (optional)
name: A name for this operation (optional). Defaults to "{opname}".
Returns:
Op that raises `InvalidArgumentError` if `x {sym} y` is False. This can
be used with `tf.control_dependencies` inside of `tf.function`s to
block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x == y` is False. The check can be performed immediately during eager
execution or if `x` and `y` are statically known.
""".format(
sym=sym, opname=opname, test_var=test_var)
return func
return _decorator
def _make_assert_msg_data(sym, x, y, summarize, test_op):
"""Subroutine of _binary_assert that generates the components of the default error message when running in eager mode.
Args:
sym: Mathematical symbol for the test to apply to pairs of tensor elements,
i.e. "=="
x: First input to the assertion after applying `convert_to_tensor()`
y: Second input to the assertion
summarize: Value of the "summarize" parameter to the original assert_* call;
tells how many elements of each tensor to print.
test_op: TensorFlow op that returns a Boolean tensor with True in each
position where the assertion is satisfied.
Returns:
List of tensors and scalars that, when stringified and concatenated,
will produce the error message string.
"""
# Prepare a message with first elements of x and y.
data = []
data.append('Condition x %s y did not hold.' % sym)
if summarize > 0:
if x.shape == y.shape and x.shape.as_list():
# If the shapes of x and y are the same (and not scalars),
# Get the values that actually differed and their indices.
# If shapes are different this information is more confusing
# than useful.
mask = math_ops.logical_not(test_op)
indices = array_ops.where(mask)
indices_np = indices.numpy()
x_vals = array_ops.boolean_mask(x, mask)
y_vals = array_ops.boolean_mask(y, mask)
num_vals = min(summarize, indices_np.shape[0])
data.append('Indices of first %d different values:' % num_vals)
data.append(indices_np[:num_vals])
data.append('Corresponding x values:')
data.append(x_vals.numpy().reshape((-1,))[:num_vals])
data.append('Corresponding y values:')
data.append(y_vals.numpy().reshape((-1,))[:num_vals])
# reshape((-1,)) is the fastest way to get a flat array view.
x_np = x.numpy().reshape((-1,))
y_np = y.numpy().reshape((-1,))
x_sum = min(x_np.size, summarize)
y_sum = min(y_np.size, summarize)
data.append('First %d elements of x:' % x_sum)
data.append(x_np[:x_sum])
data.append('First %d elements of y:' % y_sum)
data.append(y_np[:y_sum])
return data
def _pretty_print(data_item, summarize):
"""Format a data item for use in an error message in eager mode.
Args:
data_item: One of the items in the "data" argument to an assert_* function.
Can be a Tensor or a scalar value.
summarize: How many elements to retain of each tensor-valued entry in data.
Returns:
An appropriate string representation of data_item
"""
if isinstance(data_item, tensor_lib.Tensor):
arr = data_item.numpy()
if np.isscalar(arr):
# Tensor.numpy() returns a scalar for zero-dimensional tensors
return str(arr)
else:
flat = arr.reshape((-1,))
lst = [str(x) for x in flat[:summarize]]
if len(lst) < flat.size:
lst.append('...')
return str(lst)
else:
return str(data_item)
def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
message, name):
"""Generic binary elementwise assertion.
Implements the behavior described in _binary_assert_doc() above.
Args:
sym: Mathematical symbol for the test to apply to pairs of tensor elements,
i.e. "=="
opname: Name of the assert op in the public API, i.e. "assert_equal"
op_func: Function that, if passed the two Tensor inputs to the assertion (x
and y), will return the test to be passed to reduce_all() i.e.
static_func: Function that, if passed numpy ndarray versions of the two
inputs to the assertion, will return a Boolean ndarray with containing
True in all positions where the assertion PASSES.
i.e. np.equal for assert_equal()
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to the value of
`opname`.
Returns:
See docstring template in _binary_assert_doc().
"""
with ops.name_scope(name, opname, [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.executing_eagerly():
test_op = op_func(x, y)
condition = math_ops.reduce_all(test_op)
if condition:
return
# If we get here, the assertion has failed.
# Default to printing 3 elements like control_flow_ops.Assert (used
# by graph mode) does. Also treat negative values as "print
# everything" for consistency with Tensor::SummarizeValue().
if summarize is None:
summarize = 3
elif summarize < 0:
summarize = 1e9 # Code below will find exact size of x and y.
if data is None:
data = _make_assert_msg_data(sym, x, y, summarize, test_op)
if message is not None:
data = [message] + list(data)
raise errors.InvalidArgumentError(
node_def=None,
op=None,
message=('\n'.join(_pretty_print(d, summarize) for d in data)))
else: # not context.executing_eagerly()
if data is None:
data = [
'Condition x %s y did not hold element-wise:' % sym,
'x (%s) = ' % x.name, x,
'y (%s) = ' % y.name, y
]
if message is not None:
data = [message] + list(data)
condition = math_ops.reduce_all(op_func(x, y))
x_static = tensor_util.constant_value(x)
y_static = tensor_util.constant_value(y)
if x_static is not None and y_static is not None:
condition_static = np.all(static_func(x_static, y_static))
_assert_static(condition_static, data)
return control_flow_assert.Assert(condition, data, summarize=summarize)
@tf_export(
'debugging.assert_proper_iterable',
v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_proper_iterable')
def assert_proper_iterable(values):
"""Static assert that values is a "proper" iterable.
`Ops` that expect iterables of `Tensor` can call this to validate input.
Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
Args:
values: Object to be checked.
Raises:
TypeError: If `values` is not iterable or is one of
`Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
"""
unintentional_iterables = (
(tensor_lib.Tensor, sparse_tensor.SparseTensor, np.ndarray)
+ compat.bytes_or_text_types
)
if isinstance(values, unintentional_iterables):
raise TypeError(
'Expected argument "values" to be a "proper" iterable. Found: %s' %
type(values))
if not hasattr(values, '__iter__'):
raise TypeError(
'Expected argument "values" to be iterable. Found: %s' % type(values))
@tf_export('debugging.assert_negative', v1=[])
@dispatch.add_dispatch_support
def assert_negative_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x < 0` holds element-wise.
This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is
empty, this is trivially satisfied.
If `x` is not negative everywhere, `message`, as well as the first `summarize`
entries of `x` are printed, and `InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_negative".
Returns:
Op raising `InvalidArgumentError` unless `x` is all negative. This can be
used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x[i] < 0` is False. The check can be performed immediately during eager
execution or if `x` is statically known.
"""
return assert_negative(x=x, message=message, summarize=summarize, name=name)
@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_negative')
@_unary_assert_doc('< 0', 'negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = _message_prefix(message)
with ops.name_scope(name, 'assert_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
data = [
message,
'Condition x < 0 did not hold element-wise:',
'x (%s) = ' % name, x]
zero = ops.convert_to_tensor(0, dtype=x.dtype)
return assert_less(x, zero, data=data, summarize=summarize)
@tf_export('debugging.assert_positive', v1=[])
@dispatch.add_dispatch_support
def assert_positive_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x > 0` holds element-wise.
This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is
empty, this is trivially satisfied.
If `x` is not positive everywhere, `message`, as well as the first `summarize`
entries of `x` are printed, and `InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_positive".
Returns:
Op raising `InvalidArgumentError` unless `x` is all positive. This can be
used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x[i] > 0` is False. The check can be performed immediately during eager
execution or if `x` is statically known.
"""
return assert_positive(x=x, summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_positive')
@_unary_assert_doc('> 0', 'positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = _message_prefix(message)
with ops.name_scope(name, 'assert_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
data = [
message, 'Condition x > 0 did not hold element-wise:',
'x (%s) = ' % name, x]
zero = ops.convert_to_tensor(0, dtype=x.dtype)
return assert_less(zero, x, data=data, summarize=summarize)
@tf_export('debugging.assert_non_negative', v1=[])
@dispatch.add_dispatch_support
def assert_non_negative_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x >= 0` holds element-wise.
This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is
empty, this is trivially satisfied.
If `x` is not >= 0 everywhere, `message`, as well as the first `summarize`
entries of `x` are printed, and `InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to
"assert_non_negative".
Returns:
Op raising `InvalidArgumentError` unless `x` is all non-negative. This can
be used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x[i] >= 0` is False. The check can be performed immediately during eager
execution or if `x` is statically known.
"""
return assert_non_negative(x=x, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_non_negative')
@_unary_assert_doc('>= 0', 'non-negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = _message_prefix(message)
with ops.name_scope(name, 'assert_non_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
data = [
message,
'Condition x >= 0 did not hold element-wise:',
'x (%s) = ' % name, x]
zero = ops.convert_to_tensor(0, dtype=x.dtype)
return assert_less_equal(zero, x, data=data, summarize=summarize)
@tf_export('debugging.assert_non_positive', v1=[])
@dispatch.add_dispatch_support
def assert_non_positive_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x <= 0` holds element-wise.
This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is
empty, this is trivially satisfied.
If `x` is not <= 0 everywhere, `message`, as well as the first `summarize`
entries of `x` are printed, and `InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to
"assert_non_positive".
Returns:
Op raising `InvalidArgumentError` unless `x` is all non-positive. This can
be used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x[i] <= 0` is False. The check can be performed immediately during eager
execution or if `x` is statically known.
"""
return assert_non_positive(x=x, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_non_positive')
@_unary_assert_doc('<= 0', 'non-positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = _message_prefix(message)
with ops.name_scope(name, 'assert_non_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
data = [
message,
'Condition x <= 0 did not hold element-wise:'
'x (%s) = ' % name, x]
zero = ops.convert_to_tensor(0, dtype=x.dtype)
return assert_less_equal(x, zero, data=data, summarize=summarize)
@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@_binary_assert_doc_v2('==', 'assert_equal', 3)
def assert_equal_v2(x, y, message=None, summarize=None, name=None):
return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_equal', 'assert_equal'])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@_binary_assert_doc('==', '[1, 2]')
def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
with ops.name_scope(name, 'assert_equal', [x, y, data]):
# Short-circuit if x and y are the same tensor.
if x is y:
return None if context.executing_eagerly() else control_flow_ops.no_op()
return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
data, summarize, message, name)
@tf_export('debugging.assert_none_equal', v1=[])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@_binary_assert_doc_v2('!=', 'assert_none_equal', 6)
def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
return assert_none_equal(x=x, y=y, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_none_equal')
@_binary_assert_doc('!=', '[2, 1]')
def assert_none_equal(
x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
np.not_equal, x, y, data, summarize, message, name)
@tf_export('debugging.assert_near', v1=[])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
name=None):
"""Assert the condition `x` and `y` are close element-wise.
This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every
pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are
empty, this is trivially satisfied.
If any elements of `x` and `y` are not close, `message`, as well as the first
`summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
is raised.
The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
representable positive number such that `1 + eps != 1`. This is about
`1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
See `numpy.finfo`.
Args:
x: Float or complex `Tensor`.
y: Float or complex `Tensor`, same dtype as and broadcastable to `x`.
rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
The relative tolerance. Default is `10 * eps`.
atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
The absolute tolerance. Default is `10 * eps`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_near".
Returns:
Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
This can be used with `tf.control_dependencies` inside of `tf.function`s
to block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x != y` is False for any pair of elements in `x` and `y`. The check can
be performed immediately during eager execution or if `x` and `y` are
statically known.
@compatibility(numpy)
Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
type. This is due to the fact that `TensorFlow` is often used with `32bit`,
`64bit`, and even `16bit` data.
@end_compatibility
"""
return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize,
message=message, name=name)
@tf_export(v1=['debugging.assert_near', 'assert_near'])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_near')
def assert_near(
x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
name=None):
"""Assert the condition `x` and `y` are close element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]):
output = tf.reduce_sum(x)
```
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have
```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```.
If both `x` and `y` are empty, this is trivially satisfied.
The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
representable positive number such that `1 + eps != 1`. This is about
`1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
See `numpy.finfo`.
Args:
x: Float or complex `Tensor`.
y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`.
rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
The relative tolerance. Default is `10 * eps`.
atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
The absolute tolerance. Default is `10 * eps`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_near".
Returns:
Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
@compatibility(numpy)
Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
type. This is due to the fact that `TensorFlow` is often used with `32bit`,
`64bit`, and even `16bit` data.
@end_compatibility
"""
message = _message_prefix(message)
with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
dtype = x.dtype
if dtype.is_complex:
dtype = dtype.real_dtype
eps = np.finfo(dtype.as_numpy_dtype).eps
rtol = 10 * eps if rtol is None else rtol
atol = 10 * eps if atol is None else atol
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype)
atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype)
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
x_name = x.name
y_name = y.name
if data is None:
data = [
message,
'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol),
'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
]
tol = atol + rtol * math_ops.abs(y)
diff = math_ops.abs(x - y)
condition = math_ops.reduce_all(math_ops.less(diff, tol))
return control_flow_assert.Assert(condition, data, summarize=summarize)
@tf_export('debugging.assert_less', 'assert_less', v1=[])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@_binary_assert_doc_v2('<', 'assert_less', 3)
def assert_less_v2(x, y, message=None, summarize=None, name=None):
return assert_less(x=x, y=y, summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_less', 'assert_less'])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@_binary_assert_doc('<', '[2, 3]')
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
summarize, message, name)
@tf_export('debugging.assert_less_equal', v1=[])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@_binary_assert_doc_v2('<=', 'assert_less_equal', 3)
def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
return assert_less_equal(x=x, y=y,
summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_less_equal')
@_binary_assert_doc('<=', '[1, 3]')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
np.less_equal, x, y, data, summarize, message, name)
@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@_binary_assert_doc_v2('>', 'assert_greater', 9)
def assert_greater_v2(x, y, message=None, summarize=None, name=None):
return assert_greater(x=x, y=y, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@_binary_assert_doc('>', '[0, 1]')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
y, data, summarize, message, name)
@tf_export('debugging.assert_greater_equal', v1=[])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@_binary_assert_doc_v2('>=', 'assert_greater_equal', 9)
def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
return assert_greater_equal(x=x, y=y, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
@dispatch.register_binary_elementwise_assert_api
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_greater_equal')
@_binary_assert_doc('>=', '[1, 0]')
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
name=None):
return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
np.greater_equal, x, y, data, summarize, message, name)
def _assert_rank_condition(
x, rank, static_condition, dynamic_condition, data, summarize):
"""Assert `x` has a rank that satisfies a given condition.
Args:
x: Numeric `Tensor`.
rank: Scalar `Tensor`.
static_condition: A python function that takes `[actual_rank, given_rank]`
and returns `True` if the condition is satisfied, `False` otherwise.
dynamic_condition: An `op` that takes [actual_rank, given_rank] and return
`True` if the condition is satisfied, `False` otherwise.
data: The tensors to print out if the condition is false. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
Returns:
Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
Raises:
ValueError: If static checks determine `x` fails static_condition.
"""
assert_type(rank, dtypes.int32)
# Attempt to statically defined rank.
rank_static = tensor_util.constant_value(rank)
if rank_static is not None:
if rank_static.ndim != 0:
raise ValueError('Rank must be a scalar.')
x_rank_static = x.get_shape().ndims
if x_rank_static is not None:
if not static_condition(x_rank_static, rank_static):
raise ValueError(
'Static rank condition failed', x_rank_static, rank_static)
return control_flow_ops.no_op(name='static_checks_determined_all_ok')
condition = dynamic_condition(array_ops.rank(x), rank)
# Add the condition that `rank` must have rank zero. Prevents the bug where
# someone does assert_rank(x, [n]), rather than assert_rank(x, n).
if rank_static is None:
this_data = ['Rank must be a scalar. Received rank: ', rank]
rank_check = assert_rank(rank, 0, data=this_data)
condition = control_flow_ops.with_dependencies([rank_check], condition)
return control_flow_assert.Assert(condition, data, summarize=summarize)
@tf_export('debugging.assert_rank', 'assert_rank', v1=[])
@dispatch.add_dispatch_support
def assert_rank_v2(x, rank, message=None, name=None):
"""Assert that `x` has rank equal to `rank`.
This Op checks that the rank of `x` is equal to `rank`.
If `x` has a different rank, `message`, as well as the shape of `x` are
printed, and `InvalidArgumentError` is raised.
Args:
x: `Tensor`.
rank: Scalar integer `Tensor`.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to
"assert_rank".
Returns:
Op raising `InvalidArgumentError` unless `x` has specified rank.
If static checks determine `x` has correct rank, a `no_op` is returned.
This can be used with `tf.control_dependencies` inside of `tf.function`s
to block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x` does not have rank `rank`. The check can be performed immediately
during eager execution or if the shape of `x` is statically known.
"""
return assert_rank(x=x, rank=rank, message=message, name=name)
@tf_export(v1=['debugging.assert_rank', 'assert_rank'])
@dispatch.add_dispatch_support
def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank`.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]):
output = tf.reduce_sum(x)
```
Args:
x: Numeric `Tensor`.
rank: Scalar integer `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and the shape of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_rank".
Returns:
Op raising `InvalidArgumentError` unless `x` has specified rank.
If static checks determine `x` has correct rank, a `no_op` is returned.
Raises:
ValueError: If static checks determine `x` has wrong rank.
"""
with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
if not isinstance(x, sparse_tensor.SparseTensor):
x = ops.convert_to_tensor(x, name='x')
rank = ops.convert_to_tensor(rank, name='rank')
message = _message_prefix(message)
static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
dynamic_condition = math_ops.equal
if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
name = ''
else:
name = x.name
if data is None:
data = [
message,
'Tensor %s must have rank' % name, rank, 'Received shape: ',
array_ops.shape(x)
]
try:
assert_op = _assert_rank_condition(x, rank, static_condition,
dynamic_condition, data, summarize)
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%sTensor %s must have rank %d. Received rank %d, shape %s' %
(message, name, e.args[2], e.args[1], x.get_shape()))
else:
raise ValueError(e.args[0])
return assert_op
@tf_export('debugging.assert_rank_at_least', v1=[])
@dispatch.add_dispatch_support
def assert_rank_at_least_v2(x, rank, message=None, name=None):
"""Assert that `x` has rank of at least `rank`.
This Op checks that the rank of `x` is greater or equal to `rank`.
If `x` has a rank lower than `rank`, `message`, as well as the shape of `x`
are printed, and `InvalidArgumentError` is raised.
Args:
x: `Tensor`.
rank: Scalar integer `Tensor`.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to
"assert_rank_at_least".
Returns:
Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
If static checks determine `x` has correct rank, a `no_op` is returned.
This can be used with `tf.control_dependencies` inside of `tf.function`s
to block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: `x` does not have rank at least `rank`, but the rank
cannot be statically determined.
ValueError: If static checks determine `x` has mismatched rank.
"""
return assert_rank_at_least(x=x, rank=rank, message=message, name=name)
@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_rank_at_least')
def assert_rank_at_least(
x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank` or higher.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]):
output = tf.reduce_sum(x)
```
Args:
x: Numeric `Tensor`.
rank: Scalar `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Defaults to "assert_rank_at_least".
Returns:
Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
If static checks determine `x` has correct rank, a `no_op` is returned.
Raises:
ValueError: If static checks determine `x` has wrong rank.
"""
with ops.name_scope(
name, 'assert_rank_at_least', (x, rank) + tuple(data or [])):
x = ops.convert_to_tensor(x, name='x')
rank = ops.convert_to_tensor(rank, name='rank')
message = _message_prefix(message)
static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
dynamic_condition = math_ops.greater_equal
if context.executing_eagerly():
name = ''
else:
name = x.name
if data is None:
data = [
message,
'Tensor %s must have rank at least' % name, rank,
'Received shape: ', array_ops.shape(x)
]
try:
assert_op = _assert_rank_condition(x, rank, static_condition,
dynamic_condition, data, summarize)
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%sTensor %s must have rank at least %d. Received rank %d, '
'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
else:
raise
return assert_op
def _static_rank_in(actual_rank, given_ranks):
return actual_rank in given_ranks
def _dynamic_rank_in(actual_rank, given_ranks):
if len(given_ranks) < 1:
return ops.convert_to_tensor(False)
result = math_ops.equal(given_ranks[0], actual_rank)
for given_rank in given_ranks[1:]:
result = math_ops.logical_or(
result, math_ops.equal(given_rank, actual_rank))
return result
def _assert_ranks_condition(
x, ranks, static_condition, dynamic_condition, data, summarize):
"""Assert `x` has a rank that satisfies a given condition.
Args:
x: Numeric `Tensor`.
ranks: Scalar `Tensor`.
static_condition: A python function that takes
`[actual_rank, given_ranks]` and returns `True` if the condition is
satisfied, `False` otherwise.
dynamic_condition: An `op` that takes [actual_rank, given_ranks]
and return `True` if the condition is satisfied, `False` otherwise.
data: The tensors to print out if the condition is false. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
Returns:
Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
Raises:
ValueError: If static checks determine `x` fails static_condition.
"""
for rank in ranks:
assert_type(rank, dtypes.int32)
# Attempt to statically defined rank.
ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
if not any(r is None for r in ranks_static):
for rank_static in ranks_static:
if rank_static.ndim != 0:
raise ValueError('Rank must be a scalar.')
x_rank_static = x.get_shape().ndims
if x_rank_static is not None:
if not static_condition(x_rank_static, ranks_static):
raise ValueError(
'Static rank condition failed', x_rank_static, ranks_static)
return control_flow_ops.no_op(name='static_checks_determined_all_ok')
condition = dynamic_condition(array_ops.rank(x), ranks)
# Add the condition that `rank` must have rank zero. Prevents the bug where
# someone does assert_rank(x, [n]), rather than assert_rank(x, n).
for rank, rank_static in zip(ranks, ranks_static):
if rank_static is None:
this_data = ['Rank must be a scalar. Received rank: ', rank]
rank_check = assert_rank(rank, 0, data=this_data)
condition = control_flow_ops.with_dependencies([rank_check], condition)
return control_flow_assert.Assert(condition, data, summarize=summarize)
@tf_export('debugging.assert_rank_in', v1=[])
@dispatch.add_dispatch_support
def assert_rank_in_v2(x, ranks, message=None, name=None):
"""Assert that `x` has a rank in `ranks`.
This Op checks that the rank of `x` is in `ranks`.
If `x` has a different rank, `message`, as well as the shape of `x` are
printed, and `InvalidArgumentError` is raised.
Args:
x: `Tensor`.
ranks: `Iterable` of scalar `Tensor` objects.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_rank_in".
Returns:
Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
If static checks determine `x` has matching rank, a `no_op` is returned.
This can be used with `tf.control_dependencies` inside of `tf.function`s
to block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot
be statically determined.
ValueError: If static checks determine `x` has mismatched rank.
"""
return assert_rank_in(x=x, ranks=ranks, message=message, name=name)
@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_rank_in')
def assert_rank_in(
x, ranks, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank in `ranks`.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]):
output = tf.reduce_sum(x)
```
Args:
x: Numeric `Tensor`.
ranks: Iterable of scalar `Tensor` objects.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Defaults to "assert_rank_in".
Returns:
Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
If static checks determine `x` has matching rank, a `no_op` is returned.
Raises:
ValueError: If static checks determine `x` has mismatched rank.
"""
with ops.name_scope(
name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
if not isinstance(x, sparse_tensor.SparseTensor):
x = ops.convert_to_tensor(x, name='x')
ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
message = _message_prefix(message)
if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
name = ''
else:
name = x.name
if data is None:
data = [
message, 'Tensor %s must have rank in' % name
] + list(ranks) + [
'Received shape: ', array_ops.shape(x)
]
try:
assert_op = _assert_ranks_condition(x, ranks, _static_rank_in,
_dynamic_rank_in, data, summarize)
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%sTensor %s must have rank in %s. Received rank %d, '
'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
else:
raise
return assert_op
@tf_export('debugging.assert_integer', v1=[])
@dispatch.add_dispatch_support
def assert_integer_v2(x, message=None, name=None):
"""Assert that `x` is of integer dtype.
If `x` has a non-integer type, `message`, as well as the dtype of `x` are
printed, and `InvalidArgumentError` is raised.
This can always be checked statically, so this method returns nothing.
Args:
x: A `Tensor`.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_integer".
Raises:
TypeError: If `x.dtype` is not a non-quantized integer type.
"""
assert_integer(x=x, message=message, name=name)
@tf_export(v1=['debugging.assert_integer', 'assert_integer'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_integer')
def assert_integer(x, message=None, name=None):
"""Assert that `x` is of integer dtype.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_integer(x)]):
output = tf.reduce_sum(x)
```
Args:
x: `Tensor` whose basetype is integer and is not quantized.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_integer".
Raises:
TypeError: If `x.dtype` is anything other than non-quantized integer.
Returns:
A `no_op` that does nothing. Type can be determined statically.
"""
with ops.name_scope(name, 'assert_integer', [x]):
x = ops.convert_to_tensor(x, name='x')
if not x.dtype.is_integer:
if context.executing_eagerly():
name = 'tensor'
else:
name = x.name
err_msg = (
'%sExpected "x" to be integer type. Found: %s of dtype %s'
% (_message_prefix(message), name, x.dtype))
raise TypeError(err_msg)
return control_flow_ops.no_op('statically_determined_was_integer')
@tf_export('debugging.assert_type', v1=[])
@dispatch.add_dispatch_support
def assert_type_v2(tensor, tf_type, message=None, name=None):
"""Asserts that the given `Tensor` is of the specified type.
This can always be checked statically, so this method returns nothing.
Example:
>>> a = tf.Variable(1.0)
>>> tf.debugging.assert_type(a, tf_type= tf.float32)
>>> b = tf.constant(21)
>>> tf.debugging.assert_type(b, tf_type=tf.bool)
Traceback (most recent call last):
...
TypeError: ...
>>> c = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2],
... dense_shape=[3, 4])
>>> tf.debugging.assert_type(c, tf_type= tf.int32)
Args:
tensor: A `Tensor`, `SparseTensor` or `tf.Variable` .
tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
etc).
message: A string to prefix to the default message.
name: A name for this operation. Defaults to "assert_type"
Raises:
TypeError: If the tensor's data type doesn't match `tf_type`.
"""
assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name)
@tf_export(v1=['debugging.assert_type', 'assert_type'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_type')
def assert_type(tensor, tf_type, message=None, name=None):
"""Statically asserts that the given `Tensor` is of the specified type.
Args:
tensor: A `Tensor` or `SparseTensor`.
tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
etc).
message: A string to prefix to the default message.
name: A name to give this `Op`. Defaults to "assert_type"
Raises:
TypeError: If the tensors data type doesn't match `tf_type`.
Returns:
A `no_op` that does nothing. Type can be determined statically.
"""
tf_type = dtypes.as_dtype(tf_type)
with ops.name_scope(name, 'assert_type', [tensor]):
if not isinstance(tensor, sparse_tensor.SparseTensor):
tensor = ops.convert_to_tensor(tensor, name='tensor')
if tensor.dtype != tf_type:
raise TypeError(
f'{_message_prefix(message)}{getattr(tensor, "name", "tensor")}'
f' must be of type {tf_type!r}; got {tensor.dtype!r}')
return control_flow_ops.no_op('statically_determined_correct_type')
def _dimension_sizes(x):
"""Gets the dimension sizes of a tensor `x`.
If a size can be determined statically it is returned as an integer,
otherwise as a tensor.
If `x` is a scalar it is treated as rank 1 size 1.
Args:
x: A `Tensor`.
Returns:
Dimension sizes.
"""
dynamic_shape = array_ops.shape(x)
rank = x.get_shape().rank
rank_is_known = rank is not None
if rank_is_known and rank == 0:
return (1,)
if rank_is_known and rank > 0:
static_shape = x.get_shape().as_list()
sizes = [
int(size) if size is not None else dynamic_shape[i]
for i, size in enumerate(static_shape)
]
return sizes
has_rank_zero = math_ops.equal(array_ops.rank(x), 0)
return cond.cond(
has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape)
def _symbolic_dimension_sizes(symbolic_shape):
# If len(symbolic_shape) == 0 construct a tuple
if not symbolic_shape:
return tuple([1])
return symbolic_shape
def _has_known_value(dimension_size):
not_none = dimension_size is not None
try:
int(dimension_size)
can_be_parsed_as_int = True
except (ValueError, TypeError):
can_be_parsed_as_int = False
return not_none and can_be_parsed_as_int
def _is_symbol_for_any_size(symbol):
return symbol in [None, '.']
_TensorDimSizes = collections.namedtuple(
'_TensorDimSizes',
['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes'])
@tf_export('debugging.assert_shapes', v1=[])
@dispatch.add_dispatch_support
def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
name=None):
"""Assert tensor shapes and dimension size relationships between tensors.
This Op checks that a collection of tensors shape relationships
satisfies given constraints.
Example:
>>> n = 10
>>> q = 3
>>> d = 7
>>> x = tf.zeros([n,q])
>>> y = tf.ones([n,d])
>>> param = tf.Variable([1.0, 2.0, 3.0])
>>> scalar = 1.0
>>> tf.debugging.assert_shapes([
... (x, ('N', 'Q')),
... (y, ('N', 'D')),
... (param, ('Q',)),
... (scalar, ()),
... ])
>>> tf.debugging.assert_shapes([
... (x, ('N', 'D')),
... (y, ('N', 'D'))
... ])
Traceback (most recent call last):
...
ValueError: ...
If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
all specified constraints, `message`, as well as the first `summarize` entries
of the first encountered violating tensor are printed, and
`InvalidArgumentError` is raised.
Size entries in the specified shapes are checked against other entries by
their __hash__, except:
- a size entry is interpreted as an explicit size if it can be parsed as an
integer primitive.
- a size entry is interpreted as *any* size if it is None or '.'.
If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
a variable number of outer dimensions of unspecified size, i.e. the constraint
applies to the inner-most dimensions only.
Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
prefix) are both treated as having a single dimension of size one.
Args:
shapes: dictionary with (`Tensor` to shape) items, or a list of
(`Tensor`, shape) tuples. A shape must be an iterable.
data: The tensors to print out if the condition is False. Defaults to error
message and first few entries of the violating tensor.
summarize: Print this many entries of the tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_shapes".
Raises:
ValueError: If static checks determine any shape constraint is violated.
"""
assert_shapes(
shapes, data=data, summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_shapes'])
@dispatch.add_dispatch_support
def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
"""Assert tensor shapes and dimension size relationships between tensors.
This Op checks that a collection of tensors shape relationships
satisfies given constraints.
Example:
>>> n = 10
>>> q = 3
>>> d = 7
>>> x = tf.zeros([n,q])
>>> y = tf.ones([n,d])
>>> param = tf.Variable([1.0, 2.0, 3.0])
>>> scalar = 1.0
>>> tf.debugging.assert_shapes([
... (x, ('N', 'Q')),
... (y, ('N', 'D')),
... (param, ('Q',)),
... (scalar, ()),
... ])
>>> tf.debugging.assert_shapes([
... (x, ('N', 'D')),
... (y, ('N', 'D'))
... ])
Traceback (most recent call last):
...
ValueError: ...
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.assert_shapes(shapes)]):
output = tf.matmul(x, y, transpose_a=True)
```
If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
all specified constraints, `message`, as well as the first `summarize` entries
of the first encountered violating tensor are printed, and
`InvalidArgumentError` is raised.
Size entries in the specified shapes are checked against other entries by
their __hash__, except:
- a size entry is interpreted as an explicit size if it can be parsed as an
integer primitive.
- a size entry is interpreted as *any* size if it is None or '.'.
If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
a variable number of outer dimensions of unspecified size, i.e. the constraint
applies to the inner-most dimensions only.
Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
prefix) are both treated as having a single dimension of size one.
Args:
shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the
expected shape of `Tensor`. See the example code above. The `shape` must
be an iterable. Each element of the iterable can be either a concrete
integer value or a string that abstractly represents the dimension.
For example,
- `('N', 'Q')` specifies a 2D shape wherein the first and second
dimensions of shape may or may not be equal.
- `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second
dimensions are equal.
- `(1, 'N')` specifies a 2D shape wherein the first dimension is
exactly 1 and the second dimension can be any value.
Note that the abstract dimension letters take effect across different
tuple elements of the list. For example,
`tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts
that both `x` and `y` are rank-2 tensors and their first dimensions are
equal (`N`).
`shape` can also be a `tf.TensorShape`.
data: The tensors to print out if the condition is False. Defaults to error
message and first few entries of the violating tensor.
summarize: Print this many entries of the tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_shapes".
Returns:
Op raising `InvalidArgumentError` unless all shape constraints are
satisfied.
If static checks determine all constraints are satisfied, a `no_op` is
returned.
Raises:
ValueError: If static checks determine any shape constraint is violated.
"""
# If the user manages to assemble a dict containing tensors (possible in
# Graph mode only), make sure we still accept that.
if isinstance(shapes, dict):
shapes = shapes.items()
message_prefix = _message_prefix(message)
with ops.name_scope(name, 'assert_shapes', [shapes, data]):
# Shape specified as None implies no constraint
shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else
ops.convert_to_tensor(x), s)
for x, s in shapes if s is not None]
executing_eagerly = context.executing_eagerly()
def tensor_name(x):
if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor):
return _shape_and_dtype_str(x)
return x.name
tensor_dim_sizes = []
for tensor, symbolic_shape in shape_constraints:
is_iterable = (
hasattr(symbolic_shape, '__iter__') or
hasattr(symbolic_shape, '__getitem__') # For Python 2 compat.
)
if not is_iterable:
raise ValueError(
'%s'
'Tensor %s. Specified shape must be an iterable. '
'An iterable has the attribute `__iter__` or `__getitem__`. '
'Received specified shape: %s' %
(message_prefix, tensor_name(tensor), symbolic_shape))
# We convert this into a tuple to handle strings, lists and numpy arrays
symbolic_shape_tuple = tuple(symbolic_shape)
tensors_specified_innermost = False
for i, symbol in enumerate(symbolic_shape_tuple):
if symbol not in [Ellipsis, '*']:
continue
if i != 0:
raise ValueError(
'%s'
'Tensor %s specified shape index %d. '
'Symbol `...` or `*` for a variable number of '
'unspecified dimensions is only allowed as the first entry' %
(message_prefix, tensor_name(tensor), i))
tensors_specified_innermost = True
# Only include the size of the specified dimensions since the 0th symbol
# is either ellipsis or *
tensor_dim_sizes.append(
_TensorDimSizes(
tensor, tensors_specified_innermost, _dimension_sizes(tensor),
_symbolic_dimension_sizes(
symbolic_shape_tuple[1:]
if tensors_specified_innermost else symbolic_shape_tuple)))
rank_assertions = []
for sizes in tensor_dim_sizes:
rank = len(sizes.symbolic_sizes)
rank_zero_or_one = rank in [0, 1]
if sizes.unspecified_dim:
if rank_zero_or_one:
# No assertion of rank needed as `x` only need to have rank at least
# 0. See elif rank_zero_or_one case comment.
continue
assertion = assert_rank_at_least(
x=sizes.x,
rank=rank,
data=data,
summarize=summarize,
message=message,
name=name)
elif rank_zero_or_one:
# Rank 0 is treated as rank 1 size 1, i.e. there is
# no distinction between the two in terms of rank.
# See _dimension_sizes.
assertion = assert_rank_in(
x=sizes.x,
ranks=[0, 1],
data=data,
summarize=summarize,
message=message,
name=name)
else:
assertion = assert_rank(
x=sizes.x,
rank=rank,
data=data,
summarize=summarize,
message=message,
name=name)
rank_assertions.append(assertion)
size_assertions = []
size_specifications = {}
for sizes in tensor_dim_sizes:
for i, size_symbol in enumerate(sizes.symbolic_sizes):
if _is_symbol_for_any_size(size_symbol):
# Size specified as any implies no constraint
continue
if sizes.unspecified_dim:
tensor_dim = i - len(sizes.symbolic_sizes)
else:
tensor_dim = i
if size_symbol in size_specifications or _has_known_value(size_symbol):
if _has_known_value(size_symbol):
specified_size = int(size_symbol)
size_check_message = 'Specified explicitly'
else:
specified_size, specified_by_y, specified_at_dim = (
size_specifications[size_symbol])
size_check_message = (
'Specified by tensor %s dimension %d' %
(tensor_name(specified_by_y), specified_at_dim))
# This is extremely subtle. If actual_sizes is dynamic, we must
# make sure a control dependency is inserted here so that this slice
# can not execute until the rank is asserted to be enough for the
# slice to not fail.
with ops.control_dependencies(rank_assertions):
actual_size = sizes.actual_sizes[tensor_dim]
if _has_known_value(actual_size) and _has_known_value(specified_size):
if int(actual_size) != int(specified_size):
raise ValueError(
'%s%s. Tensor %s dimension %s must have size %d. '
'Received size %d, shape %s' %
(message_prefix, size_check_message, tensor_name(sizes.x),
tensor_dim, specified_size, actual_size,
sizes.x.get_shape()))
# No dynamic assertion needed
continue
condition = math_ops.equal(
ops.convert_to_tensor(actual_size),
ops.convert_to_tensor(specified_size))
data_ = data
if data is None:
data_ = [
message_prefix, size_check_message,
'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim,
'must have size', specified_size, 'Received shape: ',
array_ops.shape(sizes.x)
]
size_assertions.append(
control_flow_assert.Assert(condition, data_, summarize=summarize))
else:
# Not sure if actual_sizes is a constant, but for safety, guard
# on rank. See explanation above about actual_sizes need for safety.
with ops.control_dependencies(rank_assertions):
size = sizes.actual_sizes[tensor_dim]
size_specifications[size_symbol] = (size, sizes.x, tensor_dim)
# Ensure both assertions actually occur.
with ops.control_dependencies(rank_assertions):
shapes_assertion = control_flow_ops.group(size_assertions)
return shapes_assertion
# pylint: disable=line-too-long
def _get_results_for_monotonic_comparison(x, compare_op):
"""Gets the difference x[1:] - x[:-1]."""
x = array_ops.reshape(x, [-1])
if not is_numeric_tensor(x):
raise TypeError('Expected x to be numeric, instead found: %s' % x)
# If x has less than 2 elements, there is nothing to compare. So return [].
is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
short_result = lambda: ops.convert_to_tensor([], dtype=bool)
# With 2 or more elements, return x[1:] - x[:-1]
s_len = array_ops.shape(x) - 1
diff = lambda: compare_op(
array_ops.strided_slice(x, [1], [1] + s_len),
array_ops.strided_slice(x, [0], s_len),
)
return cond.cond(is_shorter_than_two, short_result, diff)
@tf_export(
'debugging.is_numeric_tensor',
v1=['debugging.is_numeric_tensor', 'is_numeric_tensor'])
@deprecation.deprecated_endpoints('is_numeric_tensor')
def is_numeric_tensor(tensor):
"""Returns `True` if the elements of `tensor` are numbers.
Specifically, returns `True` if the dtype of `tensor` is one of the following:
* `tf.float16`
* `tf.float32`
* `tf.float64`
* `tf.int8`
* `tf.int16`
* `tf.int32`
* `tf.int64`
* `tf.uint8`
* `tf.uint16`
* `tf.uint32`
* `tf.uint64`
* `tf.qint8`
* `tf.qint16`
* `tf.qint32`
* `tf.quint8`
* `tf.quint16`
* `tf.complex64`
* `tf.complex128`
* `tf.bfloat16`
Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not
a `tf.Tensor` object.
"""
return isinstance(tensor, tensor_lib.Tensor) and tensor.dtype in NUMERIC_TYPES
@tf_export(
'math.is_non_decreasing',
v1=[
'math.is_non_decreasing', 'debugging.is_non_decreasing',
'is_non_decreasing'
])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('debugging.is_non_decreasing',
'is_non_decreasing')
def is_non_decreasing(x, name=None):
"""Returns `True` if `x` is non-decreasing.
Elements of `x` are compared in row-major order. The tensor `[x[0],...]`
is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
If `x` has less than two elements, it is trivially non-decreasing.
See also: `is_strictly_increasing`
>>> x1 = tf.constant([1.0, 1.0, 3.0])
>>> tf.math.is_non_decreasing(x1)
<tf.Tensor: shape=(), dtype=bool, numpy=True>
>>> x2 = tf.constant([3.0, 1.0, 2.0])
>>> tf.math.is_non_decreasing(x2)
<tf.Tensor: shape=(), dtype=bool, numpy=False>
Args:
x: Numeric `Tensor`.
name: A name for this operation (optional). Defaults to "is_non_decreasing"
Returns:
Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.
Raises:
TypeError: if `x` is not a numeric tensor.
"""
with ops.name_scope(name, 'is_non_decreasing', [x]):
diff = _get_results_for_monotonic_comparison(x, math_ops.greater_equal)
# When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
return math_ops.reduce_all(diff)
@tf_export(
'math.is_strictly_increasing',
v1=[
'math.is_strictly_increasing', 'debugging.is_strictly_increasing',
'is_strictly_increasing'
])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('debugging.is_strictly_increasing',
'is_strictly_increasing')
def is_strictly_increasing(x, name=None):
"""Returns `True` if `x` is strictly increasing.
Elements of `x` are compared in row-major order. The tensor `[x[0],...]`
is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
If `x` has less than two elements, it is trivially strictly increasing.
See also: `is_non_decreasing`
>>> x1 = tf.constant([1.0, 2.0, 3.0])
>>> tf.math.is_strictly_increasing(x1)
<tf.Tensor: shape=(), dtype=bool, numpy=True>
>>> x2 = tf.constant([3.0, 1.0, 2.0])
>>> tf.math.is_strictly_increasing(x2)
<tf.Tensor: shape=(), dtype=bool, numpy=False>
Args:
x: Numeric `Tensor`.
name: A name for this operation (optional).
Defaults to "is_strictly_increasing"
Returns:
Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.
Raises:
TypeError: if `x` is not a numeric tensor.
"""
with ops.name_scope(name, 'is_strictly_increasing', [x]):
diff = _get_results_for_monotonic_comparison(x, math_ops.greater)
# When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
return math_ops.reduce_all(diff)
def _assert_same_base_type(items, expected_type=None):
r"""Asserts all items are of the same base type.
Args:
items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
`Operation`, or `IndexedSlices`). Can include `None` elements, which
will be ignored.
expected_type: Expected type. If not specified, assert all items are
of the same base type.
Returns:
Validated type, or none if neither expected_type nor items provided.
Raises:
ValueError: If any types do not match.
"""
original_expected_type = expected_type
mismatch = False
for item in items:
if item is not None:
item_type = item.dtype.base_dtype
if not expected_type:
expected_type = item_type
elif expected_type != item_type:
mismatch = True
break
if mismatch:
# Loop back through and build up an informative error message (this is very
# slow, so we don't do it unless we found an error above).
expected_type = original_expected_type
original_item_str = None
for item in items:
if item is not None:
item_type = item.dtype.base_dtype
if not expected_type:
expected_type = item_type
original_item_str = item.name if hasattr(item, 'name') else str(item)
elif expected_type != item_type:
raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
item.name if hasattr(item, 'name') else str(item),
item_type, expected_type,
(' as %s' % original_item_str) if original_item_str else ''))
return expected_type # Should be unreachable
else:
return expected_type
@tf_export(
'debugging.assert_same_float_dtype',
v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_same_float_dtype')
def assert_same_float_dtype(tensors=None, dtype=None):
"""Validate and return float type based on `tensors` and `dtype`.
For ops such as matrix multiplication, inputs and weights must be of the
same float type. This function validates that all `tensors` are the same type,
validates that type is `dtype` (if supplied), and returns the type. Type must
be a floating point type. If neither `tensors` nor `dtype` is supplied,
the function will return `dtypes.float32`.
Args:
tensors: Tensors of input values. Can include `None` elements, which will be
ignored.
dtype: Expected type.
Returns:
Validated type.
Raises:
ValueError: if neither `tensors` nor `dtype` is supplied, or result is not
float, or the common type of the inputs is not a floating point type.
"""
if tensors:
dtype = _assert_same_base_type(tensors, dtype)
if not dtype:
dtype = dtypes.float32
elif not dtype.is_floating:
raise ValueError('Expected floating point type, got %s.' % dtype)
return dtype
@tf_export('debugging.assert_scalar', v1=[])
@dispatch.add_dispatch_support
def assert_scalar_v2(tensor, message=None, name=None):
"""Asserts that the given `tensor` is a scalar.
This function raises `ValueError` unless it can be certain that the given
`tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
unknown.
This is always checked statically, so this method returns nothing.
Args:
tensor: A `Tensor`.
message: A string to prefix to the default message.
name: A name for this operation. Defaults to "assert_scalar"
Raises:
ValueError: If the tensor is not scalar (rank 0), or if its shape is
unknown.
"""
assert_scalar(tensor=tensor, message=message, name=name)
@tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_scalar')
def assert_scalar(tensor, name=None, message=None):
"""Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
This function raises `ValueError` unless it can be certain that the given
`tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
unknown.
Args:
tensor: A `Tensor`.
name: A name for this operation. Defaults to "assert_scalar"
message: A string to prefix to the default message.
Returns:
The input tensor (potentially converted to a `Tensor`).
Raises:
ValueError: If the tensor is not scalar (rank 0), or if its shape is
unknown.
"""
with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
tensor = ops.convert_to_tensor(tensor, name=name_scope)
shape = tensor.get_shape()
message = _message_prefix(message)
if shape.ndims != 0:
if context.executing_eagerly():
raise ValueError('%sExpected scalar shape, saw shape: %s.'
% (message, shape,))
else:
raise ValueError('%sExpected scalar shape for %s, saw shape: %s.'
% (message, tensor.name, shape))
return tensor
def _message_prefix(message):
if message:
return '%s. ' % message
return ''
@tf_export('ensure_shape')
@dispatch.add_dispatch_support
def ensure_shape(x, shape, name=None):
"""Updates the shape of a tensor and checks at runtime that the shape holds.
When executed, this operation asserts that the input tensor `x`'s shape
is compatible with the `shape` argument.
See `tf.TensorShape.is_compatible_with` for details.
>>> x = tf.constant([[1, 2, 3],
... [4, 5, 6]])
>>> x = tf.ensure_shape(x, [2, 3])
Use `None` for unknown dimensions:
>>> x = tf.ensure_shape(x, [None, 3])
>>> x = tf.ensure_shape(x, [2, None])
If the tensor's shape is not compatible with the `shape` argument, an error
is raised:
>>> x = tf.ensure_shape(x, [5])
Traceback (most recent call last):
...
tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
compatible with expected shape [5]. [Op:EnsureShape]
During graph construction (typically tracing a `tf.function`),
`tf.ensure_shape` updates the static-shape of the **result** tensor by
merging the two shapes. See `tf.TensorShape.merge_with` for details.
This is most useful when **you** know a shape that can't be determined
statically by TensorFlow.
The following trivial `tf.function` prints the input tensor's
static-shape before and after `ensure_shape` is applied.
>>> @tf.function
... def f(tensor):
... print("Static-shape before:", tensor.shape)
... tensor = tf.ensure_shape(tensor, [None, 3])
... print("Static-shape after:", tensor.shape)
... return tensor
This lets you see the effect of `tf.ensure_shape` when the function is traced:
>>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
Static-shape before: (None, None)
Static-shape after: (None, 3)
>>> cf(tf.zeros([3, 3])) # Passes
>>> cf(tf.constant([1, 2, 3])) # fails
Traceback (most recent call last):
...
InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].
The above example raises `tf.errors.InvalidArgumentError`, because `x`'s
shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)`
Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and
runtime shapes. This is stricter than `tf.Tensor.set_shape` which only
checks the buildtime shape.
Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape
of the resulting tensor and enforces it at runtime, raising an error if the
tensor's runtime shape is incompatible with the specified shape.
`tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it
at runtime, which may result in inconsistencies between the statically-known
shape of tensors and the runtime value of tensors.
For example, of loading images of a known size:
>>> @tf.function
... def decode_image(png):
... image = tf.image.decode_png(png, channels=3)
... # the `print` executes during tracing.
... print("Initial shape: ", image.shape)
... image = tf.ensure_shape(image,[28, 28, 3])
... print("Final shape: ", image.shape)
... return image
When tracing a function, no ops are being executed, shapes may be unknown.
See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function)
for details.
>>> concrete_decode = decode_image.get_concrete_function(
... tf.TensorSpec([], dtype=tf.string))
Initial shape: (None, None, 3)
Final shape: (28, 28, 3)
>>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)
>>> image = tf.cast(image,tf.uint8)
>>> png = tf.image.encode_png(image)
>>> image2 = concrete_decode(png)
>>> print(image2.shape)
(28, 28, 3)
>>> image = tf.concat([image,image], axis=0)
>>> print(image.shape)
(56, 28, 3)
>>> png = tf.image.encode_png(image)
>>> image2 = concrete_decode(png)
Traceback (most recent call last):
...
tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not
compatible with expected shape [28,28,3].
Caution: if you don't use the result of `tf.ensure_shape` the check may not
run.
>>> @tf.function
... def bad_decode_image(png):
... image = tf.image.decode_png(png, channels=3)
... # the `print` executes during tracing.
... print("Initial shape: ", image.shape)
... # BAD: forgot to use the returned tensor.
... tf.ensure_shape(image,[28, 28, 3])
... print("Final shape: ", image.shape)
... return image
>>> image = bad_decode_image(png)
Initial shape: (None, None, 3)
Final shape: (None, None, 3)
>>> print(image.shape)
(56, 28, 3)
Args:
x: A `Tensor`.
shape: A `TensorShape` representing the shape of this tensor, a
`TensorShapeProto`, a list, a tuple, or None.
name: A name for this operation (optional). Defaults to "EnsureShape".
Returns:
A `Tensor`. Has the same type and contents as `x`.
Raises:
tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape
of `x`.
"""
if not isinstance(shape, tensor_shape.TensorShape):
shape = tensor_shape.TensorShape(shape)
return array_ops.ensure_shape(x, shape, name=name)
@ops.RegisterGradient('EnsureShape')
def _ensure_shape_grad(op, grad):
del op # Unused.
return grad