qwen3next: make fused delta safe by default and fix fused tensor layout

This commit is contained in:
yurko
2026-02-08 00:06:29 -08:00
parent 143e88ae77
commit 64099e71c0
3 changed files with 54 additions and 21 deletions

View File

@@ -4180,15 +4180,15 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
enum class qwen3next_fused_delta_mode {
off,
on,
tok1_only,
tok_gt1,
all_tokens,
};
// Keep legacy DeltaNet path as default for correctness.
// LLAMA_QWEN3NEXT_FUSED_DELTA values:
// unset / 0 : off
// 1 : fused for all token counts
// 2 : fused only for single-token decode steps
// 1 : fused only for n_tok > 1 (safer; avoids known decode regression)
// 2 : fused for all token counts (experimental)
const qwen3next_fused_delta_mode fused_delta_mode = []() {
const char * env = std::getenv("LLAMA_QWEN3NEXT_FUSED_DELTA");
if (env == nullptr || env[0] == '\0') {
@@ -4201,14 +4201,13 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
case 'Y':
case 't':
case 'T':
return qwen3next_fused_delta_mode::on;
return qwen3next_fused_delta_mode::tok_gt1;
case '2':
return qwen3next_fused_delta_mode::tok1_only;
return qwen3next_fused_delta_mode::all_tokens;
default:
return qwen3next_fused_delta_mode::off;
}
}();
const bool use_fused_delta_net_full = fused_delta_mode == qwen3next_fused_delta_mode::on;
auto get_slice_2d = [&](ggml_tensor * t, int64_t c) -> ggml_tensor * {
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
@@ -4503,14 +4502,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);
cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
@@ -4521,8 +4512,8 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 1, 3, 0, 2), n_tokens, 1, H_k, n_seqs);
beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 1, 2, 0, 3), 1, n_tokens, H_k, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3), 1, n_tokens, H_k, n_seqs);
ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs);
if (!ggml_is_contiguous(state_flat)) {
@@ -4853,8 +4844,8 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
const bool use_fused_delta_net =
use_fused_delta_net_full ||
(fused_delta_mode == qwen3next_fused_delta_mode::tok1_only && n_tok == 1);
(fused_delta_mode == qwen3next_fused_delta_mode::tok_gt1 && n_tok > 1) ||
(fused_delta_mode == qwen3next_fused_delta_mode::all_tokens);
if (use_fused_delta_net) {
attn_out = build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il);
@@ -4938,7 +4929,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * causal_mask = nullptr;
ggml_tensor * identity = nullptr;
ggml_tensor * diag_mask = nullptr;
if (!use_fused_delta_net_full) {
if (fused_delta_mode != qwen3next_fused_delta_mode::all_tokens) {
causal_mask = ggml_tri(ctx0,
ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);