mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Add compile-time MaxNumBlocks optimization
- Added MaxNumBlocks template parameter to all kernel traits - Propagated through pipeline problem and pipeline - Added compile-time kNeedsRebasing check with if constexpr blocks - Created small-cache optimized instantiations (MaxNumBlocks=100000) - Added runtime dispatch logic for small vs large cache - 3.7% performance improvement for small caches vs runtime check
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32, 100000>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32, 100000>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -64,6 +64,27 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// Small-cache variants (7th template arg = MaxNumBlocks for compile-time overflow elimination).
|
||||
// For d64/GQA-8/bs32: overflow threshold = 2^31 / (32 * 64) = 1,048,575 blocks.
|
||||
// Set MaxNumBlocks = 100,000 (conservative, safe for ~98K blocks) to guarantee no overflow.
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, 100000>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, 100000>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, 100000>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
enum class tile_tier { large, medium, small, tiny };
|
||||
|
||||
static tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
@@ -87,6 +108,21 @@ static tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
return tile_tier::medium;
|
||||
}
|
||||
|
||||
// Select between small-cache (compile-time overflow elimination) and large-cache variants.
|
||||
// For d64/bs32: overflow threshold = 2^31 / (32 * 64) = 1,048,575 blocks
|
||||
// We use 100,000 as the small-cache limit (conservative, safe for ~98K blocks)
|
||||
static bool use_small_cache_variant(const unified_attention_args& args)
|
||||
{
|
||||
// Only optimize for d64 with block_size < 64 (bs32 variants)
|
||||
if(args.hdim != 64 || args.page_blk_size >= 64)
|
||||
return false;
|
||||
|
||||
// Conservative threshold: 100,000 blocks (~98K)
|
||||
// This guarantees no int32 overflow for d64/bs32
|
||||
constexpr index_t kSmallCacheThreshold = 100000;
|
||||
return args.num_blks <= kSmallCacheThreshold;
|
||||
}
|
||||
|
||||
std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
@@ -112,6 +148,7 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
if(args.hdim == 64 && args.num_queries_per_kv == 8)
|
||||
{
|
||||
const bool use_bs32 = (args.page_blk_size < 64);
|
||||
const bool use_small_cache = use_small_cache_variant(args);
|
||||
|
||||
if(tier == tile_tier::tiny)
|
||||
{
|
||||
@@ -157,8 +194,13 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
if(use_small_cache) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
}
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
@@ -211,6 +253,9 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32
|
||||
|
||||
@@ -115,7 +115,8 @@ struct unified_attention_kernel_traits
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
unified_attention_traits,
|
||||
-1>; // MaxNumBlocks = -1 (runtime check) for prefill/large tiles
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
@@ -137,7 +138,8 @@ template <unified_attention_args::data_type_enum DataType,
|
||||
index_t HeadSize_ = 128,
|
||||
index_t BlockM_ = 128,
|
||||
index_t NumQPerKV_ = 1,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
|
||||
index_t MaxNumBlocks_ = -1> // -1 means no compile-time limit (runtime check)
|
||||
struct unified_attention_decode_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
@@ -146,6 +148,7 @@ struct unified_attention_decode_kernel_traits
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
@@ -179,7 +182,8 @@ struct unified_attention_decode_kernel_traits
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
unified_attention_traits,
|
||||
-1>; // MaxNumBlocks = -1 (runtime check) for prefill/large tiles
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
@@ -198,7 +202,8 @@ template <unified_attention_args::data_type_enum DataType,
|
||||
index_t HeadSize_ = 64,
|
||||
index_t BlockM_ = 64,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
|
||||
index_t MaxNumBlocks_ = -1>
|
||||
struct unified_attention_decode_small_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
@@ -207,6 +212,7 @@ struct unified_attention_decode_small_kernel_traits
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
@@ -239,7 +245,8 @@ struct unified_attention_decode_small_kernel_traits
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
unified_attention_traits,
|
||||
MAX_NUM_BLOCKS>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
@@ -261,15 +268,17 @@ template <unified_attention_args::data_type_enum DataType,
|
||||
index_t HeadSize_ = 64,
|
||||
index_t BlockM_ = 16,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
|
||||
index_t MaxNumBlocks_ = -1> // -1 means no compile-time limit (runtime check)
|
||||
struct unified_attention_decode_tiny_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
@@ -302,7 +311,8 @@ struct unified_attention_decode_tiny_kernel_traits
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
unified_attention_traits,
|
||||
MAX_NUM_BLOCKS>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
@@ -324,15 +334,17 @@ template <unified_attention_args::data_type_enum DataType,
|
||||
index_t HeadSize_ = 64,
|
||||
index_t BlockM_ = 32,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = 32>
|
||||
index_t BlockSize_ = 32,
|
||||
index_t MaxNumBlocks_ = -1> // -1 means no compile-time limit (runtime check)
|
||||
struct unified_attention_decode_bs32_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
@@ -364,7 +376,8 @@ struct unified_attention_decode_bs32_kernel_traits
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
unified_attention_traits,
|
||||
MAX_NUM_BLOCKS>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
|
||||
Reference in New Issue
Block a user