mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 11:51:53 +00:00
This reduces compute buffer size for MLA
This commit is contained in:
@@ -855,6 +855,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.mla_attn = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-amb" || arg == "--attention-max-batch") {
|
||||
CHECK_ARG
|
||||
params.attn_max_batch = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-fmoe" || arg == "--fused-moe") {
|
||||
params.fused_moe_up_gate = true;
|
||||
return true;
|
||||
@@ -1516,6 +1521,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
|
||||
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
|
||||
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
|
||||
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
|
||||
"in conversation mode, this will be used as system prompt\n"
|
||||
@@ -2360,6 +2366,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.mla_attn = params.mla_attn;
|
||||
cparams.attn_max_batch = params.attn_max_batch;
|
||||
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
@@ -3359,6 +3366,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
|
||||
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
|
||||
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
|
||||
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
|
||||
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
|
||||
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
||||
|
||||
|
||||
@@ -175,7 +175,8 @@ struct gpt_params {
|
||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||
bool flash_attn = false; // flash attention
|
||||
int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
|
||||
int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
|
||||
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
|
||||
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
|
||||
@@ -384,6 +384,7 @@ extern "C" {
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
|
||||
int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
|
||||
bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL]
|
||||
|
||||
// Abort callback
|
||||
|
||||
110
src/llama.cpp
110
src/llama.cpp
@@ -2511,6 +2511,7 @@ struct llama_cparams {
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
int mla_attn;
|
||||
int attn_max_batch;
|
||||
bool fused_moe_up_gate;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
@@ -8924,6 +8925,7 @@ struct llm_build_context {
|
||||
|
||||
const bool flash_attn;
|
||||
const int mla_attn;
|
||||
const int attn_max_batch;
|
||||
const bool fused_moe_up_gate;
|
||||
|
||||
const enum llama_pooling_type pooling_type;
|
||||
@@ -8976,6 +8978,7 @@ struct llm_build_context {
|
||||
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
||||
flash_attn (cparams.flash_attn),
|
||||
mla_attn (cparams.mla_attn),
|
||||
attn_max_batch (cparams.attn_max_batch),
|
||||
fused_moe_up_gate(cparams.fused_moe_up_gate),
|
||||
pooling_type (cparams.pooling_type),
|
||||
rope_type (hparams.rope_type),
|
||||
@@ -13572,25 +13575,6 @@ struct llm_build_context {
|
||||
|
||||
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
|
||||
cb(q, "q", il);
|
||||
if (!pp_opt) {
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
cb(q, "q_perm", il);
|
||||
}
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
if (!pp_opt) {
|
||||
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
|
||||
cb(kq, "kq_perm", il);
|
||||
}
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
if (!pp_opt) {
|
||||
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
|
||||
cb(kq, "kq_soft_max_ext_perm", il);
|
||||
}
|
||||
|
||||
if (lctx.cparams.mla_attn > 1) {
|
||||
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
|
||||
@@ -13602,12 +13586,87 @@ struct llm_build_context {
|
||||
cb(kv_cache_trans, "kv_cache_trans", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
ggml_tensor * kqv_compressed;
|
||||
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kv_cache->ne[1]) {
|
||||
if (!pp_opt) {
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
cb(q, "q_perm", il);
|
||||
}
|
||||
|
||||
if (!pp_opt) {
|
||||
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
|
||||
cb(kqv_compressed, "kqv_compressed_perm", il);
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
//printf("kq (%ld x %ld x %ld x %ld) = kv_cache (%ld x %ld x %ld x %ld) * q (%ld x %ld x %ld x %ld)\n", kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3],
|
||||
// kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2], kv_cache->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
||||
|
||||
if (!pp_opt) {
|
||||
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
|
||||
cb(kq, "kq_perm", il);
|
||||
}
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
if (!pp_opt) {
|
||||
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
|
||||
cb(kq, "kq_soft_max_ext_perm", il);
|
||||
}
|
||||
|
||||
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
|
||||
//printf("kqv (%ld x %ld x %ld x %ld) = kv_cache_trans (%ld x %ld x %ld x %ld) * kq (%ld x %ld x %ld x %ld)\n",
|
||||
// kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2], kqv_compressed->ne[3],
|
||||
// kv_cache_trans->ne[0], kv_cache_trans->ne[1], kv_cache_trans->ne[2], kv_cache_trans->ne[3], kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3]);
|
||||
|
||||
if (!pp_opt) {
|
||||
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
|
||||
cb(kqv_compressed, "kqv_compressed_perm", il);
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
int n_step = (q->ne[1] + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
|
||||
|
||||
//kqv_compressed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, kv_cache_trans->ne[1], q->ne[1], q->ne[2]);
|
||||
//printf("q->ne[1] = %ld -> need %d steps\n", q->ne[1], n_step);
|
||||
//printf("Created kqv_compressed = %ld x %ld x %ld\n", kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2]);
|
||||
|
||||
for (int i_head = 0; i_head < q->ne[2]; ++i_head) {
|
||||
ggml_tensor * q_i = ggml_view_2d(ctx0, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i_head);
|
||||
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
|
||||
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
|
||||
if (i_head == 0) {
|
||||
kqv_compressed = ggml_view_3d(ctx0, kqv_i, kqv_i->ne[0], kqv_i->ne[1], 1, kqv_i->nb[1], kqv_i->nb[2], 0);
|
||||
} else {
|
||||
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
|
||||
}
|
||||
ggml_build_forward_expand(gf, kqv_compressed);
|
||||
//ggml_tensor * kqv_compressed_i = ggml_view_1d(ctx0, kqv_compressed, ggml_nelements(kqv_i), kqv_compressed->nb[2]*i_head);
|
||||
//ggml_tensor * head_i = ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
|
||||
//ggml_build_forward_expand(gf, head_i);
|
||||
}
|
||||
|
||||
//for (int i_step = 0; i_step < n_step; ++i_step) {
|
||||
// int i_start = i_step * lctx.cparams.attn_max_batch;
|
||||
// int this_batch = i_start + lctx.cparams.attn_max_batch <= q->ne[1] ? lctx.cparams.attn_max_batch : q->ne[1] - i_start;
|
||||
// ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], this_batch, q->ne[2], q->nb[1], q->nb[2], i_start*q->nb[1]);
|
||||
// cb(q_i, "q_i", il);
|
||||
// ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
|
||||
// cb(kq_i, "kq_i", il);
|
||||
// ggml_tensor * mask_i = ggml_view_2d(ctx0, KQ_mask, KQ_mask->ne[0], this_batch, KQ_mask->nb[1], i_start*KQ_mask->nb[1]);
|
||||
// kq_i = ggml_soft_max_ext(ctx0, kq_i, mask_i, kq_scale, hparams.f_max_alibi_bias);
|
||||
// cb(kq_i, "kq_i_softmwax", il);
|
||||
// ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
|
||||
// cb(kqv_i, "kqv_i", il);
|
||||
// ggml_tensor * kqv_compressed_i = ggml_view_3d(ctx0, kqv_compressed, kqv_compressed->ne[0], this_batch, kqv_compressed->ne[2],
|
||||
// kqv_compressed->nb[1], kqv_compressed->nb[2], i_start*kqv_compressed->nb[1]);
|
||||
// printf("step %d (%d tokens): kqv_i = %ld x %ld x %ld, kqv_compressed_i = %ld x %ld x %ld\n", i_step, this_batch,
|
||||
// kqv_i->ne[0], kqv_i->ne[1], kqv_i->ne[2], kqv_compressed_i->ne[0], kqv_compressed_i->ne[1], kqv_compressed_i->ne[2]);
|
||||
// ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
|
||||
//}
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
|
||||
@@ -17644,6 +17703,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.flash_attn =*/ false,
|
||||
/*.mla_attn =*/ 0,
|
||||
/*.attn_max_batch =*/ 0,
|
||||
/*.fused_moe_up_gate =*/ false,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
@@ -17844,6 +17904,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.mla_attn = params.mla_attn;
|
||||
cparams.attn_max_batch = params.attn_max_batch;
|
||||
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
@@ -17912,6 +17973,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
|
||||
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
|
||||
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
Reference in New Issue
Block a user