[ck_tile] Merge get_partition_index() and get_partition_index_v2() to get_partition_index() with bool_constant parameter

This commit is contained in:
Qianfeng Zhang
2025-08-08 06:20:38 +00:00
parent 40261225e8
commit fd25f5df05
3 changed files with 17 additions and 39 deletions

View File

@@ -68,16 +68,13 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
// Use these instead
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
CK_TILE_DEVICE index_t get_warp_id_to_vgpr() { return threadIdx.x / get_warp_size(); }
CK_TILE_DEVICE index_t get_warp_id_to_sgpr()
template <bool save_warp_id_in_sgpr = true>
CK_TILE_DEVICE index_t get_warp_id(bool_constant<save_warp_id_in_sgpr> = {})
{
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
}
CK_TILE_DEVICE index_t get_warp_id()
{
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
if constexpr(save_warp_id_in_sgpr)
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
else
return threadIdx.x / get_warp_size();
}
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }

View File

@@ -18,16 +18,10 @@
namespace ck_tile {
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
template <typename Distribution, bool save_warp_id_in_sgpr = true>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution, bool_constant<save_warp_id_in_sgpr> = {})
{
return Distribution::_get_partition_index();
}
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index_v2(Distribution)
{
return Distribution::_get_partition_index_v2();
return Distribution::_get_partition_index(bool_constant<save_warp_id_in_sgpr>{});
}
} // namespace detail
@@ -97,7 +91,8 @@ struct tile_distribution
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
CK_TILE_HOST_DEVICE static auto _get_partition_index()
template <bool save_warp_id_in_sgpr = true>
CK_TILE_HOST_DEVICE static auto _get_partition_index(bool_constant<save_warp_id_in_sgpr> = {})
{
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
@@ -108,22 +103,8 @@ struct tile_distribution
}
else if constexpr(NDimP == 2)
{
return array<index_t, 2>{get_warp_id_to_sgpr(), get_lane_id()};
}
}
CK_TILE_HOST_DEVICE static auto _get_partition_index_v2()
{
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
if constexpr(NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(NDimP == 2)
{
return array<index_t, 2>{get_warp_id_to_vgpr(), get_lane_id()};
return array<index_t, 2>{get_warp_id(bool_constant<save_warp_id_in_sgpr>{}),
get_lane_id()};
}
}

View File

@@ -183,9 +183,9 @@ struct tile_window_with_static_distribution
const auto partition_index = [&]() {
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::lds)
return detail::get_partition_index_v2(tile_dstr_);
return detail::get_partition_index(tile_dstr_, bool_constant<false>{});
else
return detail::get_partition_index(tile_dstr_);
return detail::get_partition_index(tile_dstr_, bool_constant<true>{});
}();
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
@@ -922,9 +922,9 @@ struct tile_window_with_static_distribution
const auto partition_index = [&]() {
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::lds)
return detail::get_partition_index_v2(tile_dstr_);
return detail::get_partition_index(tile_dstr_, bool_constant<false>{});
else
return detail::get_partition_index(tile_dstr_);
return detail::get_partition_index(tile_dstr_, bool_constant<true>{});
}();
// TODO: this use less register for FA, but more register for GEMM