[rocm-libraries] ROCm/rocm-libraries#4999 (commit 45f6624)

[CK] Fix 32-bit overflow in batch prefill kernel for >4GB KV
 cache (#4999)

Use SRD rebasing for page_block_size >= kN0: move SRD base pointer to
page start via 48-bit arithmetic, encode only within-page offset in
voffset. Original code path preserved for ps1/ps16 via constexpr-if.

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Jeff Huang
2026-03-05 01:09:12 +00:00
committed by assistant-librarian[bot]
parent 147210ac72
commit 6e558658ea
2 changed files with 120 additions and 26 deletions

View File

@@ -484,6 +484,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.init_logits_soft_cap(logits_soft_cap);
}
// Check that the maximum offset won't overflow.
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
{
if(num_total_pages > 1)
{
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
"KV cache K offset overflow: exceed int32 max");
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
"KV cache V offset overflow: exceed int32 max");
}
}
return kargs;
}
@@ -637,6 +651,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.init_logits_soft_cap(logits_soft_cap);
}
// Check that the maximum offset won't overflow.
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
{
if(num_total_pages > 1)
{
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
"KV cache K offset overflow: exceed int32 max");
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
"KV cache V offset overflow: exceed int32 max");
}
}
return kargs;
}

View File

@@ -160,18 +160,27 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica
{
// K cache: per-token lookup
// Each token may be on a different page, so we use physical_pages[k0] for each.
// Offset = physical_page * stride_page_block + token_idx_in_page * stride_token
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
const index_t physical_page = physical_pages[k0];
kv_offset_vec[k0] = static_cast<long_index_t>(physical_page) * stride_page_block +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
if constexpr(kPageBlockSize >= kN0)
{
// SRD rebasing mode: within-page offset only.
// The full page base is handled by rebasing the SRD pointer.
kv_offset_vec[k0] = token_idx_in_page * stride_token;
}
else
{
// Full global offset (original code path for ps1, ps16, etc.)
const index_t physical_page = physical_pages[k0];
kv_offset_vec[k0] =
physical_page * stride_page_block + token_idx_in_page * stride_token;
}
});
}
else // !kVTileCrossesPages
else // V cache
{
// V cache: use physical_pages[k0] for each token
// physical_pages was already populated correctly by load_physical_pages(), handling:
@@ -182,31 +191,43 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
const index_t physical_page = physical_pages[k0];
const long_index_t page_base_offset =
static_cast<long_index_t>(physical_page) * stride_page_block;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
if constexpr(kPageBlockSize >= kN0)
{
// Vectorized layout offset calculation:
// Layout: [page, token_in_page/kVectorSize, head_dim, kVectorSize]
// Offset = page_base + (token/kVectorSize) * (head_dim * kVectorSize) +
// (token % kVectorSize)
const long_index_t token_offset =
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
(stride_token * kVectorSize)) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = page_base_offset + token_offset;
// SRD rebasing mode: within-page offset only.
// The full page base is handled by rebasing the SRD pointer.
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const index_t token_offset =
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = token_offset;
}
else
{
kv_offset_vec[k0] = token_idx_in_page * stride_token;
}
}
else // LINEAR_LAYOUT
else
{
// Linear layout: [page, token_in_page, head_dim]
// Offset = page_base + token_idx_in_page * stride_token
kv_offset_vec[k0] =
page_base_offset + static_cast<long_index_t>(token_idx_in_page) * stride_token;
// Full global offset (original code path for ps1, ps16, etc.)
const index_t physical_page = physical_pages[k0];
const long_index_t page_base_offset =
static_cast<long_index_t>(physical_page) * stride_page_block;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const index_t token_offset =
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = page_base_offset + token_offset;
}
else
{
kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token;
}
}
});
}
@@ -561,6 +582,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto k_coord = k_dist.calculate_index();
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
// kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_k_window)
// kPageBlockSize < kN0: global offset, must fit int32
statically_indexed_array<index_t, NRepeat> k_offsets;
index_t current_seq_k = seqlen_k_start;
@@ -597,6 +620,40 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
k_dist,
k_offsets); // K DRAM tile window for
k_dram_window.init_raw();
// SRD rebasing: move the buffer descriptor base pointer to each page's start address
// using 48-bit pointer arithmetic, so voffset only needs the small within-page offset.
// Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page).
auto rebase_k_window = [&](auto& window, index_t physical_page) {
if constexpr(kPageBlockSize >= kN0)
{
// readfirstlane: make physical_page provably wave-uniform so the
// resulting SRD lands in SGPRs (required by buffer load instructions).
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
const auto* base_ptr = k_dram_block_window.get_bottom_tensor_view().buf_.p_data_;
const auto* page_ptr =
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_k;
window.set_bottom_tensor_view_data_ptr(page_ptr);
window.init_raw();
}
};
auto rebase_v_window = [&](auto& window, index_t physical_page) {
if constexpr(kPageBlockSize >= kN0)
{
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
const auto* base_ptr =
v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_;
const auto* page_ptr =
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_v;
window.set_bottom_tensor_view_data_ptr(page_ptr);
window.init_raw();
}
};
// Initial K SRD rebase
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = [&]() {
if constexpr(kPadSeqLenK &&
@@ -727,6 +784,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static_assert(decltype(VPageIndexYDims)::at(0) < VDstrEncode::NDimY,
"V page-index Y dim must be valid");
// kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_v_window)
// kPageBlockSize < kN0: global offset, must fit int32
statically_indexed_array<index_t, V_PageIdxRepeat> v_offsets;
// V physical pages array for use with kv_offset_array_transform
// For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner
@@ -843,6 +902,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
number<1>{}, // NumCoord
VPageIndexYDims);
// Initial V SRD rebase
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
// prefetch K tile
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
@@ -946,6 +1008,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// V physical pages already prefetched before GEMM0
update_v_offsets(number<kK1>{});
v_dram_window.update_page_idx(v_offsets);
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
// KV_BLOCKSCALE: apply k_descale to s_acc (dequantize QK result)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
@@ -1124,6 +1187,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
update_v_offsets(number<2 * kK1>{});
v_dram_window.update_page_idx(v_offsets);
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
}
__builtin_amdgcn_sched_barrier(0);
@@ -1283,6 +1347,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Update V offsets using previously prefetched physical pages
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
v_dram_window.update_page_idx(v_offsets);
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
}
// Prefetch V physical pages for NEXT iteration - overlaps with GEMM1
@@ -1361,6 +1426,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kVectorSize>(
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();