Merge commit 'd996bc78befb15ee0405ff78d0ad0da00f8550f3' into develop

This commit is contained in:
assistant-librarian[bot]
2025-06-16 10:09:10 +00:00
parent fffaadf44e
commit 4501568487
5 changed files with 155 additions and 24 deletions

View File

@@ -49,9 +49,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,

View File

@@ -33,6 +33,7 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
index_t YsGatherDim = 0>
@@ -42,6 +43,7 @@ struct tile_scatter_gather
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
using ValidArray = remove_cvref_t<StaticValidArray_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
@@ -152,12 +154,14 @@ struct tile_scatter_gather
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution,
const PageIdxArray& page_idx)
const PageIdxArray& page_idx,
const ValidArray& valids)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
page_idx_{page_idx},
valids_{valids},
pre_computed_coords_{}
{
#if 0 // debug
@@ -336,12 +340,25 @@ struct tile_scatter_gather
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
bool_constant<oob_conditional_check>{});
const vector_t vec_value = [&]() {
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
bool_constant<oob_conditional_check>{});
}
else
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
bool_constant<oob_conditional_check>{});
}
}();
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
@@ -451,9 +468,23 @@ struct tile_scatter_gather
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
}
else
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem,
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
0,
pre_nop_);
}
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
@@ -529,11 +560,24 @@ struct tile_scatter_gather
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
}
else
{
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
vec_value,
bool_constant<oob_conditional_check>{});
}
// printf("coord_offset:%d, scatter_offset:%d \n",
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
@@ -570,14 +614,23 @@ struct tile_scatter_gather
});
}
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx)
{
page_idx_ = new_idx;
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
// static_for<0, 2, 1>{}([&](auto k0) {
// printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]);
// });
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
{
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
{
valids_ = new_valids;
}
}
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray& new_idx,
const ValidArray& new_valids)
{
update_page_idx(new_idx);
update_valids(new_valids);
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
@@ -657,6 +710,7 @@ struct tile_scatter_gather
TileDstr tile_dstr_;
PageIdxArray page_idx_;
ValidArray valids_;
// this contains:
// per-thread coordinate for window adaptor
@@ -684,9 +738,10 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
std::nullptr_t,
HsGatherDim,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, page_idx};
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
}
template <typename TensorView,
@@ -728,4 +783,76 @@ CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
number<HsGatherDim>{});
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
const StaticValidArray_& valids,
number<HsGatherDim> = {},
number<NumCoord> = {})
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
remove_cvref_t<StaticValidArray_>,
HsGatherDim,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
typename StaticValidArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
const StaticValidArray& valids,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution,
page_idx,
valids,
number<HsGatherDim>{});
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
typename StaticValidArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
const StaticValidArray& valids,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
page_idx,
valids,
number<HsGatherDim>{});
}
} // namespace ck_tile

View File

@@ -447,6 +447,7 @@ struct FlatmmKernel
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
a_block_window, b_flat_block_window, num_loop, smem_ptr);
@@ -454,7 +455,7 @@ struct FlatmmKernel
auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, smem_ptr);
c_block_window, c_block_tile, d_block_window, smem_ptr);
}
CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const

View File

@@ -31,8 +31,8 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"

View File

@@ -278,13 +278,13 @@ def main():
shapes = tuples(filename)
all_results = []
from tqdm import tqdm
from functools import partial
from os import path
profiler_bin = path.join(args["build_dir"], "bin", "ckProfiler")
for s in tqdm(shapes):
total = len(shapes)
for idx, s in enumerate(shapes, 1):
run_shape_stdout_lines = run_shape(
s, profiler_bin, args["op_name"], args["dtype"], args["layout"]
)