Fusing MoE up * unary(gate): CUDA

We get ~13% speedup for PP-512 and ~2% for TG-128
for DeepSeek-Lite
This commit is contained in:
Iwan Kawrakow
2025-02-22 19:01:18 +02:00
parent da64673670
commit 001abccf73
3 changed files with 207 additions and 12 deletions

View File

@@ -2195,6 +2195,176 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0_1 = dst->src[0];
const ggml_tensor * src0_2 = dst->src[1];
const ggml_tensor * src0 = src0_1;
const ggml_tensor * src1 = dst->src[2];
const ggml_tensor * ids = dst->src[3];
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers");
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_2->buffer) && "mul_mat_id does not support split buffers");
cudaStream_t stream = ctx.stream();
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
ggml_tensor src0_1_row = *src0_1;
ggml_tensor src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
char * src0_1_original = (char *) src0_1->data;
char * src0_2_original = (char *) src0_2->data;
char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data;
src0_1_row.ne[2] = 1;
src0_1_row.ne[3] = 1;
src0_1_row.nb[3] = nb02;
src0_2_row.ne[2] = 1;
src0_2_row.ne[3] = 1;
src0_2_row.nb[3] = nb02;
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
if (ne12 == 1) {
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
const int64_t i1 = id;
const int64_t i2 = i12;
src0_1_row.data = src0_1_original + i02*nb02;
src0_2_row.data = src0_2_original + i02*nb02;
src1_row.data = src1_original + i11*nb11 + i12*nb12;
//dst_row.data = dst_original + i1*nb1 + i2*nb2;
dst_row.data = dst_up_contiguous.get();
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
CUDA_CHECK(cudaGetLastError());
dst_row.data = dst_gate_contiguous.get();
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
CUDA_CHECK(cudaGetLastError());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
CUDA_CHECK(cudaGetLastError());
}
}
} else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
src1_row.data = src1_contiguous.get();
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (row_id_i != i02) {
continue;
}
num_src1_rows++;
}
}
if (num_src1_rows == 0) {
continue;
}
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
{
dim3 block_dims(std::min((unsigned int)ne10, 768u));
dim3 grid_dims(ids->ne[1], n_ids);
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
src1_original, src1_contiguous.get(),
dev_cur_src1_row.get(), dev_row_mapping.get(),
ids_dev, i02, ids->nb[1], ids->nb[0],
ne11, ne10,
nb11, nb12);
CUDA_CHECK(cudaGetLastError());
}
src0_1_row.data = src0_1_original + i02*nb02;
src0_2_row.data = src0_2_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0);
src1_row.ne[1] = num_src1_rows;
src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows*nb11;
src1_row.nb[3] = num_src1_rows*nb11;
dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
dst_row.data = dst_up_contiguous.get();
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
CUDA_CHECK(cudaGetLastError());
dst_row.data = dst_gate_contiguous.get();
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
CUDA_CHECK(cudaGetLastError());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
CUDA_CHECK(cudaGetLastError());
{
dim3 block_dims(std::min((unsigned int)ne0, 768u));
dim3 grid_dims(num_src1_rows);
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
dst_original, dst_gate_contiguous.get(),
dev_row_mapping.get(),
ne0,
nb1, nb2);
CUDA_CHECK(cudaGetLastError());
}
}
}
}
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
// why is this here instead of mul_mat?
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
@@ -2309,6 +2479,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_MUL_MAT_ID:
ggml_cuda_mul_mat_id(ctx, dst);
break;
case GGML_OP_MOE_FUSED_UP_GATE:
ggml_cuda_up_gate_unary(ctx, dst);
break;
case GGML_OP_SCALE:
ggml_cuda_op_scale(ctx, dst);
break;
@@ -2595,7 +2768,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID) {
if (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_MOE_FUSED_UP_GATE) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
@@ -2809,9 +2982,13 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_FUSED_MUL_UNARY: return ggml_is_contiguous(op->src[0]);
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_MOE_FUSED_UP_GATE:
{
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
struct ggml_tensor * b = op->op == GGML_OP_MOE_FUSED_UP_GATE ? op->src[2] : op->src[1];
if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) {
return false;
}
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false;
}

View File

@@ -297,6 +297,19 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream);
}
void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
int64_t nelements, const float * src0_d, const float * src1_d, float * dst_d) {
cudaStream_t stream = ctx.stream();
switch (op) {
case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break;
case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break;
case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break;
default: GGML_ASSERT(false);
}
}
void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -304,19 +317,22 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor *
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
cudaStream_t stream = ctx.stream();
ggml_unary_op op = (ggml_unary_op)dst->op_params[0];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data);
switch (op) {
case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
default: GGML_ASSERT(false);
}
//cudaStream_t stream = ctx.stream();
//const float * src0_d = (const float *)src0->data;
//const float * src1_d = (const float *)src1->data;
//float * dst_d = (float *)dst->data;
//switch (op) {
// case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
// case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
// case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
// default: GGML_ASSERT(false);
//}
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@@ -36,5 +36,7 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
int64_t nelements, const float * x, const float * y, float * z);
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);