This reduces compute buffer size for MLA

This commit is contained in:
Iwan Kawrakow
2025-02-28 14:26:47 +02:00
parent b762db7c92
commit addd8994cd
4 changed files with 97 additions and 25 deletions

View File

@@ -2511,6 +2511,7 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
int mla_attn;
int attn_max_batch;
bool fused_moe_up_gate;
enum llama_pooling_type pooling_type;
@@ -8924,6 +8925,7 @@ struct llm_build_context {
const bool flash_attn;
const int mla_attn;
const int attn_max_batch;
const bool fused_moe_up_gate;
const enum llama_pooling_type pooling_type;
@@ -8976,6 +8978,7 @@ struct llm_build_context {
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
mla_attn (cparams.mla_attn),
attn_max_batch (cparams.attn_max_batch),
fused_moe_up_gate(cparams.fused_moe_up_gate),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
@@ -13572,25 +13575,6 @@ struct llm_build_context {
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
@@ -13602,12 +13586,87 @@ struct llm_build_context {
cb(kv_cache_trans, "kv_cache_trans", il);
}
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
ggml_tensor * kqv_compressed;
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kv_cache->ne[1]) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
//printf("kq (%ld x %ld x %ld x %ld) = kv_cache (%ld x %ld x %ld x %ld) * q (%ld x %ld x %ld x %ld)\n", kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3],
// kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2], kv_cache->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
//printf("kqv (%ld x %ld x %ld x %ld) = kv_cache_trans (%ld x %ld x %ld x %ld) * kq (%ld x %ld x %ld x %ld)\n",
// kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2], kqv_compressed->ne[3],
// kv_cache_trans->ne[0], kv_cache_trans->ne[1], kv_cache_trans->ne[2], kv_cache_trans->ne[3], kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3]);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
} else {
int n_step = (q->ne[1] + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
//kqv_compressed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, kv_cache_trans->ne[1], q->ne[1], q->ne[2]);
//printf("q->ne[1] = %ld -> need %d steps\n", q->ne[1], n_step);
//printf("Created kqv_compressed = %ld x %ld x %ld\n", kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2]);
for (int i_head = 0; i_head < q->ne[2]; ++i_head) {
ggml_tensor * q_i = ggml_view_2d(ctx0, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = ggml_view_3d(ctx0, kqv_i, kqv_i->ne[0], kqv_i->ne[1], 1, kqv_i->nb[1], kqv_i->nb[2], 0);
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
}
ggml_build_forward_expand(gf, kqv_compressed);
//ggml_tensor * kqv_compressed_i = ggml_view_1d(ctx0, kqv_compressed, ggml_nelements(kqv_i), kqv_compressed->nb[2]*i_head);
//ggml_tensor * head_i = ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
//ggml_build_forward_expand(gf, head_i);
}
//for (int i_step = 0; i_step < n_step; ++i_step) {
// int i_start = i_step * lctx.cparams.attn_max_batch;
// int this_batch = i_start + lctx.cparams.attn_max_batch <= q->ne[1] ? lctx.cparams.attn_max_batch : q->ne[1] - i_start;
// ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], this_batch, q->ne[2], q->nb[1], q->nb[2], i_start*q->nb[1]);
// cb(q_i, "q_i", il);
// ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
// cb(kq_i, "kq_i", il);
// ggml_tensor * mask_i = ggml_view_2d(ctx0, KQ_mask, KQ_mask->ne[0], this_batch, KQ_mask->nb[1], i_start*KQ_mask->nb[1]);
// kq_i = ggml_soft_max_ext(ctx0, kq_i, mask_i, kq_scale, hparams.f_max_alibi_bias);
// cb(kq_i, "kq_i_softmwax", il);
// ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
// cb(kqv_i, "kqv_i", il);
// ggml_tensor * kqv_compressed_i = ggml_view_3d(ctx0, kqv_compressed, kqv_compressed->ne[0], this_batch, kqv_compressed->ne[2],
// kqv_compressed->nb[1], kqv_compressed->nb[2], i_start*kqv_compressed->nb[1]);
// printf("step %d (%d tokens): kqv_i = %ld x %ld x %ld, kqv_compressed_i = %ld x %ld x %ld\n", i_step, this_batch,
// kqv_i->ne[0], kqv_i->ne[1], kqv_i->ne[2], kqv_compressed_i->ne[0], kqv_compressed_i->ne[1], kqv_compressed_i->ne[2]);
// ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
//}
cb(kqv_compressed, "kqv_compressed", il);
}
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
@@ -17644,6 +17703,7 @@ struct llama_context_params llama_context_default_params() {
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.mla_attn =*/ 0,
/*.attn_max_batch =*/ 0,
/*.fused_moe_up_gate =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
@@ -17844,6 +17904,7 @@ struct llama_context * llama_new_context_with_model(
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
cparams.pooling_type = params.pooling_type;
@@ -17912,6 +17973,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);