From d77f0bea63d1bc2cf041d1353137771b9653dc35 Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Tue, 12 May 2026 12:15:55 +0000 Subject: [PATCH] CK-UA: collapse MHA/GQA variants -- one binary per (head_dim, kBlockM) After moving kBlockQ to runtime in the previous commit, the static NumQPerKV in `variant_config` and the runtime-vs-static assert in the kernel became the only things still tying a compiled binary to a specific num_queries_per_kv. Drop both and the existing instances now serve every num_qpkv that divides kBlockM evenly. Concretely: * `variant_config` -- remove the NumQPerKV field from every specialization. * `unified_attention_kernel_traits` -- remove the `num_queries_per_kv` / `kBlockQ = kBlockM / num_qpkv` derivation. The BlockTile's 2nd entry (the static `kBlockQ` exposed via UnifiedAttentionShape) is anchored at kBlockM so it describes the "num_qpkv == 1" fallback; the actual kBlockQ is always the runtime value. * `unified_attention_kernel_launch` -- recompute kBlockQ at host time from `args.num_queries_per_kv` for the total_num_q_blocks math. * `unified_attention_kernel.hpp` -- drop the `assert(kBlockQ_dyn == kBlockQ)` (it enforced the very coupling we just removed). * `unified_attention.cpp::select_config` -- collapse the two per-num_qpkv code paths into a single (head_dim, avg_rows, max_rows) ladder, where avg_rows = avg_q * num_qpkv. Variant renames (8 variants): prefill_d128_mha -> prefill_d128 decode_d128_mha_m128 -> decode_d128_m128 decode_d128_mha_m32 -> decode_d128_m32 decode_d128_mha_m16 -> decode_d128_m16 prefill_d64_gqa8 -> prefill_d64 decode_d64_gqa8_m128 -> decode_d64_m128 decode_d64_gqa8_m64 -> decode_d64_m64 decode_d64_gqa8_m16 -> decode_d64_m16 The 16 d=64 instance files lose their `_gqa8` infix to match the d=128 naming (file count unchanged: 16 dtypes x mask combos per head_dim). Validation: * Correctness suite: 241/245 (same 4 pre-existing int32-overflow failures in the prefill rebased-pointer path). * d=128 GQA-8 (a NEW combo we never had a binary for) -- runs correctly on the existing decode_d128_m* binaries with num_qpkv=8 at runtime. max abs diff <= 1e-2 vs the torch reference at ql in {1, 4, 16}. * d=64 MHA (also a new combo) -- runs correctly on the existing decode_d64_m* binaries with num_qpkv=1. Same tolerance. * Perf sweep (b=4..256, sk=120000, MI300): d=64 GQA-8: speedups 1.28x..1.84x vs Triton (within 0.6% of baseline). d=128 MHA: speedups 0.98x..1.14x vs Triton (within 0.3% of baseline). Unlocked: adding new (head_dim, num_qpkv) combos no longer requires new kernel binaries -- just a host-side heuristic update mapping the combo to the appropriate (kBlockM, BlockWarps) ladder. Co-authored-by: Cursor --- .../unified_attention_d128_bf16_mask.cpp | 2 +- ...nified_attention_d128_bf16_mask_decode.cpp | 2 +- ...fied_attention_d128_bf16_mask_decode_s.cpp | 2 +- ...fied_attention_d128_bf16_mask_decode_t.cpp | 2 +- .../unified_attention_d128_bf16_nmask.cpp | 2 +- ...ified_attention_d128_bf16_nmask_decode.cpp | 2 +- ...ied_attention_d128_bf16_nmask_decode_s.cpp | 2 +- ...ied_attention_d128_bf16_nmask_decode_t.cpp | 2 +- .../unified_attention_d128_fp16_mask.cpp | 2 +- ...nified_attention_d128_fp16_mask_decode.cpp | 2 +- ...fied_attention_d128_fp16_mask_decode_s.cpp | 2 +- ...fied_attention_d128_fp16_mask_decode_t.cpp | 2 +- .../unified_attention_d128_fp16_nmask.cpp | 2 +- ...ified_attention_d128_fp16_nmask_decode.cpp | 2 +- ...ied_attention_d128_fp16_nmask_decode_s.cpp | 2 +- ...ied_attention_d128_fp16_nmask_decode_t.cpp | 2 +- ...pp => unified_attention_d64_bf16_mask.cpp} | 2 +- ...nified_attention_d64_bf16_mask_decode.cpp} | 2 +- ...fied_attention_d64_bf16_mask_decode_s.cpp} | 2 +- ...fied_attention_d64_bf16_mask_decode_t.cpp} | 2 +- ...ed_attention_d64_bf16_mask_gqa8_decode.cpp | 11 -- ..._attention_d64_bf16_mask_gqa8_decode_s.cpp | 11 -- ..._attention_d64_bf16_mask_gqa8_decode_t.cpp | 11 -- .../unified_attention_d64_bf16_nmask.cpp | 11 ++ ...nified_attention_d64_bf16_nmask_decode.cpp | 11 ++ ...fied_attention_d64_bf16_nmask_decode_s.cpp | 11 ++ ...fied_attention_d64_bf16_nmask_decode_t.cpp | 11 ++ ...d_attention_d64_bf16_nmask_gqa8_decode.cpp | 11 -- ...attention_d64_bf16_nmask_gqa8_decode_s.cpp | 11 -- ...attention_d64_bf16_nmask_gqa8_decode_t.cpp | 11 -- .../unified_attention_d64_fp16_mask.cpp | 11 ++ ...unified_attention_d64_fp16_mask_decode.cpp | 11 ++ ...ified_attention_d64_fp16_mask_decode_s.cpp | 11 ++ ...ified_attention_d64_fp16_mask_decode_t.cpp | 11 ++ ...ed_attention_d64_fp16_mask_gqa8_decode.cpp | 11 -- ..._attention_d64_fp16_mask_gqa8_decode_s.cpp | 11 -- ..._attention_d64_fp16_mask_gqa8_decode_t.cpp | 11 -- .../unified_attention_d64_fp16_nmask.cpp | 11 ++ ...nified_attention_d64_fp16_nmask_decode.cpp | 11 ++ ...fied_attention_d64_fp16_nmask_decode_s.cpp | 11 ++ ...fied_attention_d64_fp16_nmask_decode_t.cpp | 11 ++ ...d_attention_d64_fp16_nmask_gqa8_decode.cpp | 11 -- ...attention_d64_fp16_nmask_gqa8_decode_s.cpp | 11 -- ...attention_d64_fp16_nmask_gqa8_decode_t.cpp | 11 -- .../unified_attention.cpp | 122 ++++++++---------- .../unified_attention_impl.hpp | 80 ++++++------ .../kernel/unified_attention_kernel.hpp | 13 +- 47 files changed, 254 insertions(+), 265 deletions(-) rename example/ck_tile/42_unified_attention/instances/{unified_attention_d64_bf16_mask_gqa8.cpp => unified_attention_d64_bf16_mask.cpp} (78%) rename example/ck_tile/42_unified_attention/instances/{unified_attention_d64_fp16_mask_gqa8.cpp => unified_attention_d64_bf16_mask_decode.cpp} (78%) rename example/ck_tile/42_unified_attention/instances/{unified_attention_d64_bf16_nmask_gqa8.cpp => unified_attention_d64_bf16_mask_decode_s.cpp} (78%) rename example/ck_tile/42_unified_attention/instances/{unified_attention_d64_fp16_nmask_gqa8.cpp => unified_attention_d64_bf16_mask_decode_t.cpp} (78%) delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_t.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode_t.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_t.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode_t.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_t.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode_t.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp delete mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_t.cpp diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp index 3c80909df3..7a3f084981 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, true) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp index 9af6ef61b0..2f2dc5c65b 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, true) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp index 804d8a1761..0b60f16f14 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, true) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp index dfaa9b3dad..0967f11506 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, true) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp index d5e88ba319..443da0e955 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, false) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp index ed3de85a5e..688ddd1616 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, false) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp index 21301cc083..8873d678c4 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, false) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp index da7c91915d..8fdc025c70 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, false) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, bf16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp index 846cda09ad..fa82996675 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, true) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp index 9c5a317b51..d84e775cc1 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, true) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp index d5cbc67bfa..f55a0dfad5 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, true) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp index e22ac838c3..c827153f17 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, true) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, fp16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp index e993d6c572..6ac0562c74 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, false) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp index d73f2dbe5f..da55f834e4 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, false) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp index be137ee375..e210181341 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, false) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp index abb86554ee..8d50290fdf 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, false) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, fp16, false) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask.cpp similarity index 78% rename from example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp rename to example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask.cpp index e7cad98691..4bf2f2eb9b 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, true) +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_decode.cpp similarity index 78% rename from example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp rename to example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_decode.cpp index e23dc6673c..46d920f7a3 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_decode.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, true) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_decode_s.cpp similarity index 78% rename from example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp rename to example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_decode_s.cpp index 80d1d5d45a..2cba78f144 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_decode_s.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, false) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_decode_t.cpp similarity index 78% rename from example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp rename to example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_decode_t.cpp index 6d104b2a0f..e92c77fee2 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_decode_t.cpp @@ -6,6 +6,6 @@ namespace ck_tile { -INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, false) +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, bf16, true) } // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode.cpp deleted file mode 100644 index b0a8f44eec..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, true) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp deleted file mode 100644 index 793e65047d..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, true) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_t.cpp deleted file mode 100644 index 5f34f0319d..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_t.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, true) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask.cpp new file mode 100644 index 0000000000..d91f13c7b4 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode.cpp new file mode 100644 index 0000000000..4567ccd8d0 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode_s.cpp new file mode 100644 index 0000000000..b93e8b10ce --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode_t.cpp new file mode 100644 index 0000000000..a6fc4cfa92 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode.cpp deleted file mode 100644 index 1cb0db0896..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, false) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp deleted file mode 100644 index 7b666acb97..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, false) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_t.cpp deleted file mode 100644 index 2dc0db7b42..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_t.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, false) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask.cpp new file mode 100644 index 0000000000..e2c0573d75 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode.cpp new file mode 100644 index 0000000000..11102dba7a --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode_s.cpp new file mode 100644 index 0000000000..13b4105dc1 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode_t.cpp new file mode 100644 index 0000000000..c7b8a2943b --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode.cpp deleted file mode 100644 index c691b65da3..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, true) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp deleted file mode 100644 index 718866ca7e..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, true) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_t.cpp deleted file mode 100644 index 4b5d95a193..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_t.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, true) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask.cpp new file mode 100644 index 0000000000..a63d6885b4 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode.cpp new file mode 100644 index 0000000000..f31cf7c189 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode_s.cpp new file mode 100644 index 0000000000..0608ef32af --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode_t.cpp new file mode 100644 index 0000000000..d546b23a5d --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode.cpp deleted file mode 100644 index b3484577f6..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, false) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp deleted file mode 100644 index 0650df3cca..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, false) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_t.cpp deleted file mode 100644 index 24f2ff30af..0000000000 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_t.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "unified_attention.hpp" -#include "unified_attention_impl.hpp" - -namespace ck_tile { - -INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, false) - -} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index ea89ddeede..b563e26643 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -46,72 +46,52 @@ struct KernelConfig bool unsupported = false; }; -namespace { - -// Internal tile-tier classification — used only by select_config. The tier -// name is shorthand for a kBlockM choice; with num_queries_per_kv=8 the -// tiers correspond to kBlockQ thresholds {2, 8, 16}. -enum class tile_tier -{ - medium, - small, - tiny -}; - -tile_tier select_tile_tier(const unified_attention_args& args) -{ - const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs - : args.num_tokens; - const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv; - const index_t kBlockQ_small = 64 / args.num_queries_per_kv; - - // Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each - // seq has at most kBlockQ tokens. For mixed batches where some seqs have - // many more tokens, fall back to the medium tier (1D grid with Q iteration). - const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q; - - if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny; - if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small; - return tile_tier::medium; -} - -} // anonymous namespace - KernelConfig select_config(const unified_attention_args& args) { KernelConfig cfg; - // d=128 MHA — tile-tier ladder by (avg_q, max_q): - // * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 mfma (tiny-decode) - // * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 mfma (tiny-decode) - // * decode_d128_mha_m128 : kBlockM=128, 4 warps, 32x32 mfma (default) - // * prefill_d128_mha : kBlockM=256, 8 warps, 32x32 mfma - if(args.hdim == 128 && args.num_queries_per_kv == 1) - { - const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs + // The variants are now num_queries_per_kv-agnostic (kBlockQ is runtime + // inside the kernel) -- we just have to pick a kBlockM that holds enough + // rows for `num_qpkv * max_q` and that num_qpkv divides cleanly. + // + // `avg_q * num_qpkv` is the *effective* per-CTA tile occupancy; e.g. + // GQA-8 with sq=1 produces 8 rows per Q tile, the same as MHA with sq=8. + // Tiering on that quantity lets one variant ladder serve both regimes. + const index_t num_qpkv = args.num_queries_per_kv; + const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens; - const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q; + const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q; + const index_t avg_rows = avg_q * num_qpkv; // effective rows per Q tile + const index_t max_rows = max_q * num_qpkv; - if(avg_q <= 16 && max_q <= 16) - cfg.variant = KernelVariant::decode_d128_mha_m16; - else if(avg_q <= 32 && max_q <= 32) - cfg.variant = KernelVariant::decode_d128_mha_m32; - else if(avg_q <= 128 && max_q <= 128) - cfg.variant = KernelVariant::decode_d128_mha_m128; + if(args.hdim == 128) + { + // d=128 ladder: m16 / m32 / m128 / prefill (m256). Requires + // num_qpkv to divide the chosen kBlockM, which is automatic for + // num_qpkv in {1, 2, 4, 8, 16} and any of the kBlockM's below. + if(avg_rows <= 16 && max_rows <= 16) + cfg.variant = KernelVariant::decode_d128_m16; + else if(avg_rows <= 32 && max_rows <= 32) + cfg.variant = KernelVariant::decode_d128_m32; + else if(avg_rows <= 128 && max_rows <= 128) + cfg.variant = KernelVariant::decode_d128_m128; else - cfg.variant = KernelVariant::prefill_d128_mha; + cfg.variant = KernelVariant::prefill_d128; return cfg; } - // d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here. - if(args.hdim == 64 && args.num_queries_per_kv == 8) + if(args.hdim == 64) { - switch(select_tile_tier(args)) - { - case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break; - case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break; - case tile_tier::medium: cfg.variant = KernelVariant::decode_d64_gqa8_m128; break; - } + // d=64 ladder: m16 / m64 / m128 / prefill (m256). Same shape + // selection logic as d=128; the variant's kBlockN is just bigger. + if(avg_rows <= 16 && max_rows <= 16) + cfg.variant = KernelVariant::decode_d64_m16; + else if(avg_rows <= 64 && max_rows <= 64) + cfg.variant = KernelVariant::decode_d64_m64; + else if(avg_rows <= 128 && max_rows <= 128) + cfg.variant = KernelVariant::decode_d64_m128; + else + cfg.variant = KernelVariant::prefill_d64; return cfg; } @@ -172,22 +152,22 @@ std::pair unified_attention(const unified_attention_args& args, switch(cfg.variant) { - case KernelVariant::prefill_d128_mha: - return dispatch_variant(args, config); - case KernelVariant::decode_d128_mha_m128: - return dispatch_variant(args, config); - case KernelVariant::decode_d128_mha_m32: - return dispatch_variant(args, config); - case KernelVariant::decode_d128_mha_m16: - return dispatch_variant(args, config); - case KernelVariant::prefill_d64_gqa8: - return dispatch_variant(args, config); - case KernelVariant::decode_d64_gqa8_m128: - return dispatch_variant(args, config); - case KernelVariant::decode_d64_gqa8_m64: - return dispatch_variant(args, config); - case KernelVariant::decode_d64_gqa8_m16: - return dispatch_variant(args, config); + case KernelVariant::prefill_d128: + return dispatch_variant(args, config); + case KernelVariant::decode_d128_m128: + return dispatch_variant(args, config); + case KernelVariant::decode_d128_m32: + return dispatch_variant(args, config); + case KernelVariant::decode_d128_m16: + return dispatch_variant(args, config); + case KernelVariant::prefill_d64: + return dispatch_variant(args, config); + case KernelVariant::decode_d64_m128: + return dispatch_variant(args, config); + case KernelVariant::decode_d64_m64: + return dispatch_variant(args, config); + case KernelVariant::decode_d64_m16: + return dispatch_variant(args, config); } return std::make_pair(false, -1.f); } diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 20c63318c7..8f736dfe01 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -36,17 +36,20 @@ namespace ck_tile { // ============================================================================= enum class KernelVariant { - // d=128 MHA (num_queries_per_kv = 1) - prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma - decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128) - decode_d128_mha_m32, // kBlockM=32, 1 warp, 32x32 mfma (tiny-decode policy) - decode_d128_mha_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy) + // d=128 (num_queries_per_kv chosen at *runtime* — same binary serves both + // MHA and GQA-N as long as num_qpkv divides kBlockM). kBlockM is the only + // structural compile-time knob; pick the tier by max_q after multiplying + // by num_qpkv in select_config. + prefill_d128, // kBlockM=256, 8 warps, 32x32 mfma + decode_d128_m128, // kBlockM=128, 4 warps, 32x32 mfma + decode_d128_m32, // kBlockM=32, 1 warp, 32x32 mfma (tiny-decode policy) + decode_d128_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy) - // d=64 GQA-8 (num_queries_per_kv = 8) - prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma - decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma - decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma (decode policy) - decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy) + // d=64. + prefill_d64, // kBlockM=256, 8 warps, 32x32 mfma + decode_d64_m128, // kBlockM=128, 4 warps, 32x32 mfma + decode_d64_m64, // kBlockM=64, 2 warps, 32x32 mfma (decode policy) + decode_d64_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy) }; // ----------------------------------------------------------------------------- @@ -81,22 +84,25 @@ struct unified_attention_problem_traits // WarpGemmShape : MFMA tile shape, sequence // Pipeline

