1797 lines
64 KiB
Python
1797 lines
64 KiB
Python
|
||
import warnings
|
||
|
||
# A workaround to support both TorchScript and MyPy:
|
||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
|
||
|
||
import torch
|
||
from torch import Tensor
|
||
from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor
|
||
from . import _docs
|
||
from torch._prims_common import corresponding_real_dtype
|
||
from torch import sym_float
|
||
|
||
if TYPE_CHECKING:
|
||
from torch.types import _dtype as DType
|
||
|
||
DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
|
||
else:
|
||
# The JIT doesn't understand Union, nor torch.dtype here
|
||
DType = int
|
||
DimOrDims = Optional[Tuple[int]]
|
||
|
||
|
||
__all__: List[str] = []
|
||
|
||
# All masked reduction/normalization operations have the same
|
||
# signatures. Here we introduce docstring templates that are applied
|
||
# to docstrings of reduction/normalization functions via
|
||
# _apply_docstring_templates decorator.
|
||
|
||
|
||
def _apply_docstring_templates(func):
|
||
"""Decorator that applies docstring templates to function docstring
|
||
and returns the function instance.
|
||
"""
|
||
|
||
doc_string = getattr(_docs, f"{func.__name__}_docstring", None)
|
||
if doc_string is None:
|
||
warnings.warn(
|
||
f"No documentation string available for {func.__name__}."
|
||
" PyTorch team should run `python tools/update_masked_docs.py`"
|
||
" to generate the missing docstrings."
|
||
)
|
||
else:
|
||
func.__doc__ = doc_string
|
||
|
||
# Expose function as public symbol
|
||
__all__.append(func.__name__)
|
||
|
||
return func
|
||
|
||
|
||
def _generate_docstring(func):
|
||
"""A utility function called from tools/update_masked_docs.py
|
||
script to update the module torch.masked._docs.py
|
||
"""
|
||
docstring_templates = dict(
|
||
reduction_signature="""\
|
||
{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
|
||
reduction_descr="""\
|
||
Returns {operation name} of all the elements in the :attr:`input`
|
||
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
|
||
elements are masked out according to the boolean tensor
|
||
:attr:`mask`.""",
|
||
reduction_args="""\
|
||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
|
||
size 1. Otherwise, :attr:`dim` is squeezed (see
|
||
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
|
||
``len(dim)``) fewer dimension(s).
|
||
|
||
The boolean tensor :attr:`mask` defines the "validity" of
|
||
:attr:`input` tensor elements: if :attr:`mask` element is True
|
||
then the corresponding element in :attr:`input` tensor will be
|
||
included in {operation name} computation, otherwise the element is
|
||
ignored.
|
||
|
||
When all elements of :attr:`input` along the given dimension
|
||
:attr:`dim` are ignored (fully masked-out), the corresponding element
|
||
of the output tensor will have undefined value: it may or may not
|
||
correspond to the identity value of {operation name} operation; the
|
||
choice may correspond to the value that leads to the most efficient
|
||
storage of :attr:`output` tensor.
|
||
|
||
The mask of the output tensor can be computed as
|
||
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
|
||
dtype=torch.bool)``.
|
||
|
||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||
don't need to match, but they must be :ref:`broadcastable
|
||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||
tensor must not be greater than of the :attr:`input` tensor.
|
||
|
||
Args:
|
||
input (Tensor): the input tensor
|
||
{args_declarations}
|
||
|
||
Keyword args:
|
||
{kwargs_declarations}""",
|
||
reduction_example="""\
|
||
Example::
|
||
|
||
>>> input = {example_input}
|
||
>>> input
|
||
{indent_example_input}
|
||
>>> mask = {example_mask}
|
||
>>> mask
|
||
{indent_example_mask}
|
||
>>> {full_function_name}(input, {example_args}, mask=mask)
|
||
{indent_example_output}
|
||
""",
|
||
reduction_identity="""\
|
||
The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""",
|
||
reduction_identity_dtype="""\
|
||
The identity value of {operation name} operation, which is used to start the
|
||
reduction, depends on input dtype. For instance, for float32, uint8,
|
||
and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""",
|
||
normalization_signature="""\
|
||
{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
|
||
normalization_descr="""\
|
||
Returns {operation name} of all the slices in the :attr:`input` tensor
|
||
along :attr:`dim` while the :attr:`input` elements are masked out
|
||
according to the boolean tensor :attr:`mask`.
|
||
|
||
{definition}""",
|
||
normalization_args="""\
|
||
The boolean tensor :attr:`mask` defines the "validity" of
|
||
:attr:`input` tensor elements: if :attr:`mask` element is True then
|
||
the corresponding element in :attr:`input` tensor will be included in
|
||
{operation name} computation, otherwise the element is ignored.
|
||
|
||
The values of masked-out elements of the output tensor have undefined
|
||
value: it may or may not be set to zero or nan; the choice may correspond to
|
||
the value that leads to the most efficient storage of :attr:`output`
|
||
tensor.
|
||
|
||
The mask of the {operation name} output tensor can be computed as
|
||
``torch.broadcast_to(mask, input.shape)``.
|
||
|
||
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
|
||
don't need to match, but they must be :ref:`broadcastable
|
||
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
|
||
tensor must not be greater than of the :attr:`input` tensor.
|
||
|
||
Args:
|
||
input (Tensor): the input tensor
|
||
{args_declarations}
|
||
|
||
Keyword args:
|
||
{kwargs_declarations}""",
|
||
normalization_example="""\
|
||
Example::
|
||
|
||
>>> input = {example_input}
|
||
>>> input
|
||
{indent_example_input}
|
||
>>> mask = {example_mask}
|
||
>>> mask
|
||
{indent_example_mask}
|
||
>>> {full_function_name}(input, {example_args}, mask=mask)
|
||
{indent_example_output}
|
||
""",
|
||
)
|
||
|
||
args_and_kwargs = dict(
|
||
# argument name sufficies separated by double underscore will
|
||
# be removed in the final documentation string.
|
||
sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||
prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||
cumsum=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||
cumprod=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||
amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||
amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||
argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||
argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||
mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||
median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||
norm=(
|
||
(
|
||
"ord",
|
||
"dim",
|
||
),
|
||
("keepdim=False", "dtype=None", "mask=None"),
|
||
),
|
||
var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
||
std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
||
logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||
softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||
log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||
softmin=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||
normalize=(
|
||
(
|
||
"ord__required",
|
||
"dim__as_int",
|
||
),
|
||
("eps=1e-12", "dtype=None", "mask=None"),
|
||
),
|
||
)
|
||
|
||
argument_declarations = dict(
|
||
dim="""\
|
||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||
Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
|
||
dim__as_int="""\
|
||
dim (int): the dimension along which {operation name} is computed.""",
|
||
ord="""\
|
||
ord (int, float, optional): the order of vector norm. Default: 2.
|
||
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
||
ord__required="""\
|
||
ord (int, float): the order of vector norm. Default: 2.
|
||
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
||
unbiased="""\
|
||
unbiased (bool): when True, use Bessel’s correction, otherwise, compute
|
||
the uncorrected sample variance.""",
|
||
eps="""\
|
||
eps (float, optional): small value to avoid division by zero. Default: {default}.""",
|
||
keepdim="""\
|
||
keepdim (bool, optional): whether the output tensor has
|
||
:attr:`dim` retained or not. Default: {default}.""",
|
||
dtype="""\
|
||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||
of returned tensor. If specified, the input tensor is
|
||
casted to :attr:`dtype` before the operation is
|
||
performed. Default: {default}.""",
|
||
mask="""\
|
||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||
containing the binary mask of validity of input tensor
|
||
elements.
|
||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
|
||
)
|
||
|
||
definitions = dict(
|
||
softmax="""\
|
||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
|
||
defined as ``exp(x[i])/sum(exp(x))``.""",
|
||
log_softmax="""\
|
||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
|
||
defined as ``log(exp(x[i])/sum(exp(x)))``.""",
|
||
softmin="""\
|
||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
|
||
defined as ``exp(-x[i])/sum(exp(-x))``.""",
|
||
normalize="""\
|
||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
|
||
defined as ``x[i]/max(norm(x, p), eps)``.""",
|
||
cumsum="""\
|
||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
||
defined as ``sum(x[:i])``.""",
|
||
cumprod="""\
|
||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
||
defined as ``prod(x[:i])``.""",
|
||
)
|
||
|
||
reduction_names = dict(
|
||
sum="sum",
|
||
prod="product",
|
||
amax="maximum",
|
||
amin="minimum",
|
||
argmax="argmax",
|
||
argmin="argmin",
|
||
mean="mean",
|
||
median="median",
|
||
norm="norm",
|
||
var="variance",
|
||
std="standard_deviation",
|
||
logsumexp="logsumexp",
|
||
)
|
||
|
||
normalization_names = dict(
|
||
softmax="softmax",
|
||
log_softmax="log_softmax",
|
||
softmin="softmin",
|
||
normalize="normalize",
|
||
cumsum="cumulative_sum",
|
||
cumprod="cumulative_prod",
|
||
)
|
||
|
||
operation_names = {}
|
||
operation_names.update(reduction_names)
|
||
operation_names.update(normalization_names)
|
||
|
||
# Default example data:
|
||
example_dim = 1
|
||
example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
|
||
example_mask = torch.tensor([[True, False, True], [False, False, False]])
|
||
example_args: Tuple[Any, ...]
|
||
if func.__name__ in {"norm", "normalize"}:
|
||
example_args = (2.0, example_dim)
|
||
example_input = example_input.to(dtype=torch.float32)
|
||
elif func.__name__ in {"var", "std"}:
|
||
example_args = (example_dim, False)
|
||
elif func.__name__ == "median":
|
||
example_args = (example_dim,)
|
||
example_input = example_input.to(dtype=torch.float32)
|
||
else:
|
||
example_args = (example_dim,)
|
||
|
||
operation_args: Tuple[str, ...]
|
||
operation_kwargs: Tuple[str, ...]
|
||
operation_args, operation_kwargs = args_and_kwargs[func.__name__]
|
||
arg_declarations = [
|
||
"\n ".join(
|
||
argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines()
|
||
)
|
||
for a in operation_args
|
||
]
|
||
kwarg_declarations = [
|
||
"\n ".join(
|
||
argument_declarations.get(
|
||
a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.'
|
||
)
|
||
.format(default=a.split("=", 1)[1])
|
||
.splitlines()
|
||
)
|
||
for a in operation_kwargs
|
||
]
|
||
|
||
if func.__name__ in reduction_names:
|
||
op_kind = "reduction"
|
||
doc_sections = ["signature", "descr", "identity", "args", "example"]
|
||
elif func.__name__ in normalization_names:
|
||
op_kind = "normalization"
|
||
doc_sections = ["signature", "descr", "args", "example"]
|
||
example_input = example_input.to(dtype=torch.float32)
|
||
else:
|
||
assert 0 # add function name to operation names dictionaries
|
||
example_output = func(example_input, *example_args, mask=example_mask)
|
||
|
||
template_data = {
|
||
"function_name": func.__name__,
|
||
"full_function_name": func.__module__ + "." + func.__name__,
|
||
"operation name": operation_names[func.__name__],
|
||
"operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args),
|
||
"operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs),
|
||
# one-line representation of a tensor:
|
||
"example_input": " ".join(str(example_input).split()),
|
||
"example_args": ", ".join(map(str, example_args)),
|
||
"example_mask": " ".join(str(example_mask).split()),
|
||
# multi-line representation of a tensor with indent
|
||
"indent_example_input": ("\n ").join(str(example_input).splitlines()),
|
||
"indent_example_mask": ("\n ").join(str(example_mask).splitlines()),
|
||
"indent_example_output": ("\n ").join(str(example_output).splitlines()),
|
||
}
|
||
|
||
if func.__name__ in reduction_names:
|
||
template_data.update(
|
||
identity_uint8=_reduction_identity(
|
||
func.__name__, torch.tensor(0, dtype=torch.uint8)
|
||
),
|
||
identity_int32=_reduction_identity(
|
||
func.__name__, torch.tensor(0, dtype=torch.int32)
|
||
),
|
||
identity_float32=_reduction_identity(
|
||
func.__name__, torch.tensor(0, dtype=torch.float32)
|
||
),
|
||
)
|
||
if func.__name__ == "norm":
|
||
template_data.update(
|
||
identity_ord_ninf=_reduction_identity(
|
||
func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf")
|
||
)
|
||
)
|
||
elif func.__name__ in normalization_names:
|
||
template_data.update(definition=definitions[func.__name__])
|
||
else:
|
||
assert 0 # add function name to operation names dictionaries
|
||
template_data.update(
|
||
args_declarations=("\n ".join(arg_declarations)).format_map(template_data)
|
||
)
|
||
template_data.update(
|
||
kwargs_declarations=("\n ".join(kwarg_declarations)).format_map(
|
||
template_data
|
||
)
|
||
)
|
||
|
||
# Apply function name info to docstring templates:
|
||
templates = {
|
||
k: v.format_map(template_data)
|
||
for k, v in docstring_templates.items()
|
||
if k.startswith(op_kind)
|
||
}
|
||
templates.update(
|
||
(k, v.format_map(template_data) if isinstance(v, str) else v)
|
||
for k, v in template_data.items()
|
||
)
|
||
|
||
# Apply docstring templates to function doctring:
|
||
if func.__doc__ is None:
|
||
doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections])
|
||
else:
|
||
doc_template = func.__doc__
|
||
return doc_template.format_map(templates)
|
||
|
||
|
||
def _reduction_identity(op_name: str, input: Tensor, *args):
|
||
"""Return identity value as scalar tensor of a reduction operation on
|
||
given input, or None, if the identity value cannot be uniquely
|
||
defined for the given input.
|
||
|
||
The identity value of the operation is defined as the initial
|
||
value to reduction operation that has a property ``op(op_identity,
|
||
value) == value`` for any value in the domain of the operation.
|
||
Or put it another way, including or excluding the identity value in
|
||
a list of operands will not change the reduction result.
|
||
|
||
See https://github.com/pytorch/rfcs/pull/27 for more information.
|
||
|
||
"""
|
||
dtype: DType = input.dtype
|
||
device = input.device
|
||
op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present
|
||
if op_name in {"sum", "cumsum"}:
|
||
return torch.tensor(0, dtype=dtype, device=device)
|
||
elif op_name in {"prod", "cumprod"}:
|
||
return torch.tensor(1, dtype=dtype, device=device)
|
||
elif op_name in {"amax", "argmax", "logsumexp"}:
|
||
if torch.is_floating_point(input):
|
||
return torch.tensor(-torch.inf, dtype=dtype, device=device)
|
||
elif torch.is_signed(input) or dtype == torch.uint8:
|
||
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
|
||
elif op_name in {"amin", "argmin"}:
|
||
if torch.is_floating_point(input):
|
||
return torch.tensor(torch.inf, dtype=dtype, device=device)
|
||
elif torch.is_signed(input) or dtype == torch.uint8:
|
||
return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
|
||
elif op_name == "mean":
|
||
# Strictly speaking, the identity value of the mean operation
|
||
# is the mean of the input. Since the mean value depends on
|
||
# the dim argument and it may be a non-scalar tensor, we
|
||
# consider the identity value of the mean operation ambiguous.
|
||
# Moreover, the mean value of empty input is undefined.
|
||
return None
|
||
elif op_name == "norm":
|
||
ord = args[0] if args else 2
|
||
if ord == float("-inf"):
|
||
assert torch.is_floating_point(input), input.dtype
|
||
return torch.tensor(torch.inf, dtype=dtype, device=device)
|
||
return torch.tensor(0, dtype=dtype, device=device)
|
||
elif op_name == "median":
|
||
# We use NaN for now because the implementation is currently using torch.nanmedian
|
||
# and NaN is the identity for that function since it gets ignored
|
||
dtype = input.dtype if torch.is_floating_point(input) else torch.float
|
||
return torch.tensor(torch.nan, dtype=dtype, device=device)
|
||
elif op_name in {"var", "std"}:
|
||
return None
|
||
raise NotImplementedError(f"identity of {op_name} on {dtype} input")
|
||
|
||
|
||
def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]:
|
||
"""Return dim argument as a tuple of sorted dim values."""
|
||
dims: List[int] = []
|
||
if dim == ():
|
||
# Currently, `dim=()` in reductions operations means "reduce
|
||
# over all dimensions" while in future, it will read "no
|
||
# reduce". See https://github.com/pytorch/pytorch/issues/29137
|
||
# When gh-29137 is resolved, this if-block must be deleted.
|
||
dim = None
|
||
if dim is None:
|
||
return tuple(range(ndim))
|
||
ndim = max(ndim, 1)
|
||
dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim
|
||
for d in dim_:
|
||
if d in dims:
|
||
raise RuntimeError(f"dim={d} appears multiple times in the list of dims")
|
||
if d >= ndim or d < -ndim:
|
||
raise IndexError(
|
||
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})"
|
||
)
|
||
dims.append(d % ndim)
|
||
return tuple(sorted(dims))
|
||
|
||
|
||
def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple):
|
||
# Flatted N-D indices to 1-D indices
|
||
flat_indices = indices.new_zeros(indices.size(1))
|
||
for d, sz in enumerate(shape):
|
||
flat_indices.mul_(sz)
|
||
flat_indices.add_(indices[d])
|
||
return flat_indices
|
||
|
||
|
||
def _any(input: Tensor, dim: tuple, keepdim: bool):
|
||
# Support torch.any with tuple dim argument.
|
||
# Workaround of https://github.com/pytorch/pytorch/issues/56586
|
||
r = input
|
||
for d in reversed(dim):
|
||
r = r.any(dim=d, keepdim=keepdim)
|
||
return r
|
||
|
||
|
||
def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
|
||
"""Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors.
|
||
|
||
_sparse_coo_where implements the following invariant:
|
||
|
||
_sparse_coo_where(mask, input, fill_value).to_dense(fill_value) ==
|
||
torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
|
||
|
||
where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
|
||
tensor, and `to_dense(fill_value)` is like `to_dense()` except
|
||
that the unspecified elements are mapped to `fill_value` rather
|
||
than to `0`.
|
||
|
||
Returns a sparse COO tensor with the following features:
|
||
|
||
- all specified elements correspond to masked-in elements that
|
||
have the values of the input tensor. If there exists a masked-in
|
||
element (as specified by mask) that is not specified in the
|
||
input, in the result tensor, the corresponding element has value
|
||
0. In the dense part of the sparse tensor, the masked-out
|
||
elements are replaced with fill_value.
|
||
|
||
- all unspecified elements correspond to masked-out elements.
|
||
"""
|
||
|
||
assert input.layout == torch.sparse_coo
|
||
assert mask.layout == input.layout
|
||
assert mask.shape == input.shape
|
||
assert mask.dense_dim() == input.dense_dim() # TODO: eliminate this restriction
|
||
|
||
input = input.coalesce()
|
||
|
||
# For set operations on sparse tensor indices, we'll convert
|
||
# multi-dimensional indices to 1-D indices for efficiency.
|
||
input_flat_indices = _sparse_coo_flatten_indices(
|
||
input.indices(), input.shape[: input.sparse_dim()]
|
||
)
|
||
mask_flat_indices = _sparse_coo_flatten_indices(
|
||
mask.indices(), mask.shape[: mask.sparse_dim()]
|
||
)
|
||
|
||
# the set of mask flat indices that define masked-in elements:
|
||
if mask.dense_dim() > 0:
|
||
mask_values = _any(
|
||
mask.values(), tuple(range(1, input.sparse_dim() + 1)), False
|
||
)
|
||
else:
|
||
mask_values = mask.values()
|
||
maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]]
|
||
|
||
def intersection(i1, i2):
|
||
union, counts = torch.cat([i1, i2]).unique(return_counts=True)
|
||
return union, torch.where(counts.gt(1))
|
||
|
||
def minus(i1, i2):
|
||
union, counts = torch.cat([i1, i2]).unique(return_counts=True)
|
||
return intersection(union[torch.where(counts.eq(1))], i1)
|
||
|
||
def _apply(a):
|
||
obj, w = a
|
||
return obj[w]
|
||
|
||
# the set of input flat indices of specified and masked-in elements:
|
||
maskin_input_flat_indices = _apply(
|
||
intersection(maskin_flat_indices, input_flat_indices)
|
||
)
|
||
_, w = intersection(input_flat_indices, maskin_input_flat_indices)
|
||
|
||
# the indices and values of masked-in elements
|
||
where_input_indices = input.indices()[(slice(None),) + w]
|
||
where_input_values = input.values()[w]
|
||
|
||
if mask.dense_dim() > 0:
|
||
# apply mask to the dense part of the input values:
|
||
_, w1 = intersection(mask_flat_indices, maskin_input_flat_indices)
|
||
where_mask_values = mask.values()[w1]
|
||
where_input_values = torch.where(
|
||
where_mask_values, where_input_values, fill_value
|
||
)
|
||
|
||
# the set of flat indices of unspecified input and masked-in elements:
|
||
maskin_zero_flat_indices = _apply(
|
||
minus(maskin_flat_indices, maskin_input_flat_indices)
|
||
)
|
||
|
||
# the indices of masked-in zero elements
|
||
_, w = intersection(mask_flat_indices, maskin_zero_flat_indices)
|
||
where_zero_indices = mask.indices()[(slice(None),) + w]
|
||
|
||
# construct result
|
||
n = where_zero_indices.size(1)
|
||
if n == 0:
|
||
# the input is coalesced, hence input_flat_indices are ordered
|
||
# and the result is guaranteed to be coalesced:
|
||
result = torch.sparse_coo_tensor(
|
||
where_input_indices, where_input_values, input.shape
|
||
)
|
||
return result._coalesced_(True)
|
||
|
||
where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1)
|
||
where_values = torch.cat(
|
||
[
|
||
where_input_values,
|
||
where_input_values.new_zeros((n,) + where_input_values.shape[1:]),
|
||
]
|
||
)
|
||
result = torch.sparse_coo_tensor(where_indices, where_values, input.shape)
|
||
|
||
# appending zero elements leads to uncoalesced sparse tensor
|
||
return result.coalesce()
|
||
|
||
|
||
def _sparse_coo_scatter_reduction_helper(
|
||
op,
|
||
mask_input: Tensor,
|
||
dims: Tuple[int, ...],
|
||
keepdim: bool,
|
||
dtype: Optional[DType] = None,
|
||
) -> Tensor:
|
||
reduce = op.__name__
|
||
valid_reductions = ["sum", "prod", "amax", "amin"]
|
||
if reduce not in valid_reductions:
|
||
raise ValueError(
|
||
f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
|
||
)
|
||
|
||
output_dtype = dtype
|
||
values, indices = mask_input._values(), mask_input._indices()
|
||
input_dims = mask_input.dim()
|
||
num_sparse_dims = mask_input.sparse_dim()
|
||
reduced_sparse_dims = []
|
||
retained_sparse_dims = []
|
||
reduced_dense_dims = []
|
||
|
||
# promote dtype if specified
|
||
if values.dtype != output_dtype:
|
||
values = values.to(output_dtype)
|
||
|
||
if keepdim:
|
||
output_shape = tuple(
|
||
1 if i in dims else si for (i, si) in enumerate(mask_input.shape)
|
||
)
|
||
else:
|
||
output_shape = tuple(
|
||
si for (i, si) in enumerate(mask_input.shape) if i not in dims
|
||
)
|
||
|
||
for d in dims:
|
||
if d >= input_dims:
|
||
continue
|
||
|
||
if d < num_sparse_dims:
|
||
reduced_sparse_dims.append(d)
|
||
else:
|
||
reduced_dense_dims.append(d + 1 - num_sparse_dims)
|
||
|
||
# Reduce dense dimensions
|
||
if len(reduced_dense_dims) > 0:
|
||
if reduce == "sum":
|
||
new_values = values
|
||
new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim))
|
||
else:
|
||
# FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities
|
||
return NotImplemented
|
||
else:
|
||
new_values = values.clone()
|
||
|
||
# Reduce sparse dimensions
|
||
if len(reduced_sparse_dims) == num_sparse_dims:
|
||
if reduce in {"amax", "amin"} and new_values.size(0) == 0:
|
||
# IndexError: amax(): Expected reduction dim 0 to have non-zero size.
|
||
# sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not
|
||
# See https://github.com/pytorch/pytorch/issues/61901
|
||
new_values = _reduction_identity(reduce, new_values)
|
||
else:
|
||
new_values = op(new_values, dim=0)
|
||
if keepdim:
|
||
for _ in range(num_sparse_dims):
|
||
new_values = new_values.unsqueeze(0)
|
||
return new_values.to(dtype=output_dtype).to_sparse()
|
||
else:
|
||
new_indices = indices.clone()
|
||
if keepdim:
|
||
# zero out reduced sparse dimensions if keepdim = True
|
||
# ensures that the call to torch.unique folds duplicated indices together while preserving the dimension
|
||
new_indices[reduced_sparse_dims, :] = 0
|
||
else:
|
||
# remove reduced sparse dimensions if keepdim = False
|
||
if len(reduced_sparse_dims) > 0:
|
||
retained_sparse_dims = [
|
||
i
|
||
for i in range(num_sparse_dims)
|
||
if i not in set(reduced_sparse_dims)
|
||
]
|
||
new_indices = new_indices.index_select(
|
||
0, torch.tensor(retained_sparse_dims).to(mask_input.device)
|
||
)
|
||
|
||
# Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices
|
||
if new_indices.numel() > 0:
|
||
# lexsort indices and get index tensor for scatter reduction
|
||
new_indices, inverse_indices = torch.unique(
|
||
new_indices, return_inverse=True, dim=1
|
||
)
|
||
out_shape = list(new_values.shape)
|
||
out_shape[0] = new_indices.shape[1]
|
||
for _ in range(new_values.ndim - 1):
|
||
inverse_indices = inverse_indices.unsqueeze(-1)
|
||
scatter_indices = inverse_indices.expand(new_values.shape)
|
||
# FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce
|
||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||
new_values = new_values.to(torch.float)
|
||
out = new_values.new_empty(out_shape)
|
||
new_values = out.scatter_reduce_(
|
||
0, scatter_indices, new_values, reduce=reduce, include_self=False
|
||
)
|
||
new_values = new_values.to(dtype=output_dtype)
|
||
else:
|
||
out = new_values.new_empty(out_shape)
|
||
new_values = out.scatter_reduce_(
|
||
0, scatter_indices, new_values, reduce=reduce, include_self=False
|
||
)
|
||
|
||
return torch.sparse_coo_tensor(
|
||
new_indices,
|
||
new_values,
|
||
output_shape,
|
||
dtype=output_dtype,
|
||
device=mask_input.device,
|
||
)
|
||
|
||
|
||
def _sparse_csr_segment_reduction_helper(
|
||
op,
|
||
mask_input: Tensor,
|
||
dims: Tuple[int, ...],
|
||
keepdim: bool,
|
||
dtype: Optional[DType] = None,
|
||
) -> Tensor:
|
||
# Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
|
||
# FIXME: when dense dimensions are implemented for CSR tensors
|
||
assert (
|
||
keepdim
|
||
), "reduction operations on CSR tensors with keepdim=False is unsupported"
|
||
reduce = op.__name__
|
||
valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
|
||
if reduce not in valid_reductions:
|
||
raise ValueError(
|
||
f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
|
||
)
|
||
device = mask_input.device
|
||
output_dtype = dtype
|
||
values, crow_indices, col_indices = (
|
||
mask_input.values(),
|
||
mask_input.crow_indices(),
|
||
mask_input.col_indices(),
|
||
)
|
||
|
||
# promote dtype if specified
|
||
if values.dtype != output_dtype:
|
||
values = values.to(output_dtype)
|
||
|
||
if len(dims) == 0:
|
||
return mask_input
|
||
if len(dims) == 1:
|
||
if dims[0] == 0:
|
||
new_col_indices, scatter_indices = torch.unique(
|
||
col_indices, return_inverse=True
|
||
)
|
||
new_nnz = new_col_indices.shape[0]
|
||
new_crow_indices = torch.tensor([0, new_nnz])
|
||
new_values = values.new_empty(new_col_indices.shape)
|
||
new_values.scatter_reduce_(
|
||
0, scatter_indices, values, reduce, include_self=False
|
||
)
|
||
new_shape = [1, mask_input.size(1)]
|
||
else:
|
||
assert (
|
||
dims[0] == 1
|
||
), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
|
||
# all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
|
||
# except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
|
||
new_crow_indices = torch.cat(
|
||
(
|
||
crow_indices.new_zeros(1),
|
||
torch.cumsum(torch.diff(crow_indices) != 0, 0),
|
||
),
|
||
0,
|
||
)
|
||
new_nnz = new_crow_indices[-1]
|
||
new_col_indices = col_indices.new_zeros(new_nnz)
|
||
new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined]
|
||
new_shape = [mask_input.size(0), 1]
|
||
else:
|
||
assert len(dims) == 2
|
||
nnz = min(1, values.numel())
|
||
if nnz == 1:
|
||
op_kwargs = {"keepdim": True, "dtype": output_dtype}
|
||
# amax and amin do not support dtype kwarg
|
||
if reduce in ["amax", "amin"]:
|
||
del op_kwargs["dtype"]
|
||
new_values = op(values, 0, **op_kwargs)
|
||
else:
|
||
new_values = torch.empty(0, dtype=output_dtype)
|
||
new_col_indices = col_indices.new_zeros(nnz)
|
||
new_crow_indices = torch.tensor([0, nnz])
|
||
new_shape = [1, nnz]
|
||
|
||
return torch.sparse_csr_tensor(
|
||
new_crow_indices,
|
||
new_col_indices,
|
||
new_values,
|
||
new_shape,
|
||
dtype=output_dtype,
|
||
device=device,
|
||
)
|
||
|
||
|
||
def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
|
||
"""Sparse variant of torch.where. Supports sparse CSR tensors."""
|
||
# TODO: implement sparse CSR specific where operator for efficiency
|
||
return _sparse_coo_where(
|
||
mask.to_sparse_coo(), input.to_sparse_coo(), fill_value
|
||
).to_sparse_csr()
|
||
|
||
|
||
def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
|
||
"""torch.where with sparse inputs support.
|
||
|
||
_where implements the following invariant:
|
||
|
||
_where(mask, input, fill_value).to_dense(fill_value) ==
|
||
torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
|
||
|
||
where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
|
||
tensor, and `to_dense(fill_value)` is like `to_dense()` except
|
||
that the unspecified elements are mapped to `fill_value` rather
|
||
than to `0`.
|
||
|
||
Returns a sparse tensor with the following features:
|
||
|
||
- all specified elements correspond to masked-in elements that
|
||
have the values of the input tensor. If there exists a masked-in
|
||
element (as specified by mask) that is not specified in the
|
||
input, in the result tensor, the corresponding element has value
|
||
0. In the dense part of the sparse tensor, the masked-out
|
||
elements are replaced with fill_value.
|
||
|
||
- all unspecified elements correspond to masked-out elements.
|
||
"""
|
||
if mask.layout == torch.strided:
|
||
return torch.where(mask, input, fill_value)
|
||
elif mask.layout == torch.sparse_coo:
|
||
return _sparse_coo_where(mask, input, fill_value)
|
||
elif mask.layout == torch.sparse_csr:
|
||
return _sparse_csr_where(mask, input, fill_value)
|
||
else:
|
||
raise ValueError(
|
||
f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}"
|
||
)
|
||
|
||
|
||
def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor:
|
||
"""Return canonical input mask.
|
||
|
||
A canonical input mask is defined as a boolean mask tensor that
|
||
shape and layout matches with the shape and the layout of the
|
||
input.
|
||
|
||
The canonical input mask is computed from the :attr:`mask` tensor
|
||
content to meet the following criteria:
|
||
|
||
1. The shape of the canonical input mask is the same as the shape
|
||
of :attr:`input` tensor. If the mask tensor has a smaller shape
|
||
than the shape of the :attr:`input`, broadcasting rules will be
|
||
applied. Downcasting of mask is not supported.
|
||
|
||
2. The layout of the canonical input mask is the same as the
|
||
layout of the :attr:`input` tensor. If the mask has different
|
||
layout, it will be converted to the expected layout. In the
|
||
case of sparse COO layout, the canonical input mask will be
|
||
coalesced.
|
||
|
||
3. The dtype of the canonical input mask is torch.bool. If the
|
||
mask dtype is not bool then it will be converted to bool dtype
|
||
using `.to(dtype=bool)` method call.
|
||
|
||
4. The elements of the canonical input mask have boolean values
|
||
copied from the content of the :attr:`mask` tensor (after
|
||
possible broadcasting and dtype conversion transforms). In
|
||
general, the sparsity pattern of the sparse canonical input
|
||
mask need not to be the same as the sparsity pattern of the
|
||
sparse :attr:`input` tensor.
|
||
|
||
"""
|
||
if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
|
||
raise ValueError(
|
||
f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}"
|
||
)
|
||
|
||
mask = kwargs.get("mask")
|
||
|
||
# default mask
|
||
if mask is None:
|
||
raise ValueError("_input_mask requires explicit mask")
|
||
|
||
# mask shape must match with input shape
|
||
if mask.shape != input.shape:
|
||
if mask.ndim > input.ndim:
|
||
raise IndexError(
|
||
"_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)"
|
||
)
|
||
if mask.layout == torch.strided:
|
||
mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool)
|
||
elif mask.layout == torch.sparse_coo:
|
||
mask = torch._sparse_broadcast_to(mask, input.shape)
|
||
else:
|
||
assert mask.layout == torch.sparse_csr
|
||
# Broadcasting of CSR tensors is not implemented. Working
|
||
# around by using COO layout.
|
||
mask = torch._sparse_broadcast_to(
|
||
mask.to_sparse(), input.shape
|
||
).to_sparse_csr()
|
||
|
||
# mask layout must match with input layout
|
||
if mask.layout != input.layout:
|
||
if input.layout == torch.strided:
|
||
mask = mask.to_dense()
|
||
elif input.layout == torch.sparse_coo:
|
||
if mask.layout == torch.strided:
|
||
mask = mask.to_sparse(input.sparse_dim())
|
||
else:
|
||
mask = mask.to_sparse()
|
||
else:
|
||
assert input.layout == torch.sparse_csr
|
||
mask = mask.to_sparse_csr()
|
||
|
||
# sparse mask must be coalesced
|
||
if mask.layout == torch.sparse_coo:
|
||
mask = mask.coalesce()
|
||
|
||
# mask is a boolean tensor
|
||
mask = mask.to(dtype=torch.bool)
|
||
|
||
return mask
|
||
|
||
|
||
def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
|
||
"""Return output mask of masked operation applied to given arguments."""
|
||
if callable(op):
|
||
is_reduction = op.__name__ in {
|
||
"sum",
|
||
"prod",
|
||
"amax",
|
||
"amin",
|
||
"argmax",
|
||
"argmin",
|
||
"mean",
|
||
"median",
|
||
"norm",
|
||
"var",
|
||
"std",
|
||
"logsumexp",
|
||
}
|
||
is_normalization = op.__name__ in {
|
||
"softmax",
|
||
"log_softmax",
|
||
"softmin",
|
||
"normalize",
|
||
"cumsum",
|
||
"cumprod",
|
||
}
|
||
if is_reduction:
|
||
if op.__name__ == "norm":
|
||
if args:
|
||
args = args[1:] # lstrip ord argument
|
||
dim = args[0] if args else kwargs.get("dim")
|
||
outmask = _input_mask(input, *args, **kwargs)
|
||
keepdim = kwargs.get("keepdim", False)
|
||
dim_ = _canonical_dim(dim, input.ndim)
|
||
return _any(outmask, dim_, bool(keepdim))
|
||
elif is_normalization:
|
||
return _input_mask(input, *args, **kwargs)
|
||
else:
|
||
raise ValueError(
|
||
f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})"
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"_output_mask expected masked operation (got {type(op).__name__} object)"
|
||
)
|
||
|
||
|
||
def _combine_input_and_mask(
|
||
op, input: Union[MaskedTensor, Tensor], mask, *args
|
||
) -> Tensor:
|
||
def helper(input, mask):
|
||
if mask is None:
|
||
return input
|
||
canonical_mask = _input_mask(input, mask=mask)
|
||
if callable(op):
|
||
fill_value = _reduction_identity(op.__name__, input, *args)
|
||
return _where(canonical_mask, input, fill_value)
|
||
else:
|
||
raise ValueError(
|
||
f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)"
|
||
)
|
||
|
||
class Combine(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, input, mask):
|
||
"""Return input with masked-out elements eliminated for the given operations."""
|
||
ctx.save_for_backward(mask)
|
||
|
||
if mask is not None:
|
||
ctx.mark_non_differentiable(mask)
|
||
|
||
return helper(input, mask)
|
||
|
||
@staticmethod
|
||
def backward(ctx, grad_output):
|
||
(mask,) = ctx.saved_tensors
|
||
grad_data = (
|
||
grad_output.get_data() if is_masked_tensor(grad_output) else grad_output
|
||
)
|
||
result = as_masked_tensor(grad_data, mask)
|
||
return result, None
|
||
|
||
return (
|
||
Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr]
|
||
if is_masked_tensor(input)
|
||
else helper(input, mask)
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def sum(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: DimOrDims = None,
|
||
*,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
# __doc__ is generated by _apply_docstring_templates decorator
|
||
if dtype is None:
|
||
# promote integer types to int64 when output dtype is not specified
|
||
if input.layout == torch.sparse_csr:
|
||
if input.dtype in {
|
||
torch.uint8,
|
||
torch.bool,
|
||
torch.int8,
|
||
torch.int16,
|
||
torch.int32,
|
||
}:
|
||
# csr.to(dtype=torch.int64) is not implemented, so
|
||
# using coo.to on input to ensure the promoted dtype
|
||
input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
|
||
else:
|
||
dtype = input.dtype
|
||
else:
|
||
dtype = input.dtype
|
||
if input.dtype in {
|
||
torch.uint8,
|
||
torch.bool,
|
||
torch.int8,
|
||
torch.int16,
|
||
torch.int32,
|
||
}:
|
||
dtype = torch.int64
|
||
dim_ = _canonical_dim(dim, input.ndim)
|
||
mask_input = _combine_input_and_mask(sum, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype)
|
||
elif mask_input.layout == torch.sparse_coo:
|
||
return _sparse_coo_scatter_reduction_helper(
|
||
torch.sum, mask_input, dim_, bool(keepdim), dtype
|
||
)
|
||
elif mask_input.layout == torch.sparse_csr:
|
||
return torch._sparse_csr_sum(
|
||
mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def prod(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: DimOrDims = None,
|
||
*,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
# __doc__ is generated by _apply_docstring_templates decorator
|
||
if dtype is None:
|
||
# promote integer types to int64 when output dtype is not specified
|
||
if input.layout == torch.sparse_csr:
|
||
if input.dtype in {
|
||
torch.uint8,
|
||
torch.bool,
|
||
torch.int8,
|
||
torch.int16,
|
||
torch.int32,
|
||
}:
|
||
# csr.to(dtype=torch.int64) is not implemented, so
|
||
# using coo.to on input to ensure the promoted dtype
|
||
input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
|
||
else:
|
||
dtype = input.dtype
|
||
else:
|
||
dtype = input.dtype
|
||
if input.dtype in {
|
||
torch.uint8,
|
||
torch.bool,
|
||
torch.int8,
|
||
torch.int16,
|
||
torch.int32,
|
||
}:
|
||
dtype = torch.int64
|
||
dim_ = _canonical_dim(dim, input.ndim)
|
||
mask_input = _combine_input_and_mask(prod, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
# Workaround https://github.com/pytorch/pytorch/issues/56586
|
||
result = mask_input
|
||
result = result.to(dtype=dtype)
|
||
for d in reversed(dim_):
|
||
result = result.prod(dim=d, keepdim=bool(keepdim))
|
||
return result
|
||
elif mask_input.layout == torch.sparse_coo:
|
||
if mask is None:
|
||
# See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors
|
||
raise ValueError(
|
||
"masked prod expects explicit mask for sparse_coo tensor input"
|
||
)
|
||
return _sparse_coo_scatter_reduction_helper(
|
||
torch.prod, mask_input, dim_, bool(keepdim), dtype
|
||
)
|
||
elif mask_input.layout == torch.sparse_csr:
|
||
if mask is None:
|
||
# mask is None corresponds to all-True mask. The
|
||
# unspecified elements in the CSR tensor correspond to
|
||
# zero values. Hence, the prod reduction result is
|
||
# automatically zero unless all elements are specified.
|
||
# A semi-optimal way to take this into account is to use:
|
||
#
|
||
# masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...)
|
||
#
|
||
# but that requires implementing `all` and `nonzero`
|
||
# support for sparse csr tensors.
|
||
raise ValueError(
|
||
"masked prod expects explicit mask for sparse_csr tensor input"
|
||
)
|
||
return torch._sparse_csr_prod(
|
||
mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def cumsum(
|
||
input: Tensor,
|
||
dim: int,
|
||
*,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
dim_ = _canonical_dim(dim, input.ndim)[0]
|
||
mask_input = _combine_input_and_mask(sum, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def cumprod(
|
||
input: Tensor,
|
||
dim: int,
|
||
*,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
dim_ = _canonical_dim(dim, input.ndim)[0]
|
||
mask_input = _combine_input_and_mask(prod, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def amax(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: DimOrDims = None,
|
||
*,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
"""\
|
||
{reduction_signature}
|
||
|
||
{reduction_descr}
|
||
|
||
{reduction_identity_dtype}
|
||
|
||
{reduction_args}
|
||
|
||
{reduction_example}"""
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
|
||
mask_input = _combine_input_and_mask(amax, input, mask)
|
||
dim_ = _canonical_dim(dim, mask_input.ndim)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
|
||
elif mask_input.layout == torch.sparse_coo:
|
||
if mask is None:
|
||
# See comment in the sparse_csr branch of prod, a similar issue arises here
|
||
# where unspecified elements along a dimension may need to be reduced with the result
|
||
raise ValueError(
|
||
"masked amax expects explicit mask for sparse_coo tensor input"
|
||
)
|
||
return _sparse_coo_scatter_reduction_helper(
|
||
torch.amax, mask_input, dim_, bool(keepdim), dtype
|
||
)
|
||
elif mask_input.layout == torch.sparse_csr:
|
||
if mask is None:
|
||
raise ValueError(
|
||
"masked amax expects explicit mask for sparse_csr tensor input"
|
||
)
|
||
return _sparse_csr_segment_reduction_helper(
|
||
torch.amax, mask_input, dim_, bool(keepdim), dtype
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def amin(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: DimOrDims = None,
|
||
*,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
"""\
|
||
{reduction_signature}
|
||
|
||
{reduction_descr}
|
||
|
||
{reduction_identity_dtype}
|
||
|
||
{reduction_args}
|
||
|
||
{reduction_example}"""
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
|
||
mask_input = _combine_input_and_mask(amin, input, mask)
|
||
dim_ = _canonical_dim(dim, mask_input.ndim)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
|
||
elif mask_input.layout == torch.sparse_coo:
|
||
if mask is None:
|
||
# See comment in the sparse_csr branch of prod, a similar issue arises here
|
||
# where unspecified elements along a dimension may need to be reduced with the result
|
||
raise ValueError(
|
||
"masked amax expects explicit mask for sparse_coo tensor input"
|
||
)
|
||
return _sparse_coo_scatter_reduction_helper(
|
||
torch.amin, mask_input, dim_, bool(keepdim), dtype
|
||
)
|
||
elif mask_input.layout == torch.sparse_csr:
|
||
if mask is None:
|
||
raise ValueError(
|
||
"masked amin expects explicit mask for sparse_csr tensor input"
|
||
)
|
||
return _sparse_csr_segment_reduction_helper(
|
||
torch.amin, mask_input, dim_, bool(keepdim), dtype
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def argmax(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: Optional[int] = None,
|
||
*,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
"""\
|
||
{reduction_signature}
|
||
{reduction_descr}
|
||
{reduction_identity_dtype}
|
||
{reduction_args}
|
||
{reduction_example}"""
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
mask_input = _combine_input_and_mask(argmax, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
f"masked argmax expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def argmin(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: Optional[int] = None,
|
||
*,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
"""\
|
||
{reduction_signature}
|
||
{reduction_descr}
|
||
{reduction_identity_dtype}
|
||
{reduction_args}
|
||
{reduction_example}"""
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
mask_input = _combine_input_and_mask(argmin, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
f"masked argmin expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def mean(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: DimOrDims = None,
|
||
*,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
"""\
|
||
{reduction_signature}
|
||
|
||
{reduction_descr}
|
||
|
||
By definition, the identity value of a mean operation is the mean
|
||
value of the tensor. If all elements of the input tensor along given
|
||
dimension(s) :attr:`dim` are masked-out, the identity value of the
|
||
mean is undefined. Due to this ambiguity, the elements of output
|
||
tensor with strided layout, that correspond to fully masked-out
|
||
elements, have ``nan`` values.
|
||
|
||
{reduction_args}
|
||
|
||
{reduction_example}"""
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
if input.layout == torch.strided:
|
||
if mask is None:
|
||
# TODO: compute count analytically
|
||
count = sum(
|
||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||
dim,
|
||
keepdim=keepdim,
|
||
)
|
||
total = sum(input, dim, keepdim=keepdim, dtype=dtype)
|
||
else:
|
||
inmask = _input_mask(input, mask=mask)
|
||
count = sum(
|
||
inmask.new_ones(input.shape, dtype=torch.int64),
|
||
dim,
|
||
keepdim=keepdim,
|
||
mask=inmask,
|
||
)
|
||
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
|
||
return total / count
|
||
elif input.layout == torch.sparse_csr:
|
||
mask_input = _combine_input_and_mask(mean, input, mask)
|
||
dim_ = _canonical_dim(dim, mask_input.ndim)
|
||
if mask is None:
|
||
raise ValueError(
|
||
"masked mean expects explicit mask for sparse_csr tensor input"
|
||
)
|
||
return _sparse_csr_segment_reduction_helper(
|
||
torch.mean, mask_input, dim_, bool(keepdim), dtype
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def median(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: int = -1,
|
||
*,
|
||
keepdim: bool = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
|
||
"""\
|
||
{reduction_signature}
|
||
{reduction_descr}
|
||
By definition, the identity value of a median operation is the median
|
||
value of the tensor. If all elements of the input tensor along given
|
||
dimension(s) :attr:`dim` are masked-out, the identity value of the
|
||
median is undefined. Due to this ambiguity, the elements of output
|
||
tensor with strided layout, that correspond to fully masked-out
|
||
elements, have ``nan`` values.
|
||
{reduction_args}
|
||
{reduction_example}"""
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
dim_ = _canonical_dim(dim, input.ndim)[0]
|
||
is_float = torch.is_floating_point(input)
|
||
if not is_float:
|
||
input = input.to(dtype=torch.float)
|
||
mask_input = _combine_input_and_mask(median, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
output = torch.nanmedian(mask_input, dim_, keepdim).values
|
||
if is_float:
|
||
return output
|
||
elif not is_float and not torch.isnan(output).any():
|
||
return output.to(dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
"masked median expects no fully masked out rows if dtype is not floating point"
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"masked median expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def logsumexp(
|
||
input: Tensor,
|
||
dim: DimOrDims = None,
|
||
*,
|
||
keepdim: bool = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
dim_ = _canonical_dim(dim, input.ndim)
|
||
mask_input = _combine_input_and_mask(logsumexp, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
# Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations
|
||
def logaddexp(
|
||
input: Union[Tensor, MaskedTensor],
|
||
other: Union[Tensor, MaskedTensor],
|
||
*,
|
||
dtype: Optional[DType] = None,
|
||
input_mask: Optional[Tensor] = None,
|
||
other_mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
"""logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor
|
||
|
||
Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
|
||
tensor. The :attr:`input` elements are masked out according to the boolean tensor
|
||
:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
|
||
:attr:`other_mask`.
|
||
|
||
The shapes of a mask tensor and the tensor to be masked
|
||
don't need to match, but they must be :ref:`broadcastable
|
||
<broadcasting-semantics>` and the dimensionality of the mask
|
||
tensor must not be greater than of the tensor to be masked.
|
||
|
||
Args:
|
||
input (Tensor): the input tensor
|
||
other (Tensor): the second input tensor
|
||
|
||
Keyword args:
|
||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||
of returned tensor. If specified, the output tensor is
|
||
casted to :attr:`dtype` after the operation is
|
||
performed. Default: None.
|
||
input_mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||
containing the binary mask of validity of :attr:`input` tensor elements.
|
||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||
other_mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||
containing the binary mask of validity of :attr:`other` tensor elements.
|
||
Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
|
||
|
||
Example::
|
||
|
||
>>> input = torch.tensor([-100.0, -200, -300])
|
||
>>> input
|
||
tensor([-100., -200., -300.])
|
||
>>> other = torch.tensor([-1.0, -2, -3])
|
||
>>> other
|
||
tensor([-1., -2., -3.])
|
||
>>> mask = torch.tensor([True, False, True])
|
||
>>> mask
|
||
tensor([ True, False, True])
|
||
>>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
|
||
tensor([-1., -inf, -3.])
|
||
"""
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
if input.layout == torch.strided and other.layout == torch.strided:
|
||
mask_input = _combine_input_and_mask(logsumexp, input, input_mask)
|
||
mask_other = _combine_input_and_mask(logsumexp, other, other_mask)
|
||
return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def norm(
|
||
input: Union[Tensor, MaskedTensor],
|
||
ord: Optional[float] = 2.0,
|
||
dim: DimOrDims = None,
|
||
*,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
"""\
|
||
{reduction_signature}
|
||
|
||
{reduction_descr}
|
||
|
||
The identity value of norm operation, which is used to start the
|
||
reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is
|
||
``{identity_ord_ninf}``.
|
||
|
||
{reduction_args}
|
||
|
||
{reduction_example}"""
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
mask_input = _combine_input_and_mask(norm, input, mask, ord)
|
||
if mask_input.layout == torch.strided:
|
||
dim_ = _canonical_dim(dim, input.ndim)
|
||
return torch.linalg.vector_norm(
|
||
mask_input, ord, dim_, bool(keepdim), dtype=dtype
|
||
)
|
||
else:
|
||
raise ValueError(
|
||
f"masked norm expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
def _std_var(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: DimOrDims,
|
||
unbiased: Optional[bool],
|
||
*,
|
||
correction_opt: Optional[Union[int, float]],
|
||
keepdim: Optional[bool],
|
||
dtype: Optional[DType],
|
||
mask: Optional[Tensor],
|
||
take_sqrt: Optional[bool],
|
||
) -> Tensor:
|
||
assert (unbiased is None or correction_opt is None), "Only one of unbiased and correction may be given"
|
||
correction = 1.0
|
||
if unbiased is not None:
|
||
correction = 1.0 if unbiased else 0.0
|
||
if correction_opt is not None:
|
||
correction = sym_float(correction_opt)
|
||
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
if not (dtype.is_floating_point or dtype.is_complex):
|
||
dtype = torch.float32
|
||
compute_dtype = dtype
|
||
if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
|
||
compute_dtype = torch.float32
|
||
if input.layout == torch.strided:
|
||
if mask is None:
|
||
# TODO: compute count analytically
|
||
count = sum(
|
||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||
dim,
|
||
keepdim=True,
|
||
)
|
||
sample_total = sum(input, dim, keepdim=True, dtype=dtype)
|
||
else:
|
||
inmask = _input_mask(input, mask=mask)
|
||
count = sum(
|
||
inmask.new_ones(input.shape, dtype=torch.int64),
|
||
dim,
|
||
keepdim=True,
|
||
mask=inmask,
|
||
)
|
||
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
|
||
# TODO: replace torch.subtract/divide/square/maximum with
|
||
# masked subtract/divide/square/maximum when these will be
|
||
# available.
|
||
sample_mean = torch.divide(sample_total, count)
|
||
x = torch.subtract(input, sample_mean)
|
||
if mask is None:
|
||
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
||
else:
|
||
total = sum(
|
||
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined]
|
||
)
|
||
if not keepdim:
|
||
count = count.reshape(total.shape)
|
||
if correction != 0:
|
||
real_dtype = (corresponding_real_dtype(compute_dtype)
|
||
if compute_dtype.is_complex else compute_dtype)
|
||
count = count.to(real_dtype)
|
||
count = torch.subtract(count, correction)
|
||
count = torch.maximum(count, count.new_zeros([]))
|
||
output = torch.divide(total, count).to(dtype=dtype)
|
||
if take_sqrt:
|
||
output = torch.sqrt(output)
|
||
return output
|
||
else:
|
||
raise ValueError(
|
||
f"masked std/var expects strided tensor (got {input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def var(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: DimOrDims = None,
|
||
unbiased: Optional[bool] = None,
|
||
*,
|
||
correction: Optional[Union[int, float]] = None,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
"""\
|
||
{reduction_signature}
|
||
{reduction_descr}
|
||
The identity value of sample variance operation is undefined. The
|
||
elements of output tensor with strided layout, that correspond to
|
||
fully masked-out elements, have ``nan`` values.
|
||
{reduction_args}
|
||
{reduction_example}"""
|
||
return _std_var(
|
||
input=input,
|
||
dim=dim,
|
||
unbiased=unbiased,
|
||
correction_opt=correction,
|
||
keepdim=keepdim,
|
||
dtype=dtype,
|
||
mask=mask,
|
||
take_sqrt=False,
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def std(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: DimOrDims = None,
|
||
unbiased: Optional[bool] = None,
|
||
*,
|
||
correction: Optional[int] = None,
|
||
keepdim: Optional[bool] = False,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
"""\
|
||
{reduction_signature}
|
||
{reduction_descr}
|
||
The identity value of sample standard deviation operation is undefined. The
|
||
elements of output tensor with strided layout, that correspond to
|
||
fully masked-out elements, have ``nan`` values.
|
||
{reduction_args}
|
||
{reduction_example}"""
|
||
return _std_var(
|
||
input=input,
|
||
dim=dim,
|
||
unbiased=unbiased,
|
||
correction_opt=correction,
|
||
keepdim=keepdim,
|
||
dtype=dtype,
|
||
mask=mask,
|
||
take_sqrt=True,
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def softmax(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: int,
|
||
*,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
dim_ = _canonical_dim(dim, input.ndim)[0]
|
||
mask_input = _combine_input_and_mask(amax, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
f"masked softmax expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def log_softmax(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: int,
|
||
*,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
dim_ = _canonical_dim(dim, input.ndim)[0]
|
||
mask_input = _combine_input_and_mask(amax, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def softmin(
|
||
input: Union[Tensor, MaskedTensor],
|
||
dim: int,
|
||
*,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
dim_ = _canonical_dim(dim, input.ndim)[0]
|
||
mask_input = _combine_input_and_mask(amin, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
|
||
else:
|
||
raise ValueError(
|
||
f"masked softmin expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|
||
|
||
|
||
@_apply_docstring_templates
|
||
def normalize(
|
||
input: Union[Tensor, MaskedTensor],
|
||
ord: float,
|
||
dim: int,
|
||
*,
|
||
eps: float = 1e-12,
|
||
dtype: Optional[DType] = None,
|
||
mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
if dtype is None:
|
||
dtype = input.dtype
|
||
dim_ = _canonical_dim(dim, input.ndim)[0]
|
||
# TODO: eliminate mask_input as unnecessary when using masked divide.
|
||
mask_input = _combine_input_and_mask(sum, input, mask)
|
||
if mask_input.layout == torch.strided:
|
||
nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
|
||
# TODO: replace torch.maximum with masked maximum when available.
|
||
denom = torch.maximum(nrm_, nrm_.new_full([], eps))
|
||
# TODO: replace torch.divide with masked divide when available.
|
||
return torch.divide(mask_input, denom)
|
||
else:
|
||
raise ValueError(
|
||
f"masked normalize expects strided tensor (got {mask_input.layout} tensor)"
|
||
)
|