[rocm-libraries] ROCm/rocm-libraries#4999 (commit 45f6624)

[CK] Fix 32-bit overflow in batch prefill kernel for >4GB KV
 cache (#4999)

Use SRD rebasing for page_block_size >= kN0: move SRD base pointer to
page start via 48-bit arithmetic, encode only within-page offset in
voffset. Original code path preserved for ps1/ps16 via constexpr-if.

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Jeff Huang
2026-03-05 01:09:12 +00:00
committed by assistant-librarian[bot]
parent 147210ac72
commit 6e558658ea
2 changed files with 120 additions and 26 deletions

View File

@@ -484,6 +484,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.init_logits_soft_cap(logits_soft_cap);
}
// Check that the maximum offset won't overflow.
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
{
if(num_total_pages > 1)
{
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
"KV cache K offset overflow: exceed int32 max");
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
"KV cache V offset overflow: exceed int32 max");
}
}
return kargs;
}
@@ -637,6 +651,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.init_logits_soft_cap(logits_soft_cap);
}
// Check that the maximum offset won't overflow.
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
{
if(num_total_pages > 1)
{
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
"KV cache K offset overflow: exceed int32 max");
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
"KV cache V offset overflow: exceed int32 max");
}
}
return kargs;
}