Various fused ops around expert selection (#840)

* Fuse sigmoid+add+grouped_topk+get_rows (CPU)

* Fix CPU + CUDA

but CUDA is somehow not 100% correct as I get a slightly different
PPL (lower!)

* Minor

* Fuse sigmoid+add+topk+get_rows (CUDA)

* Fuse sigmoid+add+topk+get_rows (CPU)

* Fuse topk+view+get_rows+reshape+softmax (CPU)

* Fuse topk+view+get_rows+reshape+softmax (CUDA)

* cpu: turn off the openai topk fusing for now

Something is not right and I don't see the bug.
On the CPU one doesn't gain much if anything, so not a big loss.

* Also fuse sum_rows and div

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-19 19:02:46 +03:00
committed by GitHub
parent 1dcc044134
commit 7a41b3b1f5
9 changed files with 715 additions and 66 deletions

View File

@@ -827,18 +827,15 @@ llm_expert_gating_func_type gating_op,
auto& hparams = lctx.model.hparams;
selected_experts = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2, n_expert_used);
} else {
selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used,
lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens]
//selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used,
// lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens]
selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
}
cb(selected_experts, "ffn_moe_topk", il);
ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
if (graph) {
ggml_build_forward_expand(graph, weights);
}
if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
@@ -846,6 +843,10 @@ llm_expert_gating_func_type gating_op,
cb(weights, "ffn_moe_weights_softmax", il);
}
if (graph) {
ggml_build_forward_expand(graph, weights);
}
if (norm_w) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);