109 lines
3.9 KiB
Python
109 lines
3.9 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""A library of common shape functions."""
|
|
import itertools
|
|
|
|
from tensorflow.python.framework import tensor_shape
|
|
|
|
|
|
def _broadcast_shape_helper(shape_x, shape_y):
|
|
"""Helper functions for is_broadcast_compatible and broadcast_shape.
|
|
|
|
Args:
|
|
shape_x: A `TensorShape`
|
|
shape_y: A `TensorShape`
|
|
|
|
Returns:
|
|
Returns None if the shapes are not broadcast compatible,
|
|
a list of the broadcast dimensions otherwise.
|
|
"""
|
|
# To compute the broadcasted dimensions, we zip together shape_x and shape_y,
|
|
# and pad with 1 to make them the same length.
|
|
broadcasted_dims = reversed(
|
|
list(
|
|
itertools.zip_longest(
|
|
reversed(shape_x.dims),
|
|
reversed(shape_y.dims),
|
|
fillvalue=tensor_shape.Dimension(1))))
|
|
# Next we combine the dimensions according to the numpy broadcasting rules.
|
|
# http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
|
|
return_dims = []
|
|
for (dim_x, dim_y) in broadcasted_dims:
|
|
if dim_x.value is None or dim_y.value is None:
|
|
# One or both dimensions is unknown. If either dimension is greater than
|
|
# 1, we assume that the program is correct, and the other dimension will
|
|
# be broadcast to match it.
|
|
# TODO(mrry): If we eliminate the shape checks in C++, we must still
|
|
# assert that the unknown dim is either 1 or the same as the known dim.
|
|
if dim_x.value is not None and dim_x.value > 1:
|
|
return_dims.append(dim_x)
|
|
elif dim_y.value is not None and dim_y.value > 1:
|
|
return_dims.append(dim_y)
|
|
else:
|
|
return_dims.append(None)
|
|
elif dim_x.value == 1:
|
|
# We will broadcast dim_x to dim_y.
|
|
return_dims.append(dim_y)
|
|
elif dim_y.value == 1:
|
|
# We will broadcast dim_y to dim_x.
|
|
return_dims.append(dim_x)
|
|
elif dim_x.value == dim_y.value:
|
|
# The dimensions are compatible, so output is the same size in that
|
|
# dimension.
|
|
return_dims.append(dim_x.merge_with(dim_y))
|
|
else:
|
|
return None
|
|
return return_dims
|
|
|
|
|
|
def is_broadcast_compatible(shape_x, shape_y):
|
|
"""Returns True if `shape_x` and `shape_y` are broadcast compatible.
|
|
|
|
Args:
|
|
shape_x: A `TensorShape`
|
|
shape_y: A `TensorShape`
|
|
|
|
Returns:
|
|
True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
|
|
to. False otherwise.
|
|
"""
|
|
if shape_x.ndims is None or shape_y.ndims is None:
|
|
return False
|
|
return _broadcast_shape_helper(shape_x, shape_y) is not None
|
|
|
|
|
|
def broadcast_shape(shape_x, shape_y):
|
|
"""Returns the broadcasted shape between `shape_x` and `shape_y`.
|
|
|
|
Args:
|
|
shape_x: A `TensorShape`
|
|
shape_y: A `TensorShape`
|
|
|
|
Returns:
|
|
A `TensorShape` representing the broadcasted shape.
|
|
|
|
Raises:
|
|
ValueError: If the two shapes can not be broadcasted.
|
|
"""
|
|
if shape_x.ndims is None or shape_y.ndims is None:
|
|
return tensor_shape.unknown_shape()
|
|
return_dims = _broadcast_shape_helper(shape_x, shape_y)
|
|
if return_dims is None:
|
|
raise ValueError('Incompatible shapes for broadcasting. Two shapes are '
|
|
'compatible if for each dimension pair they are either '
|
|
'equal or one of them is 1. '
|
|
f'Received: {shape_x} and {shape_y}.')
|
|
return tensor_shape.TensorShape(return_dims)
|