diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 2b953a87..0463fc9e 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -4415,6 +4415,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() { auto build_delta_net_autoregressive = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, int il) -> std::pair { + const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1]; const int64_t n_tokens = q->ne[2]; const int64_t n_seqs = q->ne[3]; @@ -4431,38 +4432,42 @@ ggml_cgraph * llm_build_context::build_qwen3next() { q = ggml_l2_norm(ctx0, q, eps_norm); k = ggml_l2_norm(ctx0, k, eps_norm); - const float scale = 1.0f / sqrtf(S_v); + const float scale = 1.0f / sqrtf(S_k); q = ggml_scale(ctx0, q, scale); beta = ggml_sigmoid(ctx0, beta); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 1, 2, 0, 3); + cb(q, "q_in", il); cb(k, "k_in", il); cb(v, "v_in", il); cb(beta, "beta_in", il); cb(g, "g_in", il); - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_v, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_v, n_seqs); g_t = ggml_exp(ctx0, g_t); state = ggml_mul(ctx0, state, g_t); - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); + state = ggml_cont(ctx0, ggml_transpose(ctx0, state)); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k); + kv_mem = ggml_sum_rows(ctx0, kv_mem); - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); + ggml_tensor * v_diff = ggml_sub(ctx0, v, kv_mem); ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k, S_v, S_v, H_v, n_seqs), delta); + state = ggml_add(ctx0, state, k_t_delta); - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + ggml_tensor * state_q = ggml_mul(ctx0, state, q); + ggml_tensor * core_attn_out = ggml_sum_rows(ctx0, state_q); + + core_attn_out = ggml_transpose(ctx0, core_attn_out); + state = ggml_transpose(ctx0, state); cb(core_attn_out, "output_tokens", il); cb(state, "new_state", il);