Quantized Flash Attention for all supported CPU platforms (#51)

* NEON Flash Attention: add support for Q8_0, Q4_0, Q4_1

* NEON Flash Attention: quantized K*Q for q4_0

I could finally take advantage of the matrix multiplication
templates. We get quite a bit of speedup that way for q4_0:
For Gemma-2b using mul_mat_qX_0_q8_0<DequantizerQ40, q_step>
results in PP-2048 = 287 t/s vs 268 t/s when converting the
q4_0 k-cache and Q to fp16 and using fp16 multiplication.

* NEON Flash Attention: quantized K*Q for q4_1

* NEON Flash Attention: quantized K*Q for q8_0

This makes quite a bit of difference:
For Gemma2-2b PP-8192 is 228 t/s with quantized K*Q vs
178 t/s when converting things to fp16 and using fp16
matrix multiplication.
We have PP-512 = 307 t/s, so PP-8192 is now ~75% of the
performance of PP-512. In contrast, llama.cpp with Q8_0
cache is 38% of PP-512.

* Zen4 Flash Attention: quantized K*Q for q4_0, q4_1, q8_0

* AVX2 Flash Attention: quantized K*Q for q4_0, q4_1, q8_0

* Tidy up FlashMS

* Delete no longer used stuff

With the usage of quantized matrix multiplications for
quantized k- and/or v-cache, we no longer need the
helper methods loading entire rows.

* Disallow mixing bf16 with other types for kv caches

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2024-09-12 19:03:20 +03:00
committed by GitHub
parent 7874e4425f
commit e25c2e7ec2

File diff suppressed because it is too large Load Diff