add lds offset to async buffer load

This commit is contained in:
zanzhang
2025-06-27 14:25:38 +08:00
parent 4fe591dd3e
commit d4d57fcbdf
5 changed files with 36 additions and 4 deletions

View File

@@ -98,9 +98,11 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
bool_constant<pre_nop> = {},
index_t lds_offset = 0)
{
return tile_window.async_load_raw(lds_tile,
lds_offset,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});

View File

@@ -251,6 +251,33 @@ struct tensor_view
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t coord_extra_offset,
bool is_valid_element,
index_t linear_offset,
bool_constant<pre_nop> = {}) const
{
return buf_.template async_get_raw<X>(
smem,
(coord.get_offset() + coord_extra_offset) / PackedSize,
linear_offset / PackedSize,
is_valid_element && coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
return buf_.template async_get_raw<X>(smem,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
is_valid_element,
bool_constant<pre_nop>{});
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,

View File

@@ -410,6 +410,7 @@ struct tile_scatter_gather
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
index_t lds_offset = 0,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
@@ -439,7 +440,7 @@ struct tile_scatter_gather
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id() + lds_offset;
m0_set_with_memory(m0_init_value); // This should be wave independent
using Traits = load_store_traits;

View File

@@ -259,6 +259,7 @@ struct tile_window_with_static_distribution
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
index_t lds_offset = 0,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
@@ -288,7 +289,7 @@ struct tile_window_with_static_distribution
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id() + lds_offset;
m0_set_with_memory(m0_init_value); // This should be wave independent
using Traits = typename Base::Traits;

View File

@@ -478,6 +478,7 @@ struct tile_window_linear
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
index_t lds_offset = 0,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
@@ -511,7 +512,7 @@ struct tile_window_linear
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id() + lds_offset;
m0_set_with_memory(m0_init_value); // This should be wave independent
using vector_t = typename Base::Traits::vector_t;