From 5a6c4e8da557c4388007f50c1a6532bb26b8a9f6 Mon Sep 17 00:00:00 2001 From: yurko Date: Sat, 7 Feb 2026 14:00:09 -0800 Subject: [PATCH] qwen3next: keep recurrent state in 4d layout through delta path --- src/llama-build-context.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index b55795e9..5e9a241d 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -4199,7 +4199,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() { GGML_ASSERT(k->ne[2] == n_tokens); GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); GGML_ASSERT(H_k == H_v); const float eps_norm = hparams.f_norm_rms_eps; @@ -4223,8 +4223,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() { g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - cb(q, "q_perm", il); cb(k, "k_perm", il); cb(v, "v_perm", il); @@ -4400,7 +4398,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() { GGML_ASSERT(n_tokens == 1); GGML_ASSERT(n_seqs == 1); GGML_ASSERT(H_k == H_v); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); const float eps_norm = hparams.f_norm_rms_eps; q = ggml_l2_norm(ctx0, q, eps_norm); @@ -4417,8 +4415,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() { cb(beta, "beta_in", il); cb(g, "g_in", il); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); @@ -4697,7 +4693,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() { conv_state_dim * ggml_element_size(state_f32)); ggml_tensor * conv_states = ggml_reshape_3d(ctx0, conv_state_flat, hparams.ssm_d_conv - 1, conv_dim, 1); - ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim * num_v_heads, 1, 1); + ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim, num_v_heads, 1); cb(conv_states, "conv_states", il); cb(state, "state_predelta", il);