CK-UA: fix no-mask multi-Q-block path — was reading too-short K prefix

The kernel's `_max_seq_prefix_len` computation unconditionally applied a
causal upper bound on the KV-tile loop:

    _max_seq_prefix_len = context_len
                        + q_block_local_idx * kBlockQ_dyn
                        + (kBlockQ_dyn - 1) + 1

Under causal masking this is the correct optimisation — a Q-block whose
largest row index is R only needs to read K[0..R] because rows beyond R
are softmax-masked to zero. Under `mask_type=0` (no mask) every Q row
must attend to all K rows, so this truncation is incorrect: every
Q-block other than the last one ends up reading a too-short prefix of K
and the resulting softmax / weighted-sum is over the wrong support.

Symptoms at sq=sk=512, hq=hk=5, d=128, bf16, no-mask:
  Q-block 0 (rows 0..255):  max diff vs fp32 attention_ref ≈ 0.25
  Q-block 1 (rows 256..511): max diff vs fp32 attention_ref ≈ 1e-3 (ULP)

The bug never showed up in the cross-impl sweeps because Triton-UA
asserts causal=True (its only supported mode) and sweep_fp8.sh forwards
that default through.

Fix: gate the truncation behind kHasMask. When kHasMask == false the
loop bound is simply `seq_len`, matching the math.

Validated against `aiter.test_mha_common.attention_ref` across:
  - MHA d={64,128} sq=sk∈{256..2048}  bf16/fp16  no-mask & causal
  - GQA-8 d=128   sq=sk∈{256..1024}   bf16       no-mask & causal
22/22 stages PASS within bf16/fp16 ULP. sweep_fp8.sh (causal) timings
unchanged — the truncation still fires for the causal kernels.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-19 14:29:48 +00:00
parent c9bc5350c8
commit 4aff2fa016

View File

@@ -384,10 +384,27 @@ struct UnifiedAttentionKernel
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
(context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1));
if(seq_len < _max_seq_prefix_len)
// Causal optimisation: a Q-block whose largest row index is R only
// needs to read K[0..R] (rows beyond R are masked out anyway), so
// we can truncate the KV-tile loop to ceil((R+1) / kPageBlockSize)
// tiles. Under no-mask every Q row attends to all K rows, so the
// truncation MUST NOT fire — gate it on kHasMask. Previously the
// truncation fired unconditionally and made every Q-block other
// than the last one read a too-short prefix of K under
// mask_type=0, producing incorrect output. (Triton-UA only
// supports causal so this never showed up in the cross-impl
// sweeps.)
index_t _max_seq_prefix_len;
if constexpr(kHasMask)
{
_max_seq_prefix_len = amd_wave_read_first_lane(
(context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1));
if(seq_len < _max_seq_prefix_len)
{
_max_seq_prefix_len = seq_len;
}
}
else
{
_max_seq_prefix_len = seq_len;
}