mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-06 03:50:08 +00:00
On CUDA also fuse MoE down * (up * unary(gate))
in case the MUL_MAT_ID op for the down experts is the next op in the graph.
This commit is contained in:
@@ -2195,7 +2195,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
|
||||
const ggml_tensor * src0_1 = dst->src[0];
|
||||
const ggml_tensor * src0_2 = dst->src[1];
|
||||
const ggml_tensor * src0 = src0_1;
|
||||
@@ -2221,6 +2221,8 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
ggml_tensor src0_2_row = *src0_2;
|
||||
ggml_tensor src1_row = *src1;
|
||||
ggml_tensor dst_row = *dst;
|
||||
ggml_tensor final_dst;
|
||||
ggml_tensor final_src;
|
||||
|
||||
char * src0_1_original = (char *) src0_1->data;
|
||||
char * src0_2_original = (char *) src0_2->data;
|
||||
@@ -2246,9 +2248,27 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
dst_row.nb[2] = nb1;
|
||||
dst_row.nb[3] = nb1;
|
||||
|
||||
bool fuse_down = false;
|
||||
if (next && next->op == GGML_OP_MUL_MAT_ID) {
|
||||
//printf("Fusing MoE down gemm\n");
|
||||
fuse_down = true;
|
||||
final_dst = *next;
|
||||
final_dst.ne[1] = final_dst.ne[2] = final_dst.ne[3] = 1;
|
||||
final_dst.nb[2] = final_dst.nb[3] = final_dst.nb[1];
|
||||
final_src = *next->src[0];
|
||||
//printf("next->src[0]: %s, %d x %d x %d x %d and %d x %d x %d x %d\n", ggml_type_name(next->src[0]->type),
|
||||
// (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3],
|
||||
// (int)next->src[0]->nb[0], (int)next->src[0]->nb[1], (int)next->src[0]->nb[2], (int)next->src[0]->nb[3]);
|
||||
final_src.ne[2] = final_src.ne[3] = 1;
|
||||
final_src.nb[3] = final_src.nb[2];
|
||||
}
|
||||
|
||||
if (ne12 == 1) {
|
||||
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
|
||||
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
|
||||
if (fuse_down) {
|
||||
final_dst.src[1] = &dst_row;
|
||||
}
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
@@ -2274,18 +2294,39 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
if (fuse_down) {
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
|
||||
final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2];
|
||||
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
} else {
|
||||
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||
ggml_cuda_pool_alloc<char> final_dst_contiguous(ctx.pool());
|
||||
if (fuse_down) {
|
||||
final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next));
|
||||
final_dst.src[1] = &dst_row;
|
||||
}
|
||||
|
||||
src1_row.data = src1_contiguous.get();
|
||||
|
||||
bool first = false; //true;
|
||||
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
int64_t num_src1_rows = 0;
|
||||
|
||||
@@ -2351,7 +2392,39 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
{
|
||||
if (fuse_down) {
|
||||
|
||||
final_dst.ne[1] = num_src1_rows;
|
||||
final_dst.nb[1] = final_dst.ne[0]*sizeof(float);
|
||||
final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1];
|
||||
final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
|
||||
if (first) {
|
||||
printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows,
|
||||
(int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3],
|
||||
(int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3],
|
||||
(int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]);
|
||||
printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n",
|
||||
(int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3],
|
||||
(int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3],
|
||||
(int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]);
|
||||
first = false;
|
||||
}
|
||||
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
|
||||
//ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
dim3 block_dims(std::min((unsigned int)next->ne[0], 768u));
|
||||
dim3 grid_dims(num_src1_rows);
|
||||
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
(char *)next->data, final_dst_contiguous.get(),
|
||||
dev_row_mapping.get(),
|
||||
next->ne[0],
|
||||
next->nb[1], next->nb[2]);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
dim3 block_dims(std::min((unsigned int)ne0, 768u));
|
||||
dim3 grid_dims(num_src1_rows);
|
||||
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
@@ -2363,9 +2436,11 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fuse_down;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
||||
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) {
|
||||
// why is this here instead of mul_mat?
|
||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
|
||||
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
|
||||
@@ -2480,7 +2555,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_mul_mat_id(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MOE_FUSED_UP_GATE:
|
||||
ggml_cuda_up_gate_unary(ctx, dst);
|
||||
skip_next = ggml_cuda_up_gate_unary(ctx, dst, next);
|
||||
break;
|
||||
case GGML_OP_SCALE:
|
||||
ggml_cuda_op_scale(ctx, dst);
|
||||
@@ -2839,6 +2914,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr;
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
@@ -2853,11 +2929,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||
}
|
||||
#endif
|
||||
|
||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
|
||||
bool skip_next = false;
|
||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next);
|
||||
if (!ok) {
|
||||
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
if (skip_next) ++i;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user