mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -111,6 +111,44 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
|
||||
}
|
||||
|
||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void simple_moe_cuda(const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_experts) {
|
||||
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
logits += n_experts * row;
|
||||
weights += n_experts * row;
|
||||
ids += n_experts * row;
|
||||
|
||||
float max_val = -INFINITY;
|
||||
#pragma unroll
|
||||
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
|
||||
max_val = max(max_val, logits[i]);
|
||||
ids[i] = i;
|
||||
}
|
||||
|
||||
max_val = warp_reduce_max(max_val);
|
||||
|
||||
float sum = 0;
|
||||
#pragma unroll
|
||||
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
|
||||
weights[i] = expf(logits[i] - max_val);
|
||||
sum += weights[i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
float norm = 1/sum;
|
||||
#pragma unroll
|
||||
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
|
||||
weights[i] *= norm;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool normalize>
|
||||
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
const float * logits,
|
||||
@@ -124,6 +162,11 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
if (n_expert_used == n_expert) {
|
||||
simple_moe_cuda<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
|
||||
Reference in New Issue
Block a user