Merge ffn_up and ffn_gate experts tensors (#1137)

* WIP - not working

* WIP - not working

* WIP - GPT-OSS working

However, extremely stupid. The only way I could correctly repack the
up/gate experts is to copy up and gate into host buffers, repack
into another host buffer, copy back into the ffn_up_gate_exps tensor.
This is going to be very slow for giant 500 GB models.

My attempts to do this via a compute graph on the backend holding
the tensors was unsuccessful.

For GPT-OSS-20B I see ~6-7% better PP when using the original
ik_llama.cpp fused_up_gate CUDA implementation, and ~10% when
using the small batch size implementation.

Other models are not working yet on CUDA as I need to fix the
fused mul-unary implementation.

* WIP

* WIP - Qwen3-MoE (and hopefully all others) working

But when I say here and in the previous commit "working",
I mean PP is working. TG is still broken.

* WIP: TG seems to be working

* Minor

* Add command line option to merge experts up/gate

* Add merge up/gate command line parameter to llama-bench

* Turn off merge_up_gate_exps if split mode graph

It is not yet implemented

* When no bias, allow merging up/gate with tensor overrides

* Arghh, we need to increase the context size again

* Cleanup
This commit is contained in:
Kawrakow
2026-01-12 18:30:53 +02:00
committed by GitHub
parent bf0c6c57bb
commit c03c2d7cc6
16 changed files with 505 additions and 134 deletions

View File

@@ -1,4 +1,5 @@
//
// Copyright (C) 2023-2024 The ggml authors
// Copyright (C) 2024 Iwan Kawrakow
// MIT license
@@ -2487,23 +2488,21 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 &&
ggml_is_quantized(src0_1->type) &&
ggml_is_quantized(src0_2->type) &&
(!src0_2 || ggml_is_quantized(src0_2->type)) &&
ggml_backend_buffer_is_cuda(src0_1->buffer) &&
ggml_backend_buffer_is_cuda(src0_2->buffer) &&
(!src0_2 || ggml_backend_buffer_is_cuda(src0_2->buffer)) &&
ggml_backend_buffer_is_cuda(src1->buffer) &&
ggml_backend_buffer_is_cuda(dst->buffer) &&
src1->type == GGML_TYPE_F32) {
int device_id = ctx.device;
ggml_backend_cuda_buffer_context * src0_1_ctx = (ggml_backend_cuda_buffer_context *) src0_1->buffer->context;
ggml_backend_cuda_buffer_context * src0_2_ctx = (ggml_backend_cuda_buffer_context *) src0_2->buffer->context;
ggml_backend_cuda_buffer_context * src0_2_ctx = src0_2 ? (ggml_backend_cuda_buffer_context *) src0_2->buffer->context : nullptr;
ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
if (src0_1_ctx->device == device_id &&
src0_2_ctx->device == device_id &&
(!src0_2_ctx || src0_2_ctx->device == device_id) &&
src1_ctx->device == device_id &&
dst_ctx->device == device_id) {
//printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name,
// src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]);
// Fast TG path
const int64_t n_ids = ids->ne[0];
auto stream = ctx.stream(device_id, 0);
@@ -2518,7 +2517,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
if (ggml_is_quantized(src0_1->type) || ggml_is_quantized(src0_2->type)) {
if (ggml_is_quantized(src0_1->type) || (src0_2 && ggml_is_quantized(src0_2->type))) {
GGML_ASSERT(src1->ne[0] % QK8_1 == 0);
auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1;
local_src1.data = src1_quantized.alloc(src_1_ddq_size);
@@ -2538,10 +2537,36 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id;
auto unary_op = (ggml_unary_op)dst->op_params[0];
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst,
dst->src[4], dst->src[5],
(const char *)src0_1->data, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, stream);
if (src0_2) {
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst,
dst->src[4], dst->src[5],
(const char *)src0_1->data, src0_2 ? (const char *)src0_2->data : nullptr,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, stream);
} else {
auto local_src0_1 = *src0_1;
local_src0_1.ne[1] /= 2;
auto local_src0_2 = local_src0_1;
local_src0_2.data = (char *)local_src0_1.data + local_src0_1.ne[1]*local_src0_1.nb[1];
if (!dst->src[4]) {
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst,
nullptr, nullptr,
(const char *)local_src0_1.data, (const char *)local_src0_2.data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, stream);
} else {
GGML_ASSERT(!dst->src[5]);
auto local_bias_1 = *dst->src[4];
local_bias_1.ne[0] /= 2;
auto local_bias_2 = local_bias_1;
local_bias_2.data = (char *)local_bias_1.data + local_bias_1.ne[0]*local_bias_1.nb[0];
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst,
&local_bias_1, &local_bias_2,
(const char *)local_src0_1.data, (const char *)local_src0_2.data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, stream);
}
}
CUDA_CHECK(cudaGetLastError());
if (!fuse_next) return i;
@@ -2608,7 +2633,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
// looks like it really depends just on the total number of experts.
// TODO: verify with more models, or perhaps make the magic constant '32' to be defined via a compile time define.
if (src1->ne[2] <= ctx.mmq_id_thresh*src0->ne[2] &&
ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 &&
ggml_is_quantized(src0_1->type) && (!src0_2 || src0_1->type == src0_2->type) && src1->ne[1] == 1 && src1->ne[3] == 1 &&
ggml_cuda_can_use_mmq_id(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
const int64_t ne_get_rows = ne12 * n_ids;
@@ -2631,6 +2656,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1,
ne10_padded, ne11_flat, 1, 1, stream);
if (src0_2) {
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
@@ -2662,6 +2688,34 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst->data);
}
} else {
ggml_cuda_pool_alloc<char> dst_up_gate_contiguous(ctx.pool(), 2*sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
dst_row.ne[0] *= 2;
dst_row.nb[1] *= 2;
dst_row.nb[2] *= 2;
dst_row.nb[3] *= 2;
dst_row.data = dst_up_gate_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get());
if (dst->src[4]) {
GGML_ASSERT(!dst->src[5]);
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data,
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
CUDA_CHECK(cudaGetLastError());
}
auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_up_gate_contiguous.get() + dst->ne[0], (const float *)dst_up_gate_contiguous.get(),
(float *)dst->data, ggml_nelements(dst), dst->ne[0], src0_1->ne[1], src0_1->ne[1],
1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), dst->ne[0],
(const float *)dst_up_gate_contiguous.get(), (float *)dst->data);
}
}
CUDA_CHECK(cudaGetLastError());
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
@@ -2680,22 +2734,24 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
CUDA_CHECK(cudaStreamSynchronize(stream));
ggml_tensor src0_1_row = *src0_1;
ggml_tensor src0_2_row = *src0_2;
ggml_tensor src0_2_row; if (src0_2) src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
ggml_tensor final_dst;
ggml_tensor final_src;
char * src0_1_original = (char *) src0_1->data;
char * src0_2_original = (char *) src0_2->data;
char * src0_2_original = src0_2 ? (char *) src0_2->data : nullptr;
char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data;
src0_1_row.ne[2] = 1;
src0_1_row.ne[3] = 1;
src0_1_row.nb[3] = nb02;
src0_2_row.ne[2] = 1;
src0_2_row.ne[3] = 1;
src0_2_row.nb[3] = nb02;
if (src0_2) {
src0_2_row.ne[2] = 1;
src0_2_row.ne[3] = 1;
src0_2_row.nb[3] = nb02;
}
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
@@ -2723,7 +2779,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
bool use_quantized_src1 = false;
int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0;
if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) {
if (ggml_is_quantized(src0_1->type) && (!src0_2 || src0_1->type == src0_2->type) && src1->ne[1] == 1 && src1->ne[3] == 1) {
if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1);
@@ -2736,8 +2792,14 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
if (!use_quantized_src1) {
src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1));
}
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool()), dst_gate_contiguous(ctx.pool());
if (src0_2) {
dst_up_contiguous.alloc(sizeof(float)*ggml_nelements(dst));
dst_gate_contiguous.alloc(sizeof(float)*ggml_nelements(dst));
} else {
dst_up_contiguous.alloc(2*sizeof(float)*ggml_nelements(dst));
dst_gate_contiguous.alloc(sizeof(float)*ggml_nelements(dst));
}
ggml_cuda_pool_alloc<char> final_dst_contiguous(ctx.pool());
if (fuse_down) {
final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next));
@@ -2780,20 +2842,26 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
}
src0_1_row.data = src0_1_original + i02*nb02;
src0_2_row.data = src0_2_original + i02*nb02;
if (src0_2_original) src0_2_row.data = src0_2_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0);
auto nb1l = nb1;
if (!src0_2) {
nb1l = nb1*2;
dst_row.ne[0] = dst->ne[0] * 2;
}
src1_row.ne[1] = num_src1_rows;
src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11;
src1_row.nb[2] = num_src1_rows*src1_row.nb[1];
src1_row.nb[3] = num_src1_rows*src1_row.nb[1];
dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
dst_row.nb[1] = nb1l;
dst_row.nb[2] = num_src1_rows*nb1l;
dst_row.nb[3] = num_src1_rows*nb1l;
dst_row.data = dst_up_contiguous.get();
if (use_quantized_src1) {
@@ -2804,6 +2872,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
CUDA_CHECK(cudaGetLastError());
if (dst->src[4]) {
GGML_ASSERT(dst_row.ne[0] == dst->src[4]->ne[0]);
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
dim3 grid_dims(num_src1_rows);
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
@@ -2811,31 +2880,46 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
CUDA_CHECK(cudaGetLastError());
}
dst_row.data = dst_gate_contiguous.get();
if (use_quantized_src1) {
ggml_cuda_mul_mat_q_id(ctx, &src0_2_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get());
} else {
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row, nullptr, 0);
}
CUDA_CHECK(cudaGetLastError());
if (dst->src[5]) {
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
dim3 grid_dims(num_src1_rows);
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
(const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data);
CUDA_CHECK(cudaGetLastError());
}
auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
if (src0_2) {
dst_row.data = dst_gate_contiguous.get();
if (use_quantized_src1) {
ggml_cuda_mul_mat_q_id(ctx, &src0_2_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get());
} else {
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row, nullptr, 0);
}
CUDA_CHECK(cudaGetLastError());
if (dst->src[5]) {
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
dim3 grid_dims(num_src1_rows);
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
(const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data);
CUDA_CHECK(cudaGetLastError());
}
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0],
1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
}
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_up_contiguous.get() + dst->ne[0], (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row)/2, dst->ne[0], src0_1->ne[1], src0_1->ne[1],
1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row)/2, dst->ne[0],
(const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
}
dst_row.data = dst_gate_contiguous.get();
dst_row.ne[0] /= 2;
dst_row.nb[1] /= 2;
dst_row.nb[2] /= 2;
dst_row.nb[3] /= 2;
}
CUDA_CHECK(cudaGetLastError());
@@ -3603,7 +3687,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
auto src0_2 = node->src[1];
auto src1 = node->src[2];
if (src1->ne[1] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || src1->type != GGML_TYPE_F32 ||
!ggml_is_quantized(src0_1->type) || !ggml_is_quantized(src0_2->type)) {
!ggml_is_quantized(src0_1->type) || (src0_2 && !ggml_is_quantized(src0_2->type))) {
use_cuda_graph = false;
} else {
if (i < cgraph->n_nodes-1) {
@@ -3967,8 +4051,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
bool is_fused_up_gate = op->op == GGML_OP_MOE_FUSED_UP_GATE || op->op == GGML_OP_FUSED_UP_GATE;
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = is_fused_up_gate ? op->src[2] : op->src[1];
if (is_fused_up_gate && a->type != op->src[1]->type) {
printf("%s: returning false for GGML_OP_MOE_FUSED_UP_GATE because src0->type != src1->type\n", __func__);
if (is_fused_up_gate && op->src[1] && a->type != op->src[1]->type) {
fprintf(stderr, "%s: returning false for GGML_OP_MOE_FUSED_UP_GATE because src0->type != src1->type\n", __func__);
return false;
}
//==================================================================

View File

@@ -61,6 +61,18 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa
dst[i] = x[i] * y[i] / (1.0f + expf(-x[i]));
}
static __global__ void fused_mul_silu_f32(const float * x, float * dst, const int k, const int ne0) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
int row = i / ne0;
int j = i % ne0;
auto x_row = x + 2*row*ne0;
dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x_row[j + ne0]));
}
static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -70,6 +82,18 @@ static __global__ void fused_mul_relu_f32(const float * x, const float * y, floa
dst[i] = fmaxf(x[i], 0) * y[i];
}
static __global__ void fused_mul_relu_f32(const float * x, float * dst, const int k, const int ne0) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
int row = i / ne0;
int j = i % ne0;
auto x_row = x + 2*row*ne0;
dst[i] = fmaxf(x_row[j + ne0], 0) * x_row[j];
}
static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) {
constexpr float GELU_COEF_A = 0.044715f;
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -82,6 +106,21 @@ static __global__ void fused_mul_gelu_f32(const float * x, const float * y, floa
dst[i] = 0.5f*xi*y[i]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
}
static __global__ void fused_mul_gelu_f32(const float * x, float * dst, const int k, const int ne0) {
constexpr float GELU_COEF_A = 0.044715f;
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
int row = i / ne0;
int j = i % ne0;
auto x_row = x + 2*row*ne0;
float xi = x_row[j + ne0];
dst[i] = 0.5f*xi*x_row[j]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
}
static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@@ -199,6 +238,21 @@ static void fused_mul_gelu_f32_cuda(const float * x, const float * y, float * ds
fused_mul_gelu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
}
static void fused_mul_silu_f32_cuda(const float * x, float * dst, const int k, const int ne0, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
fused_mul_silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0);
}
static void fused_mul_relu_f32_cuda(const float * x, float * dst, const int k, const int ne0, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
fused_mul_relu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0);
}
static void fused_mul_gelu_f32_cuda(const float * x, float * dst, const int k, const int ne0, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
fused_mul_gelu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0);
}
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -302,29 +356,33 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
}
}
void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
int64_t nelements, int64_t ne0, const float * src0_d, float * dst_d) {
cudaStream_t stream = ctx.stream();
switch (op) {
case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, dst_d, nelements, ne0, stream); break;
case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, dst_d, nelements, ne0, stream); break;
case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, dst_d, nelements, ne0, stream); break;
default: GGML_ASSERT(false);
}
}
void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
ggml_unary_op op = (ggml_unary_op)dst->op_params[0];
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data);
//cudaStream_t stream = ctx.stream();
//const float * src0_d = (const float *)src0->data;
//const float * src1_d = (const float *)src1->data;
//float * dst_d = (float *)dst->data;
//switch (op) {
// case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
// case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
// case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
// default: GGML_ASSERT(false);
//}
if (src1) {
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data);
} else {
GGML_ASSERT(src0->ne[0] == 2*dst->ne[0] && src0->ne[1] == dst->ne[1] && src0->ne[2] == dst->ne[2] && src0->ne[3] == dst->ne[3]);
ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), dst->ne[0], (const float *)src0->data, (float *)dst->data);
}
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@@ -89,4 +89,7 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor *
void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
int64_t nelements, const float * x, const float * y, float * z);
void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
int64_t nelements,int64_t ne0, const float * x, float * z);
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -7628,13 +7628,13 @@ struct ggml_tensor * ggml_moe_up_gate(
struct ggml_tensor * b,
struct ggml_tensor * ids,
enum ggml_unary_op op) {
if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) {
if (as_gate && (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate))) {
struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids);
struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids);
return ggml_fused_mul_unary(ctx, result_gate, result_up, op);
}
GGML_ASSERT(!ggml_is_transposed(as_up));
GGML_ASSERT(!ggml_is_transposed(as_gate));
GGML_ASSERT(!as_gate || !ggml_is_transposed(as_gate));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert)
@@ -7646,11 +7646,11 @@ struct ggml_tensor * ggml_moe_up_gate(
bool is_node = false;
if (as_up->grad || as_gate->grad || b->grad) {
if (as_up->grad || (as_gate && as_gate->grad) || b->grad) {
is_node = true;
}
const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 };
const int64_t ne[4] = { as_gate ? as_up->ne[1] : as_up->ne[1]/2, ids->ne[0], b->ne[2], 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_MOE_FUSED_UP_GATE;
@@ -7681,7 +7681,7 @@ struct ggml_tensor * ggml_moe_up_gate_ext(
return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op);
}
if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) {
if (as_gate && (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate))) {
struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids);
if (as_up_b) {
result_up = ggml_add_id(ctx, result_up, as_up_b, ids);
@@ -7694,7 +7694,7 @@ struct ggml_tensor * ggml_moe_up_gate_ext(
}
GGML_ASSERT(!ggml_is_transposed(as_up));
GGML_ASSERT(!ggml_is_transposed(as_gate));
GGML_ASSERT(!as_gate || !ggml_is_transposed(as_gate));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert)
@@ -7705,10 +7705,10 @@ struct ggml_tensor * ggml_moe_up_gate_ext(
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]);
GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]);
GGML_ASSERT(!as_gate || as_gate->ne[1] == as_gate_b->ne[0]);
bool is_node = false;
const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 };
const int64_t ne[4] = { as_gate ? as_up->ne[1] : as_up->ne[1]/2, ids->ne[0], b->ne[2], 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_MOE_FUSED_UP_GATE;
@@ -16571,8 +16571,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == dst->src[1]->type);
GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst->src[1]));
GGML_ASSERT(!dst->src[1] || dst->src[0]->type == dst->src[1]->type);
GGML_ASSERT(!dst->src[1] || ggml_are_same_shape(dst->src[0], dst->src[1]));
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const struct ggml_tensor * src1 = dst->src[2];
@@ -16604,7 +16604,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
GGML_ASSERT(ne13 == 1);
const size_t nb41 = up_b ? up_b->nb[1] : 0;
const size_t nb51 = up_b ? gate_b->nb[1] : 0;
const size_t nb51 = up_b && gate_b ? gate_b->nb[1] : 0;
// row groups
const int n_ids = ids->ne[0]; // n_expert_used
@@ -16692,16 +16692,20 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
}
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
const char * src0_2_cur = src0_2 ? (const char *) src0_2->data + cur_a*nb02 : src0_1_cur + nb02/2;
const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL;
const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL;
if (up_b_cur && !gate_b_cur) {
gate_b_cur = up_b_cur + nb41/2;
}
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr0 = src0_2 ? ne01 : ne01/2; // src0 rows
const int64_t nr1 = cne1; // src1 rows
//
//if (ith == 0) printf("Calling iqk_moe_fused_up_gate with nr0 = %d, nr1 = %d, ne00 = %d, ne11 = %d\n", (int)nr0, (int)nr1, (int)ne00, (int)ne11);
if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
type, src0_1_cur, src0_2_cur, nb01,
vec_dot_type, (const char *)wdata, row_size,
@@ -16709,27 +16713,6 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
(float *)dst->data, nb1, nb2,
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
// if (nth%2 == 0) {
// const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur;
// void * dst_d = ith%2 == 0 ? dst1->data : dst2->data;
// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
// type, src0_d, nb01,
// vec_dot_type, (const char *)wdata, row_size,
// (float *)dst_d, nb1, nb2,
// matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error");
//
// } else {
// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
// src0_1->type, (const char *)src0_1_cur, nb01,
// vec_dot_type, (const char *)wdata, row_size,
// (float *)dst1->data, nb1, nb2,
// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
// src0_2->type, (const char *)src0_2_cur, nb01,
// vec_dot_type, (const char *)wdata, row_size,
// (float *)dst2->data, nb1, nb2,
// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
// }
}
#undef MMID_MATRIX_ROW
@@ -25193,10 +25176,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
{
cur = 0;
const struct ggml_tensor * src0 = node->src[0];
const struct ggml_tensor * src1 = node->src[1];
const struct ggml_tensor * src2 = node->src[2];
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src2->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]);
if (src1 && src1->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, src2->ne[0]) * ggml_nrows(src2);
}
const int n_as = src0->ne[2];
cur += GGML_PAD(cur, sizeof(int64_t)); // align