From 8f0fad3109bb04575fe2467ed064d022de16d3cc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 8 Nov 2025 12:48:01 +0200 Subject: [PATCH] Also use it in the fused up+gate op --- ggml/src/ggml-cuda.cu | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 4bb46ff2..42a76b63 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2519,9 +2519,6 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * src1_padded_num_cols, src0->type, stream); src1_row.nb[1] = src1_padded_row_size; src1_row.nb[2] = src1_row.nb[3] = src1_row.nb[1]*num_src1_rows; - //ggml_cuda_op_mul_mat_q(ctx, &src0_row, &src1_row, &dst_row, (const char *)src0_row.data, nullptr, - // src1_quantized.get(), (float *)dst_row.data, - // 0, src0_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); ggml_cuda_mul_mat_q_id(ctx, &src0_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get()); CUDA_CHECK(cudaGetLastError()); @@ -2797,15 +2794,11 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten bool fuse_down = false; if (next && next->op == GGML_OP_MUL_MAT_ID) { - //printf("Fusing MoE down gemm\n"); fuse_down = true; final_dst = *next; final_dst.ne[1] = final_dst.ne[2] = final_dst.ne[3] = 1; final_dst.nb[2] = final_dst.nb[3] = final_dst.nb[1]; final_src = *next->src[0]; - //printf("next->src[0]: %s, %d x %d x %d x %d and %d x %d x %d x %d\n", ggml_type_name(next->src[0]->type), - // (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], - // (int)next->src[0]->nb[0], (int)next->src[0]->nb[1], (int)next->src[0]->nb[2], (int)next->src[0]->nb[3]); final_src.ne[2] = final_src.ne[3] = 1; final_src.nb[3] = final_src.nb[2]; } @@ -2836,8 +2829,6 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten src1_row.data = src1_contiguous.get(); - bool first = false; //true; - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); std::vector moe_counts, cum_moe_counts; @@ -2889,8 +2880,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten dst_row.data = dst_up_contiguous.get(); if (use_quantized_src1) { - ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, - 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + ggml_cuda_mul_mat_q_id(ctx, &src0_1_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get()); } else { ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row, nullptr, 0); } @@ -2906,8 +2896,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten dst_row.data = dst_gate_contiguous.get(); if (use_quantized_src1) { - ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, - 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + ggml_cuda_mul_mat_q_id(ctx, &src0_2_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get()); } else { ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row, nullptr, 0); } @@ -2939,18 +2928,12 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten final_dst.nb[1] = final_dst.ne[0]*sizeof(float); final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - if (first) { - printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, - (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], - (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], - (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); - printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", - (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], - (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], - (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); - first = false; + if (ggml_is_quantized(next->src[0]->type) && + ggml_cuda_should_use_mmq(final_src.type, ggml_cuda_info().devices[ctx.device].cc, dst_row.ne[1])) { + ggml_cuda_mul_mat_q_id(ctx, &final_src, &dst_row, nullptr, &final_dst, nullptr, nullptr); + } else { + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst, nullptr, 0); } - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst, nullptr, 0); CUDA_CHECK(cudaGetLastError()); dim3 block_dims(std::min((unsigned int)next->ne[0], 768u));