From 5bd8f73a28c0ac8a95a5c5e54159ced9994a12b2 Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Tue, 12 May 2026 09:41:41 +0000 Subject: [PATCH] Delete CK-UA bs32 variant family The bs32 variants existed because pre-fix the pipeline required kBlockN <= page_size, so page_size=32 forced a kBlockN=32 kernel family. The multi-page-tile fix (commit 473869aba) lifted that constraint and made kBlockN compile-time-independent of the runtime page size, so the bs32 family is now redundant: every non-bs32 variant is correct for any page_size. This was validated in advance by forcing use_bs32=false in the dispatcher and running the full correctness suite -- 236/240, identical to baseline (the 4 remaining failures are the pre-existing int32- overflow case, orthogonal). Removes: * 16 instances/unified_attention_*_bs32_*.cpp files * unified_attention_decode_bs32_kernel_traits in unified_attention_impl.hpp * 3 _BS32 dispatch macros in unified_attention.cpp * 3 _p32 entries from the KernelVariant enum * 3 dispatch_*_p32 helper functions and their switch cases * the page_blk_size branch in select_config (now a pure tile-tier ladder) Net: 12 fewer compile units (build time -6s on JIT), 78 fewer dispatcher lines, and "which kernel runs?" is now driven purely by Q-tile shape. Co-authored-by: Cursor --- ...tention_d64_bf16_mask_gqa8_bs32_decode.cpp | 14 -- ...ntion_d64_bf16_mask_gqa8_bs32_decode_s.cpp | 14 -- ...ntion_d64_bf16_mask_gqa8_bs32_decode_t.cpp | 14 -- ...tention_d64_bf16_mask_gqa8_bs32_narrow.cpp | 14 -- ...ention_d64_bf16_nmask_gqa8_bs32_decode.cpp | 14 -- ...tion_d64_bf16_nmask_gqa8_bs32_decode_s.cpp | 14 -- ...tion_d64_bf16_nmask_gqa8_bs32_decode_t.cpp | 14 -- ...ention_d64_bf16_nmask_gqa8_bs32_narrow.cpp | 14 -- ...tention_d64_fp16_mask_gqa8_bs32_decode.cpp | 14 -- ...ntion_d64_fp16_mask_gqa8_bs32_decode_s.cpp | 14 -- ...ntion_d64_fp16_mask_gqa8_bs32_decode_t.cpp | 14 -- ...tention_d64_fp16_mask_gqa8_bs32_narrow.cpp | 14 -- ...ention_d64_fp16_nmask_gqa8_bs32_decode.cpp | 14 -- ...tion_d64_fp16_nmask_gqa8_bs32_decode_s.cpp | 14 -- ...tion_d64_fp16_nmask_gqa8_bs32_decode_t.cpp | 14 -- ...ention_d64_fp16_nmask_gqa8_bs32_narrow.cpp | 14 -- .../unified_attention.cpp | 122 +++--------------- .../unified_attention_impl.hpp | 62 --------- 18 files changed, 18 insertions(+), 390 deletions(-) delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_t.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_t.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_t.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_t.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow.cpp diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode.cpp deleted file mode 100644 index 112efe1222..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s.cpp deleted file mode 100644 index ef17fc1971..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_small_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_t.cpp deleted file mode 100644 index 2c6531c835..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_t.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_tiny_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow.cpp deleted file mode 100644 index 204319568f..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_bs32_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode.cpp deleted file mode 100644 index f0c3617a52..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s.cpp deleted file mode 100644 index ead32cf0bf..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_small_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_t.cpp deleted file mode 100644 index cc77cf7726..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_t.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_tiny_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow.cpp deleted file mode 100644 index 27ccae7b06..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_bs32_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode.cpp deleted file mode 100644 index b79fe8eeb2..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s.cpp deleted file mode 100644 index 272439ecb0..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_small_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_t.cpp deleted file mode 100644 index 1420e3fa40..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_t.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_tiny_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow.cpp deleted file mode 100644 index 22d1f71e5b..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_bs32_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode.cpp deleted file mode 100644 index c883749ac2..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s.cpp deleted file mode 100644 index b76f03fe0c..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_small_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_t.cpp deleted file mode 100644 index 134ab386b5..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_t.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_tiny_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow.cpp deleted file mode 100644 index 47a8dd7939..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - unified_attention_decode_bs32_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index 2b8cf4b3c7..2050591c7c 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -26,10 +26,10 @@ std::ostream& operator<<(std::ostream& stream, // 1. KernelVariant + select_config(args) // - KernelVariant is a flat enum of every compiled kernel instance the // module knows about. Each entry fixes the static knobs (kBlockM, -// warp count, MFMA shape, pipeline policy, optional kBlockN override). +// warp count, MFMA shape, pipeline policy). // - select_config() is the ONLY place where shape-based runtime -// decisions live. It reads the problem (hdim, num_queries_per_kv, -// page_blk_size, avg_q, max_seqlen_q) and emits a KernelConfig. +// decisions live. It reads (hdim, num_queries_per_kv, avg_q, +// max_seqlen_q) and emits a KernelConfig. // // 2. dispatch_() helpers + the final switch // - Each KernelVariant has a tiny helper that fans out over the @@ -38,12 +38,10 @@ std::ostream& operator<<(std::ostream& stream, // per-variant traits classes are unchanged from before; only the // selection logic moved. // -// Phase-1 note: page-size is currently still a static axis in the enum -// (the _p32 suffix marks the variant with kBlockN=32 that was originally -// required when page_size < 64). The multi-page-tile fix in the pipeline -// removed the underlying constraint, so a follow-up commit deletes the -// _p32 (a.k.a. "bs32") family entirely. Doing it in two steps keeps each -// diff easy to bisect against the test suite. +// page_size is intentionally NOT part of this enum. The multi-page-tile +// fix in the pipeline made the compile-time tile-N (kBlockN) independent +// of the runtime page_blk_size, so every variant is correct for any page +// size. Selection is driven purely by Q-tile shape. // ============================================================================= enum class KernelVariant { @@ -54,10 +52,7 @@ enum class KernelVariant { // d=64 GQA-8 (num_queries_per_kv = 8) prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma - decode_d64_gqa8_m128_p32, // kBlockM=128, 4 warps, 32x32 mfma, kBlockN=32 decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma - decode_d64_gqa8_m64_p32, // kBlockM=64, 2 warps, 32x32 mfma, kBlockN=32 - decode_d64_gqa8_m32_p32, // kBlockM=32, 2 warps, 16x16 mfma, kBlockN=32 decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma }; @@ -114,28 +109,14 @@ KernelConfig select_config(const unified_attention_args& args) return cfg; } - // d=64 GQA-8 — full tile-tier ladder, with _p32 variants for the legacy - // kBlockN=32 path used when page_blk_size < 64. + // d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here. if (args.hdim == 64 && args.num_queries_per_kv == 8) { - const bool p32 = (args.page_blk_size < 64); - switch (select_tile_tier(args)) { - case tile_tier::tiny: - // p32: 2-warp 16x16 (kBlockM=32) -- avoids the 1-warp+p32 race. - // p64: 1-warp 16x16 (kBlockM=16) -- matches Triton BLOCK_M=16. - cfg.variant = p32 ? KernelVariant::decode_d64_gqa8_m32_p32 - : KernelVariant::decode_d64_gqa8_m16; - break; - case tile_tier::small: - cfg.variant = p32 ? KernelVariant::decode_d64_gqa8_m64_p32 - : KernelVariant::decode_d64_gqa8_m64; - break; - case tile_tier::medium: - cfg.variant = p32 ? KernelVariant::decode_d64_gqa8_m128_p32 - : KernelVariant::decode_d64_gqa8_m128; - break; + case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break; + case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break; + case tile_tier::medium: cfg.variant = KernelVariant::decode_d64_gqa8_m128; break; } return cfg; } @@ -150,7 +131,7 @@ KernelConfig select_config(const unified_attention_args& args) // Each DISPATCH_* macro instantiates one (traits, dtype, mask, ...) combo and // returns. The per-variant helpers below pick the right macro family and fan // out over (dtype, mask). They look repetitive on purpose: a follow-up commit -// will collapse the 5 traits classes into one templated `kernel_traits`, +// will collapse the 4 traits classes into one templated `kernel_traits`, // at which point these helpers become one-liners. // ----------------------------------------------------------------------------- @@ -180,25 +161,6 @@ KernelConfig select_config(const unified_attention_args& args) return unified_attention_kernel_dispatch_decode(args, config); \ } -// block_size=32 dispatch macros (6th template arg = 32). -#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType, IsMask, HSize, BM, NQPKV) \ - { \ - using kernel_traits = unified_attention_decode_kernel_traits; \ - return unified_attention_kernel_dispatch(args, config); \ - } - -#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType, IsMask, HSize, BM, NQPKV) \ - { \ - using kernel_traits = unified_attention_decode_small_kernel_traits; \ - return unified_attention_kernel_dispatch_decode(args, config); \ - } - -#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \ - { \ - using kernel_traits = unified_attention_decode_bs32_kernel_traits; \ - return unified_attention_kernel_dispatch_decode(args, config); \ - } - namespace { using DType = unified_attention_args::data_type_enum; @@ -259,20 +221,6 @@ std::pair dispatch_decode_d64_gqa8_m128( return {false, -1.f}; } -std::pair dispatch_decode_d64_gqa8_m128_p32( - const unified_attention_args& args, const stream_config& config) -{ - const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - if (args.data_type == DType::fp16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::fp16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::fp16, true, 64, 128, 8) - } else if (args.data_type == DType::bf16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::bf16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::bf16, true, 64, 128, 8) - } - return {false, -1.f}; -} - std::pair dispatch_decode_d64_gqa8_m64( const unified_attention_args& args, const stream_config& config) { @@ -287,34 +235,6 @@ std::pair dispatch_decode_d64_gqa8_m64( return {false, -1.f}; } -std::pair dispatch_decode_d64_gqa8_m64_p32( - const unified_attention_args& args, const stream_config& config) -{ - const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - if (args.data_type == DType::fp16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::fp16, false, 64, 64, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::fp16, true, 64, 64, 8) - } else if (args.data_type == DType::bf16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::bf16, false, 64, 64, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::bf16, true, 64, 64, 8) - } - return {false, -1.f}; -} - -std::pair dispatch_decode_d64_gqa8_m32_p32( - const unified_attention_args& args, const stream_config& config) -{ - const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - if (args.data_type == DType::fp16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::fp16, false, 64, 32, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::fp16, true, 64, 32, 8) - } else if (args.data_type == DType::bf16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::bf16, false, 64, 32, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::bf16, true, 64, 32, 8) - } - return {false, -1.f}; -} - std::pair dispatch_decode_d64_gqa8_m16( const unified_attention_args& args, const stream_config& config) { @@ -347,22 +267,16 @@ std::pair unified_attention(const unified_attention_args& args, switch (cfg.variant) { - case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config); - case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config); - case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config); - case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config); - case KernelVariant::decode_d64_gqa8_m128_p32: return dispatch_decode_d64_gqa8_m128_p32(args, config); - case KernelVariant::decode_d64_gqa8_m64: return dispatch_decode_d64_gqa8_m64(args, config); - case KernelVariant::decode_d64_gqa8_m64_p32: return dispatch_decode_d64_gqa8_m64_p32(args, config); - case KernelVariant::decode_d64_gqa8_m32_p32: return dispatch_decode_d64_gqa8_m32_p32(args, config); - case KernelVariant::decode_d64_gqa8_m16: return dispatch_decode_d64_gqa8_m16(args, config); + case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config); + case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config); + case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config); + case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config); + case KernelVariant::decode_d64_gqa8_m64: return dispatch_decode_d64_gqa8_m64(args, config); + case KernelVariant::decode_d64_gqa8_m16: return dispatch_decode_d64_gqa8_m16(args, config); } return std::make_pair(false, -1.f); } -#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW -#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32 -#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32 #undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY #undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL #undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 0793e0695a..55b2f39216 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -316,68 +316,6 @@ struct unified_attention_decode_tiny_kernel_traits using kernel = UnifiedAttentionKernel; }; -// bs32 decode traits: 2 warps, 16x16 MFMA, kBlockM=32, kBlockQ=4 for GQA-8. -// Used for block_size=32 decode: avoids the 1-warp pipeline race condition -// and reduces query waste from 87.5% (small tier kBlockQ=8) to 75% (kBlockQ=4). -template -struct unified_attention_decode_bs32_kernel_traits -{ - static constexpr auto date_type = DataType; - static constexpr bool is_masking = IsMasking; - - static constexpr index_t kBlockM = BlockM_; - static constexpr index_t HEAD_SIZE = HeadSize_; - static constexpr index_t BLOCK_SIZE = BlockSize_; - - static constexpr index_t num_queries_per_kv = NumQPerKV_; - static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; - - using unified_attention_block_tile = sequence; - using unified_attention_warp_gemm_shape = sequence<16, 16, 32>; - using unified_attention_block_warps = sequence<2, 1, 1>; - - using unified_attention_shape = TileUnifiedAttentionShape; - - using unified_attention_traits = TileUnifiedAttentionTraits; - using unified_attention_mask = GenericAttentionMask; - - using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem< - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::lse_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - unified_attention_shape, - unified_attention_mask, - unified_attention_traits>; - - using unified_attention_pipeline = - UnifiedAttentionPipeline; - - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - true, true, true>>; - - using kernel = UnifiedAttentionKernel; -}; - template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config)