From fd25f5df05dafc52ffa8bfb472d6dd1e7bb05485 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 8 Aug 2025 06:20:38 +0000 Subject: [PATCH] [ck_tile] Merge get_partition_index() and get_partition_index_v2() to get_partition_index() with bool_constant parameter --- include/ck_tile/core/arch/arch.hpp | 15 ++++----- .../ck_tile/core/tensor/tile_distribution.hpp | 33 ++++--------------- include/ck_tile/core/tensor/tile_window.hpp | 8 ++--- 3 files changed, 17 insertions(+), 39 deletions(-) diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 751ad8f4ab..687de0d86b 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -68,16 +68,13 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; } // Use these instead CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); } -CK_TILE_DEVICE index_t get_warp_id_to_vgpr() { return threadIdx.x / get_warp_size(); } - -CK_TILE_DEVICE index_t get_warp_id_to_sgpr() +template +CK_TILE_DEVICE index_t get_warp_id(bool_constant = {}) { - return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); -} - -CK_TILE_DEVICE index_t get_warp_id() -{ - return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); + if constexpr(save_warp_id_in_sgpr) + return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); + else + return threadIdx.x / get_warp_size(); } CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 9fb30c3609..499897e3da 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -18,16 +18,10 @@ namespace ck_tile { namespace detail { -template -CK_TILE_HOST_DEVICE auto get_partition_index(Distribution) +template +CK_TILE_HOST_DEVICE auto get_partition_index(Distribution, bool_constant = {}) { - return Distribution::_get_partition_index(); -} - -template -CK_TILE_HOST_DEVICE auto get_partition_index_v2(Distribution) -{ - return Distribution::_get_partition_index_v2(); + return Distribution::_get_partition_index(bool_constant{}); } } // namespace detail @@ -97,7 +91,8 @@ struct tile_distribution CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; } - CK_TILE_HOST_DEVICE static auto _get_partition_index() + template + CK_TILE_HOST_DEVICE static auto _get_partition_index(bool_constant = {}) { // only support warp-tile and block-tile static_assert(NDimP == 1 or NDimP == 2, "wrong!"); @@ -108,22 +103,8 @@ struct tile_distribution } else if constexpr(NDimP == 2) { - return array{get_warp_id_to_sgpr(), get_lane_id()}; - } - } - - CK_TILE_HOST_DEVICE static auto _get_partition_index_v2() - { - // only support warp-tile and block-tile - static_assert(NDimP == 1 or NDimP == 2, "wrong!"); - - if constexpr(NDimP == 1) - { - return array{get_lane_id()}; - } - else if constexpr(NDimP == 2) - { - return array{get_warp_id_to_vgpr(), get_lane_id()}; + return array{get_warp_id(bool_constant{}), + get_lane_id()}; } } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index befb02a65a..945d310518 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -183,9 +183,9 @@ struct tile_window_with_static_distribution const auto partition_index = [&]() { if constexpr(BottomTensorView::buffer_view::get_address_space() == address_space_enum::lds) - return detail::get_partition_index_v2(tile_dstr_); + return detail::get_partition_index(tile_dstr_, bool_constant{}); else - return detail::get_partition_index(tile_dstr_); + return detail::get_partition_index(tile_dstr_, bool_constant{}); }(); const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( @@ -922,9 +922,9 @@ struct tile_window_with_static_distribution const auto partition_index = [&]() { if constexpr(BottomTensorView::buffer_view::get_address_space() == address_space_enum::lds) - return detail::get_partition_index_v2(tile_dstr_); + return detail::get_partition_index(tile_dstr_, bool_constant{}); else - return detail::get_partition_index(tile_dstr_); + return detail::get_partition_index(tile_dstr_, bool_constant{}); }(); // TODO: this use less register for FA, but more register for GEMM