[CK-Tile] Refactor base pipeline usage (#3251)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder
This commit is contained in:
Max Podkorytov
2025-12-04 11:45:49 -08:00
committed by GitHub
parent d9d4c9c3df
commit d184eed823
37 changed files with 1012 additions and 1836 deletions

View File

@@ -19,12 +19,12 @@ struct BaseGemmPipelineAgBgCrCompAsync
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop == 1)
{
@@ -158,9 +158,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
@@ -539,14 +537,21 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
void* p_smem_0,
void* p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
public:
@@ -557,14 +562,21 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem_0,
p_smem_1);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem_0,
p_smem_1);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile

View File

@@ -154,10 +154,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr bool HasHotLoop =
Problem::HasHotLoop; // Base::BlockHasHotloop(Problem::num_loop);
static constexpr auto TailNum =
Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop);
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
@@ -641,13 +637,20 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
/**
@@ -700,13 +703,15 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
return operator()(a_dram_block_window_tmp,
b_dram_block_window_tmp,
num_loop,
has_hot_loop,
tail_number,
p_smem);
}
template <typename AsDramBlockWindowTmp,

View File

@@ -167,9 +167,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
@@ -685,14 +683,21 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
void* p_smem_0,
void* p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,
@@ -706,14 +711,21 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem_0,
p_smem_1);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem_0,
p_smem_1);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,

View File

@@ -92,9 +92,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
@@ -404,13 +402,20 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
index_t num_loop,
void* p_smem_0) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,
@@ -423,13 +428,20 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
const index_t num_loop,
void* __restrict__ p_smem_0) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem_0);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem_0);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,

View File

@@ -22,12 +22,12 @@ struct BaseGemmPipelineAgBgCrCompV6
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop % HotloopUnroll == 1)
{
@@ -153,9 +153,7 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<BasePImpl::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<BasePImpl::is_b_load_tr>{};
@@ -173,11 +171,9 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
return concat('_', "pipeline_AgBgCrCompV6", BlockSize,
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK),
concat('x', TailNum),
concat('_', KRepeat),
concat('_', DoubleSmemBuffer),
concat('_', Preshuffle),
concat('_', HasHotLoop));
concat('_', Preshuffle));
// clang-format on
}
@@ -725,13 +721,20 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
index_t num_loop,
void* __restrict__ p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,
@@ -744,13 +747,20 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
const index_t num_loop,
void* __restrict__ p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,

View File

@@ -206,10 +206,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t Preshuffle = Problem::Preshuffle;
// Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
@@ -887,13 +884,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename AsDramBlockWindowTmp,
@@ -933,13 +937,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,

View File

@@ -224,8 +224,6 @@ template <typename AsDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename AElementWise_ = ck_tile::element_wise::PassThrough,
typename BElementWise_ = ck_tile::element_wise::PassThrough,
typename ComputeDataType_ = AsDataType_,
@@ -296,8 +294,6 @@ struct UniversalGemmPipelineProblem
static constexpr index_t VectorSizeA = VectorSizeA_;
static constexpr index_t VectorSizeB = VectorSizeB_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{

View File

@@ -148,7 +148,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
: MIterPerWarp * KIterPerWarp;
static constexpr auto TailNum = Problem::TailNum;
#ifdef __gfx942__
static constexpr index_t mfma_per_wg = 2;
@@ -1042,13 +1041,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
void* p_smem_ping,
void* p_smem_pong) const
{
return operator()<TailNum>(
a_dram_block_window_tmp[number<0>{}],
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp[number<0>{}],
num_loop,
p_smem_ping,
p_smem_pong);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const ADataType& a) { return a; };
return operator()<tail_num>(a_dram_block_window_tmp[number<0>{}],
PassThrough,
b_flat_dram_block_window_tmp[number<0>{}],
num_loop,
p_smem_ping,
p_smem_pong);
};
return Base::TailHandler(RunPipeline, true, tail_number);
}
// called from general gemm kernel
@@ -1063,13 +1069,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
void* p_smem_ping,
void* p_smem_pong) const
{
return operator()<TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const ADataType& a) { return a; };
return operator()<tail_num>(a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
};
return Base::TailHandler(RunPipeline, true, tail_number);
}
// called from grouped gemm kernel