Fused mul + multi_add op (#858)

* Adding fused mul+multi_add + CPU implementation

* fused mul+multi_add: CUDA

* fused mul+multi_add: command line argument to disable it

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-24 07:40:35 +03:00
committed by GitHub
parent 483cea527d
commit db3ba4999f
15 changed files with 211 additions and 38 deletions

View File

@@ -50,6 +50,7 @@ llm_build_context::llm_build_context(
fused_moe_up_gate(cparams.fused_moe_up_gate),
grouped_expert_routing(cparams.grouped_expert_routing),
fused_up_gate (cparams.fused_up_gate),
fused_mmad (cparams.fused_mmad),
min_experts (cparams.min_experts),
thresh_experts (cparams.thresh_experts),
pooling_type (cparams.pooling_type),
@@ -941,6 +942,11 @@ llm_expert_gating_func_type gating_op,
}
if (!weight_before_ffn) {
if (lctx.cparams.fused_mmad) {
experts = ggml_mul_multi_add(ctx, experts, weights);
cb(experts, "ffn_moe_weighted", il);
return experts;
}
experts = ggml_mul(ctx, experts, weights);
cb(experts, "ffn_moe_weighted", il);
}

View File

@@ -80,6 +80,7 @@ struct llm_build_context {
const bool fused_moe_up_gate;
const bool grouped_expert_routing;
const bool fused_up_gate;
const bool fused_mmad;
const int min_experts;
const float thresh_experts;

View File

@@ -33,6 +33,7 @@ struct llama_cparams {
bool fused_moe_up_gate;
bool grouped_expert_routing;
bool fused_up_gate;
bool fused_mmad;
int min_experts;
float thresh_experts;

View File

@@ -3756,6 +3756,7 @@ struct llama_context_params llama_context_default_params() {
/*.fused_moe_up_gate =*/ false,
/*.grouped_expert_routing =*/ false,
/*.fused_up_gate =*/ true,
/*.fused_mmad =*/ true,
/*.min_experts =*/ -1,
/*.thtesh_experts =*/ 0.0f,
/*.only_active_experts =*/ false,
@@ -3966,6 +3967,7 @@ struct llama_context * llama_new_context_with_model(
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
cparams.grouped_expert_routing = params.grouped_expert_routing;
cparams.fused_up_gate = params.fused_up_gate;
cparams.fused_mmad = params.fused_mmad;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
@@ -4047,6 +4049,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing);
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);