qwen3next: harden seq-state flow and support optional dense FFN layers

This commit is contained in:
yurko
2026-02-07 13:12:26 -08:00
parent 6db8dc86ca
commit fffd27e3c8
6 changed files with 246 additions and 111 deletions

View File

@@ -6,6 +6,8 @@
#include "ggml.h"
#include <unordered_set>
llm_build_context::llm_build_context(
llama_context & lctx,
const llama_batch & batch,
@@ -162,21 +164,34 @@ ggml_cgraph * llm_build_context::build_k_shift() {
ggml_cgraph * llm_build_context::build_s_copy() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
GGML_ASSERT(kv_self.recurrent);
const bool has_qnext_state = std::any_of(kv_self.s_l.begin(), kv_self.s_l.end(), [](const ggml_tensor * t) {
return t != nullptr;
});
GGML_ASSERT(kv_self.recurrent || has_qnext_state);
struct ggml_tensor * state_copy = build_inp_s_copy();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
if (kv_self.recurrent) {
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
// TODO: name the intermediate tensors with cb()
// TODO: name the intermediate tensors with cb()
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
}
if (kv_self.s_l.size() > (size_t) il && kv_self.s_l[il] != nullptr) {
struct ggml_tensor * qnext_states_all = ggml_reshape_2d(ctx0, kv_self.s_l[il], hparams.n_embd_v_s(), kv_self.s_l[il]->ne[1]);
struct ggml_tensor * qnext_state_copy = ggml_view_1d(ctx0, state_copy, qnext_states_all->ne[1], 0);
struct ggml_tensor * qnext_states = ggml_get_rows(ctx0, qnext_states_all, qnext_state_copy);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, qnext_states, kv_self.s_l[il]));
}
}
return gf;
@@ -4101,13 +4116,32 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
GGML_ASSERT(batch.n_tokens > 0);
llama_seq_id seq_id = kv_head;
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
GGML_ASSERT(batch.n_seq_id[0] == 1);
seq_id = batch.seq_id[0][0];
for (int i = 1; i < batch.n_tokens; ++i) {
GGML_ASSERT(batch.n_seq_id[i] == 1);
GGML_ASSERT(batch.seq_id[i][0] == seq_id);
const bool has_explicit_seq_info = batch.n_seq_id != nullptr && batch.seq_id != nullptr;
std::vector<llama_seq_id> token_seq_ids(batch.n_tokens, 0);
for (int i = 0; i < batch.n_tokens; ++i) {
if (has_explicit_seq_info) {
GGML_ASSERT(batch.n_seq_id[i] > 0 && "qwen3next expects each token to belong to at least one sequence");
GGML_ASSERT(batch.n_seq_id[i] == 1 && "qwen3next does not support multi-sequence tokens yet");
token_seq_ids[i] = batch.seq_id[i][0];
} else {
token_seq_ids[i] = 0;
}
}
const llama_seq_id seq_id = token_seq_ids[0];
const bool all_same_seq = std::all_of(token_seq_ids.begin(), token_seq_ids.end(), [&](llama_seq_id s) {
return s == seq_id;
});
bool has_unique_seq_ids = true;
if (!all_same_seq) {
std::unordered_set<llama_seq_id> seen;
seen.reserve(token_seq_ids.size());
for (llama_seq_id s : token_seq_ids) {
if (!seen.insert(s).second) {
has_unique_seq_ids = false;
break;
}
}
}
@@ -4134,11 +4168,12 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
GGML_ASSERT(hparams.n_embd_v_s() == (uint32_t) state_dim);
// Reserve-graph builds may not carry explicit sequence IDs, in which case
// seq_id falls back to kv_head and can exceed the recurrent slot count.
const uint32_t state_seq_id = (batch.n_seq_id != nullptr && batch.seq_id != nullptr)
? (uint32_t) seq_id
: 0u;
GGML_ASSERT(state_seq_id < qnext_state_slots);
// the fallback sequence slot is 0.
const uint32_t state_seq_id = (uint32_t) seq_id;
for (llama_seq_id s : token_seq_ids) {
GGML_ASSERT(s >= 0);
GGML_ASSERT((uint32_t) s < qnext_state_slots);
}
const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
@@ -4415,9 +4450,10 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
};
auto build_qkvz = [&](ggml_tensor * input, int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t n_tok = input->ne[1];
if (model.layers[il].wqkv) {
ggml_tensor * qkv_mixed = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, input);
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_tokens, 1);
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_tok, 1);
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
ggml_tensor * z = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv_gate, input);
@@ -4430,7 +4466,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
const int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, 1);
ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tok, 1);
int64_t split_sizes_qkvz[4] = {
head_k_dim,
@@ -4439,33 +4475,33 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
head_v_dim * num_v_heads / num_k_heads
};
ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, 1,
ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
cb(query, "q", il);
ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, 1,
ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
cb(key, "k", il);
ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, 1,
ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
cb(value, "v", il);
ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, 1,
ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
z = ggml_cont(ctx0, z);
cb(z, "z", il);
ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_tokens, 1);
ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_tok, 1);
cb(query_flat, "query_flat", il);
ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_tokens, 1);
ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_tok, 1);
cb(key_flat, "key_flat", il);
ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_tokens, 1);
ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_tok, 1);
cb(value_flat, "value_flat", il);
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
@@ -4539,43 +4575,68 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
};
auto build_layer_ffn = [&](ggml_tensor * cur, int il) -> ggml_tensor * {
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used, LLM_FFN_SILU,
true, false, 0.0f, LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il, gf, false);
cb(moe_out, "ffn_moe_out", il);
const bool has_moe = model.layers[il].ffn_gate_inp != nullptr;
const bool has_dense = model.layers[il].ffn_gate != nullptr && model.layers[il].ffn_up != nullptr && model.layers[il].ffn_down != nullptr;
ggml_tensor * ffn_shexp =
llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(ffn_shexp, "ffn_shexp", il);
if (has_moe) {
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used, LLM_FFN_SILU,
true, false, 0.0f, LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il, gf, false);
cb(moe_out, "ffn_moe_out", il);
ggml_tensor * shared_gate = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
cb(shared_gate, "shared_expert_gate", il);
shared_gate = ggml_sigmoid(ctx0, shared_gate);
cb(shared_gate, "shared_expert_gate_sigmoid", il);
const bool has_shexp = model.layers[il].ffn_up_shexp != nullptr &&
model.layers[il].ffn_gate_shexp != nullptr &&
model.layers[il].ffn_down_shexp != nullptr &&
model.layers[il].ffn_gate_inp_shexp != nullptr;
if (has_shexp) {
ggml_tensor * ffn_shexp =
llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(ffn_shexp, "ffn_shexp", il);
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
cb(ffn_shexp, "ffn_shexp_gated", il);
ggml_tensor * shared_gate = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
cb(shared_gate, "shared_expert_gate", il);
shared_gate = ggml_sigmoid(ctx0, shared_gate);
cb(shared_gate, "shared_expert_gate_sigmoid", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
cb(ffn_shexp, "ffn_shexp_gated", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
} else {
cur = moe_out;
}
cb(cur, "ffn_out", il);
return cur;
}
GGML_ASSERT(has_dense);
cur = llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
return cur;
};
auto build_layer_attn_linear = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il) -> ggml_tensor * {
auto build_layer_attn_linear_core = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, ggml_tensor * inp_s_seq_qnext,
uint32_t state_seq_id_local, bool reset_state_local, int il) -> ggml_tensor * {
const int64_t n_tok = cur->ne[1];
auto qkvz = build_qkvz(cur, il);
ggml_tensor * qkv_mixed = qkvz.first;
ggml_tensor * z = qkvz.second;
@@ -4584,24 +4645,24 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(mixed_ba, "linear_attn_mixed_ba", il);
int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, 1);
ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tok, 1);
int64_t split_sizes_ba[2] = {
num_v_heads / num_k_heads,
num_v_heads / num_k_heads
};
ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, 1,
ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tok, 1,
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
cb(b, "b", il);
ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, 1,
ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tok, 1,
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
cb(a, "a", il);
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tokens, 1);
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tokens, 1);
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tok, 1);
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1);
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
@@ -4615,8 +4676,8 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * state_storage = kv_self.s_l[il];
GGML_ASSERT(state_storage->type == GGML_TYPE_F32);
GGML_ASSERT(state_storage->ne[0] >= state_dim);
GGML_ASSERT(state_storage->ne[1] >= qnext_state_slots);
state_row_size = state_storage->nb[1];
state_row_size = state_storage->ne[1] >= qnext_state_slots ? state_storage->nb[1] : ggml_row_size(state_storage->type, state_dim);
GGML_ASSERT(ggml_nbytes(state_storage) >= state_row_size * qnext_state_slots);
state_all = ggml_view_2d(ctx0, state_storage, state_dim, qnext_state_slots, state_row_size, 0);
} else {
const size_t state_offs = (size_t) ggml_element_size(kv_self.v_l[il]) * hparams.n_embd_v_gqa(il) * kv_self.size;
@@ -4624,12 +4685,12 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
state_all = ggml_view_2d(ctx0, kv_self.v_l[il], state_dim, qnext_state_slots, state_row_size, state_offs);
}
ggml_tensor * state_dst = ggml_view_2d(ctx0, state_all, state_dim, 1, state_row_size, state_seq_id * state_row_size);
ggml_tensor * state_dst = ggml_view_2d(ctx0, state_all, state_dim, 1, state_row_size, state_seq_id_local * state_row_size);
ggml_tensor * state_f32 = state_dst;
if (state_f32->type != GGML_TYPE_F32) {
state_f32 = ggml_cast(ctx0, state_f32, GGML_TYPE_F32);
}
if (reset_state) {
if (reset_state_local) {
state_f32 = ggml_scale(ctx0, state_f32, 0.0f);
}
@@ -4642,36 +4703,35 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(conv_states, "conv_states", il);
cb(state, "state_predelta", il);
GGML_ASSERT(lctx.inp_s_seq_qnext != nullptr);
ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, lctx.inp_s_seq_qnext);
ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, inp_s_seq_qnext);
cb(conv_output_raw, "conv_output_raw", il);
ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tokens, conv_dim * ggml_element_size(conv_output_raw), 0);
ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tok, conv_dim * ggml_element_size(conv_output_raw), 0);
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output);
cb(conv_output_silu, "conv_output_silu", il);
ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tokens, conv_output_silu->nb[1], 0);
ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tokens, conv_output_silu->nb[1],
ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1], 0);
ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1],
key_dim * ggml_element_size(conv_output_silu));
ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_output_silu, value_dim, n_tokens, conv_output_silu->nb[1],
ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_output_silu, value_dim, n_tok, conv_output_silu->nb[1],
2 * key_dim * ggml_element_size(conv_output_silu));
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tokens, 1);
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, 1);
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, 1);
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tok, 1);
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tok, 1);
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tok, 1);
if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
const int64_t repeat_factor = num_v_heads / num_k_heads;
ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_tokens);
ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_tokens);
ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_tok);
ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_tok);
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tokens, 1);
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tokens, 1);
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tokens, 1);
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tokens, 1);
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
}
cb(q_conv, "q_conv_predelta", il);
@@ -4679,7 +4739,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(v_conv, "v_conv_predelta", il);
std::pair<ggml_tensor *, ggml_tensor *> attn_out =
n_tokens == 1
n_tok == 1
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
ggml_tensor * output = attn_out.first;
@@ -4689,7 +4749,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * new_conv_states = ggml_view_2d(ctx0, conv_output_raw, hparams.ssm_d_conv - 1, conv_dim,
hparams.ssm_d_conv * ggml_element_size(conv_output_raw),
(1 + conv_dim * n_tokens) * ggml_element_size(conv_output_raw));
(1 + conv_dim * n_tok) * ggml_element_size(conv_output_raw));
ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, ggml_cont(ctx0, new_conv_states), conv_state_dim, 1);
ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, ggml_cont(ctx0, new_state), ssm_state_dim, 1);
ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0);
@@ -4700,20 +4760,45 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_update, state_dst));
ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_tokens);
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_tokens);
ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_tok);
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_tok);
ggml_tensor * attn_out_norm = llm_build_norm(ctx0, attn_out_2d, hparams, model.layers[il].ssm_norm, nullptr, LLM_NORM_RMS, cb, il);
ggml_tensor * gated_silu = ggml_silu(ctx0, z_2d);
attn_out_norm = ggml_mul(ctx0, attn_out_norm, gated_silu);
ggml_tensor * final_output = ggml_reshape_2d(ctx0, attn_out_norm, value_dim, n_tokens);
ggml_tensor * final_output = ggml_reshape_2d(ctx0, attn_out_norm, value_dim, n_tok);
cb(final_output, "final_output", il);
ggml_tensor * out = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, final_output);
cb(out, "linear_attn_out", il);
return ggml_cont_2d(ctx0, out, n_embd, n_tokens);
return ggml_cont_2d(ctx0, out, n_embd, n_tok);
};
auto build_layer_attn_linear = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il) -> ggml_tensor * {
GGML_ASSERT(lctx.inp_s_seq_qnext != nullptr);
if (all_same_seq) {
return build_layer_attn_linear_core(cur, causal_mask, identity, diag_mask, lctx.inp_s_seq_qnext, state_seq_id, reset_state, il);
}
GGML_ASSERT(has_unique_seq_ids && "qwen3next mixed-sequence batches require unique sequence IDs per token");
ggml_tensor * out = nullptr;
for (int64_t i = 0; i < n_tokens; ++i) {
ggml_tensor * cur_i = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (size_t) i * cur->nb[1]);
ggml_tensor * inp_s_seq_qnext_i = ggml_view_2d(ctx0, lctx.inp_s_seq_qnext, 1, 1, lctx.inp_s_seq_qnext->nb[1], (size_t) i * lctx.inp_s_seq_qnext->nb[1]);
const bool reset_state_i = batch.pos != nullptr && batch.pos[i] == 0;
const uint32_t state_seq_id_i = (uint32_t) token_seq_ids[i];
ggml_tensor * out_i = build_layer_attn_linear_core(cur_i, causal_mask, identity, diag_mask, inp_s_seq_qnext_i, state_seq_id_i, reset_state_i, il);
out = out == nullptr ? out_i : ggml_concat(ctx0, out, out_i, 1);
}
return out;
};
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
@@ -4741,14 +4826,17 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
GGML_ASSERT(model.layers[il].attn_norm != nullptr);
GGML_ASSERT(model.layers[il].attn_post_norm != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr);
GGML_ASSERT(model.layers[il].ffn_up_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_down_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_inp_shexp != nullptr);
GGML_ASSERT(model.layers[il].ffn_up_shexp != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_shexp != nullptr);
GGML_ASSERT(model.layers[il].ffn_down_shexp != nullptr);
const bool has_moe = model.layers[il].ffn_gate_inp != nullptr;
const bool has_dense = model.layers[il].ffn_gate != nullptr &&
model.layers[il].ffn_up != nullptr &&
model.layers[il].ffn_down != nullptr;
GGML_ASSERT(has_moe || has_dense);
if (has_moe) {
GGML_ASSERT(model.layers[il].ffn_up_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_down_exps != nullptr);
}
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);

