Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/lax/convolution.py

938 lines
44 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
from functools import partial
import operator
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lib.mlir.dialects import hlo
_max = builtins.max
Array = Any
DType = Any
Shape = core.Shape
class ConvDimensionNumbers(NamedTuple):
"""Describes batch, spatial, and feature dimensions of a convolution.
Args:
lhs_spec: a tuple of nonnegative integer dimension numbers containing
`(batch dimension, feature dimension, spatial dimensions...)`.
rhs_spec: a tuple of nonnegative integer dimension numbers containing
`(out feature dimension, in feature dimension, spatial dimensions...)`.
out_spec: a tuple of nonnegative integer dimension numbers containing
`(batch dimension, feature dimension, spatial dimensions...)`.
"""
lhs_spec: Sequence[int]
rhs_spec: Sequence[int]
out_spec: Sequence[int]
ConvGeneralDilatedDimensionNumbers = Union[
None, ConvDimensionNumbers, Tuple[str, str, str]]
def conv_general_dilated(
lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
"""General n-dimensional convolution operator, with optional dilation.
Wraps XLA's `Conv
<https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
operator.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the strings `'SAME'`, `'SAME_LOWER'`, or `'VALID'`, or a
sequence of `n` `(low, high)` integer pairs that give the padding to apply
before and after each spatial dimension. `'SAME'` and `'SAME_LOWER'` add
padding to produce same output size as the input. The padding is split
between the two sides equally or almost equally. In case the padding is an
odd number, the extra padding is added at the end for `'SAME'` and at the
beginning for `'SAME_LOWER'`.
lhs_dilation: `None`, or a sequence of `n` integers, giving the dilation
factor to apply in each spatial dimension of `lhs`. LHS dilation is also
known as transposed convolution.
rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation
factor to apply in each spatial dimension of `rhs`. RHS dilation is also
known as atrous convolution.
dimension_numbers: either `None`, a ``ConvDimensionNumbers`` object, or a
3-tuple ``(lhs_spec, rhs_spec, out_spec)``, where each element is a string
of length `n+2`.
feature_group_count: integer, default 1. See XLA HLO docs.
batch_group_count: integer, default 1. See XLA HLO docs.
precision: Optional. Either ``None``, which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value
(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``), a
string (e.g. 'highest' or 'fastest', see the
``jax.default_matmul_precision`` context manager), or a tuple of two
:class:`~jax.lax.Precision` enums or strings indicating precision of
``lhs`` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
An array containing the convolution result.
In the string case of ``dimension_numbers``, each character identifies by
position:
- the batch dimensions in ``lhs``, ``rhs``, and the output with the character
'N',
- the feature dimensions in `lhs` and the output with the character 'C',
- the input and output feature dimensions in rhs with the characters 'I'
and 'O' respectively, and
- spatial dimension correspondences between lhs, rhs, and the output using
any distinct characters.
For example, to indicate dimension numbers consistent with the ``conv``
function with two spatial dimensions, one could use ``('NCHW', 'OIHW',
'NCHW')``. As another example, to indicate dimension numbers consistent with
the TensorFlow Conv2D operation, one could use ``('NHWC', 'HWIO', 'NHWC')``.
When using the latter form of convolution dimension specification, window
strides are associated with spatial dimension character labels according to
the order in which the labels appear in the ``rhs_spec`` string, so that
``window_strides[0]`` is matched with the dimension corresponding to the first
character appearing in rhs_spec that is not ``'I'`` or ``'O'``.
If ``dimension_numbers`` is ``None``, the default is ``('NCHW', 'OIHW',
'NCHW')`` (for a 2D convolution).
"""
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
if lhs_dilation is None:
lhs_dilation = (1,) * (lhs.ndim - 2)
elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):
raise ValueError(
"String padding is not implemented for transposed convolution "
"using this op. Please either exactly specify the required padding or "
"use conv_transpose.")
if rhs_dilation is None:
rhs_dilation = (1,) * (rhs.ndim - 2)
if isinstance(padding, str):
lhs_perm, rhs_perm, _ = dnums
rhs_shape = np.take(rhs.shape, rhs_perm)[2:] # type: ignore[index]
effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
padding = lax.padtype_to_pads(
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
window_strides, padding)
else:
try:
padding = tuple((operator.index(lo), operator.index(hi))
for lo, hi in padding)
except (ValueError, TypeError) as e:
raise ValueError(
"padding argument to conv_general_dilated should be a string or a "
f"sequence of (low, high) pairs, got {padding}") from e
preferred_element_type = (
None if preferred_element_type is None else
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
return conv_general_dilated_p.bind(
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
dimension_numbers=dnums,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=lax.canonicalize_precision(precision),
preferred_element_type=preferred_element_type)
### convenience wrappers around traceables
def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: str, precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
"""Convenience wrapper around `conv_general_dilated`.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`.
precision: Optional. Either ``None``, which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
An array containing the convolution result.
"""
return conv_general_dilated(lhs, rhs, window_strides, padding,
precision=precision,
preferred_element_type=preferred_element_type)
def conv_with_general_padding(lhs: Array, rhs: Array,
window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
lhs_dilation: Optional[Sequence[int]],
rhs_dilation: Optional[Sequence[int]],
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
"""Convenience wrapper around `conv_general_dilated`.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
`n` `(low, high)` integer pairs that give the padding to apply before and
after each spatial dimension.
lhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
is also known as transposed convolution.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
precision: Optional. Either ``None``, which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
An array containing the convolution result.
"""
return conv_general_dilated(
lhs, rhs, window_strides, padding, lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation, precision=precision,
preferred_element_type=preferred_element_type)
def _conv_transpose_padding(k, s, padding):
"""Calculate before and after padding for a dim of transposed convolution.
Args:
k: int: kernel dimension.
s: int: dimension stride value.
padding: 'same' or 'valid' padding mode for original forward conv.
Returns:
2-tuple: ints: before and after padding for transposed convolution.
"""
if padding == 'SAME':
pad_len = k + s - 2
if s > k - 1:
pad_a = k - 1
else:
pad_a = int(np.ceil(pad_len / 2))
elif padding == 'VALID':
pad_len = k + s - 2 + _max(k - s, 0)
pad_a = k - 1
else:
raise ValueError('Padding mode must be `SAME` or `VALID`.')
pad_b = pad_len - pad_a
return pad_a, pad_b
def _flip_axes(x, axes):
"""Flip ndarray 'x' along each axis specified in axes tuple."""
for axis in axes:
x = np.flip(x, axis)
return x
def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
transpose_kernel: bool = False,
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
"""Convenience wrapper for calculating the N-d convolution "transpose".
This function directly calculates a fractionally strided conv rather than
indirectly calculating the gradient (transpose) of a forward convolution.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
strides: sequence of `n` integers, sets fractional stride.
padding: 'SAME', 'VALID' will set as transpose of corresponding forward
conv, or a sequence of `n` integer 2-tuples describing before-and-after
padding for each `n` spatial dimension.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
dimension_numbers: tuple of dimension descriptors as in
lax.conv_general_dilated. Defaults to tensorflow convention.
transpose_kernel: if True flips spatial axes and swaps the input/output
channel axes of the kernel. This makes the output of this function identical
to the gradient-derived functions like keras.layers.Conv2DTranspose
applied to the same kernel. For typical use in neural nets this is completely
pointless and just makes input/output channel specification confusing.
precision: Optional. Either ``None``, which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
Transposed N-d convolution, with output padding following the conventions of
keras.layers.Conv2DTranspose.
"""
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2
ndims = len(lhs.shape)
one = (1,) * (ndims - 2)
# Set dimensional layout defaults if not specified.
if dimension_numbers is None:
if ndims == 2:
dimension_numbers = ('NC', 'IO', 'NC')
elif ndims == 3:
dimension_numbers = ('NHC', 'HIO', 'NHC')
elif ndims == 4:
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
elif ndims == 5:
dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
else:
raise ValueError('No 4+ dimensional dimension_number defaults.')
dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
k_shape = np.take(rhs.shape, dn.rhs_spec)
k_sdims = k_shape[2:] # type: ignore[index]
# Calculate correct output shape given padding and strides.
pads: Union[str, Sequence[Tuple[int, int]]]
if isinstance(padding, str) and padding in {'SAME', 'VALID'}:
if rhs_dilation is None:
rhs_dilation = (1,) * (rhs.ndim - 2)
effective_k_size = map(lambda k, r: (k-1) * r + 1, k_sdims, rhs_dilation)
pads = [_conv_transpose_padding(k, s, padding)
for k,s in zip(effective_k_size, strides)]
else:
pads = padding
if transpose_kernel:
# flip spatial dims and swap input / output channel axes
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn,
precision=precision,
preferred_element_type=preferred_element_type)
def _conv_general_dilated_shape_rule(
lhs: core.ShapedArray, rhs: core.ShapedArray, *, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, **unused_kwargs) -> Tuple[int, ...]:
assert type(dimension_numbers) is ConvDimensionNumbers
if len(lhs.shape) != len(rhs.shape):
msg = ("conv_general_dilated lhs and rhs must have the same number of "
"dimensions, but got {} and {}.")
raise ValueError(msg.format(lhs.shape, rhs.shape))
if not feature_group_count > 0:
msg = ("conv_general_dilated feature_group_count "
"must be a positive integer, got {}.")
raise ValueError(msg.format(feature_group_count))
lhs_feature_count = lhs.shape[dimension_numbers.lhs_spec[1]]
quot, rem = divmod(lhs_feature_count, feature_group_count)
if rem:
msg = ("conv_general_dilated feature_group_count must divide lhs feature "
"dimension size, but {} does not divide {}.")
raise ValueError(msg.format(feature_group_count, lhs_feature_count))
if not core.symbolic_equal_dim(quot, rhs.shape[dimension_numbers.rhs_spec[1]]):
msg = ("conv_general_dilated lhs feature dimension size divided by "
"feature_group_count must equal the rhs input feature dimension "
"size, but {} // {} != {}.")
raise ValueError(msg.format(lhs_feature_count, feature_group_count,
rhs.shape[dimension_numbers.rhs_spec[1]]))
if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count:
msg = ("conv_general_dilated rhs output feature dimension size must be a "
"multiple of feature_group_count, but {} is not a multiple of {}.")
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
feature_group_count))
if not batch_group_count > 0:
msg = ("conv_general_dilated batch_group_count "
"must be a positive integer, got {}.")
raise ValueError(msg.format(batch_group_count))
lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]]
if batch_group_count > 1 and lhs_batch_count % batch_group_count != 0:
msg = ("conv_general_dilated batch_group_count must divide lhs batch "
"dimension size, but {} does not divide {}.")
raise ValueError(msg.format(batch_group_count, lhs_batch_count))
if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count:
msg = ("conv_general_dilated rhs output feature dimension size must be a "
"multiple of batch_group_count, but {} is not a multiple of {}.")
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
batch_group_count))
if batch_group_count > 1 and feature_group_count > 1:
msg = ("At most one of batch_group_count and feature_group_count may be > "
"1, got batch_group_count={} and feature_group_count={}")
raise ValueError(msg.format(batch_group_count, feature_group_count))
if len(_conv_sdims(dimension_numbers.rhs_spec)) != len(window_strides):
msg = ("conv_general_dilated window and window_strides must have "
"the same number of dimensions, but got {} and {}")
raise ValueError(
msg.format(len(_conv_sdims(dimension_numbers.rhs_spec)), len(window_strides)))
lhs_perm, rhs_perm, out_perm = dimension_numbers
lhs_trans = lax._dilate_shape(np.take(lhs.shape, lhs_perm), lhs_dilation)
rhs_trans = lax._dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
batch_group_count)
return tuple(np.take(out_trans, np.argsort(out_perm))) # type: ignore[arg-type]
def _conv_general_dilated_dtype_rule(
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, preferred_element_type, **unused_kwargs):
result_dtype = lax.naryop_dtype_rule(lax._input_dtype, [lax._any, lax._any],
'conv_general_dilated', lhs, rhs)
if preferred_element_type is None:
return result_dtype
lax._validate_preferred_element_type(result_dtype, preferred_element_type)
return preferred_element_type
_conv_spec_transpose = lambda spec: (spec[1], spec[0]) + spec[2:]
_conv_sdims = lambda spec: spec[2:]
# Understanding the convolution transpose rules:
# Ignoring the spatial dimensions, let m = batch, j = input feature,
# k = output feature.
#
# Convolution computes the following contraction:
# Forward: [m, j] [j, k] -> [m, k]
#
# The transposes are similar to the rules for transposing a matmul:
# LHS transpose: [m, k] [k, j] -> [m, j]
# RHS transpose: [j, m] [m, k] -> [j, k]
#
# With feature grouping, we have the following signatures:
# Forward: [m, gj] [j, gk] -> [m, gk]
# LHS transpose: [m, gk] [k, gj] -> [m, gj]
# --> implemented as feature grouping after transposing the group from the
# kernel input features to the kernel output features.
# RHS transpose: [gj, m] [m, gk] -> [j, gk]
# --> which is batch grouping.
#
# With batch grouping, we have the following signatures:
# Forward: [gm,j] [j,gk]->[m,gk]
# LHS transpose: [m, gk][gk, j] -> [gm, j]
# --> implemented as feature grouping with transposing the group on the kernel
# and the output.
# RHS transpose: [j, gm][m, gk] -> [j, gk]
# --> which is feature grouping.
def _conv_general_dilated_transpose_lhs(
g, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision, preferred_element_type):
assert type(dimension_numbers) is ConvDimensionNumbers
assert batch_group_count == 1 or feature_group_count == 1
rhs_shape = rhs.shape
lhs_shape = lhs.aval.shape
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
lhs_spec, rhs_spec, out_spec = dimension_numbers
t_rhs_spec = _conv_spec_transpose(rhs_spec)
if feature_group_count > 1:
# in addition to switching the dims in the spec, need to move the feature
# group axis into the transposed rhs's output feature dim
rhs = _reshape_axis_out_of(rhs_spec[0], feature_group_count, rhs)
rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
elif batch_group_count > 1:
rhs = _reshape_axis_out_of(rhs_spec[0], batch_group_count, rhs)
rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
feature_group_count = batch_group_count
trans_dimension_numbers = ConvDimensionNumbers(out_spec, t_rhs_spec, lhs_spec)
padding = _conv_general_vjp_lhs_padding(
np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims),
window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation,
rhs_dilation)
revd_weights = lax.rev(rhs, rhs_sdims)
out = conv_general_dilated(
g, revd_weights, window_strides=lhs_dilation, padding=padding,
lhs_dilation=window_strides, rhs_dilation=rhs_dilation,
dimension_numbers=trans_dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=1, precision=precision,
preferred_element_type=preferred_element_type)
if batch_group_count > 1:
out = _reshape_axis_out_of(lhs_spec[1], batch_group_count, out)
out = _reshape_axis_into(lhs_spec[1], lhs_spec[0], out)
return out
def _conv_general_dilated_transpose_rhs(
g, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers: ConvDimensionNumbers, feature_group_count: int,
batch_group_count: int, precision, preferred_element_type):
assert type(dimension_numbers) is ConvDimensionNumbers
if np.size(g) == 0:
# Avoids forming degenerate convolutions where the RHS has spatial size 0.
return ad.Zero(rhs.aval)
lhs_shape = lhs.shape
rhs_shape = rhs.aval.shape
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
lhs_trans, rhs_trans, out_trans = map(_conv_spec_transpose, dimension_numbers)
assert batch_group_count == 1 or feature_group_count == 1
if batch_group_count > 1:
feature_group_count = batch_group_count
batch_group_count = 1
elif feature_group_count > 1:
batch_group_count = feature_group_count
feature_group_count = 1
trans_dimension_numbers = ConvDimensionNumbers(lhs_trans, out_trans, rhs_trans)
padding = _conv_general_vjp_rhs_padding(
np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims),
window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation,
rhs_dilation)
return conv_general_dilated(
lhs, g, window_strides=rhs_dilation, padding=padding,
lhs_dilation=lhs_dilation, rhs_dilation=window_strides,
dimension_numbers=trans_dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count, precision=precision,
preferred_element_type=preferred_element_type)
def _conv_general_dilated_batch_rule(
batched_args, batch_dims, *, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count, precision,
preferred_element_type, **unused_kwargs):
assert batch_group_count == 1 or feature_group_count == 1
lhs, rhs = batched_args
lhs_bdim, rhs_bdim = batch_dims
lhs_spec, rhs_spec, out_spec = dimension_numbers
# Some of the cases that reshape into batch or feature dimensions do not work
# with size 0 batch dimensions. The best fix would be to extend HLO to support
# multiple batch dimensions.
if ((lhs_bdim is not None and lhs.shape[lhs_bdim] == 0) or
(rhs_bdim is not None and rhs.shape[rhs_bdim] == 0)):
lhs_shape_unbatched, rhs_shape_unbatched = list(lhs.shape), list(rhs.shape)
if lhs_bdim is not None:
lhs_shape_unbatched.pop(lhs_bdim)
if rhs_bdim is not None:
rhs_shape_unbatched.pop(rhs_bdim)
shape = _conv_general_dilated_shape_rule(
core.ShapedArray(lhs_shape_unbatched, lhs.dtype),
core.ShapedArray(rhs_shape_unbatched, rhs.dtype),
window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
return lax.full(
(0,) + shape, 0,
dtype=lhs.dtype if preferred_element_type is None
else preferred_element_type), 0
if lhs_bdim is not None and rhs_bdim is not None:
assert lhs.shape[lhs_bdim] == rhs.shape[rhs_bdim]
if batch_group_count > 1:
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
batch_group_count *= lhs.shape[lhs_bdim]
else:
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[1], lhs)
feature_group_count *= lhs.shape[lhs_bdim]
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
out = conv_general_dilated(
new_lhs, new_rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count=feature_group_count,
batch_group_count=batch_group_count, precision=precision,
preferred_element_type=preferred_element_type)
out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out)
return out, out_spec[1]
elif lhs_bdim is not None:
if batch_group_count == 1:
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, precision=precision,
preferred_element_type=preferred_element_type)
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
return out, out_spec[0]
else:
new_lhs = _reshape_axis_out_of(lhs_spec[0] + int(lhs_bdim <= lhs_spec[0]),
batch_group_count, lhs)
new_lhs = _reshape_axis_into(lhs_bdim + int(lhs_spec[0] < lhs_bdim),
lhs_spec[0] + 1,
new_lhs)
new_lhs = _reshape_axis_into(lhs_spec[0], lhs_spec[0], new_lhs)
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision=precision,
preferred_element_type=preferred_element_type)
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
return out, out_spec[0]
elif rhs_bdim is not None:
if feature_group_count == 1 and batch_group_count == 1:
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision=precision,
preferred_element_type=preferred_element_type)
out = _reshape_axis_out_of(out_spec[1], rhs.shape[rhs_bdim], out)
return out, out_spec[1]
else:
# groups need to be outermost, so we need to factor them out of the
# rhs output feature dim, then factor the batch dim into the remaining rhs
# output feature dim, then put groups back in. We do something
# similar on the output. An alternative which would require more FLOPs but
# fewer reshapes would be to broadcast lhs.
group_count = (feature_group_count if feature_group_count > 1
else batch_group_count)
new_rhs = _reshape_axis_out_of(rhs_spec[0] + int(rhs_bdim <= rhs_spec[0]),
group_count, rhs)
new_rhs = _reshape_axis_into(rhs_bdim + int(rhs_spec[0] < rhs_bdim),
rhs_spec[0] + 1, new_rhs)
new_rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[0], new_rhs)
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision=precision,
preferred_element_type=preferred_element_type)
out = _reshape_axis_out_of(out_spec[1], group_count, out)
out = _reshape_axis_out_of(out_spec[1] + 1, rhs.shape[rhs_bdim], out)
out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out)
return out, out_spec[1]
conv_general_dilated_p = lax.standard_primitive(
_conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
'conv_general_dilated')
ad.defbilinear(conv_general_dilated_p,
_conv_general_dilated_transpose_lhs,
_conv_general_dilated_transpose_rhs)
batching.primitive_batchers[conv_general_dilated_p] = \
_conv_general_dilated_batch_rule
def _complex_mul(mul, x, y):
# We use a trick for complex multiplication sometimes attributed to Gauss
# which uses three multiplications and five additions; instead of the naive
# method of four multiplications and two additions.
# https://en.wikipedia.org/wiki/Multiplication_algorithm#Complex_multiplication_algorithm
#
# This performance win comes with a trade-off in accuracy; especially in
# cases when the real and imaginary differ hugely in magnitude. The relative
# error bound (e.g. 1p-24 in case of float32) would be relative to the
# maximum of real and imaginary parts of the result instead of being
# satisfied by the real and imaginary parts independently of each other.
x_re, x_im = lax.real(x), lax.imag(x)
y_re, y_im = lax.real(y), lax.imag(y)
k1 = mul(lax.add(x_re, x_im), y_re)
k2 = mul(x_re, lax.sub(y_im, y_re))
k3 = mul(x_im, lax.add(y_re, y_im))
return lax.complex(lax.sub(k1, k3), lax.add(k1, k2))
_real_dtype = lambda dtype: np.finfo(dtype).dtype
def _conv_general_dilated_lower(
ctx, lhs, rhs, *, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, precision, preferred_element_type,
expand_complex_convolutions=False, **unused_kwargs):
lhs_aval, rhs_aval = ctx.avals_in
aval_out, = ctx.avals_out
assert isinstance(dimension_numbers, ConvDimensionNumbers)
dtype = lhs_aval.dtype
if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating):
if preferred_element_type is not None:
# Convert complex dtype to types used for real and imaginary parts
assert np.issubdtype(preferred_element_type, np.complexfloating)
preferred_element_type = _real_dtype(preferred_element_type)
complex_conv = mlir.lower_fun(
partial(
_complex_mul,
partial(conv_general_dilated, window_strides=window_strides,
padding=padding, lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count, precision=precision,
preferred_element_type=preferred_element_type)),
multiple_results=False)
return complex_conv(ctx, lhs, rhs)
lhs_spec, rhs_spec, out_spec = dimension_numbers
dnums = hlo.ConvDimensionNumbers.get(
input_batch_dimension=lhs_spec[0],
input_feature_dimension=lhs_spec[1],
input_spatial_dimensions=list(lhs_spec[2:]),
kernel_output_feature_dimension=rhs_spec[0],
kernel_input_feature_dimension=rhs_spec[1],
kernel_spatial_dimensions=list(rhs_spec[2:]),
output_batch_dimension=out_spec[0],
output_feature_dimension=out_spec[1],
output_spatial_dimensions=list(out_spec[2:]))
num_spatial_dims = len(rhs_spec) - 2
if len(padding) == 0:
padding = np.zeros((0, 2), dtype=np.int64)
window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
if (not core.is_constant_shape(window_strides) or
not core.is_constant_shape(lhs_dilation) or
not core.is_constant_shape(rhs_dilation) or
not core.is_constant_dim(feature_group_count) or
not core.is_constant_dim(batch_group_count)):
# TODO(https://github.com/openxla/stablehlo/issues/1268)
raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count")
if all(core.is_constant_shape(p) for p in padding):
return [
hlo.ConvolutionOp(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
dimension_numbers=dnums,
feature_group_count=mlir.i64_attr(feature_group_count),
batch_group_count=mlir.i64_attr(batch_group_count),
window_strides=mlir.dense_int_elements(window_strides),
padding=mlir.dense_int_elements(padding),
lhs_dilation=mlir.dense_int_elements(lhs_dilation),
rhs_dilation=mlir.dense_int_elements(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision)).result
]
else:
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
# spatial dimension.
int2d = mlir.aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
def prep_one_pad(pad_lo_hi: Tuple[core.DimSize, core.DimSize]):
pad1 = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, pad_lo_hi)) # i32[2]
return hlo.ReshapeOp(int2d, pad1)
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),
mlir.i64_attr(0))
return [
hlo.DynamicConvOp(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
d_padding,
dimension_numbers=dnums,
feature_group_count=mlir.i64_attr(feature_group_count),
batch_group_count=mlir.i64_attr(batch_group_count),
window_strides=mlir.dense_int_elements(window_strides),
lhs_dilation=mlir.dense_int_elements(lhs_dilation),
rhs_dilation=mlir.dense_int_elements(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision)).result
]
mlir.register_lowering(conv_general_dilated_p, _conv_general_dilated_lower)
# TODO(b/161124619, b/161126248): XLA does not support complex convolution on
# GPU, and on CPU it uses a slow loop-based implementation;
# on these backends, lower complex convolutions away.
mlir.register_lowering(
conv_general_dilated_p,
partial(_conv_general_dilated_lower, expand_complex_convolutions=True),
platform='cpu')
mlir.register_lowering(
conv_general_dilated_p,
partial(_conv_general_dilated_lower, expand_complex_convolutions=True),
platform='gpu')
def _reshape_axis_into(src, dst, x):
# NB: `dst` is the number of the dimension that we should reshape into
# *after* `src` is removed from `x`'s list of dimensions. For example, if
# `src` is an added batch dimension, `dst` might name a target dimension in
# the unbatched list of dimensions.
perm = [i for i in range(x.ndim) if i != src]
perm.insert(dst, src)
new_shape = list(np.delete(x.shape, src))
new_shape[dst] *= x.shape[src]
return lax.reshape(x, new_shape, perm)
def _reshape_axis_out_of(src, size1, x):
shape = list(x.shape)
size2, ragged = divmod(shape[src], size1)
assert not ragged
shape[src:src+1] = [size1, size2]
return lax.reshape(x, shape)
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
"""Compute the shape tuple of a conv given input shapes in canonical order."""
if isinstance(pads, str):
pads = lax.padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
if len(pads) != len(lhs_shape) - 2:
msg = "Wrong number of explicit pads for convolution: expected {}, got {}."
raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))
lhs_padded = np.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2),
axis=1))
if np.any(lhs_padded < 0):
raise ValueError("Negative padding is larger than the size of the corresponding dimension: "
f"got padding={pads} for lhs_shape[2:]={lhs_shape[2:]}")
out_space = core.stride_shape(lhs_padded, rhs_shape[2:], strides)
out_space = [d if core.greater_equal_dim(d, 0) else 0
for d in out_space]
if batch_group_count > 1:
assert lhs_shape[0] % batch_group_count == 0
out_shape_0 = lhs_shape[0] // batch_group_count
else:
out_shape_0 = lhs_shape[0]
out_shape = (out_shape_0, rhs_shape[0])
return tuple(out_shape + tuple(out_space))
def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
dimension_numbers):
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
lhs_trans = np.take(lhs_shape, lhs_perm)
rhs_trans = np.take(rhs_shape, rhs_perm)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
return tuple(np.take(out_trans, np.argsort(out_perm)))
def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
dimension_numbers):
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
lhs_trans = np.take(lhs_shape, lhs_perm)
rhs_trans = np.take(rhs_shape, rhs_perm)
if isinstance(padding, str):
padding = [_conv_transpose_padding(k, s, padding)
for k,s in zip(rhs_trans[2:], window_strides)]
padding = list(map(np.sum, padding))
unpad_out_space = [(i-1) * s - k + 2
for i, k, s in zip(lhs_trans[2:],
rhs_trans[2:],
window_strides)]
out_space = np.sum([unpad_out_space, padding], axis=0).tolist()
out_trans = tuple((lhs_trans[0], rhs_trans[0]) + tuple(out_space))
return tuple(np.take(out_trans, np.argsort(out_perm)))
def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers
) -> ConvDimensionNumbers:
"""Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`.
Args:
lhs_shape: tuple of nonnegative integers, shape of the convolution input.
rhs_shape: tuple of nonnegative integers, shape of the convolution kernel.
dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers
object following the convolution dimension number specification format in
xla_client.py.
Returns:
A `ConvDimensionNumbers` object that represents `dimension_numbers` in the
canonical form used by lax functions.
"""
if isinstance(dimension_numbers, ConvDimensionNumbers):
return dimension_numbers
if len(lhs_shape) != len(rhs_shape):
msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}."
raise TypeError(msg.format(len(lhs_shape), len(rhs_shape)))
if dimension_numbers is None:
iota = tuple(range(len(lhs_shape)))
return ConvDimensionNumbers(iota, iota, iota)
elif isinstance(dimension_numbers, (list, tuple)):
if len(dimension_numbers) != 3:
msg = "convolution dimension_numbers list/tuple must be length 3, got {}."
raise TypeError(msg.format(len(dimension_numbers)))
if not all(isinstance(elt, str) for elt in dimension_numbers):
msg = "convolution dimension_numbers elements must be strings, got {}."
raise TypeError(msg.format(tuple(map(type, dimension_numbers))))
msg = ("convolution dimension_numbers[{}] must have len equal to the ndim "
"of lhs and rhs, got {} for lhs and rhs shapes {} and {}.")
for i, elt in enumerate(dimension_numbers):
if len(elt) != len(lhs_shape):
raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape))
lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers)
return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
else:
msg = "convolution dimension_numbers must be tuple/list or None, got {}."
raise TypeError(msg.format(type(dimension_numbers)))
def conv_general_permutations(dimension_numbers):
"""Utility for convolution dimension permutations relative to Conv HLO."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C")
for i, (a, b) in enumerate(charpairs):
if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1:
msg = ("convolution dimension_numbers[{}] must contain the characters "
"'{}' and '{}' exactly once, got {}.")
raise TypeError(msg.format(i, a, b, dimension_numbers[i]))
if len(dimension_numbers[i]) != len(set(dimension_numbers[i])):
msg = ("convolution dimension_numbers[{}] cannot have duplicate "
"characters, got {}.")
raise TypeError(msg.format(i, dimension_numbers[i]))
if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) ==
set(out_spec) - set(out_char)):
msg = ("convolution dimension_numbers elements must each have the same "
"set of spatial characters, got {}.")
raise TypeError(msg.format(dimension_numbers))
def getperm(spec, charpair):
spatial = (i for i, c in enumerate(spec) if c not in charpair)
if spec is not rhs_spec:
spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i]))
return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial)
lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs)
return lhs_perm, rhs_perm, out_perm
def _conv_general_vjp_lhs_padding(
in_shape, window_dimensions, window_strides, out_shape, padding,
lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]:
lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation)
rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation)
out_dilated_shape = lax._dilate_shape(out_shape, window_strides)
pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1
pad_after = (np.add(lhs_dilated_shape, rhs_dilated_shape) - 1
- out_dilated_shape - pad_before)
return util.safe_zip(pad_before, pad_after)
def _conv_general_vjp_rhs_padding(
in_shape, window_dimensions, window_strides, out_shape, padding,
lhs_dilation, rhs_dilation):
if len(in_shape) == 0: # 0D conv
return []
lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation)
rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation)
out_dilated_shape = lax._dilate_shape(out_shape, window_strides)
pads_lo, _ = util.unzip2(padding)
pads_from_lhs = core.diff_shape(out_dilated_shape, lhs_dilated_shape)
pads_from_rhs = core.diff_shape(core.diff_shape(rhs_dilated_shape, pads_lo),
(1,) * len(pads_lo))
pads_hi = core.sum_shapes(pads_from_lhs, pads_from_rhs)
return list(zip(pads_lo, pads_hi))