From c2ab44bf2fef3619e7db8021061586a9df2a0aba Mon Sep 17 00:00:00 2001 From: ThomasNing Date: Sun, 29 Jun 2025 01:02:47 -0500 Subject: [PATCH] Finished the feature --- .../core/arch/amd_buffer_addressing.hpp | 8 +-- include/ck_tile/core/tensor/load_tile.hpp | 13 +++++ include/ck_tile/core/tensor/tensor_view.hpp | 6 +- include/ck_tile/core/tensor/tile_window.hpp | 40 ++++++++++++- .../core/tensor/tile_window_linear.hpp | 57 +++++-------------- 5 files changed, 73 insertions(+), 51 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 25d22e922e..a30503379f 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1797,14 +1797,14 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, reinterpret_cast(smem)), sizeof(uint32_t), v_offset, - 0, - 0, + src_wave_addr_offset, + src_immediate_addr_offset, static_cast(coherence)); } else { llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( ), + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(reinterpret_cast(smem)), bytes, src_thread_addr_offset, src_wave_addr_offset, @@ -2798,7 +2798,7 @@ template __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) { - static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), + static_assert(__has_builtin(__builtin_amdgcn_ds_read_tr16_b64_v4f16), "We need to have the compatible compiler version to build this instruction"); if constexpr(std::is_same_v, ck_tile::half_t>) { diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 4601261197..a919fafea4 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -89,6 +89,19 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, tile, number{}, bool_constant{}, bool_constant{}); } +template +CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, + const TileWindow_& tile_window, + number = {}, + bool_constant = {}) +{ + return tile_window.async_load( + lds_tile, number{}, bool_constant{}); +} + template * smem, const TensorCoord& coord, - index_t linear_offset) const + index_t linear_offset, + bool_constant = {}) const { return buf_.template async_get( smem, @@ -181,7 +182,8 @@ struct tensor_view async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t* smem, const TensorCoord& coord, index_t linear_offset, - bool is_valid_element) const + bool is_valid_element, + bool_constant = {}) const { return buf_.template async_get(smem, coord.get_offset() / PackedSize, diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 6027668c8e..d9ea18cc78 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -345,8 +345,44 @@ struct tile_window_with_static_distribution using LdsTileWindow = remove_cvref_t; using LdsDataType = typename LdsTileWindow::DataType; - // issues * warps * lanes - static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + using Traits = typename Base::Traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + auto lds_bottom_tensor_thread_idx = + lds_tile.get_window_origin() + window_adaptor_thread_coord.get_bottom_index(); + + const auto lds_coord = make_tensor_coordinate( + lds_tile.get_bottom_tensor_view().get_tensor_descriptor(), + lds_bottom_tensor_thread_idx); + CK_TILE_LDS_ADDR LdsDataType* smem = + lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + + lds_coord.get_offset(); + // write into bottom tensor + this->get_bottom_tensor_view().template async_get_vectorized_elements( + smem, bottom_tensor_thread_coord, 0, bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); // TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out // dependency) hence avoid use offset based solution. size_per_buf should be zero (how to diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 596584f3cc..db8fb20d50 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -186,7 +186,7 @@ struct tile_window_linear const typename Base::WindowLengths& window_lengths, const typename Base::BottomTensorIndex& window_origin, const typename Base::TileDstr& tile_distribution) - : cached_coords_{}, cached_flags_{} + : cached_coords_{}, cached_window_adaptor_coords_{}, cached_flags_{} { this->bottom_tensor_view_ = bottom_tensor_view; this->window_lengths_ = window_lengths; @@ -554,63 +554,32 @@ struct tile_window_linear { using LdsTileWindow = remove_cvref_t; using LdsDataType = typename LdsTileWindow::DataType; - - // currently we only support everything is non linear dim - // actually it's not performant if we have linear dim(e.g. fast changing) - static_assert(NumAccess_NonLinear == NumAccess); - static_assert(Base::BottomTensorView::buffer_view::get_address_space() == - address_space_enum::global); - - // issues * warps * lanes - static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded - - // TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out - // dependency) hence avoid use offset based solution. size_per_buf should be zero (how to - // check?) - constexpr index_t size_per_buf = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<0>{}, number<0>{}, number<0>{})); - - constexpr index_t size_per_wave = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<0>{}, number<1>{}, number<0>{})) - - size_per_buf; - - constexpr index_t size_per_issue = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<1>{}, number<0>{}, number<0>{})) - - size_per_buf; - - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - using vector_t = typename Base::Traits::vector_t; - // TODO: we force CK_TILE_LDS_ADDR - CK_TILE_LDS_ADDR LdsDataType* smem = - lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; - - // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { constexpr auto IAccess = number{}; constexpr auto non_linear_id = number{}; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id]; auto bottom_tensor_flag = cached_flags_[IAccess]; - // read from bottom tensor + auto lds_bottom_tensor_thread_idx = + lds_tile.get_window_origin() + window_adaptor_coord.get_bottom_index(); + + const auto lds_coord = + make_tensor_coordinate(lds_tile.get_bottom_tensor_view().get_tensor_descriptor(), + lds_bottom_tensor_thread_idx); + CK_TILE_LDS_ADDR LdsDataType* smem = + lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + + lds_coord.get_offset(); + this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, bool_constant{}); - - // move thread coordinate - if constexpr(i_access_ != (NumAccess - 1)) - { - smem += size_per_issue; // Note we manually increase the per-issue offset - } }; - WINDOW_DISPATCH_ISSUE(); } @@ -929,6 +898,7 @@ struct tile_window_linear if constexpr(need_save_non_linear_coord) { cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp; + cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp; } if constexpr(i_access != (NumAccess - 1)) @@ -948,6 +918,7 @@ struct tile_window_linear // this contains: array cached_coords_; + array cached_window_adaptor_coords_; array cached_flags_; };