Flash MLA (CPU only) (#240)

* FlashMLA - it finally works (on the CPU)

* FlashMLA: allow for f16 and bf16 cache in addition to q8_0

* It works with ggml FA, not with iqk FA

* WIP

* FlashMLA: it now works with iqk

I had forgotten to divide the Q stride by sizeof(float) and
that's why, very cobfusingly, it was working for TG but not for PP.

* WIP

* FlashMLA: that should be it for now

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-03 15:17:51 +02:00
committed by GitHub
parent a89adaa78f
commit a87e54db6e
4 changed files with 183 additions and 79 deletions

View File

@@ -3168,7 +3168,7 @@ static bool llama_kv_cache_init(
// DeepSeek MLA
cache.kv_l.reserve(n_layer);
if (cparams.mla_attn == 1) {
if (cparams.mla_attn == 1 && !cparams.flash_attn) {
cache.kvt_l.reserve(n_layer);
}
@@ -3201,14 +3201,20 @@ static bool llama_kv_cache_init(
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v;
ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size);
ggml_format_name(kv, "cache_kv_l%d", i);
cache.kv_l.push_back(kv);
if (cparams.mla_attn == 1) {
ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size);
ggml_format_name(kvt, "cache_kvt_l%d", i);
cache.kvt_l.push_back(kvt);
if (cparams.flash_attn) {
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
ggml_format_name(kv, "cache_kv_l%d", i);
cache.kv_l.push_back(kv);
} else {
auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v;
ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size);
ggml_format_name(kv, "cache_kv_l%d", i);
cache.kv_l.push_back(kv);
if (cparams.mla_attn == 1) {
ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size);
ggml_format_name(kvt, "cache_kvt_l%d", i);
cache.kvt_l.push_back(kvt);
}
}
n_mla++;
}
@@ -13588,7 +13594,7 @@ struct llm_build_context {
ggml_tensor * kv_cache_trans;
if (lctx.cparams.mla_attn == 1) {
if (lctx.cparams.mla_attn == 1 && !lctx.cparams.flash_attn) {
ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank,
ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head));
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
@@ -13630,70 +13636,88 @@ struct llm_build_context {
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kqv_compressed;
if (lctx.cparams.flash_attn) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
cb(kv_cache, "kv_cache_lora", il);
cb(kv_cache_lora, "kv_cache_lora", il);
kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
cb(kv_cache_trans, "kv_cache_trans", il);
//ggml_tensor * v = ggml_cont(ctx0, kv_cache_lora);
//kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
cb(kqv_compressed, "kqv_compressed", il);
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
else {
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
cb(kv_cache, "kv_cache_lora", il);
ggml_tensor * kqv_compressed;
auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
cb(kv_cache_trans, "kv_cache_trans", il);
}
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
} else {
int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
n_step = std::min(n_step, int(q->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
//printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = kqv_i;
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
ggml_build_forward_expand(gf, kqv_compressed);
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
} else {
int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
n_step = std::min(n_step, int(q->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
//printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = kqv_i;
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
}
ggml_build_forward_expand(gf, kqv_compressed);
}
cb(kqv_compressed, "kqv_compressed", il);
}
cb(kqv_compressed, "kqv_compressed", il);
}
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
@@ -18226,7 +18250,7 @@ struct llama_context * llama_new_context_with_model(
}
if (memory_size_kv + memory_size_kvt > 0) {
if (cparams.mla_attn == 1) {
if (cparams.mla_attn == 1 && !cparams.flash_attn) {
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f),