View File

@@ -471,7 +471,10 @@ void llm_load_hparams(
hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0);
}
model.type = e_model::MODEL_UNKNOWN;
switch (hparams.n_layer) {
case 48: model.type = e_model::MODEL_80B_A3B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3VLMOE:
{

View File

@@ -1189,10 +1189,8 @@ bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) {
}
}
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for QWEN3NEXT");
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for QWEN3NEXT");
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
const bool has_moe_hparams = n_expert > 0 && n_expert_used > 0;
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : (has_moe_hparams ? n_ff / n_expert_used : n_ff);
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp;
const int64_t head_k_dim = hparams.ssm_d_state;
@@ -1240,14 +1238,32 @@ bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) {
}
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
use_mmap_buffer &= !create_std_ffn_exps(n_embd, tn, i);
// Dense FFN path (optional, e.g. mlp_only_layers)
layer.ffn_gate = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_gate_inp_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp});
layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp});
layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd});
// MoE path (optional per-layer)
layer.ffn_gate_inp = nullptr;
if (n_expert > 0) {
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
}
if (layer.ffn_gate_inp != nullptr) {
if (n_expert_used == 0) {
throw std::runtime_error("n_expert_used must be > 0 when QWEN3NEXT MoE tensors are present");
}
use_mmap_buffer &= !create_std_ffn_exps(n_embd, tn, i, llama_model_loader::TENSOR_NOT_REQUIRED, n_ff_exp);
}
// Shared expert path (optional per-layer)
layer.ffn_gate_inp_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (layer.ffn_gate_inp_shexp != nullptr) {
layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
}
}
return use_mmap_buffer;

