From 109686af6f7ca4441adb556569dafcf9fa235478 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 25 Jan 2026 14:38:54 +0000 Subject: [PATCH] Faster hybrid inference when shared experts --- src/llama-build-context.cpp | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index a1586849..7594fbff 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1207,6 +1207,19 @@ llm_expert_gating_func_type gating_op, GGML_ASSERT(!split_up_b_shexp || split_up_b_shexp->n_device == split_up_shexp->n_device); GGML_ASSERT(!split_gate_b_shexp || split_gate_b_shexp->n_device == split_up_shexp->n_device); GGML_ASSERT(!split_down_b_shexp || split_down_b_shexp->n_device == split_up_shexp->n_device); + bool down_bias_added = false; + int id_add_routed = -1; + if (split_up_shexp->splits[lctx.model.main_gpu]) { + id_add_routed = lctx.model.main_gpu; + } else { + for (int id = 0; id < split_up_shexp->n_device; ++id) { + if (split_up_shexp->splits[id]) { + id_add_routed = id; + break; + } + } + } + GGML_ASSERT(id_add_routed >= 0); for (int id = 0; id < split_up_shexp->n_device; ++id) { int il_cb = 1000*id + il; GGML_ASSERT((split_up_shexp->splits[id] && split_gate_shexp->splits[id] && split_down_shexp->splits[id]) || @@ -1216,32 +1229,18 @@ llm_expert_gating_func_type gating_op, auto shared_out = llm_build_ffn(ctx, lctx, the_ffn_norm, input, split_up_shexp->splits[id], split_up_b_shexp ? split_up_b_shexp->splits[id] : nullptr, nullptr, split_gate_shexp->splits[id], split_gate_b_shexp ? split_gate_b_shexp->splits[id] : nullptr, nullptr, - split_down_shexp->splits[id], split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr, - nullptr, type_op_shexp, LLM_FFN_PAR, cb, il); + split_down_shexp->splits[id], !down_bias_added && split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr, + nullptr, type_op_shexp, LLM_FFN_PAR, cb, il, graph, false, false, + id == id_add_routed ? routed_out : nullptr); cb(shared_out, "ffn_shexp_out", il_cb); if (shared_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { shared_out = ggml_cast(ctx, shared_out, lctx.cparams.reduce_type); } + down_bias_added = true; results.push_back(shared_out); } GGML_ASSERT(!results.empty()); - if (results.size() == 1) { - cur = results.front(); - } else { - cur = ggml_add(ctx, results[0], results[1]); - cur->op_params[0] = 0xff; - cb(cur, "ffn_shared_combined", il); - for (int id = 2; id < int(results.size()); ++id) { - cur = ggml_add(ctx, cur, results[id]); - cb(cur, "ffn_shared_combined", il); - } - } - if (routed_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { - auto routed_out_f16 = ggml_cast(ctx, routed_out, lctx.cparams.reduce_type); - cur = ggml_add(ctx, routed_out_f16, cur); - } else { - cur = ggml_add(ctx, routed_out, cur); - } + cur = ggml_reduce(ctx, results.data(), split_up_shexp->n_device, GGML_OP_ADD); cb(cur, "ffn_out", il); } else { auto shared_out = llm_build_ffn(ctx, lctx, nullptr, cur,