Preserve raw async tile-load semantics on gfx12

gfx12 falls back from async global-to-LDS loads to sync VGPR loads plus LDS stores. The async raw API relies on buffer OOB behavior instead of tensor-coordinate validity, so keep the sync fallback aligned with that raw-load contract.
This commit is contained in:
Aaryaman Vasishta
2026-05-06 05:06:49 +09:00
parent 41cb5058b7
commit a2b161d552
2 changed files with 70 additions and 0 deletions

View File

@@ -501,6 +501,7 @@ struct tile_window_with_static_distribution
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
#if !defined(__gfx12__)
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
@@ -523,13 +524,24 @@ struct tile_window_with_static_distribution
size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
m0_set_with_memory(
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
#endif
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
#if defined(__gfx12__)
// gfx12 does not support the direct global-to-LDS async buffer load. Preserve the
// raw LDS issue/warp/lane layout by loading through VGPRs and explicitly storing
// each vector to the LDS coordinate the async instruction would have targeted.
auto lds_bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& lds_tensor_desc = lds_bottom_tensor_view.get_tensor_descriptor();
const auto lds_lane_id = get_lane_id();
const auto lds_warp_id = get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
#else
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
#endif
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
@@ -546,9 +558,31 @@ struct tile_window_with_static_distribution
return bool_constant<false>{};
}();
#if defined(__gfx12__)
vector_t vec_value;
// async_get_vectorized_elements_raw ignores tensor-coordinate validity; it relies
// on raw buffer OOB behavior instead. Keep the synchronous gfx12 fallback aligned
// with that raw-load contract.
this->get_bottom_tensor_view()
.template get_vectorized_elements_raw<vector_t, false>(
vec_value,
bottom_tensor_thread_coord,
0,
bool_constant<false>{},
pre_nop_);
const typename LdsTileWindow::BottomTensorIndex lds_thread_idx{
static_cast<index_t>(iCoord * NumAccessPerCoord + iCoordAccess),
lds_warp_id,
lds_lane_id};
const auto lds_coord = make_tensor_coordinate(lds_tensor_desc, lds_thread_idx);
lds_bottom_tensor_view.template set_vectorized_elements<vector_t, false>(
lds_coord, 0, true, vec_value);
#else
// read from bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, 0, pre_nop_);
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
@@ -562,7 +596,9 @@ struct tile_window_with_static_distribution
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
#if !defined(__gfx12__)
m0_inc_with_memory(size_per_issue);
#endif
}
});
});

View File

@@ -499,6 +499,7 @@ struct tile_window_linear
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
#if !defined(__gfx12__)
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
@@ -519,10 +520,21 @@ struct tile_window_linear
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
#endif
using vector_t = typename Base::Traits::vector_t;
#if defined(__gfx12__)
// gfx12 does not support the direct global-to-LDS async buffer load. Preserve the
// raw LDS issue/warp/lane layout by loading through VGPRs and explicitly storing
// each vector to the LDS coordinate the async instruction would have targeted.
auto lds_bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& lds_tensor_desc = lds_bottom_tensor_view.get_tensor_descriptor();
const auto lds_lane_id = get_lane_id();
const auto lds_warp_id = get_warp_id();
#else
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
#endif
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
@@ -538,14 +550,36 @@ struct tile_window_linear
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
#if defined(__gfx12__)
vector_t vec_value;
// async_get_vectorized_elements_raw ignores tensor-coordinate validity; it relies
// on raw buffer OOB behavior instead. Keep the synchronous gfx12 fallback aligned
// with that raw-load contract.
this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t, false>(
vec_value,
bottom_tensor_thread_coord,
0,
bottom_tensor_flag,
bool_constant<false>{},
pre_nop_);
const typename LdsTileWindow::BottomTensorIndex lds_thread_idx{
static_cast<index_t>(i_access_), lds_warp_id, lds_lane_id};
const auto lds_coord = make_tensor_coordinate(lds_tensor_desc, lds_thread_idx);
lds_bottom_tensor_view.template set_vectorized_elements<vector_t, false>(
lds_coord, 0, true, vec_value);
#else
// read from bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
#endif
// move thread coordinate
if constexpr(i_access_ != (NumAccess - 1))
{
#if !defined(__gfx12__)
m0_inc_with_memory(size_per_issue);
#endif
}
};