diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 53934ebcd3..c6628f66be 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -484,20 +484,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.init_logits_soft_cap(logits_soft_cap); } - // Check that the maximum offset won't overflow. - if constexpr(kPageBlockSize < FmhaPipeline::kN0) - { - if(num_total_pages > 1) - { - assert(static_cast(num_total_pages - 1) * batch_stride_k <= - static_cast(std::numeric_limits::max()) && - "KV cache K offset overflow: exceed int32 max"); - assert(static_cast(num_total_pages - 1) * batch_stride_v <= - static_cast(std::numeric_limits::max()) && - "KV cache V offset overflow: exceed int32 max"); - } - } - return kargs; } @@ -651,20 +637,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.init_logits_soft_cap(logits_soft_cap); } - // Check that the maximum offset won't overflow. - if constexpr(kPageBlockSize < FmhaPipeline::kN0) - { - if(num_total_pages > 1) - { - assert(static_cast(num_total_pages - 1) * batch_stride_k <= - static_cast(std::numeric_limits::max()) && - "KV cache K offset overflow: exceed int32 max"); - assert(static_cast(num_total_pages - 1) * batch_stride_v <= - static_cast(std::numeric_limits::max()) && - "KV cache V offset overflow: exceed int32 max"); - } - } - return kargs; }