mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
CK-UA: fix no-mask multi-Q-block path — was reading too-short K prefix
The kernel's `_max_seq_prefix_len` computation unconditionally applied a
causal upper bound on the KV-tile loop:
_max_seq_prefix_len = context_len
+ q_block_local_idx * kBlockQ_dyn
+ (kBlockQ_dyn - 1) + 1
Under causal masking this is the correct optimisation — a Q-block whose
largest row index is R only needs to read K[0..R] because rows beyond R
are softmax-masked to zero. Under `mask_type=0` (no mask) every Q row
must attend to all K rows, so this truncation is incorrect: every
Q-block other than the last one ends up reading a too-short prefix of K
and the resulting softmax / weighted-sum is over the wrong support.
Symptoms at sq=sk=512, hq=hk=5, d=128, bf16, no-mask:
Q-block 0 (rows 0..255): max diff vs fp32 attention_ref ≈ 0.25
Q-block 1 (rows 256..511): max diff vs fp32 attention_ref ≈ 1e-3 (ULP)
The bug never showed up in the cross-impl sweeps because Triton-UA
asserts causal=True (its only supported mode) and sweep_fp8.sh forwards
that default through.
Fix: gate the truncation behind kHasMask. When kHasMask == false the
loop bound is simply `seq_len`, matching the math.
Validated against `aiter.test_mha_common.attention_ref` across:
- MHA d={64,128} sq=sk∈{256..2048} bf16/fp16 no-mask & causal
- GQA-8 d=128 sq=sk∈{256..1024} bf16 no-mask & causal
22/22 stages PASS within bf16/fp16 ULP. sweep_fp8.sh (causal) timings
unchanged — the truncation still fires for the causal kernels.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -384,10 +384,27 @@ struct UnifiedAttentionKernel
|
||||
|
||||
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
|
||||
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
|
||||
(context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1));
|
||||
|
||||
if(seq_len < _max_seq_prefix_len)
|
||||
// Causal optimisation: a Q-block whose largest row index is R only
|
||||
// needs to read K[0..R] (rows beyond R are masked out anyway), so
|
||||
// we can truncate the KV-tile loop to ceil((R+1) / kPageBlockSize)
|
||||
// tiles. Under no-mask every Q row attends to all K rows, so the
|
||||
// truncation MUST NOT fire — gate it on kHasMask. Previously the
|
||||
// truncation fired unconditionally and made every Q-block other
|
||||
// than the last one read a too-short prefix of K under
|
||||
// mask_type=0, producing incorrect output. (Triton-UA only
|
||||
// supports causal so this never showed up in the cross-impl
|
||||
// sweeps.)
|
||||
index_t _max_seq_prefix_len;
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
_max_seq_prefix_len = amd_wave_read_first_lane(
|
||||
(context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1));
|
||||
if(seq_len < _max_seq_prefix_len)
|
||||
{
|
||||
_max_seq_prefix_len = seq_len;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
_max_seq_prefix_len = seq_len;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user