mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-01 01:24:08 +00:00
qwen3next: add runtime switch for fused delta-net path
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <unordered_set>
|
||||
|
||||
llm_build_context::llm_build_context(
|
||||
@@ -4177,6 +4178,25 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
|
||||
|
||||
const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
|
||||
|
||||
// Default to fused DeltaNet path; set LLAMA_QWEN3NEXT_FUSED_DELTA=0 to force legacy graph path.
|
||||
const bool use_fused_delta_net = []() {
|
||||
const char * env = std::getenv("LLAMA_QWEN3NEXT_FUSED_DELTA");
|
||||
if (env == nullptr || env[0] == '\0') {
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (env[0]) {
|
||||
case '0':
|
||||
case 'n':
|
||||
case 'N':
|
||||
case 'f':
|
||||
case 'F':
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}();
|
||||
|
||||
auto get_slice_2d = [&](ggml_tensor * t, int64_t c) -> ggml_tensor * {
|
||||
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
|
||||
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
|
||||
@@ -4443,6 +4463,13 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
|
||||
return {core_attn_out, state};
|
||||
};
|
||||
|
||||
// Fused DeltaNet path.
|
||||
// Input convention in this builder is [S, H, T, B] (GGML order), and ggml_delta_net expects:
|
||||
// q/k: [S, T, H, B]
|
||||
// v: [S, T, H, B]
|
||||
// g: [T, 1, H, B]
|
||||
// beta: [1, T, H, B]
|
||||
// state: [S, S*H, 1, B]
|
||||
auto build_delta_net_fused = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
|
||||
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
|
||||
int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
|
||||
@@ -4454,6 +4481,8 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
@@ -4809,13 +4838,18 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
|
||||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
const bool use_fused_delta_net = true;
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out =
|
||||
use_fused_delta_net
|
||||
? build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il)
|
||||
: (n_tok == 1
|
||||
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
|
||||
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il));
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
|
||||
if (use_fused_delta_net) {
|
||||
attn_out = build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
GGML_ASSERT(causal_mask != nullptr);
|
||||
GGML_ASSERT(identity != nullptr);
|
||||
GGML_ASSERT(diag_mask != nullptr);
|
||||
|
||||
attn_out = n_tok == 1
|
||||
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
|
||||
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
||||
}
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
@@ -4884,14 +4918,19 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
|
||||
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
|
||||
ggml_set_input(lctx.inp_s_seq_qnext);
|
||||
|
||||
ggml_tensor * causal_mask =
|
||||
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
|
||||
ggml_tensor * causal_mask = nullptr;
|
||||
ggml_tensor * identity = nullptr;
|
||||
ggml_tensor * diag_mask = nullptr;
|
||||
if (!use_fused_delta_net) {
|
||||
causal_mask = ggml_tri(ctx0,
|
||||
ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
|
||||
GGML_TRI_TYPE_LOWER);
|
||||
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f));
|
||||
ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
|
||||
ggml_build_forward_expand(gf, causal_mask);
|
||||
ggml_build_forward_expand(gf, identity);
|
||||
ggml_build_forward_expand(gf, diag_mask);
|
||||
identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f));
|
||||
diag_mask = ggml_add(ctx0, causal_mask, identity);
|
||||
ggml_build_forward_expand(gf, causal_mask);
|
||||
ggml_build_forward_expand(gf, identity);
|
||||
ggml_build_forward_expand(gf, diag_mask);
|
||||
}
|
||||
|
||||
ggml_tensor * cur = nullptr;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user