From 232050b47359730d6c827b431b2613afd73397a6 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 19 Nov 2025 15:48:39 +0100 Subject: [PATCH] Attempt to fix #974 (#983) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/topk-moe.cu | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 59b74b06..c71a1f64 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -111,6 +111,44 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void simple_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_experts) { + const int row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= n_rows) { + return; + } + + logits += n_experts * row; + weights += n_experts * row; + ids += n_experts * row; + + float max_val = -INFINITY; +#pragma unroll + for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) { + max_val = max(max_val, logits[i]); + ids[i] = i; + } + + max_val = warp_reduce_max(max_val); + + float sum = 0; +#pragma unroll + for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) { + weights[i] = expf(logits[i] - max_val); + sum += weights[i]; + } + + sum = warp_reduce_sum(sum); + float norm = 1/sum; +#pragma unroll + for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) { + weights[i] *= norm; + } +} + template static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, @@ -124,6 +162,11 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); + if (n_expert_used == n_expert) { + simple_moe_cuda<<>>(logits, weights, ids, n_rows, n_expert); + return; + } + switch (n_expert) { case 1: topk_moe_cuda<1, normalize><<>>(logits, weights, ids, n_rows, n_expert_used);