diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index e6a2811918..b60a3b274b 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -12,6 +12,19 @@ #include "ck_tile/host.hpp" #include "gemm_utils.hpp" +template +void try_run(ck_tile::TailNumber tn) +{ + if constexpr(Pipeline::PrefetchStages > static_cast(TN)) + { + if(tn == TN) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +} + template {}, @@ -176,60 +188,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::integral_constant{}); } - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } + auto check_tail = [&](auto... TNs) { + (try_run(tail_num), ...); + }; + + check_tail(ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}); + #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { @@ -259,7 +228,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& else if(tail_num == ck_tile::TailNumber::Even) { RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + ck_tile::integral_constant{}); } else { diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 0329f16416..85742cb3de 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -63,6 +63,19 @@ struct GemmPipelineTypeSelector using pipeline = ck_tile::GemmPipelineAgBgCrCompV4; }; +template +void try_run(ck_tile::TailNumber tn) +{ + if constexpr(Pipeline::PrefetchStages > static_cast(TN)) + { + if(tn == TN) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +} + template class TestCkTileGemmPipeline : public ::testing::Test { @@ -251,60 +264,17 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::TailNumber::Full>{}); } - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } + auto check_tail = [&](auto... TNs) { + (try_run(tail_num), ...); + }; + + check_tail( + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}); } if constexpr(PipelineType == GemmPipelineType::CompV4) diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json index 66dbdafa11..53197ada6c 100644 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -19,7 +19,7 @@ "values": [256] }, "tile_k": { - "values": [64, 32] + "values": [32] }, "warp_m": { "values": [2] diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index a748c35feb..3839523e3d 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -37,7 +37,9 @@ DEFAULT_EPILOGUE = """ WarpTileM, WarpTileN, WarpTileK, - UniversalGemmProblem::TransposeC>>; + UniversalGemmProblem::TransposeC, + true, + memory_operation>>; """ CSHUFFLE_EPILOGUE = """ @@ -55,22 +57,23 @@ CSHUFFLE_EPILOGUE = """ WarpTileM, WarpTileN, WarpTileK, - UniversalGemmProblem::TransposeC>>; + UniversalGemmProblem::TransposeC, + memory_operation>>; """ HOT_LOOP_FALSE = """ if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else @@ -79,68 +82,43 @@ HOT_LOOP_FALSE = """ } """ RUN_MEM = """ - if(tail_num == ck_tile::TailNumber::One) - { - Run(ck_tile::bool_constant{}, + // Handle One and Full cases directly + if (tail_num == ck_tile::TailNumber::One) { + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, + } else if (tail_num == ck_tile::TailNumber::Full) { + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } + // Variadic call using fold expression + auto check_tail = [&](auto... TNs) { + (try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...); + }; - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - if(tail_num == ck_tile::TailNumber::Three) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Four) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Five) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Six) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Seven) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers"); - } + check_tail( + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{} + ); """ RUN_COMPV3 = """ if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else @@ -152,12 +130,12 @@ RUN_COMPV3 = """ RUN_COMPV4 = """ if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } """ @@ -347,6 +325,15 @@ namespace {group_name} {{ kPadM: bool, kPadN: bool, kPadK: bool) -> str: """Generate kernel struct template""" return f""" +template +void try_run(ck_tile::TailNumber tn) {{ + if constexpr (Pipeline::PrefetchStages > static_cast(TN)) {{ + if (tn == TN) {{ + RunSplitk(ck_tile::bool_constant{{}}, + ck_tile::integral_constant{{}}); + }} + }} +}} template {{}}); + }} else {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} + }}; + if(has_hot_loop) {{ {HOT_LOOP_TRUE[pipeline]} }} else {{ @@ -450,6 +452,7 @@ struct GemmKernel {{ return ave_time; }} + static std::string get_name() {{ return std::string("GemmKernel