WIP: Qwen3-MoE works with graph, layer still broken

This commit is contained in:
Kawrakow
2025-12-21 15:19:16 +00:00
parent 993cb00a34
commit 8f6ff1fa76

View File

@@ -620,6 +620,20 @@ ggml_tensor * llm_build_context::llm_build_norm(
return cur;
}
static ggml_tensor * get_input_tensor_sm_graph(ggml_tensor * input, int id) {
auto cur = input;
if (input->op == GGML_OP_REDUCE) {
auto view_src = input->view_src;
GGML_ASSERT(view_src);
cur = input->src[id];
if (cur == view_src) {
//printf("%s: Setting input to %s for id = %d\n", __func__, view_src->name, id);
cur = input;
}
}
return cur;
}
ggml_tensor * llm_build_context::llm_build_ffn(
ggml_context * ctx,
llama_context & lctx,
@@ -658,18 +672,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
auto split_d = d->splits[id];
GGML_ASSERT((!split_u && !split_g && !split_d) || (split_u && split_g && split_d));
if (!split_u) continue;
//auto cur = input->op == GGML_OP_REDUCE ? ggml_fake_cpy(ctx, input->src[id], input) : input;
auto cur = input;
if (input->op == GGML_OP_REDUCE) {
auto view_src = input->view_src;
GGML_ASSERT(view_src);
cur = input->src[id];
if (cur == view_src) {
//printf("%s: Setting input to %s for id = %d\n", __func__, view_src->name, id);
cur = input;
}
}
//auto cur = input->op == GGML_OP_REDUCE ? input->src[id] : input;
auto cur = get_input_tensor_sm_graph(input, id);
if (ffn_norm && ffn_norm->extra) {
auto norm = (ggml_split_tensor_t *)ffn_norm->extra;
GGML_ASSERT(norm->splits[id]);
@@ -1225,14 +1228,14 @@ llm_expert_gating_func_type gating_op,
(!split_up_exps->splits[id] && !split_gate_exps->splits[id] && !split_down_exps->splits[id]));
if (!split_up_exps->splits[id]) continue;
int il_cb = 1000*(id + 1) + il;
auto cur = input;
auto cur = get_input_tensor_sm_graph(input, id);
if (ffn_norm) {
auto split_ffn_norm = (ggml_split_tensor_t *)ffn_norm->extra;
GGML_ASSERT(split_ffn_norm && split_ffn_norm->n_device == split_up_exps->n_device);
cur = llm_build_norm(ctx, input, lctx.model.hparams, split_ffn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il);
cur = llm_build_norm(ctx, cur, lctx.model.hparams, split_ffn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_inp_normed", il_cb);
}
else if (cur->type != GGML_TYPE_F32) {
if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
}
auto routed_out = llm_build_moe_ffn(ctx, lctx, cur,
@@ -1243,7 +1246,7 @@ llm_expert_gating_func_type gating_op,
split_exp_probs_b ? split_exp_probs_b->splits[id] : nullptr,
n_expert, n_expert_used,
type_op, norm_w, scale_w, w_scale,
gating_op, cb, il, graph, add_input);
gating_op, cb, il, graph, false);
cb(routed_out, "routed_out", il_cb);
if (split_up_shexp) {
@@ -1266,19 +1269,20 @@ llm_expert_gating_func_type gating_op,
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
cb(cur, "ffn_out_f16", il_cb);
}
ggml_build_forward_expand(graph, routed_out);
ggml_build_forward_expand(graph, cur);
results.push_back(cur);
}
GGML_ASSERT(!results.empty());
if (add_input) {
results.back() = ggml_add(ctx, results.back(), input);
cb(results.back(), "ffn_inp_added", il);
}
if (results.size() == 1) return results.front();
auto cur = ggml_add(ctx, results[0], results[1]);
cur->op_params[0] = 0xff;
cb(cur, "ffn_combined", il);
for (int id = 2; id < int(results.size()); ++id) {
cur = ggml_add(ctx, cur, results[id]);
cb(cur, "ffn_combined", il);
}
auto cur = ggml_reduce(ctx, results.data(), split_up_exps->n_device, GGML_OP_ADD);
cb(cur, "moe_ffn_combined", il);
ggml_build_forward_expand(graph, cur);
return cur;
}
@@ -3979,23 +3983,26 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
//struct ggml_tensor * inpSA = inpL;
// norm
//cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
//cb(cur, "attn_norm", il);
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il);
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0,
il, true, false, true);
//printf("%s: attn = %s(%s)\n", __func__, cur->name, ggml_op_name(cur->op));
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
//inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
auto ffn_inp = cur;
//struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
//cb(ffn_inp, "ffn_inp", il);
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
model.layers[il].ffn_gate_inp, nullptr,
@@ -4009,9 +4016,11 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
n_expert, n_expert_used,
LLM_FFN_SILU, true, false, 0.0f,
LLM_EXPERT_GATING_FUNC_SOFTMAX,
LLM_FFN_SILU, cb, il, gf);
LLM_FFN_SILU, cb, il, gf, true);
cur = ggml_add(ctx0, cur, ffn_inp);
//printf("%s: ffn = %s(%s)\n", __func__, cur->name, ggml_op_name(cur->op));
//cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
@@ -9401,18 +9410,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
if (!split_wq) continue;
//auto cur = input = input->op == GGML_OP_REDUCE ? ggml_fake_cpy(ctx0, input->src[id], input) : input;
auto cur = input;
if (input->op == GGML_OP_REDUCE) {
auto view_src = input->view_src;
GGML_ASSERT(view_src);
cur = input->src[id];
if (cur == view_src) {
//printf("%s: Setting input to %s for id = %d\n", __func__, view_src->name, id);
cur = input;
}
}
//auto cur = input = input->op == GGML_OP_REDUCE ? input->src[id] : input;
auto cur = get_input_tensor_sm_graph(input, id);
if (attn_norm) {
auto split_norm = attn_norm->splits[id];
cur = llm_build_norm(ctx0, cur, hparams, split_norm, NULL, LLM_NORM_RMS, cb, il);