From 8737d3d924b0f071c6516c253cc477ec8006307f Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 19 Feb 2026 06:58:36 +0000 Subject: [PATCH] WIP: loads and runs, but not correct Very high PPL, empty TG. --- src/llama-arch.cpp | 2 + src/llama-arch.h | 5 +- src/llama-build-context.cpp | 145 +++++++++++++++++++++++++++++++++++- src/llama-build-context.h | 2 + src/llama-delta-net.cpp | 55 ++++++++------ src/llama-hparams.cpp | 35 ++++++++- src/llama-load-tensors.cpp | 78 +++++++++++++++++++ src/llama-model.cpp | 37 ++++++++- src/llama-model.h | 3 + src/llama-vocab.cpp | 11 +++ src/llama-vocab.h | 1 + src/llama.cpp | 15 ++-- 12 files changed, 353 insertions(+), 36 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index f18741ae..fd748a57 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -30,6 +30,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, + { LLM_ARCH_QWEN35MOE, "qwen35moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PLAMO, "plamo" }, @@ -159,6 +160,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, + { LLM_KV_FULL_ATTENTION_INTERVAL, "%s.full_attention_interval" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 21715a57..e447261a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -29,6 +29,7 @@ enum llm_arch { LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, + LLM_ARCH_QWEN35MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PLAMO, @@ -152,7 +153,7 @@ enum llm_kv { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, LLM_KV_ATTENTION_INDEXER_TOP_K, - + LLM_KV_FULL_ATTENTION_INTERVAL, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT_PER_LAYER, @@ -285,6 +286,8 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_BETA, LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index d5903097..5031a1be 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -142,7 +142,7 @@ ggml_cgraph * llm_build_context::build_k_shift() { ggml_set_input(lctx.inp_K_shift); for (int il = 0; il < n_layer; ++il) { - if (model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il)) { + if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il)) { continue; } if (kv_self.k_l[il] == nullptr) { @@ -241,7 +241,7 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector & ids) } for (int il = 0; il < n_layer; ++il) { - if (model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il)) { + if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il)) { continue; } if (kv_self.k_l[il] == nullptr) { @@ -4478,6 +4478,143 @@ ggml_cgraph * llm_build_context::build_qwen3next() { return gf; } +ggml_cgraph * llm_build_context::build_qwen35moe() { + static constexpr int QWEN3NEXT_CHUNK_SIZE = 64; + + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + delta_net delta(lctx, batch); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * { + + auto Qaux = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + auto Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + auto Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Qaux, "Qaux", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + ggml_build_forward_expand(gf, Qaux); + ggml_build_forward_expand(gf, Kcur); + ggml_build_forward_expand(gf, Vcur); + + Qaux = ggml_reshape_3d(ctx0, Qaux, n_embd_head * 2, n_head, n_tokens); + auto Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], 0)); + auto gate = ggml_cont_2d(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], n_embd_head*ggml_element_size(Qaux)), n_embd_head*n_head, n_tokens); + cb(Qcur, "Qcur", il); + cb(gate, "gate", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur_roped", il); + cb(Kcur, "Kcur_roped", il); + + ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, + hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale, cb, il); + cb(attn, "attn_pregate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "gate_sigmoid", il); + attn = ggml_mul(ctx0, attn, gate); + cb(attn, "attn_gated", il); + + attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn); + cb(attn, "attn_output", il); + + return attn; + + }; + + ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; + ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + lctx.inp_s_seq_qnext = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 1, n_tokens); + cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1); + ggml_set_input(lctx.inp_s_seq_qnext); + + ggml_tensor * causal_mask = nullptr; + ggml_tensor * identity = nullptr; + ggml_tensor * diag_mask = nullptr; + causal_mask = ggml_tri(ctx0, + ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f), + GGML_TRI_TYPE_LOWER); + identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f)); + diag_mask = ggml_add(ctx0, causal_mask, identity); + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + ggml_build_forward_expand(gf, diag_mask); + + ggml_tensor * cur = nullptr; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + if (hparams.is_recurrent(il)) { + cur = delta.build_layer_attn_linear(ctx0, gf, cur, causal_mask, identity, diag_mask, il, cb); + } else { + cur = build_layer_attn(cur, inp_pos, KQ_mask, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur, + model.layers[il].ffn_gate_inp, nullptr, + model.layers[il].ffn_up_exps, nullptr, + model.layers[il].ffn_gate_exps, nullptr, + model.layers[il].ffn_down_exps, nullptr, + nullptr, + model.layers[il].ffn_up_shexp, nullptr, // we don't have shared expert biases? + model.layers[il].ffn_gate_shexp, nullptr, + model.layers[il].ffn_down_shexp, nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, false, 0.0f, + LLM_EXPERT_GATING_FUNC_SOFTMAX, + LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps, nullptr, model.layers[il].ffn_gate_inp_shexp); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = build_output(lctx, ctx0, inpL, model.output, model.output_norm, cb); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} + ggml_cgraph * llm_build_context::build_qwen3vl() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -9508,6 +9645,10 @@ ggml_cgraph * llm_build_context::llama_build_graph( { result = llm.build_qwen3next(); } break; + case LLM_ARCH_QWEN35MOE: + { + result = llm.build_qwen35moe(); + } break; case LLM_ARCH_QWEN3VL: { result = llm.build_qwen3vl(); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 67ce81e5..9508c5c6 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -206,6 +206,8 @@ struct llm_build_context { ggml_cgraph * build_qwen3next(); + ggml_cgraph * build_qwen35moe(); + ggml_cgraph * build_phi2(); ggml_cgraph * build_phi3(); diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index 3afa0c91..2259b34a 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -92,6 +92,9 @@ std::pair delta_net::build_delta_net_chunking(ggml GGML_ASSERT(v->ne[2] == n_tokens); GGML_ASSERT(k->ne[2] == n_tokens); GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + if (beta->ne[0] != H_v || beta->ne[2] != n_tokens || beta->ne[3] != n_seqs) { + printf("beta: %ld x %ld x %ld, expected %ld x %ld x %ld\n", beta->ne[0], beta->ne[2], beta->ne[3], H_v, n_tokens, n_seqs); + } GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); GGML_ASSERT(H_k == H_v); @@ -320,10 +323,6 @@ std::pair delta_net::build_delta_net_autoregressiv GGML_ASSERT(H_k == H_v); GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); - //const float eps_norm = hparams.f_norm_rms_eps; - //q = ggml_l2_norm(ctx0, q, eps_norm); - //k = ggml_l2_norm(ctx0, k, eps_norm); - const float scale = 1.0f / sqrtf(S_v); q = ggml_scale(ctx0, q, scale); @@ -464,35 +463,45 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ const uint32_t qnext_state_slots = llm_build_context::llama_kv_qnext_state_slots(kv_self); GGML_ASSERT(qnext_state_slots > 0); - const int64_t n_tok = cur->ne[1]; + const int64_t n_seqs = 1; + const int64_t n_seq_tokens = n_tok; auto qkvz = build_qkvz(ctx0, cur, il, cb); ggml_tensor * qkv_mixed = qkvz.first; ggml_tensor * z = qkvz.second; - ggml_tensor * mixed_ba = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta_alpha, cur); - cb(mixed_ba, "linear_attn_mixed_ba", il); + ggml_tensor *alpha, *beta; + if (model.layers[il].ssm_beta_alpha) { + ggml_tensor * mixed_ba = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta_alpha, cur); + cb(mixed_ba, "linear_attn_mixed_ba", il); - int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; - ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tok, 1); + int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; + ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tok, 1); - int64_t split_sizes_ba[2] = { - num_v_heads / num_k_heads, - num_v_heads / num_k_heads - }; + int64_t split_sizes_ba[2] = { + num_v_heads / num_k_heads, + num_v_heads / num_k_heads + }; - ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tok, 1, - mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0); - cb(b, "b", il); + ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tok, 1, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0); + cb(b, "b", il); - ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tok, 1, - mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], - split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); - cb(a, "a", il); + ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tok, 1, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], + split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); + cb(a, "a", il); - ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tok, 1); - ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1); + beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tok, 1); + alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1); + } else { + beta = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta, cur); + beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_tok, 1); + alpha = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_alpha, cur); + // Why??? + alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + } cb(beta, "beta", il); cb(alpha, "alpha", il); @@ -645,7 +654,7 @@ ggml_tensor * delta_net::build_layer_attn_linear(ggml_context * ctx0, ggml_cgrap GGML_ASSERT(model.layers[il].ssm_conv1d != nullptr); GGML_ASSERT(model.layers[il].ssm_dt != nullptr); GGML_ASSERT(model.layers[il].ssm_a != nullptr); - GGML_ASSERT(model.layers[il].ssm_beta_alpha != nullptr); + GGML_ASSERT(model.layers[il].ssm_beta_alpha != nullptr || (model.layers[il].ssm_alpha != nullptr && model.layers[il].ssm_beta != nullptr)); GGML_ASSERT(model.layers[il].ssm_norm != nullptr); GGML_ASSERT(model.layers[il].ssm_out != nullptr); GGML_ASSERT(model.layers[il].wqkv != nullptr || model.layers[il].ssm_in != nullptr); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 217f0390..fa5a34c9 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -85,8 +85,8 @@ void llm_load_hparams( std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), false); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); // n_head_kv is optional, default to n_head hparams.n_head_kv_arr = hparams.n_head_arr; @@ -476,6 +476,37 @@ void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN35MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_80B_A3B; break; + case 48: model.type = e_model::MODEL_80B_A3B; break; + case 60: model.type = e_model::MODEL_397B_A17B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_QWEN3VLMOE: { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 8b8151e0..b4c64b31 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -75,6 +75,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { bool create_qwen3next_tensors(const LLM_TN & tn); + bool create_qwen35moe_tensors(const LLM_TN & tn); + bool create_phi2_tensors(const LLM_TN & tn); bool create_phi3_tensors(const LLM_TN & tn); @@ -1387,6 +1389,80 @@ bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) { return use_mmap_buffer; } +bool create_tensors_helper::create_qwen35moe_tensors(const LLM_TN & tn) { + LOADING_PRELUDE + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm = layer.attn_post_norm; + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared experts + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + + } + + return use_mmap_buffer; +} + bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) { LOADING_PRELUDE @@ -3319,6 +3395,8 @@ bool create_tensors_helper::create_tensors() { use_mmap_buffer = create_qwen3_moe_tensors(tn); break; case LLM_ARCH_QWEN3NEXT: use_mmap_buffer = create_qwen3next_tensors(tn); break; + case LLM_ARCH_QWEN35MOE: + use_mmap_buffer = create_qwen35moe_tensors(tn); break; case LLM_ARCH_PHI2: use_mmap_buffer = create_phi2_tensors(tn); break; case LLM_ARCH_PHI3: diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c2eeb85d..7436684f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -462,6 +462,40 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_QWEN35MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, + { LLM_TENSOR_SSM_ALPHA, "blk.%d.ssm_alpha" }, + //{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_QWEN3VL, { @@ -1690,9 +1724,10 @@ const char * llama_model_type_name(e_model type) { case MODEL_310B_A15B: return "310B.A15B"; case MODEL_300B_A47B: return "300B.A47B"; case MODEL_355B_A32B: return "355B.A32B"; + case MODEL_397B_A17B: return "397B.A17B"; case MODEL_744B_A40B: return "744B.A40B"; case MODEL_E2B: return "E2B"; case MODEL_E4B: return "E4B"; - default: return "?B"; + default: return "?B"; } } diff --git a/src/llama-model.h b/src/llama-model.h index a0727282..53600194 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -116,6 +116,7 @@ enum e_model { MODEL_310B_A15B, MODEL_300B_A47B, // Ernie MoE big MODEL_355B_A32B, + MODEL_397B_A17B, // Qwen-3.5-MoE MODEL_744B_A40B, MODEL_E2B, MODEL_E4B, @@ -292,6 +293,8 @@ struct llama_layer { struct ggml_tensor * ssm_out = nullptr; struct ggml_tensor * ssm_norm = nullptr; struct ggml_tensor * ssm_beta_alpha = nullptr; + struct ggml_tensor * ssm_alpha = nullptr; + struct ggml_tensor * ssm_beta = nullptr; // mamba struct ggml_tensor * ssm_conv1d = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 575fcc3b..918a8ea4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -360,6 +360,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_QWEN35: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_PORO: case LLAMA_VOCAB_PRE_TYPE_BLOOM: case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: @@ -1885,6 +1892,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "deepseek-r1-qwen") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; + } else if ( + tokenizer_pre == "qwen35") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN35; + clean_spaces = false; } else if ( tokenizer_pre == "stablelm2") { pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 6f77b38f..6f064bb0 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -50,6 +50,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 40, LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 41, + LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, }; struct LLM_KV; diff --git a/src/llama.cpp b/src/llama.cpp index 0be9df8a..28ffdc4b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -569,7 +569,7 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) { bool llama_context::update_cache_copies() { int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2; auto layer_has_attention_kv = [&](int il) { - return !(model.arch == LLM_ARCH_QWEN3NEXT && model.hparams.is_recurrent(il)); + return !((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && model.hparams.is_recurrent(il)); }; if ((int)kv_self.k_l.size() != n_layer) return false; if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false; @@ -656,7 +656,7 @@ static inline bool llama_qwen3next_is_recurrent_layer( const llama_model & model, const llama_hparams & hparams, uint32_t il) { - return model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il); + return (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && hparams.is_recurrent(il); } static inline uint32_t llama_kv_v_row_embd( @@ -665,7 +665,7 @@ static inline uint32_t llama_kv_v_row_embd( uint32_t il) { // qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence), // so per-token V rows include only attention values. - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) { return hparams.n_embd_v_gqa(il); } @@ -724,7 +724,7 @@ static bool llama_kv_cache_init( cache.recurrent = model.arch == LLM_ARCH_MAMBA; // qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in // standard layout to match the mainline hybrid path when flash attention is off. - cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT; + cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT && model.arch != LLM_ARCH_QWEN35MOE; cache.head = 0; cache.size = kv_size; @@ -736,7 +736,7 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(kv_size); - if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT) { + if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) { // init state copy sources for (uint32_t i = 0; i < cache.size; ++i) { cache.cells[i].src = i; @@ -821,7 +821,7 @@ static bool llama_kv_cache_init( std::vector mem_split(model.splits.size(), 0); const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size); - if (model.arch == LLM_ARCH_QWEN3NEXT && qnext_state_slots < std::max(1, cparams.n_seq_max)) { + if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && qnext_state_slots < std::max(1, cparams.n_seq_max)) { LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n", __func__, std::max(1, cparams.n_seq_max), qnext_state_slots); } @@ -3144,7 +3144,7 @@ static int llama_decode_internal( auto tim1 = ggml_time_us(); #endif uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); - if (model.arch == LLM_ARCH_QWEN3NEXT && + if ((model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35MOE) && n_tokens > 1 && batch_all.n_seq_id != nullptr && batch_all.seq_id != nullptr) { @@ -5227,6 +5227,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_QWEN35MOE: return LLAMA_ROPE_TYPE_IMROPE; // all model arches should be listed explicitly here