qwen3next: add runtime switch for fused delta-net path

This commit is contained in:
yurko
2026-02-07 17:31:17 -08:00
parent ed0565f801
commit b33cef68ad

View File

@@ -6,6 +6,7 @@
#include "ggml.h"
#include <cstdlib>
#include <unordered_set>
llm_build_context::llm_build_context(
@@ -4177,6 +4178,25 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
// Default to fused DeltaNet path; set LLAMA_QWEN3NEXT_FUSED_DELTA=0 to force legacy graph path.
const bool use_fused_delta_net = []() {
const char * env = std::getenv("LLAMA_QWEN3NEXT_FUSED_DELTA");
if (env == nullptr || env[0] == '\0') {
return true;
}
switch (env[0]) {
case '0':
case 'n':
case 'N':
case 'f':
case 'F':
return false;
default:
return true;
}
}();
auto get_slice_2d = [&](ggml_tensor * t, int64_t c) -> ggml_tensor * {
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
@@ -4443,6 +4463,13 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
return {core_attn_out, state};
};
// Fused DeltaNet path.
// Input convention in this builder is [S, H, T, B] (GGML order), and ggml_delta_net expects:
// q/k: [S, T, H, B]
// v: [S, T, H, B]
// g: [T, 1, H, B]
// beta: [1, T, H, B]
// state: [S, S*H, 1, B]
auto build_delta_net_fused = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
@@ -4454,6 +4481,8 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
@@ -4809,13 +4838,18 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
const bool use_fused_delta_net = true;
std::pair<ggml_tensor *, ggml_tensor *> attn_out =
use_fused_delta_net
? build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il)
: (n_tok == 1
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il));
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
if (use_fused_delta_net) {
attn_out = build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il);
} else {
GGML_ASSERT(causal_mask != nullptr);
GGML_ASSERT(identity != nullptr);
GGML_ASSERT(diag_mask != nullptr);
attn_out = n_tok == 1
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
}
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
@@ -4884,14 +4918,19 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
ggml_set_input(lctx.inp_s_seq_qnext);
ggml_tensor * causal_mask =
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
ggml_tensor * causal_mask = nullptr;
ggml_tensor * identity = nullptr;
ggml_tensor * diag_mask = nullptr;
if (!use_fused_delta_net) {
causal_mask = ggml_tri(ctx0,
ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f));
ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
ggml_build_forward_expand(gf, causal_mask);
ggml_build_forward_expand(gf, identity);
ggml_build_forward_expand(gf, diag_mask);
identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f));
diag_mask = ggml_add(ctx0, causal_mask, identity);
ggml_build_forward_expand(gf, causal_mask);
ggml_build_forward_expand(gf, identity);
ggml_build_forward_expand(gf, diag_mask);
}
ggml_tensor * cur = nullptr;