From 4aff2fa016fa2a03337c9b66b85f9f87a733ec5d Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Tue, 19 May 2026 14:29:48 +0000 Subject: [PATCH] =?UTF-8?q?CK-UA:=20fix=20no-mask=20multi-Q-block=20path?= =?UTF-8?q?=20=E2=80=94=20was=20reading=20too-short=20K=20prefix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The kernel's `_max_seq_prefix_len` computation unconditionally applied a causal upper bound on the KV-tile loop: _max_seq_prefix_len = context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1 Under causal masking this is the correct optimisation — a Q-block whose largest row index is R only needs to read K[0..R] because rows beyond R are softmax-masked to zero. Under `mask_type=0` (no mask) every Q row must attend to all K rows, so this truncation is incorrect: every Q-block other than the last one ends up reading a too-short prefix of K and the resulting softmax / weighted-sum is over the wrong support. Symptoms at sq=sk=512, hq=hk=5, d=128, bf16, no-mask: Q-block 0 (rows 0..255): max diff vs fp32 attention_ref ≈ 0.25 Q-block 1 (rows 256..511): max diff vs fp32 attention_ref ≈ 1e-3 (ULP) The bug never showed up in the cross-impl sweeps because Triton-UA asserts causal=True (its only supported mode) and sweep_fp8.sh forwards that default through. Fix: gate the truncation behind kHasMask. When kHasMask == false the loop bound is simply `seq_len`, matching the math. Validated against `aiter.test_mha_common.attention_ref` across: - MHA d={64,128} sq=sk∈{256..2048} bf16/fp16 no-mask & causal - GQA-8 d=128 sq=sk∈{256..1024} bf16 no-mask & causal 22/22 stages PASS within bf16/fp16 ULP. sweep_fp8.sh (causal) timings unchanged — the truncation still fires for the causal kernels. Co-authored-by: Cursor --- .../kernel/unified_attention_kernel.hpp | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index d27a0f4573..f8a6abd59a 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -384,10 +384,27 @@ struct UnifiedAttentionKernel const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); - index_t _max_seq_prefix_len = amd_wave_read_first_lane( - (context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1)); - - if(seq_len < _max_seq_prefix_len) + // Causal optimisation: a Q-block whose largest row index is R only + // needs to read K[0..R] (rows beyond R are masked out anyway), so + // we can truncate the KV-tile loop to ceil((R+1) / kPageBlockSize) + // tiles. Under no-mask every Q row attends to all K rows, so the + // truncation MUST NOT fire — gate it on kHasMask. Previously the + // truncation fired unconditionally and made every Q-block other + // than the last one read a too-short prefix of K under + // mask_type=0, producing incorrect output. (Triton-UA only + // supports causal so this never showed up in the cross-impl + // sweeps.) + index_t _max_seq_prefix_len; + if constexpr(kHasMask) + { + _max_seq_prefix_len = amd_wave_read_first_lane( + (context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1)); + if(seq_len < _max_seq_prefix_len) + { + _max_seq_prefix_len = seq_len; + } + } + else { _max_seq_prefix_len = seq_len; }