Files
ik_llama.cpp/github-data/pull_requests/246 - Faster FlashMLA prompt processing.md
2025-07-23 13:31:53 +02:00

13 KiB

🔀 #246 - Faster FlashMLA prompt processing

Author ikawrakow
State Closed
Created 2025-03-08
Updated 2025-03-08

Description

MLA as used in the DeepSeek models is great for token generation (TG), but prompt processing (PP) speed is much lower compared to standard attention even with FA enabled.

This PR improves FlashMLA speed by a large margin. FlashMLA is CPU only, but the PR paves the way to perhaps also get it on CUDA (but this is left for a future PR).

The following table compares FlashMLA PP speed for DeepSeek-Lite quantized as IQ4_NL between the main branch and this PR. CPU is Ryzen-7950X, the cache is quantized with Q8_0, fmoe is on.

model test t/s (main) t/s (PR) Speedup
deepseek2 16B IQ4_NL pp512 605.29 ± 4.92 681.72 ± 1.12 1.126
deepseek2 16B IQ4_NL pp1024 568.79 ± 0.75 648.71 ± 1.48 1.141
deepseek2 16B IQ4_NL pp2048 509.15 ± 4.38 598.99 ± 0.83 1.176
deepseek2 16B IQ4_NL pp4096 420.10 ± 0.82 514.62 ± 2.68 1.225
deepseek2 16B IQ4_NL pp8192 293.24 ± 2.09 399.14 ± 5.89 1.361
deepseek2 16B IQ4_NL pp16384 170.66 ± 0.76 269.01 ± 4.64 1.576

For reference, here is a comparison between standard attention with FA enabled and FlashMLA with this PR

model test t/s (standard FA) t/s (PR) Speedup
deepseek2 16B IQ4_NL pp512 675.89 ± 7.49 681.72 ± 1.12 1.009
deepseek2 16B IQ4_NL pp1024 658.84 ± 1.08 648.71 ± 1.48 0.985
deepseek2 16B IQ4_NL pp2048 635.75 ± 1.70 598.99 ± 0.83 0.942
deepseek2 16B IQ4_NL pp4096 591.13 ± 0.06 514.62 ± 2.68 0.871
deepseek2 16B IQ4_NL pp8192 515.03 ± 2.53 399.14 ± 5.89 0.775
deepseek2 16B IQ4_NL pp16384 400.24 ± 0.74 269.01 ± 4.64 0.672

I.e., still quite a bit slower than standard attention with FA enabled for long contexts, but much better than the original implementation.

The new functionality is enabled via -mla 2 -fa as command line arguments. I know, it is getting confusing, so here is a summary of what happens with the different mla and fa combinations:

  • mla = 0, fa = 0: standard attention without FA. Works on the CPU and on CUDA. Large K- and V-cache required. The V cache cannot be quantized
  • mla = 0, fa = 1: standard attention with FA. Works on the CPU and on CUDA. Large K- and V-cache required. The V cache can be quantized. Best PP performance, TG performance is slightly lower than standard attention without FA
  • mla = 1, fa = 0: MLA attention. Works on the CPU and on CUDA. Smaller K- and smaller transposed V cache required. The V cache cannot be quantized. Great TG performance, pathetic TG performance.
  • mla = 1, fa = 1: FlashMLA. Works only on the CPU. Only small K cache required. Great TG performance, slightly less pathetic PP performance
  • mla = 2, fa = 0: FlashMLA . Works only on the CPU and on CUDA. Only small K cache required (the transposed V cache is computed on the fly). Great TG performance (but slightly lower than mla = 1 for long contexts), pathetic PP performance.
  • mla = 2, fa = 1: FlashMLA from this PR. Works only on CPU. Only small K cache required. Great TG performance, more acceptable PP performance.

Background

Let X and Q be the activations and the query after projection with their corresponding MQA tensors and after applying rotational position encoding (RoPE). In standard attention one computes (apart from scaling factors and masks that I'll omit for simplicity)

K = W_k X, \quad\quad V = W_v X,\quad\quad R = V_{\rm cache} {\rm softmax}(K_{\rm cache} Q)

In practice the W_k and W_v tensors are combined into W_{kv} (the tensor wkv_b in llama.cpp), one computes Y = W_{kv} X, and the tensors K and V are views into Y. The matrix multiplication with W_{kv} is performed only for the tokens in the batch being processed, the results are stored in the cache, and the tensors V_{\rm cache} and K_{\rm cache} are views into the KV cache.

With MLA one computes

Q' = W_k^T Q,\quad\quad R = W_v \left[ V_{\rm cache} {\rm softmax}(K_{\rm cache} Q' \right]

where one stores X directly into the K-cache, and K_{\rm cache} is an appropriate view into the cache. V_{\rm cache} is a transposed version of K_{\rm cache} with FA is not used, or a slightly different view into the K-cache with FA or mla=2. The benefit of doing this reordering of the operations is that the cache becomes much smaller. But as these are not square matrices, the amount of multiply-adds (madds in the following) does depend on the order of the matrix multiplications. If we denote the number of madds in the standard attention implementation wit N, for the DeepSeek models the number of madds with MLA is (576 + 512)/(192 + 128) \times N = 3.4 \times N. Why is TG with MLA faster than with standard attention if one needs to do more computation? The difference comes from the shapes of the various matrices involved. TG with standard attention results in the tensor Q being of shape M \times 1 \times L, so all multiplications are matrix-vector (a.k.a. GEMV), which are memory bound on basically any modern system (CPU or GPU). With MLA the shape of Q' is M' \times L, so the calculation involves matrix-matrix multiplications (a.k.a. GEMM), which are much faster per madd, so one ends up with a better performance despite having computed more madds. But for PP in both cases we are dealing with GEMMs, so the 3.4X more madds makes MLA PP processing slower. As an example, for 8k tokens with standard attention and FA, about 25% of the time is spent in the flash attention computation. We can estimate the expected MLA PP performance to be 0.75 + 0.25 x 3.4 = 1.6 times slower. From the above tables we see that in practice it is 515 t/s / 293 t/s = 1.75 times slower. As there are some other differences in the performed matrix multiplications, our back-of-the-envelope estimate comes quite close to the observed behavior.

