[ck_tile] Add get_partition_index_v2 which uses warp_id in vgpr and to be used by tile_windows on lds-based tensor_view

This commit is contained in:
Qianfeng Zhang
2025-08-06 14:40:38 +00:00
parent ae05715998
commit 40261225e8
3 changed files with 48 additions and 4 deletions

View File

@@ -68,6 +68,13 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
// Use these instead
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
CK_TILE_DEVICE index_t get_warp_id_to_vgpr() { return threadIdx.x / get_warp_size(); }
CK_TILE_DEVICE index_t get_warp_id_to_sgpr()
{
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
}
CK_TILE_DEVICE index_t get_warp_id()
{
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());

View File

@@ -23,6 +23,12 @@ CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
return Distribution::_get_partition_index();
}
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index_v2(Distribution)
{
return Distribution::_get_partition_index_v2();
}
} // namespace detail
// distributed span
@@ -102,7 +108,22 @@ struct tile_distribution
}
else if constexpr(NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
return array<index_t, 2>{get_warp_id_to_sgpr(), get_lane_id()};
}
}
CK_TILE_HOST_DEVICE static auto _get_partition_index_v2()
{
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
if constexpr(NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(NDimP == 2)
{
return array<index_t, 2>{get_warp_id_to_vgpr(), get_lane_id()};
}
}

View File

@@ -179,10 +179,18 @@ struct tile_window_with_static_distribution
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto partition_index = [&]() {
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::lds)
return detail::get_partition_index_v2(tile_dstr_);
else
return detail::get_partition_index(tile_dstr_);
}();
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, NDimY>{0}));
container_concat(partition_index, array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
@@ -911,11 +919,19 @@ struct tile_window_with_static_distribution
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
const auto partition_index = [&]() {
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::lds)
return detail::get_partition_index_v2(tile_dstr_);
else
return detail::get_partition_index(tile_dstr_);
}();
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
container_concat(partition_index, array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =