diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index d27a0f4573..f8a6abd59a 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -384,10 +384,27 @@ struct UnifiedAttentionKernel const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); - index_t _max_seq_prefix_len = amd_wave_read_first_lane( - (context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1)); - - if(seq_len < _max_seq_prefix_len) + // Causal optimisation: a Q-block whose largest row index is R only + // needs to read K[0..R] (rows beyond R are masked out anyway), so + // we can truncate the KV-tile loop to ceil((R+1) / kPageBlockSize) + // tiles. Under no-mask every Q row attends to all K rows, so the + // truncation MUST NOT fire — gate it on kHasMask. Previously the + // truncation fired unconditionally and made every Q-block other + // than the last one read a too-short prefix of K under + // mask_type=0, producing incorrect output. (Triton-UA only + // supports causal so this never showed up in the cross-impl + // sweeps.) + index_t _max_seq_prefix_len; + if constexpr(kHasMask) + { + _max_seq_prefix_len = amd_wave_read_first_lane( + (context_len + q_block_local_idx * kBlockQ_dyn + (kBlockQ_dyn - 1) + 1)); + if(seq_len < _max_seq_prefix_len) + { + _max_seq_prefix_len = seq_len; + } + } + else { _max_seq_prefix_len = seq_len; }