[CK_TILE] Enable persistent kernel and tail handler in tile_engine (#2300)

* Enable persistent kernel in tile_engine and use tail handler

* Fix formatting

* Add persistent to default_config.json

* Remove extra newlines and add persistent also to user config

* Reduce instances from default_config.json

* add persistent to benchmark.json and custom_ci_config.json

* changed the config file to have few instances

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: ThomasNing <thomasning@amd.com>

[ROCm/composable_kernel commit: 3c9400471d]
This commit is contained in:
Sami Remes
2025-08-08 02:03:49 +03:00
committed by GitHub
parent ed2293a87b
commit bc1c2b6e7b
8 changed files with 60 additions and 125 deletions

View File

@@ -15,16 +15,9 @@ from json_config import GemmConfig, RangeConfigParam
from codegen_utils import (
DATA_TYPE_MAP,
LAYOUT_MAP,
DEFAULT_EPILOGUE,
CSHUFFLE_EPILOGUE,
HOT_LOOP_FALSE,
RUN_MEM,
RUN_COMPV3,
RUN_COMPV4,
PIPELINE_MAP,
SCHEDULER_MAP,
EPILOGUE_MAP,
HOT_LOOP_TRUE,
BOOL_MAP,
warp_tile_supported_combinations,
trait_unsupported_combinations,
@@ -114,7 +107,7 @@ class GemmCodeGenerator:
def _generate_all_traits(self):
"""Generate all possible kernel traits names."""
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k", "persistent"]
# Generate all unique_combinations
_unique = set(
@@ -124,13 +117,14 @@ class GemmCodeGenerator:
)
for combo in _unique:
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = combo
current_combination = (pipeline, epilogue, scheduler)
if current_combination not in trait_unsupported_combinations:
trait_name = (
f"{pipeline}_{epilogue}_{scheduler}_"
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}_"
f"{BOOL_MAP(persistent)}"
)
self.valid_trait_names.append(trait_name)
else:
@@ -189,7 +183,7 @@ using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]};
def _generate_trait_file(self, trait: str):
"""Generate a trait with all tile/warp combinations."""
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_")
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = trait.split("_")
filename = f"gemm_{trait}.hpp"
content = f"""// SPDX-License-Identifier: MIT
@@ -206,8 +200,7 @@ namespace {trait} {{
"""
# Add template struct with configuration
content += self._generate_kernel_struct(
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
)
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent)
content += f"\n}} // namespace {trait}\n"
(self.output_dir / filename).write_text(content)
@@ -220,6 +213,7 @@ namespace {trait} {{
pad_m: str,
pad_n: str,
pad_k: str,
persistent: str,
) -> str:
"""Generate the code block of kernel struct"""
return f"""
@@ -229,9 +223,10 @@ template <int TileM, int TileN, int TileK,
int WarpTileM, int WarpTileN, int WarpTileK,
bool structured_sparsity>
struct GemmKernel {{
static constexpr bool kPadM = {pad_m};
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static constexpr bool kPadM = {pad_m};
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static constexpr bool kPersistent = {persistent};
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
static constexpr bool permuteA = false;
@@ -250,7 +245,6 @@ struct GemmKernel {{
permuteA,
permuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
@@ -261,7 +255,8 @@ struct GemmKernel {{
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
ALayout, BLayout, CLayout, TransposeC,
structured_sparsity, kPersistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -297,14 +292,14 @@ struct GemmKernel {{
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = {'Kernel::MaxOccupancyGridSize(stream)' if persistent == 'true' else 'Kernel::GridSize(args.M, args.N, args.k_batch)'};
if(stream.log_level_ > 0)
{{
std::cout << "Launching kernel with args:"
@@ -377,11 +372,7 @@ struct GemmKernel {{
}}
}};
if(has_hot_loop) {{
{HOT_LOOP_TRUE[pipeline]}
}} else {{
{HOT_LOOP_FALSE}
}}
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}}
@@ -395,7 +386,8 @@ struct GemmKernel {{
"{pad_k}" + "_" +
"{pipeline}" + "_" +
"{epilogue}" + "_" +
"{scheduler}";
"{scheduler}" + "_" +
"{persistent}";
}}
}};
"""
@@ -673,6 +665,8 @@ struct KernelTraits
bool pad_n;
/// @brief Indicates whether padding is applied to the K dimension.
bool pad_k;
/// @brief Indicates whether the kernel is persistent.
bool persistent;
};
struct GemmDispatcher {
@@ -773,7 +767,8 @@ private:
trait.scheduler + "_" +
(trait.pad_m ? "true" : "false") + "_" +
(trait.pad_n ? "true" : "false") + "_" +
(trait.pad_k ? "true" : "false");
(trait.pad_k ? "true" : "false") + "_" +
(trait.persistent ? "true" : "false");
}
};