mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[rocm-libraries] ROCm/rocm-libraries#4999 (commit 45f6624)
[CK] Fix 32-bit overflow in batch prefill kernel for >4GB KV cache (#4999) Use SRD rebasing for page_block_size >= kN0: move SRD base pointer to page start via 48-bit arithmetic, encode only within-page offset in voffset. Original code path preserved for ps1/ps16 via constexpr-if. ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
147210ac72
commit
6e558658ea
@@ -484,6 +484,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
// Check that the maximum offset won't overflow.
|
||||
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
|
||||
{
|
||||
if(num_total_pages > 1)
|
||||
{
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache K offset overflow: exceed int32 max");
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache V offset overflow: exceed int32 max");
|
||||
}
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -637,6 +651,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
// Check that the maximum offset won't overflow.
|
||||
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
|
||||
{
|
||||
if(num_total_pages > 1)
|
||||
{
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache K offset overflow: exceed int32 max");
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache V offset overflow: exceed int32 max");
|
||||
}
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user