[CK_TILE] Move GEMM pipeline tail handling logic to pipelines (#2222)

* Add TailHandler for V3, V4 and Mem pipelines

* Adapt examples and tests to use TailHandler

* move tail-handling logic to pipeline in persistent grouped gemm

* Fix Mem pipeline dispatching, add CompV4 dispatching

* Use a macro for handling the many tails of Mem pipeline

* Fix formatting again

* Use const-ref RunFunction, remove unnecessary try_run
This commit is contained in:
Sami Remes
2025-06-04 11:50:21 +03:00
committed by GitHub
parent ffb52783d0
commit 7ea1508b59
10 changed files with 234 additions and 553 deletions

View File

@@ -192,32 +192,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
}
};
if(has_hot_loop)
{
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \""
<< tail_num << "\" which is not supported! PrefetchStages: "
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
else
{
std::ostringstream err;
err << "Num K loop must be larger than number of prefetech stages."
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename ALayout, typename BLayout, typename CLayout>