Joye/revise wp pipeline (#3493)

* [CK_TILE] unify double and single lds implementation (#108)

Unify LDS buffer management API for single and double buffering modes

This change consolidates the Local Data Store (LDS) buffer management by:

Merging single and double LDS buffer APIs into a unified interface
Implementing ping-pong address calculation in pipeline when double LDS is enabled
Computing pong buffer addresses dynamically using base address offsets

---------

Co-authored-by: joye <joye@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* update wp_pipeline

* fix a c++17 issue

* update for ci errors

* fix ci issues

* include a header to fix ci errors

* fix some rebase issues

* update with rebase

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
joyeamd
2026-01-06 05:49:26 +08:00
committed by GitHub
parent 1224bc0a82
commit 2b563ad048
13 changed files with 766 additions and 929 deletions

View File

@@ -64,12 +64,17 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep>
template <typename SrcDataType = void,
typename DstDataType = void,
index_t UnaryOpSize = 8,
typename DstBlockTile,
typename SrcTileWindow,
typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
load_tile(dst_block_tile, dram_tile_window);
load_int4_tile<SrcDataType, DstDataType, UnaryOpSize>(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, dram_tile_window_step);
}
@@ -217,22 +222,17 @@ struct GemmPipelineAgBgCrImplBase
return std::move(a_copy_dram_window);
}
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&,
const array<index_t, 2>& offset = {0, 0}) const
template <typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&) const
{
// A DRAM tile window for load
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
// A LDS tile window for store
auto a_lds_shape = []() {
if constexpr(is_a_load_tr)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
}();
auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0});
auto a_lds_load_tile_distr = []() {
@@ -244,32 +244,73 @@ struct GemmPipelineAgBgCrImplBase
else
return ALdsLoadTileDistr{};
}();
auto a_lds_gemm_window =
make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr);
return make_tuple(std::move(a_copy_lds_window), std::move(a_lds_gemm_window));
}
template <
typename ADramBlockWindowTmp,
typename ALdsTensorView,
typename ALdsLoadTileDistr,
typename std::enable_if_t<!is_detected<is_tuple, ALdsTensorView>::value, bool>* = nullptr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr& a_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// A DRAM tile window for load
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
// Create LDS windows
auto [a_copy_lds_window, a_lds_gemm_window] =
MakeALdsWindows(a_lds_block_view, a_lds_load_tile_distr);
return make_tuple(std::move(a_copy_dram_window),
std::move(a_copy_lds_window),
std::move(a_lds_gemm_window));
}
template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr&,
// Unified GetAWindows that supports 1, 2, or 3 LDS buffers
template <typename ADramBlockWindowTmp,
typename ALdsTensorViewsTuple,
typename ALdsLoadTileDistr,
typename std::enable_if_t<is_detected<is_tuple, ALdsTensorViewsTuple>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorViewsTuple& a_lds_block_views_tuple,
const ALdsLoadTileDistr& a_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// A DRAM tile window for load
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
// TODO: Do we really need those two tile windows???
// They're exactly same...
// B LDS tile window for store
// Create LDS windows for each buffer
constexpr index_t num_buffers = ALdsTensorViewsTuple::size();
auto a_lds_windows = generate_tuple(
[&](auto i) {
return MakeALdsWindows(a_lds_block_views_tuple[i], a_lds_load_tile_distr);
},
number<num_buffers>{});
// Return: (dram_window, lds_windows_tuple)
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
return make_tuple(std::move(a_copy_dram_window), std::move(a_lds_windows));
}
template <typename BLdsTensorView, typename BLdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto MakeBLdsWindows(const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr&) const
{
auto b_lds_shape = []() {
if constexpr(is_b_load_tr)
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
}();
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
using BLdsDataType =
@@ -286,13 +327,61 @@ struct GemmPipelineAgBgCrImplBase
else
return BLdsLoadTileDistr{};
}();
auto b_lds_gemm_window =
make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr);
return make_tuple(std::move(b_copy_lds_window), std::move(b_lds_gemm_window));
}
template <
typename BDramBlockWindowTmp,
typename BLdsTensorView,
typename BLdsLoadTileDistr,
typename std::enable_if_t<!is_detected<is_tuple, BLdsTensorView>::value, bool>* = nullptr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr& b_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// A DRAM tile window for load
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
// Create LDS windows
auto [b_copy_lds_window, b_lds_gemm_window] =
MakeBLdsWindows(b_lds_block_view, b_lds_load_tile_distr);
return make_tuple(std::move(b_copy_dram_window),
std::move(b_copy_lds_window),
std::move(b_lds_gemm_window));
}
// Unified GetBWindows that supports 1, 2, or 3 LDS buffers
template <typename BDramBlockWindowTmp,
typename BLdsTensorViewsTuple,
typename BLdsLoadTileDistr,
typename std::enable_if_t<is_detected<is_tuple, BLdsTensorViewsTuple>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorViewsTuple& b_lds_block_views_tuple,
const BLdsLoadTileDistr& b_lds_load_tile_distr,
const array<index_t, 2>& offset = {0, 0}) const
{
// B DRAM tile window for load
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
// Create LDS windows for each buffer
constexpr index_t num_buffers = BLdsTensorViewsTuple::size();
auto b_lds_windows = generate_tuple(
[&](auto i) {
return MakeBLdsWindows(b_lds_block_views_tuple[i], b_lds_load_tile_distr);
},
number<num_buffers>{});
// Return: (dram_window, lds_windows_tuple)
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
return make_tuple(std::move(b_copy_dram_window), std::move(b_lds_windows));
}
};
} // namespace ck_tile

View File

@@ -158,6 +158,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
@@ -172,7 +174,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
return 2 * smem_size;
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
@@ -240,8 +243,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
// TODO support multi-ABD
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
@@ -303,8 +305,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
number<BsLayout::size()>{});
// this pipeline has a pair of LDS buffers per logical tile
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
auto&& [a_lds_block1, b_lds_block1] =
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
// set up LDS tile shapes
constexpr auto a_lds_shape = []() {
@@ -534,21 +538,18 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
@@ -559,8 +560,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
@@ -572,8 +572,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);

View File

@@ -172,6 +172,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
@@ -191,7 +193,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
return 2 * smem_size;
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
@@ -281,8 +284,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
@@ -324,8 +326,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// global read 0
////////////// LDS desc, window & register /////////////////
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
auto&& [a_lds_block1, b_lds_block1] =
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
constexpr auto a_lds_shape = []() {
if constexpr(is_a_load_tr_v())
@@ -680,8 +684,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
@@ -693,8 +696,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
@@ -708,8 +710,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
@@ -721,8 +722,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
@@ -738,8 +738,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
@@ -751,8 +750,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
b_dram_block_window_tmp,
PassThrough,
num_loop,
p_smem_0,
p_smem_1);
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
@@ -769,16 +767,14 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
p_smem);
}
template <typename ADramBlockWindowTmp,
@@ -789,14 +785,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem_0,
p_smem_1);
p_smem);
}
template <typename ADramBlockWindowTmp,
@@ -809,16 +803,14 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
has_hot_loop,
tail_number,
p_smem_0,
p_smem_1);
p_smem);
}
};
} // namespace ck_tile

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
@@ -201,6 +202,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t kNPerBlock = TileShape::kN;
constexpr index_t kKPerBlock = TileShape::kK;
constexpr index_t NIterPerWarp =
kNPerBlock / TileShape::BlockWarps::at(I1) / TileShape::WarpTile::at(I1);
constexpr index_t KIterPerWarp = kKPerBlock / TileShape::WarpTile::at(I2);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
@@ -213,13 +220,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
constexpr index_t KRepeat = KIterPerWarp;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t NRepeat = NIterPerWarp;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
@@ -232,8 +239,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
sequence<1, 2, 1, 2>,
sequence<0, 0, 3, 3>>{});
}
template <typename Problem>
@@ -307,7 +314,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockWeightPreshuffleASmemBSmemCRegV1<Problem, BlockWeightPreshufflePolicy>{};
return BlockWeightPreshuffleASmemBRegCReg<Problem, BlockWeightPreshufflePolicy>{};
}
/**
* @brief Get the vector store size for C tensor.
@@ -325,7 +332,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
using WG_ = typename BlockGemm::WG;
using WG_ = typename BlockGemm::WarpGemm;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;

View File

@@ -32,19 +32,34 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
if(tail_number == TailNumber::Odd)
if(has_hot_loop)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
if(tail_number == TailNumber::Odd)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else // Even tail number
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
else // Even tail number
else
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
if(tail_number == TailNumber::Odd)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else // Even tail number
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
@@ -52,7 +67,8 @@ template <typename Problem, typename PipelinePolicy = UniversalWeightPreshuffleP
struct WeightPreshufflePipelineAGmemBGmemCRegV2
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
{
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, PipelinePolicy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
@@ -75,11 +91,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
using BlockWeightPreshuffle =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
static constexpr auto config =
BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
@@ -95,6 +106,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
@@ -131,12 +144,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MWarp = BlockWarps::at(I0);
static constexpr index_t NWarp = BlockWarps::at(I1);
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
static constexpr index_t WarpTileM = WarpTile::at(I0);
static constexpr index_t WarpTileN = WarpTile::at(I1);
static constexpr index_t WarpTileK = WarpTile::at(I2);
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpTileM);
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpTileN);
static constexpr index_t KIterPerWarp = kKPerBlock / WarpTileK;
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
@@ -154,20 +171,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
#else
static constexpr index_t mfma_per_wg = 1;
#endif
static constexpr index_t dsread_per_wg =
max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
static constexpr index_t dsread_per_wg = max(
index_t(WarpTileM * WarpTileK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
#if defined(__HIP_DEVICE_COMPILE__)
static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
static_assert((WarpTileM * WarpTileK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
Problem::VectorLoadSize ==
0);
#endif
static constexpr index_t dsread_num_perK =
WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize;
static constexpr index_t dsread_num_perK = WarpTileM * WarpTileK * sizeof(ADataType) *
MIterPerWarp / WaveSize / Problem::VectorLoadSize;
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
static constexpr index_t Bload_num_perK = kNPerBlock * WarpTileK / NWarp / K1 / WaveSize;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
@@ -187,7 +204,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV2",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', WG::kM, WG::kN, WG::kK),
concat('x', WarpTileM, WarpTileN, WarpTileK),
concat('x', GetVectorSizeA(), GetVectorSizeB()),
concat('x', kPadM, kPadN, kPadK));
@@ -195,14 +212,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr index_t Preshuffle = Problem::Preshuffle;
using Base::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return PipelinePolicy::template GetSmemSize<Problem>();
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
return DoubleSmemBuffer ? 2 * smem_size : smem_size;
}
// dsread_perM: how many LDS reads want to issue in this M-iter
@@ -515,515 +534,184 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// __builtin_amdgcn_sched_barrier(0);
}
template <TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr,
index_t UnaryOpSize_ = 8>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
struct PipelineImpl : public PipelineImplBase
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
"wrong!");
using Base = PipelineImplBase;
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
const index_t iMWarp = get_warp_id() / NWarp;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
__builtin_amdgcn_sched_barrier(0);
// A tile in LDS
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block_ping =
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
auto a_lds_block_pong =
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_copy_lds_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
// ping-pong window for A LDS
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_ping;
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Block GEMM
auto block_weight_preshuffle = BlockWeightPreshuffle();
// Acc register tile
auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// pingpong buffer for B
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
b_warp_tensor_pong;
// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// Prefill A0
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
__builtin_amdgcn_sched_barrier(0);
// Prefetch A1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds();
// preload A00,A10 from lds
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr,
index_t UnaryOpSize_ = 8>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
[[maybe_unused]] const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
"wrong!");
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
// A tile in LDS
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
// Prefetch A(2i+2)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
auto a_lds_blocks = generate_tuple(
[&](auto i) {
ADataType* p_a_lds = static_cast<ADataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + smem_size * i.value));
return make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
},
number<2>{});
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(
BlockWeightPreshuffle::MakeABlockDistributionEncode());
auto&& windows_result =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_blocks, a_lds_load_tile_distr);
auto&& a_copy_dram_window = windows_result.template get<0>();
auto&& a_lds_windows = windows_result.template get<1>();
auto a_copy_lds_windows = generate_tuple(
[&](auto i) -> decltype(auto) { return a_lds_windows[i].template at<0>(); },
number<2>{});
// Block GEMM
auto block_weight_preshuffle = BlockWeightPreshuffle();
// Acc register tile
auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
auto a_load_windows = generate_tuple(
[&](auto i) -> decltype(auto) {
return block_weight_preshuffle.MakeALoadWindows(a_copy_lds_windows[i]);
},
number<2>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(b_flat_dram_block_window_tmp
.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp * NIterPerWarp>{},
number<flatKPerWarp * KIterPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BFlatBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kflatKPerBlock);
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BBlockTile =
decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
ABlockTile a_global_tile;
BBlockTile b_global_tile[2];
// // Prefetch A0
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
// Prefill A0
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
// Prefetch A1
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds();
// preload A00,A10 from lds
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
__builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
if constexpr(HasHotLoop)
{
index_t i_global_read = amd_wave_read_first_lane(2);
do
{
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
Base::GlobalPrefetch(
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
block_weight_preshuffle(c_block_tile,
a_load_windows[I0],
b_global_tile[0],
b_flat_distribution);
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
HotLoopScheduler();
}
{
block_sync_lds();
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
Base::GlobalPrefetch(
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
block_weight_preshuffle(c_block_tile,
a_load_windows[I1],
b_global_tile[1],
b_flat_distribution);
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
HotLoopScheduler();
}
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
i_global_read += 2;
} while(i_global_read < num_loop);
}
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
// tail
if constexpr(TailNum == TailNumber::Even)
{
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
block_weight_preshuffle(
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
block_sync_lds();
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
Last2ndHotLoopScheduler();
}
{
block_weight_preshuffle(
c_block_tile, a_load_windows[I1], b_global_tile[1], b_flat_distribution);
LastHotLoopScheduler();
}
}
else if constexpr(TailNum == TailNumber::Odd)
{
block_weight_preshuffle(
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
LastHotLoopScheduler();
}
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(2i+2)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
// Prefetch A(2i+3)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i+1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
iCounter--;
return c_block_tile;
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
// __builtin_amdgcn_sched_barrier(0);
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(loopK)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
// GEMM loopK-1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
// TailHotLoopScheduler();
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
LastHotLoopScheduler();
}
else if constexpr(TailNum == TailNumber::Odd)
{
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
LastHotLoopScheduler();
}
return c_block_tile;
}
};
// called from universal gemm kernel
template <typename ADramBlockWindowTmp,
@@ -1038,23 +726,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
[[maybe_unused]] const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
void* p_smem) const
{
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const ADataType& a) { return a; };
return operator()<tail_num>(a_dram_block_window_tmp[number<0>{}],
PassThrough,
b_flat_dram_block_window_tmp[number<0>{}],
num_loop,
p_smem_ping,
p_smem_pong);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp[number<0>{}],
a_element_func,
b_flat_dram_block_window_tmp[number<0>{}],
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, true, tail_number);
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
// called from general gemm kernel
@@ -1066,23 +751,21 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
void* p_smem) const
{
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr auto PassThrough = [](const ADataType& a) { return a; };
return operator()<tail_num>(a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, true, tail_number);
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
// called from grouped gemm kernel
@@ -1095,21 +778,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
void* __restrict__ p_smem) const
{
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr auto PassThrough = [](const auto& x) { return x; };
return operator()<tail_num>(a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem_0,
p_smem_1);
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, true, tail_number);
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};