FlashMLA-2: eliminate intermediate f32 tensors

This works on the CPU. PP performance is ~13% better for 16k tokens
and compute buffer is quite a bit smaller.
This commit is contained in:
Iwan Kawrakow
2025-03-12 10:45:36 +02:00
parent 3f23ed68f1
commit f05484d9a3
5 changed files with 205 additions and 44 deletions

View File

@@ -13630,37 +13630,12 @@ struct llm_build_context {
if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) {
// Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not
// provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix
// multiplication, which *must* be f32.
auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0);
auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32);
cb(kv_cache_view_f32, "kv_cache_view_f32", il);
// The no- and rotational position encoding portions of the KV cache
auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0);
auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv,
kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank));
GGML_ASSERT(hparams.n_embd_head_v == n_embd_head_qk_nope);
auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
cb(kv_f32, "kv_f32", il);
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
cb(k_nope_f32, "k_nope_f32", il);
ggml_tensor repeater;
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1;
auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3);
cb(k_rope_f32, "k_rope_f32", il);
auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0);
cb(k_f32, "k_f32", il);
auto k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type);
cb(k, "k", il);
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
@@ -13670,6 +13645,64 @@ struct llm_build_context {
auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
cb(v, "v", il);
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
cb(k_nope_f32, "k_nope_f32", il);
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type);
cb(k_nope, "k_nope", il);
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
ggml_tensor repeater;
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1;
auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
cb(k_rope, "k_rope", il);
auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
cb(k, "k", il);
//// Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not
//// provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix
//// multiplication, which *must* be f32.
//auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0);
//auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32);
//cb(kv_cache_view_f32, "kv_cache_view_f32", il);
//// The no- and rotational position encoding portions of the KV cache
//auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0);
//auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv,
// kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank));
//auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
//cb(kv_f32, "kv_f32", il);
//auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
// ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
// ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
//cb(k_nope_f32, "k_nope_f32", il);
//ggml_tensor repeater;
//repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1;
//auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3);
//cb(k_rope_f32, "k_rope_f32", il);
//auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0);
//cb(k_f32, "k_f32", il);
//auto k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type);
//cb(k, "k", il);
//auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
// ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
// ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
// ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
//cb(v_f32, "v_f32", il);
//auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
//cb(v, "v", il);
auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_concat", il);