Files
ik_llama.cpp/ggml/src/ggml-cuda/topk-moe.cu
Kawrakow af10490331 Attempt to fix #974 (#983)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-11-19 15:48:39 +01:00

263 lines
9.1 KiB
Plaintext

#include "ggml-cuda/common.cuh"
#include "ggml.h"
#include "topk-moe.cuh"
/*
This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template <size_t n_experts, bool normalize>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
const int n_rows,
const int n_expert_used) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) {
return;
}
logits += n_experts * row;
weights += n_expert_used * row;
ids += n_experts * row;
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
float logits_r[experts_per_thread];
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY;
}
float max_val = logits_r[0];
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
}
max_val = warp_reduce_max(max_val);
float wt[experts_per_thread];
float tmp = 0.f;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = expf(val - max_val);
tmp += wt[i];
}
tmp = warp_reduce_sum(tmp);
const float inv_sum = 1.0f / tmp;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
}
//at this point, each thread holds a portion of softmax,
//we do the argmax reduce over n_expert_used, each time marking
//the expert weight as -inf to exclude from the next iteration
[[maybe_unused]] float sum_selected = 0;
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
int max_expert = threadIdx.x;
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if (expert < n_experts && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val) {
max_val = val;
max_expert = expert;
}
}
sum_selected += max_val;
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
weights[k] = max_val;
ids[k] = max_expert;
}
}
if (!normalize) return;
__syncthreads();
float norm = 1/sum_selected;
for (int k = threadIdx.x; k < n_expert_used; k += WARP_SIZE) {
weights[k] *= norm;
}
}
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void simple_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
const int n_rows,
const int n_experts) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) {
return;
}
logits += n_experts * row;
weights += n_experts * row;
ids += n_experts * row;
float max_val = -INFINITY;
#pragma unroll
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
max_val = max(max_val, logits[i]);
ids[i] = i;
}
max_val = warp_reduce_max(max_val);
float sum = 0;
#pragma unroll
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
weights[i] = expf(logits[i] - max_val);
sum += weights[i];
}
sum = warp_reduce_sum(sum);
float norm = 1/sum;
#pragma unroll
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
weights[i] *= norm;
}
}
template <bool normalize>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
float * weights,
int32_t * ids,
const int n_rows,
const int n_expert,
const int n_expert_used) {
const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
cudaStream_t stream = ctx.stream();
if (n_expert_used == n_expert) {
simple_moe_cuda<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert);
return;
}
switch (n_expert) {
case 1:
topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 2:
topk_moe_cuda<2, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 4:
topk_moe_cuda<4, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 8:
topk_moe_cuda<8, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 16:
topk_moe_cuda<16, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 32:
topk_moe_cuda<32, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 64:
topk_moe_cuda<64, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 128:
topk_moe_cuda<128, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 256:
topk_moe_cuda<256, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 512:
topk_moe_cuda<512, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
default:
GGML_ASSERT(false && "fatal error");
break;
}
}
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1];
const float * logits_d = (const float *) logits->src[0]->data;
float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data;
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
cudaStream_t stream = ctx.stream();
if (weights->op == GGML_OP_DIV) {
const int n_expert_used = weights->ne[0];
launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
} else {
const int n_expert_used = weights->ne[1];
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
return false;
}
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;
}
return true;
}