So, how can we improve? We can rearrange the computation back to standard attention. The only difference: as we are storing X into the cache, we need to multiply W_{kv} with the entire content of the cache. This seems pretty stupid at first glance (and I had had the idea to rearrange the multiplications quite a while ago but discarded it because of that), but if one sits down and counts the actual madds that are required, one finds that for DeepSeek this results in (192 + 3 \times 128)/(192 + 128) = 1.8 \times N more madds than standard attention. I.e., we still need more madds, but significantly less madds than the existing MLA implementation. What about TG? We save the day by applying the rearranged matrix multiplications only if the number of tokens in the batch is greater than 1 (or some suitably chosen threshold). In this way we keep the good TG performance, keep the reduced cache size, and get improved prompt processing speed.


💬 Conversation

👤 davidsyoung commented the 2025-03-08 at 14:58:12:

Getting a linking error on iqk_flash_attn_noalibi:

129.5 c++ -std=c++17 -fPIC -O3 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -pthread -fopenmp -march=native -mtune=native -Wno-array-bounds -Wno-format-truncation -Wextra-semi -Iggml/include -Iggml/src -Iinclude -Isrc -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_OPENMP -DGGML_USE_IQK_MULMAT -DGGML_USE_LLAMAFILE -DGGML_USE_CUDA -I/usr/local/cuda/include -I/usr/local/cuda/targets/x86_64-linux/include -DGGML_CUDA_USE_GRAPHS -DLLAMA_USE_CURL ggml/src/iqk/iqk_quantize.o ggml/src/iqk/iqk_mul_mat.o ggml/src/llamafile/sgemm.o ggml/src/ggml-cuda.o ggml/src/ggml-cuda/acc.o ggml/src/ggml-cuda/arange.o ggml/src/ggml-cuda/argsort.o ggml/src/ggml-cuda/binbcast.o ggml/src/ggml-cuda/clamp.o ggml/src/ggml-cuda/concat.o ggml/src/ggml-cuda/conv-transpose-1d.o ggml/src/ggml-cuda/convert.o ggml/src/ggml-cuda/cpy.o ggml/src/ggml-cuda/diagmask.o ggml/src/ggml-cuda/dmmv.o ggml/src/ggml-cuda/fattn-tile-f16.o ggml/src/ggml-cuda/fattn-tile-f32.o ggml/src/ggml-cuda/fattn.o ggml/src/ggml-cuda/getrows.o ggml/src/ggml-cuda/im2col.o ggml/src/ggml-cuda/iqk_mmvq.o ggml/src/ggml-cuda/mmq.o ggml/src/ggml-cuda/mmvq.o ggml/src/ggml-cuda/norm.o ggml/src/ggml-cuda/pad.o ggml/src/ggml-cuda/pool2d.o ggml/src/ggml-cuda/quantize.o ggml/src/ggml-cuda/rope.o ggml/src/ggml-cuda/scale.o ggml/src/ggml-cuda/softcap.o ggml/src/ggml-cuda/softmax.o ggml/src/ggml-cuda/sumrows.o ggml/src/ggml-cuda/tsembd.o ggml/src/ggml-cuda/unary.o ggml/src/ggml-cuda/upscale.o ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.o ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.o ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.o ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.o ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.o ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.o ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.o ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.o ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.o ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.o ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.o ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.o ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.o ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.o ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.o ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.o ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.o ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.o ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.o ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.o ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.o ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0.o ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.o ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.o ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.o ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.o ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.o ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.o ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.o ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.o ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.o ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.o ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.o ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.o ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.o ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.o ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.o ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.o ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-iq4_nl.o ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-iq4_nl.o ggml/src/ggml.o ggml/src/ggml-alloc.o ggml/src/ggml-backend.o ggml/src/ggml-quants.o ggml/src/ggml-aarch64.o src/llama.o src/llama-vocab.o src/llama-grammar.o src/llama-sampling.o src/unicode.o src/unicode-data.o common/common.o common/console.o common/ngram-cache.o common/sampling.o common/train.o common/grammar-parser.o common/build-info.o common/json-schema-to-grammar.o -Iexamples/server examples/server/server.o -o llama-server -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/usr/lib64 -L/usr/local/cuda/targets/x86_64-linux/lib -L/usr/local/cuda/lib64/stubs -L/usr/lib/wsl/lib -lcurl 129.7 /usr/bin/ld: ggml/src/ggml.o: in function ggml_compute_forward_flash_attn_ext_f16': 129.7 ggml.c:(.text+0xb96b): undefined reference to iqk_flash_attn_noalibi' 130.1 collect2: error: ld returned 1 exit status 130.1 make: *** [Makefile:1462: llama-server] Error 1


👤 ikawrakow commented the 2025-03-08 at 15:07:14:

Are you using cmake to build? The object file for the new file that I added (iqk_flash_attn.cpp) is missing from the link command. It should be automatically added with cmake.


👤 davidsyoung commented the 2025-03-08 at 15:20:58:

Are you using cmake to build? The object file for the new file that I added (iqk_flash_attn.cpp) is missing from the link command. It should be automatically added with cmake.

Ah, I think that'll fix it. I was using the full-cuda.Dockerfile to run and I believe it was using a version of make still from previously forked llama.cpp.