From b33cef68ad3f670b10ae10cfae438023ff19eb1b Mon Sep 17 00:00:00 2001 From: yurko Date: Sat, 7 Feb 2026 17:31:17 -0800 Subject: [PATCH] qwen3next: add runtime switch for fused delta-net path --- src/llama-build-context.cpp | 67 +++++++++++++++++++++++++++++-------- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index edae6a27..73a5667a 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -6,6 +6,7 @@ #include "ggml.h" +#include #include llm_build_context::llm_build_context( @@ -4177,6 +4178,25 @@ ggml_cgraph * llm_build_context::build_qwen3next() { const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0; + // Default to fused DeltaNet path; set LLAMA_QWEN3NEXT_FUSED_DELTA=0 to force legacy graph path. + const bool use_fused_delta_net = []() { + const char * env = std::getenv("LLAMA_QWEN3NEXT_FUSED_DELTA"); + if (env == nullptr || env[0] == '\0') { + return true; + } + + switch (env[0]) { + case '0': + case 'n': + case 'N': + case 'f': + case 'F': + return false; + default: + return true; + } + }(); + auto get_slice_2d = [&](ggml_tensor * t, int64_t c) -> ggml_tensor * { return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); @@ -4443,6 +4463,13 @@ ggml_cgraph * llm_build_context::build_qwen3next() { return {core_attn_out, state}; }; + // Fused DeltaNet path. + // Input convention in this builder is [S, H, T, B] (GGML order), and ggml_delta_net expects: + // q/k: [S, T, H, B] + // v: [S, T, H, B] + // g: [T, 1, H, B] + // beta: [1, T, H, B] + // state: [S, S*H, 1, B] auto build_delta_net_fused = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, int il) -> std::pair { @@ -4454,6 +4481,8 @@ ggml_cgraph * llm_build_context::build_qwen3next() { const int64_t S_v = v->ne[0]; const int64_t H_v = v->ne[1]; + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); GGML_ASSERT(v->ne[2] == n_tokens); GGML_ASSERT(k->ne[2] == n_tokens); GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); @@ -4809,13 +4838,18 @@ ggml_cgraph * llm_build_context::build_qwen3next() { cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - const bool use_fused_delta_net = true; - std::pair attn_out = - use_fused_delta_net - ? build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il) - : (n_tok == 1 - ? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il) - : build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il)); + std::pair attn_out; + if (use_fused_delta_net) { + attn_out = build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il); + } else { + GGML_ASSERT(causal_mask != nullptr); + GGML_ASSERT(identity != nullptr); + GGML_ASSERT(diag_mask != nullptr); + + attn_out = n_tok == 1 + ? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il) + : build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); @@ -4884,14 +4918,19 @@ ggml_cgraph * llm_build_context::build_qwen3next() { cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1); ggml_set_input(lctx.inp_s_seq_qnext); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f), + ggml_tensor * causal_mask = nullptr; + ggml_tensor * identity = nullptr; + ggml_tensor * diag_mask = nullptr; + if (!use_fused_delta_net) { + causal_mask = ggml_tri(ctx0, + ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f), GGML_TRI_TYPE_LOWER); - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); + identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f)); + diag_mask = ggml_add(ctx0, causal_mask, identity); + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + ggml_build_forward_expand(gf, diag_mask); + } ggml_tensor * cur = nullptr;