gpt-oss: add ability to use -fmoe (only CUDA for now)

This commit is contained in:
Iwan Kawrakow
2025-08-12 09:49:18 +03:00
parent 464b8fc03b
commit 5abc39481b
9 changed files with 222 additions and 20 deletions

View File

@@ -1396,6 +1396,16 @@ extern "C" {
struct ggml_tensor * ids,
enum ggml_unary_op op);
GGML_API struct ggml_tensor * ggml_moe_up_gate_ext(
struct ggml_context * ctx,
struct ggml_tensor * a_up,
struct ggml_tensor * a_gate,
struct ggml_tensor * b,
struct ggml_tensor * ids,
struct ggml_tensor * a_up_b,
struct ggml_tensor * a_gate_b,
enum ggml_unary_op op);
// A: m columns, n rows,
// B: p columns, n rows,
// result is m columns, p rows

View File

@@ -2221,6 +2221,24 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
}
}
//static __global__ void k_quick_add(uint32_t n, uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
//
// for (uint32_t j = threadIdx.x; j < n; j += blockDim.x) {
// dst[j] = src1[j] + src2[j % n_per_row];
// }
//}
static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
uint32_t row = blockIdx.x;
const float * src1_row = src1 + row*n_per_row;
float * dst_row = dst + row*n_per_row;
for (uint32_t j = threadIdx.x; j < n_per_row; j += blockDim.x) {
dst_row[j] = src1_row[j] + src2[j];
}
}
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
@@ -2271,7 +2289,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
return is_ser;
}
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];
@@ -2320,7 +2338,25 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
0, src0->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
return;
if (next && next->op == GGML_OP_MUL_MAT_ID && next->src[0]->type == src0->type && src1 == next->src[1] &&
ggml_are_same_shape(src0, next->src[0]) &&
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
ggml_backend_buffer_is_cuda(next->buffer) &&
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer)) {
ggml_backend_cuda_buffer_context * next_src0_ctx = (ggml_backend_cuda_buffer_context *) next->src[0]->buffer->context;
ggml_backend_cuda_buffer_context * next_dst_ctx = (ggml_backend_cuda_buffer_context *) next->buffer->context;
if (next_src0_ctx->device == device_id &&
next_dst_ctx->device == device_id) {
local_dst.data = next->data;
ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst,
(const char *)next->src[0]->data, nullptr, src1_quantized.get(), (float *)next->data,
0, src0->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
return true;
}
}
return false;
}
}
@@ -2443,6 +2479,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
}
return false;
}
static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
@@ -2471,6 +2508,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
src0_2_ctx->device == device_id &&
src1_ctx->device == device_id &&
dst_ctx->device == device_id) {
//printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name,
// src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]);
// Fast TG path
const int64_t n_ids = ids->ne[0];
auto stream = ctx.stream(device_id, 0);
@@ -2506,12 +2545,26 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
0, src0_1->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
if (dst->src[4]) {
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[4]->data,
(const int32_t *)ids->data, (float *)local_dst.data,
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
local_dst.nb[1], local_dst.nb[2], dst->src[4]->nb[1], ids->nb[2], stream);
}
local_dst.data = dst_gate_contiguous.get();
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst,
(const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(),
0, src0_2->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
if (dst->src[5]) {
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[5]->data,
(const int32_t *)ids->data, (float *)local_dst.data,
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream);
}
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
@@ -2519,8 +2572,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
ggml_backend_buffer_is_cuda(next->buffer) &&
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids,
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids,
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
}
CUDA_CHECK(cudaGetLastError());
const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
@@ -2556,8 +2616,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
return true;
} else {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst->data, dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, unary_op, ggml_nelements(dst),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
}
CUDA_CHECK(cudaGetLastError());
return false;
}
@@ -2625,7 +2691,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
final_src.nb[3] = final_src.nb[2];
}
if (ne12 == 1) {
if (false && ne12 == 1) {
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
if (fuse_down) {
@@ -2762,6 +2828,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
CUDA_CHECK(cudaGetLastError());
if (dst->src[4]) {
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
dim3 grid_dims(num_src1_rows);
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
(const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data);
CUDA_CHECK(cudaGetLastError());
}
dst_row.data = dst_gate_contiguous.get();
if (use_quantized_src1) {
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
@@ -2771,8 +2845,24 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
CUDA_CHECK(cudaGetLastError());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
if (dst->src[5]) {
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
dim3 grid_dims(num_src1_rows);
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
(const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data);
CUDA_CHECK(cudaGetLastError());
}
auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0],
1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
}
CUDA_CHECK(cudaGetLastError());
if (fuse_down) {
@@ -2945,7 +3035,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
}
break;
case GGML_OP_MUL_MAT_ID:
ggml_cuda_mul_mat_id(ctx, dst);
skip_next = ggml_cuda_mul_mat_id(ctx, dst, next);
break;
case GGML_OP_MOE_FUSED_UP_GATE:
skip_next = ggml_cuda_up_gate_unary(ctx, dst, next);

View File

@@ -56,3 +56,17 @@ void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
nb21
);
}
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne00, int64_t ne01, int64_t ne02,
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream) {
int threads = std::min((int)ne00, 768); // cols
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
add_id_kernel<<<blocks, threads, 0, stream>>>(
src0, src1, src2, dst,
ne0, ne1,
nb01, nb02,
nb11,
nb21
);
}

