diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 59b74b06..c71a1f64 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -111,6 +111,44 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void simple_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_experts) { + const int row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= n_rows) { + return; + } + + logits += n_experts * row; + weights += n_experts * row; + ids += n_experts * row; + + float max_val = -INFINITY; +#pragma unroll + for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) { + max_val = max(max_val, logits[i]); + ids[i] = i; + } + + max_val = warp_reduce_max(max_val); + + float sum = 0; +#pragma unroll + for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) { + weights[i] = expf(logits[i] - max_val); + sum += weights[i]; + } + + sum = warp_reduce_sum(sum); + float norm = 1/sum; +#pragma unroll + for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) { + weights[i] *= norm; + } +} + template static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, @@ -124,6 +162,11 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); + if (n_expert_used == n_expert) { + simple_moe_cuda<<>>(logits, weights, ids, n_rows, n_expert); + return; + } + switch (n_expert) { case 1: topk_moe_cuda<1, normalize><<>>(logits, weights, ids, n_rows, n_expert_used);