mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Add support for specifying valid flag when fetching elements for tile_scatter_gather (#2332)
* Add support for specifying valid flag when fetching elements for tile_scatter_gather
Add constexpr for operator[] of TrueGenerator
* Use different path when valid is enabled
[ROCm/composable_kernel commit: b34c234f51]
This commit is contained in:
@@ -33,6 +33,7 @@ template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1,
|
||||
index_t YsGatherDim = 0>
|
||||
@@ -42,6 +43,7 @@ struct tile_scatter_gather
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
|
||||
using ValidArray = remove_cvref_t<StaticValidArray_>;
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
@@ -152,12 +154,14 @@ struct tile_scatter_gather
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution,
|
||||
const PageIdxArray& page_idx)
|
||||
const PageIdxArray& page_idx,
|
||||
const ValidArray& valids)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
page_idx_{page_idx},
|
||||
valids_{valids},
|
||||
pre_computed_coords_{}
|
||||
{
|
||||
#if 0 // debug
|
||||
@@ -336,12 +340,25 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
const vector_t vec_value = [&]() {
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
}();
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
@@ -451,9 +468,23 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
|
||||
}
|
||||
else
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
0,
|
||||
pre_nop_);
|
||||
}
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
@@ -529,11 +560,24 @@ struct tile_scatter_gather
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// printf("coord_offset:%d, scatter_offset:%d \n",
|
||||
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
@@ -570,14 +614,23 @@ struct tile_scatter_gather
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx)
|
||||
{
|
||||
page_idx_ = new_idx;
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
|
||||
|
||||
// static_for<0, 2, 1>{}([&](auto k0) {
|
||||
// printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]);
|
||||
// });
|
||||
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
|
||||
{
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
|
||||
{
|
||||
valids_ = new_valids;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray& new_idx,
|
||||
const ValidArray& new_valids)
|
||||
{
|
||||
update_page_idx(new_idx);
|
||||
update_valids(new_valids);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
@@ -657,6 +710,7 @@ struct tile_scatter_gather
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
PageIdxArray page_idx_;
|
||||
ValidArray valids_;
|
||||
|
||||
// this contains:
|
||||
// per-thread coordinate for window adaptor
|
||||
@@ -684,9 +738,10 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
std::nullptr_t,
|
||||
HsGatherDim,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx};
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
@@ -728,4 +783,76 @@ CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
const StaticValidArray_& valids,
|
||||
number<HsGatherDim> = {},
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
remove_cvref_t<StaticValidArray_>,
|
||||
HsGatherDim,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename StaticPageIndexArray,
|
||||
typename StaticValidArray,
|
||||
index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
const StaticValidArray& valids,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
origin,
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
valids,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename StaticPageIndexArray,
|
||||
typename StaticValidArray,
|
||||
index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
const StaticValidArray& valids,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
valids,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user