From 18d9f4fc4dfe3b400e9a94e17e9fb45b7d045f95 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 19 Oct 2025 10:03:51 +0300 Subject: [PATCH] Fuse sigmoid+add+topk+get_rows (CPU) --- ggml/src/ggml-cuda.cu | 8 ---- ggml/src/ggml.c | 11 +++++- ggml/src/iqk/iqk_cpu_ops.cpp | 77 +++++++++++++++++++++++++++++++++++- ggml/src/iqk/iqk_cpu_ops.h | 2 + 4 files changed, 88 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 4372d99f..4e3b0413 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3182,14 +3182,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]); i += 5; } - //else if (i + 5 < cgraph->n_nodes) { - // printf("sigmoid(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s)\n", dst->name, - // ggml_op_name(cgraph->nodes[i+1]->op), cgraph->nodes[i+1]->name, - // ggml_op_name(cgraph->nodes[i+2]->op), cgraph->nodes[i+2]->name, - // ggml_op_name(cgraph->nodes[i+3]->op), cgraph->nodes[i+3]->name, - // ggml_op_name(cgraph->nodes[i+4]->op), cgraph->nodes[i+4]->name, - // ggml_op_name(cgraph->nodes[i+5]->op), cgraph->nodes[i+5]->name); - //} else if (i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5ca73ae5..a556b2a4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22612,7 +22612,16 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml case GGML_OP_UNARY: { const enum ggml_unary_op unary_op = ggml_get_unary_op(tensor); - if (unary_op == GGML_UNARY_OP_SIGMOID && i + 4 < cgraph->n_nodes && + if (unary_op == GGML_UNARY_OP_SIGMOID && i + 5 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ADD && + cgraph->nodes[i+3]->op == GGML_OP_ARGSORT && + cgraph->nodes[i+4]->op == GGML_OP_VIEW && + cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) { + iqk_glm45moe_experts(cgraph->nodes[i+5], cgraph->nodes[i+4], params->ith, params->nth); + i += 5; + } + else if (unary_op == GGML_UNARY_OP_SIGMOID && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index f823d282..51cdfcc8 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -87,13 +87,41 @@ inline void biased_sigmoid(int n, const float * x, const float * bias, float * y } #endif #ifdef __ARM_NEON - for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_biased_sigmoid(vld1q_f32(x + i), vld1q_f32(bias + i))); + for (; i + 3 < n; i += 4) { + auto v = v_sigmoid(vld1q_f32(x + i)); + vst1q_f32(y + i, vaddq_f32(v, vld1q_f32(bias + i))); + vst1q_f32(z + i, v); + } #endif for (; i < n; ++i) { z[i] = 1/(1 + expf(-x[i])); y[i] = y[i] + bias[i]; } } +inline void biased_sigmoid(int n, const float * x, const float * bias, float * y) { + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + for (; i + 15 < n; i += 16) { + auto v = v_sigmoid(_mm512_loadu_ps(x + i)); + _mm512_storeu_ps(y + i, _mm512_add_ps(v, _mm512_loadu_ps(bias + i))); + } +#endif +#if defined __AVX2__ && defined __FMA__ + for (; i + 7 < n; i += 8) { + auto v = v_sigmoid(_mm256_loadu_ps(x + i)); + _mm256_storeu_ps(y + i, _mm256_add_ps(v, _mm256_loadu_ps(bias + i))); + } +#endif +#ifdef __ARM_NEON + for (; i + 3 < n; i += 4) { + auto v = v_sigmoid(vld1q_f32(x + i)); + vst1q_f32(y + i, vaddq_f32(v, vld1q_f32(bias + i))); + } +#endif + for (; i < n; ++i) { + y[i] = 1/(1 + expf(-x[i])) + bias[i]; + } +} } void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) { @@ -270,3 +298,50 @@ void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * top } } +void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_view, int ith, int nth) { + GGML_ASSERT(topk_view->op == GGML_OP_VIEW); + auto topk = topk_view->src[0]; + auto topk_src = topk->src[0]; + auto probs = topk_src->src[0]->src[0]; + auto t_bias = topk_src->src[1]; + + auto nrows = ggml_nrows(probs); + auto npt = (nrows + nth - 1)/nth; + auto first = npt*ith; + auto last = std::min(first + npt, nrows); + if (last <= first) return; + + int ne00 = probs->ne[0]; + int ne0 = topk_view->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(t_bias->ne[1] == 1); + GGML_ASSERT(t_bias->ne[0] == probs->ne[0]); + GGML_ASSERT(ne0 == dst->ne[1]); + GGML_ASSERT(ne0 <= ne00); + + size_t work_size = 2*ne00; + auto& aux = get_work_buffer(work_size); + + auto biased_values = (float *)(aux.data() + ne00); + //auto values = biased_values + ne00; + + auto bias = (const float *)t_bias->data; + + for (int ir = first; ir < last; ++ir) { + auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]); + //biased_sigmoid(ne00, data, bias, biased_values, values); + biased_sigmoid(ne00, data, bias, biased_values); + auto weights = (float *)((char *)dst->data + ir*dst->nb[2]); + auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]); + for (int j = 0; j < ne00; ++j) aux[j] = { biased_values[j], j }; + if (ne0 < ne00) { + std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + ne00, std::greater>{}); + } else { + std::sort(aux.begin(), aux.begin() + ne00, std::greater>{}); + } + for (int j = 0; j < ne0; ++j) { + weights[j] = 1/(1 + expf(-data[aux[j].second])); + ids[j] = aux[j].second; + } + } +} diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index 81c14fd5..2de3a5cb 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -20,6 +20,8 @@ void iqk_argsort(struct ggml_tensor * dst, int ith, int nth); void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * topk, int ith, int nth); +void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_view, int ith, int nth); + #ifdef __cplusplus } #endif