diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 3e151ac7..998c4a23 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -41,6 +41,7 @@ #include "ggml-cuda/graph.cuh" #include "ggml-cuda/mmq_id.cuh" #include "ggml-cuda/quantize_id.cuh" +#include "ggml-cuda/topk-moe.cuh" #include #include @@ -3030,7 +3031,8 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } -static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) { +static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, + const ggml_cgraph * cgraph, int & i) { // why is this here instead of mul_mat? if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); @@ -3152,10 +3154,10 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } break; case GGML_OP_MUL_MAT_ID: - skip_next = ggml_cuda_mul_mat_id(ctx, dst, next); + if (ggml_cuda_mul_mat_id(ctx, dst, next)) ++i; break; case GGML_OP_MOE_FUSED_UP_GATE: - skip_next = ggml_cuda_moe_up_gate_unary(ctx, dst, next); + if (ggml_cuda_moe_up_gate_unary(ctx, dst, next)) ++i; break; case GGML_OP_FUSED_UP_GATE: ggml_cuda_up_gate_unary(ctx, dst); @@ -3185,7 +3187,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - ggml_cuda_op_soft_max(ctx, dst); + if (i + 4 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ARGSORT && + cgraph->nodes[i+3]->op == GGML_OP_VIEW && + cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && + ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4])) { + ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]); + i += 4; + } else { + ggml_cuda_op_soft_max(ctx, dst); + } break; case GGML_OP_SOFT_CAP_MAX: ggml_cuda_op_soft_cap_max(ctx, dst); @@ -3592,13 +3604,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx GGML_UNUSED(integrated); #endif // NDEBUG - bool skip_next = false; - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next); + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, cgraph, i); if (!ok) { GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); - if (skip_next) ++i; } } #ifdef USE_CUDA_GRAPH diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu new file mode 100644 index 00000000..d4c7f30a --- /dev/null +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -0,0 +1,203 @@ +#include "ggml-cuda/common.cuh" +#include "ggml.h" +#include "topk-moe.cuh" + +/* + This kernel does the following: + 1. softmax over the logits per token [n_experts, n_tokens] + 2. argmax reduce over the top-k (n_experts_used) logits + 3. write weights + ids to global memory + + It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models +*/ +template +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert_used) { + const int row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= n_rows) { + return; + } + + logits += n_experts * row; + weights += n_expert_used * row; + ids += n_experts * row; + + constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; + + float logits_r[experts_per_thread]; + +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY; + } + + float max_val = logits_r[0]; + +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const float val = logits_r[i]; + max_val = max(val, max_val); + } + + max_val = warp_reduce_max(max_val); + + float wt[experts_per_thread]; + float tmp = 0.f; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const float val = logits_r[i]; + wt[i] = expf(val - max_val); + tmp += wt[i]; + } + + tmp = warp_reduce_sum(tmp); + + const float inv_sum = 1.0f / tmp; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = wt[i] * inv_sum; + } + + //at this point, each thread holds a portion of softmax, + //we do the argmax reduce over n_expert_used, each time marking + //the expert weight as -inf to exclude from the next iteration + + for (int k = 0; k < n_expert_used; k++) { + float max_val = wt[0]; + int max_expert = threadIdx.x; + +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if (expert < n_experts && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val > max_val) { + max_val = val; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; + + weights[k] = max_val; + ids[k] = max_expert; + } + } +} + +static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, + const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert, + const int n_expert_used) { + const int rows_per_block = 4; + dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); + dim3 block_dims(WARP_SIZE, rows_per_block, 1); + cudaStream_t stream = ctx.stream(); + + switch (n_expert) { + case 1: + topk_moe_cuda<1><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 2: + topk_moe_cuda<2><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 4: + topk_moe_cuda<4><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 8: + topk_moe_cuda<8><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 16: + topk_moe_cuda<16><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 32: + topk_moe_cuda<32><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 64: + topk_moe_cuda<64><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 128: + topk_moe_cuda<128><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 256: + topk_moe_cuda<256><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 512: + topk_moe_cuda<512><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + default: + GGML_ASSERT(false && "fatal error"); + break; + } +} + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids) { + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const int n_experts = logits->ne[0]; + const int n_rows = logits->ne[1]; + + const float * logits_d = (const float *) logits->src[0]->data; + float * weights_d = (float *) weights->data; + int32_t * ids_d = (int32_t *) ids->data; + + GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); + + cudaStream_t stream = ctx.stream(); + + const int n_expert_used = weights->ne[1]; + + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); +} + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + + if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { + return false; + } + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + + const int n_expert = softmax->ne[0]; + // n_expert must be a power of 2 + if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { + return false; + } + + return true; +} diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh new file mode 100644 index 00000000..03f4ad56 --- /dev/null +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -0,0 +1,8 @@ +#include "common.cuh" + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * top_k); + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); diff --git a/src/llama.cpp b/src/llama.cpp index c3401267..f0df3a07 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7950,6 +7950,10 @@ llm_expert_gating_func_type gating_op, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); + if (graph) { + ggml_build_forward_expand(graph, weights); + } + if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens] @@ -8960,7 +8964,7 @@ struct llm_build_context { LLM_FFN_SILU, false, false, 0.0, LLM_EXPERT_GATING_FUNC_SIGMOID, - cb, il); + cb, il, gf); // Shared experts ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, ffn_inp_normed, @@ -8991,7 +8995,7 @@ struct llm_build_context { LLM_FFN_SILU, true, false, 0.0, LLM_EXPERT_GATING_FUNC_SOFTMAX, - cb, il); + cb, il, gf); cb(cur, "ffn_moe_out", il); } @@ -9648,7 +9652,7 @@ struct llm_build_context { LLM_FFN_GELU, true, false, 0.0, LLM_EXPERT_GATING_FUNC_SOFTMAX, - cb, il); + cb, il, gf); cb(cur, "ffn_moe_out", il); // Grok @@ -9791,7 +9795,7 @@ struct llm_build_context { LLM_FFN_SILU, true, false, 0.0, LLM_EXPERT_GATING_FUNC_SOFTMAX, - cb, il); + cb, il, gf); cb(cur, "ffn_moe_out", il); cur = ggml_add(ctx0, cur, ffn_inp); @@ -10923,7 +10927,7 @@ struct llm_build_context { LLM_FFN_SILU, false, false, 0.0, LLM_EXPERT_GATING_FUNC_SOFTMAX, - cb, il); + cb, il, gf); cb(cur, "ffn_moe_out", il); // FFN shared expert @@ -11188,7 +11192,7 @@ struct llm_build_context { LLM_FFN_SILU, true, false, 0.0, LLM_EXPERT_GATING_FUNC_SOFTMAX, - cb, il); + cb, il, gf); cb(cur, "ffn_moe_out", il); cur = ggml_add(ctx0, cur, ffn_inp); @@ -13451,7 +13455,7 @@ struct llm_build_context { LLM_FFN_SILU, true, false, 0.0, LLM_EXPERT_GATING_FUNC_SOFTMAX, - cb, il); + cb, il, gf); cb(cur, "ffn_moe_out", il); cur = ggml_add(ctx0, cur, ffn_out); @@ -13940,7 +13944,7 @@ struct llm_build_context { LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, (enum llm_expert_gating_func_type) hparams.expert_gating_func, - cb, il); + cb, il, gf); cb(moe_out, "ffn_moe_out", il); // FFN shared expert @@ -14116,7 +14120,7 @@ struct llm_build_context { LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, (enum llm_expert_gating_func_type) hparams.expert_gating_func, - cb, il); + cb, il, gf); cb(routed_out, "routed_out", il); { @@ -15377,7 +15381,7 @@ struct llm_build_context { LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, (enum llm_expert_gating_func_type) hparams.expert_gating_func, - cb, il); + cb, il, gf); cb(moe_out, "ffn_moe_out", il); { @@ -15670,7 +15674,7 @@ struct llm_build_context { LLM_FFN_SILU, true, false, 0.0, LLM_EXPERT_GATING_FUNC_SOFTMAX, - cb, il); + cb, il, gf); cb(moe_out, "ffn_moe_out", il); // Shared expert (if present) @@ -15835,7 +15839,7 @@ struct llm_build_context { 0.0, LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, - il); + il, gf); cb(cur_moe, "ffn_moe_out", il); ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp);