This commit is contained in:
Kawrakow
2026-02-18 10:36:38 +00:00
parent 19817d884b
commit 5a22dca980
3 changed files with 91 additions and 77 deletions

View File

@@ -4311,69 +4311,11 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
GGML_ASSERT(batch.n_tokens > 0);
const bool has_explicit_seq_info = batch.n_seq_id != nullptr && batch.seq_id != nullptr;
std::vector<llama_seq_id> token_seq_ids(batch.n_tokens, 0);
for (int i = 0; i < batch.n_tokens; ++i) {
if (has_explicit_seq_info) {
GGML_ASSERT(batch.n_seq_id[i] > 0 && "qwen3next expects each token to belong to at least one sequence");
GGML_ASSERT(batch.n_seq_id[i] == 1 && "qwen3next does not support multi-sequence tokens yet");
token_seq_ids[i] = batch.seq_id[i][0];
} else {
token_seq_ids[i] = 0;
}
}
const llama_seq_id seq_id = token_seq_ids[0];
const bool all_same_seq = std::all_of(token_seq_ids.begin(), token_seq_ids.end(), [&](llama_seq_id s) {
return s == seq_id;
});
bool has_unique_seq_ids = true;
if (!all_same_seq) {
std::unordered_set<llama_seq_id> seen;
seen.reserve(token_seq_ids.size());
for (llama_seq_id s : token_seq_ids) {
if (!seen.insert(s).second) {
has_unique_seq_ids = false;
break;
}
}
}
GGML_ASSERT(hparams.ssm_n_group > 0);
GGML_ASSERT(hparams.ssm_dt_rank > 0);
GGML_ASSERT(hparams.ssm_d_conv > 0);
GGML_ASSERT(hparams.ssm_d_inner % hparams.ssm_dt_rank == 0);
delta_net delta(lctx, batch);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads;
const int64_t key_dim = head_k_dim * num_k_heads;
const int64_t value_dim = head_v_dim * num_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
const int64_t conv_state_dim = (hparams.ssm_d_conv - 1) * conv_dim;
const int64_t ssm_state_dim = head_v_dim * head_v_dim * num_v_heads;
const int64_t state_dim = conv_state_dim + ssm_state_dim;
const uint32_t qnext_state_slots = llama_kv_qnext_state_slots(kv_self);
GGML_ASSERT(qnext_state_slots > 0);
GGML_ASSERT(hparams.n_embd_v_s() == (uint32_t) state_dim);
// Reserve-graph builds may not carry explicit sequence IDs, in which case
// the fallback sequence slot is 0.
const uint32_t state_seq_id = (uint32_t) seq_id;
for (llama_seq_id s : token_seq_ids) {
GGML_ASSERT(s >= 0);
GGML_ASSERT((uint32_t) s < qnext_state_slots);
}
const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * {
ggml_tensor * Qcur_full = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur_full, "Qcur_full", il);
@@ -4522,8 +4464,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * cur = nullptr;
delta_net delta(lctx);
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
@@ -4545,15 +4485,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(cur, "attn_norm", il);
if (hparams.is_recurrent(il)) {
GGML_ASSERT(model.layers[il].ssm_conv1d != nullptr);
GGML_ASSERT(model.layers[il].ssm_dt != nullptr);
GGML_ASSERT(model.layers[il].ssm_a != nullptr);
GGML_ASSERT(model.layers[il].ssm_beta_alpha != nullptr);
GGML_ASSERT(model.layers[il].ssm_norm != nullptr);
GGML_ASSERT(model.layers[il].ssm_out != nullptr);
GGML_ASSERT(model.layers[il].wqkv != nullptr || model.layers[il].ssm_in != nullptr);
GGML_ASSERT(model.layers[il].wqkv_gate != nullptr || model.layers[il].ssm_in != nullptr);
cur = delta.build_layer_attn_linear(ctx0, gf, batch, token_seq_ids, all_same_seq, has_unique_seq_ids, reset_state, cur, causal_mask, identity, diag_mask, il, cb);
cur = delta.build_layer_attn_linear(ctx0, gf, cur, causal_mask, identity, diag_mask, il, cb);
} else {
GGML_ASSERT(model.layers[il].wq != nullptr);
GGML_ASSERT(model.layers[il].wk != nullptr);

View File

@@ -6,9 +6,73 @@
#include "ggml.h"
#include <algorithm>
#include <unordered_set>
#define QWEN3NEXT_CHUNK_SIZE 64
delta_net::delta_net(llama_context & _lctx) : lctx(_lctx) {}
delta_net::delta_net(llama_context & _lctx, const llama_batch & _batch) : lctx(_lctx), batch(_batch) {
auto & model = lctx.model;
auto & hparams = model.hparams;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT(hparams.ssm_n_group > 0);
GGML_ASSERT(hparams.ssm_dt_rank > 0);
GGML_ASSERT(hparams.ssm_d_conv > 0);
GGML_ASSERT(hparams.ssm_d_inner % hparams.ssm_dt_rank == 0);
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads;
const int64_t key_dim = head_k_dim * num_k_heads;
const int64_t value_dim = head_v_dim * num_v_heads;
const int64_t ssm_state_dim = head_v_dim * head_v_dim * num_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
const int64_t conv_state_dim = (hparams.ssm_d_conv - 1) * conv_dim;
const int64_t state_dim = conv_state_dim + ssm_state_dim;
GGML_ASSERT(hparams.n_embd_v_s() == (uint32_t) state_dim);
const bool has_explicit_seq_info = batch.n_seq_id != nullptr && batch.seq_id != nullptr;
token_seq_ids.resize(batch.n_tokens, 0);
for (int i = 0; i < batch.n_tokens; ++i) {
if (has_explicit_seq_info) {
GGML_ASSERT(batch.n_seq_id[i] > 0 && "qwen3next expects each token to belong to at least one sequence");
GGML_ASSERT(batch.n_seq_id[i] == 1 && "qwen3next does not support multi-sequence tokens yet");
token_seq_ids[i] = batch.seq_id[i][0];
} else {
token_seq_ids[i] = 0;
}
}
auto seq_id = token_seq_ids[0];
all_same_seq = std::all_of(token_seq_ids.begin(), token_seq_ids.end(), [seq_id](llama_seq_id s) { return s == seq_id; });
has_unique_seq_ids = true;
if (!all_same_seq) {
std::unordered_set<llama_seq_id> seen;
seen.reserve(token_seq_ids.size());
for (auto s : token_seq_ids) {
if (!seen.insert(s).second) {
has_unique_seq_ids = false;
break;
}
}
}
const uint32_t qnext_state_slots = llm_build_context::llama_kv_qnext_state_slots(lctx.kv_self);
GGML_ASSERT(qnext_state_slots > 0);
// Reserve-graph builds may not carry explicit sequence IDs, in which case
// the fallback sequence slot is 0.
for (llama_seq_id s : token_seq_ids) {
GGML_ASSERT(s >= 0);
GGML_ASSERT((uint32_t) s < qnext_state_slots);
}
}
delta_net::~delta_net() = default;
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_delta_net_chunking(ggml_context * ctx0,
ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
@@ -569,13 +633,26 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
}
ggml_tensor * delta_net::build_layer_attn_linear(ggml_context * ctx0, ggml_cgraph * gf, const llama_batch & batch, const std::vector<llama_seq_id> & token_seq_ids,
bool all_same_seq, bool has_unique_seq_ids, bool reset_state,
ggml_tensor * delta_net::build_layer_attn_linear(ggml_context * ctx0, ggml_cgraph * gf,
ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il, const llm_build_cb & cb) const {
GGML_ASSERT(lctx.inp_s_seq_qnext != nullptr);
auto & model = lctx.model;
auto & hparams = model.hparams;
GGML_ASSERT(hparams.is_recurrent(il));
GGML_ASSERT(model.layers[il].ssm_conv1d != nullptr);
GGML_ASSERT(model.layers[il].ssm_dt != nullptr);
GGML_ASSERT(model.layers[il].ssm_a != nullptr);
GGML_ASSERT(model.layers[il].ssm_beta_alpha != nullptr);
GGML_ASSERT(model.layers[il].ssm_norm != nullptr);
GGML_ASSERT(model.layers[il].ssm_out != nullptr);
GGML_ASSERT(model.layers[il].wqkv != nullptr || model.layers[il].ssm_in != nullptr);
GGML_ASSERT(model.layers[il].wqkv_gate != nullptr || model.layers[il].ssm_in != nullptr);
if (all_same_seq) {
bool reset_state = batch.pos != nullptr && batch.pos[0] == 0;
return build_layer_attn_linear_core(ctx0, gf, cur, causal_mask, identity, diag_mask, lctx.inp_s_seq_qnext, token_seq_ids.front(), reset_state, il, cb);
}

View File

@@ -5,7 +5,8 @@
#include <utility>
struct delta_net {
delta_net(llama_context & lctx);
delta_net(llama_context & lctx, const llama_batch & batch);
~delta_net();
static std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(ggml_context * ctx0,
ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
@@ -25,12 +26,16 @@ struct delta_net {
ggml_tensor * diag_mask, ggml_tensor * inp_s_seq_qnext,
uint32_t state_seq_id_local, bool reset_state_local, int il, const llm_build_cb & cb) const;
ggml_tensor * build_layer_attn_linear(ggml_context * ctx0, ggml_cgraph * gf, const llama_batch & batch, const std::vector<llama_seq_id> & token_seq_ids,
bool all_same_seq, bool has_unique_seq_ids, bool reset_state,
ggml_tensor * build_layer_attn_linear(ggml_context * ctx0, ggml_cgraph * gf,
ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity,
ggml_tensor * diag_mask, int il, const llm_build_cb & cb) const;
private:
llama_context & lctx;
llama_context & lctx;
const llama_batch & batch;
std::vector<llama_seq_id> token_seq_ids;
bool all_same_seq;
bool has_unique_seq_ids;
};