From 8fe2bb927a481e32af9b8f56d21fcdb7bc8ab163 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 19 Oct 2025 09:13:34 +0300 Subject: [PATCH] Fuse sigmoid+add+topk+get_rows (CUDA) --- ggml/src/ggml-cuda.cu | 21 ++++++++- ggml/src/ggml-cuda/argsort.cu | 78 ++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/argsort.cuh | 2 + 3 files changed, 99 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 12da791f..4372d99f 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3173,12 +3173,29 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_relu(ctx, dst); break; case GGML_UNARY_OP_SIGMOID: - if (i + 4 < cgraph->n_nodes && + if (i + 5 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ADD && + cgraph->nodes[i+3]->op == GGML_OP_ARGSORT && + cgraph->nodes[i+4]->op == GGML_OP_VIEW && + cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) { + cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]); + i += 5; + } + //else if (i + 5 < cgraph->n_nodes) { + // printf("sigmoid(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s)\n", dst->name, + // ggml_op_name(cgraph->nodes[i+1]->op), cgraph->nodes[i+1]->name, + // ggml_op_name(cgraph->nodes[i+2]->op), cgraph->nodes[i+2]->name, + // ggml_op_name(cgraph->nodes[i+3]->op), cgraph->nodes[i+3]->name, + // ggml_op_name(cgraph->nodes[i+4]->op), cgraph->nodes[i+4]->name, + // ggml_op_name(cgraph->nodes[i+5]->op), cgraph->nodes[i+5]->name); + //} + else if (i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) { - cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]); + cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+4]); i += 4; } else { ggml_cuda_op_sigmoid(ctx, dst); diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 3db45b3b..7c3c5e66 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -124,6 +124,35 @@ static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float } } +template +static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float * bias, float * weights, int * ids, const int ncols, int ncols_pad, int ntop, + size_t nb_ids) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols_pad) { + return; + } + + extern __shared__ int dst_row[]; + auto x_row = (float *)(dst_row + ncols_pad); + + // initialize indices + dst_row[col] = col; + x_row[col] = col < ncols ? 1/(1 + expf(-x[row*ncols + col])) + bias[col] : -INFINITY; + + __syncthreads(); + + sort(ncols_pad, ncols, col, x_row, dst_row); + + if (col < ntop) { + weights[row * ntop + col] = 1/(1 + expf(-x[row * ncols + dst_row[col]])); + auto row_ids = (int *)((char *)ids + row*nb_ids); + row_ids[col] = dst_row[col]; + } +} + template static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) { // bitonic sort @@ -247,6 +276,29 @@ static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, fl } } +static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias, float * weights, int * ids, const int ncols, const int nrows, int ntop, + size_t nb_ids, ggml_sort_order order, cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + const dim3 block_dims(ncols_pad, 1, 1); + const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float)); + + // FIXME: this limit could be raised by ~2-4x on Ampere or newer + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + + if (order == GGML_SORT_ORDER_ASC) { + k_argsort_biased_f32_f32_i32<<>>(x, bias, weights, ids, + ncols, ncols_pad, ntop, nb_ids); + } else if (order == GGML_SORT_ORDER_DESC) { + k_argsort_biased_f32_f32_i32<<>>(x, bias, weights, ids, + ncols, ncols_pad, ntop, nb_ids); + } else { + GGML_ABORT("fatal error"); + } +} + void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -395,3 +447,29 @@ void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * ds ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); } + +void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk_view) { + GGML_ASSERT(topk_view->op == GGML_OP_VIEW); + auto topk = topk_view->src[0]; + auto topk_src = topk->src[0]; + auto probs = topk_src->src[0]->src[0]; + auto bias = topk_src->src[1]; + + auto nrows = ggml_nrows(probs); + + int ne00 = probs->ne[0]; + int ne0 = topk_view->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(bias->ne[1] == 1); + GGML_ASSERT(bias->ne[0] == probs->ne[0]); + GGML_ASSERT(ne0 == dst->ne[1]); + GGML_ASSERT(ne0 <= ne00); + + //printf("probs: %ld x %ld x %ld x %ld. topk: %ld x %ld x %ld x %ld. dst: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n", + // probs->ne[0], probs->ne[1], probs->ne[2], probs->ne[3], topk->ne[0], topk->ne[1], topk->ne[2], topk->ne[3], + // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]); + + argsort_biased_f32_f32_i32_cuda((const float *)probs->data, (const float *)bias->data, (float *)dst->data, (int *)topk->data, + ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); + +} diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index e467abf0..43987fbb 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -13,3 +13,5 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk); + +void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk);