Attention mask tweaks for better long context performance (#825)

* Parallelize mask

We see non-negligible PP gains for long contexts.
More importantly, the strange drop in performance
observed for GPT-OSS for context >= 32k tokens is gone.

* Whith FA on, create mask as f16 directly

* WIP

* Reduce KQ mask padding to 16

Why was it 64 in the first place?

I don't observe any issues, while TG performance
for long contexts improves by 2-4%.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-13 14:01:11 +03:00
committed by GitHub
parent 21a0bfb1c0
commit 4e24d48e63
3 changed files with 277 additions and 25 deletions

View File

@@ -276,6 +276,12 @@ ggml_tensor * llm_build_context::build_inp_out_ids() {
}
ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
if (causal && flash_attn) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
return lctx.inp_KQ_mask;
}
lctx.inp_KQ_mask = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
@@ -287,6 +293,12 @@ ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
ggml_tensor * llm_build_context::build_inp_KQ_mask_swa(bool causal) {
GGML_ASSERT(hparams.n_swa > 0);
if (causal && flash_attn) {
lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(lctx.inp_KQ_mask_swa);
return lctx.inp_KQ_mask_swa;
}
lctx.inp_KQ_mask_swa = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))