WIP: Qwen3Next (#1266)

* qwen3next: add architecture support and recurrent-state fixes

* qwen3next: optimize broadcast sub and single-seq ssm conv

* cuda: build MoE row mapping on device in mul_mat_id

* cuda: add guarded multi-seq fast path for ssm_conv

* docs: update qwen3next perf report for cuda MoE/SSM tuning

* cuda: reduce qwen3next moe/ssm sync overhead and refresh eval

* qwen3next: split cpu/cuda eval builds and tune PP scheduling

* qwen3next: harden seq-state flow and support optional dense FFN layers

* qwen3next: trim delta-net graph overhead in chunking path

* qwen3next: remove redundant v_conv cont in delta path

* qwen3next: avoid extra cont on linear attention output

* qwen3next: drop redundant cont before recurrent state flatten

* qwen3next: keep recurrent state in 4d layout through delta path

* qwen3next: add fused delta-net op and wire model path

* tests: add backend-op coverage for ggml_delta_net

* qwen3next: add runtime switch for fused delta-net path

* docs: refresh qwen3next perf review and benchmark matrix

* qwen3next: default fused delta-net off and document quality checks

* qwen3next: add decode-only fused delta mode

* qwen3next: make fused delta safe by default and fix fused tensor layout

* qwen3next: warn when forcing fused decode mode

* qwen3next: add fused-delta regression runner script

* qwen3next: integrate fused regression into eval harness

* qwen3next: clean up chunked delta-net shape handling

* qwen3next: add absolute sanity guards to fused regression

* qwen3next: add unified regression runner script

* qwen3next: disable flash-attn for cpu-only contexts

* docs: reconcile qwen3next status and remaining upstream gaps

* common: add qwen3next fused-delta runtime flag

* cuda: add qwen3next delta-net kernel dispatch override

* docs: update qwen3next quality and serving baseline findings

* qwen3next: keep fused delta on safe path and remove PR artifacts

* qwen3next: align autoregressive delta-net decode layout

* Revert "qwen3next: align autoregressive delta-net decode layout"

This reverts commit 9241164a5e.

* cuda: port solve-tri fast-paths for qwen3next delta-net

* qwen3next: add fused-delta runtime flag and drop env toggle

* qwen3next: make fused delta single-flag and default on

* Account for GPU arch differences

* Revert "cuda: build MoE row mapping on device in mul_mat_id"

This reverts commit 89e9ecfa84.

* qwen3next: drop non-essential MoE scheduling and split heuristics

* qwen3next: avoid generic ggml_sub broadcast changes

* llama: restore only_active_experts log message

* Remove unnecessary hacks, disable fusion for now.

* qwen3next: port hybrid recurrent state memory semantics

* qwen3next: clean up recurrent state slot plumbing

* qwen3next: fix hybrid V-cache layout plumbing

* qwen3next: guard recurrent state slots against kv capacity

* qwen3next: persist recurrent state in session data

- serialize/restore qwen3next cache.s_l in state/session paths\n- bump session and sequence-state file versions for format change\n- fallback to single-token chunking for mixed repeated seq_id batches

* qwen3next: drop unused fused-delta builder path

- remove dead build_delta_net_fused lambda\n- remove unused llm_build_context::fused_delta member

* qwen3next: remove unused fused-delta CLI/context plumbing

- drop -fd/-no-fd options and related YAML dump field\n- remove fused_delta fields from public/internal context params\n- remove fused_delta assignment and logging in context init

* ggml: remove unused DELTA_NET operator stack

* Missing include

* Reorder ops/unary ops

So we don't change again the enum values of the mul mat ops

* Minor

* Discard unnecessary changes in llama-build-context.cpp

* Minor

* Revert "Discard unnecessary changes in llama-build-context.cpp"

This reverts commit edadb80ed6.

* Increase GGML_SCHED_MAX_SPLITS - required for larger u-batches

* Fix CPU concat in the TG case: 7.25 -> 10.5 t/s for Qwen3Next

* Fix CPU sum_rows: 10.5 -> 13.6 t/s for Qwen3Next

It was single-threaded and was taking ~25% of the computation time
during TG. It is now down to 2%.

Strangely enough, I measure 13.6 t/s with llama-bench, but if I
let the model give me an actual response with llama-cli, I get close
to 17 t/s.

* Fix CPU scale: 13.6 -> 16.7 t/s for Qwen3Next

For Qwen3Next there is a scale op on a largish tensor (548k elements)
that has a single row for TG, so was done in a single thread.
We now simply use blocks of 1024 elements.

* Optimize CPU mul: 16.7 -> 17.6 t/s for Qwen3Next

* CPU: fuse transpose -> cont -> sum_rows -> transpos: 17.6 -> 23.1 t/s for Qwen3Next

* Optimize CPU repeat: 176 -> 200 t/s for Qwen3Next PP-512

* Multithreading for OP_SUB

* Don't commit with timing trace on

* Multithread neg and sigmoid

* Be able to turn on/off fusion more easily (CPU)

* Name the mul_mat ops so we know where the time goes

* WIP

* Much better PP on CUDA

* CUDA: fuse transpose -> cont -> sum_rows -> transpose

Needs non-coontiguous variant of sum_rows.
On the CPU this gave 30+% improvement in TG performance,
on CUDA ist is disapointing 6-7%. I guess, this is because
Georgi's cont CPU implementation was so bad that skipping
it made such a big difference.

* CUDA: faster mul for special case relevant for Qwen3Next

Worth 1% in TG

* Fix CPU OP_CONT

---------

Co-authored-by: yurko <yurko@local>
Co-authored-by: Yurko <yurko@example.com>
Co-authored-by: yurko <yurko@pop-os.tail5a1a6b.ts.net>
Co-authored-by: Yurko Hoshko <YurkoHoshko@users.noreply.github.com>
This commit is contained in:
Kawrakow
2026-02-16 06:50:28 +01:00
committed by GitHub
parent 528cadb07b
commit e30198a553
35 changed files with 4600 additions and 232 deletions

View File

@@ -27,6 +27,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_PHI2, "phi2" },
@@ -186,6 +187,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
@@ -242,4 +244,3 @@ const char * llama_model_arch_name(llm_arch arch) {
}
return it->second;
}

View File

@@ -26,6 +26,7 @@ enum llm_arch {
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_PHI2,
@@ -180,6 +181,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_PRE,
@@ -278,8 +280,11 @@ enum llm_tensor {
LLM_TENSOR_SSM_X,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_A_NOSCAN,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_SSM_BETA_ALPHA,
LLM_TENSOR_ATTN_Q_A,
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,

View File

@@ -6,6 +6,28 @@
#include "ggml.h"
#include <unordered_set>
#include <algorithm>
static inline uint32_t llama_kv_qnext_state_slots(const llama_kv_cache & kv_self) {
uint32_t n_slots = 0;
for (const ggml_tensor * t : kv_self.s_l) {
if (t == nullptr) {
continue;
}
const uint32_t layer_slots = (uint32_t) t->ne[1];
if (n_slots == 0) {
n_slots = layer_slots;
} else {
GGML_ASSERT(n_slots == layer_slots);
}
}
return n_slots;
}
llm_build_context::llm_build_context(
llama_context & lctx,
const llama_batch & batch,
@@ -84,6 +106,7 @@ void llm_build_context::init() {
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_s_seq_qnext = nullptr;
lctx.inp_pos_bucket = nullptr;
lctx.inp_embd_enc = nullptr;
lctx.inp_KQ_mask_cross = nullptr;
@@ -118,6 +141,12 @@ ggml_cgraph * llm_build_context::build_k_shift() {
ggml_set_input(lctx.inp_K_shift);
for (int il = 0; il < n_layer; ++il) {
if (model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il)) {
continue;
}
if (kv_self.k_l[il] == nullptr) {
continue;
}
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
struct ggml_tensor * rope_factors = build_rope_factors(il);
@@ -161,21 +190,34 @@ ggml_cgraph * llm_build_context::build_k_shift() {
ggml_cgraph * llm_build_context::build_s_copy() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
GGML_ASSERT(kv_self.recurrent);
const uint32_t qnext_state_slots = llama_kv_qnext_state_slots(kv_self);
const bool has_qnext_state = qnext_state_slots > 0;
GGML_ASSERT(kv_self.recurrent || has_qnext_state);
struct ggml_tensor * state_copy = build_inp_s_copy();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
if (kv_self.recurrent) {
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
// TODO: name the intermediate tensors with cb()
// TODO: name the intermediate tensors with cb()
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
}
if (kv_self.s_l.size() > (size_t) il && kv_self.s_l[il] != nullptr) {
struct ggml_tensor * qnext_states_all = ggml_reshape_2d(ctx0, kv_self.s_l[il], hparams.n_embd_v_s(), kv_self.s_l[il]->ne[1]);
GGML_ASSERT((uint32_t) qnext_states_all->ne[1] == qnext_state_slots);
struct ggml_tensor * qnext_state_copy = ggml_view_1d(ctx0, state_copy, qnext_state_slots, 0);
struct ggml_tensor * qnext_states = ggml_get_rows(ctx0, qnext_states_all, qnext_state_copy);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, qnext_states, kv_self.s_l[il]));
}
}
return gf;
@@ -198,6 +240,12 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector<uint32_t> & ids)
}
for (int il = 0; il < n_layer; ++il) {
if (model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il)) {
continue;
}
if (kv_self.k_l[il] == nullptr) {
continue;
}
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
@@ -214,7 +262,7 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector<uint32_t> & ids)
ggml_tensor * view_v_src = nullptr;
ggml_tensor * view_v_dst = nullptr;
if (kv_self.v_l.size() > il) {
if (kv_self.v_l.size() > il && kv_self.v_l[il] != nullptr) {
// Note: with MLA the V cache may not be present.
if (flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
@@ -509,12 +557,12 @@ void llm_build_context::llm_build_kv_store(
struct ggml_tensor * v_cache_view = nullptr;
if (cparams.flash_attn) {
if (!kv.v_trans) {
v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
(kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa));
lctx.cache_copies[2*il+1].step = ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa);
} else {
// note: the V cache is transposed when not using flash attention
// note: the V cache is transposed for legacy non-FA layouts
v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il]));
@@ -1454,12 +1502,21 @@ static ggml_tensor * llm_build_kqv(
} else {
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
struct ggml_tensor * v;
if (kv.v_trans) {
v = ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv.v_l[il])*n_ctx,
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
0);
} else {
v = ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
0);
v = ggml_cont(ctx, ggml_transpose(ctx, v));
}
cb(v, "v", il);
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
@@ -4248,6 +4305,822 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
return gf;
}
ggml_cgraph * llm_build_context::build_qwen3next() {
static constexpr int QWEN3NEXT_CHUNK_SIZE = 64;
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
GGML_ASSERT(batch.n_tokens > 0);
const bool has_explicit_seq_info = batch.n_seq_id != nullptr && batch.seq_id != nullptr;
std::vector<llama_seq_id> token_seq_ids(batch.n_tokens, 0);
for (int i = 0; i < batch.n_tokens; ++i) {
if (has_explicit_seq_info) {
GGML_ASSERT(batch.n_seq_id[i] > 0 && "qwen3next expects each token to belong to at least one sequence");
GGML_ASSERT(batch.n_seq_id[i] == 1 && "qwen3next does not support multi-sequence tokens yet");
token_seq_ids[i] = batch.seq_id[i][0];
} else {
token_seq_ids[i] = 0;
}
}
const llama_seq_id seq_id = token_seq_ids[0];
const bool all_same_seq = std::all_of(token_seq_ids.begin(), token_seq_ids.end(), [&](llama_seq_id s) {
return s == seq_id;
});
bool has_unique_seq_ids = true;
if (!all_same_seq) {
std::unordered_set<llama_seq_id> seen;
seen.reserve(token_seq_ids.size());
for (llama_seq_id s : token_seq_ids) {
if (!seen.insert(s).second) {
has_unique_seq_ids = false;
break;
}
}
}
GGML_ASSERT(hparams.ssm_n_group > 0);
GGML_ASSERT(hparams.ssm_dt_rank > 0);
GGML_ASSERT(hparams.ssm_d_conv > 0);
GGML_ASSERT(hparams.ssm_d_inner % hparams.ssm_dt_rank == 0);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads;
const int64_t key_dim = head_k_dim * num_k_heads;
const int64_t value_dim = head_v_dim * num_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
const int64_t conv_state_dim = (hparams.ssm_d_conv - 1) * conv_dim;
const int64_t ssm_state_dim = head_v_dim * head_v_dim * num_v_heads;
const int64_t state_dim = conv_state_dim + ssm_state_dim;
const uint32_t qnext_state_slots = llama_kv_qnext_state_slots(kv_self);
GGML_ASSERT(qnext_state_slots > 0);
GGML_ASSERT(hparams.n_embd_v_s() == (uint32_t) state_dim);
// Reserve-graph builds may not carry explicit sequence IDs, in which case
// the fallback sequence slot is 0.
const uint32_t state_seq_id = (uint32_t) seq_id;
for (llama_seq_id s : token_seq_ids) {
GGML_ASSERT(s >= 0);
GGML_ASSERT((uint32_t) s < qnext_state_slots);
}
const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
auto get_slice_2d = [&](ggml_tensor * t, int64_t c) -> ggml_tensor * {
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
};
auto build_delta_net_chunking = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(n_seqs == 1);
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_v, n_seqs);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
cb(beta, "beta_perm", il);
cb(g, "g_perm", il);
cb(state,"state_in", il);
const int64_t chunk_size = QWEN3NEXT_CHUNK_SIZE;
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
cb(q, "q_pad", il);
cb(k, "k_pad", il);
cb(v, "v_pad", il);
cb(beta, "beta_pad", il);
cb(g, "g_pad", il);
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
ggml_tensor * k_beta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, beta, k->ne[0], beta->ne[1], beta->ne[2], beta->ne[3]), k);
cb(v_beta, "v_beta", il);
cb(k_beta, "k_beta", il);
q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_v * n_seqs);
v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_v * n_seqs);
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
cb(g_cumsum, "g_cumsum", il);
ggml_tensor * gcs_i =
ggml_repeat_4d(ctx0, g_cumsum, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j_broadcast =
ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
cb(decay_mask, "decay_mask", il);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
decay_mask = ggml_exp(ctx0, decay_mask);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
cb(kmulkbeta, "kk_beta", il);
ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
cb(attn, "attn_pre_solve", il);
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
ggml_tensor * identity_repeat =
ggml_repeat_4d(ctx0, identity, attn_lower->ne[0], attn_lower->ne[1], attn_lower->ne[2], attn_lower->ne[3]);
ggml_tensor * lhs = ggml_neg(ctx0, ggml_sub(ctx0, attn_lower, identity_repeat));
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
attn = ggml_mul(ctx0, lin_solve, causal_mask);
attn = ggml_add(ctx0, attn, identity);
cb(attn, "attn_solved", il);
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
cb(v, "v_beta", il);
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
cb(g_cumsum_t, "g_cumsum_t", il);
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
cb(gexp, "gexp", il);
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
cb(kbeta_gexp, "kbeta_gexp", il);
auto attn_kbeta = ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)));
cb(attn_kbeta, "attn_kbeta", il);
ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, attn_kbeta));
cb(k_cumdecay, "k_cumdecay", il);
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
cb(attn_kq, "attn_kq_pre", il);
attn_kq = ggml_mul(ctx0, decay_mask, attn_kq);
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
cb(attn_kq, "attn_kq", il);
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
(g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
g_last = ggml_cont(ctx0, g_last);
cb(g_last, "g_last", il);
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
cb(g_last_exp, "g_last_exp", il);
ggml_tensor * g_last_repeat =
ggml_repeat_4d(ctx0, g_last, chunk_size, 1, n_chunks, H_v * n_seqs);
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last_repeat));
cb(g_diff, "g_diff", il);
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
ggml_tensor * key_gdiff = ggml_mul(ctx0, ggml_repeat_4d(ctx0, g_diff_exp_t, k->ne[0], g_diff_exp_t->ne[1], g_diff_exp_t->ne[2], g_diff_exp_t->ne[3]), k);
cb(key_gdiff, "key_gdiff", il);
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
cb(key_gdiff_t, "key_gdiff_t", il);
cb(state, "new_state", il);
ggml_tensor * core_attn_out = nullptr;
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
ggml_tensor * q_chunk = get_slice_2d(q, chunk);
ggml_tensor * v_chunk = get_slice_2d(v, chunk);
ggml_tensor * gexp_chunk = get_slice_2d(gexp, chunk);
ggml_tensor * k_cumdecay_chunk = get_slice_2d(k_cumdecay, chunk);
ggml_tensor * attn_chunk = get_slice_2d(attn_kq, chunk);
cb(attn_chunk, "attn_chunk", il);
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
//printf("v_prime_chunk: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", state_t->ne[0], state_t->ne[1], state_t->ne[2], state_t->ne[3], ggml_type_name(state_t->type),
// k_cumdecay_chunk->ne[0], k_cumdecay_chunk->ne[1], k_cumdecay_chunk->ne[2], k_cumdecay_chunk->ne[3], ggml_type_name(k_cumdecay_chunk->type));
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
cb(v_prime, "v_prime_chunk", il);
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
cb(v_new, "v_new_chunk", il);
ggml_tensor * q_g_exp = ggml_mul(ctx0, ggml_repeat_4d(ctx0, gexp_chunk, q_chunk->ne[0], gexp_chunk->ne[1], gexp_chunk->ne[2], gexp_chunk->ne[3]), q_chunk);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
cb(attn_inter, "attn_inter_chunk", il);
//printf("v_attn_chunk: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", v_new_t->ne[0], v_new_t->ne[1], v_new_t->ne[2], v_new_t->ne[3], ggml_type_name(v_new_t->type),
// attn_chunk->ne[0], attn_chunk->ne[1], attn_chunk->ne[2], attn_chunk->ne[3], ggml_type_name(attn_chunk->type));
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
cb(v_attn, "v_attn_chunk", il);
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
cb(core_attn_out_chunk, "core_attn_out_chunk", il);
core_attn_out = core_attn_out == nullptr
? core_attn_out_chunk
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
ggml_tensor * k_gdiff_t = get_slice_2d(key_gdiff_t, chunk);
//printf("kgdmulvnew: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", v_new_t->ne[0], v_new_t->ne[1], v_new_t->ne[2], v_new_t->ne[3], ggml_type_name(v_new_t->type),
// k_gdiff_t->ne[0], k_gdiff_t->ne[1], k_gdiff_t->ne[2], k_gdiff_t->ne[3], ggml_type_name(k_gdiff_t->type));
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
cb(kgdmulvnew, "kgdmulvnew", il);
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(g_last_exp, chunk));
state = ggml_add(ctx0,
ggml_mul(ctx0, state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
}
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
S_v, n_tokens, H_v, n_seqs,
ggml_row_size(core_attn_out->type, S_v),
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks),
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks * H_v), 0);
cb(output_tokens, "output_tokens", il);
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
output_tokens = ggml_cont(ctx0, output_tokens);
return {output_tokens, state};
};
auto build_delta_net_autoregressive = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(n_tokens == 1);
GGML_ASSERT(n_seqs == 1);
GGML_ASSERT(H_k == H_v);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
g_t = ggml_exp(ctx0, g_t);
state = ggml_mul(ctx0, state, g_t);
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
kv_mem = ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem));
cb(kv_mem, "kv_mem_t_cont", il);
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, kv_mem));
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);
cb(v_diff, "v_diff", il);
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
state = ggml_add(ctx0, state, k_t_delta);
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
state_q = ggml_cont(ctx0, ggml_transpose(ctx0, state_q));
cb(state_q, "state_q_t_cont", il);
ggml_tensor * core_attn_out = ggml_transpose(ctx0, ggml_sum_rows(ctx0, state_q));
cb(core_attn_out, "output_tokens", il);
cb(state, "new_state", il);
return {core_attn_out, state};
};
auto build_qkvz = [&](ggml_tensor * input, int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t n_tok = input->ne[1];
if (model.layers[il].wqkv) {
ggml_tensor * qkv_mixed = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, input);
cb(qkv_mixed, "qkv_mixed", il);
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_tok, 1);
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
ggml_tensor * z = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv_gate, input);
cb(z, "z", il);
return { qkv_mixed, z };
}
ggml_tensor * mixed_qkvz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, input);
cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
const int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tok, 1);
int64_t split_sizes_qkvz[4] = {
head_k_dim,
head_k_dim,
head_v_dim * num_v_heads / num_k_heads,
head_v_dim * num_v_heads / num_k_heads
};
ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
cb(query, "q", il);
ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
cb(key, "k", il);
ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
cb(value, "v", il);
ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tok, 1,
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
z = ggml_cont(ctx0, z);
cb(z, "z", il);
ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_tok, 1);
cb(query_flat, "query_flat", il);
ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_tok, 1);
cb(key_flat, "key_flat", il);
ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_tok, 1);
cb(value_flat, "value_flat", il);
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
cb(qkv_mixed, "qkv_mixed", il);
return { qkv_mixed, z };
};
auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * {
ggml_tensor * Qcur_full = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur_full, "Qcur_full", il);
Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
cb(Qcur, "Qcur", il);
cb(gate, "gate", il);
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur_reshaped", il);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cb(gate, "gate_reshaped", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf,
nullptr, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv,
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale,
cb, il);
cb(attn, "attn_pregate", il);
ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
cb(gate_sigmoid, "gate_sigmoid", il);
attn = ggml_mul(ctx0, attn, gate_sigmoid);
cb(attn, "attn_gated", il);
attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn);
cb(attn, "attn_output", il);
return attn;
};
auto build_layer_ffn = [&](ggml_tensor * cur, int il) -> ggml_tensor * {
const bool has_moe = model.layers[il].ffn_gate_inp != nullptr;
const bool has_dense = model.layers[il].ffn_gate != nullptr && model.layers[il].ffn_up != nullptr && model.layers[il].ffn_down != nullptr;
if (has_moe) {
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used, LLM_FFN_SILU,
true, false, 0.0f, LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il, gf, false);
cb(moe_out, "ffn_moe_out", il);
const bool has_shexp = model.layers[il].ffn_up_shexp != nullptr &&
model.layers[il].ffn_gate_shexp != nullptr &&
model.layers[il].ffn_down_shexp != nullptr &&
model.layers[il].ffn_gate_inp_shexp != nullptr;
if (has_shexp) {
ggml_tensor * ffn_shexp =
llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(ffn_shexp, "ffn_shexp", il);
ggml_tensor * shared_gate = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
cb(shared_gate, "shared_expert_gate", il);
if (shared_gate->ne[1] == 1) {
ffn_shexp = ggml_fused_mul_unary(ctx0, shared_gate, ffn_shexp, GGML_UNARY_OP_SIGMOID);
} else {
shared_gate = ggml_sigmoid(ctx0, shared_gate);
cb(shared_gate, "shared_expert_gate_sigmoid", il);
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
}
cb(ffn_shexp, "ffn_shexp_gated", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
} else {
cur = moe_out;
}
cb(cur, "ffn_out", il);
return cur;
}
GGML_ASSERT(has_dense);
cur = llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
return cur;
};
auto build_layer_attn_linear_core = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, ggml_tensor * inp_s_seq_qnext,
uint32_t state_seq_id_local, bool reset_state_local, int il) -> ggml_tensor * {
const int64_t n_tok = cur->ne[1];
auto qkvz = build_qkvz(cur, il);
ggml_tensor * qkv_mixed = qkvz.first;
ggml_tensor * z = qkvz.second;
ggml_tensor * mixed_ba = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta_alpha, cur);
cb(mixed_ba, "linear_attn_mixed_ba", il);
int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tok, 1);
int64_t split_sizes_ba[2] = {
num_v_heads / num_k_heads,
num_v_heads / num_k_heads
};
ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tok, 1,
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
cb(b, "b", il);
ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tok, 1,
mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
cb(a, "a", il);
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tok, 1);
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1);
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
cb(alpha_softplus, "a_softplus", il);
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);
cb(gate, "gate", il);
size_t state_row_size = 0;
ggml_tensor * state_all = nullptr;
GGML_ASSERT((size_t) il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr);
ggml_tensor * state_storage = kv_self.s_l[il];
GGML_ASSERT(state_storage->type == GGML_TYPE_F32);
GGML_ASSERT(state_storage->ne[0] >= state_dim);
GGML_ASSERT((uint32_t) state_storage->ne[1] == qnext_state_slots);
state_row_size = state_storage->nb[1];
GGML_ASSERT(ggml_nbytes(state_storage) >= state_row_size * qnext_state_slots);
state_all = ggml_view_2d(ctx0, state_storage, state_dim, qnext_state_slots, state_row_size, 0);
ggml_tensor * state_dst = ggml_view_2d(ctx0, state_all, state_dim, 1, state_row_size, state_seq_id_local * state_row_size);
ggml_tensor * state_f32 = state_dst;
if (state_f32->type != GGML_TYPE_F32) {
state_f32 = ggml_cast(ctx0, state_f32, GGML_TYPE_F32);
}
if (reset_state_local) {
state_f32 = ggml_scale(ctx0, state_f32, 0.0f);
}
ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_f32, conv_state_dim, 1, state_f32->nb[1], 0);
ggml_tensor * ssm_state_flat = ggml_view_2d(ctx0, state_f32, ssm_state_dim, 1, state_f32->nb[1],
conv_state_dim * ggml_element_size(state_f32));
ggml_tensor * conv_states = ggml_reshape_3d(ctx0, conv_state_flat, hparams.ssm_d_conv - 1, conv_dim, 1);
ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim, num_v_heads, 1);
cb(conv_states, "conv_states", il);
cb(state, "state_predelta", il);
ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, inp_s_seq_qnext);
cb(conv_output_raw, "conv_output_raw", il);
ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tok, conv_dim * ggml_element_size(conv_output_raw), 0);
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output);
cb(conv_output_silu, "conv_output_silu", il);
ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1], 0);
ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1],
key_dim * ggml_element_size(conv_output_silu));
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_output_silu, head_v_dim, num_v_heads, n_tok, 1,
ggml_row_size(conv_output_silu->type, head_v_dim),
conv_output_silu->nb[1],
conv_output_silu->nb[1] * n_tok,
2 * key_dim * ggml_element_size(conv_output_silu));
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tok, 1);
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tok, 1);
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tok, 1);
cb(q_conv, "q_conv_cont", il);
cb(k_conv, "k_conv_cont", il);
cb(v_conv, "v_conv_cont", il);
if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
const int64_t repeat_factor = num_v_heads / num_k_heads;
ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_tok);
ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_tok);
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
}
cb(q_conv, "q_conv_predelta", il);
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
GGML_ASSERT(causal_mask != nullptr);
GGML_ASSERT(identity != nullptr);
GGML_ASSERT(diag_mask != nullptr);
attn_out = n_tok == 1
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
cb(new_state, "new_state", il);
ggml_tensor * new_conv_states = ggml_view_2d(ctx0, conv_output_raw, hparams.ssm_d_conv - 1, conv_dim,
hparams.ssm_d_conv * ggml_element_size(conv_output_raw),
(1 + conv_dim * n_tok) * ggml_element_size(conv_output_raw));
ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, ggml_cont(ctx0, new_conv_states), conv_state_dim, 1);
ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, new_state, ssm_state_dim, 1);
ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0);
ggml_tensor * state_update = new_state_flat;
if (state_dst->type != GGML_TYPE_F32) {
state_update = ggml_cast(ctx0, state_update, state_dst->type);
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_update, state_dst));
ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_tok);
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_tok);
ggml_tensor * attn_out_norm = llm_build_norm(ctx0, attn_out_2d, hparams, model.layers[il].ssm_norm, nullptr, LLM_NORM_RMS, cb, il);
ggml_tensor * gated_silu = ggml_silu(ctx0, z_2d);
attn_out_norm = ggml_mul(ctx0, attn_out_norm, gated_silu);
ggml_tensor * final_output = ggml_reshape_2d(ctx0, attn_out_norm, value_dim, n_tok);
cb(final_output, "final_output", il);
ggml_tensor * out = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, final_output);
cb(out, "linear_attn_out", il);
return ggml_reshape_2d(ctx0, out, n_embd, n_tok);
};
auto build_layer_attn_linear = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il) -> ggml_tensor * {
GGML_ASSERT(lctx.inp_s_seq_qnext != nullptr);
if (all_same_seq) {
return build_layer_attn_linear_core(cur, causal_mask, identity, diag_mask, lctx.inp_s_seq_qnext, state_seq_id, reset_state, il);
}
GGML_ASSERT(has_unique_seq_ids && "qwen3next mixed-sequence batches require unique sequence IDs per token");
ggml_tensor * out = nullptr;
for (int64_t i = 0; i < n_tokens; ++i) {
ggml_tensor * cur_i = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (size_t) i * cur->nb[1]);
ggml_tensor * inp_s_seq_qnext_i = ggml_view_2d(ctx0, lctx.inp_s_seq_qnext, 1, 1, lctx.inp_s_seq_qnext->nb[1], (size_t) i * lctx.inp_s_seq_qnext->nb[1]);
const bool reset_state_i = batch.pos != nullptr && batch.pos[i] == 0;
const uint32_t state_seq_id_i = (uint32_t) token_seq_ids[i];
ggml_tensor * out_i = build_layer_attn_linear_core(cur_i, causal_mask, identity, diag_mask, inp_s_seq_qnext_i, state_seq_id_i, reset_state_i, il);
out = out == nullptr ? out_i : ggml_concat(ctx0, out, out_i, 1);
}
return out;
};
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
ggml_tensor * KQ_mask = build_inp_KQ_mask();
lctx.inp_s_seq_qnext = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 1, n_tokens);
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
ggml_set_input(lctx.inp_s_seq_qnext);
ggml_tensor * causal_mask = nullptr;
ggml_tensor * identity = nullptr;
ggml_tensor * diag_mask = nullptr;
causal_mask = ggml_tri(ctx0,
ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f));
diag_mask = ggml_add(ctx0, causal_mask, identity);
ggml_build_forward_expand(gf, causal_mask);
ggml_build_forward_expand(gf, identity);
ggml_build_forward_expand(gf, diag_mask);
ggml_tensor * cur = nullptr;
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
GGML_ASSERT(model.layers[il].attn_norm != nullptr);
GGML_ASSERT(model.layers[il].attn_post_norm != nullptr);
const bool has_moe = model.layers[il].ffn_gate_inp != nullptr;
const bool has_dense = model.layers[il].ffn_gate != nullptr &&
model.layers[il].ffn_up != nullptr &&
model.layers[il].ffn_down != nullptr;
GGML_ASSERT(has_moe || has_dense);
if (has_moe) {
GGML_ASSERT(model.layers[il].ffn_up_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_gate_exps != nullptr);
GGML_ASSERT(model.layers[il].ffn_down_exps != nullptr);
}
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
if (hparams.is_recurrent(il)) {
GGML_ASSERT(model.layers[il].ssm_conv1d != nullptr);
GGML_ASSERT(model.layers[il].ssm_dt != nullptr);
GGML_ASSERT(model.layers[il].ssm_a != nullptr);
GGML_ASSERT(model.layers[il].ssm_beta_alpha != nullptr);
GGML_ASSERT(model.layers[il].ssm_norm != nullptr);
GGML_ASSERT(model.layers[il].ssm_out != nullptr);
GGML_ASSERT(model.layers[il].wqkv != nullptr || model.layers[il].ssm_in != nullptr);
GGML_ASSERT(model.layers[il].wqkv_gate != nullptr || model.layers[il].ssm_in != nullptr);
cur = build_layer_attn_linear(cur, causal_mask, identity, diag_mask, il);
} else {
GGML_ASSERT(model.layers[il].wq != nullptr);
GGML_ASSERT(model.layers[il].wk != nullptr);
GGML_ASSERT(model.layers[il].wv != nullptr);
GGML_ASSERT(model.layers[il].wo != nullptr);
GGML_ASSERT(model.layers[il].attn_q_norm != nullptr);
GGML_ASSERT(model.layers[il].attn_k_norm != nullptr);
cur = build_layer_attn(cur, inp_pos, KQ_mask, il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
ggml_tensor * ffn_residual = cur;
ggml_tensor * attn_post_norm = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(attn_post_norm, "attn_post_norm", il);
cur = build_layer_ffn(attn_post_norm, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
cur = llm_build_norm(ctx0, inpL, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
ggml_cgraph * llm_build_context::build_qwen3vl() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -9273,6 +10146,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
{
result = llm.build_qwen3moe();
} break;
case LLM_ARCH_QWEN3NEXT:
{
result = llm.build_qwen3next();
} break;
case LLM_ARCH_QWEN3VL:
{
result = llm.build_qwen3vl();

View File

@@ -204,6 +204,8 @@ struct llm_build_context {
ggml_cgraph * build_qwen3vlmoe();
ggml_cgraph * build_qwen3next();
ggml_cgraph * build_phi2();
ggml_cgraph * build_phi3();

View File

@@ -56,6 +56,7 @@ struct llama_kv_cache {
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
std::vector<struct ggml_tensor *> s_l; // per layer recurrent state storage (Qwen3Next)
std::vector<llama_split_tensor> split_k_l;
std::vector<llama_split_tensor> split_v_l;
@@ -202,6 +203,7 @@ struct llama_context {
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_s_seq_qnext; // I32 [1, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]

View File

@@ -5,7 +5,7 @@
#include <map>
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
@@ -83,6 +83,7 @@ void llm_load_hparams(
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), false);
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
@@ -453,6 +454,28 @@ void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3NEXT:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Upstream convention: every 4th layer is full attention, others are recurrent.
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0);
}
switch (hparams.n_layer) {
case 48: model.type = e_model::MODEL_80B_A3B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3VLMOE:
{
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);

View File

@@ -89,6 +89,10 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
uint32_t ssm_n_group = 0;
// for hybrid state-space models (e.g. qwen3next)
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
@@ -169,6 +173,8 @@ struct llama_hparams {
if (this->ssm_d_inner != other.ssm_d_inner) return true;
if (this->ssm_d_state != other.ssm_d_state) return true;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->ssm_n_group != other.ssm_n_group) return true;
if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true;
if (this->dec_start_token_id != other.dec_start_token_id) return true;
@@ -246,6 +252,10 @@ struct llama_hparams {
}
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
if (ssm_n_group > 0) {
// qwen3next keeps all recurrent state in the V-cache tail
return 0;
}
// corresponds to Mamba's conv_states size
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
@@ -253,10 +263,26 @@ struct llama_hparams {
}
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
if (ssm_n_group > 0) {
// qwen3next recurrent state packs:
// 1) conv state: (d_conv - 1) * (2 * key_dim + value_dim)
// 2) delta-net state: head_v_dim * head_v_dim * num_v_heads
const uint32_t key_dim = ssm_d_state * ssm_n_group;
const uint32_t value_dim = ssm_d_inner;
const uint32_t conv_dim = 2 * key_dim + value_dim;
const uint32_t conv_state_dim = (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * conv_dim;
const uint32_t head_v_dim = ssm_dt_rank > 0 ? ssm_d_inner / ssm_dt_rank : 0;
const uint32_t ssm_state_dim = head_v_dim * head_v_dim * ssm_dt_rank;
return conv_state_dim + ssm_state_dim;
}
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}
bool is_recurrent(uint32_t il) const {
return il < n_layer ? recurrent_layer_arr[il] : false;
}
static bool is_float_close(float a, float b, float abs_tol) {
// Check for non-negative tolerance
if (abs_tol < 0.0) {

View File

@@ -73,6 +73,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
bool create_qwen3_moe_tensors(const LLM_TN & tn);
bool create_qwen3next_tensors(const LLM_TN & tn);
bool create_phi2_tensors(const LLM_TN & tn);
bool create_phi3_tensors(const LLM_TN & tn);
@@ -1291,6 +1293,99 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
return use_mmap_buffer;
}
bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (model.output == NULL) {
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
const bool has_moe_hparams = n_expert > 0 && n_expert_used > 0;
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : (has_moe_hparams ? n_ff / n_expert_used : n_ff);
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp;
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads;
const int64_t key_dim = head_k_dim * num_k_heads;
const int64_t value_dim = head_v_dim * num_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
const int64_t qkvz_dim = key_dim * 2 + value_dim * 2;
const int64_t ba_dim = num_v_heads * 2;
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
if (!hparams.is_recurrent(i)) {
// Full-attention layer
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head * 2});
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
} else {
// Recurrent linear-attention layer
layer.ssm_in = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, qkvz_dim},
llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, key_dim * 2 + value_dim},
llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, value_dim},
llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {hparams.ssm_d_conv, conv_dim});
layer.ssm_dt = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {hparams.ssm_dt_rank});
layer.ssm_a = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A_NOSCAN, i), {hparams.ssm_dt_rank});
layer.ssm_beta_alpha = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), {n_embd, ba_dim});
layer.ssm_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {head_v_dim});
layer.ssm_out = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {value_dim, n_embd});
}
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
// Dense FFN path (optional, e.g. mlp_only_layers)
layer.ffn_gate = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
// MoE path (optional per-layer)
layer.ffn_gate_inp = nullptr;
if (n_expert > 0) {
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
}
if (layer.ffn_gate_inp != nullptr) {
if (n_expert_used == 0) {
throw std::runtime_error("n_expert_used must be > 0 when QWEN3NEXT MoE tensors are present");
}
use_mmap_buffer &= !create_std_ffn_exps(n_embd, tn, i, llama_model_loader::TENSOR_NOT_REQUIRED, n_ff_exp);
}
// Shared expert path (optional per-layer)
layer.ffn_gate_inp_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (layer.ffn_gate_inp_shexp != nullptr) {
layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
}
}
return use_mmap_buffer;
}
bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
@@ -3221,6 +3316,8 @@ bool create_tensors_helper::create_tensors() {
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3VLMOE:
use_mmap_buffer = create_qwen3_moe_tensors(tn); break;
case LLM_ARCH_QWEN3NEXT:
use_mmap_buffer = create_qwen3next_tensors(tn); break;
case LLM_ARCH_PHI2:
use_mmap_buffer = create_phi2_tensors(tn); break;
case LLM_ARCH_PHI3:

View File

@@ -429,6 +429,39 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_QWEN3NEXT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_QWEN3VL,
{
@@ -1648,6 +1681,7 @@ const char * llama_model_type_name(e_model type) {
case MODEL_16B_A1B: return "16B.A1B";
case MODEL_21B_A3B: return "21B.A3B";
case MODEL_30B_A3B: return "30B.A3B";
case MODEL_80B_A3B: return "80B.A3B";
case MODEL_80B_A13B: return "80B.A13B";
case MODEL_100B_A6B: return "100B.A6B";
case MODEL_106B_A12B: return "106B.A12B";

View File

@@ -107,6 +107,7 @@ enum e_model {
MODEL_16B_A1B,
MODEL_21B_A3B, // Ernie MoE small
MODEL_30B_A3B,
MODEL_80B_A3B, // Qwen3-Next
MODEL_80B_A13B,
MODEL_100B_A6B,
MODEL_106B_A12B,
@@ -289,6 +290,8 @@ struct llama_layer {
struct ggml_tensor * ssm_x = nullptr;
struct ggml_tensor * ssm_dt = nullptr;
struct ggml_tensor * ssm_out = nullptr;
struct ggml_tensor * ssm_norm = nullptr;
struct ggml_tensor * ssm_beta_alpha = nullptr;
// mamba
struct ggml_tensor * ssm_conv1d = nullptr;

View File

@@ -568,9 +568,15 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
bool llama_context::update_cache_copies() {
int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
auto layer_has_attention_kv = [&](int il) {
return !(model.arch == LLM_ARCH_QWEN3NEXT && model.hparams.is_recurrent(il));
};
if ((int)kv_self.k_l.size() != n_layer) return false;
if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false;
for (int il = 0; il < n_layer; ++il) {
if (!layer_has_attention_kv(il) || kv_self.k_l[il] == nullptr) {
continue;
}
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
if (kl) {
GGML_ASSERT(model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN);
@@ -597,6 +603,9 @@ bool llama_context::update_cache_copies() {
}
} else {
for (int il = 0; il < n_layer; ++il) {
if (!layer_has_attention_kv(il) || kv_self.k_l[il] == nullptr) {
continue;
}
auto& c = cache_copies[2*il+0];
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false;
c.cpy->view_offs = kv_self.head*c.step;
@@ -605,6 +614,9 @@ bool llama_context::update_cache_copies() {
}
if (kv_self.v_l.empty()) return true;
for (int il = 0; il < n_layer; ++il) {
if (!layer_has_attention_kv(il) || kv_self.v_l[il] == nullptr) {
continue;
}
auto& c = cache_copies[2*il+1];
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false;
c.cpy->view_offs = kv_self.head*c.step;
@@ -640,6 +652,58 @@ llama_context::~llama_context() {
// kv cache helpers
//
static inline bool llama_qwen3next_is_recurrent_layer(
const llama_model & model,
const llama_hparams & hparams,
uint32_t il) {
return model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il);
}
static inline uint32_t llama_kv_v_row_embd(
const llama_model & model,
const llama_hparams & hparams,
uint32_t il) {
// qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence),
// so per-token V rows include only attention values.
if (model.arch == LLM_ARCH_QWEN3NEXT) {
return hparams.n_embd_v_gqa(il);
}
return hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
}
static inline uint32_t llama_qwen3next_state_slots(const llama_cparams & cparams, uint32_t kv_size) {
return std::min<uint32_t>(std::max<uint32_t>(1, cparams.n_seq_max), kv_size);
}
static inline uint32_t llama_kv_qnext_state_slots(const llama_kv_cache & cache) {
uint32_t n_slots = 0;
for (const ggml_tensor * t : cache.s_l) {
if (t == nullptr) {
continue;
}
const uint32_t layer_slots = (uint32_t) t->ne[1];
if (n_slots == 0) {
n_slots = layer_slots;
} else {
GGML_ASSERT(n_slots == layer_slots);
}
}
return n_slots;
}
static inline bool llama_kv_has_qnext_state_storage(const llama_kv_cache & cache) {
return llama_kv_qnext_state_slots(cache) > 0;
}
static inline bool llama_kv_qnext_seq_id_in_range(const llama_kv_cache & cache, llama_seq_id seq_id) {
const uint32_t n_slots = llama_kv_qnext_state_slots(cache);
return n_slots > 0 && seq_id >= 0 && (uint32_t) seq_id < n_slots;
}
static bool llama_kv_cache_init(
struct llama_kv_cache & cache,
const llama_context * ctx,
@@ -658,7 +722,9 @@ static bool llama_kv_cache_init(
// TODO: find a nicer way to add other recurrent model architectures
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
// qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in
// standard layout to match the mainline hybrid path when flash attention is off.
cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT;
cache.head = 0;
cache.size = kv_size;
@@ -670,7 +736,7 @@ static bool llama_kv_cache_init(
cache.cells.clear();
cache.cells.resize(kv_size);
if (cache.recurrent) {
if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT) {
// init state copy sources
for (uint32_t i = 0; i < cache.size; ++i) {
cache.cells[i].src = i;
@@ -750,18 +816,27 @@ static bool llama_kv_cache_init(
needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn;
}
if (needs_v_cache) cache.v_l.reserve(n_layer);
cache.s_l.reserve(n_layer);
std::vector<size_t> mem_split(model.splits.size(), 0);
const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size);
if (model.arch == LLM_ARCH_QWEN3NEXT && qnext_state_slots < std::max<uint32_t>(1, cparams.n_seq_max)) {
LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n",
__func__, std::max<uint32_t>(1, cparams.n_seq_max), qnext_state_slots);
}
int n_mla = 0;
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const bool qnext_recurrent = llama_qwen3next_is_recurrent_layer(model, hparams, i);
const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i);
const uint32_t n_head_kv = hparams.n_head_kv(i);
const uint32_t n_embd_head_k= hparams.n_embd_head_k;
struct ggml_context * ctx = split_cache ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
ggml_tensor * v;
ggml_tensor * k = nullptr;
ggml_tensor * v = nullptr;
ggml_tensor * s = nullptr;
if (is_mla_attn && cparams.mla_attn) {
// DeepSeek MLA
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
@@ -792,56 +867,70 @@ static bool llama_kv_cache_init(
ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
split_cache_i = false;
}
int n_embd_head_v = hparams.n_embd_head_v;
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
auto v_name = std::string{"cache_v_l"} + std::to_string(i);
ggml_set_name(k, k_name.c_str());
ggml_set_name(v, v_name.c_str());
//ggml_format_name(k, "cache_k_l%d", i);
//ggml_format_name(v, "cache_v_l%d", i);
if (qnext_recurrent) {
s = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_embd_v_s(), qnext_state_slots);
split_cache_i = false;
} else {
int n_embd_head_v = hparams.n_embd_head_v;
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
int64_t v_ne = int64_t(n_embd_v_row)*kv_size;
v = ggml_new_tensor_1d(ctx, type_v, v_ne);
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
auto v_name = std::string{"cache_v_l"} + std::to_string(i);
ggml_set_name(k, k_name.c_str());
ggml_set_name(v, v_name.c_str());
//ggml_format_name(k, "cache_k_l%d", i);
//ggml_format_name(v, "cache_v_l%d", i);
if (split_cache_i) {
bool use_V_for_K = model.layers[i].attn_k_norm && model.layers[i].attn_k_norm->ne[0] == K->ne[1] ? true : false;
auto extra_K = (const ggml_split_tensor_t *)K->extra;
auto extra_V = (const ggml_split_tensor_t *)V->extra;
auto & split_k_l = cache.split_k_l.emplace_back();
auto & split_v_l = cache.split_v_l.emplace_back();
split_k_l.tensor_splits.resize(extra_K->n_device, nullptr);
split_v_l.tensor_splits.resize(extra_V->n_device, nullptr);
for (int is = 0; is < extra_K->n_device; ++is) {
auto split = use_V_for_K ? extra_V->splits[is] : extra_K->splits[is];
if (!split) continue;
int nhead_kv = use_V_for_K ? split->ne[1] / n_embd_head_v : split->ne[1]/n_embd_head_k;
if (use_V_for_K) {
LLAMA_LOG_DEBUG("K_cache(%d, %d): using %d instead of %ld heads\n",
i, is, nhead_kv, extra_K->splits[is]->ne[1]/n_embd_head_k);
}
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, nhead_kv * kv_size);
auto split_name = k_name + '.' + std::to_string(is);
ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]);
}
split_k_l.ggml.n_device = extra_K->n_device;
split_k_l.ggml.split_dim = 0;
split_k_l.ggml.splits = split_k_l.tensor_splits.data();
for (int is = 0; is < extra_V->n_device; ++is) {
auto split = extra_V->splits[is];
if (!split) continue;
split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size);
auto split_name = v_name + '.' + std::to_string(is);
ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_v_l.tensor_splits[is]);
}
split_v_l.ggml.n_device = extra_V->n_device;
split_v_l.ggml.split_dim = 0;
split_v_l.ggml.splits = split_v_l.tensor_splits.data();
k->extra = (void *)&split_k_l.ggml;
v->extra = (void *)&split_v_l.ggml;
}
}
if (s) {
auto s_name = std::string{"cache_s_l"} + std::to_string(i);
ggml_set_name(s, s_name.c_str());
}
cache.k_l.push_back(k);
cache.v_l.push_back(v);
if (split_cache_i) {
bool use_V_for_K = model.layers[i].attn_k_norm && model.layers[i].attn_k_norm->ne[0] == K->ne[1] ? true : false;
auto extra_K = (const ggml_split_tensor_t *)K->extra;
auto extra_V = (const ggml_split_tensor_t *)V->extra;
auto & split_k_l = cache.split_k_l.emplace_back();
auto & split_v_l = cache.split_v_l.emplace_back();
split_k_l.tensor_splits.resize(extra_K->n_device, nullptr);
split_v_l.tensor_splits.resize(extra_V->n_device, nullptr);
for (int is = 0; is < extra_K->n_device; ++is) {
auto split = use_V_for_K ? extra_V->splits[is] : extra_K->splits[is];
if (!split) continue;
int nhead_kv = use_V_for_K ? split->ne[1] / n_embd_head_v : split->ne[1]/n_embd_head_k;
if (use_V_for_K) {
LLAMA_LOG_DEBUG("K_cache(%d, %d): using %d instead of %ld heads\n",
i, is, nhead_kv, extra_K->splits[is]->ne[1]/n_embd_head_k);
}
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, nhead_kv * kv_size);
auto split_name = k_name + '.' + std::to_string(is);
ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]);
}
split_k_l.ggml.n_device = extra_K->n_device;
split_k_l.ggml.split_dim = 0;
split_k_l.ggml.splits = split_k_l.tensor_splits.data();
for (int is = 0; is < extra_V->n_device; ++is) {
auto split = extra_V->splits[is];
if (!split) continue;
split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size);
auto split_name = v_name + '.' + std::to_string(is);
ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_v_l.tensor_splits[is]);
}
split_v_l.ggml.n_device = extra_V->n_device;
split_v_l.ggml.split_dim = 0;
split_v_l.ggml.splits = split_v_l.tensor_splits.data();
k->extra = (void *)&split_k_l.ggml;
v->extra = (void *)&split_v_l.ggml;
}
}
cache.s_l.push_back(s);
}
if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
@@ -1017,6 +1106,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].src = i;
cache.cells[i].seq_id.clear();
}
cache.head = 0;
@@ -1056,6 +1146,8 @@ static bool llama_kv_cache_seq_rm(
}
}
const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache);
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
@@ -1070,6 +1162,9 @@ static bool llama_kv_cache_seq_rm(
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
if (has_qnext_state) {
cache.cells[i].src = i;
}
if (new_head == cache.size) new_head = i;
}
}
@@ -1111,6 +1206,21 @@ static void llama_kv_cache_seq_cp(
}
return;
}
const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache);
if (has_qnext_state &&
llama_kv_qnext_seq_id_in_range(cache, seq_id_dst) &&
llama_kv_qnext_seq_id_in_range(cache, seq_id_src) &&
(uint32_t) seq_id_dst < cache.size &&
(uint32_t) seq_id_src < cache.size) {
seq_id_src = cache.cells[seq_id_src].src;
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
cache.cells[seq_id_dst].src = seq_id_src;
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
cache.do_copy = true;
}
// otherwise, this is the KV cache of a Transformer-like model
cache.head = 0;
@@ -1124,11 +1234,15 @@ static void llama_kv_cache_seq_cp(
static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
uint32_t new_head = cache.size;
const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache);
for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
if (has_qnext_state) {
cache.cells[i].src = i;
}
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
@@ -2764,6 +2878,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (lctx.inp_s_seq_qnext) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq_qnext->buffer));
int32_t * data = (int32_t *) lctx.inp_s_seq_qnext->data;
for (int64_t j = 0; j < n_tokens; ++j) {
// qwen3next linear-attention path uses a single local recurrent state slot.
data[j] = 0;
}
}
if (lctx.inp_pos_bucket) {
const int64_t n_tokens = batch.n_tokens;
@@ -3012,11 +3138,51 @@ static int llama_decode_internal(
}
}
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
bool warned_qnext_mixed_repeat = false;
for (uint32_t cur_token = 0; cur_token < n_tokens_all; ) {
#if IK_PRINT_TIMING
auto tim1 = ggml_time_us();
#endif
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
if (model.arch == LLM_ARCH_QWEN3NEXT &&
n_tokens > 1 &&
batch_all.n_seq_id != nullptr &&
batch_all.seq_id != nullptr) {
bool can_check = true;
bool any_diff = false;
bool has_dup = false;
llama_seq_id first_seq_id = 0;
std::unordered_set<llama_seq_id> seen_seq_ids;
seen_seq_ids.reserve(n_tokens);
for (uint32_t i = 0; i < n_tokens; ++i) {
const uint32_t idx = cur_token + i;
if (batch_all.n_seq_id[idx] <= 0 || batch_all.seq_id[idx] == nullptr) {
can_check = false;
break;
}
const llama_seq_id seq_id_i = batch_all.seq_id[idx][0];
if (i == 0) {
first_seq_id = seq_id_i;
} else if (seq_id_i != first_seq_id) {
any_diff = true;
}
if (!seen_seq_ids.insert(seq_id_i).second) {
has_dup = true;
}
}
if (can_check && any_diff && has_dup) {
n_tokens = 1;
if (!warned_qnext_mixed_repeat) {
LLAMA_LOG_WARN("%s: qwen3next mixed-sequence batch contains repeated seq_id values; falling back to single-token chunking\n", __func__);
warned_qnext_mixed_repeat = true;
}
}
}
llama_batch u_batch = {
/* .n_tokens = */ (int32_t) n_tokens,
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
@@ -3293,6 +3459,7 @@ static int llama_decode_internal(
#endif
}
n_outputs_prev += lctx.n_outputs;
cur_token += n_tokens;
}
// set to total number of outputs in the batch, for use in llama_get_logits_ith
@@ -3766,7 +3933,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
}
}
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
if ((lctx.kv_self.recurrent || llama_kv_has_qnext_state_storage(lctx.kv_self)) && lctx.kv_self.do_copy) {
{
lctx.reset_scheduler();
@@ -4787,11 +4954,15 @@ struct llama_context * llama_init_from_model(
size_t memory_size_v = 0;
for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k);
if (k) {
memory_size_k += ggml_nbytes(k);
}
}
for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v);
if (v) {
memory_size_v += ggml_nbytes(v);
}
}
if (memory_size_k + memory_size_v > 0) {
@@ -4918,7 +5089,7 @@ struct llama_context * llama_init_from_model(
}
if (params.only_active_experts) {
LLAMA_LOG_INFO("XXXXXXXXXXXXXXXXXXXXX Setting only active experts offload\n");
LLAMA_LOG_INFO("%s: enabling only_active_experts scheduling\n", __func__);
ggml_backend_sched_set_only_active_experts(ctx->sched, true);
}
if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH && (!model->has_tensor_overrides() || cparams.split_mode_graph_scheduling)) {
@@ -5031,6 +5202,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_QWEN2MOE:
case LLM_ARCH_QWEN3:
case LLM_ARCH_QWEN3MOE:
case LLM_ARCH_QWEN3NEXT:
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3:
case LLM_ARCH_GEMMA:
@@ -5586,7 +5758,7 @@ struct llama_data_write {
}
}
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
const struct llama_kv_cache & kv_self = ctx->kv_self;
const struct llama_hparams & hparams = ctx->model.hparams;
@@ -5599,23 +5771,30 @@ struct llama_data_write {
write(&v_state, sizeof(v_state));
write(&n_layer, sizeof(n_layer));
std::vector<uint8_t> tmp_buf;
// Iterate and write all the keys first, each row is a cell
// Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
const bool has_k_cache = kv_self.k_l[il] != nullptr;
// Write key type
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
const int32_t k_type_i = has_k_cache ? (int32_t) kv_self.k_l[il]->type : -1;
write(&k_type_i, sizeof(k_type_i));
// Write row size of key
const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
const uint64_t k_size_row = has_k_cache
? ((ctx->cparams.mla_attn == 0)
? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)
: ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope))
: 0;
write(&k_size_row, sizeof(k_size_row));
if (!has_k_cache) {
continue;
}
// Read each range of cells of k_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
@@ -5626,16 +5805,21 @@ struct llama_data_write {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
write(&v_type_i, sizeof(v_type_i));
// Write row size of value
const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
const uint64_t v_size_row = has_v_cache ? ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa) : 0;
write(&v_size_row, sizeof(v_size_row));
if (!has_v_cache) {
continue;
}
// Read each range of cells of v_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
@@ -5648,18 +5832,24 @@ struct llama_data_write {
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1;
write(&v_type_i, sizeof(v_type_i));
// Write element size
const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
const uint32_t v_size_el = has_v_cache ? ggml_type_size(kv_self.v_l[il]->type) : 0;
write(&v_size_el, sizeof(v_size_el));
// Write GQA embedding size
write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
const uint32_t n_embd_v_gqa_write = has_v_cache ? n_embd_v_gqa : 0;
write(&n_embd_v_gqa_write, sizeof(n_embd_v_gqa_write));
if (!has_v_cache) {
continue;
}
// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
@@ -5673,6 +5863,42 @@ struct llama_data_write {
}
}
}
const uint32_t qnext_state = llama_kv_has_qnext_state_storage(kv_self) ? 1 : 0;
write(&qnext_state, sizeof(qnext_state));
if (qnext_state != 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const bool has_s_cache = il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr;
const int32_t s_type_i = has_s_cache ? (int32_t) kv_self.s_l[il]->type : -1;
write(&s_type_i, sizeof(s_type_i));
const uint64_t s_size_row = has_s_cache ? ggml_row_size(kv_self.s_l[il]->type, kv_self.s_l[il]->ne[0]) : 0;
write(&s_size_row, sizeof(s_size_row));
uint32_t s_rows = 0;
size_t s_offset = 0;
if (has_s_cache) {
const uint32_t n_slots = (uint32_t) kv_self.s_l[il]->ne[1];
if (seq_id == -1) {
s_rows = n_slots;
} else if (llama_kv_qnext_seq_id_in_range(kv_self, seq_id) && (uint32_t) seq_id < kv_self.size) {
llama_seq_id src_seq_id = kv_self.cells[seq_id].src;
if (llama_kv_qnext_seq_id_in_range(kv_self, src_seq_id)) {
s_rows = 1;
s_offset = (size_t) src_seq_id * s_size_row;
}
}
}
write(&s_rows, sizeof(s_rows));
if (has_s_cache && s_rows > 0) {
write_tensor_data(kv_self.s_l[il], s_offset, s_rows * s_size_row, il);
}
}
}
}
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
@@ -5711,7 +5937,7 @@ struct llama_data_write {
write(&cell_count, sizeof(cell_count));
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
write_kv_cache_data(ctx, cell_ranges);
write_kv_cache_data(ctx, cell_ranges, seq_id);
}
};
@@ -5922,7 +6148,7 @@ struct llama_data_read {
GGML_ASSERT(sum_split_row_size == row_size);
}
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) {
const struct llama_hparams & hparams = ctx->model.hparams;
struct llama_kv_cache & kv_self = ctx->kv_self;
@@ -5954,20 +6180,35 @@ struct llama_data_read {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
const bool has_k_cache = kv_self.k_l[il] != nullptr;
// Read type of key
int32_t k_type_i_ref;
read_to(&k_type_i_ref, sizeof(k_type_i_ref));
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
if (k_type_i != k_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
return false;
if (!has_k_cache) {
if (k_type_i_ref != -1) {
LLAMA_LOG_ERROR("%s: missing key cache for layer %d\n", __func__, il);
return false;
}
} else {
const int32_t k_type_i = (int32_t) kv_self.k_l[il]->type;
if (k_type_i != k_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
return false;
}
}
// Read row size of key
uint64_t k_size_row_ref;
read_to(&k_size_row_ref, sizeof(k_size_row_ref));
if (!has_k_cache) {
if (k_size_row_ref != 0) {
LLAMA_LOG_ERROR("%s: expected empty key row size for layer %d\n", __func__, il);
return false;
}
continue;
}
const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
if (k_size_row != k_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
@@ -5986,20 +6227,35 @@ struct llama_data_read {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
// Read type of value
int32_t v_type_i_ref;
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
if (!has_v_cache) {
if (v_type_i_ref != -1) {
LLAMA_LOG_ERROR("%s: missing value cache for layer %d\n", __func__, il);
return false;
}
} else {
const int32_t v_type_i = (int32_t) kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
}
}
// Read row size of value
uint64_t v_size_row_ref;
read_to(&v_size_row_ref, sizeof(v_size_row_ref));
if (!has_v_cache) {
if (v_size_row_ref != 0) {
LLAMA_LOG_ERROR("%s: expected empty value row size for layer %d\n", __func__, il);
return false;
}
continue;
}
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
@@ -6019,35 +6275,58 @@ struct llama_data_read {
else if (v_state == 1) {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il);
const bool has_v_cache = kv_self.v_l[il] != nullptr;
// Read type of value
int32_t v_type_i_ref;
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
if (!has_v_cache) {
if (v_type_i_ref != -1) {
LLAMA_LOG_ERROR("%s: missing transposed value cache for layer %d\n", __func__, il);
return false;
}
} else {
const int32_t v_type_i = (int32_t) kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
}
}
// Read element size of value
uint32_t v_size_el_ref;
read_to(&v_size_el_ref, sizeof(v_size_el_ref));
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
return false;
if (!has_v_cache) {
if (v_size_el_ref != 0) {
LLAMA_LOG_ERROR("%s: expected empty transposed value element size for layer %d\n", __func__, il);
return false;
}
} else {
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
return false;
}
}
// Read GQA embedding size
uint32_t n_embd_v_gqa_ref;
read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
if (!has_v_cache) {
if (n_embd_v_gqa_ref != 0) {
LLAMA_LOG_ERROR("%s: expected empty transposed value rows for layer %d\n", __func__, il);
return false;
}
continue;
}
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
return false;
}
if (cell_count) {
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (kv_self.v_l[il]->extra) {
throw std::runtime_error("Transposed V cache is not sypported with split mode 'graph'");
}
@@ -6059,6 +6338,76 @@ struct llama_data_read {
}
}
}
uint32_t qnext_state_ref = 0;
read_to(&qnext_state_ref, sizeof(qnext_state_ref));
const bool has_qnext_state = llama_kv_has_qnext_state_storage(kv_self);
if ((qnext_state_ref != 0) != has_qnext_state) {
LLAMA_LOG_ERROR("%s: incompatible qwen3next state cache presence\n", __func__);
return false;
}
if (qnext_state_ref != 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const bool has_s_cache = il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr;
int32_t s_type_i_ref;
read_to(&s_type_i_ref, sizeof(s_type_i_ref));
if (!has_s_cache) {
if (s_type_i_ref != -1) {
LLAMA_LOG_ERROR("%s: missing qwen3next state cache for layer %d\n", __func__, il);
return false;
}
} else {
const int32_t s_type_i = (int32_t) kv_self.s_l[il]->type;
if (s_type_i != s_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched qwen3next state type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
return false;
}
}
uint64_t s_size_row_ref;
read_to(&s_size_row_ref, sizeof(s_size_row_ref));
const uint64_t s_size_row = has_s_cache ? ggml_row_size(kv_self.s_l[il]->type, kv_self.s_l[il]->ne[0]) : 0;
if (s_size_row != s_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched qwen3next state row size (%zu != %zu, layer %d)\n",
__func__, (size_t) s_size_row, (size_t) s_size_row_ref, il);
return false;
}
uint32_t s_rows_ref;
read_to(&s_rows_ref, sizeof(s_rows_ref));
uint32_t s_rows = 0;
uint32_t s_dst_row = 0;
if (has_s_cache) {
const uint32_t n_slots = (uint32_t) kv_self.s_l[il]->ne[1];
if (seq_id == -1) {
s_rows = n_slots;
} else if (llama_kv_qnext_seq_id_in_range(kv_self, seq_id)) {
s_rows = 1;
s_dst_row = (uint32_t) seq_id;
}
}
if (s_rows_ref != s_rows) {
LLAMA_LOG_ERROR("%s: mismatched qwen3next state row count (%u != %u, layer %d)\n", __func__, s_rows, s_rows_ref, il);
return false;
}
if (s_rows > 0) {
const size_t s_data_size = s_rows * s_size_row;
const size_t s_dst_offset = (size_t) s_dst_row * s_size_row;
if (kv_self.s_l[il]->extra) {
read_kv_cache_data_split(ctx, kv_self.s_l[il], read(s_data_size), s_dst_row, s_size_row, s_rows, il);
} else {
ggml_backend_tensor_set(kv_self.s_l[il], read(s_data_size), s_dst_offset, s_data_size);
}
}
}
}
return true;
}
@@ -6066,7 +6415,7 @@ struct llama_data_read {
uint32_t cell_count;
read_to(&cell_count, sizeof(cell_count));
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id);
if (!res) {
if (seq_id == -1) {