From e9ea1c5b2c3422ce9858c4dba90447ed6d4af905 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 6 Oct 2025 10:47:12 +0530 Subject: [PATCH 1/2] up --- src/diffusers/models/attention_dispatch.py | 42 ++++++++++++++++++++-- src/diffusers/utils/kernels_utils.py | 17 ++++++--- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..6729a21af055 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -83,12 +83,16 @@ raise ImportError( "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." ) - from ..utils.kernels_utils import _get_fa3_from_hub + from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub - flash_attn_interface_hub = _get_fa3_from_hub() + flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3) flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func + + sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE) + sage_attn_func_hub = sage_interface_hub.sageattn else: flash_attn_3_func_hub = None + sage_attn_func_hub = None if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -190,6 +194,7 @@ class AttentionBackendName(str, Enum): # `sageattention` SAGE = "sage" + SAGE_HUB = "sage_hub" SAGE_VARLEN = "sage_varlen" _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" @@ -1756,6 +1761,39 @@ def _sage_attention( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE_HUB, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=False, +) +def _sage_attention_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + lse = None + if _parallel_config is None: + out = sage_attn_func_hub( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + if return_lse: + out, lse, *_ = out + else: + raise NotImplementedError("SAGE attention doesn't yet support parallelism.") + + return (out, lse) if return_lse else out + + @_AttentionBackendRegistry.register( AttentionBackendName.SAGE_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index 26d6e3972fb7..61201b847b74 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -6,18 +6,25 @@ _DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" +_DEFAULT_HUB_ID_SAGE = "kernels-community/sage_attention" +_KERNEL_REVISION = { + # TODO: temporary revision for now. Remove when merged upstream into `main`. + _DEFAULT_HUB_ID_FA3: "fake-ops-return-probs", + _DEFAULT_HUB_ID_SAGE: None, +} -def _get_fa3_from_hub(): +def _get_kernel_from_hub(kernel_id): if not is_kernels_available(): return None else: from kernels import get_kernel try: - # TODO: temporary revision for now. Remove when merged upstream into `main`. - flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") - return flash_attn_3_hub + if kernel_id not in _KERNEL_REVISION: + raise NotImplementedError(f"{kernel_id} is not implemented in Diffusers.") + kernel_hub = get_kernel(kernel_id, revision=_KERNEL_REVISION.get(kernel_id)) + return kernel_hub except Exception as e: - logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") + logger.error(f"An error occurred while fetching kernel '{kernel_id}' from the Hub: {e}") raise From d3441340b9bb88d06a1eca5b60cd3ab7ef0f4683 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 Oct 2025 18:40:04 +0530 Subject: [PATCH 2/2] support automatic dispatch. --- src/diffusers/models/attention_dispatch.py | 23 ++-- src/diffusers/utils/kernels_utils.py | 2 +- src/diffusers/utils/sage_utils.py | 137 +++++++++++++++++++++ 3 files changed, 146 insertions(+), 16 deletions(-) create mode 100644 src/diffusers/utils/sage_utils.py diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 6729a21af055..447d8a7b1783 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -17,7 +17,8 @@ import inspect import math from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -84,12 +85,16 @@ "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." ) from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub + from ..utils.sage_utils import _get_sage_attn_fn_for_device flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3) flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE) - sage_attn_func_hub = sage_interface_hub.sageattn + sage_fn_with_kwargs = _get_sage_attn_fn_for_device() + sage_attn_func_hub = getattr(sage_interface_hub, sage_fn_with_kwargs["func"]) + sage_attn_func_hub = partial(sage_attn_func_hub, **sage_fn_with_kwargs["kwargs"]) + else: flash_attn_3_func_hub = None sage_attn_func_hub = None @@ -166,10 +171,6 @@ def wrap(func): # - CP with sage attention, flex, xformers, other missing backends # - Add support for normal and CP training with backends that don't support it yet -_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] -_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] -_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] - class AttentionBackendName(str, Enum): # EAGER = "eager" @@ -1777,15 +1778,7 @@ def _sage_attention_hub( ) -> torch.Tensor: lse = None if _parallel_config is None: - out = sage_attn_func_hub( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) + out = sage_attn_func_hub(q=query, k=key, v=value) if return_lse: out, lse, *_ = out else: diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index 61201b847b74..3470692cca09 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -10,7 +10,7 @@ _KERNEL_REVISION = { # TODO: temporary revision for now. Remove when merged upstream into `main`. _DEFAULT_HUB_ID_FA3: "fake-ops-return-probs", - _DEFAULT_HUB_ID_SAGE: None, + _DEFAULT_HUB_ID_SAGE: "compile", } diff --git a/src/diffusers/utils/sage_utils.py b/src/diffusers/utils/sage_utils.py new file mode 100644 index 000000000000..28e4e17941eb --- /dev/null +++ b/src/diffusers/utils/sage_utils.py @@ -0,0 +1,137 @@ +""" +Copyright (c) 2024 by SageAttention, The HuggingFace team. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the +License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +""" + +""" +Modified from +https://github.com/thu-ml/SageAttention/blob/68de3797d163b89d28f9a38026c3b7313f6940d2/sageattention/core.py +""" + + +import torch # noqa + + +SAGE_ATTENTION_DISPATCH = { + "sm80": { + "func": "sageattn_qk_int8_pv_fp16_cuda", + "kwargs": { + "tensor_layout": "NHD", + "is_causal": False, + "sm_scale": None, + "return_lse": False, + "pv_accum_dtype": "fp32", + }, + }, + "sm89": { + "func": "sageattn_qk_int8_pv_fp8_cuda", + "kwargs": { + "tensor_layout": "NHD", + "is_causal": False, + "sm_scale": None, + "return_lse": False, + "pv_accum_dtype": "fp32+fp16", + }, + }, + "sm90": { + "func": "sageattn_qk_int8_pv_fp8_cuda_sm90", + "kwargs": { + "tensor_layout": "NHD", + "is_causal": False, + "sm_scale": None, + "return_lse": False, + "pv_accum_dtype": "fp32+fp32", + }, + }, + "sm120": { + "func": "sageattn_qk_int8_pv_fp8_cuda", + "kwargs": { + "tensor_layout": "NHD", + "is_causal": False, + "qk_quant_gran": "per_warp", + "sm_scale": None, + "return_lse": False, + "pv_accum_dtype": "fp32+fp16", + }, + }, +} + + +def get_cuda_version(): + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + return major, minor + else: + raise EnvironmentError("CUDA not found.") + + +def get_cuda_arch_versions(): + if not torch.cuda.is_available(): + EnvironmentError("CUDA not found.") + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +# Unlike the actual implementation, we just maintain function names rather than actual +# implementations. +def _get_sage_attn_fn_for_device(): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute + capability. + + Parameters ---------- q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns ------- torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: + ``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. + + Note ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + device_index = torch.cuda.current_device() + arch = get_cuda_arch_versions()[device_index] + return SAGE_ATTENTION_DISPATCH[arch]