From ae1d09f545ff81375ce6ab3383616bfdd14e8a9a Mon Sep 17 00:00:00 2001 From: Amir Ghamarian Date: Sat, 28 Mar 2026 10:57:10 +0000 Subject: [PATCH] Add 2-warp decode kernel with kBlockM=64 for minimal tile waste Introduce UnifiedAttentionPipelineDecodePolicy with NumWarpPerGroup=2, enabling sequence<2,1,1> (2 warps, 1D layout along M). This gives kBlockM=64, kBlockQ=8 for GQA-8, reducing Q tile padding waste from 15/16 (kBlockM=128) to 7/8 for decode workloads. Key insight: instead of fighting with 2D warp layouts (which break the permlane32_swap softmax reduction), use fewer warps with a smaller NumWarpPerGroup. The 1D warp layout is preserved so no reduction changes are needed. Benchmark (64-seq decode, d64 GQA-8): kBlockM=128 (prev): 0.03406ms kBlockM=64 (this): 0.03247ms (~4.7% faster) Total vs baseline: 0.06177ms -> 0.03247ms (1.90x speedup) Made-with: Cursor --- ...ified_attention_d64_bf16_mask_gqa8_decode_s.cpp | 14 ++++++++++++++ ...fied_attention_d64_bf16_nmask_gqa8_decode_s.cpp | 14 ++++++++++++++ ...ified_attention_d64_fp16_mask_gqa8_decode_s.cpp | 14 ++++++++++++++ ...fied_attention_d64_fp16_nmask_gqa8_decode_s.cpp | 14 ++++++++++++++ .../42_unified_attention/unified_attention.cpp | 14 +++++++------- .../unified_attention_impl.hpp | 11 +++++++---- .../unified_attention_pipeline_default_policy.hpp | 7 +++++++ 7 files changed, 77 insertions(+), 11 deletions(-) create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp new file mode 100644 index 0000000000..f0bb3e3fbe --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_small_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp new file mode 100644 index 0000000000..fea554456e --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_decode_s.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_small_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp new file mode 100644 index 0000000000..abaab70201 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_decode_s.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_small_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp new file mode 100644 index 0000000000..b413773e90 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_decode_s.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_small_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index a3e470b546..85a40d9a5e 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -26,10 +26,10 @@ std::ostream& operator<<(std::ostream& stream, return unified_attention_kernel_dispatch(args, config); \ } -// Helper macro for decode-tuned dispatch (4 warps, kBlockM=128). +// Helper macro for decode dispatch (2 warps, kBlockM=64). #define DISPATCH_UNIFIED_ATTENTION_DECODE(DType, IsMask, HSize, BM, NQPKV) \ { \ - using kernel_traits = unified_attention_decode_kernel_traits; \ + using kernel_traits = unified_attention_decode_small_kernel_traits; \ return unified_attention_kernel_dispatch(args, config); \ } @@ -65,16 +65,16 @@ std::pair unified_attention(const unified_attention_args& args, { if(use_decode) { - // Decode-tuned: 4 warps, kBlockM=128 (kBlockQ=16) + // Small decode: 2 warps, kBlockM=64 (kBlockQ=8) if(args.data_type == unified_attention_args::data_type_enum::fp16) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8) + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8) } else if(args.data_type == unified_attention_args::data_type_enum::bf16) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8) + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8) } } else diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 12881e3d86..2f9cdd7a21 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -182,7 +182,8 @@ struct unified_attention_decode_kernel_traits using kernel = UnifiedAttentionKernel; }; -// Aggressive decode traits: 4 warps (2x2 layout), kBlockM=64 for maximum decode throughput. +// Small decode traits: 2 warps, kBlockM=64, decode policy (NumWarpPerGroup=2). +// Uses 1D warp layout (sequence<2,1,1>) so no softmax reduction changes needed. template ; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; - // 2x2 warp layout: 4 warps total, kBlockM=2*32=64, N split=2*32=64 - using unified_attention_block_warps = sequence<2, 2, 1>; + // 2 warps along M: kBlockM=2*32=64, kBlockSize=128, NumWarpGroups=1 + using unified_attention_block_warps = sequence<2, 1, 1>; using unified_attention_shape = TileUnifiedAttentionShape; - using unified_attention_pipeline = UnifiedAttentionPipeline; + using unified_attention_pipeline = + UnifiedAttentionPipeline; using epilogue = Default2DEpilogue< Default2DEpilogueProblem::acc_dtype, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index b0f8b26af6..adc6bb3271 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -596,4 +596,11 @@ struct UnifiedAttentionPipelineDefaultPolicy } }; +struct UnifiedAttentionPipelineDecodePolicy : UnifiedAttentionPipelineDefaultPolicy +{ + static constexpr ck_tile::index_t NumWarpPerGroup = 2; + static constexpr ck_tile::index_t NumThreadPerWarpGroup = + NumWarpPerGroup * ck_tile::get_warp_size(); +}; + } // namespace ck_tile