mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
[ck_tile] Add get_partition_index_v2 which uses warp_id in vgpr and to be used by tile_windows on lds-based tensor_view
This commit is contained in:
@@ -68,6 +68,13 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
|
||||
// Use these instead
|
||||
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id_to_vgpr() { return threadIdx.x / get_warp_size(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id_to_sgpr()
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id()
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
|
||||
@@ -23,6 +23,12 @@ CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
|
||||
{
|
||||
return Distribution::_get_partition_index();
|
||||
}
|
||||
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index_v2(Distribution)
|
||||
{
|
||||
return Distribution::_get_partition_index_v2();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// distributed span
|
||||
@@ -102,7 +108,22 @@ struct tile_distribution
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id(), get_lane_id()};
|
||||
return array<index_t, 2>{get_warp_id_to_sgpr(), get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto _get_partition_index_v2()
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
return array<index_t, 1>{get_lane_id()};
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id_to_vgpr(), get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -179,10 +179,18 @@ struct tile_window_with_static_distribution
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
|
||||
const auto partition_index = [&]() {
|
||||
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::lds)
|
||||
return detail::get_partition_index_v2(tile_dstr_);
|
||||
else
|
||||
return detail::get_partition_index(tile_dstr_);
|
||||
}();
|
||||
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_distribution),
|
||||
array<index_t, NDimY>{0}));
|
||||
container_concat(partition_index, array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
@@ -911,11 +919,19 @@ struct tile_window_with_static_distribution
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
const auto partition_index = [&]() {
|
||||
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::lds)
|
||||
return detail::get_partition_index_v2(tile_dstr_);
|
||||
else
|
||||
return detail::get_partition_index(tile_dstr_);
|
||||
}();
|
||||
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
|
||||
container_concat(partition_index, array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
|
||||
Reference in New Issue
Block a user