diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index f6891cac..0db57b08 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3336,8 +3336,16 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4]) && ops_are_same_device(cgraph, i, i+4)) { - ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]); - i += 4; + if (i + 7 < cgraph->n_nodes && + cgraph->nodes[i+5]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+6]->op == GGML_OP_SUM_ROWS && + cgraph->nodes[i+7]->op == GGML_OP_DIV) { + ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+7], cgraph->nodes[i+3]); + i += 7; + } else { + ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]); + i += 4; + } } else { ggml_cuda_op_soft_max(ctx, dst); } diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index d4c7f30a..59b74b06 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -10,7 +10,7 @@ It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template +template __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, float * weights, int32_t * ids, @@ -58,7 +58,6 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * tmp = warp_reduce_sum(tmp); const float inv_sum = 1.0f / tmp; - #pragma unroll for (int i = 0; i < experts_per_thread; i++) { wt[i] = wt[i] * inv_sum; @@ -68,6 +67,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * //we do the argmax reduce over n_expert_used, each time marking //the expert weight as -inf to exclude from the next iteration + [[maybe_unused]] float sum_selected = 0; for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; int max_expert = threadIdx.x; @@ -91,6 +91,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } + sum_selected += max_val; if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { wt[max_expert / WARP_SIZE] = -INFINITY; @@ -98,8 +99,19 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * ids[k] = max_expert; } } + + if (!normalize) return; + + __syncthreads(); + + float norm = 1/sum_selected; + for (int k = threadIdx.x; k < n_expert_used; k += WARP_SIZE) { + weights[k] *= norm; + } + } +template static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, float * weights, @@ -114,34 +126,34 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, switch (n_expert) { case 1: - topk_moe_cuda<1><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<1, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; case 2: - topk_moe_cuda<2><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<2, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; case 4: - topk_moe_cuda<4><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<4, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; case 8: - topk_moe_cuda<8><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<8, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; case 16: - topk_moe_cuda<16><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<16, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; case 32: - topk_moe_cuda<32><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<32, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; case 64: - topk_moe_cuda<64><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<64, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; case 128: - topk_moe_cuda<128><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<128, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; case 256: - topk_moe_cuda<256><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<256, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; case 512: - topk_moe_cuda<512><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<512, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); break; default: GGML_ASSERT(false && "fatal error"); @@ -168,9 +180,13 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, cudaStream_t stream = ctx.stream(); - const int n_expert_used = weights->ne[1]; - - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + if (weights->op == GGML_OP_DIV) { + const int n_expert_used = weights->ne[0]; + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } else { + const int n_expert_used = weights->ne[1]; + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } } bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index a3391a4c..9dc72999 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -865,10 +865,6 @@ llm_expert_gating_func_type gating_op, cb(weights, "ffn_moe_weights_softmax", il); } - if (graph) { - ggml_build_forward_expand(graph, weights); - } - if (norm_w) { weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); @@ -890,6 +886,10 @@ llm_expert_gating_func_type gating_op, cb(weights, "ffn_moe_weights_scaled", il); } + if (graph) { + ggml_build_forward_expand(graph, weights); + } + cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); if (weight_before_ffn) {