diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index a6106a8d..b3436460 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -4266,12 +4266,11 @@ ggml_cgraph * llm_build_context::build_qwen3next() { ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); cb(g_cumsum, "g_cumsum", il); - ggml_tensor * gcs_i = g_cumsum; ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, g_cumsum); cb(decay_mask, "decay_mask", il); decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); @@ -4285,7 +4284,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() { cb(attn, "attn_pre_solve", il); ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + ggml_tensor * lhs = ggml_neg(ctx0, ggml_sub(ctx0, attn_lower, identity)); ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); attn = ggml_mul(ctx0, lin_solve, causal_mask); @@ -4380,7 +4379,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() { ggml_row_size(core_attn_out->type, S_v), ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks), ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); cb(output_tokens, "output_tokens", il); output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);