From f5beedb2e918fa3221061ef536412504af8aca0d Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Tue, 12 May 2026 11:48:19 +0000 Subject: [PATCH] Add CK-UA decode_d128_mha_m32 / _m16 small-Q tiers For pure-decode workloads (sq=1) at d=128 the m128 tile wastes most of its 128 query rows, capping CK below Triton on every batch size in our sweep (4..256). Add two small-Q tiers that mirror the d=64 GQA-8 ladder: * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 MFMA (tiny-decode) * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 MFMA (tiny-decode) select_config now ladders by (avg_q, max_q): m16 -> m32 -> m128 -> prefill. d=128 MHA, hq=16/hk=16, sq=1, sk=120k, num_blocks=60k: batch before after CK BW 4 ~0.95x 0.98x 4.76 TB/s 8 ~0.85x 1.29x 5.00 TB/s 32 ~0.85x 1.14x 5.29 TB/s 64 ~0.75x 0.93x 5.35 TB/s 128 ~1.00x 1.09x 5.39 TB/s 256 ~1.03x 1.02x 5.41 TB/s Correctness suite stays at 241/245 (same 4 known int32-overflow failures in the prefill path). Co-authored-by: Cursor --- ...fied_attention_d128_bf16_mask_decode_s.cpp | 11 +++++++ ...fied_attention_d128_bf16_mask_decode_t.cpp | 11 +++++++ ...ied_attention_d128_bf16_nmask_decode_s.cpp | 11 +++++++ ...ied_attention_d128_bf16_nmask_decode_t.cpp | 11 +++++++ ...fied_attention_d128_fp16_mask_decode_s.cpp | 11 +++++++ ...fied_attention_d128_fp16_mask_decode_t.cpp | 11 +++++++ ...ied_attention_d128_fp16_nmask_decode_s.cpp | 11 +++++++ ...ied_attention_d128_fp16_nmask_decode_t.cpp | 11 +++++++ .../unified_attention.cpp | 20 +++++++++---- .../unified_attention_impl.hpp | 30 +++++++++++++++++++ 10 files changed, 132 insertions(+), 6 deletions(-) create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp new file mode 100644 index 0000000000..804d8a1761 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp new file mode 100644 index 0000000000..dfaa9b3dad --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp new file mode 100644 index 0000000000..21301cc083 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp new file mode 100644 index 0000000000..da7c91915d --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp new file mode 100644 index 0000000000..d5cbc67bfa --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp new file mode 100644 index 0000000000..e22ac838c3 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp new file mode 100644 index 0000000000..be137ee375 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp new file mode 100644 index 0000000000..abb86554ee --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index 123ffa7238..ea89ddeede 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -81,18 +81,22 @@ KernelConfig select_config(const unified_attention_args& args) { KernelConfig cfg; - // d=128 MHA — two variants today: - // * decode_d128_mha_m128 : kBlockM=128, 4 warps. Used when avg/max Q - // both fit in 128. Cuts Q-tile waste roughly 2x vs prefill for - // short-Q workloads. - // * prefill_d128_mha : kBlockM=256, 8 warps. Everything else. + // d=128 MHA — tile-tier ladder by (avg_q, max_q): + // * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 mfma (tiny-decode) + // * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 mfma (tiny-decode) + // * decode_d128_mha_m128 : kBlockM=128, 4 warps, 32x32 mfma (default) + // * prefill_d128_mha : kBlockM=256, 8 warps, 32x32 mfma if(args.hdim == 128 && args.num_queries_per_kv == 1) { const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens; const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q; - if(avg_q <= 128 && max_q <= 128) + if(avg_q <= 16 && max_q <= 16) + cfg.variant = KernelVariant::decode_d128_mha_m16; + else if(avg_q <= 32 && max_q <= 32) + cfg.variant = KernelVariant::decode_d128_mha_m32; + else if(avg_q <= 128 && max_q <= 128) cfg.variant = KernelVariant::decode_d128_mha_m128; else cfg.variant = KernelVariant::prefill_d128_mha; @@ -172,6 +176,10 @@ std::pair unified_attention(const unified_attention_args& args, return dispatch_variant(args, config); case KernelVariant::decode_d128_mha_m128: return dispatch_variant(args, config); + case KernelVariant::decode_d128_mha_m32: + return dispatch_variant(args, config); + case KernelVariant::decode_d128_mha_m16: + return dispatch_variant(args, config); case KernelVariant::prefill_d64_gqa8: return dispatch_variant(args, config); case KernelVariant::decode_d64_gqa8_m128: diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 6fd845278e..20c63318c7 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -39,6 +39,8 @@ enum class KernelVariant // d=128 MHA (num_queries_per_kv = 1) prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128) + decode_d128_mha_m32, // kBlockM=32, 1 warp, 32x32 mfma (tiny-decode policy) + decode_d128_mha_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy) // d=64 GQA-8 (num_queries_per_kv = 8) prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma @@ -117,6 +119,34 @@ struct variant_config static constexpr bool kUseDecodeGrid = false; }; +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 128; + static constexpr index_t BlockM = 32; + static constexpr index_t NumQPerKV = 1; + static constexpr index_t BlockSize = 32; + using BlockWarps = sequence<1, 1, 1>; + using WarpGemmShape = sequence<32, 32, 16>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = true; +}; + +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 128; + static constexpr index_t BlockM = 16; + static constexpr index_t NumQPerKV = 1; + static constexpr index_t BlockSize = 32; + using BlockWarps = sequence<1, 1, 1>; + using WarpGemmShape = sequence<16, 16, 32>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = true; +}; + template <> struct variant_config {