View File

@@ -1589,6 +1589,7 @@ const char * llama_model_type_name(e_model type) {
case MODEL_16B_A1B: return "16B.A1B";
case MODEL_21B_A3B: return "21B.A3B";
case MODEL_30B_A3B: return "30B.A3B";
case MODEL_80B_A3B: return "80B.A3B";
case MODEL_80B_A13B: return "80B.A13B";
case MODEL_100B_A6B: return "100B.A6B";
case MODEL_106B_A12B: return "106B.A12B";

View File

@@ -107,6 +107,7 @@ enum e_model {
MODEL_16B_A1B,
MODEL_21B_A3B, // Ernie MoE small
MODEL_30B_A3B,
MODEL_80B_A3B, // Qwen3-Next
MODEL_80B_A13B,
MODEL_100B_A6B,
MODEL_106B_A12B,

View File

@@ -660,6 +660,12 @@ static inline uint32_t llama_qwen3next_state_slots(const llama_cparams & cparams
return std::max<uint32_t>(1, cparams.n_seq_max);
}
static inline bool llama_kv_has_qnext_state_storage(const llama_kv_cache & cache) {
return std::any_of(cache.s_l.begin(), cache.s_l.end(), [](const ggml_tensor * t) {
return t != nullptr;
});
}
static bool llama_kv_cache_init(
struct llama_kv_cache & cache,
const llama_context * ctx,
@@ -690,7 +696,7 @@ static bool llama_kv_cache_init(
cache.cells.clear();
cache.cells.resize(kv_size);
if (cache.recurrent) {
if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT) {
// init state copy sources
for (uint32_t i = 0; i < cache.size; ++i) {
cache.cells[i].src = i;
@@ -814,8 +820,7 @@ static bool llama_kv_cache_init(
int64_t v_ne = int64_t(n_embd_v_row)*kv_size;
v = ggml_new_tensor_1d(ctx, type_v, v_ne);
if (qnext_recurrent) {
const int64_t s_ne = int64_t(hparams.n_embd_v_s())*qnext_state_slots;
s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, s_ne);
s = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_embd_v_s(), qnext_state_slots);
}
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
auto v_name = std::string{"cache_v_l"} + std::to_string(i);
@@ -1051,6 +1056,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].src = i;
cache.cells[i].seq_id.clear();
}
cache.head = 0;
@@ -1090,6 +1096,8 @@ static bool llama_kv_cache_seq_rm(
}
}
const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache);
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
@@ -1104,6 +1112,9 @@ static bool llama_kv_cache_seq_rm(
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
if (has_qnext_state) {
cache.cells[i].src = i;
}
if (new_head == cache.size) new_head = i;
}
}
@@ -1145,6 +1156,17 @@ static void llama_kv_cache_seq_cp(
}
return;
}
const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache);
if (has_qnext_state && (uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
seq_id_src = cache.cells[seq_id_src].src;
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
cache.cells[seq_id_dst].src = seq_id_src;
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
cache.do_copy = true;
}
// otherwise, this is the KV cache of a Transformer-like model
cache.head = 0;
@@ -1158,11 +1180,15 @@ static void llama_kv_cache_seq_cp(
static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
uint32_t new_head = cache.size;
const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache);
for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
if (has_qnext_state) {
cache.cells[i].src = i;
}
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
@@ -3823,7 +3849,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
}
}
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
if ((lctx.kv_self.recurrent || llama_kv_has_qnext_state_storage(lctx.kv_self)) && lctx.kv_self.do_copy) {
{
lctx.reset_scheduler();