3.1 KiB
🔀 #32 - Zen4 Flash Attention
| Author | ikawrakow |
|---|---|
| State | ❌ Closed |
| Created | 2024-09-01 |
| Updated | 2024-09-01 |
Description
TL;DR
This PR adds a flash attention (FA) implementation optimized for the Zen4 architecture as part of the quest to improve CPU inference for long contexts (#25, #26).
Limitations
- It is Zen4-only for now. Strictly speaking, a much smaller subset of the AVX512 specification is required in the implementation (just
AVX512FandAVX512DQ) compared to what Zen4 provides, but I didn't want to have too many variants, so decided to enable for Zen4 only. - It is not implemented for ALiBi or unmasked attention. It is trivial to add these but I didn't want to clutter the implementation with branches that are mostly irrelevant.
Performance comparisons
The following graph compares the prompt processing (PP) performance of mainline llama.cpp (build: a47667cf - 3650) without (green symbols) and with (blue symbols) FA to PP performance in this repository for Q4_K_S-quantized LLaMA-3.1-8B running on a Ryzen-7950X CPU where
- Black symbols are without FA
- Brown symbols are with FA inherited from
llama.cpp - Magenta symbols are with the new FA implementation in this PR
We observe that the original FA implementation results in a significant performance degradation in mainline llama.cpp and also here. The effect is much stronger for the version here. This is due to the K*Q and V*(softmax(K*Q) matrix multiplications being much faster in this repository thanks to iqk_mul_mat, so performance hit is larger when they are replaced with the original llama.cpp FA CPU kernel. The new FA implementation improves performance. The improvement increases with context length, reaching about 24% at 32k tokens.
The next graph shows results for Q4_K_S-quantized Gemma-2-2b. Symbol colors are the same as above.
In this case the original FA kernel improves performance in mainline llama.cpp. The difference in behavior compared to LLaMA-3.1-8B is easily explained by the fact that the Gemma-2 series of models use "soft-caping" in their attention layers, where softcap(x) = c * tanh(x/c) (c is a model-defined constant). This is implemented as 3 different operations in llama.cpp. When FA is enabled, these 3 operations, along with softmax are fused into a single kernel, and this results in am improvement of mainline llama.cpp performance even for short contexts. But when the original FA kernel is used in our version, where "soft-caping" is already handled by a dedicated fused operation, we get a massive drop in performance just like in the LLaMA-3.1-8B case above. The new implementation in this PR is much better and performance improves again, reaching 11% at 8k tokens, which is the maximum training context length of Gemma-2-2b.