mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
cuda: fused top_k+softmax as used in most MoE models (#789)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -41,6 +41,7 @@
|
||||
#include "ggml-cuda/graph.cuh"
|
||||
#include "ggml-cuda/mmq_id.cuh"
|
||||
#include "ggml-cuda/quantize_id.cuh"
|
||||
#include "ggml-cuda/topk-moe.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@@ -3030,7 +3031,8 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
|
||||
}
|
||||
|
||||
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) {
|
||||
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next,
|
||||
const ggml_cgraph * cgraph, int & i) {
|
||||
// why is this here instead of mul_mat?
|
||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
|
||||
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
|
||||
@@ -3152,10 +3154,10 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
skip_next = ggml_cuda_mul_mat_id(ctx, dst, next);
|
||||
if (ggml_cuda_mul_mat_id(ctx, dst, next)) ++i;
|
||||
break;
|
||||
case GGML_OP_MOE_FUSED_UP_GATE:
|
||||
skip_next = ggml_cuda_moe_up_gate_unary(ctx, dst, next);
|
||||
if (ggml_cuda_moe_up_gate_unary(ctx, dst, next)) ++i;
|
||||
break;
|
||||
case GGML_OP_FUSED_UP_GATE:
|
||||
ggml_cuda_up_gate_unary(ctx, dst);
|
||||
@@ -3185,7 +3187,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_diag_mask_inf(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
ggml_cuda_op_soft_max(ctx, dst);
|
||||
if (i + 4 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS &&
|
||||
ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4])) {
|
||||
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]);
|
||||
i += 4;
|
||||
} else {
|
||||
ggml_cuda_op_soft_max(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_SOFT_CAP_MAX:
|
||||
ggml_cuda_op_soft_cap_max(ctx, dst);
|
||||
@@ -3592,13 +3604,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
GGML_UNUSED(integrated);
|
||||
#endif // NDEBUG
|
||||
|
||||
bool skip_next = false;
|
||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next);
|
||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, cgraph, i);
|
||||
if (!ok) {
|
||||
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
if (skip_next) ++i;
|
||||
}
|
||||
}
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
|
||||
203
ggml/src/ggml-cuda/topk-moe.cu
Normal file
203
ggml/src/ggml-cuda/topk-moe.cu
Normal file
@@ -0,0 +1,203 @@
|
||||
#include "ggml-cuda/common.cuh"
|
||||
#include "ggml.h"
|
||||
#include "topk-moe.cuh"
|
||||
|
||||
/*
|
||||
This kernel does the following:
|
||||
1. softmax over the logits per token [n_experts, n_tokens]
|
||||
2. argmax reduce over the top-k (n_experts_used) logits
|
||||
3. write weights + ids to global memory
|
||||
|
||||
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
|
||||
*/
|
||||
template <size_t n_experts>
|
||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert_used) {
|
||||
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
logits += n_experts * row;
|
||||
weights += n_expert_used * row;
|
||||
ids += n_experts * row;
|
||||
|
||||
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
|
||||
float logits_r[experts_per_thread];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const int expert = i + threadIdx.x;
|
||||
logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY;
|
||||
}
|
||||
|
||||
float max_val = logits_r[0];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
max_val = max(val, max_val);
|
||||
}
|
||||
|
||||
max_val = warp_reduce_max(max_val);
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float tmp = 0.f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
wt[i] = expf(val - max_val);
|
||||
tmp += wt[i];
|
||||
}
|
||||
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = wt[i] * inv_sum;
|
||||
}
|
||||
|
||||
//at this point, each thread holds a portion of softmax,
|
||||
//we do the argmax reduce over n_expert_used, each time marking
|
||||
//the expert weight as -inf to exclude from the next iteration
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
int max_expert = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const int expert = threadIdx.x + i * WARP_SIZE;
|
||||
if (expert < n_experts && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
|
||||
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
weights[k] = max_val;
|
||||
ids[k] = max_expert;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert,
|
||||
const int n_expert_used) {
|
||||
const int rows_per_block = 4;
|
||||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 2:
|
||||
topk_moe_cuda<2><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 4:
|
||||
topk_moe_cuda<4><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 8:
|
||||
topk_moe_cuda<8><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 16:
|
||||
topk_moe_cuda<16><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 32:
|
||||
topk_moe_cuda<32><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 64:
|
||||
topk_moe_cuda<64><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 128:
|
||||
topk_moe_cuda<128><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 256:
|
||||
topk_moe_cuda<256><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 512:
|
||||
topk_moe_cuda<512><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids) {
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
const int n_experts = logits->ne[0];
|
||||
const int n_rows = logits->ne[1];
|
||||
|
||||
const float * logits_d = (const float *) logits->src[0]->data;
|
||||
float * weights_d = (float *) weights->data;
|
||||
int32_t * ids_d = (int32_t *) ids->data;
|
||||
|
||||
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int n_expert_used = weights->ne[1];
|
||||
|
||||
launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
|
||||
|
||||
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_expert = softmax->ne[0];
|
||||
// n_expert must be a power of 2
|
||||
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
8
ggml/src/ggml-cuda/topk-moe.cuh
Normal file
8
ggml/src/ggml-cuda/topk-moe.cuh
Normal file
@@ -0,0 +1,8 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * top_k);
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
|
||||
@@ -7950,6 +7950,10 @@ llm_expert_gating_func_type gating_op,
|
||||
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
cb(weights, "ffn_moe_weights", il);
|
||||
|
||||
if (graph) {
|
||||
ggml_build_forward_expand(graph, weights);
|
||||
}
|
||||
|
||||
if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
|
||||
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
|
||||
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
|
||||
@@ -8960,7 +8964,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SIGMOID,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
|
||||
// Shared experts
|
||||
ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, ffn_inp_normed,
|
||||
@@ -8991,7 +8995,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
|
||||
@@ -9648,7 +9652,7 @@ struct llm_build_context {
|
||||
LLM_FFN_GELU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
// Grok
|
||||
@@ -9791,7 +9795,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
@@ -10923,7 +10927,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
// FFN shared expert
|
||||
@@ -11188,7 +11192,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
@@ -13451,7 +13455,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_out);
|
||||
@@ -13940,7 +13944,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// FFN shared expert
|
||||
@@ -14116,7 +14120,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(routed_out, "routed_out", il);
|
||||
|
||||
{
|
||||
@@ -15377,7 +15381,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
{
|
||||
@@ -15670,7 +15674,7 @@ struct llm_build_context {
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb, il, gf);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// Shared expert (if present)
|
||||
@@ -15835,7 +15839,7 @@ struct llm_build_context {
|
||||
0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb,
|
||||
il);
|
||||
il, gf);
|
||||
cb(cur_moe, "ffn_moe_out", il);
|
||||
|
||||
ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp);
|
||||
|
||||
Reference in New Issue
Block a user