Add support for specifying valid flag when fetching elements for tile_scatter_gather (#2332)

* Add support for specifying valid flag when fetching elements for tile_scatter_gather

Add constexpr for operator[] of TrueGenerator

* Use different path when valid is enabled

[ROCm/composable_kernel commit: b34c234f51]
This commit is contained in:
ruanjm
2025-06-16 17:17:03 +08:00
committed by GitHub
parent 370dd01230
commit 1fdac8b8fe

View File

@@ -33,6 +33,7 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
index_t YsGatherDim = 0>
@@ -42,6 +43,7 @@ struct tile_scatter_gather
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
using ValidArray = remove_cvref_t<StaticValidArray_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
@@ -152,12 +154,14 @@ struct tile_scatter_gather
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution,
const PageIdxArray& page_idx)
const PageIdxArray& page_idx,
const ValidArray& valids)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
page_idx_{page_idx},
valids_{valids},
pre_computed_coords_{}
{
#if 0 // debug
@@ -336,12 +340,25 @@ struct tile_scatter_gather
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
bool_constant<oob_conditional_check>{});
const vector_t vec_value = [&]() {
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
bool_constant<oob_conditional_check>{});
}
else
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
bool_constant<oob_conditional_check>{});
}
}();
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
@@ -451,9 +468,23 @@ struct tile_scatter_gather
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
}
else
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem,
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
0,
pre_nop_);
}
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
@@ -529,11 +560,24 @@ struct tile_scatter_gather
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
}
else
{
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
vec_value,
bool_constant<oob_conditional_check>{});
}
// printf("coord_offset:%d, scatter_offset:%d \n",
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
@@ -570,14 +614,23 @@ struct tile_scatter_gather
});
}
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx)
{
page_idx_ = new_idx;
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
// static_for<0, 2, 1>{}([&](auto k0) {
// printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]);
// });
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
{
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
{
valids_ = new_valids;
}
}
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray& new_idx,
const ValidArray& new_valids)
{
update_page_idx(new_idx);
update_valids(new_valids);
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
@@ -657,6 +710,7 @@ struct tile_scatter_gather
TileDstr tile_dstr_;
PageIdxArray page_idx_;
ValidArray valids_;
// this contains:
// per-thread coordinate for window adaptor
@@ -684,9 +738,10 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
HsGatherDim,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, page_idx};
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}
template <typename TensorView,
@@ -728,4 +783,76 @@ CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
number<HsGatherDim>{});
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
const StaticValidArray_& valids,
number<HsGatherDim> = {},
number<NumCoord> = {})
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
remove_cvref_t<StaticValidArray_>,
HsGatherDim,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
typename StaticValidArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
const StaticValidArray& valids,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution,
page_idx,
valids,
number<HsGatherDim>{});
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
typename StaticValidArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
const StaticValidArray& valids,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
page_idx,
valids,
number<HsGatherDim>{});
}
} // namespace ck_tile