View File

@@ -1,3 +1,8 @@
#include "common.cuh"
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne00, int64_t ne01, int64_t ne02,
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream);

View File

@@ -524,7 +524,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512,
// so no other implementation works.
//
if (new_mma_available(cc) && Q->ne[0] == 576) {
if (new_mma_available(cc) && (Q->ne[0] == 576 || (Q->ne[0] == 64) && Q->ne[1] >= 128)) {
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
return;
}

View File

@@ -546,3 +546,7 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
}
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
swiglu_oai_cuda(x, g, dst, k, n, o0, o1, alpha, limit, stream);
}

View File

@@ -49,3 +49,7 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream);

View File

@@ -7085,6 +7085,66 @@ struct ggml_tensor * ggml_moe_up_gate(
result->src[1] = as_gate;
result->src[2] = b;
result->src[3] = ids;
result->src[4] = NULL;
result->src[5] = NULL;
ggml_set_op_params_i32(result, 0, (int32_t) op);
return result;
}
struct ggml_tensor * ggml_moe_up_gate_ext(
struct ggml_context * ctx,
struct ggml_tensor * as_up,
struct ggml_tensor * as_gate,
struct ggml_tensor * b,
struct ggml_tensor * ids,
struct ggml_tensor * as_up_b,
struct ggml_tensor * as_gate_b,
enum ggml_unary_op op) {
if (!as_up_b && !as_gate_b) {
return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op);
}
if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) {
struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids);
if (as_up_b) {
result_up = ggml_add_id(ctx, result_up, as_up_b, ids);
}
struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids);
if (as_gate_b) {
result_gate = ggml_add_id(ctx, result_gate, as_gate_b, ids);
}
return ggml_fused_mul_unary(ctx, result_gate, result_up, op);
}
GGML_ASSERT(!ggml_is_transposed(as_up));
GGML_ASSERT(!ggml_is_transposed(as_gate));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert)
GGML_ASSERT(b->ne[3] == 1); // b is 3d
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]);
GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]);
bool is_node = false;
const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_MOE_FUSED_UP_GATE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = as_up;
result->src[1] = as_gate;
result->src[2] = b;
result->src[3] = ids;
result->src[4] = as_up_b;
result->src[5] = as_gate_b;
ggml_set_op_params_i32(result, 0, (int32_t) op);

View File

@@ -10512,7 +10512,7 @@ static ggml_tensor * llm_build_moe_ffn(
float w_scale,
llm_expert_gating_func_type gating_op,
const llm_build_cb & cb,
int il) {
int il, struct ggml_cgraph * graph = nullptr) {
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
bool weight_before_ffn = lctx.model.arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -10606,23 +10606,38 @@ llm_expert_gating_func_type gating_op,
// For now we don't modify the fused up/gate op to include biases.
// Hence, if we have biases, we cannot use fmoe.
//
bool can_use_fmoe = !up_exps_b && !gate_exps_b && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU);
//bool can_use_fmoe = !up_exps_b && !gate_exps_b && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU);
bool can_use_fmoe = type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU || type_op == LLM_FFN_SWIGLU_OAI_MOE;
ggml_tensor * par;
if (can_use_fmoe && lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) {
par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
if (up_exps_b || gate_exps_b) {
par = ggml_moe_up_gate_ext(ctx, up_exps, gate_exps, cur, selected_experts, up_exps_b, gate_exps_b,
type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
type_op == LLM_FFN_GELU ? GGML_UNARY_OP_GELU : GGML_UNARY_OP_SWIGLU_OAI);
} else {
GGML_ASSERT(type_op != LLM_FFN_SWIGLU_OAI_MOE);
par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts,
type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
}
} else {
ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
if (graph) {
// So we can potentially fuse the up and gate mul_mat_id
ggml_build_forward_expand(graph, up);
ggml_build_forward_expand(graph, gate);
}
if (up_exps_b) {
up = ggml_add_id(ctx, up, up_exps_b, selected_experts);
cb(up, "ffn_moe_up_biased", il);
}
ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
if (gate_exps_b) {
gate = ggml_add_id(ctx, gate, gate_exps_b, selected_experts);
cb(gate, "ffn_moe_gate_biased", il);
@@ -10683,7 +10698,7 @@ static ggml_tensor * llm_build_moe_ffn(
float w_scale,
llm_expert_gating_func_type gating_op,
const llm_build_cb & cb,
int il) {
int il, struct ggml_cgraph * graph = nullptr) {
return llm_build_moe_ffn(ctx, lctx, cur,
gate_inp, nullptr,
up_exps, nullptr,
@@ -10692,7 +10707,7 @@ llm_expert_gating_func_type gating_op,
exp_probs_b,
n_expert, n_expert_used,
type_op, norm_w, scale_w, w_scale,
gating_op, cb, il);
gating_op, cb, il, graph);
}
static struct ggml_tensor * llm_build_kqv(
@@ -18242,7 +18257,7 @@ struct llm_build_context {
LLM_FFN_SWIGLU_OAI_MOE, false,
false, 0.0,
LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT,
cb, il);
cb, il, gf);
cb(cur, "ffn_moe_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);