From d4d57fcbdf3214484e436d6146bc4f5d1ebd54a8 Mon Sep 17 00:00:00 2001 From: zanzhang Date: Fri, 27 Jun 2025 14:25:38 +0800 Subject: [PATCH] add lds offset to async buffer load --- include/ck_tile/core/tensor/load_tile.hpp | 4 ++- include/ck_tile/core/tensor/tensor_view.hpp | 27 +++++++++++++++++++ .../core/tensor/tile_scatter_gather.hpp | 3 ++- include/ck_tile/core/tensor/tile_window.hpp | 3 ++- .../core/tensor/tile_window_linear.hpp | 3 ++- 5 files changed, 36 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 4601261197..f00277cfc3 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -98,9 +98,11 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, const TileWindow_& tile_window, number = {}, bool_constant = {}, - bool_constant = {}) + bool_constant = {}, + index_t lds_offset = 0) { return tile_window.async_load_raw(lds_tile, + lds_offset, number{}, bool_constant{}, bool_constant{}); diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 656ce8d20d..08d06462e9 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -251,6 +251,33 @@ struct tensor_view bool_constant{}); } + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + async_get_vectorized_elements_raw(remove_cvref_t* smem, + const TensorCoord& coord, + index_t coord_extra_offset, + bool is_valid_element, + index_t linear_offset, + bool_constant = {}) const + { + return buf_.template async_get_raw( + smem, + (coord.get_offset() + coord_extra_offset) / PackedSize, + linear_offset / PackedSize, + is_valid_element && coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); + return buf_.template async_get_raw(smem, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, + is_valid_element, + bool_constant{}); + } + // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + index_t lds_offset = 0, number = {}, bool_constant = {}, bool_constant = {}) const @@ -439,7 +440,7 @@ struct tile_scatter_gather sizeof(LdsDataType) - size_per_buf; - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id() + lds_offset; m0_set_with_memory(m0_init_value); // This should be wave independent using Traits = load_store_traits; diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index d8a5c14f9b..b94ee14599 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -259,6 +259,7 @@ struct tile_window_with_static_distribution bool oob_conditional_check = true, bool pre_nop = false> CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + index_t lds_offset = 0, number = {}, bool_constant = {}, bool_constant = {}) const @@ -288,7 +289,7 @@ struct tile_window_with_static_distribution sizeof(LdsDataType) - size_per_buf; - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id() + lds_offset; m0_set_with_memory(m0_init_value); // This should be wave independent using Traits = typename Base::Traits; diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index f11610d658..743b0bd3d9 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -478,6 +478,7 @@ struct tile_window_linear bool oob_conditional_check = true, bool pre_nop = false> CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + index_t lds_offset = 0, number = {}, bool_constant = {}, bool_constant = {}) const @@ -511,7 +512,7 @@ struct tile_window_linear sizeof(LdsDataType) - size_per_buf; - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id() + lds_offset; m0_set_with_memory(m0_init_value); // This should be wave independent using vector_t = typename Base::Traits::vector_t;