Add 2-warp decode kernel with kBlockM=64 for minimal tile waste

Introduce UnifiedAttentionPipelineDecodePolicy with NumWarpPerGroup=2,
enabling sequence<2,1,1> (2 warps, 1D layout along M). This gives
kBlockM=64, kBlockQ=8 for GQA-8, reducing Q tile padding waste from
15/16 (kBlockM=128) to 7/8 for decode workloads.

Key insight: instead of fighting with 2D warp layouts (which break the
permlane32_swap softmax reduction), use fewer warps with a smaller
NumWarpPerGroup. The 1D warp layout is preserved so no reduction changes
are needed.

Benchmark (64-seq decode, d64 GQA-8):
  kBlockM=128 (prev): 0.03406ms
  kBlockM=64  (this): 0.03247ms (~4.7% faster)
  Total vs baseline:  0.06177ms -> 0.03247ms (1.90x speedup)

Made-with: Cursor
This commit is contained in:
Amir Ghamarian
2026-03-28 10:57:10 +00:00
parent 8d396d29f0
commit ae1d09f545
7 changed files with 77 additions and 11 deletions

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -26,10 +26,10 @@ std::ostream& operator<<(std::ostream& stream,
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
// Helper macro for decode-tuned dispatch (4 warps, kBlockM=128).
// Helper macro for decode dispatch (2 warps, kBlockM=64).
#define DISPATCH_UNIFIED_ATTENTION_DECODE(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
@@ -65,16 +65,16 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
{
if(use_decode)
{
// Decode-tuned: 4 warps, kBlockM=128 (kBlockQ=16)
// Small decode: 2 warps, kBlockM=64 (kBlockQ=8)
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
}
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
}
}
else

View File

@@ -182,7 +182,8 @@ struct unified_attention_decode_kernel_traits
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
};
// Aggressive decode traits: 4 warps (2x2 layout), kBlockM=64 for maximum decode throughput.
// Small decode traits: 2 warps, kBlockM=64, decode policy (NumWarpPerGroup=2).
// Uses 1D warp layout (sequence<2,1,1>) so no softmax reduction changes needed.
template <unified_attention_args::data_type_enum DataType,
bool IsMasking,
index_t HeadSize_ = 64,
@@ -202,8 +203,8 @@ struct unified_attention_decode_small_kernel_traits
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
// 2x2 warp layout: 4 warps total, kBlockM=2*32=64, N split=2*32=64
using unified_attention_block_warps = sequence<2, 2, 1>;
// 2 warps along M: kBlockM=2*32=64, kBlockSize=128, NumWarpGroups=1
using unified_attention_block_warps = sequence<2, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,
@@ -230,7 +231,9 @@ struct unified_attention_decode_small_kernel_traits
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
using unified_attention_pipeline =
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
UnifiedAttentionPipelineDecodePolicy>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,

View File

@@ -596,4 +596,11 @@ struct UnifiedAttentionPipelineDefaultPolicy
}
};
struct UnifiedAttentionPipelineDecodePolicy : UnifiedAttentionPipelineDefaultPolicy
{
static constexpr ck_tile::index_t NumWarpPerGroup = 2;
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
NumWarpPerGroup * ck_tile::get_warp_size();
};
} // namespace ck_tile