From 32b68ff8862e87565bf55fc5186390820a9be892 Mon Sep 17 00:00:00 2001 From: ThomasNing Date: Sun, 6 Jul 2025 17:01:11 +0000 Subject: [PATCH] Comment Addressed --- CHANGELOG.md | 1 + example/ck_tile/36_copy/test_copy.hpp | 20 +++++------ .../core/arch/amd_buffer_addressing.hpp | 21 +++++------- .../arch/amd_buffer_addressing_builtins.hpp | 21 +++++------- include/ck_tile/core/tensor/tile_window.hpp | 33 +++++++++++------- .../core/tensor/tile_window_linear.hpp | 34 +++++++++++-------- 6 files changed, 68 insertions(+), 62 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f04935b8d..eb4c0d0a87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166) * Added Vectorize Transpose optimization for CK Tile (#2131) +* Added the asynchronous copy for gfx950 (#2425) ### Fixes diff --git a/example/ck_tile/36_copy/test_copy.hpp b/example/ck_tile/36_copy/test_copy.hpp index 1368872f6d..0b3c87d472 100644 --- a/example/ck_tile/36_copy/test_copy.hpp +++ b/example/ck_tile/36_copy/test_copy.hpp @@ -159,7 +159,16 @@ struct TileCopy if(my_id == warp_id) { - if constexpr(AsyncCopy == false) + if constexpr(AsyncCopy) + { + async_load_tile(x_block_lds_window_no_dist, x_block_window); + + load_tile(dram_tile, x_block_lds_window); + + // store from registers to DRAM + store_tile(y_block_window, dram_tile); + } + else { // load from DRAM to registers load_tile(dram_tile, x_block_window); @@ -170,15 +179,6 @@ struct TileCopy // read from lds to registers load_tile(dram_tile, x_block_lds_window); - // store from registers to DRAM - store_tile(y_block_window, dram_tile); - } - else - { - async_load_tile(x_block_lds_window_no_dist, x_block_window); - - load_tile(dram_tile, x_block_lds_window); - // store from registers to DRAM store_tile(y_block_window, dram_tile); } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 18318c5f56..aafc6c0a85 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -24,6 +24,8 @@ #define LIKELY(x) (__builtin_expect(!!(x), 1)) #endif +using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*; + namespace ck_tile { // 128 bit SGPRs to supply buffer resource in buffer instructions @@ -1271,7 +1273,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, // Direct loads from global to LDS. CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, - __attribute__((address_space(3))) uint32_t* lds_ptr, + as3_uint32_ptr lds_ptr, index_t size, index_t voffset, index_t soffset, @@ -1791,8 +1793,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; llvm_amdgcn_raw_buffer_load_lds( src_wave_buffer_resource, - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(smem)), + reinterpret_cast(reinterpret_cast(smem)), bytes, v_offset, 0, @@ -1803,8 +1804,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, { llvm_amdgcn_raw_buffer_load_lds( src_wave_buffer_resource, - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(smem)), + reinterpret_cast(reinterpret_cast(smem)), bytes, src_thread_addr_offset, 0, @@ -1818,8 +1818,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; llvm_amdgcn_raw_buffer_load_lds( src_wave_buffer_resource, - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(smem)), + reinterpret_cast(reinterpret_cast(smem)), bytes, v_offset, src_wave_addr_offset, @@ -1830,8 +1829,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, { llvm_amdgcn_raw_buffer_load_lds( src_wave_buffer_resource, - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(smem)), + reinterpret_cast(reinterpret_cast(smem)), bytes, src_thread_addr_offset, src_wave_addr_offset, @@ -2812,9 +2810,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, : "memory"); #else // LDS pointer must be attributed with the LDS address space. - __attribute__((address_space(3))) uint32_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(lds_base_ptr + lds_offset)); + as3_uint32_ptr lds_ptr = + reinterpret_cast(reinterpret_cast(lds_base_ptr + lds_offset)); llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 74781f24d1..6ada83aa0e 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -14,6 +14,8 @@ #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" +using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*; + namespace ck_tile { // 128 bit SGPRs to supply buffer resource in buffer instructions @@ -1138,7 +1140,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, // Direct loads from global to LDS. CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, - __attribute__((address_space(3))) uint32_t* lds_ptr, + as3_uint32_ptr lds_ptr, index_t size, index_t voffset, index_t soffset, @@ -1560,8 +1562,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; llvm_amdgcn_raw_buffer_load_lds( src_wave_buffer_resource, - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(smem)), + reinterpret_cast(reinterpret_cast(smem)), bytes, v_offset, 0, @@ -1572,8 +1573,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, { llvm_amdgcn_raw_buffer_load_lds( src_wave_buffer_resource, - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(smem)), + reinterpret_cast(reinterpret_cast(smem)), bytes, src_thread_addr_offset, 0, @@ -1587,8 +1587,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; llvm_amdgcn_raw_buffer_load_lds( src_wave_buffer_resource, - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(smem)), + reinterpret_cast(reinterpret_cast(smem)), bytes, v_offset, src_wave_addr_offset, @@ -1599,8 +1598,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, { llvm_amdgcn_raw_buffer_load_lds( src_wave_buffer_resource, - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(smem)), + reinterpret_cast(reinterpret_cast(smem)), bytes, src_thread_addr_offset, src_wave_addr_offset, @@ -2581,9 +2579,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, : "memory"); #else // LDS pointer must be attributed with the LDS address space. - __attribute__((address_space(3))) uint32_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(lds_base_ptr + lds_offset)); + as3_uint32_ptr lds_ptr = + reinterpret_cast(reinterpret_cast(lds_base_ptr + lds_offset)); llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index a001cdc3c3..ad5902f16e 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -349,34 +349,41 @@ struct tile_window_with_static_distribution using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - // loop over thread tensor space [y0, y1, ...] + // Precompute invariant values outside loops + const auto window_origin = lds_tile.get_window_origin(); + const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view(); + const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor(); + auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_; + static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; - auto lds_bottom_tensor_thread_idx = - lds_tile.get_window_origin() + window_adaptor_thread_coord.get_bottom_index(); - const auto lds_coord = make_tensor_coordinate( - lds_tile.get_bottom_tensor_view().get_tensor_descriptor(), - lds_bottom_tensor_thread_idx); - CK_TILE_LDS_ADDR LdsDataType* smem = - lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + - lds_coord.get_offset(); - // write into bottom tensor + // Use precomputed window origin + auto lds_bottom_tensor_thread_idx = + window_origin + window_adaptor_thread_coord.get_bottom_index(); + + // Use precomputed tensor descriptor + const auto lds_coord = + make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); + + // Calculate SMEM address using base pointer + CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset(); + + // Write into bottom tensor this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, number<0>{}, bool_constant{}); - // move thread coordinate + // Move thread coordinate if not last access if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { - constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 9f012a1e9d..c4b24fba93 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -557,31 +557,34 @@ struct tile_window_linear using LdsDataType = typename LdsTileWindow::DataType; using vector_t = typename traits::vector_t; - // currently we only support everything is non linear dim - // actually it's not performant if we have linear dim(e.g. fast changing) - static_assert(NumAccess_NonLinear == NumAccess); + static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration"); static_assert(Base::BottomTensorView::buffer_view::get_address_space() == - address_space_enum::global); + address_space_enum::global, + "Requires global memory"); + + // Precompute invariant values outside the lambda + const auto window_origin = lds_tile.get_window_origin(); + const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view(); + const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor(); + auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_; - // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { - constexpr auto IAccess = number{}; - constexpr auto non_linear_id = number{}; + constexpr auto IAccess = number{}; + constexpr auto non_linear_id = number{}; + + // Use precomputed values auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id]; auto bottom_tensor_flag = cached_flags_[IAccess]; auto lds_bottom_tensor_thread_idx = - lds_tile.get_window_origin() + window_adaptor_coord.get_bottom_index(); - + window_origin + window_adaptor_coord.get_bottom_index(); const auto lds_coord = - make_tensor_coordinate(lds_tile.get_bottom_tensor_view().get_tensor_descriptor(), - lds_bottom_tensor_thread_idx); - CK_TILE_LDS_ADDR LdsDataType* smem = - lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + - lds_coord.get_offset(); + make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); - // read from bottom tensor + CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset(); + + // Read from bottom tensor this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, @@ -589,6 +592,7 @@ struct tile_window_linear bottom_tensor_flag, bool_constant{}); }; + WINDOW_DISPATCH_ISSUE(); }