From 1f77d58ae965545c864ccdf512231eee76af7358 Mon Sep 17 00:00:00 2001 From: ruanjm Date: Mon, 16 Jun 2025 17:17:03 +0800 Subject: [PATCH] Add support for specifying valid flag when fetching elements for tile_scatter_gather (#2332) * Add support for specifying valid flag when fetching elements for tile_scatter_gather Add constexpr for operator[] of TrueGenerator * Use different path when valid is enabled [ROCm/composable_kernel commit: b34c234f5144d4ebd16ca04a379c907854d087ff] --- .../core/tensor/tile_scatter_gather.hpp | 167 +++++++++++++++--- 1 file changed, 147 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 351737d4d9..c7811133d6 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -33,6 +33,7 @@ template @@ -42,6 +43,7 @@ struct tile_scatter_gather using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; using PageIdxArray = remove_cvref_t; + using ValidArray = remove_cvref_t; using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; using BottomTensorDesc = typename BottomTensorView::TensorDesc; @@ -152,12 +154,14 @@ struct tile_scatter_gather const WindowLengths& window_lengths, const BottomTensorIndex& window_origin, const TileDstr& tile_distribution, - const PageIdxArray& page_idx) + const PageIdxArray& page_idx, + const ValidArray& valids) : bottom_tensor_view_{bottom_tensor_view}, window_lengths_{window_lengths}, window_origin_{window_origin}, tile_dstr_{tile_distribution}, page_idx_{page_idx}, + valids_{valids}, pre_computed_coords_{} { #if 0 // debug @@ -336,12 +340,25 @@ struct tile_scatter_gather constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_gather = idx_ys_start[number{}]; const auto page_offset = page_idx_[idx_gather]; + // read from bottom tensor - const vector_t vec_value = - get_bottom_tensor_view().template get_vectorized_elements( - bottom_tensor_thread_coord, - page_offset, - bool_constant{}); + const vector_t vec_value = [&]() { + if constexpr(std::is_same_v) + { + return get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + bool_constant{}); + } + else + { + return get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + valids_[idx_gather], + bool_constant{}); + } + }(); #if 1 // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { @@ -451,9 +468,23 @@ struct tile_scatter_gather constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_gather = idx_ys_start[number{}]; const auto page_offset = page_idx_[idx_gather]; + // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements_raw( - smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); + if constexpr(std::is_same_v) + { + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); + } + else + { + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, + bottom_tensor_thread_coord, + page_offset, + valids_[idx_gather], + 0, + pre_nop_); + } // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -529,11 +560,24 @@ struct tile_scatter_gather // const vector_t vec_value = vec.template get_as().template at<0>(); // write into bottom tensor - get_bottom_tensor_view().template set_vectorized_elements( - bottom_tensor_thread_coord, - page_offset, - vec_value, - bool_constant{}); + if constexpr(std::is_same_v) + { + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + vec_value, + bool_constant{}); + } + else + { + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + valids_[idx_gather], + vec_value, + bool_constant{}); + } + // printf("coord_offset:%d, scatter_offset:%d \n", // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -570,14 +614,23 @@ struct tile_scatter_gather }); } - CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) - { - page_idx_ = new_idx; + CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; } - // static_for<0, 2, 1>{}([&](auto k0) { - // printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]); - // }); + CK_TILE_DEVICE void update_valids(const ValidArray& new_valids) + { + if constexpr(std::is_same_v == false) + { + valids_ = new_valids; + } } + + CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray& new_idx, + const ValidArray& new_valids) + { + update_page_idx(new_idx); + update_valids(new_valids); + } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) { window_origin_ = new_window_origin; @@ -657,6 +710,7 @@ struct tile_scatter_gather TileDstr tile_dstr_; PageIdxArray page_idx_; + ValidArray valids_; // this contains: // per-thread coordinate for window adaptor @@ -684,9 +738,10 @@ make_tile_scatter_gather(const TensorView_& tensor_view, remove_cvref_t, remove_cvref_t, remove_cvref_t, + std::nullptr_t, HsGatherDim, NumCoord>{ - tensor_view, window_lengths, origin, tile_distribution, page_idx}; + tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; } template {}); } +template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + const StaticValidArray_& valids, + number = {}, + number = {}) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + HsGatherDim, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution, page_idx, valids}; +} + +template +CK_TILE_DEVICE constexpr auto make_tile_scatter_gather( + const tile_window_with_static_lengths& tile_window, + const multi_index& origin, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx, + const StaticValidArray& valids, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + origin, + tile_distribution, + page_idx, + valids, + number{}); +} + +template +CK_TILE_DEVICE constexpr auto make_tile_scatter_gather( + const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx, + const StaticValidArray& valids, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + page_idx, + valids, + number{}); +} + } // namespace ck_tile