From 5c052119599faa7f123742ce39387700d528a539 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 7 Jul 2025 18:07:20 +0000 Subject: [PATCH] Merge commit 'f240ae32487219b4dd9d3152b816f87166e20feb' into develop --- CHANGELOG.md | 1 + example/ck_tile/03_gemm/gemm_utils.hpp | 1 - example/ck_tile/36_copy/test_copy.cpp | 11 +-- example/ck_tile/36_copy/test_copy.hpp | 37 ++++++--- .../core/arch/amd_buffer_addressing.hpp | 76 ++++++++++++++----- .../arch/amd_buffer_addressing_builtins.hpp | 75 +++++++++++++----- include/ck_tile/core/tensor/buffer_view.hpp | 4 +- include/ck_tile/core/tensor/load_tile.hpp | 13 ++++ include/ck_tile/core/tensor/tensor_view.hpp | 6 +- include/ck_tile/core/tensor/tile_window.hpp | 62 ++++++--------- .../core/tensor/tile_window_linear.hpp | 73 +++++++----------- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 9 +++ 12 files changed, 225 insertions(+), 143 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 86a426e321..17f9455feb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166) * Added Vectorize Transpose optimization for CK Tile (#2131) +* Added the asynchronous copy for gfx950 (#2425) ### Fixes diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 5f767d56aa..2157397f1d 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -15,7 +15,6 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -// temporary workaround to get k_warp_tile based on PrecType and gfx950 or not template constexpr ck_tile::index_t get_k_warp_tile() { diff --git a/example/ck_tile/36_copy/test_copy.cpp b/example/ck_tile/36_copy/test_copy.cpp index 81ea5255fc..4123408453 100644 --- a/example/ck_tile/36_copy/test_copy.cpp +++ b/example/ck_tile/36_copy/test_copy.cpp @@ -53,16 +53,17 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); - using BlockWaves = ck_tile::sequence<2, 1>; - using BlockTile = ck_tile::sequence<64, 8>; - using WaveTile = ck_tile::sequence<64, 8>; - using Vector = ck_tile::sequence<1, 4>; + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = ck_tile::sequence<64, 8>; + using WaveTile = ck_tile::sequence<64, 8>; + using Vector = ck_tile::sequence<1, 2>; + constexpr bool AsyncCopy = true; ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); std::cout << "grid size " << kGridSize << std::endl; using Shape = ck_tile::TileCopyShape; - using Problem = ck_tile::TileCopyProblem; + using Problem = ck_tile::TileCopyProblem; using Kernel = ck_tile::TileCopy; constexpr ck_tile::index_t kBlockSize = 128; diff --git a/example/ck_tile/36_copy/test_copy.hpp b/example/ck_tile/36_copy/test_copy.hpp index 8fed22a3d0..0b3c87d472 100644 --- a/example/ck_tile/36_copy/test_copy.hpp +++ b/example/ck_tile/36_copy/test_copy.hpp @@ -50,11 +50,12 @@ struct TileCopyShape static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, "Inconsisten wave group size!"); }; -template +template struct TileCopyProblem { - using XDataType = remove_cvref_t; - using BlockShape = remove_cvref_t; + using XDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + static constexpr bool AsyncCopy = AsyncCopy_; }; template @@ -63,6 +64,8 @@ struct TileCopy using Problem = ck_tile::remove_cvref_t; using XDataType = typename Problem::XDataType; + static constexpr bool AsyncCopy = Problem::AsyncCopy; + template CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() { @@ -156,17 +159,29 @@ struct TileCopy if(my_id == warp_id) { - // load from DRAM to registers - load_tile(dram_tile, x_block_window); + if constexpr(AsyncCopy) + { + async_load_tile(x_block_lds_window_no_dist, x_block_window); - // store in lds - store_tile(x_block_lds_window_no_dist, dram_tile); + load_tile(dram_tile, x_block_lds_window); - // read from lds to registers - load_tile(dram_tile, x_block_lds_window); + // store from registers to DRAM + store_tile(y_block_window, dram_tile); + } + else + { + // load from DRAM to registers + load_tile(dram_tile, x_block_window); - // store from registers to DRAM - store_tile(y_block_window, dram_tile); + // store in lds + store_tile(x_block_lds_window_no_dist, dram_tile); + + // read from lds to registers + load_tile(dram_tile, x_block_lds_window); + + // store from registers to DRAM + store_tile(y_block_window, dram_tile); + } } __syncthreads(); move_tile_window(x_block_window, {0, S::Block_N}); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 12f49aa4e3..aafc6c0a85 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -13,6 +13,7 @@ #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/ignore.hpp" // This attribute gives a hint to the compiler that a branch is likely to be taken. // Then, the compiler should remove if possible the associated s_cbranch_execz branch that would @@ -23,6 +24,8 @@ #define LIKELY(x) (__builtin_expect(!!(x), 1)) #endif +using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*; + namespace ck_tile { // 128 bit SGPRs to supply buffer resource in buffer instructions @@ -1270,7 +1273,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, // Direct loads from global to LDS. CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, - __attribute__((address_space(3))) uint32_t* lds_ptr, + as3_uint32_ptr lds_ptr, index_t size, index_t voffset, index_t soffset, @@ -1749,7 +1752,7 @@ template -CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, +CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, @@ -1779,29 +1782,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, index_t flag = 0, bool_constant = {}) { - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - + constexpr index_t bytes = sizeof(T) * N; +#if defined(__gfx950__) + static_assert(bytes == 4 || bytes == 12 || bytes == 16, + "wrong! only support in dword, dwordx3, dwordx4"); + ignore = src_wave_addr_offset; + ignore = src_immediate_addr_offset; if constexpr(oob_conditional_check) { index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - smem, - sizeof(uint32_t), - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + 0, + 0, + static_cast(coherence)); } else { - llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - smem, - sizeof(uint32_t), - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + src_thread_addr_offset, + 0, + 0, + static_cast(coherence)); } +#else + static_assert(bytes == 4, "wrong! not implemented vector size"); + if constexpr(oob_conditional_check) + { + index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + static_cast(coherence)); + } + else + { + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + static_cast(coherence)); + } +#endif } template ( - reinterpret_cast(lds_base_ptr + lds_offset)); + as3_uint32_ptr lds_ptr = + reinterpret_cast(reinterpret_cast(lds_base_ptr + lds_offset)); llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 306d2cdac3..6ada83aa0e 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -14,6 +14,8 @@ #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" +using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*; + namespace ck_tile { // 128 bit SGPRs to supply buffer resource in buffer instructions @@ -1138,7 +1140,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, // Direct loads from global to LDS. CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, - __attribute__((address_space(3))) uint32_t* lds_ptr, + as3_uint32_ptr lds_ptr, index_t size, index_t voffset, index_t soffset, @@ -1549,29 +1551,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, index_t flag = 0, bool_constant = {}) { - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - + constexpr index_t bytes = sizeof(T) * N; +#if defined(__gfx950__) + static_assert(bytes == 4 || bytes == 12 || bytes == 16, + "wrong! only support in dword, dwordx3, dwordx4"); + ignore = src_wave_addr_offset; + ignore = src_immediate_addr_offset; if constexpr(oob_conditional_check) { - index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - smem, - sizeof(uint32_t), - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); + index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + 0, + 0, + static_cast(coherence)); } else { - llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - smem, - sizeof(uint32_t), - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + src_thread_addr_offset, + 0, + 0, + static_cast(coherence)); } +#else + static_assert(bytes == 4, "wrong! not implemented vector size"); + if constexpr(oob_conditional_check) + { + index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + static_cast(coherence)); + } + else + { + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + static_cast(coherence)); + } +#endif } template ( - reinterpret_cast(lds_base_ptr + lds_offset)); + as3_uint32_ptr lds_ptr = + reinterpret_cast(reinterpret_cast(lds_base_ptr + lds_offset)); llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 231a2c832b..5cae332007 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -452,10 +452,12 @@ struct buffer_view, t_per_x, Coherence>( smem, - cached_buf_res_, + src_wave_buffer_resource, i, linear_offset, is_valid_element, diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 4601261197..8b7541bf23 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -89,6 +89,19 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, tile, number{}, bool_constant{}, bool_constant{}); } +template +CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, + const TileWindow_& tile_window, + number = {}, + bool_constant = {}) +{ + return tile_window.async_load( + lds_tile, number{}, bool_constant{}); +} + template * smem, const TensorCoord& coord, - index_t linear_offset) const + index_t linear_offset, + bool_constant = {}) const { return buf_.template async_get( smem, @@ -181,7 +182,8 @@ struct tensor_view async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t* smem, const TensorCoord& coord, index_t linear_offset, - bool is_valid_element) const + bool is_valid_element, + bool_constant = {}) const { return buf_.template async_get(smem, coord.get_offset() / PackedSize, diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 6027668c8e..ad5902f16e 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -344,64 +344,52 @@ struct tile_window_with_static_distribution { using LdsTileWindow = remove_cvref_t; using LdsDataType = typename LdsTileWindow::DataType; - - // issues * warps * lanes - static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded - - // TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out - // dependency) hence avoid use offset based solution. size_per_buf should be zero (how to - // check?) - constexpr index_t size_per_buf = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<0>{}, number<0>{}, number<0>{})); - - constexpr index_t size_per_wave = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<0>{}, number<1>{}, number<0>{})) - - size_per_buf; - - constexpr index_t size_per_issue = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<1>{}, number<0>{}, number<0>{})) - - size_per_buf; - - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - - using Traits = typename Base::Traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - // TODO: we force CK_TILE_LDS_ADDR - CK_TILE_LDS_ADDR LdsDataType* smem = - lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; + // Precompute invariant values outside loops + const auto window_origin = lds_tile.get_window_origin(); + const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view(); + const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor(); + auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_; - // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { - /// TODO: use structure binding (to be captured later) if compiled in C++20 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; - // read from bottom tensor - this->get_bottom_tensor_view().template async_get_vectorized_elements( - smem, bottom_tensor_thread_coord, 0, bool_constant{}); + // Use precomputed window origin + auto lds_bottom_tensor_thread_idx = + window_origin + window_adaptor_thread_coord.get_bottom_index(); - // move thread coordinate + // Use precomputed tensor descriptor + const auto lds_coord = + make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); + + // Calculate SMEM address using base pointer + CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset(); + + // Write into bottom tensor + this->get_bottom_tensor_view().template async_get_vectorized_elements( + smem, + bottom_tensor_thread_coord, + number<0>{}, + bool_constant{}); + + // Move thread coordinate if not last access if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { - constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - - smem += size_per_issue; // Note we manually increase the per-issue offset } }); }); diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 596584f3cc..c4b24fba93 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -186,7 +186,7 @@ struct tile_window_linear const typename Base::WindowLengths& window_lengths, const typename Base::BottomTensorIndex& window_origin, const typename Base::TileDstr& tile_distribution) - : cached_coords_{}, cached_flags_{} + : cached_coords_{}, cached_window_adaptor_coords_{}, cached_flags_{} { this->bottom_tensor_view_ = bottom_tensor_view; this->window_lengths_ = window_lengths; @@ -214,7 +214,8 @@ struct tile_window_linear if constexpr(need_save_non_linear_coord) { - cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp; + cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp; + cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp; } // TODO: need pad_tensor_view to check which dim need use flag to check @@ -554,61 +555,42 @@ struct tile_window_linear { using LdsTileWindow = remove_cvref_t; using LdsDataType = typename LdsTileWindow::DataType; + using vector_t = typename traits::vector_t; - // currently we only support everything is non linear dim - // actually it's not performant if we have linear dim(e.g. fast changing) - static_assert(NumAccess_NonLinear == NumAccess); + static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration"); static_assert(Base::BottomTensorView::buffer_view::get_address_space() == - address_space_enum::global); + address_space_enum::global, + "Requires global memory"); - // issues * warps * lanes - static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + // Precompute invariant values outside the lambda + const auto window_origin = lds_tile.get_window_origin(); + const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view(); + const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor(); + auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_; - // TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out - // dependency) hence avoid use offset based solution. size_per_buf should be zero (how to - // check?) - constexpr index_t size_per_buf = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<0>{}, number<0>{}, number<0>{})); - - constexpr index_t size_per_wave = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<0>{}, number<1>{}, number<0>{})) - - size_per_buf; - - constexpr index_t size_per_issue = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<1>{}, number<0>{}, number<0>{})) - - size_per_buf; - - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - - using vector_t = typename Base::Traits::vector_t; - - // TODO: we force CK_TILE_LDS_ADDR - CK_TILE_LDS_ADDR LdsDataType* smem = - lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; - - // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { - constexpr auto IAccess = number{}; - constexpr auto non_linear_id = number{}; + constexpr auto IAccess = number{}; + constexpr auto non_linear_id = number{}; + + // Use precomputed values auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id]; auto bottom_tensor_flag = cached_flags_[IAccess]; - // read from bottom tensor + auto lds_bottom_tensor_thread_idx = + window_origin + window_adaptor_coord.get_bottom_index(); + const auto lds_coord = + make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); + + CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset(); + + // Read from bottom tensor this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, bool_constant{}); - - // move thread coordinate - if constexpr(i_access_ != (NumAccess - 1)) - { - smem += size_per_issue; // Note we manually increase the per-issue offset - } }; WINDOW_DISPATCH_ISSUE(); @@ -928,7 +910,8 @@ struct tile_window_linear if constexpr(need_save_non_linear_coord) { - cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp; + cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp; + cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp; } if constexpr(i_access != (NumAccess - 1)) @@ -948,6 +931,8 @@ struct tile_window_linear // this contains: array cached_coords_; + array + cached_window_adaptor_coords_; array cached_flags_; }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 07bfb33252..6861adb153 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -32,6 +32,15 @@ struct GemmPipelineAgBgCrImplBase move_tile_window(dram_tile_window, dram_tile_window_step); } + template + CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window, + SrcTileWindow& dram_tile_window, + const DramTileWindowStep& dram_tile_window_step) const + { + async_load_tile(dst_block_window, dram_tile_window); + move_tile_window(dram_tile_window, dram_tile_window_step); + } + template CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, const SrcBlockTile& src_block_tile,