Also use it in the fused up+gate op

This commit is contained in:
Iwan Kawrakow
2025-11-08 12:48:01 +02:00
parent 675f36787d
commit 8f0fad3109

View File

@@ -2519,9 +2519,6 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
src1_padded_num_cols, src0->type, stream);
src1_row.nb[1] = src1_padded_row_size;
src1_row.nb[2] = src1_row.nb[3] = src1_row.nb[1]*num_src1_rows;
//ggml_cuda_op_mul_mat_q(ctx, &src0_row, &src1_row, &dst_row, (const char *)src0_row.data, nullptr,
// src1_quantized.get(), (float *)dst_row.data,
// 0, src0_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
ggml_cuda_mul_mat_q_id(ctx, &src0_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get());
CUDA_CHECK(cudaGetLastError());
@@ -2797,15 +2794,11 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
bool fuse_down = false;
if (next && next->op == GGML_OP_MUL_MAT_ID) {
//printf("Fusing MoE down gemm\n");
fuse_down = true;
final_dst = *next;
final_dst.ne[1] = final_dst.ne[2] = final_dst.ne[3] = 1;
final_dst.nb[2] = final_dst.nb[3] = final_dst.nb[1];
final_src = *next->src[0];
//printf("next->src[0]: %s, %d x %d x %d x %d and %d x %d x %d x %d\n", ggml_type_name(next->src[0]->type),
// (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3],
// (int)next->src[0]->nb[0], (int)next->src[0]->nb[1], (int)next->src[0]->nb[2], (int)next->src[0]->nb[3]);
final_src.ne[2] = final_src.ne[3] = 1;
final_src.nb[3] = final_src.nb[2];
}
@@ -2836,8 +2829,6 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
src1_row.data = src1_contiguous.get();
bool first = false; //true;
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool());
std::vector<int> moe_counts, cum_moe_counts;
@@ -2889,8 +2880,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
dst_row.data = dst_up_contiguous.get();
if (use_quantized_src1) {
ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
ggml_cuda_mul_mat_q_id(ctx, &src0_1_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get());
} else {
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row, nullptr, 0);
}
@@ -2906,8 +2896,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
dst_row.data = dst_gate_contiguous.get();
if (use_quantized_src1) {
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
ggml_cuda_mul_mat_q_id(ctx, &src0_2_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get());
} else {
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row, nullptr, 0);
}
@@ -2939,18 +2928,12 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
final_dst.nb[1] = final_dst.ne[0]*sizeof(float);
final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1];
final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
if (first) {
printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows,
(int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3],
(int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3],
(int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]);
printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n",
(int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3],
(int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3],
(int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]);
first = false;
if (ggml_is_quantized(next->src[0]->type) &&
ggml_cuda_should_use_mmq(final_src.type, ggml_cuda_info().devices[ctx.device].cc, dst_row.ne[1])) {
ggml_cuda_mul_mat_q_id(ctx, &final_src, &dst_row, nullptr, &final_dst, nullptr, nullptr);
} else {
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst, nullptr, 0);
}
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst, nullptr, 0);
CUDA_CHECK(cudaGetLastError());
dim3 block_dims(std::min((unsigned int)next->ne[0], 768u));