From a180ebf3e72e79e5e4f330be560586cdbacbb8b9 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Fri, 8 Aug 2025 02:03:49 +0300 Subject: [PATCH] [CK_TILE] Enable persistent kernel and tail handler in tile_engine (#2300) * Enable persistent kernel in tile_engine and use tail handler * Fix formatting * Add persistent to default_config.json * Remove extra newlines and add persistent also to user config * Reduce instances from default_config.json * add persistent to benchmark.json and custom_ci_config.json * changed the config file to have few instances --------- Co-authored-by: Thomas Ning Co-authored-by: ThomasNing [ROCm/composable_kernel commit: 3c9400471dcd4b3f55d8f6b88b562bda63b75657] --- tile_engine/ops/gemm/codegen_utils.py | 89 ------------------- tile_engine/ops/gemm/configs/benchmark.json | 6 ++ .../ops/gemm/configs/custom_ci_config.json | 6 ++ .../ops/gemm/configs/default_config.json | 7 +- .../gemm/configs/user_provided_config.json | 6 ++ tile_engine/ops/gemm/gemm_host_api.hpp | 16 ++-- tile_engine/ops/gemm/gemm_instance_builder.py | 51 +++++------ tile_engine/ops/gemm/json_config.py | 4 + 8 files changed, 60 insertions(+), 125 deletions(-) diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 9ff76724cc..4a990f3309 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -65,93 +65,6 @@ CSHUFFLE_EPILOGUE = """ UniversalGemmProblem::TransposeC, memory_operation>>; """ -HOT_LOOP_FALSE = """ - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Num K loop must be larger than number of prefetech stages."); - } -""" -RUN_MEM = """ - // Handle One and Full cases directly - if (tail_num == ck_tile::TailNumber::One) { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } else if (tail_num == ck_tile::TailNumber::Full) { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - auto check_tail = [&](auto... TNs) { - ([&]{ - if constexpr(BaseGemmPipeline::PrefetchStages > static_cast(decltype(TNs)::value)) { - if(tail_num == decltype(TNs)::value) { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - }(), ...); - }; - - check_tail( - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{} - ); -""" - -RUN_COMPV3 = """ - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even."); - } -""" - -RUN_COMPV4 = """ - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } -""" - PIPELINE_MAP = { "mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"], @@ -172,8 +85,6 @@ SCHEDULER_MAP = { EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE} -HOT_LOOP_TRUE = {"mem": RUN_MEM, "compv3": RUN_COMPV3, "compv4": RUN_COMPV4} - def BOOL_MAP(b_): return {True: "true", False: "false"}[bool(b_)] diff --git a/tile_engine/ops/gemm/configs/benchmark.json b/tile_engine/ops/gemm/configs/benchmark.json index 1560698b77..def3ca4453 100644 --- a/tile_engine/ops/gemm/configs/benchmark.json +++ b/tile_engine/ops/gemm/configs/benchmark.json @@ -96,6 +96,12 @@ "values": [ false ] + }, + "persistent": { + "values": [ + false, + true + ] } } } \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/custom_ci_config.json b/tile_engine/ops/gemm/configs/custom_ci_config.json index 9187fb01eb..ca6c7230fd 100644 --- a/tile_engine/ops/gemm/configs/custom_ci_config.json +++ b/tile_engine/ops/gemm/configs/custom_ci_config.json @@ -77,6 +77,12 @@ "values": [ false ] + }, + "persistent": { + "values": [ + false, + true + ] } } } \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json index 12a8ddd4b7..5bd51b809a 100644 --- a/tile_engine/ops/gemm/configs/default_config.json +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -95,6 +95,11 @@ "values": [ false ] + }, + "persistent": { + "values": [ + false + ] } } -} \ No newline at end of file +} diff --git a/tile_engine/ops/gemm/configs/user_provided_config.json b/tile_engine/ops/gemm/configs/user_provided_config.json index 5761b39ada..76e194f6b9 100644 --- a/tile_engine/ops/gemm/configs/user_provided_config.json +++ b/tile_engine/ops/gemm/configs/user_provided_config.json @@ -82,6 +82,12 @@ "values": [ false ] + }, + "persistent": { + "values": [ + false, + true + ] } } } \ No newline at end of file diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp index 2c4af8955f..f28f5dd29c 100644 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -144,7 +144,8 @@ inline auto create_args(int argc, char* argv[]) .insert("pad_k", "false", "Whether pad or not in k direction. Possible values are true or false. Default is " - "false."); + "false.") + .insert("persistent", "false", "Whether to use persistent kernel. Default is false."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -208,12 +209,13 @@ void permute_vectors_i4x4_b(Tensor& tensor) auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser) { KernelTraits trait; - trait.pipeline = arg_parser.get_str("pipeline"); - trait.scheduler = arg_parser.get_str("scheduler"); - trait.epilogue = arg_parser.get_str("epilogue"); - trait.pad_m = arg_parser.get_bool("pad_m"); - trait.pad_n = arg_parser.get_bool("pad_n"); - trait.pad_k = arg_parser.get_bool("pad_k"); + trait.pipeline = arg_parser.get_str("pipeline"); + trait.scheduler = arg_parser.get_str("scheduler"); + trait.epilogue = arg_parser.get_str("epilogue"); + trait.pad_m = arg_parser.get_bool("pad_m"); + trait.pad_n = arg_parser.get_bool("pad_n"); + trait.pad_k = arg_parser.get_bool("pad_k"); + trait.persistent = arg_parser.get_bool("persistent"); bool structured_sparsity = arg_parser.get_bool("structured_sparsity"); diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 4a35a2bcd3..6d713bdcb8 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -15,16 +15,9 @@ from json_config import GemmConfig, RangeConfigParam from codegen_utils import ( DATA_TYPE_MAP, LAYOUT_MAP, - DEFAULT_EPILOGUE, - CSHUFFLE_EPILOGUE, - HOT_LOOP_FALSE, - RUN_MEM, - RUN_COMPV3, - RUN_COMPV4, PIPELINE_MAP, SCHEDULER_MAP, EPILOGUE_MAP, - HOT_LOOP_TRUE, BOOL_MAP, warp_tile_supported_combinations, trait_unsupported_combinations, @@ -114,7 +107,7 @@ class GemmCodeGenerator: def _generate_all_traits(self): """Generate all possible kernel traits names.""" - params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"] + params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k", "persistent"] # Generate all unique_combinations _unique = set( @@ -124,13 +117,14 @@ class GemmCodeGenerator: ) for combo in _unique: - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = combo current_combination = (pipeline, epilogue, scheduler) if current_combination not in trait_unsupported_combinations: trait_name = ( f"{pipeline}_{epilogue}_{scheduler}_" - f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}" + f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}_" + f"{BOOL_MAP(persistent)}" ) self.valid_trait_names.append(trait_name) else: @@ -189,7 +183,7 @@ using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]}; def _generate_trait_file(self, trait: str): """Generate a trait with all tile/warp combinations.""" - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_") + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = trait.split("_") filename = f"gemm_{trait}.hpp" content = f"""// SPDX-License-Identifier: MIT @@ -206,8 +200,7 @@ namespace {trait} {{ """ # Add template struct with configuration content += self._generate_kernel_struct( - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k - ) + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent) content += f"\n}} // namespace {trait}\n" (self.output_dir / filename).write_text(content) @@ -220,6 +213,7 @@ namespace {trait} {{ pad_m: str, pad_n: str, pad_k: str, + persistent: str, ) -> str: """Generate the code block of kernel struct""" return f""" @@ -229,9 +223,10 @@ template struct GemmKernel {{ - static constexpr bool kPadM = {pad_m}; - static constexpr bool kPadN = {pad_n}; - static constexpr bool kPadK = {pad_k}; + static constexpr bool kPadM = {pad_m}; + static constexpr bool kPadN = {pad_n}; + static constexpr bool kPadK = {pad_k}; + static constexpr bool kPersistent = {persistent}; static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ static constexpr bool permuteA = false; @@ -250,7 +245,6 @@ struct GemmKernel {{ permuteA, permuteB>; - using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + ALayout, BLayout, CLayout, TransposeC, + structured_sparsity, kPersistent>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -297,14 +292,14 @@ struct GemmKernel {{ using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); }} + constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = {'Kernel::MaxOccupancyGridSize(stream)' if persistent == 'true' else 'Kernel::GridSize(args.M, args.N, args.k_batch)'}; + if(stream.log_level_ > 0) {{ std::cout << "Launching kernel with args:" @@ -377,11 +372,7 @@ struct GemmKernel {{ }} }}; - if(has_hot_loop) {{ - {HOT_LOOP_TRUE[pipeline]} - }} else {{ - {HOT_LOOP_FALSE} - }} + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); return ave_time; }} @@ -395,7 +386,8 @@ struct GemmKernel {{ "{pad_k}" + "_" + "{pipeline}" + "_" + "{epilogue}" + "_" + - "{scheduler}"; + "{scheduler}" + "_" + + "{persistent}"; }} }}; """ @@ -673,6 +665,8 @@ struct KernelTraits bool pad_n; /// @brief Indicates whether padding is applied to the K dimension. bool pad_k; + /// @brief Indicates whether the kernel is persistent. + bool persistent; }; struct GemmDispatcher { @@ -773,7 +767,8 @@ private: trait.scheduler + "_" + (trait.pad_m ? "true" : "false") + "_" + (trait.pad_n ? "true" : "false") + "_" + - (trait.pad_k ? "true" : "false"); + (trait.pad_k ? "true" : "false") + "_" + + (trait.persistent ? "true" : "false"); } }; diff --git a/tile_engine/ops/gemm/json_config.py b/tile_engine/ops/gemm/json_config.py index 675a2052ef..04f2dd4890 100644 --- a/tile_engine/ops/gemm/json_config.py +++ b/tile_engine/ops/gemm/json_config.py @@ -107,6 +107,7 @@ class TraitConfig: pad_m: EnumConfigParam pad_n: EnumConfigParam pad_k: EnumConfigParam + persistent: EnumConfigParam @dataclass @@ -215,6 +216,9 @@ class GemmConfig: pad_k=EnumConfigParam( values=config_dict["trait_config"]["pad_k"]["values"] ), + persistent=EnumConfigParam( + values=config_dict["trait_config"]["persistent"]["values"] + ), ) return cls(