From 4bc2360f760496b8a171fa521f00aee272a4ac27 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 22 Oct 2025 07:51:53 +0300 Subject: [PATCH] Macro to easily enable/disable fusion --- ggml/src/ggml-cuda.cu | 18 +++++++++++------- ggml/src/ggml-cuda/argsort.cu | 20 ++++++++++---------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index a8c1e894..0ec4abe4 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3100,6 +3100,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); } +#define ENABLE_FUSION false + #if IK_PRINT_TIMING int64_t tim1 = ggml_time_us(); #endif @@ -3129,7 +3131,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_dup(ctx, dst); break; case GGML_OP_ADD: - if (i + 2 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 2 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_ADD && cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && ggml_is_contiguous(dst->src[0]) && @@ -3143,7 +3145,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]); i += 2; } - else if (false && i + 1 < cgraph->n_nodes && + else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && @@ -3197,7 +3199,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_relu(ctx, dst); break; case GGML_UNARY_OP_SIGMOID: - if (i + 5 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_ARGSORT && @@ -3206,7 +3208,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]); i += 5; } - else if (i + 4 < cgraph->n_nodes && + else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && @@ -3323,7 +3325,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - if (i + 4 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ARGSORT && cgraph->nodes[i+3]->op == GGML_OP_VIEW && @@ -3357,7 +3359,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_pool2d(ctx, dst); break; case GGML_OP_SUM_ROWS: - if (i + 1 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 1 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_DIV && cgraph->nodes[i+1]->src[1] == dst && cgraph->nodes[i+1]->src[0] == dst->src[0]) { @@ -3368,7 +3370,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } break; case GGML_OP_ARGSORT: - if (i + 5 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS && cgraph->nodes[i+3]->op == GGML_OP_RESHAPE && @@ -3404,6 +3406,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1)); #endif +#undef ENABLE_FUSION + return true; } diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 14d5d93b..e9672100 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -59,7 +59,7 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n // int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -101,7 +101,7 @@ static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float size_t nb_ids) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -129,7 +129,7 @@ static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float size_t nb_ids) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -158,7 +158,7 @@ static __global__ void k_openai_f32_f32_i32(const float * x, float * weights, in size_t nb_ids) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -204,7 +204,7 @@ template static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -275,7 +275,7 @@ static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, co const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = ncols_pad * sizeof(int); // FIXME: this limit could be raised by ~2-4x on Ampere or newer @@ -306,7 +306,7 @@ static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, fl const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = ncols_pad * sizeof(int); // FIXME: this limit could be raised by ~2-4x on Ampere or newer @@ -329,7 +329,7 @@ static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias, const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float)); // FIXME: this limit could be raised by ~2-4x on Ampere or newer @@ -352,7 +352,7 @@ static void argsort_openai_f32_f32_i32_cuda(const float * x, float * weights, in const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = (ncols_pad + ncols_pad > WARP_SIZE ? WARP_SIZE : 0) * sizeof(int); // FIXME: this limit could be raised by ~2-4x on Ampere or newer @@ -415,7 +415,7 @@ static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = (ncols_pad + WARP_SIZE) * sizeof(int); GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);