From addd8994cdc206ebaac774b1ee421185b5a3d670 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 28 Feb 2025 14:26:47 +0200 Subject: [PATCH] This reduces compute buffer size for MLA --- common/common.cpp | 8 ++++ common/common.h | 3 +- include/llama.h | 1 + src/llama.cpp | 110 ++++++++++++++++++++++++++++++++++++---------- 4 files changed, 97 insertions(+), 25 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6359426f..5c9070da 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -855,6 +855,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.mla_attn = std::stoi(argv[i]); return true; } + if (arg == "-amb" || arg == "--attention-max-batch") { + CHECK_ARG + params.attn_max_batch = std::stoi(argv[i]); + return true; + } if (arg == "-fmoe" || arg == "--fused-moe") { params.fused_moe_up_gate = true; return true; @@ -1516,6 +1521,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); + options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -2360,6 +2366,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate = params.fused_moe_up_gate; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); @@ -3359,6 +3366,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); + fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index ef5175f3..f35f3558 100644 --- a/common/common.h +++ b/common/common.h @@ -175,7 +175,8 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention - int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache + int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache + int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix diff --git a/include/llama.h b/include/llama.h index 2b33701c..bb43aebc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -384,6 +384,7 @@ extern "C" { bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] int mla_attn; // whether to use MLA attention [EXPERIMENTAL] + int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL] bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] // Abort callback diff --git a/src/llama.cpp b/src/llama.cpp index f2c5f9d4..1f6026d6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2511,6 +2511,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; int mla_attn; + int attn_max_batch; bool fused_moe_up_gate; enum llama_pooling_type pooling_type; @@ -8924,6 +8925,7 @@ struct llm_build_context { const bool flash_attn; const int mla_attn; + const int attn_max_batch; const bool fused_moe_up_gate; const enum llama_pooling_type pooling_type; @@ -8976,6 +8978,7 @@ struct llm_build_context { n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), mla_attn (cparams.mla_attn), + attn_max_batch (cparams.attn_max_batch), fused_moe_up_gate(cparams.fused_moe_up_gate), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -13572,25 +13575,6 @@ struct llm_build_context { ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - if (!pp_opt) { - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - cb(q, "q_perm", il); - } - ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); - cb(kq, "kq", il); - - if (!pp_opt) { - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); - } - - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - - if (!pp_opt) { - kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); - cb(kq, "kq_soft_max_ext_perm", il); - } if (lctx.cparams.mla_attn > 1) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], @@ -13602,12 +13586,87 @@ struct llm_build_context { cb(kv_cache_trans, "kv_cache_trans", il); } - struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); - cb(kqv_compressed, "kqv_compressed", il); + ggml_tensor * kqv_compressed; + if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kv_cache->ne[1]) { + if (!pp_opt) { + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_perm", il); + } - if (!pp_opt) { - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); + ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); + cb(kq, "kq", il); + + //printf("kq (%ld x %ld x %ld x %ld) = kv_cache (%ld x %ld x %ld x %ld) * q (%ld x %ld x %ld x %ld)\n", kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3], + // kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2], kv_cache->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } + + kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + cb(kqv_compressed, "kqv_compressed", il); + + //printf("kqv (%ld x %ld x %ld x %ld) = kv_cache_trans (%ld x %ld x %ld x %ld) * kq (%ld x %ld x %ld x %ld)\n", + // kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2], kqv_compressed->ne[3], + // kv_cache_trans->ne[0], kv_cache_trans->ne[1], kv_cache_trans->ne[2], kv_cache_trans->ne[3], kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3]); + + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + + } else { + + int n_step = (q->ne[1] + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; + + //kqv_compressed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, kv_cache_trans->ne[1], q->ne[1], q->ne[2]); + //printf("q->ne[1] = %ld -> need %d steps\n", q->ne[1], n_step); + //printf("Created kqv_compressed = %ld x %ld x %ld\n", kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2]); + + for (int i_head = 0; i_head < q->ne[2]; ++i_head) { + ggml_tensor * q_i = ggml_view_2d(ctx0, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i_head); + ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); + kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); + if (i_head == 0) { + kqv_compressed = ggml_view_3d(ctx0, kqv_i, kqv_i->ne[0], kqv_i->ne[1], 1, kqv_i->nb[1], kqv_i->nb[2], 0); + } else { + kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + } + ggml_build_forward_expand(gf, kqv_compressed); + //ggml_tensor * kqv_compressed_i = ggml_view_1d(ctx0, kqv_compressed, ggml_nelements(kqv_i), kqv_compressed->nb[2]*i_head); + //ggml_tensor * head_i = ggml_cpy(ctx0, kqv_i, kqv_compressed_i); + //ggml_build_forward_expand(gf, head_i); + } + + //for (int i_step = 0; i_step < n_step; ++i_step) { + // int i_start = i_step * lctx.cparams.attn_max_batch; + // int this_batch = i_start + lctx.cparams.attn_max_batch <= q->ne[1] ? lctx.cparams.attn_max_batch : q->ne[1] - i_start; + // ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], this_batch, q->ne[2], q->nb[1], q->nb[2], i_start*q->nb[1]); + // cb(q_i, "q_i", il); + // ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); + // cb(kq_i, "kq_i", il); + // ggml_tensor * mask_i = ggml_view_2d(ctx0, KQ_mask, KQ_mask->ne[0], this_batch, KQ_mask->nb[1], i_start*KQ_mask->nb[1]); + // kq_i = ggml_soft_max_ext(ctx0, kq_i, mask_i, kq_scale, hparams.f_max_alibi_bias); + // cb(kq_i, "kq_i_softmwax", il); + // ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); + // cb(kqv_i, "kqv_i", il); + // ggml_tensor * kqv_compressed_i = ggml_view_3d(ctx0, kqv_compressed, kqv_compressed->ne[0], this_batch, kqv_compressed->ne[2], + // kqv_compressed->nb[1], kqv_compressed->nb[2], i_start*kqv_compressed->nb[1]); + // printf("step %d (%d tokens): kqv_i = %ld x %ld x %ld, kqv_compressed_i = %ld x %ld x %ld\n", i_step, this_batch, + // kqv_i->ne[0], kqv_i->ne[1], kqv_i->ne[2], kqv_compressed_i->ne[0], kqv_compressed_i->ne[1], kqv_compressed_i->ne[2]); + // ggml_cpy(ctx0, kqv_i, kqv_compressed_i); + //} + cb(kqv_compressed, "kqv_compressed", il); } struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, @@ -17644,6 +17703,7 @@ struct llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.mla_attn =*/ 0, + /*.attn_max_batch =*/ 0, /*.fused_moe_up_gate =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -17844,6 +17904,7 @@ struct llama_context * llama_new_context_with_model( cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate= params.fused_moe_up_gate; cparams.pooling_type = params.pooling_type; @@ -17912,6 +17973,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); + LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);