feat: add flash_attention 3 kernel for diffusers pipelines#287
feat: add flash_attention 3 kernel for diffusers pipelines#287johannaSommer merged 10 commits intomainfrom
flash_attention 3 kernel for diffusers pipelines#287Conversation
|
bugbot run |
sharpenb
left a comment
There was a problem hiding this comment.
Thanks for the PR :)
- I would be interested in a small benchmark (base, torch compile, fa3, fa3+torch.compile) for Wan (if that is the main model targeted by this update)
- Was it tested with EvaluationAgent after compression?
- Could we have test notebook since it integrates a new compression algorithm?
| bool | ||
| True if the model is a valid model for the algorithm, False otherwise. | ||
| """ | ||
| if Version(diffusers_version) >= Version("0.35.0.dev0"): |
There was a problem hiding this comment.
I am not sure about these conditions. It might be able to work with transformers.
There was a problem hiding this comment.
I don't understand this comment, can you elaborate? Because of refactoring in the upcoming diffusers version we have to apply seperate logic.
There was a problem hiding this comment.
Sorry for the confusion. I mean that model checks focus on diffusers, while it might work for transformers architectures. My understanding now is that it was not tested for transformers but could be worth to have a quick check.
| imported_packages = self.import_algorithm_packages() | ||
|
|
||
| # in the new version of diffusers, we can use the modular attention backend to inject flash_attn3 | ||
| if Version(diffusers_version) >= Version("0.35.0.dev0"): |
There was a problem hiding this comment.
Do we need to support both versions for certain model use-cases? It feels to me that the code could be more readable by saying we only support from version xyz.
There was a problem hiding this comment.
Let's have a reminder to remove support for older version of diffusers when a couple fo diffusers releases will have been done :)
| ]: | ||
| component.set_attention_backend("flash_attn3_pruna") | ||
|
|
||
| else: |
There was a problem hiding this comment.
I was confused by some functions redundant but I guess they are designed for the two diffusers versions. Could we mark e.g. in their docstrings, or separate files.
| Dict[str, Any] | ||
| The algorithm packages. | ||
| """ | ||
| flash_attention_3 = get_kernel("kernels-community/flash-attn3") |
There was a problem hiding this comment.
This is designed only for fa3 but kernels_hub has many more kernels. I feel that extending it to all kernels directly would be nice. Probably not for this PR but worth spending a bit of time now if the logic is similar.
There was a problem hiding this comment.
I don't this this makes sense here - at least currently the other kernels available on kernel hub tackle other functionality, i.e. not the attention function. I think it is better to think of this get_kernel like an aggregation of building, registering and importing a certain kernel package, but which functions you use and how you use them will be different from kernel to kernel. If there are other interesting attention kernels coming soon, we will of course consolidate them.
There was a problem hiding this comment.
I think that it aligns with my point. Here, would it be possible to have a streamlined way to integrate other kernels e.g. by defining a mapping from kernel to the functions we use. I would expect that if we implement another kernel from the hub, we would go over similar implementation steps e.g. registering, identifying functions, and wrapping calls. One potential ideais to take inspiration from their blog which suggests to use use_kernel_forward_from_hub. Again no need to refactorize , but wanted to highlight this since it might be relevant for the future ;)
| "Rerouting to native attention. Check the following criteria in algorithms/kernels/flash_attn3.py: " | ||
| f"attn_mask_pass: {attn_mask is not None}, dropout_p_pass: {dropout_p != 0.0}, " | ||
| f"dtype_pass: {dtype_pass}, num_heads_pass: {num_heads_pass}, head_dim_pass: {head_dim_pass}" | ||
| ) |
There was a problem hiding this comment.
Bug: Flash Attention 3 Debug Logic Inverted
The debug message in src/pruna/algorithms/kernels/flash_attn3.py for rerouting to native attention displays attn_mask_pass and dropout_p_pass with inverted logic. These expressions evaluate to True when Flash Attention 3 conditions fail (e.g., attn_mask is not None), which is inconsistent with their _pass suffix and the correct 'pass' logic used in FlashAttention3Context.
sharpenb
left a comment
There was a problem hiding this comment.
Let's go! I left some final thoughts
| bool | ||
| True if the model is a valid model for the algorithm, False otherwise. | ||
| """ | ||
| if Version(diffusers_version) >= Version("0.35.0.dev0"): |
There was a problem hiding this comment.
Sorry for the confusion. I mean that model checks focus on diffusers, while it might work for transformers architectures. My understanding now is that it was not tested for transformers but could be worth to have a quick check.
| Dict[str, Any] | ||
| The algorithm packages. | ||
| """ | ||
| flash_attention_3 = get_kernel("kernels-community/flash-attn3") |
There was a problem hiding this comment.
I think that it aligns with my point. Here, would it be possible to have a streamlined way to integrate other kernels e.g. by defining a mapping from kernel to the functions we use. I would expect that if we implement another kernel from the hub, we would go over similar implementation steps e.g. registering, identifying functions, and wrapping calls. One potential ideais to take inspiration from their blog which suggests to use use_kernel_forward_from_hub. Again no need to refactorize , but wanted to highlight this since it might be relevant for the future ;)
| imported_packages = self.import_algorithm_packages() | ||
|
|
||
| # in the new version of diffusers, we can use the modular attention backend to inject flash_attn3 | ||
| if Version(diffusers_version) >= Version("0.35.0.dev0"): |
There was a problem hiding this comment.
Let's have a reminder to remove support for older version of diffusers when a couple fo diffusers releases will have been done :)
johnrachwan123
left a comment
There was a problem hiding this comment.
Thanks a lot for this cool contribution!
Description
In this PR, we make use of HuggingFace's kernel hub to utilize flash attention 3 in diffusers pipelines where possible. Importantly, diffusers is refactoring their attention handling right now, hence we have two different implementations.
In the current diffusers version, we wrap the call of a pipeline to intercept each call to
torch.nn.funcional.scaled_dot_product_attention. If shapes, dtype and keyword arguments can be supported by flash attention, we reroute the computation to the fa3 kernel.In the newer diffusers version (0.35.0, unreleased so far), we can register a new attention backend that calls the native attention or fa3 on the same criteria.
Overall, the speedup is minor for T2I pipelines but is substantial for Video Gen pipelines, in particular Wan.
Related Issue
None.
Type of Change
How Has This Been Tested?
Added algorithm tests with Flux and Wan and added combination tests.
Checklist
Additional Notes
Flash Attention 3 on WAN gives around a 1.4 speedup, no warmup or quality degradation of course.
Minimal Script:
Output during Generation:
100%|██████████| 10/10 [01:04<00:00, 6.45s/it]Generation times on the prompt / size above on 1 H100:
flash_attn3: 6.46 s/it (no warmup!)torch.compile: 6.38 s/itflash_attn3+torch.compile: 5.15 s/it