From 40261225e80da554113d912b0732408b74ab3cec Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 6 Aug 2025 14:40:38 +0000 Subject: [PATCH] [ck_tile] Add get_partition_index_v2 which uses warp_id in vgpr and to be used by tile_windows on lds-based tensor_view --- include/ck_tile/core/arch/arch.hpp | 7 ++++++ .../ck_tile/core/tensor/tile_distribution.hpp | 23 ++++++++++++++++++- include/ck_tile/core/tensor/tile_window.hpp | 22 +++++++++++++++--- 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 09de5f325f..751ad8f4ab 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -68,6 +68,13 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; } // Use these instead CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); } +CK_TILE_DEVICE index_t get_warp_id_to_vgpr() { return threadIdx.x / get_warp_size(); } + +CK_TILE_DEVICE index_t get_warp_id_to_sgpr() +{ + return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); +} + CK_TILE_DEVICE index_t get_warp_id() { return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 7761be492d..9fb30c3609 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -23,6 +23,12 @@ CK_TILE_HOST_DEVICE auto get_partition_index(Distribution) { return Distribution::_get_partition_index(); } + +template +CK_TILE_HOST_DEVICE auto get_partition_index_v2(Distribution) +{ + return Distribution::_get_partition_index_v2(); +} } // namespace detail // distributed span @@ -102,7 +108,22 @@ struct tile_distribution } else if constexpr(NDimP == 2) { - return array{get_warp_id(), get_lane_id()}; + return array{get_warp_id_to_sgpr(), get_lane_id()}; + } + } + + CK_TILE_HOST_DEVICE static auto _get_partition_index_v2() + { + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + if constexpr(NDimP == 1) + { + return array{get_lane_id()}; + } + else if constexpr(NDimP == 2) + { + return array{get_warp_id_to_vgpr(), get_lane_id()}; } } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 3bb728df23..befb02a65a 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -179,10 +179,18 @@ struct tile_window_with_static_distribution #else // TODO: this use less register for FA, but more register for GEMM // need investigation + + const auto partition_index = [&]() { + if constexpr(BottomTensorView::buffer_view::get_address_space() == + address_space_enum::lds) + return detail::get_partition_index_v2(tile_dstr_); + else + return detail::get_partition_index(tile_dstr_); + }(); + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_distribution), - array{0})); + container_concat(partition_index, array{0})); #endif BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = @@ -911,11 +919,19 @@ struct tile_window_with_static_distribution AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); } #else + const auto partition_index = [&]() { + if constexpr(BottomTensorView::buffer_view::get_address_space() == + address_space_enum::lds) + return detail::get_partition_index_v2(tile_dstr_); + else + return detail::get_partition_index(tile_dstr_); + }(); + // TODO: this use less register for FA, but more register for GEMM // need investigation const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_dstr_.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_dstr_), array{0})); + container_concat(partition_index, array{0})); #endif BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =