diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp new file mode 100644 index 0000000000..8659f68a7d --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_decode.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp new file mode 100644 index 0000000000..2505832331 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_decode.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp new file mode 100644 index 0000000000..8e1fb0d1f8 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_decode.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp new file mode 100644 index 0000000000..a9d6b17211 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_decode.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index 9de8a48459..2b8cf4b3c7 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -49,6 +49,7 @@ std::ostream& operator<<(std::ostream& stream, enum class KernelVariant { // d=128 MHA (num_queries_per_kv = 1) prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma + decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128) // d=64 GQA-8 (num_queries_per_kv = 8) prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma @@ -95,12 +96,21 @@ KernelConfig select_config(const unified_attention_args& args) { KernelConfig cfg; - // d=128 MHA — only the 8-warp prefill kernel exists today. A dedicated - // d=128 decode variant is the next commit; until then all d=128 traffic - // takes the prefill kernel (Q-tile waste for short Q, but correct). + // d=128 MHA — two variants today: + // * decode_d128_mha_m128 : kBlockM=128, 4 warps. Used when avg/max Q + // both fit in 128. Cuts Q-tile waste roughly 2x vs prefill for + // short-Q workloads. + // * prefill_d128_mha : kBlockM=256, 8 warps. Everything else. if (args.hdim == 128 && args.num_queries_per_kv == 1) { - cfg.variant = KernelVariant::prefill_d128_mha; + const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs + : args.num_tokens; + const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q; + + if (avg_q <= 128 && max_q <= 128) + cfg.variant = KernelVariant::decode_d128_mha_m128; + else + cfg.variant = KernelVariant::prefill_d128_mha; return cfg; } @@ -207,6 +217,20 @@ std::pair dispatch_prefill_d128_mha( return {false, -1.f}; } +std::pair dispatch_decode_d128_mha_m128( + const unified_attention_args& args, const stream_config& config) +{ + const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); + if (args.data_type == DType::fp16) { + if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, false, 128, 128, 1) + else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, true, 128, 128, 1) + } else if (args.data_type == DType::bf16) { + if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, false, 128, 128, 1) + else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, true, 128, 128, 1) + } + return {false, -1.f}; +} + std::pair dispatch_prefill_d64_gqa8( const unified_attention_args& args, const stream_config& config) { @@ -324,6 +348,7 @@ std::pair unified_attention(const unified_attention_args& args, switch (cfg.variant) { case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config); + case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config); case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config); case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config); case KernelVariant::decode_d64_gqa8_m128_p32: return dispatch_decode_d64_gqa8_m128_p32(args, config);