From 40f0d01e2920349ed05d1ea0499fb0de5e0f0a2a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 7 Aug 2024 09:29:55 +0000 Subject: [PATCH] Allow transit tile_window to another page-block --- .../block/block_tile_window_navigator.hpp | 43 +++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp index fb73a780ed..19c72f86cc 100644 --- a/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp +++ b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp @@ -73,7 +73,7 @@ struct PagedTileWindowNavigator static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window"); using WindowOrigin = multi_index<2>; - CK_TILE_DEVICE constexpr PagedTileWindowNavigator( + CK_TILE_HOST_DEVICE constexpr PagedTileWindowNavigator( copy_const_t* physical_blocks_, long_index_t block_stride_, long_index_t fixed_offset_, @@ -90,7 +90,7 @@ struct PagedTileWindowNavigator } template - CK_TILE_DEVICE auto + CK_TILE_HOST_DEVICE auto make_tile_window(const tile_window_with_static_lengths& tile_window, const WindowOrigin& window_origin) const { @@ -104,7 +104,7 @@ struct PagedTileWindowNavigator } template - CK_TILE_DEVICE auto + CK_TILE_HOST_DEVICE auto make_tile_window(const tile_window_with_static_lengths& tile_window, const WindowOrigin& window_origin, const TileDistribution& tile_distribution) const @@ -120,7 +120,7 @@ struct PagedTileWindowNavigator } template - CK_TILE_DEVICE index_t + CK_TILE_HOST_DEVICE index_t move_tile_window(index_t block_index, TileWindow& tile_window, const typename remove_cvref_t::BottomTensorIndex& step) const @@ -139,18 +139,45 @@ struct PagedTileWindowNavigator return new_block_index; } - CK_TILE_DEVICE + template + CK_TILE_HOST_DEVICE bool is_closs_block(const TileWindow& tile_window) const + { + return page_block_size < (tile_window.get_window_origin().at(number{}) + + tile_window.get_window_lengths().at(number{})); + } + + template + CK_TILE_HOST_DEVICE void + move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const + { + const multi_index<2> step = [&]() { + const index_t origin_diff = (block_index - new_block_index) * page_block_size; + if constexpr(VirtualDim == 0) + { + return make_multi_index(origin_diff, 0); + } + else + { + return make_multi_index(0, origin_diff); + } + }(); + + tile_window.set_window_origin(tile_window.get_window_origin() + step); + tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); + } + + CK_TILE_HOST_DEVICE DataType* get_block_ptr(index_t block_index) const { return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset; } - CK_TILE_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const + CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const { return integer_divide_floor(global_window_origin.at(number{}), page_block_size); } - CK_TILE_DEVICE WindowOrigin + CK_TILE_HOST_DEVICE WindowOrigin to_local_window_origin(const WindowOrigin& global_window_origin) const { if constexpr(VirtualDim == 0) @@ -169,7 +196,7 @@ struct PagedTileWindowNavigator } } - CK_TILE_DEVICE WindowOrigin + CK_TILE_HOST_DEVICE WindowOrigin to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const { if constexpr(VirtualDim == 0)