diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_local.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_local.cpp new file mode 100644 index 0000000000..ddfea46828 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_local.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +// d64 GQA-8 medium decode tier with BlockSize=32 (kBlockM=128, kBlockQ=16), +// IsMasking=true, IsLocal=true. Targets GPT-OSS short-prefill SWA shapes +// (max_seqlen_q in [257,1024], page_blk_size=32, GQA-8). +using kernel_traits = + unified_attention_decode_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow_local.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow_local.cpp new file mode 100644 index 0000000000..d8e269eca0 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow_local.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +// d64 GQA-8 tiny+bs32 decode tier (kBlockM=32, kBlockQ=4, BlockSize=32), +// IsMasking=true, IsLocal=true. Targets GPT-OSS decode shapes +// (q=1, page_blk_size=32, GQA-8) with sliding-window-attention. +using kernel_traits = + unified_attention_decode_bs32_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_local.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_local.cpp new file mode 100644 index 0000000000..fa2c15683d --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_local.cpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +// d64 GQA-8 medium decode tier with BlockSize=32 (kBlockM=128, kBlockQ=16), +// IsMasking=true, IsLocal=true. fp16 sibling of the bf16 instance. +using kernel_traits = + unified_attention_decode_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow_local.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow_local.cpp new file mode 100644 index 0000000000..97ebfd6cde --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow_local.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +// d64 GQA-8 tiny+bs32 decode tier (kBlockM=32, kBlockQ=4, BlockSize=32), +// IsMasking=true, IsLocal=true. fp16 sibling of the bf16 instance used by +// the GPT-OSS decode SWA path. +using kernel_traits = + unified_attention_decode_bs32_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/script/edge_test_swa.sh b/example/ck_tile/42_unified_attention/script/edge_test_swa.sh index 221934e858..2568d1cd47 100755 --- a/example/ck_tile/42_unified_attention/script/edge_test_swa.sh +++ b/example/ck_tile/42_unified_attention/script/edge_test_swa.sh @@ -5,10 +5,11 @@ # - window=1 (every Q row attends to its own K position only) # - window > seq_k with right=0 (degenerates to plain causal) # - explicit b:0,0 (alternative spelling of diagonal-only) -# - decode shapes (q=1, kv>>1) — exercises the SWA path on a single-token Q -# (today this is dispatched to the large-tier kernel via the -# `if(is_local) tier = tile_tier::large` hack; correctness should hold even -# though it's wasteful) +# - decode shapes (q=1, kv>>1) — exercises the SWA path on a single-token Q. +# For page_blk_size>=64 (Edge 4) we route to the large-tier kernel, which is +# wasteful but correct. For page_blk_size==32 (Edge 5) we route to the +# decode-tier IsLocal=true instances added in Phase 5 — that's the GPT-OSS +# production path. # # Same convention as smoke_test_swa.sh: every test must pass, exit code is the # number of failures. @@ -35,6 +36,15 @@ BASELINE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=512 -s_k=512 -query_lens=400,256,512, DECODE_A="-d=128 -h_k=8 -nqpkv=1 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512" DECODE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512" +# Decode + page_blk_size=32 fixtures for d=64 GQA-8. These exercise the NEW +# decode-tier IsLocal=true instances added in Phase 5: +# - DECODE_BS32_Q1 (q=1) → tiny+bs32 local (kBlockM=32, kBlockQ=4) +# - DECODE_BS32_QM (q in [256,1024]) → medium+bs32 local (kBlockM=128, kBlockQ=16) +# Use the GPT-OSS-shaped window (left=127, right=0) to mirror the production +# workload that motivated Phase 5. +DECODE_BS32_Q1="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512" +DECODE_BS32_QM="-d=64 -h_k=1 -nqpkv=8 -b=2 -s=1024 -s_k=1024 -query_lens=1024,512 -kv_lens=1024,512" + TESTS=( # Edge 1: window=1 — diagonal-only attention. Smallest non-zero window. # xb:1 decodes to left=0, right=0 via window/2 split. @@ -55,8 +65,17 @@ TESTS=( # Edge 4: decode shapes (single-token query). The SWA mask trims the K range # to a 64-wide window at the bottom-right corner of the (1, 512) # attention matrix, so most of the kv tail is masked out. - "decode q=1 d128 xb:64|$DECODE_A -mask=xb:64" - "decode q=1 d64 xb:64|$DECODE_B -mask=xb:64" + "decode q=1 d128 xb:64 |$DECODE_A -mask=xb:64" + "decode q=1 d64 xb:64 |$DECODE_B -mask=xb:64" + + # Edge 5 (Phase 5): GPT-OSS-shaped d64 GQA-8 SWA on page_blk_size=32. These + # MUST hit the new decode-tier IsLocal=true instances; if a regression takes + # them back to the bs64-only fallback they fail with "no matching kernel + # instance" or wrong numerics. + "decode q=1 d64 bs32 b:127,0 |$DECODE_BS32_Q1 -page_blk_size=32 -mask=b:127,0" + "decode q=1 d64 bs32 xb:128 |$DECODE_BS32_Q1 -page_blk_size=32 -mask=xb:128" + "shortpf d64 bs32 b:127,0 |$DECODE_BS32_QM -page_blk_size=32 -mask=b:127,0" + "shortpf d64 bs32 xb:128 |$DECODE_BS32_QM -page_blk_size=32 -mask=xb:128" ) n_pass=0 @@ -66,7 +85,7 @@ for entry in "${TESTS[@]}"; do name="${entry%%|*}" args="${entry#*|}" - printf '== %-26s :: %s\n' "$name" "$args" + printf '== %-32s :: %s\n' "$name" "$args" set +e "$EXE" $COMMON $args > /tmp/swa_edge_out.$$ 2>&1 ret=$? diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index 91056e7c0d..309c297272 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -34,6 +34,21 @@ std::ostream& operator<<(std::ostream& stream, return unified_attention_kernel_dispatch(args, config); \ } +// SWA-aware decode dispatchers for bs32. These mirror the non-local *_BS32 macros +// but flip IsLocal=true on the 7th template arg, so the kernel uses the +// sliding-window mask AND the per-Q-tile KV-block iteration clipping. +#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL(DType, HSize, BM, NQPKV) \ + { \ + using kernel_traits = unified_attention_decode_kernel_traits; \ + return unified_attention_kernel_dispatch(args, config); \ + } + +#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(DType, HSize, BM, NQPKV) \ + { \ + using kernel_traits = unified_attention_decode_bs32_kernel_traits; \ + return unified_attention_kernel_dispatch_decode(args, config); \ + } + // Dispatch macros for three tile tiers (default block_size). #define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \ { \ @@ -108,22 +123,52 @@ std::pair unified_attention(const unified_attention_args& args, const bool is_local = is_mask && (args.window_size_left >= 0 || args.window_size_right > 0); auto tier = select_tile_tier(args); - // SWA instances currently only exist at the large prefill tier (kBlockM=256, - // 8 warps). Each requires args.page_blk_size >= kBlockN of the instance — - // otherwise the kernel hits a device-side `kv_page_size_in_blocks >= 1` - // assertion. When SWA is requested on an unsupported (shape, page_blk_size) - // pair we return {false, 0} so the caller (e.g. _try_ck_unified_attention) - // can fall back to a backend that handles it (Triton). Falling through to - // the IsLocal=false path would silently ignore window_size_left and produce + // SWA-instance availability matrix (IsLocal=true). Anything not listed here + // returns {false, 0} so the caller (e.g. _try_ck_unified_attention) falls + // back to a backend that handles it (Triton). Falling through to the + // IsLocal=false path would silently ignore window_size_left and produce // wrong outputs, so we reject explicitly. + // + // shape | page_blk_size | tier we route to | instance + // ---------------+---------------+-----------------------------+--------------------- + // d128 MHA | >= 32 | tile_tier::large | d128_*_mask_local + // d64 GQA-8 | >= 64 | tile_tier::large | d64_*_mask_gqa8_local + // d64 GQA-8 | == 32, tiny | tile_tier::tiny | d64_*_mask_gqa8_bs32_narrow_local + // d64 GQA-8 | == 32, med | tile_tier::medium | d64_*_mask_gqa8_bs32_decode_local + // (small+bs32 SWA has no instance yet -> Triton fallback; GPT-OSS shows zero + // such shapes in practice. Bumping it to medium is plausible but wastes a + // full kBlockM=128 tile when only kBlockM=64 is needed -- revisit if the + // workload changes.) if(is_local) { - const bool d128_mha = (args.hdim == 128 && args.num_queries_per_kv == 1); - const bool d64_gqa8 = (args.hdim == 64 && args.num_queries_per_kv == 8); - const index_t kBN_req = d128_mha ? 32 : (d64_gqa8 ? 64 : 0); - if(kBN_req == 0 || args.page_blk_size < kBN_req) + const bool d128_mha = (args.hdim == 128 && args.num_queries_per_kv == 1); + const bool d64_gqa8 = (args.hdim == 64 && args.num_queries_per_kv == 8); + if(d128_mha) + { + if(args.page_blk_size < 32) return {false, 0.f}; + tier = tile_tier::large; + } + else if(d64_gqa8) + { + if(args.page_blk_size >= 64) + { + tier = tile_tier::large; + } + else if(args.page_blk_size >= 32) + { + if(tier != tile_tier::tiny && tier != tile_tier::medium) + return {false, 0.f}; + // keep selected tier as-is + } + else + { + return {false, 0.f}; + } + } + else + { return {false, 0.f}; - tier = tile_tier::large; + } } // d128, MHA (num_queries_per_kv == 1) @@ -155,13 +200,15 @@ std::pair unified_attention(const unified_attention_args& args, // Avoids 1-warp race condition; 2x less waste than small tier. if(args.data_type == unified_attention_args::data_type_enum::fp16) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8) + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8) + else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 32, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8) } else if(args.data_type == unified_attention_args::data_type_enum::bf16) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8) + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8) + else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 32, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8) } } else { // bs64 tiny: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2. @@ -205,8 +252,9 @@ std::pair unified_attention(const unified_attention_args& args, if(args.data_type == unified_attention_args::data_type_enum::fp16) { if(use_bs32) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8) + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8) + else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 128, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8) } else { if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8) else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8) @@ -215,8 +263,9 @@ std::pair unified_attention(const unified_attention_args& args, else if(args.data_type == unified_attention_args::data_type_enum::bf16) { if(use_bs32) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8) + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8) + else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 128, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8) } else { if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8) else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8) @@ -248,6 +297,8 @@ std::pair unified_attention(const unified_attention_args& args, return std::make_pair(false, -1.f); } +#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL +#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL #undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW #undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32 #undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32