3RNN/Lib/site-packages/tensorflow/python/framework/common_shapes.py
2024-05-26 19:49:15 +02:00

109 lines
3.9 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A library of common shape functions."""
import itertools
from tensorflow.python.framework import tensor_shape
def _broadcast_shape_helper(shape_x, shape_y):
"""Helper functions for is_broadcast_compatible and broadcast_shape.
Args:
shape_x: A `TensorShape`
shape_y: A `TensorShape`
Returns:
Returns None if the shapes are not broadcast compatible,
a list of the broadcast dimensions otherwise.
"""
# To compute the broadcasted dimensions, we zip together shape_x and shape_y,
# and pad with 1 to make them the same length.
broadcasted_dims = reversed(
list(
itertools.zip_longest(
reversed(shape_x.dims),
reversed(shape_y.dims),
fillvalue=tensor_shape.Dimension(1))))
# Next we combine the dimensions according to the numpy broadcasting rules.
# http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
return_dims = []
for (dim_x, dim_y) in broadcasted_dims:
if dim_x.value is None or dim_y.value is None:
# One or both dimensions is unknown. If either dimension is greater than
# 1, we assume that the program is correct, and the other dimension will
# be broadcast to match it.
# TODO(mrry): If we eliminate the shape checks in C++, we must still
# assert that the unknown dim is either 1 or the same as the known dim.
if dim_x.value is not None and dim_x.value > 1:
return_dims.append(dim_x)
elif dim_y.value is not None and dim_y.value > 1:
return_dims.append(dim_y)
else:
return_dims.append(None)
elif dim_x.value == 1:
# We will broadcast dim_x to dim_y.
return_dims.append(dim_y)
elif dim_y.value == 1:
# We will broadcast dim_y to dim_x.
return_dims.append(dim_x)
elif dim_x.value == dim_y.value:
# The dimensions are compatible, so output is the same size in that
# dimension.
return_dims.append(dim_x.merge_with(dim_y))
else:
return None
return return_dims
def is_broadcast_compatible(shape_x, shape_y):
"""Returns True if `shape_x` and `shape_y` are broadcast compatible.
Args:
shape_x: A `TensorShape`
shape_y: A `TensorShape`
Returns:
True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
to. False otherwise.
"""
if shape_x.ndims is None or shape_y.ndims is None:
return False
return _broadcast_shape_helper(shape_x, shape_y) is not None
def broadcast_shape(shape_x, shape_y):
"""Returns the broadcasted shape between `shape_x` and `shape_y`.
Args:
shape_x: A `TensorShape`
shape_y: A `TensorShape`
Returns:
A `TensorShape` representing the broadcasted shape.
Raises:
ValueError: If the two shapes can not be broadcasted.
"""
if shape_x.ndims is None or shape_y.ndims is None:
return tensor_shape.unknown_shape()
return_dims = _broadcast_shape_helper(shape_x, shape_y)
if return_dims is None:
raise ValueError('Incompatible shapes for broadcasting. Two shapes are '
'compatible if for each dimension pair they are either '
'equal or one of them is 1. '
f'Received: {shape_x} and {shape_y}.')
return tensor_shape.TensorShape(return_dims)