mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
Add compile-time MaxNumBlocks optimization
- Added MaxNumBlocks template parameter to all kernel traits - Propagated through pipeline problem and pipeline - Added compile-time kNeedsRebasing check with if constexpr blocks - Created small-cache optimized instantiations (MaxNumBlocks=100000) - Added runtime dispatch logic for small vs large cache - 3.7% performance improvement for small caches vs runtime check
This commit is contained in:
@@ -66,6 +66,7 @@ struct UnifiedAttentionPipeline
|
||||
static constexpr ck_tile::index_t kPageBlockSize = UnifiedAttentionShape::kPageBlockSize;
|
||||
static constexpr ck_tile::index_t kHeadDim = UnifiedAttentionShape::kHeadDim;
|
||||
static constexpr ck_tile::index_t kHeadDimPadded = UnifiedAttentionShape::kHeadDimPadded;
|
||||
static constexpr ck_tile::index_t kMaxNumBlocks = Problem::kMaxNumBlocks;
|
||||
|
||||
static_assert(kHeadDimPadded <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
@@ -364,50 +365,85 @@ struct UnifiedAttentionPipeline
|
||||
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
|
||||
|
||||
// Use pointer rebasing to avoid int32 overflow in tensor_coordinate::get_offset()
|
||||
// for large KV pools (>131K blocks for d64/GQA-8).
|
||||
// Only enabled when row strides are provided (from kernel) and for hdim <= 64 configs.
|
||||
const bool use_ptr_rebase = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64);
|
||||
// Overflow happens when: row_index * stride > INT32_MAX
|
||||
// Example for d64/GQA-8: max_row=4,799,968, stride=512, offset=2,457,583,616 > INT32_MAX
|
||||
//
|
||||
// Calculate overflow threshold using compile-time constants where possible
|
||||
// Assumption: kv_page_size_in_blocks is typically 1 (page_size == kPageBlockSize)
|
||||
// For configurations where this isn't true, we use runtime PageSize
|
||||
//
|
||||
// Compile-time threshold calculation (assuming page_size_in_blocks == 1):
|
||||
// threshold = INT32_MAX / (kPageBlockSize * kHeadDim)
|
||||
// For d64, block_size=32: threshold = 2147483647 / (32 * 64) = 1,048,575 blocks
|
||||
//
|
||||
// Only enabled when:
|
||||
// 1. Row strides provided from kernel (indicates we have stride info) - runtime
|
||||
// 2. Cache size exceeds overflow threshold - compile-time if kMaxNumBlocks != -1
|
||||
// 3. hdim <= 64 - compile-time (hdim=128 has different buffer layout)
|
||||
constexpr long_index_t kOverflowThresholdBlocks =
|
||||
(kHeadDim <= 64) ? (2147483647L / (kPageBlockSize * kHeadDim)) : 2147483647L;
|
||||
|
||||
// Get views and save original base pointers
|
||||
auto k_view = k_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
auto v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
auto* k_base_ptr = k_view.buf_.p_data_;
|
||||
auto* v_base_ptr = v_view.buf_.p_data_;
|
||||
const auto k_buf_size_orig = k_view.buf_.buffer_size_;
|
||||
const auto v_buf_size_orig = v_view.buf_.buffer_size_;
|
||||
// Compile-time overflow detection when kMaxNumBlocks is specified
|
||||
constexpr bool kNeedsRebasing = (kMaxNumBlocks != -1) && (kHeadDim <= 64) &&
|
||||
(static_cast<long_index_t>(kMaxNumBlocks) > kOverflowThresholdBlocks);
|
||||
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
// Rebase pointers to avoid int32 overflow in window origin coordinates
|
||||
long_index_t k_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
|
||||
k_view.buf_.p_data_ = k_base_ptr + k_off;
|
||||
auto new_k = k_buf_size_orig - k_off;
|
||||
k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
|
||||
const bool need_overflow_check = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64);
|
||||
const bool use_ptr_rebase = kNeedsRebasing ||
|
||||
(need_overflow_check && (kMaxNumBlocks == -1) &&
|
||||
(static_cast<long_index_t>(num_blocks) > kOverflowThresholdBlocks));
|
||||
|
||||
long_index_t v_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
|
||||
v_view.buf_.p_data_ = v_base_ptr + v_off;
|
||||
auto new_v = v_buf_size_orig - v_off;
|
||||
v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
|
||||
}
|
||||
|
||||
const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize;
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_view,
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{init_origin, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
// Fast path: Create windows directly for small caches (no overflow risk)
|
||||
// Slow path: Use rebased pointers for large caches (overflow risk)
|
||||
auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{kv_blk_idx_initial * PageSize, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
k_dram_window.init_raw();
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_view,
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{init_origin, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{kv_blk_idx_initial * PageSize, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
v_dram_window.init_raw();
|
||||
|
||||
// Variables for rebasing (only used if rebasing is possible)
|
||||
// When kMaxNumBlocks != -1 and kNeedsRebasing == false, compiler will eliminate this entirely
|
||||
using KPtrType = remove_cvref_t<decltype(k_dram_window.bottom_tensor_view_.buf_.p_data_)>;
|
||||
using VPtrType = remove_cvref_t<decltype(v_dram_window.bottom_tensor_view_.buf_.p_data_)>;
|
||||
[[maybe_unused]] KPtrType k_base_ptr = nullptr;
|
||||
[[maybe_unused]] VPtrType v_base_ptr = nullptr;
|
||||
[[maybe_unused]] long_index_t k_buf_size_orig = 0;
|
||||
[[maybe_unused]] long_index_t v_buf_size_orig = 0;
|
||||
|
||||
if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1))
|
||||
{
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
// Save original pointers and sizes for lazy rebasing
|
||||
k_base_ptr = k_dram_window.bottom_tensor_view_.buf_.p_data_;
|
||||
v_base_ptr = v_dram_window.bottom_tensor_view_.buf_.p_data_;
|
||||
k_buf_size_orig = k_dram_window.bottom_tensor_view_.buf_.buffer_size_;
|
||||
v_buf_size_orig = v_dram_window.bottom_tensor_view_.buf_.buffer_size_;
|
||||
|
||||
// Initial rebase to first block
|
||||
long_index_t k_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
|
||||
k_dram_window.bottom_tensor_view_.buf_.p_data_ = k_base_ptr + k_off;
|
||||
auto new_k = k_buf_size_orig - k_off;
|
||||
k_dram_window.bottom_tensor_view_.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
|
||||
k_dram_window.init_raw();
|
||||
k_dram_window.set_window_origin({0, 0});
|
||||
|
||||
long_index_t v_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
|
||||
v_dram_window.bottom_tensor_view_.buf_.p_data_ = v_base_ptr + v_off;
|
||||
auto new_v = v_buf_size_orig - v_off;
|
||||
v_dram_window.bottom_tensor_view_.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
|
||||
v_dram_window.init_raw();
|
||||
v_dram_window.set_window_origin({0, 0});
|
||||
}
|
||||
}
|
||||
|
||||
// prefetch K tile
|
||||
constexpr index_t k0_loops = 1;
|
||||
constexpr index_t k1_loops = 1;
|
||||
@@ -497,15 +533,33 @@ struct UnifiedAttentionPipeline
|
||||
constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
// Helper lambda to rebase window pointer (avoids int32 overflow)
|
||||
// This is expensive (calls init_raw), so we minimize calls via lazy rebasing
|
||||
auto rebase_window = [](auto& window, auto* base_ptr, long_index_t elem_offset,
|
||||
auto buf_size_orig) {
|
||||
window.bottom_tensor_view_.buf_.p_data_ = base_ptr + elem_offset;
|
||||
auto new_size = buf_size_orig - elem_offset;
|
||||
window.bottom_tensor_view_.buf_.buffer_size_ = new_size > 0 ? new_size : kPageBlockSize * kHeadDim;
|
||||
window.init_raw();
|
||||
window.init_raw(); // Expensive: rebuilds AMD buffer resource descriptor
|
||||
window.set_window_origin({0, 0});
|
||||
};
|
||||
|
||||
// Lazy rebasing: track which block we're currently rebased to
|
||||
// Only call rebase_window (expensive init_raw) when we drift too far from base
|
||||
// Threshold: rebase when offset from base would exceed 1 billion (half of int32_max)
|
||||
// For d64, block_size=32: threshold = 1B / (32 * 64) = ~488,281 blocks
|
||||
// This is compile-time constant, allowing compiler to optimize
|
||||
constexpr long_index_t kRebaseThreshold = 1000000000L / (kPageBlockSize * kHeadDim);
|
||||
[[maybe_unused]] index_t k_base_block = 0;
|
||||
[[maybe_unused]] index_t v_base_block = 0;
|
||||
if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1))
|
||||
{
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
k_base_block = kv_blk_idx_initial;
|
||||
v_base_block = kv_blk_idx_initial;
|
||||
}
|
||||
}
|
||||
|
||||
// Page block index tracking
|
||||
// const index_t kv_page_size_in_blocks =
|
||||
// PageSize / kPageBlockSize;
|
||||
@@ -518,20 +572,42 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
index_t k_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
|
||||
if(use_ptr_rebase)
|
||||
|
||||
if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1))
|
||||
{
|
||||
long_index_t k_row =
|
||||
static_cast<long_index_t>(k_page_blk_idx) * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
k_dram_window.set_window_origin(
|
||||
{k_page_blk_idx * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
// Lazy rebasing: only call expensive rebase_window when drifting too far from base
|
||||
long_index_t offset_from_base = static_cast<long_index_t>(k_page_blk_idx) - k_base_block;
|
||||
if(offset_from_base < 0) offset_from_base = -offset_from_base; // abs value
|
||||
|
||||
if(offset_from_base > kRebaseThreshold)
|
||||
{
|
||||
// Too far from base, rebase to current block (expensive: calls init_raw)
|
||||
k_base_block = k_page_blk_idx;
|
||||
long_index_t k_row =
|
||||
static_cast<long_index_t>(k_page_blk_idx) * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Close to base, just update window origin (cheap: no init_raw)
|
||||
long_index_t k_row =
|
||||
static_cast<long_index_t>(k_page_blk_idx) * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
long_index_t base_row = static_cast<long_index_t>(k_base_block) * PageSize;
|
||||
k_dram_window.set_window_origin({static_cast<index_t>(k_row - base_row), 0});
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Fast path when rebasing not needed (kMaxNumBlocks is small)
|
||||
k_dram_window.set_window_origin(
|
||||
{k_page_blk_idx * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
};
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
@@ -540,20 +616,42 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
index_t v_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
|
||||
if(use_ptr_rebase)
|
||||
|
||||
if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1))
|
||||
{
|
||||
long_index_t v_row =
|
||||
static_cast<long_index_t>(v_page_blk_idx) * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
v_dram_window.set_window_origin(
|
||||
{v_page_blk_idx * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
// Lazy rebasing: only call expensive rebase_window when drifting too far from base
|
||||
long_index_t offset_from_base = static_cast<long_index_t>(v_page_blk_idx) - v_base_block;
|
||||
if(offset_from_base < 0) offset_from_base = -offset_from_base; // abs value
|
||||
|
||||
if(offset_from_base > kRebaseThreshold)
|
||||
{
|
||||
// Too far from base, rebase to current block (expensive: calls init_raw)
|
||||
v_base_block = v_page_blk_idx;
|
||||
long_index_t v_row =
|
||||
static_cast<long_index_t>(v_page_blk_idx) * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Close to base, just update window origin (cheap: no init_raw)
|
||||
long_index_t v_row =
|
||||
static_cast<long_index_t>(v_page_blk_idx) * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
long_index_t base_row = static_cast<long_index_t>(v_base_block) * PageSize;
|
||||
v_dram_window.set_window_origin({static_cast<index_t>(v_row - base_row), 0});
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Fast path when rebasing not needed (kMaxNumBlocks is small)
|
||||
v_dram_window.set_window_origin(
|
||||
{v_page_blk_idx * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
};
|
||||
|
||||
auto K_lds_load = [&](auto k_lds_read_idx) {
|
||||
|
||||
@@ -19,7 +19,8 @@ template <typename QDataType_,
|
||||
typename ODataType_,
|
||||
typename UnifiedAttentionShape_,
|
||||
typename FmhaMask_,
|
||||
typename Traits_>
|
||||
typename Traits_,
|
||||
index_t MaxNumBlocks_ = -1>
|
||||
struct UnifiedAttentionPipelineProblem
|
||||
{
|
||||
// TODO kM0 and KN1??
|
||||
@@ -41,6 +42,7 @@ struct UnifiedAttentionPipelineProblem
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
|
||||
static constexpr index_t kMaxNumBlocks = MaxNumBlocks_;
|
||||
static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size();
|
||||
|
||||
Reference in New Issue
Block a user