354 lines
13 KiB
Python
354 lines
13 KiB
Python
|
"""Defines bias subclasses that work with scaled_dot_product_attention"""
|
||
|
from enum import auto, IntEnum
|
||
|
from typing import Optional
|
||
|
from warnings import warn
|
||
|
|
||
|
import torch
|
||
|
from torch.backends.cuda import (
|
||
|
can_use_efficient_attention,
|
||
|
can_use_flash_attention,
|
||
|
SDPAParams,
|
||
|
)
|
||
|
from torch.nn.attention import _raise_kernel_warnings
|
||
|
from torch.nn.attention._utils import (
|
||
|
_calculate_scale,
|
||
|
_input_requires_grad,
|
||
|
_postprocess_flash_output,
|
||
|
_validate_sdpa_input,
|
||
|
)
|
||
|
from torch.nn.functional import scaled_dot_product_attention
|
||
|
|
||
|
__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
|
||
|
|
||
|
|
||
|
torch._dynamo.allow_in_graph(can_use_flash_attention)
|
||
|
torch._dynamo.allow_in_graph(can_use_efficient_attention)
|
||
|
torch._dynamo.allow_in_graph(SDPAParams)
|
||
|
|
||
|
|
||
|
class CausalVariant(IntEnum):
|
||
|
r"""
|
||
|
Enum for causal variants used in attention mechanisms.
|
||
|
|
||
|
Defines two types of causal biases:
|
||
|
|
||
|
`UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention.
|
||
|
The equivalent pytorch code for constructing this bias is:
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
torch.tril(torch.ones(size, dtype=torch.bool))
|
||
|
|
||
|
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
|
||
|
|
||
|
.. code-block:: text
|
||
|
|
||
|
[[1, 0, 0, 0],
|
||
|
[1, 1, 0, 0],
|
||
|
[1, 1, 1, 0]]
|
||
|
|
||
|
|
||
|
`LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower
|
||
|
right corner of the matrix.
|
||
|
|
||
|
The equivalent pytorch code for constructing this bias is:
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
diagonal_offset = size[1] - size[0]
|
||
|
torch.tril(
|
||
|
torch.ones(size, dtype=torch.bool),
|
||
|
diagonal=diagonal_offset,
|
||
|
)
|
||
|
|
||
|
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
|
||
|
|
||
|
.. code-block:: text
|
||
|
|
||
|
[[1, 1, 0, 0],
|
||
|
[1, 1, 1, 0],
|
||
|
[1, 1, 1, 1]]
|
||
|
|
||
|
Note that these variants are equivalent to each other when the sequence lengths of the query and key/value
|
||
|
tensors are equal since the triangular matrix is square.
|
||
|
|
||
|
.. warning:: This enum is a prototype and subject to change.
|
||
|
"""
|
||
|
|
||
|
UPPER_LEFT = auto()
|
||
|
LOWER_RIGHT = auto()
|
||
|
|
||
|
|
||
|
class CausalBias(torch.Tensor):
|
||
|
"""
|
||
|
A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.
|
||
|
|
||
|
This class is used for defining causal (triangular) attention biases. For construing the bias, there exist
|
||
|
two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
from torch.nn.attention.bias import causal_lower_right
|
||
|
|
||
|
bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8
|
||
|
|
||
|
# Create a lower-right causal bias
|
||
|
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
|
||
|
|
||
|
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
|
||
|
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
|
||
|
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
|
||
|
|
||
|
out = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
||
|
|
||
|
.. warning:: This class is a prototype and subject to change.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int):
|
||
|
"""
|
||
|
Initializes the CausalBias instance with a specified variant and sequence lengths.
|
||
|
|
||
|
Args:
|
||
|
variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT).
|
||
|
seq_len_q (int): The sequence length of the query tensor.
|
||
|
seq_len_kv (int): The sequence length of the key/value tensor.
|
||
|
|
||
|
Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs.
|
||
|
"""
|
||
|
assert isinstance(variant, CausalVariant)
|
||
|
self.variant = variant
|
||
|
self.seq_len_q = seq_len_q
|
||
|
self.seq_len_kv = seq_len_kv
|
||
|
if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
|
||
|
warn(
|
||
|
"Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!"
|
||
|
)
|
||
|
|
||
|
def _upper_left(self, device: torch.device) -> torch.Tensor:
|
||
|
"""Upper left causal bias"""
|
||
|
return torch.tril(
|
||
|
torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool)
|
||
|
)
|
||
|
|
||
|
def _lower_right(self, device: torch.device) -> torch.Tensor:
|
||
|
"""Lower right causal bias"""
|
||
|
diagonal_offset = self.seq_len_kv - self.seq_len_q
|
||
|
return torch.tril(
|
||
|
torch.ones(
|
||
|
self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool
|
||
|
),
|
||
|
diagonal=diagonal_offset,
|
||
|
)
|
||
|
|
||
|
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
||
|
"""
|
||
|
Materializes the causal bias into a tensor form.
|
||
|
|
||
|
Depending on the variant, this method generates either an upper-left or lower-right
|
||
|
triangular matrix to represent the causal bias.
|
||
|
|
||
|
Args:
|
||
|
device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU.
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: The materialized bias tensor.
|
||
|
"""
|
||
|
if device is None:
|
||
|
device = torch.device("cpu")
|
||
|
if self.variant == CausalVariant.UPPER_LEFT:
|
||
|
return self._upper_left(device)
|
||
|
elif self.variant == CausalVariant.LOWER_RIGHT:
|
||
|
return self._lower_right(device)
|
||
|
|
||
|
@staticmethod
|
||
|
def _dispatch(
|
||
|
query: torch.Tensor,
|
||
|
key: torch.Tensor,
|
||
|
value: torch.Tensor,
|
||
|
attn_mask: "CausalBias",
|
||
|
dropout_p: float = 0.0,
|
||
|
is_causal: bool = False,
|
||
|
scale: Optional[float] = None,
|
||
|
) -> torch.Tensor:
|
||
|
r"""
|
||
|
Handles the logic for computing attention with the specified causal bias.
|
||
|
|
||
|
Args:
|
||
|
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
|
||
|
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
|
||
|
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
|
||
|
attn_mask (CausalBias): The type of causal attention to apply.
|
||
|
A boolean mask where a value of True indicates that the element *should* take part in attention.
|
||
|
A float mask of the same type as query, key, value that is added to the attention score.
|
||
|
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
|
||
|
is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal
|
||
|
are set.
|
||
|
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
|
||
|
to :math:`\frac{1}{\sqrt{E}}`.
|
||
|
|
||
|
Returns:
|
||
|
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the causal bias variant is not a CausalVariant type.
|
||
|
|
||
|
"""
|
||
|
if is_causal:
|
||
|
raise ValueError("CausalBias should not be used with causal=True")
|
||
|
|
||
|
if (
|
||
|
attn_mask.seq_len_q == attn_mask.seq_len_kv
|
||
|
or attn_mask.variant == CausalVariant.UPPER_LEFT
|
||
|
):
|
||
|
return scaled_dot_product_attention(
|
||
|
query,
|
||
|
key,
|
||
|
value,
|
||
|
attn_mask=None,
|
||
|
dropout_p=dropout_p,
|
||
|
is_causal=True,
|
||
|
scale=scale,
|
||
|
)
|
||
|
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
|
||
|
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
|
||
|
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
|
||
|
if can_use_flash_attention(sdpa_params):
|
||
|
needs_padding = query.size(-1) % 8 != 0
|
||
|
og_head_size = query.size(-1)
|
||
|
og_scale = _calculate_scale(og_head_size, scale)
|
||
|
if needs_padding:
|
||
|
query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8))
|
||
|
key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8))
|
||
|
value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8))
|
||
|
out = torch.ops.aten._scaled_dot_product_flash_attention(
|
||
|
query,
|
||
|
key,
|
||
|
value,
|
||
|
dropout_p,
|
||
|
is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right
|
||
|
return_debug_mask=False,
|
||
|
scale=og_scale,
|
||
|
)[0]
|
||
|
return _postprocess_flash_output(out, og_head_size)
|
||
|
if can_use_efficient_attention(sdpa_params):
|
||
|
compute_log_sumexp = False
|
||
|
if _input_requires_grad(query, key, value):
|
||
|
compute_log_sumexp = True
|
||
|
return torch.ops.aten._efficient_attention_forward(
|
||
|
query.transpose(1, 2),
|
||
|
key.transpose(1, 2),
|
||
|
value.transpose(1, 2),
|
||
|
bias=None,
|
||
|
cu_seqlens_q=None,
|
||
|
cu_seqlens_k=None,
|
||
|
max_seqlen_q=None,
|
||
|
max_seqlen_k=None,
|
||
|
dropout_p=dropout_p,
|
||
|
custom_mask_type=int(attn_mask.variant),
|
||
|
compute_log_sumexp=compute_log_sumexp,
|
||
|
scale=scale,
|
||
|
causal_diagonal=None,
|
||
|
seqlen_k=None,
|
||
|
)[0].transpose(1, 2)
|
||
|
else:
|
||
|
_raise_kernel_warnings(sdpa_params)
|
||
|
# We cant use efficient attention the only support for lower right is via materialization
|
||
|
return scaled_dot_product_attention(
|
||
|
query,
|
||
|
key,
|
||
|
value,
|
||
|
attn_mask=attn_mask._materialize(query.device),
|
||
|
dropout_p=dropout_p,
|
||
|
is_causal=False,
|
||
|
scale=scale,
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}"
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||
|
"""Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""
|
||
|
if kwargs is None:
|
||
|
kwargs = {}
|
||
|
if func != torch.nn.functional.scaled_dot_product_attention:
|
||
|
raise NotImplementedError(
|
||
|
"CausalBias only supports scaled_dot_product_attention"
|
||
|
)
|
||
|
return cls._dispatch(*args, **kwargs)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return self._materialize().__repr__()
|
||
|
|
||
|
|
||
|
def causal_upper_left(*size) -> CausalBias:
|
||
|
"""
|
||
|
Creates an upper-left triangular causal bias.
|
||
|
|
||
|
This function generates a upper-left triangular matrix to represent causal attention bias with a
|
||
|
diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix.
|
||
|
This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`.
|
||
|
|
||
|
The equivalent pytorch code for constructing this bias is:
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
torch.tril(torch.ones(size, dtype=torch.bool))
|
||
|
|
||
|
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
|
||
|
|
||
|
.. code-block:: text
|
||
|
|
||
|
[[1, 0, 0, 0],
|
||
|
[1, 1, 0, 0],
|
||
|
[1, 1, 1, 0]]
|
||
|
|
||
|
Args:
|
||
|
size: The size of the bias matrix.
|
||
|
|
||
|
Returns:
|
||
|
CausalBias: The UPPER_LEFT triangular causal bias variant.
|
||
|
"""
|
||
|
assert len(size) == 2, "causal_upper_left only supports 2D tensors"
|
||
|
seq_len_q, seq_len_kv = size
|
||
|
return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv)
|
||
|
|
||
|
|
||
|
def causal_lower_right(*size) -> CausalBias:
|
||
|
"""
|
||
|
Creates a lower-right triangular causal bias.
|
||
|
|
||
|
This function generates a lower-right triangular matrix to represent causal attention bias with a
|
||
|
diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix.
|
||
|
|
||
|
The equivalent pytorch code for constructing this bias is:
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
diagonal_offset = size[1] - size[0]
|
||
|
torch.tril(
|
||
|
torch.ones(size, dtype=torch.bool),
|
||
|
diagonal=diagonal_offset,
|
||
|
)
|
||
|
|
||
|
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
|
||
|
|
||
|
.. code-block:: text
|
||
|
|
||
|
[[1, 1, 0, 0],
|
||
|
[1, 1, 1, 0],
|
||
|
[1, 1, 1, 1]]
|
||
|
|
||
|
Args:
|
||
|
size: The size of the bias matrix.
|
||
|
|
||
|
Returns:
|
||
|
CausalBias: The LOWER_RIGHT triangular causal bias variant.
|
||
|
"""
|
||
|
assert len(size) == 2, "causal_lower_right only supports 2D tensors"
|
||
|
seq_len_q, seq_len_kv = size
|
||
|
return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv)
|