From 84ba4164c910d3629559fcd75c3752d09246a4da Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Wed, 4 Jun 2025 11:50:21 +0300 Subject: [PATCH] [CK_TILE] Move GEMM pipeline tail handling logic to pipelines (#2222) * Add TailHandler for V3, V4 and Mem pipelines * Adapt examples and tests to use TailHandler * move tail-handling logic to pipeline in persistent grouped gemm * Fix Mem pipeline dispatching, add CompV4 dispatching * Use a macro for handling the many tails of Mem pipeline * Fix formatting again * Use const-ref RunFunction, remove unnecessary try_run [ROCm/composable_kernel commit: 7ea1508b59a0e8f89540d8d5f7eb3e7da9a50a62] --- example/ck_tile/03_gemm/universal_gemm.cpp | 103 +------------- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 132 +----------------- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 116 +-------------- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 61 +------- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 80 +++++++++++ .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 65 +++++++++ .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 78 ++++++++++- .../batched_gemm/test_batched_gemm_util.hpp | 27 +--- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 98 +------------ .../grouped_gemm/test_grouped_gemm_util.hpp | 27 +--- 10 files changed, 234 insertions(+), 553 deletions(-) diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 645263d26d..3a7cc93df8 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -13,19 +13,6 @@ #include "gemm_utils.hpp" #include "run_gemm_example.inc" -template -void try_run(ck_tile::TailNumber tn) -{ - if constexpr(Pipeline::PrefetchStages > static_cast(TN)) - { - if(tn == TN) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } -} - template {}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "For compute pipeline tail number should always be Full, but have \"" << tail_num - << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - if(tail_num == ck_tile::TailNumber::One) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - auto check_tail = [&](auto... TNs) { - (try_run(tail_num), ...); - }; - - check_tail(ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}); - -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } -#endif - } - else - { - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "Num K loop must be larger than number of prefetech stages." - << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 68ad1106ce..c5c86b1952 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -183,137 +183,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre } }; - if(has_hot_loop) - { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got " - << tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Tail pipeline One to Seven - if(tail_num == ck_tile::TailNumber::One) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } -#endif - } - else - { - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - std::ostringstream err; - err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but " - "got " - << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 067319b3f9..2a72c6325e 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -197,121 +197,7 @@ float grouped_gemm(const std::vector& gemm_descs, } }; - if(has_hot_loop) - { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got " - << tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Tail pipeline One to Seven - if(tail_num == ck_tile::TailNumber::One) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } -#endif - } - else - { - std::ostringstream err; - err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but " - << "got " << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); return ave_time; } diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index d0ad97c800..f57600d7a5 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -252,60 +252,13 @@ struct GroupedGemmKernel : public GemmKernel( - c_block_window, c_block_tile, smem_ptr_0); - }; - - if constexpr(is_specialization_of::value) - { - // Run the specific implementation with hotloop+tailnum config - using PipelineImpl = - typename GemmPipeline::template PipelineImpl; - const auto PassThrough = [](const auto& a) { return a; }; - if(has_hot_loop && tail_num == TailNumber::Full) - { - const auto& c_block_tile = - PipelineImpl{}.template operator()(a_block_window, - PassThrough, - b_block_window, - PassThrough, - num_loop, - smem_ptr_0); - RunEpilogue(c_block_tile); - } - else if(has_hot_loop && tail_num == TailNumber::Odd) - { - const auto& c_block_tile = - PipelineImpl{}.template operator()(a_block_window, - PassThrough, - b_block_window, - PassThrough, - num_loop, - smem_ptr_0); - RunEpilogue(c_block_tile); - } - else if(has_hot_loop && tail_num == TailNumber::Even) - { - const auto& c_block_tile = - PipelineImpl{}.template operator()(a_block_window, - PassThrough, - b_block_window, - PassThrough, - num_loop, - smem_ptr_0); - RunEpilogue(c_block_tile); - } - } - else - { - ignore = a_block_window; - ignore = b_block_window; - static_assert(false, "GemmPipeline specialization not supported!"); - } + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I2); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr_0); } CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 90cd22429e..a6267e4c89 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -50,6 +50,50 @@ struct BaseGemmPipelineAgBgCrCompV3 } } } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + // Handle all the valid cases. + if(has_hot_loop) + { + if(tail_number == TailNumber::Full) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + else + { + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } +#if defined(__HIP_DEVICE_COMPILE__) + // This path should be unreachable in device code if tail_number is valid. + __builtin_unreachable(); +#else + // If execution reaches here, it's an invalid combination of arguments. + if(has_hot_loop) + { + throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must " + "be TailNumber::Full."); + } + else + { + throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must " + "be TailNumber::Odd or TailNumber::Even."); + } +#endif + } }; // Compute optimized pipeline @@ -556,6 +600,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 p_smem); } + /** + * @brief This function runs the pipeline by wrapping it with the tail handler. + * + * @note This is used by the persistent gemm kernel variants that don't determine + * hot loop and tail number on the host side, e.g. grouped gemm kernel. + */ + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + constexpr bool hot_loop = hot_loop_.value; + constexpr auto tail_num = tail_num_.value; + constexpr auto PassThrough = [](const auto& x) { return x; }; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_dram_block_window_tmp, + PassThrough, + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + + /** + * @brief This function runs the pipeline using compile-time known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ template CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 6535f612f1..6fc6ba2ba2 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -34,6 +34,46 @@ struct BaseGemmPipelineAgBgCrCompV4 return TailNumber::Two; } } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + // Handle all the valid cases. + if(has_hot_loop) + { + if(tail_number == TailNumber::Three) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Two) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + else + { + if(tail_number == TailNumber::Three) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Two) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + // If execution reaches here, it's an invalid tail_number because it wasn't handled above. +#if defined(__HIP_DEVICE_COMPILE__) + __builtin_unreachable(); +#else + throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than " + "PrefetchStages are supported."); +#endif + } }; /** @@ -572,5 +612,30 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 p_smem_0, p_smem_1); } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + constexpr bool hot_loop = hot_loop_.value; + constexpr auto tail_num = tail_num_.value; + constexpr auto PassThrough = [](const auto& x) { return x; }; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_dram_block_window_tmp, + PassThrough, + num_loop, + p_smem_0, + p_smem_1); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index abf5b617ee..f7b5f9b3cb 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -52,13 +52,14 @@ struct BaseGemmPipelineAgBgCrMem static constexpr index_t LocalPrefillStages = 1; static constexpr index_t GlobalBufferNum = PrefetchStages; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; - CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; } - CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { if(num_loop % PrefetchStages == 1) { @@ -93,6 +94,56 @@ struct BaseGemmPipelineAgBgCrMem return TailNumber::Full; } } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + // Wrap the hot_loop dispatch first. + auto tail_dispatch = [&](auto tail_num_constant) { + if(has_hot_loop) + { + return run_func(bool_constant{}, tail_num_constant); + } + else + { + return run_func(bool_constant{}, tail_num_constant); + } + }; + +#define CHECK_TAIL_NUMBER(TAIL_NUMBER, PREFETCH_VALUE) \ + else if(tail_number == TailNumber::TAIL_NUMBER) \ + { \ + if constexpr(PrefetchStages > PREFETCH_VALUE) \ + { \ + return tail_dispatch(integral_constant{}); \ + } \ + } + // Handle all the valid cases. + if(tail_number == TailNumber::One) + { + return tail_dispatch(integral_constant{}); + } + else if(tail_number == TailNumber::Full) + { + return tail_dispatch(integral_constant{}); + } + CHECK_TAIL_NUMBER(Two, 2) + CHECK_TAIL_NUMBER(Three, 3) + CHECK_TAIL_NUMBER(Four, 4) + CHECK_TAIL_NUMBER(Five, 5) + CHECK_TAIL_NUMBER(Six, 6) + CHECK_TAIL_NUMBER(Seven, 7) +#undef CHECK_TAIL_NUMBER + + // We shouldn't get here unless we have a tail number larger than the prefetch stages. +#if defined(__HIP_DEVICE_COMPILE__) + __builtin_unreachable(); +#else + throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than " + "PrefetchStages are supported."); +#endif + } }; // Maximum Global Memory throughput pipeline with >=32KB data in fly @@ -749,6 +800,29 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem p_smem); } + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + constexpr bool hot_loop = hot_loop_.value; + constexpr auto tail_num = tail_num_.value; + constexpr auto PassThrough = [](const auto& x) { return x; }; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_dram_block_window_tmp, + PassThrough, + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + template CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 4633f23ded..cffa81d1c5 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -159,32 +159,7 @@ class TestCkTileBatchedGemm : public ::testing::Test } }; - if(has_hot_loop) - { - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "For compute pipeline tail number should always be Full, but have \"" - << tail_num << "\" which is not supported! PrefetchStages: " - << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } - else - { - std::ostringstream err; - err << "Num K loop must be larger than number of prefetech stages." - << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } public: diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 1892aa0e31..b3146b5f8e 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -63,19 +63,6 @@ struct GemmPipelineTypeSelector using pipeline = ck_tile::GemmPipelineAgBgCrCompV4; }; -template -void try_run(ck_tile::TailNumber tn) -{ - if constexpr(Pipeline::PrefetchStages > static_cast(TN)) - { - if(tn == TN) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } -} - template class TestCkTileGemmPipeline : public ::testing::Test { @@ -240,90 +227,7 @@ class TestCkTileGemmPipeline : public ::testing::Test } }; - if(has_hot_loop) - { - if constexpr(PipelineType == GemmPipelineType::CompV3) - { - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "For compute pipeline tail number should always be Full, but have \"" - << tail_num << "\" which is not supported! PrefetchStages: " - << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } - - if constexpr(PipelineType == GemmPipelineType::Mem) - { - // Tail pipeline One to Seven - if(tail_num == ck_tile::TailNumber::One) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - auto check_tail = [&](auto... TNs) { - (try_run(tail_num), ...); - }; - - check_tail( - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}); - } - - if constexpr(PipelineType == GemmPipelineType::CompV4) - { - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - } - else - { - // Tail number always Full - #PrefetchStages - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "When there's no hot loop, this tail number \"" << tail_num - << "\" is not supported! " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } public: diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index cdc2e4f090..382a32a7d9 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -192,32 +192,7 @@ class TestCkTileGroupedGemm : public ::testing::Test } }; - if(has_hot_loop) - { - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "For compute pipeline tail number should always be Full, but have \"" - << tail_num << "\" which is not supported! PrefetchStages: " - << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } - else - { - std::ostringstream err; - err << "Num K loop must be larger than number of prefetech stages." - << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } template