rename to tile_scatter_gather

This commit is contained in:
coderfeli
2025-04-08 02:50:46 +00:00
parent bfbd28c9d8
commit 8d62ff557a
3 changed files with 18 additions and 18 deletions

View File

@@ -35,7 +35,7 @@ template <typename BottomTensorView_,
typename StaticPageIndexArray_,
index_t PageIndexDim = 0,
index_t NumCoord = 1>
struct page_tile_with_static_distribution
struct tile_scatter_gather
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
@@ -80,7 +80,7 @@ struct page_tile_with_static_distribution
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
page_tile_with_static_distribution::
tile_scatter_gather::
get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
@@ -146,9 +146,9 @@ struct page_tile_with_static_distribution
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
CK_TILE_DEVICE constexpr page_tile_with_static_distribution() = default;
CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
CK_TILE_DEVICE constexpr page_tile_with_static_distribution(
CK_TILE_DEVICE constexpr tile_scatter_gather(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
@@ -585,7 +585,7 @@ template <typename TensorView_,
index_t PageIndexDim = 0,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_window_paged(const TensorView_& tensor_view,
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
@@ -593,7 +593,7 @@ make_tile_window_paged(const TensorView_& tensor_view,
number<PageIndexDim> = {},
number<NumCoord> = {})
{
return page_tile_with_static_distribution<remove_cvref_t<TensorView_>,
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
@@ -614,7 +614,7 @@ make_tile_window_paged(const TensorView_& tensor_view,
// const StaticTileDistribution_& tile_distribution,
// number<NumCoord> = {})
// {
// auto w = page_tile_with_static_distribution<remove_cvref_t<TensorView_>,
// auto w = tile_scatter_gather<remove_cvref_t<TensorView_>,
// remove_cvref_t<WindowLengths_>,
// remove_cvref_t<StaticTileDistribution_>,
// NumCoord>{
@@ -628,11 +628,11 @@ make_tile_window_paged(const TensorView_& tensor_view,
// typename StaticTileDistribution_,
// index_t NumCoord>
// CK_TILE_DEVICE void move_tile_window(
// page_tile_with_static_distribution<TensorView_,
// tile_scatter_gather<TensorView_,
// WindowLengths_,
// StaticTileDistribution_,
// NumCoord>& window,
// const typename page_tile_with_static_distribution<TensorView_,
// const typename tile_scatter_gather<TensorView_,
// WindowLengths_,
// StaticTileDistribution_,
// NumCoord>::BottomTensorIndex& step)
@@ -643,13 +643,13 @@ make_tile_window_paged(const TensorView_& tensor_view,
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray, index_t PageIndexDim>
CK_TILE_DEVICE constexpr auto
make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
make_tile_scatter_gather(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
number<PageIndexDim> = {})
{
return make_tile_window_paged(tile_window.get_bottom_tensor_view(),
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution,
@@ -659,11 +659,11 @@ make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowL
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray, index_t PageIndexDim>
CK_TILE_DEVICE constexpr auto
make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
make_tile_scatter_gather(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution, const StaticPageIndexArray& page_idx,
number<PageIndexDim> = {})
{
return make_tile_window_paged(tile_window.get_bottom_tensor_view(),
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
@@ -676,7 +676,7 @@ make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowL
// make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
// const StaticTileDistribution& tile_distribution)
// {
// auto w = make_tile_window_paged(tile_window.get_bottom_tensor_view(),
// auto w = make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
// tile_window.get_window_lengths(),
// tile_window.get_window_origin(),
// tile_distribution);

View File

@@ -6,7 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/core/tensor/tile_window_paged.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
namespace ck_tile {
template <typename ADataType_,
@@ -214,7 +214,7 @@ struct CShuffleEpilogue
{
auto tile_window = make_tile_window_paged(out_dram_window.get_bottom_tensor_view(),
auto tile_window = make_tile_scatter_gather(out_dram_window.get_bottom_tensor_view(),
out_dram_window.get_window_lengths(),
out_dram_window.get_window_origin(),
dram_tile_distribution,

View File

@@ -9,7 +9,7 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_window_paged.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
namespace ck_tile {
@@ -289,7 +289,7 @@ struct BlockFmhaPipelineQRKSVS
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[i_total_loops * kN0 + c_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
auto k_dram_window = make_tile_window_paged(
auto k_dram_window = make_tile_scatter_gather(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),