mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-28 02:11:50 +00:00
qwen3next: add fused delta-net op and wire model path
This commit is contained in:
@@ -4443,6 +4443,81 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
|
||||
return {core_attn_out, state};
|
||||
};
|
||||
|
||||
auto build_delta_net_fused = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
|
||||
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
|
||||
int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
|
||||
GGML_ASSERT(H_k == H_v);
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(g, "g_in", il);
|
||||
cb(state,"state_in", il);
|
||||
|
||||
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 1, 3, 0, 2), n_tokens, 1, H_k, n_seqs);
|
||||
beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 1, 2, 0, 3), 1, n_tokens, H_k, n_seqs);
|
||||
|
||||
ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs);
|
||||
if (!ggml_is_contiguous(state_flat)) {
|
||||
state_flat = ggml_cont_4d(ctx0, state_flat, S_v, S_v * H_v, 1, n_seqs);
|
||||
}
|
||||
|
||||
cb(q, "q_fused", il);
|
||||
cb(k, "k_fused", il);
|
||||
cb(v, "v_fused", il);
|
||||
cb(g, "g_fused", il);
|
||||
cb(beta, "beta_fused", il);
|
||||
cb(state_flat,"state_fused", il);
|
||||
|
||||
ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat);
|
||||
cb(fused_result, "delta_net_fused_raw", il);
|
||||
|
||||
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
|
||||
const int64_t state_size = S_v * S_v * H_v * n_seqs;
|
||||
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, fused_result,
|
||||
S_v, H_v, n_tokens, n_seqs,
|
||||
ggml_row_size(fused_result->type, S_v),
|
||||
ggml_row_size(fused_result->type, S_v * H_v),
|
||||
ggml_row_size(fused_result->type, S_v * H_v * n_tokens),
|
||||
0);
|
||||
output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size,
|
||||
output_size * ggml_element_size(fused_result));
|
||||
ggml_tensor * new_state = ggml_reshape_4d(ctx0, new_state_flat, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
cb(output_tokens, "output_tokens", il);
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
return {output_tokens, new_state};
|
||||
};
|
||||
|
||||
auto build_qkvz = [&](ggml_tensor * input, int il) -> std::pair<ggml_tensor *, ggml_tensor *> {
|
||||
const int64_t n_tok = input->ne[1];
|
||||
if (model.layers[il].wqkv) {
|
||||
@@ -4734,10 +4809,13 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
|
||||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
const bool use_fused_delta_net = true;
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out =
|
||||
n_tok == 1
|
||||
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
|
||||
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
||||
use_fused_delta_net
|
||||
? build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il)
|
||||
: (n_tok == 1
|
||||
? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il)
|
||||
: build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il));
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
|
||||
Reference in New Issue
Block a user