mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE] Move GEMM pipeline tail handling logic to pipelines (#2222)
* Add TailHandler for V3, V4 and Mem pipelines * Adapt examples and tests to use TailHandler * move tail-handling logic to pipeline in persistent grouped gemm * Fix Mem pipeline dispatching, add CompV4 dispatching * Use a macro for handling the many tails of Mem pipeline * Fix formatting again * Use const-ref RunFunction, remove unnecessary try_run
This commit is contained in:
@@ -252,60 +252,13 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunEpilogue = [&](auto& c_block_tile) {
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I2);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr_0);
|
||||
};
|
||||
|
||||
if constexpr(is_specialization_of<GemmPipeline, GemmPipelineAgBgCrCompV3>::value)
|
||||
{
|
||||
// Run the specific implementation with hotloop+tailnum config
|
||||
using PipelineImpl =
|
||||
typename GemmPipeline::template PipelineImpl<GemmPipeline::Scheduler>;
|
||||
const auto PassThrough = [](const auto& a) { return a; };
|
||||
if(has_hot_loop && tail_num == TailNumber::Full)
|
||||
{
|
||||
const auto& c_block_tile =
|
||||
PipelineImpl{}.template operator()<true, TailNumber::Full>(a_block_window,
|
||||
PassThrough,
|
||||
b_block_window,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
smem_ptr_0);
|
||||
RunEpilogue(c_block_tile);
|
||||
}
|
||||
else if(has_hot_loop && tail_num == TailNumber::Odd)
|
||||
{
|
||||
const auto& c_block_tile =
|
||||
PipelineImpl{}.template operator()<true, TailNumber::Odd>(a_block_window,
|
||||
PassThrough,
|
||||
b_block_window,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
smem_ptr_0);
|
||||
RunEpilogue(c_block_tile);
|
||||
}
|
||||
else if(has_hot_loop && tail_num == TailNumber::Even)
|
||||
{
|
||||
const auto& c_block_tile =
|
||||
PipelineImpl{}.template operator()<true, TailNumber::Even>(a_block_window,
|
||||
PassThrough,
|
||||
b_block_window,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
smem_ptr_0);
|
||||
RunEpilogue(c_block_tile);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ignore = a_block_window;
|
||||
ignore = b_block_window;
|
||||
static_assert(false, "GemmPipeline specialization not supported!");
|
||||
}
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I2);
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr,
|
||||
|
||||
Reference in New Issue
Block a user