582 lines
20 KiB
Python
582 lines
20 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Utilities used by convolution layers."""
|
||
|
|
||
|
import itertools
|
||
|
|
||
|
import numpy as np
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
|
||
|
|
||
|
def convert_data_format(data_format, ndim):
|
||
|
if data_format == "channels_last":
|
||
|
if ndim == 3:
|
||
|
return "NWC"
|
||
|
elif ndim == 4:
|
||
|
return "NHWC"
|
||
|
elif ndim == 5:
|
||
|
return "NDHWC"
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Input rank not supported: {ndim}. "
|
||
|
"Expected values are [3, 4, 5]"
|
||
|
)
|
||
|
elif data_format == "channels_first":
|
||
|
if ndim == 3:
|
||
|
return "NCW"
|
||
|
elif ndim == 4:
|
||
|
return "NCHW"
|
||
|
elif ndim == 5:
|
||
|
return "NCDHW"
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Input rank not supported: {ndim}. "
|
||
|
"Expected values are [3, 4, 5]"
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Invalid data_format: {data_format}. "
|
||
|
'Expected values are ["channels_first", "channels_last"]'
|
||
|
)
|
||
|
|
||
|
|
||
|
def normalize_tuple(value, n, name, allow_zero=False):
|
||
|
"""Transforms non-negative/positive integer/integers into an integer tuple.
|
||
|
|
||
|
Args:
|
||
|
value: The value to validate and convert. Could an int, or any iterable of
|
||
|
ints.
|
||
|
n: The size of the tuple to be returned.
|
||
|
name: The name of the argument being validated, e.g. "strides" or
|
||
|
"kernel_size". This is only used to format error messages.
|
||
|
allow_zero: Default to False. A ValueError will raised if zero is received
|
||
|
and this param is False.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of n integers.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If something else than an int/long or iterable thereof or a
|
||
|
negative value is
|
||
|
passed.
|
||
|
"""
|
||
|
error_msg = (
|
||
|
f"The `{name}` argument must be a tuple of {n} "
|
||
|
f"integers. Received: {value}"
|
||
|
)
|
||
|
|
||
|
if isinstance(value, int):
|
||
|
value_tuple = (value,) * n
|
||
|
else:
|
||
|
try:
|
||
|
value_tuple = tuple(value)
|
||
|
except TypeError:
|
||
|
raise ValueError(error_msg)
|
||
|
if len(value_tuple) != n:
|
||
|
raise ValueError(error_msg)
|
||
|
for single_value in value_tuple:
|
||
|
try:
|
||
|
int(single_value)
|
||
|
except (ValueError, TypeError):
|
||
|
error_msg += (
|
||
|
f"including element {single_value} of "
|
||
|
f"type {type(single_value)}"
|
||
|
)
|
||
|
raise ValueError(error_msg)
|
||
|
|
||
|
if allow_zero:
|
||
|
unqualified_values = {v for v in value_tuple if v < 0}
|
||
|
req_msg = ">= 0"
|
||
|
else:
|
||
|
unqualified_values = {v for v in value_tuple if v <= 0}
|
||
|
req_msg = "> 0"
|
||
|
|
||
|
if unqualified_values:
|
||
|
error_msg += (
|
||
|
f" including {unqualified_values}"
|
||
|
f" that does not satisfy the requirement `{req_msg}`."
|
||
|
)
|
||
|
raise ValueError(error_msg)
|
||
|
|
||
|
return value_tuple
|
||
|
|
||
|
|
||
|
def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
|
||
|
"""Determines output length of a convolution given input length.
|
||
|
|
||
|
Args:
|
||
|
input_length: integer.
|
||
|
filter_size: integer.
|
||
|
padding: one of "same", "valid", "full", "causal"
|
||
|
stride: integer.
|
||
|
dilation: dilation rate, integer.
|
||
|
|
||
|
Returns:
|
||
|
The output length (integer).
|
||
|
"""
|
||
|
if input_length is None:
|
||
|
return None
|
||
|
assert padding in {"same", "valid", "full", "causal"}
|
||
|
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
|
||
|
if padding in ["same", "causal"]:
|
||
|
output_length = input_length
|
||
|
elif padding == "valid":
|
||
|
output_length = input_length - dilated_filter_size + 1
|
||
|
elif padding == "full":
|
||
|
output_length = input_length + dilated_filter_size - 1
|
||
|
return (output_length + stride - 1) // stride
|
||
|
|
||
|
|
||
|
def conv_input_length(output_length, filter_size, padding, stride):
|
||
|
"""Determines input length of a convolution given output length.
|
||
|
|
||
|
Args:
|
||
|
output_length: integer.
|
||
|
filter_size: integer.
|
||
|
padding: one of "same", "valid", "full".
|
||
|
stride: integer.
|
||
|
|
||
|
Returns:
|
||
|
The input length (integer).
|
||
|
"""
|
||
|
if output_length is None:
|
||
|
return None
|
||
|
assert padding in {"same", "valid", "full"}
|
||
|
if padding == "same":
|
||
|
pad = filter_size // 2
|
||
|
elif padding == "valid":
|
||
|
pad = 0
|
||
|
elif padding == "full":
|
||
|
pad = filter_size - 1
|
||
|
return (output_length - 1) * stride - 2 * pad + filter_size
|
||
|
|
||
|
|
||
|
def deconv_output_length(
|
||
|
input_length,
|
||
|
filter_size,
|
||
|
padding,
|
||
|
output_padding=None,
|
||
|
stride=0,
|
||
|
dilation=1,
|
||
|
):
|
||
|
"""Determines output length of a transposed convolution given input length.
|
||
|
|
||
|
Args:
|
||
|
input_length: Integer.
|
||
|
filter_size: Integer.
|
||
|
padding: one of `"same"`, `"valid"`, `"full"`.
|
||
|
output_padding: Integer, amount of padding along the output dimension.
|
||
|
Can be set to `None` in which case the output length is inferred.
|
||
|
stride: Integer.
|
||
|
dilation: Integer.
|
||
|
|
||
|
Returns:
|
||
|
The output length (integer).
|
||
|
"""
|
||
|
assert padding in {"same", "valid", "full"}
|
||
|
if input_length is None:
|
||
|
return None
|
||
|
|
||
|
# Get the dilated kernel size
|
||
|
filter_size = filter_size + (filter_size - 1) * (dilation - 1)
|
||
|
|
||
|
# Infer length if output padding is None, else compute the exact length
|
||
|
if output_padding is None:
|
||
|
if padding == "valid":
|
||
|
length = input_length * stride + max(filter_size - stride, 0)
|
||
|
elif padding == "full":
|
||
|
length = input_length * stride - (stride + filter_size - 2)
|
||
|
elif padding == "same":
|
||
|
length = input_length * stride
|
||
|
|
||
|
else:
|
||
|
if padding == "same":
|
||
|
pad = filter_size // 2
|
||
|
elif padding == "valid":
|
||
|
pad = 0
|
||
|
elif padding == "full":
|
||
|
pad = filter_size - 1
|
||
|
|
||
|
length = (
|
||
|
(input_length - 1) * stride + filter_size - 2 * pad + output_padding
|
||
|
)
|
||
|
return length
|
||
|
|
||
|
|
||
|
def normalize_data_format(value):
|
||
|
if value is None:
|
||
|
value = backend.image_data_format()
|
||
|
data_format = value.lower()
|
||
|
if data_format not in {"channels_first", "channels_last"}:
|
||
|
raise ValueError(
|
||
|
"The `data_format` argument must be one of "
|
||
|
f'"channels_first", "channels_last". Received: {value}'
|
||
|
)
|
||
|
return data_format
|
||
|
|
||
|
|
||
|
def normalize_padding(value):
|
||
|
if isinstance(value, (list, tuple)):
|
||
|
return value
|
||
|
padding = value.lower()
|
||
|
if padding not in {"valid", "same", "causal"}:
|
||
|
raise ValueError(
|
||
|
"The `padding` argument must be a list/tuple or one of "
|
||
|
'"valid", "same" (or "causal", only for `Conv1D). '
|
||
|
f"Received: {padding}"
|
||
|
)
|
||
|
return padding
|
||
|
|
||
|
|
||
|
def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
|
||
|
"""Compute a mask representing the connectivity of a convolution operation.
|
||
|
|
||
|
Assume a convolution with given parameters is applied to an input having N
|
||
|
spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
|
||
|
output with shape `(d_out1, ..., d_outN)`. This method returns a boolean
|
||
|
array of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True`
|
||
|
entries indicating pairs of input and output locations that are connected by
|
||
|
a weight.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
>>> input_shape = (4,)
|
||
|
>>> kernel_shape = (2,)
|
||
|
>>> strides = (1,)
|
||
|
>>> padding = "valid"
|
||
|
>>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
|
||
|
array([[ True, False, False],
|
||
|
[ True, True, False],
|
||
|
[False, True, True],
|
||
|
[False, False, True]])
|
||
|
|
||
|
where rows and columns correspond to inputs and outputs respectively.
|
||
|
|
||
|
|
||
|
Args:
|
||
|
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
|
||
|
input.
|
||
|
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
|
||
|
receptive field.
|
||
|
strides: tuple of size N, strides along each spatial dimension.
|
||
|
padding: type of padding, string `"same"` or `"valid"`.
|
||
|
`"valid"` means no padding. `"same"` results in padding evenly to
|
||
|
the left/right or up/down of the input such that output has the same
|
||
|
height/width dimension as the input.
|
||
|
|
||
|
Returns:
|
||
|
A boolean 2N-D `np.ndarray` of shape
|
||
|
`(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
|
||
|
is the spatial shape of the output. `True` entries in the mask represent
|
||
|
pairs of input-output locations that are connected by a weight.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
|
||
|
same number of dimensions.
|
||
|
NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
|
||
|
"""
|
||
|
if padding not in {"same", "valid"}:
|
||
|
raise NotImplementedError(
|
||
|
f"Padding type {padding} not supported. "
|
||
|
'Only "valid" and "same" are implemented.'
|
||
|
)
|
||
|
|
||
|
in_dims = len(input_shape)
|
||
|
if isinstance(kernel_shape, int):
|
||
|
kernel_shape = (kernel_shape,) * in_dims
|
||
|
if isinstance(strides, int):
|
||
|
strides = (strides,) * in_dims
|
||
|
|
||
|
kernel_dims = len(kernel_shape)
|
||
|
stride_dims = len(strides)
|
||
|
if kernel_dims != in_dims or stride_dims != in_dims:
|
||
|
raise ValueError(
|
||
|
"Number of strides, input and kernel dimensions must all "
|
||
|
f"match. Received: stride_dims={stride_dims}, "
|
||
|
f"in_dims={in_dims}, kernel_dims={kernel_dims}"
|
||
|
)
|
||
|
|
||
|
output_shape = conv_output_shape(
|
||
|
input_shape, kernel_shape, strides, padding
|
||
|
)
|
||
|
|
||
|
mask_shape = input_shape + output_shape
|
||
|
mask = np.zeros(mask_shape, bool)
|
||
|
|
||
|
output_axes_ticks = [range(dim) for dim in output_shape]
|
||
|
for output_position in itertools.product(*output_axes_ticks):
|
||
|
input_axes_ticks = conv_connected_inputs(
|
||
|
input_shape, kernel_shape, output_position, strides, padding
|
||
|
)
|
||
|
for input_position in itertools.product(*input_axes_ticks):
|
||
|
mask[input_position + output_position] = True
|
||
|
|
||
|
return mask
|
||
|
|
||
|
|
||
|
def conv_kernel_idxs(
|
||
|
input_shape,
|
||
|
kernel_shape,
|
||
|
strides,
|
||
|
padding,
|
||
|
filters_in,
|
||
|
filters_out,
|
||
|
data_format,
|
||
|
):
|
||
|
"""Yields output-input tuples of indices in a CNN layer.
|
||
|
|
||
|
The generator iterates over all `(output_idx, input_idx)` tuples, where
|
||
|
`output_idx` is an integer index in a flattened tensor representing a single
|
||
|
output image of a convolutional layer that is connected (via the layer
|
||
|
weights) to the respective single input image at `input_idx`
|
||
|
|
||
|
Example:
|
||
|
|
||
|
>>> input_shape = (2, 2)
|
||
|
>>> kernel_shape = (2, 1)
|
||
|
>>> strides = (1, 1)
|
||
|
>>> padding = "valid"
|
||
|
>>> filters_in = 1
|
||
|
>>> filters_out = 1
|
||
|
>>> data_format = "channels_last"
|
||
|
>>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding,
|
||
|
... filters_in, filters_out, data_format))
|
||
|
[(0, 0), (0, 2), (1, 1), (1, 3)]
|
||
|
|
||
|
Args:
|
||
|
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
|
||
|
input.
|
||
|
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
|
||
|
receptive field.
|
||
|
strides: tuple of size N, strides along each spatial dimension.
|
||
|
padding: type of padding, string `"same"` or `"valid"`.
|
||
|
`"valid"` means no padding. `"same"` results in padding evenly to
|
||
|
the left/right or up/down of the input such that output has the same
|
||
|
height/width dimension as the input.
|
||
|
filters_in: `int`, number if filters in the input to the layer.
|
||
|
filters_out: `int', number if filters in the output of the layer.
|
||
|
data_format: string, "channels_first" or "channels_last".
|
||
|
|
||
|
Yields:
|
||
|
The next tuple `(output_idx, input_idx)`, where `output_idx` is an integer
|
||
|
index in a flattened tensor representing a single output image of a
|
||
|
convolutional layer that is connected (via the layer weights) to the
|
||
|
respective single input image at `input_idx`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if `data_format` is neither `"channels_last"` nor
|
||
|
`"channels_first"`, or if number of strides, input, and kernel number
|
||
|
of dimensions do not match.
|
||
|
|
||
|
NotImplementedError: if `padding` is neither `"same"` nor `"valid"`.
|
||
|
"""
|
||
|
if padding not in ("same", "valid"):
|
||
|
raise NotImplementedError(
|
||
|
f"Padding type {padding} not supported. "
|
||
|
'Only "valid" and "same" are implemented.'
|
||
|
)
|
||
|
|
||
|
in_dims = len(input_shape)
|
||
|
if isinstance(kernel_shape, int):
|
||
|
kernel_shape = (kernel_shape,) * in_dims
|
||
|
if isinstance(strides, int):
|
||
|
strides = (strides,) * in_dims
|
||
|
|
||
|
kernel_dims = len(kernel_shape)
|
||
|
stride_dims = len(strides)
|
||
|
if kernel_dims != in_dims or stride_dims != in_dims:
|
||
|
raise ValueError(
|
||
|
"Number of strides, input and kernel dimensions must all "
|
||
|
f"match. Received: stride_dims={stride_dims}, "
|
||
|
f"in_dims={in_dims}, kernel_dims={kernel_dims}"
|
||
|
)
|
||
|
|
||
|
output_shape = conv_output_shape(
|
||
|
input_shape, kernel_shape, strides, padding
|
||
|
)
|
||
|
output_axes_ticks = [range(dim) for dim in output_shape]
|
||
|
|
||
|
if data_format == "channels_first":
|
||
|
concat_idxs = (
|
||
|
lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx
|
||
|
)
|
||
|
elif data_format == "channels_last":
|
||
|
concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (
|
||
|
filter_idx,
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Data format `{data_format}` not recognized."
|
||
|
'`data_format` must be "channels_first" or "channels_last".'
|
||
|
)
|
||
|
|
||
|
for output_position in itertools.product(*output_axes_ticks):
|
||
|
input_axes_ticks = conv_connected_inputs(
|
||
|
input_shape, kernel_shape, output_position, strides, padding
|
||
|
)
|
||
|
for input_position in itertools.product(*input_axes_ticks):
|
||
|
for f_in in range(filters_in):
|
||
|
for f_out in range(filters_out):
|
||
|
out_idx = np.ravel_multi_index(
|
||
|
multi_index=concat_idxs(output_position, f_out),
|
||
|
dims=concat_idxs(output_shape, filters_out),
|
||
|
)
|
||
|
in_idx = np.ravel_multi_index(
|
||
|
multi_index=concat_idxs(input_position, f_in),
|
||
|
dims=concat_idxs(input_shape, filters_in),
|
||
|
)
|
||
|
yield (out_idx, in_idx)
|
||
|
|
||
|
|
||
|
def conv_connected_inputs(
|
||
|
input_shape, kernel_shape, output_position, strides, padding
|
||
|
):
|
||
|
"""Return locations of the input connected to an output position.
|
||
|
|
||
|
Assume a convolution with given parameters is applied to an input having N
|
||
|
spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
|
||
|
returns N ranges specifying the input region that was convolved with the
|
||
|
kernel to produce the output at position
|
||
|
`output_position = (p_out1, ..., p_outN)`.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
>>> input_shape = (4, 4)
|
||
|
>>> kernel_shape = (2, 1)
|
||
|
>>> output_position = (1, 1)
|
||
|
>>> strides = (1, 1)
|
||
|
>>> padding = "valid"
|
||
|
>>> conv_connected_inputs(input_shape, kernel_shape, output_position,
|
||
|
... strides, padding)
|
||
|
[range(1, 3), range(1, 2)]
|
||
|
|
||
|
Args:
|
||
|
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
|
||
|
input.
|
||
|
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
|
||
|
receptive field.
|
||
|
output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single
|
||
|
position in the output of the convolution.
|
||
|
strides: tuple of size N, strides along each spatial dimension.
|
||
|
padding: type of padding, string `"same"` or `"valid"`.
|
||
|
`"valid"` means no padding. `"same"` results in padding evenly to
|
||
|
the left/right or up/down of the input such that output has the same
|
||
|
height/width dimension as the input.
|
||
|
|
||
|
Returns:
|
||
|
N ranges `[[p_in_left1, ..., p_in_right1], ...,
|
||
|
[p_in_leftN, ..., p_in_rightN]]` specifying the region in the
|
||
|
input connected to output_position.
|
||
|
"""
|
||
|
ranges = []
|
||
|
|
||
|
ndims = len(input_shape)
|
||
|
for d in range(ndims):
|
||
|
left_shift = int(kernel_shape[d] / 2)
|
||
|
right_shift = kernel_shape[d] - left_shift
|
||
|
|
||
|
center = output_position[d] * strides[d]
|
||
|
|
||
|
if padding == "valid":
|
||
|
center += left_shift
|
||
|
|
||
|
start = max(0, center - left_shift)
|
||
|
end = min(input_shape[d], center + right_shift)
|
||
|
|
||
|
ranges.append(range(start, end))
|
||
|
|
||
|
return ranges
|
||
|
|
||
|
|
||
|
def conv_output_shape(input_shape, kernel_shape, strides, padding):
|
||
|
"""Return the output shape of an N-D convolution.
|
||
|
|
||
|
Forces dimensions where input is empty (size 0) to remain empty.
|
||
|
|
||
|
Args:
|
||
|
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
|
||
|
input.
|
||
|
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
|
||
|
receptive field.
|
||
|
strides: tuple of size N, strides along each spatial dimension.
|
||
|
padding: type of padding, string `"same"` or `"valid"`.
|
||
|
`"valid"` means no padding. `"same"` results in padding evenly to
|
||
|
the left/right or up/down of the input such that output has the same
|
||
|
height/width dimension as the input.
|
||
|
|
||
|
Returns:
|
||
|
tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
|
||
|
"""
|
||
|
dims = range(len(kernel_shape))
|
||
|
output_shape = [
|
||
|
conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d])
|
||
|
for d in dims
|
||
|
]
|
||
|
output_shape = tuple(
|
||
|
[0 if input_shape[d] == 0 else output_shape[d] for d in dims]
|
||
|
)
|
||
|
return output_shape
|
||
|
|
||
|
|
||
|
def squeeze_batch_dims(inp, op, inner_rank):
|
||
|
"""Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
|
||
|
|
||
|
Where `squeeze_batch` reshapes `inp` to shape
|
||
|
`[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]`
|
||
|
and `unsqueeze_batch` does the reverse reshape but on the output.
|
||
|
|
||
|
Args:
|
||
|
inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape`
|
||
|
is length `inner_rank`.
|
||
|
op: A callable that takes a single input tensor and returns a single.
|
||
|
output tensor.
|
||
|
inner_rank: A python integer.
|
||
|
|
||
|
Returns:
|
||
|
`unsqueeze_batch_op(squeeze_batch(inp))`.
|
||
|
"""
|
||
|
with tf.name_scope("squeeze_batch_dims"):
|
||
|
shape = inp.shape
|
||
|
|
||
|
inner_shape = shape[-inner_rank:]
|
||
|
if not inner_shape.is_fully_defined():
|
||
|
inner_shape = tf.shape(inp)[-inner_rank:]
|
||
|
|
||
|
batch_shape = shape[:-inner_rank]
|
||
|
if not batch_shape.is_fully_defined():
|
||
|
batch_shape = tf.shape(inp)[:-inner_rank]
|
||
|
|
||
|
if isinstance(inner_shape, tf.TensorShape):
|
||
|
inp_reshaped = tf.reshape(inp, [-1] + inner_shape.as_list())
|
||
|
else:
|
||
|
inp_reshaped = tf.reshape(
|
||
|
inp, tf.concat(([-1], inner_shape), axis=-1)
|
||
|
)
|
||
|
|
||
|
out_reshaped = op(inp_reshaped)
|
||
|
|
||
|
out_inner_shape = out_reshaped.shape[-inner_rank:]
|
||
|
if not out_inner_shape.is_fully_defined():
|
||
|
out_inner_shape = tf.shape(out_reshaped)[-inner_rank:]
|
||
|
|
||
|
out = tf.reshape(
|
||
|
out_reshaped, tf.concat((batch_shape, out_inner_shape), axis=-1)
|
||
|
)
|
||
|
|
||
|
out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:])
|
||
|
return out
|