mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-07 20:40:02 +00:00
This reduces compute buffer size for MLA
This commit is contained in:
110
src/llama.cpp
110
src/llama.cpp
@@ -2511,6 +2511,7 @@ struct llama_cparams {
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
int mla_attn;
|
||||
int attn_max_batch;
|
||||
bool fused_moe_up_gate;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
@@ -8924,6 +8925,7 @@ struct llm_build_context {
|
||||
|
||||
const bool flash_attn;
|
||||
const int mla_attn;
|
||||
const int attn_max_batch;
|
||||
const bool fused_moe_up_gate;
|
||||
|
||||
const enum llama_pooling_type pooling_type;
|
||||
@@ -8976,6 +8978,7 @@ struct llm_build_context {
|
||||
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
||||
flash_attn (cparams.flash_attn),
|
||||
mla_attn (cparams.mla_attn),
|
||||
attn_max_batch (cparams.attn_max_batch),
|
||||
fused_moe_up_gate(cparams.fused_moe_up_gate),
|
||||
pooling_type (cparams.pooling_type),
|
||||
rope_type (hparams.rope_type),
|
||||
@@ -13572,25 +13575,6 @@ struct llm_build_context {
|
||||
|
||||
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
|
||||
cb(q, "q", il);
|
||||
if (!pp_opt) {
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
cb(q, "q_perm", il);
|
||||
}
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
if (!pp_opt) {
|
||||
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
|
||||
cb(kq, "kq_perm", il);
|
||||
}
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
if (!pp_opt) {
|
||||
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
|
||||
cb(kq, "kq_soft_max_ext_perm", il);
|
||||
}
|
||||
|
||||
if (lctx.cparams.mla_attn > 1) {
|
||||
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
|
||||
@@ -13602,12 +13586,87 @@ struct llm_build_context {
|
||||
cb(kv_cache_trans, "kv_cache_trans", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
ggml_tensor * kqv_compressed;
|
||||
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kv_cache->ne[1]) {
|
||||
if (!pp_opt) {
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
cb(q, "q_perm", il);
|
||||
}
|
||||
|
||||
if (!pp_opt) {
|
||||
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
|
||||
cb(kqv_compressed, "kqv_compressed_perm", il);
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
//printf("kq (%ld x %ld x %ld x %ld) = kv_cache (%ld x %ld x %ld x %ld) * q (%ld x %ld x %ld x %ld)\n", kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3],
|
||||
// kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2], kv_cache->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
||||
|
||||
if (!pp_opt) {
|
||||
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
|
||||
cb(kq, "kq_perm", il);
|
||||
}
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
if (!pp_opt) {
|
||||
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
|
||||
cb(kq, "kq_soft_max_ext_perm", il);
|
||||
}
|
||||
|
||||
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
|
||||
//printf("kqv (%ld x %ld x %ld x %ld) = kv_cache_trans (%ld x %ld x %ld x %ld) * kq (%ld x %ld x %ld x %ld)\n",
|
||||
// kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2], kqv_compressed->ne[3],
|
||||
// kv_cache_trans->ne[0], kv_cache_trans->ne[1], kv_cache_trans->ne[2], kv_cache_trans->ne[3], kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3]);
|
||||
|
||||
if (!pp_opt) {
|
||||
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
|
||||
cb(kqv_compressed, "kqv_compressed_perm", il);
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
int n_step = (q->ne[1] + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
|
||||
|
||||
//kqv_compressed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, kv_cache_trans->ne[1], q->ne[1], q->ne[2]);
|
||||
//printf("q->ne[1] = %ld -> need %d steps\n", q->ne[1], n_step);
|
||||
//printf("Created kqv_compressed = %ld x %ld x %ld\n", kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2]);
|
||||
|
||||
for (int i_head = 0; i_head < q->ne[2]; ++i_head) {
|
||||
ggml_tensor * q_i = ggml_view_2d(ctx0, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i_head);
|
||||
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
|
||||
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
|
||||
if (i_head == 0) {
|
||||
kqv_compressed = ggml_view_3d(ctx0, kqv_i, kqv_i->ne[0], kqv_i->ne[1], 1, kqv_i->nb[1], kqv_i->nb[2], 0);
|
||||
} else {
|
||||
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
|
||||
}
|
||||
ggml_build_forward_expand(gf, kqv_compressed);
|
||||
//ggml_tensor * kqv_compressed_i = ggml_view_1d(ctx0, kqv_compressed, ggml_nelements(kqv_i), kqv_compressed->nb[2]*i_head);
|
||||
//ggml_tensor * head_i = ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
|
||||
//ggml_build_forward_expand(gf, head_i);
|
||||
}
|
||||
|
||||
//for (int i_step = 0; i_step < n_step; ++i_step) {
|
||||
// int i_start = i_step * lctx.cparams.attn_max_batch;
|
||||
// int this_batch = i_start + lctx.cparams.attn_max_batch <= q->ne[1] ? lctx.cparams.attn_max_batch : q->ne[1] - i_start;
|
||||
// ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], this_batch, q->ne[2], q->nb[1], q->nb[2], i_start*q->nb[1]);
|
||||
// cb(q_i, "q_i", il);
|
||||
// ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
|
||||
// cb(kq_i, "kq_i", il);
|
||||
// ggml_tensor * mask_i = ggml_view_2d(ctx0, KQ_mask, KQ_mask->ne[0], this_batch, KQ_mask->nb[1], i_start*KQ_mask->nb[1]);
|
||||
// kq_i = ggml_soft_max_ext(ctx0, kq_i, mask_i, kq_scale, hparams.f_max_alibi_bias);
|
||||
// cb(kq_i, "kq_i_softmwax", il);
|
||||
// ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
|
||||
// cb(kqv_i, "kqv_i", il);
|
||||
// ggml_tensor * kqv_compressed_i = ggml_view_3d(ctx0, kqv_compressed, kqv_compressed->ne[0], this_batch, kqv_compressed->ne[2],
|
||||
// kqv_compressed->nb[1], kqv_compressed->nb[2], i_start*kqv_compressed->nb[1]);
|
||||
// printf("step %d (%d tokens): kqv_i = %ld x %ld x %ld, kqv_compressed_i = %ld x %ld x %ld\n", i_step, this_batch,
|
||||
// kqv_i->ne[0], kqv_i->ne[1], kqv_i->ne[2], kqv_compressed_i->ne[0], kqv_compressed_i->ne[1], kqv_compressed_i->ne[2]);
|
||||
// ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
|
||||
//}
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
|
||||
@@ -17644,6 +17703,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.flash_attn =*/ false,
|
||||
/*.mla_attn =*/ 0,
|
||||
/*.attn_max_batch =*/ 0,
|
||||
/*.fused_moe_up_gate =*/ false,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
@@ -17844,6 +17904,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.mla_attn = params.mla_attn;
|
||||
cparams.attn_max_batch = params.attn_max_batch;
|
||||
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
@@ -17912,6 +17973,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
|
||||
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
|
||||
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
Reference in New Issue
Block a user