Add command line option to enable grouped expert routing

This commit is contained in:
Iwan Kawrakow
2025-10-16 10:54:14 +03:00
parent c30c35b007
commit 3683e50660
7 changed files with 16 additions and 1 deletions

View File

@@ -1012,6 +1012,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.fused_moe_up_gate = true;
return true;
}
if (arg == "-ger" || arg == "--grouped-expert-routing") {
params.grouped_expert_routing = true;
return true;
}
if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
params.fused_up_gate = false;
return true;
@@ -1800,6 +1804,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" });
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
@@ -2755,6 +2760,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
cparams.grouped_expert_routing = params.grouped_expert_routing;
cparams.fused_up_gate = params.fused_up_gate;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
@@ -3871,6 +3877,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

View File

@@ -235,6 +235,7 @@ struct gpt_params {
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
bool fused_up_gate = true; // fused up*unary(gate) op
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
int min_experts = -1;
float thresh_experts = 0;

View File

@@ -420,6 +420,7 @@ extern "C" {
int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
bool fused_moe_up_gate; // whether to use fused MoE up/gate op
bool grouped_expert_routing; // whether to use grouped expert routing (BailingMoeV2 arch)
bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL]
int min_experts;
float thresh_experts;

View File

@@ -48,6 +48,7 @@ llm_build_context::llm_build_context(
mla_attn (cparams.mla_attn),
attn_max_batch (cparams.attn_max_batch),
fused_moe_up_gate(cparams.fused_moe_up_gate),
grouped_expert_routing(cparams.grouped_expert_routing),
fused_up_gate (cparams.fused_up_gate),
min_experts (cparams.min_experts),
thresh_experts (cparams.thresh_experts),
@@ -822,7 +823,7 @@ llm_expert_gating_func_type gating_op,
// select experts
ggml_tensor * selected_experts;
if (true && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
if (lctx.cparams.grouped_expert_routing && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
auto& hparams = lctx.model.hparams;
selected_experts = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2, n_expert_used);
} else {

View File

@@ -78,6 +78,7 @@ struct llm_build_context {
const int mla_attn;
const int attn_max_batch;
const bool fused_moe_up_gate;
const bool grouped_expert_routing;
const bool fused_up_gate;
const int min_experts;
const float thresh_experts;

View File

@@ -31,6 +31,7 @@ struct llama_cparams {
int mla_attn;
int attn_max_batch;
bool fused_moe_up_gate;
bool grouped_expert_routing;
bool fused_up_gate;
int min_experts;
float thresh_experts;

View File

@@ -3754,6 +3754,7 @@ struct llama_context_params llama_context_default_params() {
/*.mla_attn =*/ 0,
/*.attn_max_batch =*/ 0,
/*.fused_moe_up_gate =*/ false,
/*.grouped_expert_routing =*/ false,
/*.fused_up_gate =*/ true,
/*.min_experts =*/ -1,
/*.thtesh_experts =*/ 0.0f,
@@ -3963,6 +3964,7 @@ struct llama_context * llama_new_context_with_model(
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
cparams.grouped_expert_routing = params.grouped_expert_routing;
cparams.fused_up_gate = params.fused_up_gate;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
@@ -4043,6 +4045,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing);
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);