mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4999 (commit 45f6624)
[CK] Fix 32-bit overflow in batch prefill kernel for >4GB KV cache (#4999) Use SRD rebasing for page_block_size >= kN0: move SRD base pointer to page start via 48-bit arithmetic, encode only within-page offset in voffset. Original code path preserved for ps1/ps16 via constexpr-if. ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
147210ac72
commit
6e558658ea
@@ -484,6 +484,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
// Check that the maximum offset won't overflow.
|
||||
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
|
||||
{
|
||||
if(num_total_pages > 1)
|
||||
{
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache K offset overflow: exceed int32 max");
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache V offset overflow: exceed int32 max");
|
||||
}
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -637,6 +651,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
// Check that the maximum offset won't overflow.
|
||||
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
|
||||
{
|
||||
if(num_total_pages > 1)
|
||||
{
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache K offset overflow: exceed int32 max");
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache V offset overflow: exceed int32 max");
|
||||
}
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
|
||||
@@ -160,18 +160,27 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica
|
||||
{
|
||||
// K cache: per-token lookup
|
||||
// Each token may be on a different page, so we use physical_pages[k0] for each.
|
||||
// Offset = physical_page * stride_page_block + token_idx_in_page * stride_token
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
|
||||
kv_offset_vec[k0] = static_cast<long_index_t>(physical_page) * stride_page_block +
|
||||
static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
// SRD rebasing mode: within-page offset only.
|
||||
// The full page base is handled by rebasing the SRD pointer.
|
||||
kv_offset_vec[k0] = token_idx_in_page * stride_token;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Full global offset (original code path for ps1, ps16, etc.)
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
kv_offset_vec[k0] =
|
||||
physical_page * stride_page_block + token_idx_in_page * stride_token;
|
||||
}
|
||||
});
|
||||
}
|
||||
else // !kVTileCrossesPages
|
||||
else // V cache
|
||||
{
|
||||
// V cache: use physical_pages[k0] for each token
|
||||
// physical_pages was already populated correctly by load_physical_pages(), handling:
|
||||
@@ -182,31 +191,43 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(physical_page) * stride_page_block;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
// Vectorized layout offset calculation:
|
||||
// Layout: [page, token_in_page/kVectorSize, head_dim, kVectorSize]
|
||||
// Offset = page_base + (token/kVectorSize) * (head_dim * kVectorSize) +
|
||||
// (token % kVectorSize)
|
||||
const long_index_t token_offset =
|
||||
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
|
||||
(stride_token * kVectorSize)) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
|
||||
kv_offset_vec[k0] = page_base_offset + token_offset;
|
||||
// SRD rebasing mode: within-page offset only.
|
||||
// The full page base is handled by rebasing the SRD pointer.
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
const index_t token_offset =
|
||||
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
kv_offset_vec[k0] = token_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
kv_offset_vec[k0] = token_idx_in_page * stride_token;
|
||||
}
|
||||
}
|
||||
else // LINEAR_LAYOUT
|
||||
else
|
||||
{
|
||||
// Linear layout: [page, token_in_page, head_dim]
|
||||
// Offset = page_base + token_idx_in_page * stride_token
|
||||
kv_offset_vec[k0] =
|
||||
page_base_offset + static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
// Full global offset (original code path for ps1, ps16, etc.)
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(physical_page) * stride_page_block;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
const index_t token_offset =
|
||||
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
kv_offset_vec[k0] = page_base_offset + token_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -561,6 +582,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto k_coord = k_dist.calculate_index();
|
||||
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
|
||||
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
|
||||
// kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_k_window)
|
||||
// kPageBlockSize < kN0: global offset, must fit int32
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
index_t current_seq_k = seqlen_k_start;
|
||||
|
||||
@@ -597,6 +620,40 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
k_dist,
|
||||
k_offsets); // K DRAM tile window for
|
||||
k_dram_window.init_raw();
|
||||
|
||||
// SRD rebasing: move the buffer descriptor base pointer to each page's start address
|
||||
// using 48-bit pointer arithmetic, so voffset only needs the small within-page offset.
|
||||
// Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page).
|
||||
auto rebase_k_window = [&](auto& window, index_t physical_page) {
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
// readfirstlane: make physical_page provably wave-uniform so the
|
||||
// resulting SRD lands in SGPRs (required by buffer load instructions).
|
||||
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
|
||||
const auto* base_ptr = k_dram_block_window.get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto* page_ptr =
|
||||
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_k;
|
||||
window.set_bottom_tensor_view_data_ptr(page_ptr);
|
||||
window.init_raw();
|
||||
}
|
||||
};
|
||||
|
||||
auto rebase_v_window = [&](auto& window, index_t physical_page) {
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
|
||||
const auto* base_ptr =
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto* page_ptr =
|
||||
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_v;
|
||||
window.set_bottom_tensor_view_data_ptr(page_ptr);
|
||||
window.init_raw();
|
||||
}
|
||||
};
|
||||
|
||||
// Initial K SRD rebase
|
||||
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
|
||||
|
||||
constexpr auto k_oob_ck = bool_constant<true>{};
|
||||
constexpr auto k_pre_np = [&]() {
|
||||
if constexpr(kPadSeqLenK &&
|
||||
@@ -727,6 +784,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static_assert(decltype(VPageIndexYDims)::at(0) < VDstrEncode::NDimY,
|
||||
"V page-index Y dim must be valid");
|
||||
|
||||
// kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_v_window)
|
||||
// kPageBlockSize < kN0: global offset, must fit int32
|
||||
statically_indexed_array<index_t, V_PageIdxRepeat> v_offsets;
|
||||
// V physical pages array for use with kv_offset_array_transform
|
||||
// For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner
|
||||
@@ -843,6 +902,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
number<1>{}, // NumCoord
|
||||
VPageIndexYDims);
|
||||
|
||||
// Initial V SRD rebase
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
|
||||
@@ -946,6 +1008,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// V physical pages already prefetched before GEMM0
|
||||
update_v_offsets(number<kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
|
||||
// KV_BLOCKSCALE: apply k_descale to s_acc (dequantize QK result)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
@@ -1124,6 +1187,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
update_v_offsets(number<2 * kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -1283,6 +1347,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// Update V offsets using previously prefetched physical pages
|
||||
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
}
|
||||
|
||||
// Prefetch V physical pages for NEXT iteration - overlaps with GEMM1
|
||||
@@ -1361,6 +1426,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kVectorSize>(
|
||||
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
Reference in New Issue
Block a user