qwen3next: trim delta-net graph overhead in chunking path

This commit is contained in:
yurko
2026-02-07 13:21:02 -08:00
parent fffd27e3c8
commit a1163d0b68

View File

@@ -4266,12 +4266,11 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
cb(g_cumsum, "g_cumsum", il);
ggml_tensor * gcs_i = g_cumsum;
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j_broadcast =
ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, g_cumsum);
cb(decay_mask, "decay_mask", il);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
@@ -4285,7 +4284,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(attn, "attn_pre_solve", il);
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
ggml_tensor * lhs = ggml_neg(ctx0, ggml_sub(ctx0, attn_lower, identity));
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
attn = ggml_mul(ctx0, lin_solve, causal_mask);
@@ -4380,7 +4379,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_row_size(core_attn_out->type, S_v),
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks),
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks * H_v), 0);
output_tokens = ggml_cont(ctx0, output_tokens);
cb(output_tokens, "output_tokens", il);
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);