mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Add 2-warp decode kernel with kBlockM=64 for minimal tile waste
Introduce UnifiedAttentionPipelineDecodePolicy with NumWarpPerGroup=2, enabling sequence<2,1,1> (2 warps, 1D layout along M). This gives kBlockM=64, kBlockQ=8 for GQA-8, reducing Q tile padding waste from 15/16 (kBlockM=128) to 7/8 for decode workloads. Key insight: instead of fighting with 2D warp layouts (which break the permlane32_swap softmax reduction), use fewer warps with a smaller NumWarpPerGroup. The 1D warp layout is preserved so no reduction changes are needed. Benchmark (64-seq decode, d64 GQA-8): kBlockM=128 (prev): 0.03406ms kBlockM=64 (this): 0.03247ms (~4.7% faster) Total vs baseline: 0.06177ms -> 0.03247ms (1.90x speedup) Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -26,10 +26,10 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// Helper macro for decode-tuned dispatch (4 warps, kBlockM=128).
|
||||
// Helper macro for decode dispatch (2 warps, kBlockM=64).
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
@@ -65,16 +65,16 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
{
|
||||
if(use_decode)
|
||||
{
|
||||
// Decode-tuned: 4 warps, kBlockM=128 (kBlockQ=16)
|
||||
// Small decode: 2 warps, kBlockM=64 (kBlockQ=8)
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -182,7 +182,8 @@ struct unified_attention_decode_kernel_traits
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// Aggressive decode traits: 4 warps (2x2 layout), kBlockM=64 for maximum decode throughput.
|
||||
// Small decode traits: 2 warps, kBlockM=64, decode policy (NumWarpPerGroup=2).
|
||||
// Uses 1D warp layout (sequence<2,1,1>) so no softmax reduction changes needed.
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 64,
|
||||
@@ -202,8 +203,8 @@ struct unified_attention_decode_small_kernel_traits
|
||||
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
// 2x2 warp layout: 4 warps total, kBlockM=2*32=64, N split=2*32=64
|
||||
using unified_attention_block_warps = sequence<2, 2, 1>;
|
||||
// 2 warps along M: kBlockM=2*32=64, kBlockSize=128, NumWarpGroups=1
|
||||
using unified_attention_block_warps = sequence<2, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
@@ -230,7 +231,9 @@ struct unified_attention_decode_small_kernel_traits
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
UnifiedAttentionPipelineDecodePolicy>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
|
||||
@@ -596,4 +596,11 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
}
|
||||
};
|
||||
|
||||
struct UnifiedAttentionPipelineDecodePolicy : UnifiedAttentionPipelineDefaultPolicy
|
||||
{
|
||||
static constexpr ck_tile::index_t NumWarpPerGroup = 2;
|
||||
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
|
||||
NumWarpPerGroup * ck_tile::get_warp_size();
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user