diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 453c337143..8202a78e16 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -212,20 +212,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - // make sure the first tile is completely located in page-block - const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] { - if constexpr(kIsPagedKV) - { - return kN0 * integer_divide_floor(seqlen_k_start_, kN0); - } - else - { - return seqlen_k_start_; - } - }(); - const index_t num_total_loop = - integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0); - // check early exit if masked and no work to do. if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) { @@ -250,6 +236,20 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } } + // make sure the first tile is completely located in page-block + const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] { + if constexpr(kIsPagedKV) + { + return kN0 * integer_divide_floor(seqlen_k_start_, kN0); + } + else + { + return seqlen_k_start_; + } + }(); + const index_t num_total_loop = + integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0); + auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( k_dram_block_window_tmp, {adjusted_seqlen_k_start, 0});