Fuse sum_rows and div with topk-moe (#984)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-19 13:44:09 +01:00
committed by GitHub
parent 047a519771
commit d764edd652
3 changed files with 45 additions and 21 deletions

View File

@@ -3336,8 +3336,16 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS &&
ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4]) &&
ops_are_same_device(cgraph, i, i+4)) {
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]);
i += 4;
if (i + 7 < cgraph->n_nodes &&
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+6]->op == GGML_OP_SUM_ROWS &&
cgraph->nodes[i+7]->op == GGML_OP_DIV) {
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+7], cgraph->nodes[i+3]);
i += 7;
} else {
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]);
i += 4;
}
} else {
ggml_cuda_op_soft_max(ctx, dst);
}

View File

@@ -10,7 +10,7 @@
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template <size_t n_experts>
template <size_t n_experts, bool normalize>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
@@ -58,7 +58,6 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
tmp = warp_reduce_sum(tmp);
const float inv_sum = 1.0f / tmp;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
@@ -68,6 +67,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
//we do the argmax reduce over n_expert_used, each time marking
//the expert weight as -inf to exclude from the next iteration
[[maybe_unused]] float sum_selected = 0;
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
int max_expert = threadIdx.x;
@@ -91,6 +91,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}
sum_selected += max_val;
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
@@ -98,8 +99,19 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
ids[k] = max_expert;
}
}
if (!normalize) return;
__syncthreads();
float norm = 1/sum_selected;
for (int k = threadIdx.x; k < n_expert_used; k += WARP_SIZE) {
weights[k] *= norm;
}
}
template <bool normalize>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
float * weights,
@@ -114,34 +126,34 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) {
case 1:
topk_moe_cuda<1><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 2:
topk_moe_cuda<2><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<2, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 4:
topk_moe_cuda<4><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<4, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 8:
topk_moe_cuda<8><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<8, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 16:
topk_moe_cuda<16><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<16, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 32:
topk_moe_cuda<32><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<32, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 64:
topk_moe_cuda<64><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<64, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 128:
topk_moe_cuda<128><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<128, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 256:
topk_moe_cuda<256><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<256, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 512:
topk_moe_cuda<512><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<512, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
default:
GGML_ASSERT(false && "fatal error");
@@ -168,9 +180,13 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
cudaStream_t stream = ctx.stream();
const int n_expert_used = weights->ne[1];
launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
if (weights->op == GGML_OP_DIV) {
const int n_expert_used = weights->ne[0];
launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
} else {
const int n_expert_used = weights->ne[1];
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {