qwen3next: align autoregressive delta-net decode layout

This commit is contained in:
yurko
2026-02-08 19:53:33 -08:00
parent 48e0e351ce
commit 9241164a5e

View File

@@ -4415,6 +4415,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
auto build_delta_net_autoregressive = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
@@ -4431,38 +4432,42 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
const float scale = 1.0f / sqrtf(S_k);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 1, 2, 0, 3);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_v, n_seqs);
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_v, n_seqs);
g_t = ggml_exp(ctx0, g_t);
state = ggml_mul(ctx0, state, g_t);
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
state = ggml_cont(ctx0, ggml_transpose(ctx0, state));
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k);
kv_mem = ggml_sum_rows(ctx0, kv_mem);
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);
ggml_tensor * v_diff = ggml_sub(ctx0, v, kv_mem);
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
state = ggml_add(ctx0, state, k_t_delta);
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k, S_v, S_v, H_v, n_seqs), delta);
state = ggml_add(ctx0, state, k_t_delta);
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
ggml_tensor * core_attn_out =
ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
ggml_tensor * state_q = ggml_mul(ctx0, state, q);
ggml_tensor * core_attn_out = ggml_sum_rows(ctx0, state_q);
core_attn_out = ggml_transpose(ctx0, core_attn_out);
state = ggml_transpose(ctx0, state);
cb(core_attn_out, "output_tokens", il);
cb(state, "new_state", il);