mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
[ck_tile] Merge get_partition_index() and get_partition_index_v2() to get_partition_index() with bool_constant parameter
This commit is contained in:
@@ -68,16 +68,13 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
|
||||
// Use these instead
|
||||
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id_to_vgpr() { return threadIdx.x / get_warp_size(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id_to_sgpr()
|
||||
template <bool save_warp_id_in_sgpr = true>
|
||||
CK_TILE_DEVICE index_t get_warp_id(bool_constant<save_warp_id_in_sgpr> = {})
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id()
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
if constexpr(save_warp_id_in_sgpr)
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
else
|
||||
return threadIdx.x / get_warp_size();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
@@ -18,16 +18,10 @@
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
|
||||
template <typename Distribution, bool save_warp_id_in_sgpr = true>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution, bool_constant<save_warp_id_in_sgpr> = {})
|
||||
{
|
||||
return Distribution::_get_partition_index();
|
||||
}
|
||||
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index_v2(Distribution)
|
||||
{
|
||||
return Distribution::_get_partition_index_v2();
|
||||
return Distribution::_get_partition_index(bool_constant<save_warp_id_in_sgpr>{});
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
@@ -97,7 +91,8 @@ struct tile_distribution
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto _get_partition_index()
|
||||
template <bool save_warp_id_in_sgpr = true>
|
||||
CK_TILE_HOST_DEVICE static auto _get_partition_index(bool_constant<save_warp_id_in_sgpr> = {})
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
@@ -108,22 +103,8 @@ struct tile_distribution
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id_to_sgpr(), get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto _get_partition_index_v2()
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
return array<index_t, 1>{get_lane_id()};
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id_to_vgpr(), get_lane_id()};
|
||||
return array<index_t, 2>{get_warp_id(bool_constant<save_warp_id_in_sgpr>{}),
|
||||
get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -183,9 +183,9 @@ struct tile_window_with_static_distribution
|
||||
const auto partition_index = [&]() {
|
||||
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::lds)
|
||||
return detail::get_partition_index_v2(tile_dstr_);
|
||||
return detail::get_partition_index(tile_dstr_, bool_constant<false>{});
|
||||
else
|
||||
return detail::get_partition_index(tile_dstr_);
|
||||
return detail::get_partition_index(tile_dstr_, bool_constant<true>{});
|
||||
}();
|
||||
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
@@ -922,9 +922,9 @@ struct tile_window_with_static_distribution
|
||||
const auto partition_index = [&]() {
|
||||
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::lds)
|
||||
return detail::get_partition_index_v2(tile_dstr_);
|
||||
return detail::get_partition_index(tile_dstr_, bool_constant<false>{});
|
||||
else
|
||||
return detail::get_partition_index(tile_dstr_);
|
||||
return detail::get_partition_index(tile_dstr_, bool_constant<true>{});
|
||||
}();
|
||||
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
|
||||
Reference in New Issue
Block a user