591 lines
18 KiB
Python
591 lines
18 KiB
Python
import math
|
|
|
|
from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch._prims as prims
|
|
import torch._prims_common as utils
|
|
from torch._decomp import register_decomposition
|
|
from torch._prims_common import DimsType, ShapeType, TensorLikeType
|
|
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
|
|
|
|
__all__ = [
|
|
# Transforms
|
|
"fft",
|
|
"fft2",
|
|
"fftn",
|
|
"hfft",
|
|
"hfft2",
|
|
"hfftn",
|
|
"rfft",
|
|
"rfft2",
|
|
"rfftn",
|
|
"ifft",
|
|
"ifft2",
|
|
"ifftn",
|
|
"ihfft",
|
|
"ihfft2",
|
|
"ihfftn",
|
|
"irfft",
|
|
"irfft2",
|
|
"irfftn",
|
|
# Helpers
|
|
"fftshift",
|
|
"ifftshift",
|
|
]
|
|
|
|
NormType = Union[None, Literal["forward", "backward", "ortho"]]
|
|
_NORM_VALUES = {None, "forward", "backward", "ortho"}
|
|
aten = torch._ops.ops.aten
|
|
|
|
|
|
def _apply_norm(
|
|
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
|
|
) -> TensorLikeType:
|
|
"""Apply normalization to the un-normalized FFT result"""
|
|
torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
|
|
|
|
if norm == "ortho":
|
|
return x * (1 / math.sqrt(signal_numel))
|
|
|
|
normalize = (not forward and (norm is None or norm == "backward")) or (
|
|
forward and norm == "forward"
|
|
)
|
|
return x * (1 / signal_numel) if normalize else x
|
|
|
|
|
|
def _promote_type_fft(
|
|
dtype: torch.dtype, require_complex: bool, device: torch.device
|
|
) -> torch.dtype:
|
|
"""Helper to promote a dtype to one supported by the FFT primitives"""
|
|
if dtype.is_complex:
|
|
return dtype
|
|
|
|
# Promote integral to default float type
|
|
if not dtype.is_floating_point:
|
|
dtype = torch.get_default_dtype()
|
|
|
|
allowed_types = [torch.float32, torch.float64]
|
|
maybe_support_half = device.type in ["cuda", "meta"]
|
|
|
|
if maybe_support_half:
|
|
allowed_types.append(torch.float16)
|
|
torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
|
|
|
|
if require_complex:
|
|
dtype = utils.corresponding_complex_dtype(dtype)
|
|
|
|
return dtype
|
|
|
|
|
|
def _maybe_promote_tensor_fft(
|
|
t: TensorLikeType, require_complex: bool = False
|
|
) -> TensorLikeType:
|
|
"""Helper to promote a tensor to a dtype supported by the FFT primitives"""
|
|
cur_type = t.dtype
|
|
new_type = _promote_type_fft(cur_type, require_complex, t.device)
|
|
return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
|
|
|
|
|
|
def _resize_fft_input(
|
|
x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
|
|
) -> TensorLikeType:
|
|
"""
|
|
Fixes the shape of x such that x.size(dims[i]) == sizes[i],
|
|
either by zero-padding, or by slicing x starting from 0.
|
|
"""
|
|
assert len(dims) == len(sizes)
|
|
must_copy = False
|
|
x_sizes = x.shape
|
|
pad_amount = [0] * len(x_sizes) * 2
|
|
for i in range(len(dims)):
|
|
if sizes[i] == -1:
|
|
continue
|
|
|
|
if x_sizes[dims[i]] < sizes[i]:
|
|
must_copy = True
|
|
pad_idx = len(pad_amount) - 2 * dims[i] - 1
|
|
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
|
|
|
|
if x_sizes[dims[i]] > sizes[i]:
|
|
x = x.narrow(dims[i], 0, sizes[i])
|
|
|
|
return torch.constant_pad_nd(x, pad_amount) if must_copy else x
|
|
|
|
|
|
def _fft_c2r(
|
|
func_name: str,
|
|
input: TensorLikeType,
|
|
n: Optional[int],
|
|
dim: int,
|
|
norm: NormType,
|
|
forward: bool,
|
|
) -> TensorLikeType:
|
|
"""Common code for performing any complex to real FFT (irfft or hfft)"""
|
|
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
|
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
|
|
torch._check(
|
|
last_dim_size >= 1,
|
|
lambda: f"Invalid number of data points ({last_dim_size}) specified",
|
|
)
|
|
|
|
if n is not None:
|
|
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
|
|
|
|
if forward:
|
|
input = torch.conj(input)
|
|
|
|
output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
|
|
return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
|
|
|
|
|
|
def _fft_r2c(
|
|
func_name: str,
|
|
input: TensorLikeType,
|
|
n: Optional[int],
|
|
dim: int,
|
|
norm: NormType,
|
|
forward: bool,
|
|
onesided: bool,
|
|
) -> TensorLikeType:
|
|
"""Common code for performing any real to complex FFT (rfft or ihfft)"""
|
|
torch._check(
|
|
not input.dtype.is_complex,
|
|
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
|
|
)
|
|
input = _maybe_promote_tensor_fft(input)
|
|
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
|
dim_size = n if n is not None else input.shape[dim]
|
|
torch._check(
|
|
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
|
|
)
|
|
|
|
if n is not None:
|
|
input = _resize_fft_input(input, dims, (n,))
|
|
|
|
ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
|
|
ret = _apply_norm(ret, norm, dim_size, forward)
|
|
return ret if forward else torch.conj(ret)
|
|
|
|
|
|
def _fft_c2c(
|
|
func_name: str,
|
|
input: TensorLikeType,
|
|
n: Optional[int],
|
|
dim: int,
|
|
norm: NormType,
|
|
forward: bool,
|
|
) -> TensorLikeType:
|
|
"""Common code for performing any complex to complex FFT (fft or ifft)"""
|
|
torch._check(
|
|
input.dtype.is_complex,
|
|
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
|
|
)
|
|
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
|
|
dim_size = n if n is not None else input.shape[dim]
|
|
torch._check(
|
|
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
|
|
)
|
|
|
|
if n is not None:
|
|
input = _resize_fft_input(input, dims, (n,))
|
|
|
|
ret = prims.fft_c2c(input, dim=dims, forward=forward)
|
|
return _apply_norm(ret, norm, dim_size, forward)
|
|
|
|
|
|
@register_decomposition(aten.fft_fft)
|
|
@out_wrapper()
|
|
def fft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
if input.dtype.is_complex:
|
|
return _fft_c2c("fft", input, n, dim, norm, forward=True)
|
|
else:
|
|
return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_ifft)
|
|
@out_wrapper()
|
|
def ifft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
if input.dtype.is_complex:
|
|
return _fft_c2c("ifft", input, n, dim, norm, forward=False)
|
|
else:
|
|
return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_rfft)
|
|
@out_wrapper()
|
|
def rfft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_irfft)
|
|
@out_wrapper()
|
|
def irfft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return _fft_c2r("irfft", input, n, dim, norm, forward=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_hfft)
|
|
@out_wrapper()
|
|
def hfft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return _fft_c2r("hfft", input, n, dim, norm, forward=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_ihfft)
|
|
@out_wrapper()
|
|
def ihfft(
|
|
input: TensorLikeType,
|
|
n: Optional[int] = None,
|
|
dim: int = -1,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
|
|
|
|
|
|
class _ShapeAndDims(NamedTuple):
|
|
shape: Tuple[int, ...]
|
|
dims: Tuple[int, ...]
|
|
|
|
|
|
def _canonicalize_fft_shape_and_dim_args(
|
|
input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
|
|
) -> _ShapeAndDims:
|
|
"""Convert the shape and dim arguments into a canonical form where neither are optional"""
|
|
input_dim = input.ndim
|
|
input_sizes = input.shape
|
|
|
|
if dim is not None:
|
|
if not isinstance(dim, Sequence):
|
|
dim = (dim,)
|
|
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
|
|
|
|
# Check dims are unique
|
|
torch._check(
|
|
len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
|
|
)
|
|
|
|
if shape is not None:
|
|
if not isinstance(shape, Sequence):
|
|
shape = (shape,)
|
|
|
|
# Has shape, might have dim
|
|
torch._check(
|
|
dim is None or len(dim) == len(shape),
|
|
lambda: "When given, dim and shape arguments must have the same length",
|
|
)
|
|
transform_ndim = len(shape)
|
|
|
|
torch._check(
|
|
transform_ndim <= input_dim,
|
|
lambda: f"Got shape with {transform_ndim} values but input tensor "
|
|
f"only has {input_dim} dimensions.",
|
|
)
|
|
|
|
# If shape is given, dims defaults to the last len(shape) dimensions
|
|
if dim is None:
|
|
ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
|
|
|
|
# Translate any -1 values in shape to the default length
|
|
ret_shape = tuple(
|
|
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
|
|
)
|
|
elif dim is None:
|
|
# No shape, no dim
|
|
ret_dims = tuple(range(input_dim))
|
|
ret_shape = tuple(input_sizes)
|
|
else:
|
|
# No shape, has dim
|
|
ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined]
|
|
|
|
for n in ret_shape:
|
|
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
|
|
|
|
return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined]
|
|
|
|
|
|
def _prod(xs: Iterable[int]) -> int:
|
|
"""Compute product of a list"""
|
|
prod = 1
|
|
for x in xs:
|
|
prod *= x
|
|
return prod
|
|
|
|
|
|
def _fftn_c2c(
|
|
function_name: str,
|
|
input: TensorLikeType,
|
|
shape: Tuple[int, ...],
|
|
dim: Tuple[int, ...],
|
|
norm: NormType,
|
|
forward: bool,
|
|
) -> TensorLikeType:
|
|
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
|
|
torch._check(
|
|
input.dtype.is_complex,
|
|
lambda: f"{function_name} expects a complex input tensor, "
|
|
f"but got {input.dtype}",
|
|
)
|
|
x = _resize_fft_input(input, dim, shape)
|
|
output = prims.fft_c2c(x, dim=dim, forward=forward)
|
|
return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
|
|
|
|
|
|
@register_decomposition(aten.fft_fftn)
|
|
@out_wrapper()
|
|
def fftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
x = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_ifftn)
|
|
@out_wrapper()
|
|
def ifftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
x = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_rfftn)
|
|
@out_wrapper()
|
|
def rfftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
torch._check(
|
|
not input.dtype.is_complex,
|
|
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
|
|
)
|
|
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
input = _maybe_promote_tensor_fft(input, require_complex=False)
|
|
input = _resize_fft_input(input, dim, shape)
|
|
out = prims.fft_r2c(input, dim=dim, onesided=True)
|
|
return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_ihfftn)
|
|
@out_wrapper()
|
|
def ihfftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
torch._check(
|
|
not input.dtype.is_complex,
|
|
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
|
|
)
|
|
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
|
|
input = _maybe_promote_tensor_fft(input, require_complex=False)
|
|
input = _resize_fft_input(input, dim, shape)
|
|
|
|
tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
|
|
|
|
if len(dim) == 1:
|
|
tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
|
|
return prims.conj(tmp)
|
|
|
|
tmp = prims.conj_physical(tmp)
|
|
tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
|
|
return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
|
|
|
|
|
|
class _CanonicalizeC2rReturn(NamedTuple):
|
|
shape: Tuple[int, ...]
|
|
dim: Tuple[int, ...]
|
|
last_dim_size: int
|
|
|
|
|
|
def _canonicalize_fft_c2r_shape_and_dim_args(
|
|
fname: str,
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType],
|
|
dim: Optional[DimsType],
|
|
) -> _CanonicalizeC2rReturn:
|
|
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms,
|
|
as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
|
|
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
|
|
torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
|
|
|
|
if s is None or s[-1] == -1:
|
|
last_dim_size = 2 * (input.shape[dim[-1]] - 1)
|
|
else:
|
|
last_dim_size = shape[-1]
|
|
|
|
torch._check(
|
|
last_dim_size >= 1,
|
|
lambda: f"Invalid number of data points ({last_dim_size}) specified",
|
|
)
|
|
|
|
shape_list = list(shape)
|
|
shape_list[-1] = last_dim_size // 2 + 1
|
|
return _CanonicalizeC2rReturn(
|
|
shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.fft_irfftn)
|
|
@out_wrapper()
|
|
def irfftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
|
|
"irfftn", input, s, dim
|
|
)
|
|
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
input = _resize_fft_input(input, dim, shape)
|
|
out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
|
|
return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
|
|
|
|
|
|
@register_decomposition(aten.fft_hfftn)
|
|
@out_wrapper()
|
|
def hfftn(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = None,
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
|
|
"hfftn", input, s, dim
|
|
)
|
|
input = _maybe_promote_tensor_fft(input, require_complex=True)
|
|
input = _resize_fft_input(input, dim, shape)
|
|
|
|
tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
|
|
tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
|
|
tmp = prims.conj_physical(tmp)
|
|
out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
|
|
return _apply_norm(out, norm, last_dim_size, forward=True)
|
|
|
|
|
|
@register_decomposition(aten.fft_fft2)
|
|
@out_wrapper()
|
|
def fft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_ifft2)
|
|
@out_wrapper()
|
|
def ifft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_rfft2)
|
|
@out_wrapper()
|
|
def rfft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_irfft2)
|
|
@out_wrapper()
|
|
def irfft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_hfft2)
|
|
@out_wrapper()
|
|
def hfft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
@register_decomposition(aten.fft_ihfft2)
|
|
@out_wrapper()
|
|
def ihfft2(
|
|
input: TensorLikeType,
|
|
s: Optional[ShapeType] = None,
|
|
dim: Optional[DimsType] = (-2, -1),
|
|
norm: NormType = None,
|
|
) -> TensorLikeType:
|
|
return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
|
|
|
|
|
|
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
|
|
"""Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
|
|
if dim is None:
|
|
return list(range(x.ndim))
|
|
elif not isinstance(dim, Sequence):
|
|
return [dim]
|
|
else:
|
|
return list(dim)
|
|
|
|
|
|
@register_decomposition(aten.fft_fftshift)
|
|
def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
|
|
dims = _default_alldims(dim, input)
|
|
shift = [input.shape[d] // 2 for d in dims]
|
|
return torch.roll(input, shift, dims)
|
|
|
|
|
|
@register_decomposition(aten.fft_ifftshift)
|
|
def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
|
|
dims = _default_alldims(dim, input)
|
|
shift = [(input.shape[d] + 1) // 2 for d in dims]
|
|
return torch.roll(input, shift, dims)
|