mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
The kernel's `_max_seq_prefix_len` computation unconditionally applied a
causal upper bound on the KV-tile loop:
_max_seq_prefix_len = context_len
+ q_block_local_idx * kBlockQ_dyn
+ (kBlockQ_dyn - 1) + 1
Under causal masking this is the correct optimisation — a Q-block whose
largest row index is R only needs to read K[0..R] because rows beyond R
are softmax-masked to zero. Under `mask_type=0` (no mask) every Q row
must attend to all K rows, so this truncation is incorrect: every
Q-block other than the last one ends up reading a too-short prefix of K
and the resulting softmax / weighted-sum is over the wrong support.
Symptoms at sq=sk=512, hq=hk=5, d=128, bf16, no-mask:
Q-block 0 (rows 0..255): max diff vs fp32 attention_ref ≈ 0.25
Q-block 1 (rows 256..511): max diff vs fp32 attention_ref ≈ 1e-3 (ULP)
The bug never showed up in the cross-impl sweeps because Triton-UA
asserts causal=True (its only supported mode) and sweep_fp8.sh forwards
that default through.
Fix: gate the truncation behind kHasMask. When kHasMask == false the
loop bound is simply `seq_len`, matching the math.
Validated against `aiter.test_mha_common.attention_ref` across:
- MHA d={64,128} sq=sk∈{256..2048} bf16/fp16 no-mask & causal
- GQA-8 d=128 sq=sk∈{256..1024} bf16 no-mask & causal
22/22 stages PASS within bf16/fp16 ULP. sweep_fp8.sh (causal) timings
unchanged — the truncation still fires for the causal kernels.
Co-authored-by: Cursor <cursoragent@cursor.com>