diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp new file mode 100644 index 0000000000..804d8a1761 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp new file mode 100644 index 0000000000..dfaa9b3dad --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp new file mode 100644 index 0000000000..21301cc083 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp new file mode 100644 index 0000000000..da7c91915d --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp new file mode 100644 index 0000000000..d5cbc67bfa --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp new file mode 100644 index 0000000000..e22ac838c3 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp new file mode 100644 index 0000000000..be137ee375 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_s.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp new file mode 100644 index 0000000000..abb86554ee --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode_t.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index 123ffa7238..ea89ddeede 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -81,18 +81,22 @@ KernelConfig select_config(const unified_attention_args& args) { KernelConfig cfg; - // d=128 MHA — two variants today: - // * decode_d128_mha_m128 : kBlockM=128, 4 warps. Used when avg/max Q - // both fit in 128. Cuts Q-tile waste roughly 2x vs prefill for - // short-Q workloads. - // * prefill_d128_mha : kBlockM=256, 8 warps. Everything else. + // d=128 MHA — tile-tier ladder by (avg_q, max_q): + // * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 mfma (tiny-decode) + // * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 mfma (tiny-decode) + // * decode_d128_mha_m128 : kBlockM=128, 4 warps, 32x32 mfma (default) + // * prefill_d128_mha : kBlockM=256, 8 warps, 32x32 mfma if(args.hdim == 128 && args.num_queries_per_kv == 1) { const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens; const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q; - if(avg_q <= 128 && max_q <= 128) + if(avg_q <= 16 && max_q <= 16) + cfg.variant = KernelVariant::decode_d128_mha_m16; + else if(avg_q <= 32 && max_q <= 32) + cfg.variant = KernelVariant::decode_d128_mha_m32; + else if(avg_q <= 128 && max_q <= 128) cfg.variant = KernelVariant::decode_d128_mha_m128; else cfg.variant = KernelVariant::prefill_d128_mha; @@ -172,6 +176,10 @@ std::pair unified_attention(const unified_attention_args& args, return dispatch_variant(args, config); case KernelVariant::decode_d128_mha_m128: return dispatch_variant(args, config); + case KernelVariant::decode_d128_mha_m32: + return dispatch_variant(args, config); + case KernelVariant::decode_d128_mha_m16: + return dispatch_variant(args, config); case KernelVariant::prefill_d64_gqa8: return dispatch_variant(args, config); case KernelVariant::decode_d64_gqa8_m128: diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 6fd845278e..20c63318c7 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -39,6 +39,8 @@ enum class KernelVariant // d=128 MHA (num_queries_per_kv = 1) prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128) + decode_d128_mha_m32, // kBlockM=32, 1 warp, 32x32 mfma (tiny-decode policy) + decode_d128_mha_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy) // d=64 GQA-8 (num_queries_per_kv = 8) prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma @@ -117,6 +119,34 @@ struct variant_config static constexpr bool kUseDecodeGrid = false; }; +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 128; + static constexpr index_t BlockM = 32; + static constexpr index_t NumQPerKV = 1; + static constexpr index_t BlockSize = 32; + using BlockWarps = sequence<1, 1, 1>; + using WarpGemmShape = sequence<32, 32, 16>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = true; +}; + +template <> +struct variant_config +{ + static constexpr index_t HeadSize = 128; + static constexpr index_t BlockM = 16; + static constexpr index_t NumQPerKV = 1; + static constexpr index_t BlockSize = 32; + using BlockWarps = sequence<1, 1, 1>; + using WarpGemmShape = sequence<16, 16, 32>; + template + using Pipeline = UnifiedAttentionPipeline; + static constexpr bool kUseDecodeGrid = true; +}; + template <> struct variant_config {