From a2b161d55234b39b6cf3e971e6cfe6637ce3c1fc Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Wed, 6 May 2026 05:06:49 +0900 Subject: [PATCH] Preserve raw async tile-load semantics on gfx12 gfx12 falls back from async global-to-LDS loads to sync VGPR loads plus LDS stores. The async raw API relies on buffer OOB behavior instead of tensor-coordinate validity, so keep the sync fallback aligned with that raw-load contract. --- include/ck_tile/core/tensor/tile_window.hpp | 36 +++++++++++++++++++ .../core/tensor/tile_window_linear.hpp | 34 ++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 3e28544509..33d21737cb 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -501,6 +501,7 @@ struct tile_window_with_static_distribution // issues * warps * lanes static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded +#if !defined(__gfx12__) const index_t size_per_buf = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<0>{}, number<0>{})) * @@ -523,13 +524,24 @@ struct tile_window_with_static_distribution size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant{}); m0_set_with_memory( amd_wave_read_first_lane(m0_init_value)); // This should be wave independent +#endif using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; +#if defined(__gfx12__) + // gfx12 does not support the direct global-to-LDS async buffer load. Preserve the + // raw LDS issue/warp/lane layout by loading through VGPRs and explicitly storing + // each vector to the LDS coordinate the async instruction would have targeted. + auto lds_bottom_tensor_view = lds_tile.get_bottom_tensor_view(); + const auto& lds_tensor_desc = lds_bottom_tensor_view.get_tensor_descriptor(); + const auto lds_lane_id = get_lane_id(); + const auto lds_warp_id = get_warp_id(/*ReturnSgpr=*/bool_constant{}); +#else LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; +#endif // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -546,9 +558,31 @@ struct tile_window_with_static_distribution return bool_constant{}; }(); +#if defined(__gfx12__) + vector_t vec_value; + // async_get_vectorized_elements_raw ignores tensor-coordinate validity; it relies + // on raw buffer OOB behavior instead. Keep the synchronous gfx12 fallback aligned + // with that raw-load contract. + this->get_bottom_tensor_view() + .template get_vectorized_elements_raw( + vec_value, + bottom_tensor_thread_coord, + 0, + bool_constant{}, + pre_nop_); + + const typename LdsTileWindow::BottomTensorIndex lds_thread_idx{ + static_cast(iCoord * NumAccessPerCoord + iCoordAccess), + lds_warp_id, + lds_lane_id}; + const auto lds_coord = make_tensor_coordinate(lds_tensor_desc, lds_thread_idx); + lds_bottom_tensor_view.template set_vectorized_elements( + lds_coord, 0, true, vec_value); +#else // read from bottom tensor this->get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, 0, pre_nop_); +#endif // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -562,7 +596,9 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); +#if !defined(__gfx12__) m0_inc_with_memory(size_per_issue); +#endif } }); }); diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 118223d9f9..905be46966 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -499,6 +499,7 @@ struct tile_window_linear // issues * warps * lanes static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded +#if !defined(__gfx12__) const index_t size_per_buf = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<0>{}, number<0>{})) * @@ -519,10 +520,21 @@ struct tile_window_linear const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); m0_set_with_memory( amd_wave_read_first_lane(m0_init_value)); // This should be wave independent +#endif using vector_t = typename Base::Traits::vector_t; +#if defined(__gfx12__) + // gfx12 does not support the direct global-to-LDS async buffer load. Preserve the + // raw LDS issue/warp/lane layout by loading through VGPRs and explicitly storing + // each vector to the LDS coordinate the async instruction would have targeted. + auto lds_bottom_tensor_view = lds_tile.get_bottom_tensor_view(); + const auto& lds_tensor_desc = lds_bottom_tensor_view.get_tensor_descriptor(); + const auto lds_lane_id = get_lane_id(); + const auto lds_warp_id = get_warp_id(); +#else LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; +#endif // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { @@ -538,14 +550,36 @@ struct tile_window_linear auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway +#if defined(__gfx12__) + vector_t vec_value; + // async_get_vectorized_elements_raw ignores tensor-coordinate validity; it relies + // on raw buffer OOB behavior instead. Keep the synchronous gfx12 fallback aligned + // with that raw-load contract. + this->get_bottom_tensor_view().template get_vectorized_elements_raw( + vec_value, + bottom_tensor_thread_coord, + 0, + bottom_tensor_flag, + bool_constant{}, + pre_nop_); + + const typename LdsTileWindow::BottomTensorIndex lds_thread_idx{ + static_cast(i_access_), lds_warp_id, lds_lane_id}; + const auto lds_coord = make_tensor_coordinate(lds_tensor_desc, lds_thread_idx); + lds_bottom_tensor_view.template set_vectorized_elements( + lds_coord, 0, true, vec_value); +#else // read from bottom tensor this->get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_); +#endif // move thread coordinate if constexpr(i_access_ != (NumAccess - 1)) { +#if !defined(__gfx12__) m0_inc_with_memory(size_per_issue); +#endif } };