FlashMLA-3: the best of both worlds (CPU only) (#273)

* Repack a model with the quantize tool

* WIP

* Fixed various issues

As we don't have a way to tell if a repacked quant has been modified,
I had to remove the modification at the expense of a slight decrease
in performance. This affects q8_0_r8, q8_KV_r8, q8_k_r8 on Zen4, and
q4_0_r8 on ARM.

* Create wk_b and wv_b as Q8_0_R8 if the wkv_b type is interleaved

* Fix GCC 13.3 compilation error

* Another one

* Add missing include

* FlashMLA-3: the best of both worlds - CPU only

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-21 07:24:22 +01:00
committed by GitHub
parent b8d1fac97b
commit ddc8eee10e

View File

@@ -13760,7 +13760,7 @@ struct llm_build_context {
ggml_tensor * kqv;
if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) {
if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && pp_opt) { // PP for mla=2,3
auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
@@ -13869,7 +13869,7 @@ struct llm_build_context {
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1) {
if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);