From 3683e50660461bf8000962a072a53b47d4f8f53f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 16 Oct 2025 10:54:14 +0300 Subject: [PATCH] Add command line option to enable grouped expert routing --- common/common.cpp | 7 +++++++ common/common.h | 1 + include/llama.h | 1 + src/llama-build-context.cpp | 3 ++- src/llama-build-context.h | 1 + src/llama-cparams.h | 1 + src/llama.cpp | 3 +++ 7 files changed, 16 insertions(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 287ec58f..e7ade95d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1012,6 +1012,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.fused_moe_up_gate = true; return true; } + if (arg == "-ger" || arg == "--grouped-expert-routing") { + params.grouped_expert_routing = true; + return true; + } if (arg == "-no-fug" || arg == "--no-fused-up-gate") { params.fused_up_gate = false; return true; @@ -1800,6 +1804,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); + options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" }); options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" @@ -2755,6 +2760,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate = params.fused_moe_up_gate; + cparams.grouped_expert_routing = params.grouped_expert_routing; cparams.fused_up_gate = params.fused_up_gate; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; @@ -3871,6 +3877,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); + fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false"); fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false"); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index 2b4d1540..ddd50755 100644 --- a/common/common.h +++ b/common/common.h @@ -235,6 +235,7 @@ struct gpt_params { int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool fused_up_gate = true; // fused up*unary(gate) op + bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch) int min_experts = -1; float thresh_experts = 0; diff --git a/include/llama.h b/include/llama.h index 4f9fc9c8..d24de230 100644 --- a/include/llama.h +++ b/include/llama.h @@ -420,6 +420,7 @@ extern "C" { int mla_attn; // whether to use MLA attention [EXPERIMENTAL] int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL] bool fused_moe_up_gate; // whether to use fused MoE up/gate op + bool grouped_expert_routing; // whether to use grouped expert routing (BailingMoeV2 arch) bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL] int min_experts; float thresh_experts; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index cfe99e6b..b3208dcd 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -48,6 +48,7 @@ llm_build_context::llm_build_context( mla_attn (cparams.mla_attn), attn_max_batch (cparams.attn_max_batch), fused_moe_up_gate(cparams.fused_moe_up_gate), + grouped_expert_routing(cparams.grouped_expert_routing), fused_up_gate (cparams.fused_up_gate), min_experts (cparams.min_experts), thresh_experts (cparams.thresh_experts), @@ -822,7 +823,7 @@ llm_expert_gating_func_type gating_op, // select experts ggml_tensor * selected_experts; - if (true && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { + if (lctx.cparams.grouped_expert_routing && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { auto& hparams = lctx.model.hparams; selected_experts = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2, n_expert_used); } else { diff --git a/src/llama-build-context.h b/src/llama-build-context.h index a1f0b8ae..2381a656 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -78,6 +78,7 @@ struct llm_build_context { const int mla_attn; const int attn_max_batch; const bool fused_moe_up_gate; + const bool grouped_expert_routing; const bool fused_up_gate; const int min_experts; const float thresh_experts; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index e8ec0f74..cbfb4949 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -31,6 +31,7 @@ struct llama_cparams { int mla_attn; int attn_max_batch; bool fused_moe_up_gate; + bool grouped_expert_routing; bool fused_up_gate; int min_experts; float thresh_experts; diff --git a/src/llama.cpp b/src/llama.cpp index 5dc11e47..928d66b0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3754,6 +3754,7 @@ struct llama_context_params llama_context_default_params() { /*.mla_attn =*/ 0, /*.attn_max_batch =*/ 0, /*.fused_moe_up_gate =*/ false, + /*.grouped_expert_routing =*/ false, /*.fused_up_gate =*/ true, /*.min_experts =*/ -1, /*.thtesh_experts =*/ 0.0f, @@ -3963,6 +3964,7 @@ struct llama_context * llama_new_context_with_model( cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate= params.fused_moe_up_gate; + cparams.grouped_expert_routing = params.grouped_expert_routing; cparams.fused_up_gate = params.fused_up_gate; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; @@ -4043,6 +4045,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); + LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing); LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);