Adding Ling/Ring (a.k.a., Bailing-MoE2) support (#833)

* Adding Ling/Ring (a.k.a., Bailing-MoE2)

* Add expert group selection (not working, so turned off)

* BailingMoE2 conversion

* WIP

* Bits and pieces

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-15 14:20:40 +03:00
committed by GitHub
parent ba9fefb73d
commit f7adde1043
25 changed files with 1295 additions and 56 deletions

View File

@@ -820,6 +820,38 @@ llm_expert_gating_func_type gating_op,
selection_probs = logits;
}
if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
auto& hparams = lctx.model.hparams;
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
// organize experts into n_expert_groups
ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups]
#if 0
ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups]
group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups]
// get top n_group_used expert groups
group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1]
#else
// Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead
ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups]
group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups]
// get top n_group_used expert groups
group_scores = ggml_transpose(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1]
#endif
ggml_tensor * expert_groups = ggml_top_k(ctx, ggml_cont(ctx, group_scores), hparams.n_group_used); // [n_group_used, 1]
cb(expert_groups->src[0], "ffn_moe_group_argsort", il);
cb(expert_groups, "ffn_moe_group_topk", il);
// mask out the other groups
selection_probs = ggml_get_rows(ctx, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used]
selection_probs = ggml_set_rows(ctx, ggml_scale_bias(ctx, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups]
selection_probs = ggml_view_2d(ctx, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert]
selection_probs = ggml_cont(ctx, ggml_transpose(ctx, selection_probs)); // [n_expert, n_tokens]
cb(selection_probs, "ffn_moe_probs_masked", il);
}
// select experts
ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used,
lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens]
@@ -846,6 +878,11 @@ llm_expert_gating_func_type gating_op,
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);
if (lctx.model.arch == LLM_ARCH_BAILINGMOE2) {
weights_sum = ggml_scale_bias(ctx, weights_sum, 1.0, 1e-20);
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
}
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il);
@@ -8192,6 +8229,139 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
return gf;
}
ggml_cgraph * llm_build_context::build_bailingmoe2() {
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
//auto * inp_attn = build_attn_inp_kv();
ggml_tensor * KQ_mask = build_inp_KQ_mask();
//const int64_t n_embd_head = hparams.n_embd_head_v;
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
ggml_tensor * inp_out_ids = build_inp_out_ids();
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self_attention
{
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
//ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA);
cb(sa_out, "sa_out", il);
// MoE branch
cur = llm_build_norm(ctx0, sa_out, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
} else {
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(llm_expert_gating_func_type) hparams.expert_gating_func,
cb, il, gf);
cb(moe_out, "ffn_moe_out", il);
ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, sa_out);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
ggml_cgraph * llm_build_context::llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
llama_batch dummy;
dummy.n_tokens = 0;
@@ -8513,6 +8683,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
{
result = llm.build_openai_moe();
} break;
case LLM_ARCH_BAILINGMOE2:
{
result = llm.build_bailingmoe2();
} break;
default:
GGML_ABORT("fatal error");
}