Macro to easily enable/disable fusion

This commit is contained in:
Iwan Kawrakow
2025-10-22 07:51:53 +03:00
parent c291fc056c
commit 4bc2360f76
2 changed files with 21 additions and 17 deletions

View File

@@ -3100,6 +3100,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
}
#define ENABLE_FUSION false
#if IK_PRINT_TIMING
int64_t tim1 = ggml_time_us();
#endif
@@ -3129,7 +3131,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_dup(ctx, dst);
break;
case GGML_OP_ADD:
if (i + 2 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_ADD &&
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
ggml_is_contiguous(dst->src[0]) &&
@@ -3143,7 +3145,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
}
else if (false && i + 1 < cgraph->n_nodes &&
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM &&
ggml_is_contiguous(dst->src[0]) &&
ggml_is_contiguous(dst->src[1]) &&
@@ -3197,7 +3199,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_relu(ctx, dst);
break;
case GGML_UNARY_OP_SIGMOID:
if (i + 5 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
@@ -3206,7 +3208,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]);
i += 5;
}
else if (i + 4 < cgraph->n_nodes &&
else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK &&
@@ -3323,7 +3325,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_diag_mask_inf(ctx, dst);
break;
case GGML_OP_SOFT_MAX:
if (i + 4 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
@@ -3357,7 +3359,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_pool2d(ctx, dst);
break;
case GGML_OP_SUM_ROWS:
if (i + 1 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_DIV &&
cgraph->nodes[i+1]->src[1] == dst &&
cgraph->nodes[i+1]->src[0] == dst->src[0]) {
@@ -3368,7 +3370,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
}
break;
case GGML_OP_ARGSORT:
if (i + 5 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS &&
cgraph->nodes[i+3]->op == GGML_OP_RESHAPE &&
@@ -3404,6 +3406,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1));
#endif
#undef ENABLE_FUSION
return true;
}

View File

@@ -59,7 +59,7 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n
// int min_experts, float thresh_experts) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -101,7 +101,7 @@ static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float
size_t nb_ids) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -129,7 +129,7 @@ static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float
size_t nb_ids) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -158,7 +158,7 @@ static __global__ void k_openai_f32_f32_i32(const float * x, float * weights, in
size_t nb_ids) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -204,7 +204,7 @@ template<ggml_sort_order order>
static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -275,7 +275,7 @@ static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, co
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
@@ -306,7 +306,7 @@ static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, fl
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
@@ -329,7 +329,7 @@ static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias,
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float));
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
@@ -352,7 +352,7 @@ static void argsort_openai_f32_f32_i32_cuda(const float * x, float * weights, in
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = (ncols_pad + ncols_pad > WARP_SIZE ? WARP_SIZE : 0) * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
@@ -415,7 +415,7 @@ static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float *
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = (ncols_pad + WARP_SIZE) * sizeof(int);
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);