mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
add lds offset to async buffer load
This commit is contained in:
@@ -98,9 +98,11 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
bool_constant<pre_nop> = {},
|
||||
index_t lds_offset = 0)
|
||||
{
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
lds_offset,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
|
||||
@@ -251,6 +251,33 @@ struct tensor_view
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t coord_extra_offset,
|
||||
bool is_valid_element,
|
||||
index_t linear_offset,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem,
|
||||
(coord.get_offset() + coord_extra_offset) / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element && coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
return buf_.template async_get_raw<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
|
||||
@@ -410,6 +410,7 @@ struct tile_scatter_gather
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
|
||||
index_t lds_offset = 0,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
@@ -439,7 +440,7 @@ struct tile_scatter_gather
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id() + lds_offset;
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
@@ -259,6 +259,7 @@ struct tile_window_with_static_distribution
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
|
||||
index_t lds_offset = 0,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
@@ -288,7 +289,7 @@ struct tile_window_with_static_distribution
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id() + lds_offset;
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
@@ -478,6 +478,7 @@ struct tile_window_linear
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
|
||||
index_t lds_offset = 0,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
@@ -511,7 +512,7 @@ struct tile_window_linear
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id() + lds_offset;
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
Reference in New Issue
Block a user