From fb0d729fbb7bdc9e6204617e18f2b95cfc22169f Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Tue, 12 May 2026 10:35:15 +0000 Subject: [PATCH] Collapse CK-UA traits into single kernel_traits template Replace 4 near-identical *_kernel_traits classes (~400 lines of repeated shape/policy plumbing) with one templated `unified_attention_kernel_traits` parameterized by `KernelVariant V`. The 6 dispatch_ helpers in unified_attention.cpp collapse into a single `dispatch_variant` function template that fans out over (dtype, mask). Per-variant compile-time knobs (BlockM, BlockSize, warp count, MFMA shape, pipeline policy, decode-grid flag) now live in one variant_config specialization each. "What's different between variants" is readable top-to-bottom in a single block of code, and each instance .cpp shrinks to a one-line `INST_UNIFIED_ATTENTION_DISPATCH(V, DTYPE, IS_MASK)` macro. No behavior change. Correctness suite: 236/240 (same 4 known num_blocks=32768 + d=128 MHA int32-overflow failures as baseline). Co-authored-by: Cursor --- .../unified_attention_d128_bf16_mask.cpp | 5 +- ...nified_attention_d128_bf16_mask_decode.cpp | 5 +- .../unified_attention_d128_bf16_nmask.cpp | 5 +- ...ified_attention_d128_bf16_nmask_decode.cpp | 5 +- .../unified_attention_d128_fp16_mask.cpp | 5 +- ...nified_attention_d128_fp16_mask_decode.cpp | 5 +- .../unified_attention_d128_fp16_nmask.cpp | 5 +- ...ified_attention_d128_fp16_nmask_decode.cpp | 5 +- .../unified_attention_d64_bf16_mask_gqa8.cpp | 5 +- ...ed_attention_d64_bf16_mask_gqa8_decode.cpp | 5 +- ..._attention_d64_bf16_mask_gqa8_decode_s.cpp | 5 +- ..._attention_d64_bf16_mask_gqa8_decode_t.cpp | 5 +- .../unified_attention_d64_bf16_nmask_gqa8.cpp | 5 +- ...d_attention_d64_bf16_nmask_gqa8_decode.cpp | 5 +- ...attention_d64_bf16_nmask_gqa8_decode_s.cpp | 5 +- ...attention_d64_bf16_nmask_gqa8_decode_t.cpp | 5 +- .../unified_attention_d64_fp16_mask_gqa8.cpp | 5 +- ...ed_attention_d64_fp16_mask_gqa8_decode.cpp | 5 +- ..._attention_d64_fp16_mask_gqa8_decode_s.cpp | 5 +- ..._attention_d64_fp16_mask_gqa8_decode_t.cpp | 5 +- .../unified_attention_d64_fp16_nmask_gqa8.cpp | 5 +- ...d_attention_d64_fp16_nmask_gqa8_decode.cpp | 5 +- ...attention_d64_fp16_nmask_gqa8_decode_s.cpp | 5 +- ...attention_d64_fp16_nmask_gqa8_decode_t.cpp | 5 +- .../unified_attention.cpp | 236 +++------ .../unified_attention_impl.hpp | 463 ++++++++---------- 26 files changed, 299 insertions(+), 520 deletions(-) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp index 72717026bc..3c80909df3 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp index 8659f68a7d..9af6ef61b0 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp index 391103891a..d5e88ba319 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp index 2505832331..ed3de85a5e 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp index f2cc00f835..846cda09ad 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp index 8e1fb0d1f8..9c5a317b51 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp index 6a2a9984d1..e993d6c572 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp index a9d6b17211..d73f2dbe5f 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp index 0b6be68278..e7cad98691 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode.cpp index a01ff6a23f..b0a8f44eec 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp index 8f36b20ffb..793e65047d 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_small_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_t.cpp index 31de799776..5f34f0319d 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_t.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_tiny_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp index 6bd3dd6f58..80d1d5d45a 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode.cpp index 15ad7a8565..1cb0db0896 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp index 6f76a63874..7b666acb97 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_small_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_t.cpp index d4d4306638..2dc0db7b42 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_t.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_tiny_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp index 28ff9f22b1..e23dc6673c 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode.cpp index f9087e4147..c691b65da3 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp index ca688180e0..718866ca7e 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_small_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_t.cpp index c3a4c1f850..4b5d95a193 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_t.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_tiny_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp index f4d83a06a0..6d104b2a0f 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode.cpp index 06ae388ba7..b3484577f6 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp index 27fb381887..0650df3cca 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_small_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_t.cpp index 7c127cc130..24f2ff30af 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_t.cpp @@ -6,9 +6,6 @@ namespace ck_tile { -using kernel_traits = - unified_attention_decode_tiny_kernel_traits; - -INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index 2050591c7c..123ffa7238 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -23,50 +23,40 @@ std::ostream& operator<<(std::ostream& stream, // // The job is split in two halves so each is small enough to read in one sitting: // -// 1. KernelVariant + select_config(args) -// - KernelVariant is a flat enum of every compiled kernel instance the -// module knows about. Each entry fixes the static knobs (kBlockM, -// warp count, MFMA shape, pipeline policy). -// - select_config() is the ONLY place where shape-based runtime -// decisions live. It reads (hdim, num_queries_per_kv, avg_q, -// max_seqlen_q) and emits a KernelConfig. +// 1. select_config(args) +// - Reads shape (hdim, num_queries_per_kv, avg_q, max_seqlen_q) and +// picks one of the KernelVariants defined in unified_attention_impl.hpp. +// KernelVariant is the only place where compile-time knobs live — +// changing a knob means adding/editing a variant_config. // -// 2. dispatch_() helpers + the final switch -// - Each KernelVariant has a tiny helper that fans out over the -// (dtype, mask) cross-product and calls into the existing -// DISPATCH_UNIFIED_ATTENTION_* macros. The macros and the -// per-variant traits classes are unchanged from before; only the -// selection logic moved. +// 2. dispatch_variant() + the final switch +// - dispatch_variant() is a single function template that fans out +// over (dtype, mask) and forwards into the per-instance dispatch +// function generated by INST_UNIFIED_ATTENTION_DISPATCH. +// - The final switch maps KernelVariant -> dispatch_variant. // -// page_size is intentionally NOT part of this enum. The multi-page-tile -// fix in the pipeline made the compile-time tile-N (kBlockN) independent -// of the runtime page_blk_size, so every variant is correct for any page -// size. Selection is driven purely by Q-tile shape. +// page_size is intentionally NOT part of the config — the multi-page-tile +// pipeline fix made kBlockN independent of runtime page_blk_size, so every +// variant is correct for any page size. // ============================================================================= -enum class KernelVariant { - // d=128 MHA (num_queries_per_kv = 1) - prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma - decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128) - - // d=64 GQA-8 (num_queries_per_kv = 8) - prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma - decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma - decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma - decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma -}; - -struct KernelConfig { +struct KernelConfig +{ KernelVariant variant; bool unsupported = false; }; namespace { -// Internal tier classification — used only by select_config. The tier name is -// just shorthand for a kBlockM choice; with num_queries_per_kv=8 the tiers -// correspond to kBlockQ thresholds {2, 8, 16}. -enum class tile_tier { medium, small, tiny }; +// Internal tile-tier classification — used only by select_config. The tier +// name is shorthand for a kBlockM choice; with num_queries_per_kv=8 the +// tiers correspond to kBlockQ thresholds {2, 8, 16}. +enum class tile_tier +{ + medium, + small, + tiny +}; tile_tier select_tile_tier(const unified_attention_args& args) { @@ -80,8 +70,8 @@ tile_tier select_tile_tier(const unified_attention_args& args) // many more tokens, fall back to the medium tier (1D grid with Q iteration). const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q; - if (avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny; - if (avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small; + if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny; + if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small; return tile_tier::medium; } @@ -96,13 +86,13 @@ KernelConfig select_config(const unified_attention_args& args) // both fit in 128. Cuts Q-tile waste roughly 2x vs prefill for // short-Q workloads. // * prefill_d128_mha : kBlockM=256, 8 warps. Everything else. - if (args.hdim == 128 && args.num_queries_per_kv == 1) + if(args.hdim == 128 && args.num_queries_per_kv == 1) { const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens; const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q; - if (avg_q <= 128 && max_q <= 128) + if(avg_q <= 128 && max_q <= 128) cfg.variant = KernelVariant::decode_d128_mha_m128; else cfg.variant = KernelVariant::prefill_d128_mha; @@ -110,9 +100,9 @@ KernelConfig select_config(const unified_attention_args& args) } // d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here. - if (args.hdim == 64 && args.num_queries_per_kv == 8) + if(args.hdim == 64 && args.num_queries_per_kv == 8) { - switch (select_tile_tier(args)) + switch(select_tile_tier(args)) { case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break; case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break; @@ -126,125 +116,36 @@ KernelConfig select_config(const unified_attention_args& args) } // ----------------------------------------------------------------------------- -// Dispatch macros and per-variant dispatch helpers. +// dispatch_variant // -// Each DISPATCH_* macro instantiates one (traits, dtype, mask, ...) combo and -// returns. The per-variant helpers below pick the right macro family and fan -// out over (dtype, mask). They look repetitive on purpose: a follow-up commit -// will collapse the 4 traits classes into one templated `kernel_traits`, -// at which point these helpers become one-liners. +// One function template. Fans out over (dtype, mask) and forwards into the +// per-instance dispatch generated by INST_UNIFIED_ATTENTION_DISPATCH. No +// per-variant boilerplate. // ----------------------------------------------------------------------------- - -// Helper macro: dispatches based on DataType, IsMasking, HeadSize, BlockM, NumQPerKV. -#define DISPATCH_UNIFIED_ATTENTION(DType, IsMask, HSize, BM, NQPKV) \ - { \ - using kernel_traits = unified_attention_kernel_traits; \ - return unified_attention_kernel_dispatch(args, config); \ - } - -// Dispatch macros for three tile tiers (default block_size). -#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \ - { \ - using kernel_traits = unified_attention_decode_kernel_traits; \ - return unified_attention_kernel_dispatch(args, config); \ - } - -#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType, IsMask, HSize, BM, NQPKV) \ - { \ - using kernel_traits = unified_attention_decode_small_kernel_traits; \ - return unified_attention_kernel_dispatch_decode(args, config); \ - } - -#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType, IsMask, HSize, BM, NQPKV) \ - { \ - using kernel_traits = unified_attention_decode_tiny_kernel_traits; \ - return unified_attention_kernel_dispatch_decode(args, config); \ - } - namespace { -using DType = unified_attention_args::data_type_enum; - -std::pair dispatch_prefill_d128_mha( - const unified_attention_args& args, const stream_config& config) +template +std::pair dispatch_variant(const unified_attention_args& args, + const stream_config& config) { + using DT = unified_attention_args::data_type_enum; const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - if (args.data_type == DType::fp16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::fp16, false, 128, 256, 1) - else DISPATCH_UNIFIED_ATTENTION(DType::fp16, true, 128, 256, 1) - } else if (args.data_type == DType::bf16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::bf16, false, 128, 256, 1) - else DISPATCH_UNIFIED_ATTENTION(DType::bf16, true, 128, 256, 1) + + if(args.data_type == DT::fp16) + { + if(is_mask) + return unified_attention_kernel_dispatch< + unified_attention_kernel_traits>(args, config); + return unified_attention_kernel_dispatch< + unified_attention_kernel_traits>(args, config); } - return {false, -1.f}; -} - -std::pair dispatch_decode_d128_mha_m128( - const unified_attention_args& args, const stream_config& config) -{ - const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - if (args.data_type == DType::fp16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, false, 128, 128, 1) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, true, 128, 128, 1) - } else if (args.data_type == DType::bf16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, false, 128, 128, 1) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, true, 128, 128, 1) - } - return {false, -1.f}; -} - -std::pair dispatch_prefill_d64_gqa8( - const unified_attention_args& args, const stream_config& config) -{ - const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - if (args.data_type == DType::fp16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::fp16, false, 64, 256, 8) - else DISPATCH_UNIFIED_ATTENTION(DType::fp16, true, 64, 256, 8) - } else if (args.data_type == DType::bf16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::bf16, false, 64, 256, 8) - else DISPATCH_UNIFIED_ATTENTION(DType::bf16, true, 64, 256, 8) - } - return {false, -1.f}; -} - -std::pair dispatch_decode_d64_gqa8_m128( - const unified_attention_args& args, const stream_config& config) -{ - const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - if (args.data_type == DType::fp16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, true, 64, 128, 8) - } else if (args.data_type == DType::bf16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, true, 64, 128, 8) - } - return {false, -1.f}; -} - -std::pair dispatch_decode_d64_gqa8_m64( - const unified_attention_args& args, const stream_config& config) -{ - const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - if (args.data_type == DType::fp16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::fp16, false, 64, 64, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::fp16, true, 64, 64, 8) - } else if (args.data_type == DType::bf16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::bf16, false, 64, 64, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::bf16, true, 64, 64, 8) - } - return {false, -1.f}; -} - -std::pair dispatch_decode_d64_gqa8_m16( - const unified_attention_args& args, const stream_config& config) -{ - const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - if (args.data_type == DType::fp16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::fp16, false, 64, 16, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::fp16, true, 64, 16, 8) - } else if (args.data_type == DType::bf16) { - if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::bf16, false, 64, 16, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::bf16, true, 64, 16, 8) + if(args.data_type == DT::bf16) + { + if(is_mask) + return unified_attention_kernel_dispatch< + unified_attention_kernel_traits>(args, config); + return unified_attention_kernel_dispatch< + unified_attention_kernel_traits>(args, config); } return {false, -1.f}; } @@ -256,30 +157,31 @@ std::pair unified_attention(const unified_attention_args& args, { const auto cfg = select_config(args); - if (cfg.unsupported) + if(cfg.unsupported) { std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim << " num_queries_per_kv=" << args.num_queries_per_kv - << " data_type=" << args.data_type - << " mask_type=" << args.mask_type << std::endl; + << " data_type=" << args.data_type << " mask_type=" << args.mask_type + << std::endl; return std::make_pair(false, -1.f); } - switch (cfg.variant) + switch(cfg.variant) { - case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config); - case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config); - case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config); - case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config); - case KernelVariant::decode_d64_gqa8_m64: return dispatch_decode_d64_gqa8_m64(args, config); - case KernelVariant::decode_d64_gqa8_m16: return dispatch_decode_d64_gqa8_m16(args, config); + case KernelVariant::prefill_d128_mha: + return dispatch_variant(args, config); + case KernelVariant::decode_d128_mha_m128: + return dispatch_variant(args, config); + case KernelVariant::prefill_d64_gqa8: + return dispatch_variant(args, config); + case KernelVariant::decode_d64_gqa8_m128: + return dispatch_variant(args, config); + case KernelVariant::decode_d64_gqa8_m64: + return dispatch_variant(args, config); + case KernelVariant::decode_d64_gqa8_m16: + return dispatch_variant(args, config); } return std::make_pair(false, -1.f); } -#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY -#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL -#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM -#undef DISPATCH_UNIFIED_ATTENTION - } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 55b2f39216..6fd845278e 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -20,26 +20,36 @@ #include "unified_attention.hpp" #include "mask.hpp" -#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \ - template <> \ - std::pair unified_attention_kernel_dispatch( \ - const unified_attention_args& args, const stream_config& config) \ - { \ - return std::make_pair( \ - true, unified_attention_kernel_launch(args, config)); \ - } - -#define INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) \ - template <> \ - std::pair unified_attention_kernel_dispatch_decode( \ - const unified_attention_args& args, const stream_config& config) \ - { \ - return std::make_pair( \ - true, unified_attention_kernel_launch(args, config)); \ - } - namespace ck_tile { +// ============================================================================= +// KernelVariant +// +// Flat enum of every compiled kernel instance. Each variant fixes +// (kBlockM, warp count, MFMA shape, pipeline policy) via a variant_config +// specialization below. This is the single source of truth for "what knobs +// differ between kernel instances". +// +// page_size is intentionally NOT part of this enum. The multi-page-tile fix +// in the pipeline decoupled kBlockN from page_blk_size, so every variant is +// correct for any page size. +// ============================================================================= +enum class KernelVariant +{ + // d=128 MHA (num_queries_per_kv = 1) + prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma + decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128) + + // d=64 GQA-8 (num_queries_per_kv = 8) + prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma + decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma + decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma (decode policy) + decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy) +}; + +// ----------------------------------------------------------------------------- +// Per-DataType problem element types. +// ----------------------------------------------------------------------------- template struct unified_attention_problem_traits; @@ -61,261 +71,182 @@ struct unified_attention_problem_traits +// ============================================================================= +// variant_config +// +// One specialization per KernelVariant. Each exposes the static knobs that +// distinguish that variant from the others: +// +// HeadSize : head dimension (compile-time) +// BlockM : Q-tile size along the M (token) axis +// NumQPerKV : 1 for MHA, 8 for GQA-8 +// BlockSize : kBlockN — KV-tile size along the N axis +// BlockWarps : warp layout, sequence +// WarpGemmShape : MFMA tile shape, sequence +// Pipeline

: pipeline template (default vs decode vs tiny-decode policy) +// kUseDecodeGrid : selects 2D-by-seq grid (true) vs Q-block grid (false) +// ============================================================================= +template +struct variant_config; + +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 128; + static constexpr index_t BlockM = 256; + static constexpr index_t NumQPerKV = 1; + static constexpr index_t BlockSize = 32; + using BlockWarps = sequence<8, 1, 1>; + using WarpGemmShape = sequence<32, 32, 16>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = false; +}; + +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 128; + static constexpr index_t BlockM = 128; + static constexpr index_t NumQPerKV = 1; + static constexpr index_t BlockSize = 32; + using BlockWarps = sequence<4, 1, 1>; + using WarpGemmShape = sequence<32, 32, 16>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = false; +}; + +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 64; + static constexpr index_t BlockM = 256; + static constexpr index_t NumQPerKV = 8; + static constexpr index_t BlockSize = 64; + using BlockWarps = sequence<8, 1, 1>; + using WarpGemmShape = sequence<32, 32, 16>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = false; +}; + +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 64; + static constexpr index_t BlockM = 128; + static constexpr index_t NumQPerKV = 8; + static constexpr index_t BlockSize = 64; + using BlockWarps = sequence<4, 1, 1>; + using WarpGemmShape = sequence<32, 32, 16>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = false; +}; + +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 64; + static constexpr index_t BlockM = 64; + static constexpr index_t NumQPerKV = 8; + static constexpr index_t BlockSize = 64; + using BlockWarps = sequence<2, 1, 1>; + using WarpGemmShape = sequence<32, 32, 16>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = true; +}; + +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 64; + static constexpr index_t BlockM = 16; + static constexpr index_t NumQPerKV = 8; + static constexpr index_t BlockSize = 64; + using BlockWarps = sequence<1, 1, 1>; + using WarpGemmShape = sequence<16, 16, 32>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = true; +}; + +// ============================================================================= +// unified_attention_kernel_traits +// +// Single templated trait. Pulls per-variant knobs from variant_config and +// per-dtype element types from unified_attention_problem_traits. +// ============================================================================= +template struct unified_attention_kernel_traits { - static constexpr auto date_type = DataType; - static constexpr bool is_masking = IsMasking; + using cfg = variant_config; + using dt = unified_attention_problem_traits; - static constexpr index_t kBlockM = BlockM_; - static constexpr index_t HEAD_SIZE = HeadSize_; - static constexpr index_t BLOCK_SIZE = BlockSize_; + static constexpr auto date_type = DataType; + static constexpr bool is_masking = IsMasking; + static constexpr KernelVariant variant = V; - static constexpr index_t num_queries_per_kv = NumQPerKV_; + static constexpr index_t HEAD_SIZE = cfg::HeadSize; + static constexpr index_t kBlockM = cfg::BlockM; + static constexpr index_t BLOCK_SIZE = cfg::BlockSize; + static constexpr index_t num_queries_per_kv = cfg::NumQPerKV; static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; + static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid; - // kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE using unified_attention_block_tile = sequence; - - using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; - // 8 warps for warp specialization; kBlockM must be 8 * 32 = 256 - using unified_attention_block_warps = sequence<8, 1, 1>; + using unified_attention_warp_gemm_shape = typename cfg::WarpGemmShape; + using unified_attention_block_warps = typename cfg::BlockWarps; using unified_attention_shape = TileUnifiedAttentionShape; + true>; // IsVLayoutRowMajor using unified_attention_traits = TileUnifiedAttentionTraits; + -1>; // kBlockPerCu + using unified_attention_mask = GenericAttentionMask; - using unified_attention_mask = GenericAttentionMask; - - using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem< - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::lse_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - unified_attention_shape, - unified_attention_mask, - unified_attention_traits>; - - using unified_attention_pipeline = UnifiedAttentionPipeline; - - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - true, // kPadM - true, // kPadM - true // UseRawStore - >>; - - using kernel = UnifiedAttentionKernel; -}; - -// Decode-tuned traits: 4 warps (1 warp group), kBlockM=128, serial pipeline. -// Uses the single-warp-group path in UnifiedAttentionPipeline. -template -struct unified_attention_decode_kernel_traits -{ - static constexpr auto date_type = DataType; - static constexpr bool is_masking = IsMasking; - - static constexpr index_t kBlockM = BlockM_; - static constexpr index_t HEAD_SIZE = HeadSize_; - static constexpr index_t BLOCK_SIZE = BlockSize_; - - static constexpr index_t num_queries_per_kv = NumQPerKV_; - static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; - - // kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE - using unified_attention_block_tile = sequence; - using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; - // 4 warps -> kBlockSize = 256 threads -> NumWarpGroups = 1 - using unified_attention_block_warps = sequence<4, 1, 1>; - - using unified_attention_shape = TileUnifiedAttentionShape; - - using unified_attention_traits = TileUnifiedAttentionTraits; - using unified_attention_mask = GenericAttentionMask; - - using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem< - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::lse_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - unified_attention_shape, - unified_attention_mask, - unified_attention_traits>; - - using unified_attention_pipeline = UnifiedAttentionPipeline; - - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - true, true, true>>; - - using kernel = UnifiedAttentionKernel; -}; - -// Small decode traits: 2 warps, kBlockM=64, decode policy (NumWarpPerGroup=2). -// Uses 1D warp layout (sequence<2,1,1>) so no softmax reduction changes needed. -template -struct unified_attention_decode_small_kernel_traits -{ - static constexpr auto date_type = DataType; - static constexpr bool is_masking = IsMasking; - - static constexpr index_t kBlockM = BlockM_; - static constexpr index_t HEAD_SIZE = HeadSize_; - static constexpr index_t BLOCK_SIZE = BlockSize_; - - static constexpr index_t num_queries_per_kv = NumQPerKV_; - static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; - - using unified_attention_block_tile = sequence; - using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; - // 2 warps along M: kBlockM=2*32=64, kBlockSize=128, NumWarpGroups=1 - using unified_attention_block_warps = sequence<2, 1, 1>; - - using unified_attention_shape = TileUnifiedAttentionShape; - - using unified_attention_traits = TileUnifiedAttentionTraits; - using unified_attention_mask = GenericAttentionMask; - - using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem< - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::lse_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - unified_attention_shape, - unified_attention_mask, - unified_attention_traits>; + using unified_attention_pipeline_problem = + UnifiedAttentionPipelineProblem; using unified_attention_pipeline = - UnifiedAttentionPipeline; + typename cfg::template Pipeline; - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - true, true, true>>; - - using kernel = UnifiedAttentionKernel; -}; - -// Tiny decode traits: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2 for GQA-8. -// Matches Triton's BLOCK_M=16 / BLOCK_Q=2 decode configuration. -// Uses block_tile_reduce_sync instead of permlane32_swap for 16x16 MFMA. -template -struct unified_attention_decode_tiny_kernel_traits -{ - static constexpr auto date_type = DataType; - static constexpr bool is_masking = IsMasking; - - static constexpr index_t kBlockM = BlockM_; - static constexpr index_t HEAD_SIZE = HeadSize_; - static constexpr index_t BLOCK_SIZE = BlockSize_; - - static constexpr index_t num_queries_per_kv = NumQPerKV_; - static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; - - using unified_attention_block_tile = sequence; - using unified_attention_warp_gemm_shape = sequence<16, 16, 32>; - // 1 warp: kBlockM=1*16=16, kBlockSize=64, NumWarpGroups=1 - using unified_attention_block_warps = sequence<1, 1, 1>; - - using unified_attention_shape = TileUnifiedAttentionShape; - - using unified_attention_traits = TileUnifiedAttentionTraits; - using unified_attention_mask = GenericAttentionMask; - - using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem< - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::lse_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - unified_attention_shape, - unified_attention_mask, - unified_attention_traits>; - - using unified_attention_pipeline = - UnifiedAttentionPipeline; - - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - true, true, true>>; + using epilogue = + Default2DEpilogue>; using kernel = UnifiedAttentionKernel; }; +// ============================================================================= +// Kernel launch — common helper. Picks the grid layout from +// Traits::kUseDecodeGrid; all other launch args are identical across variants. +// ============================================================================= template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) @@ -380,15 +311,33 @@ float unified_attention_kernel_launch(const unified_attention_args& args, return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); } -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -template +// ============================================================================= +// Per-instance dispatch. Each instance .cpp specializes this for its +// (V, DataType, IsMasking) tuple via INST_UNIFIED_ATTENTION_DISPATCH. +// +// Return: (launched?, elapsed_ms). elapsed_ms is valid only when launched. +// ============================================================================= +template std::pair unified_attention_kernel_dispatch(const unified_attention_args& args, const stream_config& config); -template -std::pair unified_attention_kernel_dispatch_decode(const unified_attention_args& args, - const stream_config& config); - } // namespace ck_tile + +// One-line instantiation per (V, DataType, IsMasking) combination. Each +// instance .cpp consists of exactly one of these calls. +#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \ + template <> \ + std::pair unified_attention_kernel_dispatch< \ + unified_attention_kernel_traits>(const unified_attention_args& args, \ + const stream_config& config) \ + { \ + using Traits = unified_attention_kernel_traits< \ + KernelVariant::VARIANT_, \ + unified_attention_args::data_type_enum::DTYPE_, \ + IS_MASK_>; \ + return std::make_pair(true, \ + unified_attention_kernel_launch(args, config)); \ + }