diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 90a2ba6373..bd6c36603a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -560,7 +560,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor v_host( - USE_PAGED_VCACHE && 0 < page_block_size + 0 < page_block_size ? (is_v_rowmajor ? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_v) : get_lengths(i_perm, max_num_blocks, nhead_k, hdim_v, page_block_size)) @@ -884,9 +884,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(is_v_rowmajor) return i_perm ? hdim_v : nhead_k * hdim_v; else - return USE_PAGED_VCACHE && 0 < page_block_size - ? (i_perm ? page_block_size : nhead_k * page_block_size) - : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size) + : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); }(); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); @@ -899,13 +898,11 @@ bool run(const ck_tile::ArgParser& arg_parser) : (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); const ck_tile::index_t nhead_stride_v = [&]() { if(is_v_rowmajor) - return USE_PAGED_VCACHE && 0 < page_block_size - ? (i_perm ? page_block_size * hdim_v : hdim_v) - : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v) + : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); else - return USE_PAGED_VCACHE && 0 < page_block_size - ? (i_perm ? hdim_v * page_block_size : page_block_size) - : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); + return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size) + : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); @@ -920,8 +917,8 @@ bool run(const ck_tile::ArgParser& arg_parser) (0 < page_block_size ? (nhead_k * page_block_size * hdim_q) : (nhead_k * shape_seqlen_k * hdim_q)); const ck_tile::index_t batch_stride_v = - (USE_PAGED_VCACHE && 0 < page_block_size ? (nhead_k * hdim_v * page_block_size) - : (nhead_k * hdim_v * shape_seqlen_k)); + (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) + : (nhead_k * hdim_v * shape_seqlen_k)); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); @@ -1128,7 +1125,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } #endif - if (USE_PAGED_VCACHE && 0 < page_block_size) { + if (0 < page_block_size) { if (is_v_rowmajor) { if(i_perm) { v_host_ref.ForEach([&](auto& self, auto i) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 61c8832cf9..9982c7a156 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -218,7 +218,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + // make sure the first tile is completely located in page-block + const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] { + if constexpr(kIsPagedKV) + { + return kN0 * integer_divide_floor(seqlen_k_start_, kN0); + } + else + { + return seqlen_k_start_; + } + }(); + const auto num_total_loop = + integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0); // check early exit if masked and no work to do. if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) @@ -242,8 +254,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } } - auto [i_block0, k_dram_block_window] = - k_tile_navigator.make_tile_window(k_dram_block_window_tmp, {seqlen_k_start, 0}); + auto [i_block0, k_dram_block_window] = k_tile_navigator.make_tile_window( + k_dram_block_window_tmp, {adjusted_seqlen_k_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window( @@ -257,7 +269,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto [i_block1, v_dram_window] = v_tile_navigator.make_tile_window( v_dram_block_window_tmp, - {0, seqlen_k_start}, // TODO: hdim split? + {0, adjusted_seqlen_k_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -384,17 +396,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } move_tile_window(bias_dram_window, {0, kN0}); - /// TODO: only check in last iteration without increasing code size + /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) { const auto k_origin = k_tile_navigator.to_global_window_origin( i_block0, k_dram_block_window.get_window_origin()); set_tile_if(s_acc, -numeric::infinity(), - [&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { + [&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end]( + auto tile_idx) { const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return seqlen_k_end_ <= col; + if constexpr(kIsPagedKV) + { + return col < seqlen_k_start_ || seqlen_k_end_ <= col; + } + else + { + return seqlen_k_end_ <= col; + } }); }