diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index be17805e1a..4f1bf17fcb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -325,6 +325,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS k_dram_block_window, Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + index_t seqlen_k_curr_offset = aligned_physical_seqlen_k_start; + using k_tile_type = decltype(load_tile(k_dram_window)); statically_indexed_array k_tiles; @@ -450,8 +452,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS } else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) { - const auto k_origin = k_page_block_navigator.to_global_window_origin( - i_page_block_k, k_dram_block_window.get_window_origin()); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); s_acc = tile_elementwise_in(s_acc_element_func, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { @@ -460,7 +460,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + const auto col = seqlen_k_curr_offset + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); s_acc(i_j_idx) *= scale_s; @@ -506,33 +506,29 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) { - const auto k_origin = k_page_block_navigator.to_global_window_origin( - i_page_block_k, k_dram_block_window.get_window_origin()); - set_tile_if( - s_acc, - -numeric::infinity(), - [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, - physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - if constexpr(kIsPagedKV) - { - return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col; - } - else - { - return physical_seqlen_k_end_ <= col; - } - }); + set_tile_if(s_acc, + -numeric::infinity(), + [&, + physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { + const auto col = seqlen_k_curr_offset + tile_idx.at(number<1>{}); + if constexpr(kIsPagedKV) + { + return col < physical_seqlen_k_start_ || + physical_seqlen_k_end_ <= col; + } + else + { + return physical_seqlen_k_end_ <= col; + } + }); } if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = k_page_block_navigator.to_global_window_origin( - i_page_block_k, k_dram_block_window.get_window_origin()); // mask accept only logical coordinates, do conversion here bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}) - kv_l2p_offset, + seqlen_k_curr_offset - kv_l2p_offset, number{}, number{}); if(need_perpixel_check) @@ -540,7 +536,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS set_tile_if( s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + const auto col = seqlen_k_curr_offset + tile_idx.at(number<1>{}); return mask.IsOutOfBound(row, col - kv_l2p_offset); }); } @@ -562,8 +558,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS i_page_block_k, k_dram_block_window, {kN0, 0}); k_dram_window = make_tile_window( - k_dram_block_window, - Policy::template MakeKDramTileDistribution()); // K DRAM tile window + k_dram_block_window, Policy::template MakeKDramTileDistribution()); + + seqlen_k_curr_offset += kN0; } __builtin_amdgcn_sched_barrier(0);