diff --git a/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp index 19c72f86cc..b1b67d3ba4 100644 --- a/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp +++ b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp @@ -140,10 +140,12 @@ struct PagedTileWindowNavigator } template - CK_TILE_HOST_DEVICE bool is_closs_block(const TileWindow& tile_window) const + CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index, + const TileWindow& tile_window) const { - return page_block_size < (tile_window.get_window_origin().at(number{}) + - tile_window.get_window_lengths().at(number{})); + const index_t origin = tile_window.get_window_origin().at(number{}); + const index_t length = tile_window.get_window_lengths().at(number{}); + return (block_index < num_blocks - 1) && (page_block_size < origin + length); } template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index c9d8ef9a8f..a04e4101b0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -151,9 +151,8 @@ struct BlockFmhaFwdAppendKVPipeline if constexpr(kIsPagedKV) { store_tile(k_dram_block_window, knew_tile); - // write tile to another block if nesscary - if(k_tile_navigator.is_closs_block(k_dram_block_window)) + if(k_tile_navigator.is_cross_block(i_block0, k_dram_block_window)) { k_tile_navigator.move_to_block(i_block0, k_dram_block_window, i_block0 + 1); store_tile(k_dram_block_window, knew_tile); @@ -176,9 +175,8 @@ struct BlockFmhaFwdAppendKVPipeline if constexpr(kIsPagedKV) { store_tile(v_dram_block_window, vnew_tile); - // write tile to another block if nesscary - if(v_tile_navigator.is_closs_block(v_dram_block_window)) + if(v_tile_navigator.is_cross_block(i_block1, v_dram_block_window)) { v_tile_navigator.move_to_block(i_block1, v_dram_block_window, i_block1 + 1); store_tile(v_dram_block_window, vnew_tile);