This reduces compute buffer size for MLA

This commit is contained in:
Iwan Kawrakow
2025-02-28 14:26:47 +02:00
parent b762db7c92
commit addd8994cd
4 changed files with 97 additions and 25 deletions

View File

@@ -855,6 +855,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.mla_attn = std::stoi(argv[i]);
return true;
}
if (arg == "-amb" || arg == "--attention-max-batch") {
CHECK_ARG
params.attn_max_batch = std::stoi(argv[i]);
return true;
}
if (arg == "-fmoe" || arg == "--fused-moe") {
params.fused_moe_up_gate = true;
return true;
@@ -1516,6 +1521,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
@@ -2360,6 +2366,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
@@ -3359,6 +3366,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

View File

@@ -175,7 +175,8 @@ struct gpt_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix

View File

@@ -384,6 +384,7 @@ extern "C" {
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL]
// Abort callback

View File

@@ -2511,6 +2511,7 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
int mla_attn;
int attn_max_batch;
bool fused_moe_up_gate;
enum llama_pooling_type pooling_type;
@@ -8924,6 +8925,7 @@ struct llm_build_context {
const bool flash_attn;
const int mla_attn;
const int attn_max_batch;
const bool fused_moe_up_gate;
const enum llama_pooling_type pooling_type;
@@ -8976,6 +8978,7 @@ struct llm_build_context {
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
mla_attn (cparams.mla_attn),
attn_max_batch (cparams.attn_max_batch),
fused_moe_up_gate(cparams.fused_moe_up_gate),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
@@ -13572,25 +13575,6 @@ struct llm_build_context {
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
@@ -13602,12 +13586,87 @@ struct llm_build_context {
cb(kv_cache_trans, "kv_cache_trans", il);
}
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
ggml_tensor * kqv_compressed;
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kv_cache->ne[1]) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
//printf("kq (%ld x %ld x %ld x %ld) = kv_cache (%ld x %ld x %ld x %ld) * q (%ld x %ld x %ld x %ld)\n", kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3],
// kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2], kv_cache->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
//printf("kqv (%ld x %ld x %ld x %ld) = kv_cache_trans (%ld x %ld x %ld x %ld) * kq (%ld x %ld x %ld x %ld)\n",
// kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2], kqv_compressed->ne[3],
// kv_cache_trans->ne[0], kv_cache_trans->ne[1], kv_cache_trans->ne[2], kv_cache_trans->ne[3], kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3]);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
} else {
int n_step = (q->ne[1] + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
//kqv_compressed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, kv_cache_trans->ne[1], q->ne[1], q->ne[2]);
//printf("q->ne[1] = %ld -> need %d steps\n", q->ne[1], n_step);
//printf("Created kqv_compressed = %ld x %ld x %ld\n", kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2]);
for (int i_head = 0; i_head < q->ne[2]; ++i_head) {
ggml_tensor * q_i = ggml_view_2d(ctx0, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = ggml_view_3d(ctx0, kqv_i, kqv_i->ne[0], kqv_i->ne[1], 1, kqv_i->nb[1], kqv_i->nb[2], 0);
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
}
ggml_build_forward_expand(gf, kqv_compressed);
//ggml_tensor * kqv_compressed_i = ggml_view_1d(ctx0, kqv_compressed, ggml_nelements(kqv_i), kqv_compressed->nb[2]*i_head);
//ggml_tensor * head_i = ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
//ggml_build_forward_expand(gf, head_i);
}
//for (int i_step = 0; i_step < n_step; ++i_step) {
// int i_start = i_step * lctx.cparams.attn_max_batch;
// int this_batch = i_start + lctx.cparams.attn_max_batch <= q->ne[1] ? lctx.cparams.attn_max_batch : q->ne[1] - i_start;
// ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], this_batch, q->ne[2], q->nb[1], q->nb[2], i_start*q->nb[1]);
// cb(q_i, "q_i", il);
// ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
// cb(kq_i, "kq_i", il);
// ggml_tensor * mask_i = ggml_view_2d(ctx0, KQ_mask, KQ_mask->ne[0], this_batch, KQ_mask->nb[1], i_start*KQ_mask->nb[1]);
// kq_i = ggml_soft_max_ext(ctx0, kq_i, mask_i, kq_scale, hparams.f_max_alibi_bias);
// cb(kq_i, "kq_i_softmwax", il);
// ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
// cb(kqv_i, "kqv_i", il);
// ggml_tensor * kqv_compressed_i = ggml_view_3d(ctx0, kqv_compressed, kqv_compressed->ne[0], this_batch, kqv_compressed->ne[2],
// kqv_compressed->nb[1], kqv_compressed->nb[2], i_start*kqv_compressed->nb[1]);
// printf("step %d (%d tokens): kqv_i = %ld x %ld x %ld, kqv_compressed_i = %ld x %ld x %ld\n", i_step, this_batch,
// kqv_i->ne[0], kqv_i->ne[1], kqv_i->ne[2], kqv_compressed_i->ne[0], kqv_compressed_i->ne[1], kqv_compressed_i->ne[2]);
// ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
//}
cb(kqv_compressed, "kqv_compressed", il);
}
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
@@ -17644,6 +17703,7 @@ struct llama_context_params llama_context_default_params() {
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.mla_attn =*/ 0,
/*.attn_max_batch =*/ 0,
/*.fused_moe_up_gate =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
@@ -17844,6 +17904,7 @@ struct llama_context * llama_new_context_with_model(
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
cparams.pooling_type = params.pooling_type;
@@ -17912,6 +17973,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);