From 8f5f93e6b11d039e24b7253b1883b13bfc75f41f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 18 Oct 2025 10:09:32 +0300 Subject: [PATCH] Fuse sigmoid+add+grouped_topk+get_rows (CPU) --- ggml/src/ggml.c | 12 +++- ggml/src/iqk/iqk_cpu_ops.cpp | 115 +++++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_cpu_ops.h | 2 + src/llama-build-context.cpp | 5 +- 4 files changed, 131 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 63c0b995..5ca73ae5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22611,7 +22611,17 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_UNARY: { - ggml_compute_forward_unary(params, tensor); + const enum ggml_unary_op unary_op = ggml_get_unary_op(tensor); + if (unary_op == GGML_UNARY_OP_SIGMOID && i + 4 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ADD && + cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && + cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) { + iqk_bailingmoev2_experts(cgraph->nodes[i+4], cgraph->nodes[i+3], params->ith, params->nth); + i += 4; + } else { + ggml_compute_forward_unary(params, tensor); + } } break; case GGML_OP_GLU: { diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 115f25ce..5d0adcba 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -5,6 +5,7 @@ // #include "iqk_cpu_ops.h" +#include "iqk_utils.h" #include "ggml.h" #include @@ -39,6 +40,49 @@ inline std::vector> & get_work_buffer(size_t size) { return buffer; } +#ifdef __ARM_NEON +inline float32x4_t v_biased_sigmoid(float32x4_t x, float32x4_t b) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t neg_x = vsubq_f32(zero, x); + const float32x4_t exp_neg_x = v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vaddq_f32(b, vdivq_f32(one, one_plus_exp_neg_x)); +} +#endif +#ifdef __AVX2__ +inline __m256 v_biased_sigmoid(__m256 x, __m256 b) { + const __m256 one = _mm256_set1_ps(1); + const __m256 zero = _mm256_setzero_ps(); + const __m256 neg_x = _mm256_sub_ps(zero, x); + const __m256 exp_neg_x = v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_add_ps(b, _mm256_div_ps(one, one_plus_exp_neg_x)); +} +#endif +#if defined __AVX512F__ && defined __AVX512DQ__ +inline __m512 v_biased_sigmoid(__m512 x, __m512 b) { + const __m512 one = _mm512_set1_ps(1); + const __m512 zero = _mm512_setzero_ps(); + const __m512 neg_x = _mm512_sub_ps(zero, x); + const __m512 exp_neg_x = v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_add_ps(b, _mm512_div_ps(one, one_plus_exp_neg_x)); +} +#endif +inline void biased_sigmoid(int n, const float * x, const float * bias, float * y) { + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_biased_sigmoid(_mm512_loadu_ps(x + i), _mm512_loadu_ps(bias + i))); +#endif +#if defined __AVX2__ && defined __FMA__ + for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_biased_sigmoid(_mm256_loadu_ps(x + i), _mm256_loadu_ps(bias + i))); +#endif +#ifdef __ARM_NEON + for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_biased_sigmoid(vld1q_f32(x + i), vld1q_f32(bias + i))); +#endif + for (; i < n; ++i) y[i] = 1/(1 + expf(-x[i])) + bias[i]; +} } void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) { @@ -143,3 +187,74 @@ void iqk_argsort(ggml_tensor * dst, int ith, int nth) { } +void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * topk, int ith, int nth) { + auto topk_src = topk->src[0]; + auto probs = topk_src->src[0]->src[0]; + auto t_bias = topk_src->src[1]; + + auto nrows = ggml_nrows(probs); + auto npt = (nrows + nth - 1)/nth; + auto first = npt*ith; + auto last = std::min(first + npt, nrows); + if (last <= first) return; + + int n_groups = topk->op_params[0]; + int n_top_groups = topk->op_params[1]; + int nk = topk->op_params[2]; + + int ne00 = probs->ne[0]; + int ne0 = topk->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(t_bias->ne[1] == 1); + GGML_ASSERT(t_bias->ne[0] == probs->ne[0]); + GGML_ASSERT(ne0 == dst->ne[1]); + GGML_ASSERT(ne0 <= ne00); + GGML_ASSERT(ne00%n_groups == 0); + int n_per_group = ne00/n_groups; + GGML_ASSERT(nk <= n_per_group); + GGML_ASSERT(n_top_groups <= n_groups); + + size_t work_size = n_groups + n_per_group*n_top_groups + (ne00 + 1)/2; + auto& aux = get_work_buffer(work_size); + + auto groups = aux.data() + n_per_group*n_top_groups; + auto values = (float *)(groups + n_groups); + + auto bias = (const float *)t_bias->data; + + for (int ir = first; ir < last; ++ir) { + auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]); + biased_sigmoid(ne00, data, bias, values); + //for (int j = 0; j < ne00; ++j) values[j] = 1/(1 + expf(-data[j])) + bias[j]; + auto weights = (float *)((char *)dst->data + ir*dst->nb[2]); + auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]); + if (ne0 > n_per_group*n_top_groups) { + for (int j = 0; j < ne0; ++j) { + weights[j] = values[j]; + ids[j] = j; + } + continue; + } + if (n_top_groups < n_groups) { + for (int ig = 0; ig < n_groups; ++ig) { + groups[ig] = { group_score(n_per_group, nk, values + ig*n_per_group, (float *)aux.data()), ig }; + } + std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater>{}); + + for (int ig = 0; ig < n_top_groups; ++ig) { + int i0 = n_per_group * ig; + int j0 = n_per_group * groups[ig].second; + for (int j = 0; j < n_per_group; ++j) aux[i0 + j] = { values[j0 + j], j0 + j }; + } + } else { + for (int j = 0; j < ne00; ++j) aux[j] = { values[j], j }; + } + std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + n_top_groups*n_per_group, std::greater>{}); + for (int j = 0; j < ne0; ++j) { + weights[j] = aux[j].first; + ids[j] = aux[j].second; + } + + } +} + diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index c83d8061..81c14fd5 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -18,6 +18,8 @@ void iqk_grouped_top_k(struct ggml_tensor * dst, int ith, int nth); void iqk_argsort(struct ggml_tensor * dst, int ith, int nth); +void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * topk, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index b3208dcd..e2fccfca 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -827,8 +827,9 @@ llm_expert_gating_func_type gating_op, auto& hparams = lctx.model.hparams; selected_experts = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2, n_expert_used); } else { - selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, - lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] + //selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, + // lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] + selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] } cb(selected_experts, "ffn_moe_topk", il); ggml_tensor * weights = ggml_get_rows(ctx,