Fused FFN_UP+FFN_GATE op (#741)

* Fused up+gate+unary for regular (not MoE) FFN - CPU

* WIP CUDA

* Seems to be working on CUDA

For a dense model we get 2-3% speedup for PP and ~0.6% for TG.

* Add command line option

This time the option is ON by default, and one needs to turn it
off via -no-fug or --no-fused-up-gate

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-08-31 18:16:36 +03:00
committed by GitHub
parent f22a9ef95a
commit b66cecca45
10 changed files with 276 additions and 12 deletions

View File

@@ -2072,6 +2072,7 @@ struct llama_cparams {
int mla_attn;
int attn_max_batch;
bool fused_moe_up_gate;
bool fused_up_gate;
int min_experts;
float thresh_experts;
@@ -7612,6 +7613,34 @@ static struct ggml_tensor * llm_build_ffn(
llm_ffn_gate_type type_gate,
const llm_build_cb & cb,
int il) {
if (lctx.cparams.fused_up_gate &&
up && gate && !up_b && !up_s && !gate_b && !gate_s && type_gate == LLM_FFN_PAR &&
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU;
cur = ggml_fused_up_gate(ctx, up, gate, cur, unary_op);
cb(cur, "ffn_up_gate", il);
if (down) {
cur = llm_build_lora_mm(lctx, ctx, down, cur);
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (down_b) {
cb(cur, "ffn_down", il);
}
if (down_b) {
cur = ggml_add(ctx, cur, down_b);
}
if (down_s) {
cur = ggml_mul(ctx, cur, down_s);
cb(cur, "ffn_down_s", il);
}
return cur;
}
struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
cb(tmp, "ffn_up", il);
@@ -8223,6 +8252,7 @@ struct llm_build_context {
const int mla_attn;
const int attn_max_batch;
const bool fused_moe_up_gate;
const bool fused_up_gate;
const int min_experts;
const float thresh_experts;
@@ -8278,6 +8308,7 @@ struct llm_build_context {
mla_attn (cparams.mla_attn),
attn_max_batch (cparams.attn_max_batch),
fused_moe_up_gate(cparams.fused_moe_up_gate),
fused_up_gate (cparams.fused_up_gate),
min_experts (cparams.min_experts),
thresh_experts (cparams.thresh_experts),
pooling_type (cparams.pooling_type),
@@ -18923,6 +18954,7 @@ struct llama_context_params llama_context_default_params() {
/*.mla_attn =*/ 0,
/*.attn_max_batch =*/ 0,
/*.fused_moe_up_gate =*/ false,
/*.fused_up_gate =*/ true,
/*.min_experts =*/ -1,
/*.thtesh_experts =*/ 0.0f,
/*.abort_callback =*/ nullptr,
@@ -19130,6 +19162,7 @@ struct llama_context * llama_new_context_with_model(
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
cparams.fused_up_gate = params.fused_up_gate;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
@@ -19209,6 +19242,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);