Files
ik_llama.cpp/github-data/pull_requests/32 - Zen4 Flash Attention.md
2025-07-23 13:31:53 +02:00

37 lines
3.1 KiB
Markdown

### 🔀 [#32](https://github.com/ikawrakow/ik_llama.cpp/pull/32) - Zen4 Flash Attention
| **Author** | `ikawrakow` |
| :--- | :--- |
| **State** | ❌ **Closed** |
| **Created** | 2024-09-01 |
| **Updated** | 2024-09-01 |
---
#### Description
### TL;DR
This PR adds a flash attention (FA) implementation optimized for the Zen4 architecture as part of the quest to improve CPU inference for long contexts (#25, #26).
### Limitations
* It is Zen4-only for now. Strictly speaking, a much smaller subset of the AVX512 specification is required in the implementation (just `AVX512F` and `AVX512DQ`) compared to what Zen4 provides, but I didn't want to have too many variants, so decided to enable for Zen4 only.
* It is not implemented for ALiBi or unmasked attention. It is trivial to add these but I didn't want to clutter the implementation with branches that are mostly irrelevant.
### Performance comparisons
The following graph compares the prompt processing (PP) performance of mainline `llama.cpp` (build: a47667cf - 3650) without (green symbols) and with (blue symbols) FA to PP performance in this repository for `Q4_K_S`-quantized LLaMA-3.1-8B running on a Ryzen-7950X CPU where
* Black symbols are without FA
* Brown symbols are with FA inherited from `llama.cpp`
* Magenta symbols are with the new FA implementation in this PR
![fa](https://github.com/user-attachments/assets/57078b91-cdcf-45b8-ba41-eee97774bc56)
We observe that the original FA implementation results in a significant performance degradation in mainline `llama.cpp` and also here. The effect is much stronger for the version here. This is due to the `K*Q` and `V*(softmax(K*Q)` matrix multiplications being much faster in this repository thanks to `iqk_mul_mat`, so performance hit is larger when they are replaced with the original `llama.cpp` FA CPU kernel. The new FA implementation improves performance. The improvement increases with context length, reaching about 24% at 32k tokens.
The next graph shows results for `Q4_K_S`-quantized Gemma-2-2b. Symbol colors are the same as above.
![fa_gemma2b](https://github.com/user-attachments/assets/8206ee28-02a0-43b6-be67-f9ea03378eb3)
In this case the original FA kernel improves performance in mainline `llama.cpp`. The difference in behavior compared to LLaMA-3.1-8B is easily explained by the fact that the Gemma-2 series of models use "soft-caping" in their attention layers, where `softcap(x) = c * tanh(x/c)` (`c` is a model-defined constant). This is implemented as 3 different operations in `llama.cpp`. When FA is enabled, these 3 operations, along with `softmax` are fused into a single kernel, and this results in am improvement of mainline `llama.cpp` performance even for short contexts. But when the original FA kernel is used in our version, where "soft-caping" is already handled by a dedicated fused operation, we get a massive drop in performance just like in the LLaMA-3.1-8B case above. The new implementation in this PR is much better and performance improves again, reaching 11% at 8k tokens, which is the maximum training context length of Gemma-2-2b.