Skip to content

feat: add flash_attention 3 kernel for diffusers pipelines#287

Merged
johannaSommer merged 10 commits intomainfrom
feat/fa3
Aug 8, 2025
Merged

feat: add flash_attention 3 kernel for diffusers pipelines#287
johannaSommer merged 10 commits intomainfrom
feat/fa3

Conversation

@johannaSommer
Copy link
Copy Markdown
Member

@johannaSommer johannaSommer commented Aug 6, 2025

Description

In this PR, we make use of HuggingFace's kernel hub to utilize flash attention 3 in diffusers pipelines where possible. Importantly, diffusers is refactoring their attention handling right now, hence we have two different implementations.

In the current diffusers version, we wrap the call of a pipeline to intercept each call to torch.nn.funcional.scaled_dot_product_attention. If shapes, dtype and keyword arguments can be supported by flash attention, we reroute the computation to the fa3 kernel.

In the newer diffusers version (0.35.0, unreleased so far), we can register a new attention backend that calls the native attention or fa3 on the same criteria.

Overall, the speedup is minor for T2I pipelines but is substantial for Video Gen pipelines, in particular Wan.

Related Issue

None.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Added algorithm tests with Flux and Wan and added combination tests.

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Flash Attention 3 on WAN gives around a 1.4 speedup, no warmup or quality degradation of course.

Minimal Script:

from pruna import SmashConfig, smash
from diffusers import WanPipeline
import torch

pipeline = WanPipeline.from_pretrained(
    "Wan-AI/Wan2.2-T2V-A14B-Diffusers", torch_dtype=torch.bfloat16
)
pipeline.to("cuda");

config = SmashConfig()
config._prepare_saving = False
config["kernel"] = "flash_attn3"
pipeline = smash(pipeline, config)

prompt = "A cat is doing an acrobatic dive into a swimming pool at the olympics, from a 10m high diving board, flips and spins"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
output = pipeline.__call__(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=480,
    width=640,
    num_frames=69,
    guidance_scale=5.0,
    num_inference_steps=10,
    generator=torch.Generator(device="cpu").manual_seed(1),
).frames[0]

Output during Generation:
100%|██████████| 10/10 [01:04<00:00, 6.45s/it]


Generation times on the prompt / size above on 1 H100:

  • base: 8.04 s/it
  • flash_attn3: 6.46 s/it (no warmup!)
  • torch.compile: 6.38 s/it
  • flash_attn3 + torch.compile: 5.15 s/it

cursor[bot]

This comment was marked as outdated.

@johnrachwan123
Copy link
Copy Markdown
Member

bugbot run

cursor[bot]

This comment was marked as outdated.

Copy link
Copy Markdown
Member

@sharpenb sharpenb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR :)

  • I would be interested in a small benchmark (base, torch compile, fa3, fa3+torch.compile) for Wan (if that is the main model targeted by this update)
  • Was it tested with EvaluationAgent after compression?
  • Could we have test notebook since it integrates a new compression algorithm?

bool
True if the model is a valid model for the algorithm, False otherwise.
"""
if Version(diffusers_version) >= Version("0.35.0.dev0"):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about these conditions. It might be able to work with transformers.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment, can you elaborate? Because of refactoring in the upcoming diffusers version we have to apply seperate logic.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I mean that model checks focus on diffusers, while it might work for transformers architectures. My understanding now is that it was not tested for transformers but could be worth to have a quick check.

Comment thread src/pruna/algorithms/kernels/flash_attn3.py
imported_packages = self.import_algorithm_packages()

# in the new version of diffusers, we can use the modular attention backend to inject flash_attn3
if Version(diffusers_version) >= Version("0.35.0.dev0"):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to support both versions for certain model use-cases? It feels to me that the code could be more readable by saying we only support from version xyz.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a reminder to remove support for older version of diffusers when a couple fo diffusers releases will have been done :)

]:
component.set_attention_backend("flash_attn3_pruna")

else:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused by some functions redundant but I guess they are designed for the two diffusers versions. Could we mark e.g. in their docstrings, or separate files.

Dict[str, Any]
The algorithm packages.
"""
flash_attention_3 = get_kernel("kernels-community/flash-attn3")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is designed only for fa3 but kernels_hub has many more kernels. I feel that extending it to all kernels directly would be nice. Probably not for this PR but worth spending a bit of time now if the logic is similar.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't this this makes sense here - at least currently the other kernels available on kernel hub tackle other functionality, i.e. not the attention function. I think it is better to think of this get_kernel like an aggregation of building, registering and importing a certain kernel package, but which functions you use and how you use them will be different from kernel to kernel. If there are other interesting attention kernels coming soon, we will of course consolidate them.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it aligns with my point. Here, would it be possible to have a streamlined way to integrate other kernels e.g. by defining a mapping from kernel to the functions we use. I would expect that if we implement another kernel from the hub, we would go over similar implementation steps e.g. registering, identifying functions, and wrapping calls. One potential ideais to take inspiration from their blog which suggests to use use_kernel_forward_from_hub. Again no need to refactorize , but wanted to highlight this since it might be relevant for the future ;)

"Rerouting to native attention. Check the following criteria in algorithms/kernels/flash_attn3.py: "
f"attn_mask_pass: {attn_mask is not None}, dropout_p_pass: {dropout_p != 0.0}, "
f"dtype_pass: {dtype_pass}, num_heads_pass: {num_heads_pass}, head_dim_pass: {head_dim_pass}"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Flash Attention 3 Debug Logic Inverted

The debug message in src/pruna/algorithms/kernels/flash_attn3.py for rerouting to native attention displays attn_mask_pass and dropout_p_pass with inverted logic. These expressions evaluate to True when Flash Attention 3 conditions fail (e.g., attn_mask is not None), which is inconsistent with their _pass suffix and the correct 'pass' logic used in FlashAttention3Context.

Fix in Cursor Fix in Web

Copy link
Copy Markdown
Member

@sharpenb sharpenb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go! I left some final thoughts

bool
True if the model is a valid model for the algorithm, False otherwise.
"""
if Version(diffusers_version) >= Version("0.35.0.dev0"):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I mean that model checks focus on diffusers, while it might work for transformers architectures. My understanding now is that it was not tested for transformers but could be worth to have a quick check.

Dict[str, Any]
The algorithm packages.
"""
flash_attention_3 = get_kernel("kernels-community/flash-attn3")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it aligns with my point. Here, would it be possible to have a streamlined way to integrate other kernels e.g. by defining a mapping from kernel to the functions we use. I would expect that if we implement another kernel from the hub, we would go over similar implementation steps e.g. registering, identifying functions, and wrapping calls. One potential ideais to take inspiration from their blog which suggests to use use_kernel_forward_from_hub. Again no need to refactorize , but wanted to highlight this since it might be relevant for the future ;)

Comment thread src/pruna/algorithms/kernels/flash_attn3.py
imported_packages = self.import_algorithm_packages()

# in the new version of diffusers, we can use the modular attention backend to inject flash_attn3
if Version(diffusers_version) >= Version("0.35.0.dev0"):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a reminder to remove support for older version of diffusers when a couple fo diffusers releases will have been done :)

Copy link
Copy Markdown
Member

@johnrachwan123 johnrachwan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this cool contribution!

@johannaSommer johannaSommer merged commit 806d160 into main Aug 8, 2025
7 checks passed
@johannaSommer johannaSommer deleted the feat/fa3 branch August 8, 2025 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants