Attempt to fix #974 (#983)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-19 15:48:39 +01:00
committed by GitHub
parent d764edd652
commit 232050b473

View File

@@ -111,6 +111,44 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void simple_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
const int n_rows,
const int n_experts) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) {
return;
}
logits += n_experts * row;
weights += n_experts * row;
ids += n_experts * row;
float max_val = -INFINITY;
#pragma unroll
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
max_val = max(max_val, logits[i]);
ids[i] = i;
}
max_val = warp_reduce_max(max_val);
float sum = 0;
#pragma unroll
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
weights[i] = expf(logits[i] - max_val);
sum += weights[i];
}
sum = warp_reduce_sum(sum);
float norm = 1/sum;
#pragma unroll
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
weights[i] *= norm;
}
}
template <bool normalize>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
@@ -124,6 +162,11 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
cudaStream_t stream = ctx.stream();
if (n_expert_used == n_expert) {
simple_moe_cuda<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert);
return;
}
switch (n_expert) {
case 1:
topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);