qwen3next: optimize broadcast sub and single-seq ssm conv

This commit is contained in:
yurko
2026-02-06 12:50:43 +00:00
parent a7df116441
commit 9fbb50481e
4 changed files with 419 additions and 36 deletions

View File

@@ -4231,8 +4231,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
cb(g_cumsum, "g_cumsum", il);
ggml_tensor * gcs_i =
ggml_repeat_4d(ctx0, g_cumsum, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_i = g_cumsum;
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
ggml_tensor * gcs_j_broadcast =
@@ -4284,9 +4283,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
cb(g_last_exp, "g_last_exp", il);
ggml_tensor * g_last_repeat =
ggml_repeat_4d(ctx0, g_last, chunk_size, 1, n_chunks, H_v * n_seqs);
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last_repeat));
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
cb(g_diff, "g_diff", il);
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);