From c9b62968b975d3401cf41977818811095f91b7d1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 25 Sep 2025 14:17:30 +0530 Subject: [PATCH 1/5] add a lightweight test suite for attention backends. --- tests/others/test_attention_backends.py | 121 ++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 tests/others/test_attention_backends.py diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py new file mode 100644 index 000000000000..4b47d37c44ad --- /dev/null +++ b/tests/others/test_attention_backends.py @@ -0,0 +1,121 @@ +""" +This test suite exists for the maintainers currently. It's not run in our CI at the moment. + +Once attention backends become more mature, we can consider including this in our CI. + +To run this test suite: + +``` +export RUN_ATTENTION_BACKEND_TESTS=yes +export DIFFUSERS_ENABLE_HUB_KERNELS=yes + +pytest tests/others/test_attention_backends.py +``` +""" + +import os + +import pytest +import torch + + +pytestmark = pytest.mark.skipif( + os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "true", reason="Feature not mature enough." +) + +from pytest import mark as parameterize # noqa: E402 +from torch._dynamo import config as dynamo_config # noqa: E402 + +from diffusers import FluxPipeline # noqa: E402 + + +FORWARD_CASES = [ + ("flash_hub", None), + ("_flash_3_hub", None), + ("native", None), + ("_native_cudnn", None), +] + +COMPILE_CASES = [ + ("flash_hub", None, True), + ("_flash_3_hub", None, True), + ("native", None, True), + ("_native_cudnn", None, True), + ("native", None, True), +] + +INFER_KW = { + "prompt": "dance doggo dance", + "height": 256, + "width": 256, + "num_inference_steps": 2, + "guidance_scale": 3.5, + "max_sequence_length": 128, + "output_type": "pt", +} + + +def _backend_is_probably_supported(pipe, name: str) -> bool: + try: + pipe.transformer.set_attention_backend(name) + return True + except (NotImplementedError, RuntimeError, ValueError): + return False + + +def _check_if_slices_match(output, expected_slice): + img = output.images + generated_slice = img.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + assert torch.allclose(generated_slice, expected_slice, atol=1e-4) + + +@pytest.fixture(scope="session") +def device(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for these tests.") + return torch.device("cuda:0") + + +@pytest.fixture(scope="session") +def pipe(device): + torch.set_grad_enabled(False) + model_id = "black-forest-labs/FLUX.1-dev" + pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device) + pipe.set_progress_bar_config(disable=True) + pipe.transformer.eval() + return pipe + + +@parameterize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES]) +def test_forward(pipe, backend_name, expected_slice): + if not _backend_is_probably_supported(pipe, backend_name): + pytest.xfail(f"Backend '{backend_name}' not supported in this environment.") + + out = pipe( + "a tiny toy cat in a box", + **INFER_KW, + generator=torch.manual_seed(0), + ) + _check_if_slices_match(out, expected_slice) + + +@parameterize( + "backend_name,expected_slice,error_on_recompile", + COMPILE_CASES, + ids=[c[0] for c in COMPILE_CASES], +) +def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile): + if not _backend_is_probably_supported(pipe, backend_name): + pytest.xfail(f"Backend '{backend_name}' not supported in this environment.") + + pipe.transformer.compile(fullgraph=True) + with dynamo_config.patch(error_on_recompile=bool(error_on_recompile)): + torch.manual_seed(0) + out = pipe( + "a tiny toy cat in a box", + **INFER_KW, + generator=torch.manual_seed(0), + ) + + _check_if_slices_match(out, expected_slice) From b62fec0c83c61f551e2cad58db1776413693afb9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 25 Sep 2025 14:18:16 +0530 Subject: [PATCH 2/5] up --- tests/others/test_attention_backends.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 4b47d37c44ad..4a67e2d4445e 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -67,6 +67,7 @@ def _check_if_slices_match(output, expected_slice): img = output.images generated_slice = img.flatten() generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + print(f"{generated_slice=}") assert torch.allclose(generated_slice, expected_slice, atol=1e-4) From c51e5f2fe3ee0b8485e985c1c050538eb4a6004a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 25 Sep 2025 17:02:47 +0530 Subject: [PATCH 3/5] up --- tests/others/test_attention_backends.py | 212 ++++++++++++++++++++---- 1 file changed, 177 insertions(+), 35 deletions(-) diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 4a67e2d4445e..53e4e95531bd 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -5,12 +5,15 @@ To run this test suite: -``` +```bash export RUN_ATTENTION_BACKEND_TESTS=yes export DIFFUSERS_ENABLE_HUB_KERNELS=yes pytest tests/others/test_attention_backends.py ``` + +Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in +"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128). """ import os @@ -20,28 +23,165 @@ pytestmark = pytest.mark.skipif( - os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "true", reason="Feature not mature enough." + os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough." ) - -from pytest import mark as parameterize # noqa: E402 -from torch._dynamo import config as dynamo_config # noqa: E402 - from diffusers import FluxPipeline # noqa: E402 +from diffusers.utils import is_torch_version # noqa: E402 FORWARD_CASES = [ ("flash_hub", None), - ("_flash_3_hub", None), - ("native", None), - ("_native_cudnn", None), + ( + "_flash_3_hub", + torch.tensor( + [ + 0.0820, + 0.0859, + 0.0938, + 0.1016, + 0.0977, + 0.0996, + 0.1016, + 0.1016, + 0.2188, + 0.2246, + 0.2344, + 0.2480, + 0.2539, + 0.2480, + 0.2441, + 0.2715, + ], + dtype=torch.bfloat16, + ), + ), + ( + "native", + torch.tensor( + [ + 0.0820, + 0.0859, + 0.0938, + 0.1016, + 0.0957, + 0.0996, + 0.0996, + 0.1016, + 0.2188, + 0.2266, + 0.2363, + 0.2500, + 0.2539, + 0.2480, + 0.2461, + 0.2734, + ], + dtype=torch.bfloat16, + ), + ), + ( + "_native_cudnn", + torch.tensor( + [ + 0.0781, + 0.0840, + 0.0879, + 0.0957, + 0.0898, + 0.0957, + 0.0957, + 0.0977, + 0.2168, + 0.2246, + 0.2324, + 0.2500, + 0.2539, + 0.2480, + 0.2441, + 0.2695, + ], + dtype=torch.bfloat16, + ), + ), ] COMPILE_CASES = [ ("flash_hub", None, True), - ("_flash_3_hub", None, True), - ("native", None, True), - ("_native_cudnn", None, True), - ("native", None, True), + ( + "_flash_3_hub", + torch.tensor( + [ + 0.0410, + 0.0410, + 0.0449, + 0.0508, + 0.0508, + 0.0605, + 0.0625, + 0.0605, + 0.2344, + 0.2461, + 0.2578, + 0.2734, + 0.2852, + 0.2812, + 0.2773, + 0.3047, + ], + dtype=torch.bfloat16, + ), + True, + ), + ( + "native", + torch.tensor( + [ + 0.0410, + 0.0410, + 0.0449, + 0.0508, + 0.0508, + 0.0605, + 0.0605, + 0.0605, + 0.2344, + 0.2461, + 0.2578, + 0.2773, + 0.2871, + 0.2832, + 0.2773, + 0.3066, + ], + dtype=torch.bfloat16, + ), + True, + ), + ( + "_native_cudnn", + torch.tensor( + [ + 0.0410, + 0.0410, + 0.0430, + 0.0508, + 0.0488, + 0.0586, + 0.0605, + 0.0586, + 0.2344, + 0.2461, + 0.2578, + 0.2773, + 0.2871, + 0.2832, + 0.2793, + 0.3086, + ], + dtype=torch.bfloat16, + ), + True, + ), ] INFER_KW = { @@ -55,19 +195,18 @@ } -def _backend_is_probably_supported(pipe, name: str) -> bool: +def _backend_is_probably_supported(pipe, name: str): try: pipe.transformer.set_attention_backend(name) - return True - except (NotImplementedError, RuntimeError, ValueError): + return pipe, True + except Exception: return False def _check_if_slices_match(output, expected_slice): - img = output.images + img = output.images.detach().cpu() generated_slice = img.flatten() generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) - print(f"{generated_slice=}") assert torch.allclose(generated_slice, expected_slice, atol=1e-4) @@ -88,35 +227,38 @@ def pipe(device): return pipe -@parameterize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES]) +@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES]) def test_forward(pipe, backend_name, expected_slice): - if not _backend_is_probably_supported(pipe, backend_name): + out = _backend_is_probably_supported(pipe, backend_name) + if isinstance(out, bool): pytest.xfail(f"Backend '{backend_name}' not supported in this environment.") - out = pipe( - "a tiny toy cat in a box", - **INFER_KW, - generator=torch.manual_seed(0), - ) + modified_pipe = out[0] + out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) _check_if_slices_match(out, expected_slice) -@parameterize( +@pytest.mark.parametrize( "backend_name,expected_slice,error_on_recompile", COMPILE_CASES, ids=[c[0] for c in COMPILE_CASES], ) def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile): - if not _backend_is_probably_supported(pipe, backend_name): + if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"): + pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.") + + out = _backend_is_probably_supported(pipe, backend_name) + if isinstance(out, bool): pytest.xfail(f"Backend '{backend_name}' not supported in this environment.") - pipe.transformer.compile(fullgraph=True) - with dynamo_config.patch(error_on_recompile=bool(error_on_recompile)): - torch.manual_seed(0) - out = pipe( - "a tiny toy cat in a box", - **INFER_KW, - generator=torch.manual_seed(0), - ) + modified_pipe = out[0] + modified_pipe.transformer.compile(fullgraph=True) + + torch.compiler.reset() + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=error_on_recompile), + ): + out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) _check_if_slices_match(out, expected_slice) From 20ece4c7cc36fa3a235fd1eb57d8f4c4dbbd8b97 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 7 Oct 2025 13:30:00 +0530 Subject: [PATCH 4/5] Apply suggestions from code review --- tests/others/test_attention_backends.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 53e4e95531bd..a8221e6bbcc6 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -219,11 +219,9 @@ def device(): @pytest.fixture(scope="session") def pipe(device): - torch.set_grad_enabled(False) - model_id = "black-forest-labs/FLUX.1-dev" - pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device) + repo_id = "black-forest-labs/FLUX.1-dev" + pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device) pipe.set_progress_bar_config(disable=True) - pipe.transformer.eval() return pipe From 301fed023044a52043ba94624af6f14b196925e8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 Oct 2025 14:26:44 +0530 Subject: [PATCH 5/5] formatting --- tests/others/test_attention_backends.py | 134 ++---------------------- 1 file changed, 8 insertions(+), 126 deletions(-) diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index a8221e6bbcc6..42cdcd56f74a 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -29,79 +29,20 @@ from diffusers.utils import is_torch_version # noqa: E402 +# fmt: off FORWARD_CASES = [ ("flash_hub", None), ( "_flash_3_hub", - torch.tensor( - [ - 0.0820, - 0.0859, - 0.0938, - 0.1016, - 0.0977, - 0.0996, - 0.1016, - 0.1016, - 0.2188, - 0.2246, - 0.2344, - 0.2480, - 0.2539, - 0.2480, - 0.2441, - 0.2715, - ], - dtype=torch.bfloat16, - ), + torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16), ), ( "native", - torch.tensor( - [ - 0.0820, - 0.0859, - 0.0938, - 0.1016, - 0.0957, - 0.0996, - 0.0996, - 0.1016, - 0.2188, - 0.2266, - 0.2363, - 0.2500, - 0.2539, - 0.2480, - 0.2461, - 0.2734, - ], - dtype=torch.bfloat16, + torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16) ), - ), ( "_native_cudnn", - torch.tensor( - [ - 0.0781, - 0.0840, - 0.0879, - 0.0957, - 0.0898, - 0.0957, - 0.0957, - 0.0977, - 0.2168, - 0.2246, - 0.2324, - 0.2500, - 0.2539, - 0.2480, - 0.2441, - 0.2695, - ], - dtype=torch.bfloat16, - ), + torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16), ), ] @@ -109,80 +50,21 @@ ("flash_hub", None, True), ( "_flash_3_hub", - torch.tensor( - [ - 0.0410, - 0.0410, - 0.0449, - 0.0508, - 0.0508, - 0.0605, - 0.0625, - 0.0605, - 0.2344, - 0.2461, - 0.2578, - 0.2734, - 0.2852, - 0.2812, - 0.2773, - 0.3047, - ], - dtype=torch.bfloat16, - ), + torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), True, ), ( "native", - torch.tensor( - [ - 0.0410, - 0.0410, - 0.0449, - 0.0508, - 0.0508, - 0.0605, - 0.0605, - 0.0605, - 0.2344, - 0.2461, - 0.2578, - 0.2773, - 0.2871, - 0.2832, - 0.2773, - 0.3066, - ], - dtype=torch.bfloat16, - ), + torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16), True, ), ( "_native_cudnn", - torch.tensor( - [ - 0.0410, - 0.0410, - 0.0430, - 0.0508, - 0.0488, - 0.0586, - 0.0605, - 0.0586, - 0.2344, - 0.2461, - 0.2578, - 0.2773, - 0.2871, - 0.2832, - 0.2793, - 0.3086, - ], - dtype=torch.bfloat16, - ), + torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16), True, ), ] +# fmt: on INFER_KW = { "prompt": "dance doggo dance",