gpt-oss: WIP llama

Model loads and runs (CPU only), but PPL is much to high
(~1500 for 1st batch vs ~200 in mainline).
Is it because of SWA, because of vocab, or did I introduce a bug somewhere?
This commit is contained in:
Iwan Kawrakow
2025-08-10 10:09:42 +03:00
parent e24a1d3eda
commit c69d04f324
2 changed files with 463 additions and 157 deletions

View File

@@ -22067,6 +22067,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SWIGLU:
case GGML_UNARY_OP_SWIGLU_OAI:
{
n_tasks = n_threads;
} break;

View File

@@ -254,6 +254,7 @@ enum llm_arch {
LLM_ARCH_COHERE2,
LLM_ARCH_DOTS1,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_OPENAI_MOE,
LLM_ARCH_UNKNOWN,
};
@@ -313,6 +314,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_COHERE2, "cohere2" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -577,6 +579,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_OUT_NORM,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
@@ -1732,6 +1735,25 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_OPENAI_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@@ -1778,6 +1800,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_OPENAI_MOE,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
@@ -1817,6 +1840,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
{ "bitnet", LLM_CHAT_TEMPLATE_BITNET },
};
@@ -2723,14 +2747,17 @@ static const size_t MiB = 1024*kiB;
static const size_t GiB = 1024*MiB;
enum llm_expert_gating_func_type {
LLM_EXPERT_GATING_FUNC_TYPE_NONE = 0,
LLM_EXPERT_GATING_FUNC_SOFTMAX = 1,
LLM_EXPERT_GATING_FUNC_SIGMOID = 2,
LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3,
};
static const char * llama_expert_gating_func_name(llm_expert_gating_func_type type) {
switch (type) {
case LLM_EXPERT_GATING_FUNC_SOFTMAX: return "softmax";
case LLM_EXPERT_GATING_FUNC_SIGMOID: return "sigmoid";
case LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT: return "softmax_weight";
default: return "unknown";
}
}
@@ -2982,107 +3009,114 @@ struct llama_layer_nextn {
// TODO: separate into "llama_layer_enc" and "llama_layer_dec"
struct llama_layer {
// normalization
struct ggml_tensor * attn_norm;
struct ggml_tensor * attn_norm_b;
struct ggml_tensor * attn_norm_2;
struct ggml_tensor * attn_norm_2_b;
struct ggml_tensor * attn_q_norm;
struct ggml_tensor * attn_q_norm_b;
struct ggml_tensor * attn_k_norm;
struct ggml_tensor * attn_k_norm_b;
struct ggml_tensor * attn_out_norm;
struct ggml_tensor * attn_out_norm_b;
struct ggml_tensor * attn_q_a_norm;
struct ggml_tensor * attn_kv_a_norm;
struct ggml_tensor * attn_sub_norm;
struct ggml_tensor * attn_post_norm;
struct ggml_tensor * ffn_sub_norm;
struct ggml_tensor * attn_norm_cross;
struct ggml_tensor * attn_norm_enc;
struct ggml_tensor * attn_norm = nullptr;
struct ggml_tensor * attn_norm_b = nullptr;
struct ggml_tensor * attn_norm_2 = nullptr;
struct ggml_tensor * attn_norm_2_b = nullptr;
struct ggml_tensor * attn_q_norm = nullptr;
struct ggml_tensor * attn_q_norm_b = nullptr;
struct ggml_tensor * attn_k_norm = nullptr;
struct ggml_tensor * attn_k_norm_b = nullptr;
struct ggml_tensor * attn_out_norm = nullptr;
struct ggml_tensor * attn_out_norm_b = nullptr;
struct ggml_tensor * attn_q_a_norm = nullptr;
struct ggml_tensor * attn_kv_a_norm = nullptr;
struct ggml_tensor * attn_sub_norm = nullptr;
struct ggml_tensor * attn_post_norm = nullptr;
struct ggml_tensor * ffn_sub_norm = nullptr;
struct ggml_tensor * attn_norm_cross = nullptr;
struct ggml_tensor * attn_norm_enc = nullptr;
// attention
struct ggml_tensor * wq;
struct ggml_tensor * wk;
struct ggml_tensor * wv;
struct ggml_tensor * wo;
struct ggml_tensor * wqkv;
struct ggml_tensor * wq_a;
struct ggml_tensor * wq_b;
struct ggml_tensor * wkv_a_mqa;
struct ggml_tensor * wkv_b;
struct ggml_tensor * wk_b;
struct ggml_tensor * wv_b;
struct ggml_tensor * wq_cross;
struct ggml_tensor * wk_cross;
struct ggml_tensor * wv_cross;
struct ggml_tensor * wo_cross;
struct ggml_tensor * wq_enc;
struct ggml_tensor * wk_enc;
struct ggml_tensor * wv_enc;
struct ggml_tensor * wo_enc;
struct ggml_tensor * wq = nullptr;
struct ggml_tensor * wk = nullptr;
struct ggml_tensor * wv = nullptr;
struct ggml_tensor * wo = nullptr;
struct ggml_tensor * wqkv = nullptr;
struct ggml_tensor * wq_a = nullptr;
struct ggml_tensor * wq_b = nullptr;
struct ggml_tensor * wkv_a_mqa = nullptr;
struct ggml_tensor * wkv_b = nullptr;
struct ggml_tensor * wk_b = nullptr;
struct ggml_tensor * wv_b = nullptr;
struct ggml_tensor * wq_cross = nullptr;
struct ggml_tensor * wk_cross = nullptr;
struct ggml_tensor * wv_cross = nullptr;
struct ggml_tensor * wo_cross = nullptr;
struct ggml_tensor * wq_enc = nullptr;
struct ggml_tensor * wk_enc = nullptr;
struct ggml_tensor * wv_enc = nullptr;
struct ggml_tensor * wo_enc = nullptr;
struct ggml_tensor * attn_sinks = nullptr;
// attention bias
struct ggml_tensor * bq;
struct ggml_tensor * bk;
struct ggml_tensor * bv;
struct ggml_tensor * bo;
struct ggml_tensor * bqkv;
struct ggml_tensor * bq = nullptr;
struct ggml_tensor * bk = nullptr;
struct ggml_tensor * bv = nullptr;
struct ggml_tensor * bo = nullptr;
struct ggml_tensor * bqkv = nullptr;
// relative position bias
struct ggml_tensor * attn_rel_b;
struct ggml_tensor * attn_rel_b_enc;
struct ggml_tensor * attn_rel_b_cross;
struct ggml_tensor * attn_rel_b = nullptr;
struct ggml_tensor * attn_rel_b_enc = nullptr;
struct ggml_tensor * attn_rel_b_cross = nullptr;
// normalization
struct ggml_tensor * ffn_norm;
struct ggml_tensor * ffn_norm_b;
struct ggml_tensor * ffn_post_norm;
struct ggml_tensor * layer_out_norm;
struct ggml_tensor * layer_out_norm_b;
struct ggml_tensor * ffn_norm_exps;
struct ggml_tensor * ffn_norm_enc;
struct ggml_tensor * ffn_norm = nullptr;
struct ggml_tensor * ffn_norm_b = nullptr;
struct ggml_tensor * ffn_post_norm = nullptr;
struct ggml_tensor * layer_out_norm = nullptr;
struct ggml_tensor * layer_out_norm_b = nullptr;
struct ggml_tensor * ffn_norm_exps = nullptr;
struct ggml_tensor * ffn_norm_enc = nullptr;
// ff
struct ggml_tensor * ffn_gate; // w1
struct ggml_tensor * ffn_down; // w2
struct ggml_tensor * ffn_up; // w3
struct ggml_tensor * ffn_gate_enc;
struct ggml_tensor * ffn_down_enc;
struct ggml_tensor * ffn_up_enc;
struct ggml_tensor * ffn_gate = nullptr; // w1
struct ggml_tensor * ffn_down = nullptr; // w2
struct ggml_tensor * ffn_up = nullptr; // w3
struct ggml_tensor * ffn_gate_enc = nullptr;
struct ggml_tensor * ffn_down_enc = nullptr;
struct ggml_tensor * ffn_up_enc = nullptr;
// ff MoE
struct ggml_tensor * ffn_gate_inp;
struct ggml_tensor * ffn_gate_exps;
struct ggml_tensor * ffn_down_exps;
struct ggml_tensor * ffn_up_exps ;
struct ggml_tensor * ffn_gate_inp = nullptr;
struct ggml_tensor * ffn_gate_exps = nullptr;
struct ggml_tensor * ffn_down_exps = nullptr;
struct ggml_tensor * ffn_up_exps = nullptr;
// ff MoE bias
struct ggml_tensor * ffn_gate_inp_b = nullptr;
struct ggml_tensor * ffn_gate_exps_b = nullptr;
struct ggml_tensor * ffn_down_exps_b = nullptr;
struct ggml_tensor * ffn_up_exps_b = nullptr;
// ff shared expert (shexp)
struct ggml_tensor * ffn_gate_inp_shexp;
struct ggml_tensor * ffn_gate_shexp;
struct ggml_tensor * ffn_down_shexp;
struct ggml_tensor * ffn_up_shexp;
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
struct ggml_tensor * ffn_gate_shexp = nullptr;
struct ggml_tensor * ffn_down_shexp = nullptr;
struct ggml_tensor * ffn_up_shexp = nullptr;
// ff bias
struct ggml_tensor * ffn_gate_b = nullptr;
struct ggml_tensor * ffn_down_b = nullptr; // b2
struct ggml_tensor * ffn_up_b = nullptr; // b3
struct ggml_tensor * ffn_act;
struct ggml_tensor * ffn_act = nullptr;
struct ggml_tensor * ffn_exp_probs_b = nullptr;
// mamba proj
struct ggml_tensor * ssm_in;
struct ggml_tensor * ssm_x;
struct ggml_tensor * ssm_dt;
struct ggml_tensor * ssm_out;
struct ggml_tensor * ssm_in = nullptr;
struct ggml_tensor * ssm_x = nullptr;
struct ggml_tensor * ssm_dt = nullptr;
struct ggml_tensor * ssm_out = nullptr;
// mamba
struct ggml_tensor * ssm_conv1d;
struct ggml_tensor * ssm_a;
struct ggml_tensor * ssm_d;
struct ggml_tensor * ssm_conv1d = nullptr;
struct ggml_tensor * ssm_a = nullptr;
struct ggml_tensor * ssm_d = nullptr;
// mamba bias
struct ggml_tensor * ssm_conv1d_b;
struct ggml_tensor * ssm_dt_b;
struct ggml_tensor * ssm_conv1d_b = nullptr;
struct ggml_tensor * ssm_dt_b = nullptr;
// long rope factors
struct ggml_tensor * rope_long = nullptr;
@@ -3090,13 +3124,13 @@ struct llama_layer {
struct ggml_tensor * rope_freqs = nullptr;
// bitnet scale
struct ggml_tensor * wq_scale;
struct ggml_tensor * wk_scale;
struct ggml_tensor * wv_scale;
struct ggml_tensor * wo_scale;
struct ggml_tensor * ffn_gate_scale;
struct ggml_tensor * ffn_up_scale;
struct ggml_tensor * ffn_down_scale;
struct ggml_tensor * wq_scale = nullptr;
struct ggml_tensor * wk_scale = nullptr;
struct ggml_tensor * wv_scale = nullptr;
struct ggml_tensor * wo_scale = nullptr;
struct ggml_tensor * ffn_gate_scale = nullptr;
struct ggml_tensor * ffn_up_scale = nullptr;
struct ggml_tensor * ffn_down_scale = nullptr;
struct llama_layer_nextn nextn;
@@ -6308,6 +6342,19 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_OPENAI_MOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
//TODO OAI_MOE: SWA
//hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
//hparams.set_swa_pattern(2);
// TODO: switch (hparams.n_layer)
} break;
default: (void)0;
}
@@ -6795,6 +6842,13 @@ static void llm_load_vocab(
}
}
// @ngxson : quick hack for gpt-oss, always render these tokens
for (const auto & t : vocab.token_to_id) {
if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>") {
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
}
}
// find EOM token: "<|eom_id|>"
//
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
@@ -6848,6 +6902,7 @@ static void llm_load_vocab(
}
// find FIM_MID token: "<|fim_middle|>", "<fim-middle>", "<MID>", etc.
// TODO OAI_MOE: o200k_harmony
if (vocab.special_fim_mid_id == -1) {
if (false
|| t.first == "<|fim_middle|>" // Qwen
@@ -6867,6 +6922,7 @@ static void llm_load_vocab(
}
}
// find FIM_PAD token: "<|fim_pad|>", "<fim-pad>", "<PAD>", etc.
if (vocab.special_fim_pad_id == -1) {
if (false
@@ -7170,7 +7226,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}
if (model.arch == LLM_ARCH_QWEN3MOE) {
if (model.arch == LLM_ARCH_QWEN3MOE || model.arch == LLM_ARCH_OPENAI_MOE) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
}
@@ -9794,6 +9850,48 @@ static bool llm_load_tensors(
layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
}
} break;
case LLM_ARCH_OPENAI_MOE:
{
const int64_t n_ff_exp = hparams.n_ff_exp;
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
layer.attn_sinks = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0);
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
// bias
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0);
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0);
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0);
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
layer.ffn_gate_inp_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0);
layer.ffn_gate_exps_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
layer.ffn_down_exps_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0);
layer.ffn_up_exps_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -10071,6 +10169,7 @@ enum llm_ffn_op_type {
LLM_FFN_RELU,
LLM_FFN_RELU_SQR,
LLM_FFN_SWIGLU,
LLM_FFN_SWIGLU_OAI_MOE,
};
enum llm_ffn_gate_type {
@@ -10362,6 +10461,8 @@ static struct ggml_tensor * llm_build_ffn(
cur = ggml_swiglu(ctx, cur);
cb(cur, "ffn_swiglu", il);
} break;
default:
GGML_ABORT("fatal error");
}
if (type_gate == LLM_FFN_PAR) {
@@ -10394,15 +10495,15 @@ static struct ggml_tensor * llm_build_ffn(
return cur;
}
static struct ggml_tensor * llm_build_moe_ffn(
struct ggml_context * ctx,
struct llama_context & lctx,
struct ggml_tensor * cur,
struct ggml_tensor * gate_inp,
struct ggml_tensor * up_exps,
struct ggml_tensor * gate_exps,
struct ggml_tensor * down_exps,
struct ggml_tensor * exp_probs_b,
static ggml_tensor * llm_build_moe_ffn(
ggml_context * ctx,
llama_context & lctx,
ggml_tensor * cur,
ggml_tensor * gate_inp, ggml_tensor * gate_inp_b,
ggml_tensor * up_exps, ggml_tensor * up_exps_b,
ggml_tensor * gate_exps, ggml_tensor * gate_exps_b,
ggml_tensor * down_exps, ggml_tensor * down_exps_b,
ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llm_ffn_op_type type_op,
@@ -10419,6 +10520,12 @@ llm_expert_gating_func_type gating_op,
ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il);
if (gate_inp_b) {
logits = ggml_add(ctx, logits, gate_inp_b);
cb(logits, "ffn_moe_logits_biased", il);
}
//ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
ggml_tensor * probs = nullptr;
switch (gating_op) {
@@ -10430,6 +10537,10 @@ llm_expert_gating_func_type gating_op,
{
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
} break;
case LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
{
probs = logits; // [n_expert, n_tokens]
} break;
default:
GGML_ABORT("fatal error");
}
@@ -10459,6 +10570,13 @@ llm_expert_gating_func_type gating_op,
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
cb(weights, "ffn_moe_weights_softmax", il);
}
if (norm_w) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
@@ -10485,51 +10603,58 @@ llm_expert_gating_func_type gating_op,
cb(cur, "ffn_moe_weighted", il);
}
// For now we don't modify the fused up/gate op to include biases.
// Hence, if we have biases, we cannot use fmoe.
//
bool can_use_fmoe = !up_exps_b && !gate_exps_b && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU);
ggml_tensor * par;
if (lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) {
if (can_use_fmoe && lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) {
par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
} else {
ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
if (up_exps_b) {
up = ggml_add_id(ctx, up, up_exps_b, selected_experts);
cb(up, "ffn_moe_up_biased", il);
}
ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
// This is equivalent to the commented out code below
if (gate_exps_b) {
gate = ggml_add_id(ctx, gate, gate_exps_b, selected_experts);
cb(gate, "ffn_moe_gate_biased", il);
}
if (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU) {
par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
} else if (type_op == LLM_FFN_SWIGLU_OAI_MOE) {
constexpr float alpha = 1.702f;
constexpr float limit = 7.0f;
par = ggml_swiglu_oai(ctx, gate, up, alpha, limit);
}
else {
GGML_ABORT("fatal error");
}
}
cb(par, "ffn_moe_gate_par", il);
ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
if (down_exps_b) {
experts = ggml_add_id(ctx, experts, down_exps_b, selected_experts);
cb(experts, "ffn_moe_down_biased", il);
}
if (!weight_before_ffn) {
experts = ggml_mul(ctx, experts, weights);
cb(cur, "ffn_moe_weighted", il);
}
//#ifdef GGML_USE_VULKAN
// // aggregate experts
// ggml_tensor * moe_out = nullptr;
// //ggml_tensor * first_expert = nullptr;
// for (int i = 0; i < n_expert_used; ++i) {
// ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
// experts->nb[2], i*experts->nb[1]);
//
// if (i == 0) {
// moe_out = cur_expert;
// } else {
// moe_out = ggml_add(ctx, moe_out, cur_expert);
// }
// }
//
// if (n_expert_used == 1) {
// // avoid returning a non-contiguous tensor
// moe_out = ggml_cont(ctx, moe_out);
// }
//
// return moe_out;
//#else
if (n_expert_used == 1) {
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
}
@@ -10538,10 +10663,38 @@ llm_expert_gating_func_type gating_op,
ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
}
return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
//#endif
}
static ggml_tensor * llm_build_moe_ffn(
struct ggml_context * ctx,
struct llama_context & lctx,
struct ggml_tensor * cur,
struct ggml_tensor * gate_inp,
struct ggml_tensor * up_exps,
struct ggml_tensor * gate_exps,
struct ggml_tensor * down_exps,
struct ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llm_ffn_op_type type_op,
bool norm_w,
bool scale_w,
float w_scale,
llm_expert_gating_func_type gating_op,
const llm_build_cb & cb,
int il) {
return llm_build_moe_ffn(ctx, lctx, cur,
gate_inp, nullptr,
up_exps, nullptr,
gate_exps, nullptr,
down_exps, nullptr,
exp_probs_b,
n_expert, n_expert_used,
type_op, norm_w, scale_w, w_scale,
gating_op, cb, il);
}
static struct ggml_tensor * llm_build_kqv(
struct ggml_context * ctx,
struct llama_context & lctx,
@@ -10555,7 +10708,8 @@ static struct ggml_tensor * llm_build_kqv(
int32_t n_kv,
float kq_scale,
const llm_build_cb & cb,
int il) {
int il,
ggml_tensor * sinks = nullptr) {
const llama_model & model = lctx.model;
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
@@ -10602,6 +10756,7 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
ggml_flash_attn_ext_add_sinks(cur, sinks);
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
@@ -10625,7 +10780,7 @@ static struct ggml_tensor * llm_build_kqv(
cb(v, "v", il);
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) {
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
@@ -10660,6 +10815,7 @@ static struct ggml_tensor * llm_build_kqv(
1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
} else {
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_soft_max_add_sinks(kq, sinks);
}
cb(kq, "kq_soft_max_ext", il);
@@ -10757,7 +10913,8 @@ static struct ggml_tensor * llm_build_kv(
int32_t n_kv,
float kq_scale,
const llm_build_cb & cb,
int il) {
int il,
ggml_tensor * sinks = nullptr) {
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
@@ -10772,7 +10929,7 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_tensor * cur;
cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks);
cb(cur, "kqv_out", il);
return cur;
@@ -17990,6 +18147,137 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_openai_moe() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
//const int64_t n_embd_head = hparams.n_embd_head_v;
const float kq_scale = 1.0f / sqrtf(float(n_rot)); //float(n_embd_head));
//auto * inp_attn = build_attn_inp_kv_unified_iswa();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens);
//Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
//cb(Vcur, "Vcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks);
//cur = build_attn_with_sinks(inp_attn,
// model.layers[il].wo, model.layers[il].bo,
// Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = ffn_inp;
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_post_norm", il);
// MoE branch
cur = llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SWIGLU_OAI_MOE, false,
false, 0.0,
LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT,
cb, il);
cb(cur, "ffn_moe_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
//res->t_embd = cur;
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
//res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
return gf;
}
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -18297,6 +18585,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_hunyuan_moe();
} break;
case LLM_ARCH_OPENAI_MOE:
{
result = llm.build_openai_moe();
} break;
default:
GGML_ABORT("fatal error");
}
@@ -22089,6 +22381,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_CODESHELL:
case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE:
case LLM_ARCH_OPENAI_MOE:
return LLAMA_ROPE_TYPE_NEOX;
// all model arches should be listed explicitly here
@@ -23956,6 +24249,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
} else if (tmpl_contains("<|im_middle|>") && tmpl_contains("<|im_end|>")) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
} else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
return LLM_CHAT_TEMPLATE_OPENAI_MOE;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@@ -24416,6 +24711,16 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) {
// OpenAI MoE (based on Harmony chat template)
for (auto message : chat) {
std::string role(message->role);
ss << "<|start|>" << role << "<|message|>" << message->content;
ss << (role == "assistant" ? "<|return|>" : "<|end|>");
}
if (add_ass) {
ss << "<|start|>assistant";
}
} else {
// template not supported
return -1;