mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
rename to tile_scatter_gather
This commit is contained in:
@@ -35,7 +35,7 @@ template <typename BottomTensorView_,
|
||||
typename StaticPageIndexArray_,
|
||||
index_t PageIndexDim = 0,
|
||||
index_t NumCoord = 1>
|
||||
struct page_tile_with_static_distribution
|
||||
struct tile_scatter_gather
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
@@ -80,7 +80,7 @@ struct page_tile_with_static_distribution
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
page_tile_with_static_distribution::
|
||||
tile_scatter_gather::
|
||||
get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
@@ -146,9 +146,9 @@ struct page_tile_with_static_distribution
|
||||
|
||||
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
|
||||
|
||||
CK_TILE_DEVICE constexpr page_tile_with_static_distribution() = default;
|
||||
CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr page_tile_with_static_distribution(
|
||||
CK_TILE_DEVICE constexpr tile_scatter_gather(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
@@ -585,7 +585,7 @@ template <typename TensorView_,
|
||||
index_t PageIndexDim = 0,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window_paged(const TensorView_& tensor_view,
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
@@ -593,7 +593,7 @@ make_tile_window_paged(const TensorView_& tensor_view,
|
||||
number<PageIndexDim> = {},
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
return page_tile_with_static_distribution<remove_cvref_t<TensorView_>,
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
@@ -614,7 +614,7 @@ make_tile_window_paged(const TensorView_& tensor_view,
|
||||
// const StaticTileDistribution_& tile_distribution,
|
||||
// number<NumCoord> = {})
|
||||
// {
|
||||
// auto w = page_tile_with_static_distribution<remove_cvref_t<TensorView_>,
|
||||
// auto w = tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
// remove_cvref_t<WindowLengths_>,
|
||||
// remove_cvref_t<StaticTileDistribution_>,
|
||||
// NumCoord>{
|
||||
@@ -628,11 +628,11 @@ make_tile_window_paged(const TensorView_& tensor_view,
|
||||
// typename StaticTileDistribution_,
|
||||
// index_t NumCoord>
|
||||
// CK_TILE_DEVICE void move_tile_window(
|
||||
// page_tile_with_static_distribution<TensorView_,
|
||||
// tile_scatter_gather<TensorView_,
|
||||
// WindowLengths_,
|
||||
// StaticTileDistribution_,
|
||||
// NumCoord>& window,
|
||||
// const typename page_tile_with_static_distribution<TensorView_,
|
||||
// const typename tile_scatter_gather<TensorView_,
|
||||
// WindowLengths_,
|
||||
// StaticTileDistribution_,
|
||||
// NumCoord>::BottomTensorIndex& step)
|
||||
@@ -643,13 +643,13 @@ make_tile_window_paged(const TensorView_& tensor_view,
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray, index_t PageIndexDim>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
make_tile_scatter_gather(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
number<PageIndexDim> = {})
|
||||
{
|
||||
return make_tile_window_paged(tile_window.get_bottom_tensor_view(),
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
origin,
|
||||
tile_distribution,
|
||||
@@ -659,11 +659,11 @@ make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowL
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray, index_t PageIndexDim>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
make_tile_scatter_gather(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution, const StaticPageIndexArray& page_idx,
|
||||
number<PageIndexDim> = {})
|
||||
{
|
||||
return make_tile_window_paged(tile_window.get_bottom_tensor_view(),
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution,
|
||||
@@ -676,7 +676,7 @@ make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowL
|
||||
// make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
// const StaticTileDistribution& tile_distribution)
|
||||
// {
|
||||
// auto w = make_tile_window_paged(tile_window.get_bottom_tensor_view(),
|
||||
// auto w = make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
// tile_window.get_window_lengths(),
|
||||
// tile_window.get_window_origin(),
|
||||
// tile_distribution);
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_paged.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
@@ -214,7 +214,7 @@ struct CShuffleEpilogue
|
||||
{
|
||||
|
||||
|
||||
auto tile_window = make_tile_window_paged(out_dram_window.get_bottom_tensor_view(),
|
||||
auto tile_window = make_tile_scatter_gather(out_dram_window.get_bottom_tensor_view(),
|
||||
out_dram_window.get_window_lengths(),
|
||||
out_dram_window.get_window_origin(),
|
||||
dram_tile_distribution,
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_paged.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -289,7 +289,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] = page_idx[i_total_loops * kN0 + c_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
|
||||
});
|
||||
auto k_dram_window = make_tile_window_paged(
|
||||
auto k_dram_window = make_tile_scatter_gather(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
|
||||
Reference in New Issue
Block a user