: pipeline template (default vs decode vs tiny-decode policy) // kUseDecodeGrid : selects 2D-by-seq grid (true) vs Q-block grid (false) +// +// num_queries_per_kv is *not* a compile-time knob: kBlockQ = kBlockM / +// num_qpkv is computed at runtime inside the kernel and pipeline. The only +// constraint is `kBlockM % num_qpkv == 0` (host-side select_config makes sure +// of this). // ============================================================================= template struct variant_config; template <> -struct variant_config +struct variant_config { static constexpr index_t HeadSize = 128; static constexpr index_t BlockM = 256; - static constexpr index_t NumQPerKV = 1; static constexpr index_t BlockSize = 32; using BlockWarps = sequence<8, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; @@ -106,11 +112,10 @@ struct variant_config }; template <> -struct variant_config +struct variant_config { static constexpr index_t HeadSize = 128; static constexpr index_t BlockM = 128; - static constexpr index_t NumQPerKV = 1; static constexpr index_t BlockSize = 32; using BlockWarps = sequence<4, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; @@ -120,11 +125,10 @@ struct variant_config }; template <> -struct variant_config +struct variant_config { static constexpr index_t HeadSize = 128; static constexpr index_t BlockM = 32; - static constexpr index_t NumQPerKV = 1; static constexpr index_t BlockSize = 32; using BlockWarps = sequence<1, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; @@ -134,11 +138,10 @@ struct variant_config }; template <> -struct variant_config +struct variant_config { static constexpr index_t HeadSize = 128; static constexpr index_t BlockM = 16; - static constexpr index_t NumQPerKV = 1; static constexpr index_t BlockSize = 32; using BlockWarps = sequence<1, 1, 1>; using WarpGemmShape = sequence<16, 16, 32>; @@ -148,11 +151,10 @@ struct variant_config }; template <> -struct variant_config +struct variant_config { static constexpr index_t HeadSize = 64; static constexpr index_t BlockM = 256; - static constexpr index_t NumQPerKV = 8; static constexpr index_t BlockSize = 64; using BlockWarps = sequence<8, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; @@ -162,11 +164,10 @@ struct variant_config }; template <> -struct variant_config +struct variant_config { static constexpr index_t HeadSize = 64; static constexpr index_t BlockM = 128; - static constexpr index_t NumQPerKV = 8; static constexpr index_t BlockSize = 64; using BlockWarps = sequence<4, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; @@ -176,11 +177,10 @@ struct variant_config }; template <> -struct variant_config +struct variant_config { static constexpr index_t HeadSize = 64; static constexpr index_t BlockM = 64; - static constexpr index_t NumQPerKV = 8; static constexpr index_t BlockSize = 64; using BlockWarps = sequence<2, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; @@ -190,11 +190,10 @@ struct variant_config }; template <> -struct variant_config +struct variant_config { static constexpr index_t HeadSize = 64; static constexpr index_t BlockM = 16; - static constexpr index_t NumQPerKV = 8; static constexpr index_t BlockSize = 64; using BlockWarps = sequence<1, 1, 1>; using WarpGemmShape = sequence<16, 16, 32>; @@ -221,14 +220,19 @@ struct unified_attention_kernel_traits static constexpr bool is_masking = IsMasking; static constexpr KernelVariant variant = V; - static constexpr index_t HEAD_SIZE = cfg::HeadSize; - static constexpr index_t kBlockM = cfg::BlockM; - static constexpr index_t BLOCK_SIZE = cfg::BlockSize; - static constexpr index_t num_queries_per_kv = cfg::NumQPerKV; - static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; - static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid; + static constexpr index_t HEAD_SIZE = cfg::HeadSize; + static constexpr index_t kBlockM = cfg::BlockM; + static constexpr index_t BLOCK_SIZE = cfg::BlockSize; + static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid; - using unified_attention_block_tile = sequence; + // The 2nd entry of the BlockTile is the static `kBlockQ` exposed via + // `UnifiedAttentionShape::kBlockQ`. Now that the kernel always reads + // kBlockQ from `args.num_queries_per_kv` at runtime, this static value + // is only the fallback when no num_qpkv was plumbed through (which never + // happens in practice). Anchor it at kBlockM so the static "looks like + // num_qpkv == 1" and any (kBlockM, num_qpkv) such that kBlockM % num_qpkv + // == 0 works without touching this trait. + using unified_attention_block_tile = sequence; using unified_attention_warp_gemm_shape = typename cfg::WarpGemmShape; using unified_attention_block_warps = typename cfg::BlockWarps; @@ -281,8 +285,12 @@ template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { - constexpr index_t kBlockQ = Kernel::kBlockQ; - index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs; + // kBlockQ is derived from the runtime num_queries_per_kv now -- the + // static `Kernel::kBlockQ` is anchored at kBlockM and would over-count + // tiles for GQA workloads. We assert kBlockM % num_qpkv == 0 in + // select_config so this integer divide is always exact. + const index_t kBlockQ = Kernel::kBlockM / args.num_queries_per_kv; + const index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index b2593f5f63..856cba024f 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -263,13 +263,14 @@ struct UnifiedAttentionKernel const index_t num_queries_per_kv = kargs.num_queries_per_kv; - // kBlockQ derived at runtime from num_queries_per_kv. For the variants - // we ship today this matches the compile-time `kBlockQ` from the - // pipeline trait (the assert below catches any disagreement); the - // explicit runtime form is what eventually lets a single kernel - // instantiation cover multiple num_queries_per_kv values. + // kBlockQ derived at runtime from num_queries_per_kv. The static + // `kBlockQ` from the pipeline trait is anchored at kBlockM (i.e. it + // describes num_qpkv == 1) so the same compiled binary serves every + // num_qpkv that divides kBlockM evenly -- e.g. the d=128 variants + // can run both MHA and GQA-N at runtime with no recompile. The host + // side (select_config) is responsible for enforcing kBlockM % + // num_queries_per_kv == 0. const index_t kBlockQ_dyn = kBlockM / num_queries_per_kv; - assert(kBlockQ_dyn == kBlockQ); // Split-KV: each CTA handles one (kv_head, q_block, split) tuple. The // split index lives in z — when num_splits == 1 (the only z value)