On CUDA also fuse MoE down * (up * unary(gate))

in case the MUL_MAT_ID op for the down experts is the next
op in the graph.
This commit is contained in:
Iwan Kawrakow
2025-02-23 09:47:01 +02:00
parent 001abccf73
commit c229183737

View File

@@ -2195,7 +2195,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
const ggml_tensor * src0_1 = dst->src[0];
const ggml_tensor * src0_2 = dst->src[1];
const ggml_tensor * src0 = src0_1;
@@ -2221,6 +2221,8 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
ggml_tensor src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
ggml_tensor final_dst;
ggml_tensor final_src;
char * src0_1_original = (char *) src0_1->data;
char * src0_2_original = (char *) src0_2->data;
@@ -2246,9 +2248,27 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
bool fuse_down = false;
if (next && next->op == GGML_OP_MUL_MAT_ID) {
//printf("Fusing MoE down gemm\n");
fuse_down = true;
final_dst = *next;
final_dst.ne[1] = final_dst.ne[2] = final_dst.ne[3] = 1;
final_dst.nb[2] = final_dst.nb[3] = final_dst.nb[1];
final_src = *next->src[0];
//printf("next->src[0]: %s, %d x %d x %d x %d and %d x %d x %d x %d\n", ggml_type_name(next->src[0]->type),
// (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3],
// (int)next->src[0]->nb[0], (int)next->src[0]->nb[1], (int)next->src[0]->nb[2], (int)next->src[0]->nb[3]);
final_src.ne[2] = final_src.ne[3] = 1;
final_src.nb[3] = final_src.nb[2];
}
if (ne12 == 1) {
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
if (fuse_down) {
final_dst.src[1] = &dst_row;
}
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
@@ -2274,18 +2294,39 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
CUDA_CHECK(cudaGetLastError());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
CUDA_CHECK(cudaGetLastError());
if (fuse_down) {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
CUDA_CHECK(cudaGetLastError());
final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2];
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
CUDA_CHECK(cudaGetLastError());
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
CUDA_CHECK(cudaGetLastError());
}
}
}
} else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> final_dst_contiguous(ctx.pool());
if (fuse_down) {
final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next));
final_dst.src[1] = &dst_row;
}
src1_row.data = src1_contiguous.get();
bool first = false; //true;
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
@@ -2351,7 +2392,39 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
CUDA_CHECK(cudaGetLastError());
{
if (fuse_down) {
final_dst.ne[1] = num_src1_rows;
final_dst.nb[1] = final_dst.ne[0]*sizeof(float);
final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1];
final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
if (first) {
printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows,
(int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3],
(int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3],
(int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]);
printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n",
(int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3],
(int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3],
(int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]);
first = false;
}
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
//ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst);
CUDA_CHECK(cudaGetLastError());
dim3 block_dims(std::min((unsigned int)next->ne[0], 768u));
dim3 grid_dims(num_src1_rows);
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
(char *)next->data, final_dst_contiguous.get(),
dev_row_mapping.get(),
next->ne[0],
next->nb[1], next->nb[2]);
CUDA_CHECK(cudaGetLastError());
}
else {
dim3 block_dims(std::min((unsigned int)ne0, 768u));
dim3 grid_dims(num_src1_rows);
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
@@ -2363,9 +2436,11 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
}
}
return fuse_down;
}
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) {
// why is this here instead of mul_mat?
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
@@ -2480,7 +2555,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_mul_mat_id(ctx, dst);
break;
case GGML_OP_MOE_FUSED_UP_GATE:
ggml_cuda_up_gate_unary(ctx, dst);
skip_next = ggml_cuda_up_gate_unary(ctx, dst, next);
break;
case GGML_OP_SCALE:
ggml_cuda_op_scale(ctx, dst);
@@ -2839,6 +2914,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
if (!use_cuda_graph || cuda_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr;
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
@@ -2853,11 +2929,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
}
#endif
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
bool skip_next = false;
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next);
if (!ok) {
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
if (skip_next) ++i;
}
}