[CK_TILE] Move GEMM pipeline tail handling logic to pipelines (#2222)

* Add TailHandler for V3, V4 and Mem pipelines

* Adapt examples and tests to use TailHandler

* move tail-handling logic to pipeline in persistent grouped gemm

* Fix Mem pipeline dispatching, add CompV4 dispatching

* Use a macro for handling the many tails of Mem pipeline

* Fix formatting again

* Use const-ref RunFunction, remove unnecessary try_run
This commit is contained in:
Sami Remes
2025-06-04 11:50:21 +03:00
committed by GitHub
parent ffb52783d0
commit 7ea1508b59
10 changed files with 234 additions and 553 deletions

View File

@@ -252,60 +252,13 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
const auto RunEpilogue = [&](auto& c_block_tile) {
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I2);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, smem_ptr_0);
};
if constexpr(is_specialization_of<GemmPipeline, GemmPipelineAgBgCrCompV3>::value)
{
// Run the specific implementation with hotloop+tailnum config
using PipelineImpl =
typename GemmPipeline::template PipelineImpl<GemmPipeline::Scheduler>;
const auto PassThrough = [](const auto& a) { return a; };
if(has_hot_loop && tail_num == TailNumber::Full)
{
const auto& c_block_tile =
PipelineImpl{}.template operator()<true, TailNumber::Full>(a_block_window,
PassThrough,
b_block_window,
PassThrough,
num_loop,
smem_ptr_0);
RunEpilogue(c_block_tile);
}
else if(has_hot_loop && tail_num == TailNumber::Odd)
{
const auto& c_block_tile =
PipelineImpl{}.template operator()<true, TailNumber::Odd>(a_block_window,
PassThrough,
b_block_window,
PassThrough,
num_loop,
smem_ptr_0);
RunEpilogue(c_block_tile);
}
else if(has_hot_loop && tail_num == TailNumber::Even)
{
const auto& c_block_tile =
PipelineImpl{}.template operator()<true, TailNumber::Even>(a_block_window,
PassThrough,
b_block_window,
PassThrough,
num_loop,
smem_ptr_0);
RunEpilogue(c_block_tile);
}
}
else
{
ignore = a_block_window;
ignore = b_block_window;
static_assert(false, "GemmPipeline specialization not supported!");
}
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I2);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, smem_ptr_0);
}
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr,