mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Preserve raw async tile-load semantics on gfx12
gfx12 falls back from async global-to-LDS loads to sync VGPR loads plus LDS stores. The async raw API relies on buffer OOB behavior instead of tensor-coordinate validity, so keep the sync fallback aligned with that raw-load contract.
This commit is contained in:
@@ -501,6 +501,7 @@ struct tile_window_with_static_distribution
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
#if !defined(__gfx12__)
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
@@ -523,13 +524,24 @@ struct tile_window_with_static_distribution
|
||||
size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
|
||||
m0_set_with_memory(
|
||||
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
|
||||
#endif
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
#if defined(__gfx12__)
|
||||
// gfx12 does not support the direct global-to-LDS async buffer load. Preserve the
|
||||
// raw LDS issue/warp/lane layout by loading through VGPRs and explicitly storing
|
||||
// each vector to the LDS coordinate the async instruction would have targeted.
|
||||
auto lds_bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& lds_tensor_desc = lds_bottom_tensor_view.get_tensor_descriptor();
|
||||
const auto lds_lane_id = get_lane_id();
|
||||
const auto lds_warp_id = get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
|
||||
#else
|
||||
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
#endif
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -546,9 +558,31 @@ struct tile_window_with_static_distribution
|
||||
return bool_constant<false>{};
|
||||
}();
|
||||
|
||||
#if defined(__gfx12__)
|
||||
vector_t vec_value;
|
||||
// async_get_vectorized_elements_raw ignores tensor-coordinate validity; it relies
|
||||
// on raw buffer OOB behavior instead. Keep the synchronous gfx12 fallback aligned
|
||||
// with that raw-load contract.
|
||||
this->get_bottom_tensor_view()
|
||||
.template get_vectorized_elements_raw<vector_t, false>(
|
||||
vec_value,
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
bool_constant<false>{},
|
||||
pre_nop_);
|
||||
|
||||
const typename LdsTileWindow::BottomTensorIndex lds_thread_idx{
|
||||
static_cast<index_t>(iCoord * NumAccessPerCoord + iCoordAccess),
|
||||
lds_warp_id,
|
||||
lds_lane_id};
|
||||
const auto lds_coord = make_tensor_coordinate(lds_tensor_desc, lds_thread_idx);
|
||||
lds_bottom_tensor_view.template set_vectorized_elements<vector_t, false>(
|
||||
lds_coord, 0, true, vec_value);
|
||||
#else
|
||||
// read from bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, pre_nop_);
|
||||
#endif
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
@@ -562,7 +596,9 @@ struct tile_window_with_static_distribution
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
#if !defined(__gfx12__)
|
||||
m0_inc_with_memory(size_per_issue);
|
||||
#endif
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -499,6 +499,7 @@ struct tile_window_linear
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
#if !defined(__gfx12__)
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
@@ -519,10 +520,21 @@ struct tile_window_linear
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(
|
||||
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
|
||||
#endif
|
||||
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
#if defined(__gfx12__)
|
||||
// gfx12 does not support the direct global-to-LDS async buffer load. Preserve the
|
||||
// raw LDS issue/warp/lane layout by loading through VGPRs and explicitly storing
|
||||
// each vector to the LDS coordinate the async instruction would have targeted.
|
||||
auto lds_bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& lds_tensor_desc = lds_bottom_tensor_view.get_tensor_descriptor();
|
||||
const auto lds_lane_id = get_lane_id();
|
||||
const auto lds_warp_id = get_warp_id();
|
||||
#else
|
||||
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
#endif
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
@@ -538,14 +550,36 @@ struct tile_window_linear
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
|
||||
|
||||
#if defined(__gfx12__)
|
||||
vector_t vec_value;
|
||||
// async_get_vectorized_elements_raw ignores tensor-coordinate validity; it relies
|
||||
// on raw buffer OOB behavior instead. Keep the synchronous gfx12 fallback aligned
|
||||
// with that raw-load contract.
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t, false>(
|
||||
vec_value,
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<false>{},
|
||||
pre_nop_);
|
||||
|
||||
const typename LdsTileWindow::BottomTensorIndex lds_thread_idx{
|
||||
static_cast<index_t>(i_access_), lds_warp_id, lds_lane_id};
|
||||
const auto lds_coord = make_tensor_coordinate(lds_tensor_desc, lds_thread_idx);
|
||||
lds_bottom_tensor_view.template set_vectorized_elements<vector_t, false>(
|
||||
lds_coord, 0, true, vec_value);
|
||||
#else
|
||||
// read from bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
|
||||
#endif
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(i_access_ != (NumAccess - 1))
|
||||
{
|
||||
#if !defined(__gfx12__)
|
||||
m0_inc_with_memory(size_per_issue);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user