qwen3next: add decode-only fused delta mode

This commit is contained in:
yurko
2026-02-07 23:05:19 -08:00
parent 9930f4d961
commit 143e88ae77
2 changed files with 52 additions and 7 deletions

View File

@@ -4178,12 +4178,21 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
// Keep legacy DeltaNet path as the default for correctness; enable fused path explicitly
// with LLAMA_QWEN3NEXT_FUSED_DELTA=1 for controlled testing.
const bool use_fused_delta_net = []() {
enum class qwen3next_fused_delta_mode {
off,
on,
tok1_only,
};
// Keep legacy DeltaNet path as default for correctness.
// LLAMA_QWEN3NEXT_FUSED_DELTA values:
// unset / 0 : off
// 1 : fused for all token counts
// 2 : fused only for single-token decode steps
const qwen3next_fused_delta_mode fused_delta_mode = []() {
const char * env = std::getenv("LLAMA_QWEN3NEXT_FUSED_DELTA");
if (env == nullptr || env[0] == '\0') {
return false;
return qwen3next_fused_delta_mode::off;
}
switch (env[0]) {
@@ -4192,11 +4201,14 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
case 'Y':
case 't':
case 'T':
return true;
return qwen3next_fused_delta_mode::on;
case '2':
return qwen3next_fused_delta_mode::tok1_only;
default:
return false;
return qwen3next_fused_delta_mode::off;
}
}();
const bool use_fused_delta_net_full = fused_delta_mode == qwen3next_fused_delta_mode::on;
auto get_slice_2d = [&](ggml_tensor * t, int64_t c) -> ggml_tensor * {
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
@@ -4840,6 +4852,10 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(v_conv, "v_conv_predelta", il);
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
const bool use_fused_delta_net =
use_fused_delta_net_full ||
(fused_delta_mode == qwen3next_fused_delta_mode::tok1_only && n_tok == 1);
if (use_fused_delta_net) {
attn_out = build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il);
} else {
@@ -4922,7 +4938,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * causal_mask = nullptr;
ggml_tensor * identity = nullptr;
ggml_tensor * diag_mask = nullptr;
if (!use_fused_delta_net) {
if (!use_fused_delta_net_full) {
causal_mask = ggml_tri(ctx0,
ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);