From 1031eebe979f12af8e7cc73ff239e460f8cd2f81 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 11 Jul 2025 11:06:48 +0000 Subject: [PATCH] Merge commit '45904b8fd7cde71dfc3741970325b3d552b06d27' into develop --- .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 6 ++- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 54 ++++++++++++------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index e56d518634..d8cd006c60 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -1122,7 +1122,8 @@ struct FmhaFwdPagedKVKernel const index_t num_blocks = integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); - const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_k; + const long_index_t fixed_offset = + static_cast(i_nhead_) * kargs.nhead_stride_k; return make_page_block_navigator( kargs.k_ptr, @@ -1152,7 +1153,8 @@ struct FmhaFwdPagedKVKernel const index_t num_blocks = integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); - const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_v; + const long_index_t fixed_offset = + static_cast(i_nhead_) * kargs.nhead_stride_v; return make_page_block_navigator( kargs.v_ptr, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 6ad5844b69..9d267e1cee 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -441,28 +441,46 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS } } move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { const auto k_origin = k_page_block_navigator.to_global_window_origin( i_page_block_k, k_dram_block_window.get_window_origin()); - // mask accept only logical coordinates, do conversion here - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}) - kv_l2p_offset, - number{}, - number{}); - if(need_perpixel_check) + + if constexpr(kIsPagedKV) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col - kv_l2p_offset, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }); + // check columns in [aligned_physical_seqlen_k_start, physical_seqlen_k_end) + if(kv_l2p_offset > 0) + { + set_tile_if( + s_acc, + -numeric::infinity(), + [&, physical_seqlen_k_start_ = physical_seqlen_k_start](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return col < physical_seqlen_k_start_; + }); + }; + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + // mask accept only logical coordinates, do conversion here + bool need_perpixel_check = + mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}) - kv_l2p_offset, + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col - kv_l2p_offset); + }); + } } }