# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Helper classes for tensor shape inference.""" import functools import operator from typing import Optional, Sequence, Type from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.function import trace_type from tensorflow.python import tf2 from tensorflow.python.eager import monitoring from tensorflow.python.platform import tf_logging as logging from tensorflow.python.types import trace from tensorflow.python.util.tf_export import tf_export from tensorflow.tools.docs import doc_controls _TENSORSHAPE_V2_OVERRIDE = None _api_usage_gauge = monitoring.BoolGauge( "/tensorflow/api/v2_tensorshape", "Whether tensor_shape.enable_v2_tensorshape() is called.") @tf_export(v1=["enable_v2_tensorshape"]) def enable_v2_tensorshape(): """In TensorFlow 2.0, iterating over a TensorShape instance returns values. This enables the new behavior. Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but it V2 it returns either an integer, or None. Examples: ``` ####################### # If you had this in V1: value = tensor_shape[i].value # Do this in V2 instead: value = tensor_shape[i] ####################### # If you had this in V1: for dim in tensor_shape: value = dim.value print(value) # Do this in V2 instead: for value in tensor_shape: print(value) ####################### # If you had this in V1: dim = tensor_shape[i] dim.assert_is_compatible_with(other_shape) # or using any other shape method # Do this in V2 instead: if tensor_shape.rank is None: dim = Dimension(None) else: dim = tensor_shape.dims[i] dim.assert_is_compatible_with(other_shape) # or using any other shape method # The V2 suggestion above is more explicit, which will save you from # the following trap (present in V1): # you might do in-place modifications to `dim` and expect them to be reflected # in `tensor_shape[i]`, but they would not be. ``` """ global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name _TENSORSHAPE_V2_OVERRIDE = True logging.vlog(1, "Enabling v2 tensorshape") _api_usage_gauge.get_cell().set(True) @tf_export(v1=["disable_v2_tensorshape"]) def disable_v2_tensorshape(): """Disables the V2 TensorShape behavior and reverts to V1 behavior. See docstring for `enable_v2_tensorshape` for details about the new behavior. """ global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name _TENSORSHAPE_V2_OVERRIDE = False logging.vlog(1, "Disabling v2 tensorshape") _api_usage_gauge.get_cell().set(False) @tf_export( "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"]) def dimension_value(dimension): """Compatibility utility required to allow for both V1 and V2 behavior in TF. Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to coexist with the new behavior. This utility is a bridge between the two. When accessing the value of a TensorShape dimension, use this utility, like this: ``` # If you had this in your V1 code: value = tensor_shape[i].value # Use `dimension_value` as direct replacement compatible with both V1 & V2: value = dimension_value(tensor_shape[i]) # This would be the V2 equivalent: value = tensor_shape[i] # Warning: this will return the dim value in V2! ``` Args: dimension: Either a `Dimension` instance, an integer, or None. Returns: A plain value, i.e. an integer or None. """ if isinstance(dimension, Dimension): return dimension.value return dimension @tf_export( "compat.dimension_at_index", v1=["dimension_at_index", "compat.dimension_at_index"]) def dimension_at_index(shape, index): """Compatibility utility required to allow for both V1 and V2 behavior in TF. Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to coexist with the new behavior. This utility is a bridge between the two. If you want to retrieve the Dimension instance corresponding to a certain index in a TensorShape instance, use this utility, like this: ``` # If you had this in your V1 code: dim = tensor_shape[i] # Use `dimension_at_index` as direct replacement compatible with both V1 & V2: dim = dimension_at_index(tensor_shape, i) # Another possibility would be this, but WARNING: it only works if the # tensor_shape instance has a defined rank. dim = tensor_shape.dims[i] # `dims` may be None if the rank is undefined! # In native V2 code, we recommend instead being more explicit: if tensor_shape.rank is None: dim = Dimension(None) else: dim = tensor_shape.dims[i] # Being more explicit will save you from the following trap (present in V1): # you might do in-place modifications to `dim` and expect them to be reflected # in `tensor_shape[i]`, but they would not be (as the Dimension object was # instantiated on the fly. ``` Args: shape: A TensorShape instance. index: An integer index. Returns: A dimension object. """ assert isinstance(shape, TensorShape) if shape.rank is None: return Dimension(None) else: return shape.dims[index] @tf_export(v1=["Dimension"]) class Dimension(object): """Represents the value of one dimension in a TensorShape. @compatibility(TF2) In TF2, members of a `TensorShape` object are integers. The `Dimension` class is not part of TF2's data model. Please refer to the [TensorShape section of the migration guide] (https://www.tensorflow.org/guide/migrate/index#tensorshape) on common code patterns adapting Dimension objects to a TF2 syntax. @end_compatibility """ __slots__ = ["_value"] def __init__(self, value): """Creates a new Dimension with the given value.""" if isinstance(value, int): # Most common case. if value < 0: raise ValueError("Dimension %d must be >= 0" % value) self._value = value elif value is None: self._value = None elif isinstance(value, Dimension): self._value = value._value else: try: # int(...) compensates for the int/long dichotomy on Python 2.X. # TODO(b/143206389): Remove once we fully migrate to 3.X. self._value = int(value.__index__()) except AttributeError: raise TypeError( "Dimension value must be integer or None or have " "an __index__ method, got value '{0!r}' with type '{1!r}'".format( value, type(value))) from None if self._value < 0: raise ValueError("Dimension %d must be >= 0" % self._value) def __repr__(self): return "Dimension(%s)" % repr(self._value) def __str__(self): value = self._value return "?" if value is None else str(value) def __eq__(self, other): """Returns true if `other` has the same known value as this Dimension.""" try: other = as_dimension(other) except (TypeError, ValueError): return NotImplemented if self._value is None or other.value is None: return None return self._value == other.value def __ne__(self, other): """Returns true if `other` has a different known value from `self`.""" try: other = as_dimension(other) except (TypeError, ValueError): return NotImplemented if self._value is None or other.value is None: return None return self._value != other.value def __bool__(self): """Equivalent to `bool(self.value)`.""" return bool(self._value) def __int__(self): return self._value # This is needed for Windows. # See https://github.com/tensorflow/tensorflow/pull/9780 def __long__(self): return self._value def __index__(self): # Allow use in Python 3 range return self._value @property def value(self): """The value of this dimension, or None if it is unknown.""" return self._value # TODO(b/225058047): Reconsider semantics. def is_compatible_with(self, other): """Returns true if `other` is compatible with this Dimension. Two known Dimensions are compatible if they have the same value. An unknown Dimension is compatible with all other Dimensions. Args: other: Another Dimension. Returns: True if this Dimension and `other` are compatible. """ other = as_dimension(other) return (self._value is None or other.value is None or self._value == other.value) def assert_is_compatible_with(self, other): """Raises an exception if `other` is not compatible with this Dimension. Args: other: Another Dimension. Raises: ValueError: If `self` and `other` are not compatible (see is_compatible_with). """ if not self.is_compatible_with(other): raise ValueError("Dimensions %s and %s are not compatible" % (self, other)) def merge_with(self, other): """Returns a Dimension that combines the information in `self` and `other`. Dimensions are combined as follows: ```python tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(n)) == tf.compat.v1.Dimension(n) tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(None)) == tf.compat.v1.Dimension(n) tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(n)) == tf.compat.v1.Dimension(n) # equivalent to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(None)) # raises ValueError for n != m tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(m)) ``` Args: other: Another Dimension. Returns: A Dimension containing the combined information of `self` and `other`. Raises: ValueError: If `self` and `other` are not compatible (see is_compatible_with). """ other = as_dimension(other) self.assert_is_compatible_with(other) if self._value is None: return Dimension(other.value) else: return Dimension(self._value) def __add__(self, other): """Returns the sum of `self` and `other`. Dimensions are summed as follows: ```python tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(n) == tf.compat.v1.Dimension(m + n) tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(n) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) ``` Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is the sum of `self` and `other`. """ try: other = as_dimension(other) except (TypeError, ValueError): return NotImplemented if self._value is None or other.value is None: return Dimension(None) else: return Dimension(self._value + other.value) def __radd__(self, other): """Returns the sum of `other` and `self`. Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is the sum of `self` and `other`. """ return self + other def __sub__(self, other): """Returns the subtraction of `other` from `self`. Dimensions are subtracted as follows: ```python tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(n) == tf.compat.v1.Dimension(m - n) tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(n) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) ``` Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is the subtraction of `other` from `self`. """ try: other = as_dimension(other) except (TypeError, ValueError): return NotImplemented if self._value is None or other.value is None: return Dimension(None) else: return Dimension(self._value - other.value) def __rsub__(self, other): """Returns the subtraction of `self` from `other`. Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is the subtraction of `self` from `other`. """ other = as_dimension(other) if self._value is None or other.value is None: return Dimension(None) else: return Dimension(other.value - self._value) def __mul__(self, other): """Returns the product of `self` and `other`. Dimensions are summed as follows: ```python tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(n) == tf.compat.v1.Dimension(m * n) tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(n) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) ``` Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is the product of `self` and `other`. """ try: other = as_dimension(other) except (TypeError, ValueError): return NotImplemented if self._value is None or other.value is None: return Dimension(None) else: return Dimension(self._value * other.value) def __rmul__(self, other): """Returns the product of `self` and `other`. Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is the product of `self` and `other`. """ return self * other def __floordiv__(self, other): """Returns the quotient of `self` and `other` rounded down. Dimensions are divided as follows: ```python tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(n) == tf.compat.v1.Dimension(m // n) tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(n) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) ``` Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A `Dimension` whose value is the integer quotient of `self` and `other`. """ try: other = as_dimension(other) except (TypeError, ValueError): return NotImplemented if self._value is None or other.value is None: return Dimension(None) else: return Dimension(self._value // other.value) def __rfloordiv__(self, other): """Returns the quotient of `other` and `self` rounded down. Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A `Dimension` whose value is the integer quotient of `self` and `other`. """ other = as_dimension(other) if self._value is None or other.value is None: return Dimension(None) else: return Dimension(other.value // self._value) def __div__(self, other): """DEPRECATED: Use `__floordiv__` via `x // y` instead. This function exists only for backwards compatibility purposes; new code should use `__floordiv__` via the syntax `x // y`. Using `x // y` communicates clearly that the result rounds down, and is forward compatible to Python 3. Args: other: Another `Dimension`. Returns: A `Dimension` whose value is the integer quotient of `self` and `other`. """ return self // other def __rdiv__(self, other): """Use `__floordiv__` via `x // y` instead. This function exists only to have a better error message. Instead of: `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, this function will explicitly call for usage of `//` instead. Args: other: Another `Dimension`. Raises: TypeError. """ raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " "please use // instead".format(type(other).__name__)) def __truediv__(self, other): """Use `__floordiv__` via `x // y` instead. This function exists only to have a better error message. Instead of: `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`, this function will explicitly call for usage of `//` instead. Args: other: Another `Dimension`. Raises: TypeError. """ raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', " "please use // instead".format(type(other).__name__)) def __rtruediv__(self, other): """Use `__floordiv__` via `x // y` instead. This function exists only to have a better error message. Instead of: `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, this function will explicitly call for usage of `//` instead. Args: other: Another `Dimension`. Raises: TypeError. """ raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " "please use // instead".format(type(other).__name__)) def __mod__(self, other): """Returns `self` modulo `other`. Dimension modulo are computed as follows: ```python tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(n) == tf.compat.v1.Dimension(m % n) tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(n) # equiv. to tf.compat.v1.Dimension(None) tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(None) # equiv. to tf.compat.v1.Dimension(None) ``` Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is `self` modulo `other`. """ other = as_dimension(other) if self._value is None or other.value is None: return Dimension(None) else: return Dimension(self._value % other.value) def __rmod__(self, other): """Returns `other` modulo `self`. Args: other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is `other` modulo `self`. """ other = as_dimension(other) return other % self def __lt__(self, other): """Returns True if `self` is known to be less than `other`. Dimensions are compared as follows: ```python (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(n)) == (m < n) (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(None)) == None (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(n)) == None (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(None)) == None ``` Args: other: Another Dimension. Returns: The value of `self.value < other.value` if both are known, otherwise None. """ other = as_dimension(other) if self._value is None or other.value is None: return None else: return self._value < other.value def __le__(self, other): """Returns True if `self` is known to be less than or equal to `other`. Dimensions are compared as follows: ```python (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(n)) == (m <= n) (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(None)) == None (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(n)) == None (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(None)) == None ``` Args: other: Another Dimension. Returns: The value of `self.value <= other.value` if both are known, otherwise None. """ other = as_dimension(other) if self._value is None or other.value is None: return None else: return self._value <= other.value def __gt__(self, other): """Returns True if `self` is known to be greater than `other`. Dimensions are compared as follows: ```python (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(n)) == (m > n) (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(None)) == None (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(n)) == None (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(None)) == None ``` Args: other: Another Dimension. Returns: The value of `self.value > other.value` if both are known, otherwise None. """ other = as_dimension(other) if self._value is None or other.value is None: return None else: return self._value > other.value def __ge__(self, other): """Returns True if `self` is known to be greater than or equal to `other`. Dimensions are compared as follows: ```python (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(n)) == (m >= n) (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(None)) == None (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(n)) == None (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(None)) == None ``` Args: other: Another Dimension. Returns: The value of `self.value >= other.value` if both are known, otherwise None. """ other = as_dimension(other) if self._value is None or other.value is None: return None else: return self._value >= other.value def __reduce__(self): return Dimension, (self._value,) def as_dimension(value): """Converts the given value to a Dimension. A Dimension input will be returned unmodified. An input of `None` will be converted to an unknown Dimension. An integer input will be converted to a Dimension with that value. Args: value: The value to be converted. Returns: A Dimension corresponding to the given value. """ if isinstance(value, Dimension): return value else: return Dimension(value) @tf_export("TensorShape") class TensorShape(trace.TraceType, trace_type.Serializable): """Represents the shape of a `Tensor`. >>> t = tf.constant([[1,2,3],[4,5,6]]) >>> t.shape TensorShape([2, 3]) `TensorShape` is the *static* shape representation of a Tensor. During eager execution a Tensor always has a fully specified shape but when tracing a `tf.function` it may be one of the following: * *Fully-known shape:* has a known number of dimensions and a known size for each dimension. e.g. `TensorShape([16, 256])` * *Partially-known shape:* has a known number of dimensions, and an unknown size for one or more dimension. e.g. `TensorShape([None, 256])` * *Unknown shape:* has an unknown number of dimensions, and an unknown size in all dimensions. e.g. `TensorShape(None)` During function tracing `t.shape` will return a `TensorShape` object representing the shape of Tensor as it is known during tracing. This static representation will be partially defined in cases where the exact shape depends on the values within the tensors. To get the *dynamic* representation, please use `tf.shape(t)` which will return Tensor representing the fully defined shape of `t`. This way, you can express logic that manipulates the shapes of tensors by building other tensors that depend on the dynamic shape of `t`. Note: `tf.RaggedTensor.shape` also returns a `tf.TensorShape`, the lengths of any ragged dimensions are unknown (`None`). For example, this function prints the `TensorShape' (`t.shape`), when you trace the function, and returns a tensor `tf.shape(t)` for given input `t`: >>> @tf.function ... def get_dynamic_shape(t): ... print("tracing...") ... print(f"static shape is {t.shape}") ... return tf.shape(t) Just calling the function traces it with a fully-specified static shape: >>> result = get_dynamic_shape(tf.constant([[1, 1, 1], [0, 0, 0]])) tracing... static shape is (2, 3) >>> result.numpy() array([2, 3], dtype=int32) But `tf.function` can also trace the function with a partially specified (or even unspecified) shape: >>> cf1 = get_dynamic_shape.get_concrete_function(tf.TensorSpec( ... shape=[None, 2])) tracing... static shape is (None, 2) >>> cf1(tf.constant([[1., 0],[1, 0],[1, 0]])).numpy() array([3, 2], dtype=int32) >>> cf2 = get_dynamic_shape.get_concrete_function(tf.TensorSpec(shape=None)) tracing... static shape is >>> cf2(tf.constant([[[[[1., 0]]]]])).numpy() array([1, 1, 1, 1, 2], dtype=int32) If a tensor is produced by an operation of type `"Foo"`, its shape may be inferred if there is a registered shape function for `"Foo"`. See [Shape functions](https://www.tensorflow.org/guide/create_op#shape_functions_in_c) for details of shape functions and how to register them. Alternatively, you may set the shape explicitly using `tf.Tensor.ensure_shape`. """ __slots__ = ["_dims"] def __init__(self, dims): """Creates a new TensorShape with the given dimensions. Args: dims: A list of Dimensions, or None if the shape is unspecified. Raises: TypeError: If dims cannot be converted to a list of dimensions. """ if isinstance(dims, (tuple, list)): # Most common case. self._dims = tuple(as_dimension(d).value for d in dims) elif dims is None: self._dims = None elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): if dims.unknown_rank: self._dims = None else: self._dims = tuple( # Protos store variable-size dimensions as -1 dim.size if dim.size != -1 else None for dim in dims.dim ) elif isinstance(dims, TensorShape): self._dims = dims._dims else: try: dims_iter = iter(dims) except TypeError: # Treat as a singleton dimension self._dims = (as_dimension(dims).value,) else: self._dims = [] for d in dims_iter: try: self._dims.append(as_dimension(d).value) except TypeError as e: raise TypeError( "Failed to convert '{0!r}' to a shape: '{1!r}'" "could not be converted to a dimension. A shape should " "either be single dimension (e.g. 10), or an iterable of " "dimensions (e.g. [1, 10, None]).".format(dims, d)) from e self._dims = tuple(self._dims) @property def _v2_behavior(self): if _TENSORSHAPE_V2_OVERRIDE is None: return tf2.enabled() return _TENSORSHAPE_V2_OVERRIDE def __repr__(self): if self._v2_behavior: if self._dims is not None: return f"TensorShape({list(self._dims)})" else: return "TensorShape(None)" else: return f"TensorShape({self.dims})" def __str__(self): if self.rank is None: return "" elif self.rank == 1: if self._v2_behavior: return "(%s,)" % self._dims[0] else: return "(%s,)" % self.dims[0] else: if self._v2_behavior: return "(%s)" % ", ".join(str(d) for d in self._dims) else: return "(%s)" % ", ".join(str(d) for d in self.dims) @property def rank(self): """Returns the rank of this shape, or None if it is unspecified.""" if self._dims is not None: return len(self._dims) return None @property def dims(self): """Deprecated. Returns list of dimensions for this shape. Suggest `TensorShape.as_list` instead. Returns: A list containing `tf.compat.v1.Dimension`s, or None if the shape is unspecified. """ if self._dims is None: return None return [as_dimension(d) for d in self._dims] @property def ndims(self): """Deprecated accessor for `rank`.""" return self.rank def __len__(self): """Returns the rank of this shape, or raises ValueError if unspecified.""" if self._dims is None: raise ValueError("Cannot take the length of shape with unknown rank.") return len(self._dims) def __bool__(self): """Returns True if this shape contains non-zero information.""" return self._dims is not None # Python 3 wants __bool__, Python 2.7 wants __nonzero__ __nonzero__ = __bool__ def __iter__(self): """Returns `self.dims` if the rank is known, otherwise raises ValueError.""" if self._dims is None: raise ValueError("Cannot iterate over a shape with unknown rank.") else: if self._v2_behavior: return iter(d for d in self._dims) else: return iter(d for d in self.dims) def __getitem__(self, key): """Returns the value of a dimension or a shape, depending on the key. Args: key: If `key` is an integer, returns the dimension at that index; otherwise if `key` is a slice, returns a TensorShape whose dimensions are those selected by the slice from `self`. Returns: An integer if `key` is an integer, or a `TensorShape` if `key` is a slice. Raises: ValueError: If `key` is a slice and `self` is completely unknown and the step is set. """ if self._dims is not None: if isinstance(key, slice): return TensorShape(self._dims[key]) else: if self._v2_behavior: return self._dims[key] else: return self.dims[key] else: if isinstance(key, slice): start = key.start if key.start is not None else 0 stop = key.stop if key.step is not None: # TODO(mrry): Handle these maybe. raise ValueError("Steps are not yet handled") if stop is None: # NOTE(mrry): This implies that TensorShape(None) is compatible with # TensorShape(None)[1:], which is obviously not true. It would be # possible to track the number of dimensions symbolically, # and perhaps we should do that. return unknown_shape() elif start < 0 or stop < 0: # TODO(mrry): Handle this better, as it will be useful for handling # suffixes of otherwise unknown shapes. return unknown_shape() else: return unknown_shape(rank=stop - start) else: if self._v2_behavior: return None else: return Dimension(None) def num_elements(self): """Returns the total number of elements, or none for incomplete shapes.""" if self.is_fully_defined(): return functools.reduce(operator.mul, self.as_list(), 1) else: return None def merge_with(self, other): """Returns a `TensorShape` combining the information in `self` and `other`. The dimensions in `self` and `other` are merged element-wise, according to the rules below: ```python Dimension(n).merge_with(Dimension(None)) == Dimension(n) Dimension(None).merge_with(Dimension(n)) == Dimension(n) Dimension(None).merge_with(Dimension(None)) == Dimension(None) # raises ValueError for n != m Dimension(n).merge_with(Dimension(m)) ``` >> ts = tf.TensorShape([1,2]) >> ot1 = tf.TensorShape([1,2]) >> ts.merge_with(ot).as_list() [1,2] >> ot2 = tf.TensorShape([1,None]) >> ts.merge_with(ot2).as_list() [1,2] >> ot3 = tf.TensorShape([None, None]) >> ot3.merge_with(ot2).as_list() [1, None] Args: other: Another `TensorShape`. Returns: A `TensorShape` containing the combined information of `self` and `other`. Raises: ValueError: If `self` and `other` are not compatible. """ other = as_shape(other) if self.dims is None: return other if other.dims is None: return self else: try: self.assert_same_rank(other) new_dims = [ dim.merge_with(other_dim) for dim, other_dim in zip(self.dims, other.dims) ] return TensorShape(new_dims) except ValueError: raise ValueError("Shapes %s and %s are not compatible" % (self, other)) def __add__(self, other): return self.concatenate(other) def __radd__(self, other): if not isinstance(other, TensorShape): other = TensorShape(other) return other.concatenate(self) def concatenate(self, other): """Returns the concatenation of the dimension in `self` and `other`. *N.B.* If either `self` or `other` is completely unknown, concatenation will discard information about the other shape. In future, we might support concatenation that preserves this information for use with slicing. Args: other: Another `TensorShape`. Returns: A `TensorShape` whose dimensions are the concatenation of the dimensions in `self` and `other`. """ # TODO(mrry): Handle the case where we concatenate a known shape with a # completely unknown shape, so that we can use the partial information. other = as_shape(other) if self.dims is None or other.dims is None: return unknown_shape() else: return TensorShape(self.dims + other.dims) def assert_same_rank(self, other): """Raises an exception if `self` and `other` do not have compatible ranks. Args: other: Another `TensorShape`. Raises: ValueError: If `self` and `other` do not represent shapes with the same rank. """ other = as_shape(other) if self.rank is not None and other.rank is not None: if self.rank != other.rank: raise ValueError("Shapes %s and %s must have the same rank" % (self, other)) def assert_has_rank(self, rank): """Raises an exception if `self` is not compatible with the given `rank`. Args: rank: An integer. Raises: ValueError: If `self` does not represent a shape with the given `rank`. """ if self.rank not in (None, rank): raise ValueError("Shape %s must have rank %d" % (self, rank)) def with_rank(self, rank): """Returns a shape based on `self` with the given rank. This method promotes a completely unknown shape to one with a known rank. Args: rank: An integer. Returns: A shape that is at least as specific as `self` with the given rank. Raises: ValueError: If `self` does not represent a shape with the given `rank`. """ try: return self.merge_with(unknown_shape(rank=rank)) except ValueError: raise ValueError("Shape %s must have rank %d" % (self, rank)) def with_rank_at_least(self, rank): """Returns a shape based on `self` with at least the given rank. Args: rank: An integer. Returns: A shape that is at least as specific as `self` with at least the given rank. Raises: ValueError: If `self` does not represent a shape with at least the given `rank`. """ if self.rank is not None and self.rank < rank: raise ValueError("Shape %s must have rank at least %d" % (self, rank)) else: return self def with_rank_at_most(self, rank): """Returns a shape based on `self` with at most the given rank. Args: rank: An integer. Returns: A shape that is at least as specific as `self` with at most the given rank. Raises: ValueError: If `self` does not represent a shape with at most the given `rank`. """ if self.rank is not None and self.rank > rank: raise ValueError("Shape %s must have rank at most %d" % (self, rank)) else: return self def is_subtype_of(self, other: trace.TraceType) -> bool: """Returns True iff `self` is subtype of `other`. Shape A is a subtype of shape B if shape B can successfully represent it: * A `TensorShape` of any rank is a subtype of `TensorShape(None)`. * TensorShapes of equal ranks are covariant, i.e. `TensorShape([A1, A2, ..])` is a subtype of `TensorShape([B1, B2, ..])` iff An is a subtype of Bn. An is subtype of Bn iff An == Bn or Bn is None. * TensorShapes of different defined ranks have no subtyping relation. The subtyping relation is reflexive and transitive, but not symmetric. Some examples: * `TensorShape([32, 784])` is a subtype of `TensorShape(None)`, and `TensorShape([4, 4])` is also a subtype of `TensorShape(None)` but `TensorShape([32, 784])` and `TensorShape([4, 4])` are not subtypes of each other. * All two-dimensional shapes are subtypes of `TensorShape([None, None])`, such as `TensorShape([32, 784])`. There is no subtype relationship with, for example, `TensorShape([None])` or `TensorShape([None, None, None])`. * `TensorShape([32, None])` is also a subtype of `TensorShape([None, None])` and `TensorShape(None)`. It is not a subtype of, for example, `TensorShape([32])`, `TensorShape([32, None, 1])`, `TensorShape([64, None])` or `TensorShape([None, 32])`. * `TensorShape([32, 784])` is a subtype of itself, and also `TensorShape([32, None])`, `TensorShape([None, 784])`, `TensorShape([None, None])` and `TensorShape(None)`. It has no subtype relation with, for example, `TensorShape([32, 1, 784])` or `TensorShape([None])`. Args: other: Another `TensorShape`. Returns: True iff `self` is subtype of `other`. """ if not isinstance(other, TensorShape): return False # All Tensors are subtypes of a Tensor with no shape. if other.rank is None: return True # Tensor with a defined shape can only be subtype of another with a defined # shape if they have the same number of dimensions. if self.rank != other.rank: return False # A Tensor is a subtype if each corresponding dimension is a subtype. return all(o is None or s == o for s, o in zip(self._dims, other._dims)) # pylint: disable=protected-access def most_specific_common_supertype( self, others: Sequence[trace.TraceType]) -> Optional["TensorShape"]: """Returns the most specific supertype `TensorShape` of self and others. * `TensorShape([None, 1])` is the most specific `TensorShape` supertyping both `TensorShape([2, 1])` and `TensorShape([5, 1])`. Note that `TensorShape(None)` is also a supertype but it is not "most specific". * `TensorShape([1, 2, 3])` is the most specific `TensorShape` supertyping both `TensorShape([1, 2, 3])` and `TensorShape([1, 2, 3]`). There are other less specific TensorShapes that supertype above mentioned TensorShapes, e.g. `TensorShape([1, 2, None])`, `TensorShape(None)`. * `TensorShape([None, None])` is the most specific `TensorShape` supertyping both `TensorShape([2, None])` and `TensorShape([None, 3])`. As always, `TensorShape(None)` is also a supertype but not the most specific one. * `TensorShape(None`) is the only `TensorShape` supertyping both `TensorShape([1, 2, 3])` and `TensorShape([1, 2])`. In general, any two shapes that have different ranks will only have `TensorShape(None)` as a common supertype. * `TensorShape(None)` is the only `TensorShape` supertyping both `TensorShape([1, 2, 3])` and `TensorShape(None)`. In general, the common supertype of any shape with `TensorShape(None)` is `TensorShape(None)`. Args: others: Sequence of `TensorShape`. Returns: A `TensorShape` which is the most specific supertype shape of `self` and `others`. None if it does not exist. """ if any(not isinstance(other, TensorShape) for other in others): return None # A Rankless TensorShape is already a global supertype so we return another # instance of it. if self.rank is None: return unknown_shape() # A Rankless TensorShape is the most specific supertype for shapes whose # ranks do not match. if any(other.dims is None or self.rank != other.rank for other in others): return unknown_shape() # Retain the integer dimension if it is the same across all others, else # use an undefined dimension. dims = [ dim if all(dim == other._dims[i] for other in others) else None for i, dim in enumerate(self._dims) ] return TensorShape(dims) @doc_controls.do_not_doc_inheritable def placeholder_value(self, placeholder_context=None): raise NotImplementedError("A graph placeholder is not currently supported" "for an object of type: TensorShape.") @classmethod def experimental_type_proto(cls) -> Type[tensor_shape_pb2.TensorShapeProto]: """Returns the type of proto associated with TensorShape serialization.""" return tensor_shape_pb2.TensorShapeProto @classmethod def experimental_from_proto( cls, proto: tensor_shape_pb2.TensorShapeProto) -> "TensorShape": """Returns a TensorShape instance based on the serialized proto.""" return TensorShape(proto) def experimental_as_proto(self) -> tensor_shape_pb2.TensorShapeProto: """Returns a proto representation of the TensorShape instance.""" return self.as_proto() # TODO(b/216206374): Consider deprecation at TraceType release. def is_compatible_with(self, other): """Returns True iff `self` is compatible with `other`. Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape that both shapes can represent. Thus, compatibility allows the shape inference code to reason about partially-defined shapes. For example: * TensorShape(None) is compatible with all shapes. * TensorShape([None, None]) is compatible with all two-dimensional shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is not compatible with, for example, TensorShape([None]) or TensorShape([None, None, None]). * TensorShape([32, None]) is compatible with all two-dimensional shapes with size 32 in the 0th dimension, and also TensorShape([None, None]) and TensorShape(None). It is not compatible with, for example, TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). * TensorShape([32, 784]) is compatible with itself, and also TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, None]) and TensorShape(None). It is not compatible with, for example, TensorShape([32, 1, 784]) or TensorShape([None]). The compatibility relation is reflexive and symmetric, but not transitive. For example, TensorShape([32, 784]) is compatible with TensorShape(None), and TensorShape(None) is compatible with TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with TensorShape([4, 4]). Args: other: Another TensorShape. Returns: True iff `self` is compatible with `other`. """ other = as_shape(other) if self.dims is not None and other.dims is not None: if self.rank != other.rank: return False for x_dim, y_dim in zip(self.dims, other.dims): if not x_dim.is_compatible_with(y_dim): return False return True def assert_is_compatible_with(self, other): """Raises exception if `self` and `other` do not represent the same shape. This method can be used to assert that there exists a shape that both `self` and `other` represent. Args: other: Another TensorShape. Raises: ValueError: If `self` and `other` do not represent the same shape. """ if not self.is_compatible_with(other): raise ValueError("Shapes %s and %s are incompatible" % (self, other)) def most_specific_compatible_shape(self, other): """Returns the most specific TensorShape compatible with `self` and `other`. * TensorShape([None, 1]) is the most specific TensorShape compatible with both TensorShape([2, 1]) and TensorShape([5, 1]). Note that TensorShape(None) is also compatible with above mentioned TensorShapes. * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more less specific TensorShapes compatible with above mentioned TensorShapes, e.g. TensorShape([1, 2, None]), TensorShape(None). Args: other: Another `TensorShape`. Returns: A `TensorShape` which is the most specific compatible shape of `self` and `other`. """ other = as_shape(other) if self.dims is None or other.dims is None or self.rank != other.rank: return unknown_shape() dims = [ d1 if d1 is not None and d2 is not None and d1 == d2 else None for d1, d2 in zip(self.dims, other.dims) ] return TensorShape(dims) def is_fully_defined(self): """Returns True iff `self` is fully defined in every dimension.""" return (self._dims is not None and all(dim is not None for dim in self._dims)) def assert_is_fully_defined(self): """Raises an exception if `self` is not fully defined in every dimension. Raises: ValueError: If `self` does not have a known value for every dimension. """ if not self.is_fully_defined(): raise ValueError("Shape %s is not fully defined" % self) def as_list(self): """Returns a list of integers or `None` for each dimension. Returns: A list of integers or `None` for each dimension. Raises: ValueError: If `self` is an unknown shape with an unknown rank. """ if self._dims is None: raise ValueError("as_list() is not defined on an unknown TensorShape.") return list(self._dims) def as_proto(self): """Returns this shape as a `TensorShapeProto`.""" if self._dims is None: return tensor_shape_pb2.TensorShapeProto(unknown_rank=True) else: return tensor_shape_pb2.TensorShapeProto(dim=[ tensor_shape_pb2.TensorShapeProto.Dim( size=-1 if d is None else d) for d in self._dims ]) def __eq__(self, other): """Returns True if `self` is equivalent to `other`. It first tries to convert `other` to `TensorShape`. `TypeError` is thrown when the conversion fails. Otherwise, it compares each element in the TensorShape dimensions. * Two *Fully known* shapes, return True iff each element is equal. >>> t_a = tf.TensorShape([1,2]) >>> a = [1, 2] >>> t_b = tf.TensorShape([1,2]) >>> t_c = tf.TensorShape([1,2,3]) >>> t_a.__eq__(a) True >>> t_a.__eq__(t_b) True >>> t_a.__eq__(t_c) False * Two *Partially-known* shapes, return True iff each element is equal. >>> p_a = tf.TensorShape([1,None]) >>> p_b = tf.TensorShape([1,None]) >>> p_c = tf.TensorShape([2,None]) >>> p_a.__eq__(p_b) True >>> t_a.__eq__(p_a) False >>> p_a.__eq__(p_c) False * Two *Unknown shape*, return True. >>> unk_a = tf.TensorShape(None) >>> unk_b = tf.TensorShape(None) >>> unk_a.__eq__(unk_b) True >>> unk_a.__eq__(t_a) False Args: other: A `TensorShape` or type that can be converted to `TensorShape`. Returns: True if the dimensions are all equal. Raises: TypeError if `other` can not be converted to `TensorShape`. """ try: other = as_shape(other) except TypeError: return NotImplemented return self._dims == other._dims def __hash__(self): return hash(self._dims) def __reduce__(self): return TensorShape, (self.dims,) def __concat__(self, other): return self.concatenate(other) trace_type.register_serializable(TensorShape) def as_shape(shape): """Converts the given object to a TensorShape.""" if isinstance(shape, TensorShape): return shape else: return TensorShape(shape) def unknown_shape(rank=None, **kwargs): """Returns an unknown TensorShape, optionally with a known rank. Args: rank: (Optional) If specified, the number of dimensions in the shape. **kwargs: For backwards compatibility. Returns: An unknown TensorShape. Raises: TypeError: In case of invalid arguments. """ if rank is None and "ndims" in kwargs: rank = kwargs.pop("ndims") if kwargs: raise TypeError("Unknown argument: %s" % kwargs) if rank is None: return TensorShape(None) else: return TensorShape([Dimension(None)] * rank)