qwen3next: add architecture support and recurrent-state fixes

This commit is contained in:
yurko
2026-02-06 12:13:09 +00:00
parent a527b5af25
commit a7df116441
28 changed files with 2729 additions and 14 deletions

View File

@@ -27,6 +27,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_PHI2, "phi2" },
@@ -173,6 +174,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
@@ -229,4 +231,3 @@ const char * llama_model_arch_name(llm_arch arch) {
}
return it->second;
}

View File

@@ -26,6 +26,7 @@ enum llm_arch {
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_PHI2,
@@ -167,6 +168,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_PRE,
@@ -237,6 +239,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
@@ -264,8 +267,11 @@ enum llm_tensor {
LLM_TENSOR_SSM_X,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_A_NOSCAN,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_SSM_BETA_ALPHA,
LLM_TENSOR_ATTN_Q_A,
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,

View File

@@ -84,6 +84,7 @@ void llm_build_context::init() {
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_s_seq_qnext = nullptr;
lctx.inp_pos_bucket = nullptr;
lctx.inp_embd_enc = nullptr;
lctx.inp_KQ_mask_cross = nullptr;
@@ -4094,6 +4095,720 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
return gf;
}
ggml_cgraph * llm_build_context::build_qwen3next() {
static constexpr int QWEN3NEXT_CHUNK_SIZE = 64;
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
GGML_ASSERT(batch.n_tokens > 0);
llama_seq_id seq_id = kv_head;
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
GGML_ASSERT(batch.n_seq_id[0] == 1);
seq_id = batch.seq_id[0][0];
for (int i = 1; i < batch.n_tokens; ++i) {
GGML_ASSERT(batch.n_seq_id[i] == 1);
GGML_ASSERT(batch.seq_id[i][0] == seq_id);
}
}
GGML_ASSERT(hparams.ssm_n_group > 0);
GGML_ASSERT(hparams.ssm_dt_rank > 0);
GGML_ASSERT(hparams.ssm_d_conv > 0);
GGML_ASSERT(hparams.ssm_d_inner % hparams.ssm_dt_rank == 0);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads;
const int64_t key_dim = head_k_dim * num_k_heads;
const int64_t value_dim = head_v_dim * num_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
const int64_t conv_state_dim = (hparams.ssm_d_conv - 1) * conv_dim;
const int64_t ssm_state_dim = head_v_dim * head_v_dim * num_v_heads;
const int64_t state_dim = conv_state_dim + ssm_state_dim;
const uint32_t qnext_state_slots = std::max<uint32_t>(1, cparams.n_seq_max);
GGML_ASSERT(hparams.n_embd_v_s() == (uint32_t) state_dim);
// Reserve-graph builds may not carry explicit sequence IDs, in which case
// seq_id falls back to kv_head and can exceed the recurrent slot count.
const uint32_t state_seq_id = (batch.n_seq_id != nullptr && batch.seq_id != nullptr)
? (uint32_t) seq_id
: 0u;
GGML_ASSERT(state_seq_id < qnext_state_slots);
const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
auto get_slice_2d = [&](ggml_tensor * t, int64_t c) -> ggml_tensor * {
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
};
auto build_delta_net_chunking = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(n_seqs == 1);
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
cb(beta, "beta_perm", il);
cb(g, "g_perm", il);
cb(state,"state_in", il);
const int64_t chunk_size = QWEN3NEXT_CHUNK_SIZE;
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
cb(q, "q_pad", il);
cb(k, "k_pad", il);
cb(v, "v_pad", il);
cb(beta, "beta_pad", il);
cb(g, "g_pad", il);
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
cb(v_beta, "v_beta", il);
cb(k_beta, "k_beta", il);
q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
cb(g_cumsum, "g_cumsum", il);
ggml_tensor * gcs_i =
ggml_repeat_4d(ctx0, g_cumsum, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j_broadcast =
ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
cb(decay_mask, "decay_mask", il);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
cb(attn, "attn_pre_solve", il);
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
attn = ggml_mul(ctx0, lin_solve, causal_mask);
attn = ggml_add(ctx0, attn, identity);
cb(attn, "attn_solved", il);
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
cb(kbeta_gexp, "kbeta_gexp", il);
ggml_tensor * k_cumdecay =
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
cb(k_cumdecay, "k_cumdecay", il);
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
cb(attn_kq, "attn_kq", il);
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
(g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
g_last = ggml_cont(ctx0, g_last);
cb(g_last, "g_last", il);
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
cb(g_last_exp, "g_last_exp", il);
ggml_tensor * g_last_repeat =
ggml_repeat_4d(ctx0, g_last, chunk_size, 1, n_chunks, H_v * n_seqs);
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last_repeat));
cb(g_diff, "g_diff", il);
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
cb(key_gdiff, "key_gdiff", il);
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
cb(key_gdiff_t, "key_gdiff_t", il);
ggml_tensor * new_state = state;
cb(new_state, "new_state", il);
ggml_tensor * core_attn_out = nullptr;
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
ggml_tensor * q_chunk = get_slice_2d(q, chunk);
ggml_tensor * v_chunk = get_slice_2d(v, chunk);
ggml_tensor * gexp_chunk = get_slice_2d(gexp, chunk);
ggml_tensor * k_cumdecay_chunk = get_slice_2d(k_cumdecay, chunk);
ggml_tensor * attn_chunk = get_slice_2d(attn_kq, chunk);
cb(attn_chunk, "attn_chunk", il);
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
cb(v_prime, "v_prime_chunk", il);
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
cb(v_new, "v_new_chunk", il);
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
cb(attn_inter, "attn_inter_chunk", il);
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
cb(v_attn, "v_attn_chunk", il);
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
cb(core_attn_out_chunk, "core_attn_out_chunk", il);
core_attn_out = core_attn_out == nullptr
? core_attn_out_chunk
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
ggml_tensor * k_gdiff_t = get_slice_2d(key_gdiff_t, chunk);
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(g_last_exp, chunk));
new_state = ggml_add(ctx0,
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
}
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
S_v, n_tokens, H_v, n_seqs,
ggml_row_size(core_attn_out->type, S_v),
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks),
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks * H_v), 0);
output_tokens = ggml_cont(ctx0, output_tokens);
cb(output_tokens, "output_tokens", il);
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
output_tokens = ggml_cont(ctx0, output_tokens);
return {output_tokens, new_state};
};
auto build_delta_net_autoregressive = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(n_tokens == 1);
GGML_ASSERT(n_seqs == 1);
GGML_ASSERT(H_k == H_v);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
g_t = ggml_exp(ctx0, g_t);
state = ggml_mul(ctx0, state, g_t);
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
state = ggml_add(ctx0, state, k_t_delta);
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
ggml_tensor * core_attn_out =
ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
cb(core_attn_out, "output_tokens", il);
cb(state, "new_state", il);
return {core_attn_out, state};
};
auto build_qkvz = [&](ggml_tensor * input, int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
if (model.layers[il].wqkv) {
ggml_tensor * qkv_mixed = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, input);
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_tokens, 1);
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
ggml_tensor * z = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv_gate, input);
cb(z, "z", il);
return { qkv_mixed, z };
}
ggml_tensor * mixed_qkvz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, input);
cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
const int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, 1);
int64_t split_sizes_qkvz[4] = {
head_k_dim,
head_k_dim,
head_v_dim * num_v_heads / num_k_heads,
head_v_dim * num_v_heads / num_k_heads
};
ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
cb(query, "q", il);
ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
cb(key, "k", il);
ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
cb(value, "v", il);
ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
z = ggml_cont(ctx0, z);
cb(z, "z", il);
ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_tokens, 1);
cb(query_flat, "query_flat", il);
ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_tokens, 1);
cb(key_flat, "key_flat", il);
ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_tokens, 1);
cb(value_flat, "value_flat", il);
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
cb(qkv_mixed, "qkv_mixed", il);
return { qkv_mixed, z };
};
auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * {
ggml_tensor * Qcur_full = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur_full, "Qcur_full", il);
Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
cb(Qcur, "Qcur", il);
cb(gate, "gate", il);
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur_reshaped", il);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cb(gate, "gate_reshaped", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf,
nullptr, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv,
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale,
cb, il);
cb(attn, "attn_pregate", il);
ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
cb(gate_sigmoid, "gate_sigmoid", il);
attn = ggml_mul(ctx0, attn, gate_sigmoid);
cb(attn, "attn_gated", il);
attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn);
cb(attn, "attn_output", il);
return attn;
};
auto build_layer_ffn = [&](ggml_tensor * cur, int il) -> ggml_tensor * {
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used, LLM_FFN_SILU,
true, false, 0.0f, LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il, gf, false);
cb(moe_out, "ffn_moe_out", il);
ggml_tensor * ffn_shexp =
llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(ffn_shexp, "ffn_shexp", il);
ggml_tensor * shared_gate = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
cb(shared_gate, "shared_expert_gate", il);
shared_gate = ggml_sigmoid(ctx0, shared_gate);
cb(shared_gate, "shared_expert_gate_sigmoid", il);
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
cb(ffn_shexp, "ffn_shexp_gated", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
return cur;
};
auto build_layer_attn_linear = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il) -> ggml_tensor * {
auto qkvz = build_qkvz(cur, il);
ggml_tensor * qkv_mixed = qkvz.first;
ggml_tensor * z = qkvz.second;
ggml_tensor * mixed_ba = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta_alpha, cur);
cb(mixed_ba, "linear_attn_mixed_ba", il);
int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, 1);
int64_t split_sizes_ba[2] = {
num_v_heads / num_k_heads,
num_v_heads / num_k_heads
};
ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, 1,
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
cb(b, "b", il);
ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, 1,
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
cb(a, "a", il);
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tokens, 1);
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tokens, 1);
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
cb(alpha_softplus, "a_softplus", il);
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);
cb(gate, "gate", il);
size_t state_row_size = 0;
ggml_tensor * state_all = nullptr;
if ((size_t) il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr) {
ggml_tensor * state_storage = kv_self.s_l[il];
GGML_ASSERT(state_storage->type == GGML_TYPE_F32);
GGML_ASSERT(state_storage->ne[0] >= state_dim);
GGML_ASSERT(state_storage->ne[1] >= qnext_state_slots);
state_row_size = state_storage->nb[1];
state_all = ggml_view_2d(ctx0, state_storage, state_dim, qnext_state_slots, state_row_size, 0);
} else {
const size_t state_offs = (size_t) ggml_element_size(kv_self.v_l[il]) * hparams.n_embd_v_gqa(il) * kv_self.size;
state_row_size = ggml_row_size(kv_self.v_l[il]->type, state_dim);
state_all = ggml_view_2d(ctx0, kv_self.v_l[il], state_dim, qnext_state_slots, state_row_size, state_offs);
}
ggml_tensor * state_dst = ggml_view_2d(ctx0, state_all, state_dim, 1, state_row_size, state_seq_id * state_row_size);
ggml_tensor * state_f32 = state_dst;
if (state_f32->type != GGML_TYPE_F32) {
state_f32 = ggml_cast(ctx0, state_f32, GGML_TYPE_F32);
}
if (reset_state) {
state_f32 = ggml_scale(ctx0, state_f32, 0.0f);
}
ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_f32, conv_state_dim, 1, state_f32->nb[1], 0);
ggml_tensor * ssm_state_flat = ggml_view_2d(ctx0, state_f32, ssm_state_dim, 1, state_f32->nb[1],
conv_state_dim * ggml_element_size(state_f32));
ggml_tensor * conv_states = ggml_reshape_3d(ctx0, conv_state_flat, hparams.ssm_d_conv - 1, conv_dim, 1);
ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim * num_v_heads, 1, 1);
cb(conv_states, "conv_states", il);
cb(state, "state_predelta", il);
GGML_ASSERT(lctx.inp_s_seq_qnext != nullptr);
ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, lctx.inp_s_seq_qnext);
cb(conv_output_raw, "conv_output_raw", il);
ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tokens, conv_dim * ggml_element_size(conv_output_raw), 0);
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output);
cb(conv_output_silu, "conv_output_silu", il);
ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tokens, conv_output_silu->nb[1], 0);
ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tokens, conv_output_silu->nb[1],
key_dim * ggml_element_size(conv_output_silu));
ggml_tensor * v_conv = ggml_view_2d(ctx0, conv_output_silu, value_dim, n_tokens, conv_output_silu->nb[1],
2 * key_dim * ggml_element_size(conv_output_silu));
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tokens, 1);
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tokens, 1);
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tokens, 1);
if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
const int64_t repeat_factor = num_v_heads / num_k_heads;
ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_tokens);
ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_tokens);
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tokens, 1);
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tokens, 1);
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tokens, 1);
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tokens, 1);
}
cb(q_conv, "q_conv_predelta", il);
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
std::pair<ggml_tensor *, ggml_tensor *> attn_out =
n_tokens == 1
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
cb(new_state, "new_state", il);
ggml_tensor * new_conv_states = ggml_view_2d(ctx0, conv_output_raw, hparams.ssm_d_conv - 1, conv_dim,
hparams.ssm_d_conv * ggml_element_size(conv_output_raw),
(1 + conv_dim * n_tokens) * ggml_element_size(conv_output_raw));
ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, ggml_cont(ctx0, new_conv_states), conv_state_dim, 1);
ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, ggml_cont(ctx0, new_state), ssm_state_dim, 1);
ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0);
ggml_tensor * state_update = new_state_flat;
if (state_dst->type != GGML_TYPE_F32) {
state_update = ggml_cast(ctx0, state_update, state_dst->type);
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_update, state_dst));
ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_tokens);
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_tokens);
ggml_tensor * attn_out_norm = llm_build_norm(ctx0, attn_out_2d, hparams, model.layers[il].ssm_norm, nullptr, LLM_NORM_RMS, cb, il);
ggml_tensor * gated_silu = ggml_silu(ctx0, z_2d);
attn_out_norm = ggml_mul(ctx0, attn_out_norm, gated_silu);
ggml_tensor * final_output = ggml_reshape_2d(ctx0, attn_out_norm, value_dim, n_tokens);
cb(final_output, "final_output", il);
ggml_tensor * out = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, final_output);
cb(out, "linear_attn_out", il);
return ggml_cont_2d(ctx0, out, n_embd, n_tokens);
};
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
ggml_tensor * KQ_mask = build_inp_KQ_mask();
lctx.inp_s_seq_qnext = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 1, n_tokens);
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
ggml_set_input(lctx.inp_s_seq_qnext);
ggml_tensor * causal_mask =
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f));
ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
ggml_build_forward_expand(gf, causal_mask);
ggml_build_forward_expand(gf, identity);
ggml_build_forward_expand(gf, diag_mask);
ggml_tensor * cur = nullptr;
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
GGML_ASSERT(model.layers[il].attn_norm != nullptr);
GGML_ASSERT(model.layers[il].attn_post_norm != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr);
GGML_ASSERT(model.layers[il].ffn_up_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_down_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_inp_shexp != nullptr);
GGML_ASSERT(model.layers[il].ffn_up_shexp != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_shexp != nullptr);
GGML_ASSERT(model.layers[il].ffn_down_shexp != nullptr);
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
if (hparams.is_recurrent(il)) {
GGML_ASSERT(model.layers[il].ssm_conv1d != nullptr);
GGML_ASSERT(model.layers[il].ssm_dt != nullptr);
GGML_ASSERT(model.layers[il].ssm_a != nullptr);
GGML_ASSERT(model.layers[il].ssm_beta_alpha != nullptr);
GGML_ASSERT(model.layers[il].ssm_norm != nullptr);
GGML_ASSERT(model.layers[il].ssm_out != nullptr);
GGML_ASSERT(model.layers[il].wqkv != nullptr || model.layers[il].ssm_in != nullptr);
GGML_ASSERT(model.layers[il].wqkv_gate != nullptr || model.layers[il].ssm_in != nullptr);
cur = build_layer_attn_linear(cur, causal_mask, identity, diag_mask, il);
} else {
GGML_ASSERT(model.layers[il].wq != nullptr);
GGML_ASSERT(model.layers[il].wk != nullptr);
GGML_ASSERT(model.layers[il].wv != nullptr);
GGML_ASSERT(model.layers[il].wo != nullptr);
GGML_ASSERT(model.layers[il].attn_q_norm != nullptr);
GGML_ASSERT(model.layers[il].attn_k_norm != nullptr);
cur = build_layer_attn(cur, inp_pos, KQ_mask, il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
ggml_tensor * ffn_residual = cur;
ggml_tensor * attn_post_norm = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(attn_post_norm, "attn_post_norm", il);
cur = build_layer_ffn(attn_post_norm, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
cur = llm_build_norm(ctx0, inpL, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
ggml_cgraph * llm_build_context::build_qwen3vl() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -9126,6 +9841,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
{
result = llm.build_qwen3moe();
} break;
case LLM_ARCH_QWEN3NEXT:
{
result = llm.build_qwen3next();
} break;
case LLM_ARCH_QWEN3VL:
{
result = llm.build_qwen3vl();

View File

@@ -204,6 +204,8 @@ struct llm_build_context {
ggml_cgraph * build_qwen3vlmoe();
ggml_cgraph * build_qwen3next();
ggml_cgraph * build_phi2();
ggml_cgraph * build_phi3();

View File

@@ -56,6 +56,7 @@ struct llama_kv_cache {
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
std::vector<struct ggml_tensor *> s_l; // per layer recurrent state storage (Qwen3Next)
std::vector<llama_split_tensor> split_k_l;
std::vector<llama_split_tensor> split_v_l;
@@ -202,6 +203,7 @@ struct llama_context {
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_s_seq_qnext; // I32 [1, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]

View File

@@ -5,7 +5,7 @@
#include <map>
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
@@ -83,6 +83,7 @@ void llm_load_hparams(
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), false);
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
@@ -453,6 +454,25 @@ void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3NEXT:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Upstream convention: every 4th layer is full attention, others are recurrent.
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0);
}
model.type = e_model::MODEL_UNKNOWN;
} break;
case LLM_ARCH_QWEN3VLMOE:
{
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);

View File

@@ -85,6 +85,10 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
uint32_t ssm_n_group = 0;
// for hybrid state-space models (e.g. qwen3next)
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
@@ -157,6 +161,8 @@ struct llama_hparams {
if (this->ssm_d_inner != other.ssm_d_inner) return true;
if (this->ssm_d_state != other.ssm_d_state) return true;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->ssm_n_group != other.ssm_n_group) return true;
if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true;
if (this->dec_start_token_id != other.dec_start_token_id) return true;
@@ -234,6 +240,10 @@ struct llama_hparams {
}
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
if (ssm_n_group > 0) {
// qwen3next keeps all recurrent state in the V-cache tail
return 0;
}
// corresponds to Mamba's conv_states size
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
@@ -241,10 +251,30 @@ struct llama_hparams {
}
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
if (ssm_n_group > 0) {
// qwen3next recurrent state packs:
// 1) conv state: (d_conv - 1) * (2 * key_dim + value_dim)
// 2) delta-net state: head_v_dim * head_v_dim * num_v_heads
const uint32_t key_dim = ssm_d_state * ssm_n_group;
const uint32_t value_dim = ssm_d_inner;
const uint32_t conv_dim = 2 * key_dim + value_dim;
const uint32_t conv_state_dim = (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * conv_dim;
const uint32_t head_v_dim = ssm_dt_rank > 0 ? ssm_d_inner / ssm_dt_rank : 0;
const uint32_t ssm_state_dim = head_v_dim * head_v_dim * ssm_dt_rank;
return conv_state_dim + ssm_state_dim;
}
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}
bool is_recurrent(uint32_t il) const {
if (il < n_layer) {
return recurrent_layer_arr[il];
}
GGML_ABORT("fatal error");
}
static bool is_float_close(float a, float b, float abs_tol) {
// Check for non-negative tolerance
if (abs_tol < 0.0) {

View File

@@ -73,6 +73,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
bool create_qwen3_moe_tensors(const LLM_TN & tn);
bool create_qwen3next_tensors(const LLM_TN & tn);
bool create_phi2_tensors(const LLM_TN & tn);
bool create_phi3_tensors(const LLM_TN & tn);
@@ -1174,6 +1176,83 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
return use_mmap_buffer;
}
bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (model.output == NULL) {
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for QWEN3NEXT");
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for QWEN3NEXT");
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp;
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads;
const int64_t key_dim = head_k_dim * num_k_heads;
const int64_t value_dim = head_v_dim * num_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
const int64_t qkvz_dim = key_dim * 2 + value_dim * 2;
const int64_t ba_dim = num_v_heads * 2;
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
if (!hparams.is_recurrent(i)) {
// Full-attention layer
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head * 2});
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
} else {
// Recurrent linear-attention layer
layer.ssm_in = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, qkvz_dim},
llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, key_dim * 2 + value_dim},
llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, value_dim},
llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {hparams.ssm_d_conv, conv_dim});
layer.ssm_dt = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {hparams.ssm_dt_rank});
layer.ssm_a = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A_NOSCAN, i), {hparams.ssm_dt_rank});
layer.ssm_beta_alpha = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), {n_embd, ba_dim});
layer.ssm_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {head_v_dim});
layer.ssm_out = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {value_dim, n_embd});
}
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
use_mmap_buffer &= !create_std_ffn_exps(n_embd, tn, i);
layer.ffn_gate_inp_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp});
layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp});
layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd});
}
return use_mmap_buffer;
}
bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
@@ -2984,6 +3063,8 @@ bool create_tensors_helper::create_tensors() {
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3VLMOE:
use_mmap_buffer = create_qwen3_moe_tensors(tn); break;
case LLM_ARCH_QWEN3NEXT:
use_mmap_buffer = create_qwen3next_tensors(tn); break;
case LLM_ARCH_PHI2:
use_mmap_buffer = create_phi2_tensors(tn); break;
case LLM_ARCH_PHI3:

View File

@@ -429,6 +429,39 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_QWEN3NEXT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_QWEN3VL,
{

View File

@@ -173,6 +173,7 @@ struct llama_layer {
struct ggml_tensor * wk_enc = nullptr;
struct ggml_tensor * wv_enc = nullptr;
struct ggml_tensor * wo_enc = nullptr;
struct ggml_tensor * wqkv_gate = nullptr;
struct ggml_tensor * attn_sinks = nullptr;
// attention bias
@@ -286,6 +287,8 @@ struct llama_layer {
struct ggml_tensor * ssm_x = nullptr;
struct ggml_tensor * ssm_dt = nullptr;
struct ggml_tensor * ssm_out = nullptr;
struct ggml_tensor * ssm_norm = nullptr;
struct ggml_tensor * ssm_beta_alpha = nullptr;
// mamba
struct ggml_tensor * ssm_conv1d = nullptr;

View File

@@ -636,6 +636,30 @@ llama_context::~llama_context() {
// kv cache helpers
//
static inline bool llama_qwen3next_is_recurrent_layer(
const llama_model & model,
const llama_hparams & hparams,
uint32_t il) {
return model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il);
}
static inline uint32_t llama_kv_v_row_embd(
const llama_model & model,
const llama_hparams & hparams,
uint32_t il) {
// qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence),
// so per-token V rows include only attention values.
if (model.arch == LLM_ARCH_QWEN3NEXT) {
return hparams.n_embd_v_gqa(il);
}
return hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
}
static inline uint32_t llama_qwen3next_state_slots(const llama_cparams & cparams) {
return std::max<uint32_t>(1, cparams.n_seq_max);
}
static bool llama_kv_cache_init(
struct llama_kv_cache & cache,
const llama_context * ctx,
@@ -744,18 +768,23 @@ static bool llama_kv_cache_init(
needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn;
}
if (needs_v_cache) cache.v_l.reserve(n_layer);
cache.s_l.reserve(n_layer);
std::vector<size_t> mem_split(model.splits.size(), 0);
const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams);
int n_mla = 0;
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const bool qnext_recurrent = llama_qwen3next_is_recurrent_layer(model, hparams, i);
const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i);
const uint32_t n_head_kv = hparams.n_head_kv(i);
const uint32_t n_embd_head_k= hparams.n_embd_head_k;
struct ggml_context * ctx = split_cache ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
ggml_tensor * v;
ggml_tensor * s = nullptr;
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
// DeepSeek MLA
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
@@ -781,11 +810,21 @@ static bool llama_kv_cache_init(
else {
int n_embd_head_v = hparams.n_embd_head_v;
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
int64_t v_ne = int64_t(n_embd_v_row)*kv_size;
v = ggml_new_tensor_1d(ctx, type_v, v_ne);
if (qnext_recurrent) {
const int64_t s_ne = int64_t(hparams.n_embd_v_s())*qnext_state_slots;
s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, s_ne);
}
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
auto v_name = std::string{"cache_v_l"} + std::to_string(i);
auto s_name = std::string{"cache_s_l"} + std::to_string(i);
ggml_set_name(k, k_name.c_str());
ggml_set_name(v, v_name.c_str());
if (s) {
ggml_set_name(s, s_name.c_str());
}
//ggml_format_name(k, "cache_k_l%d", i);
//ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
@@ -836,6 +875,7 @@ static bool llama_kv_cache_init(
//}
}
}
cache.s_l.push_back(s);
}
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
@@ -2756,6 +2796,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (lctx.inp_s_seq_qnext) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq_qnext->buffer));
int32_t * data = (int32_t *) lctx.inp_s_seq_qnext->data;
for (int64_t j = 0; j < n_tokens; ++j) {
// qwen3next linear-attention path uses a single local recurrent state slot.
data[j] = 0;
}
}
if (lctx.inp_pos_bucket) {
const int64_t n_tokens = batch.n_tokens;
@@ -5032,6 +5084,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_QWEN2MOE:
case LLM_ARCH_QWEN3:
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3NEXT:
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3:
case LLM_ARCH_GEMMA:
@@ -5625,7 +5678,7 @@ struct llama_data_write {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
@@ -5647,7 +5700,7 @@ struct llama_data_write {
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
@@ -5985,7 +6038,7 @@ struct llama_data_read {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
// Read type of value
int32_t v_type_i_ref;
@@ -6018,7 +6071,7 @@ struct llama_data_read {
else if (v_state == 1) {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
// Read type of value
int32_t v_type_i_ref;