[CK_TILE] Enable persistent kernel and tail handler in tile_engine (#2300)

* Enable persistent kernel in tile_engine and use tail handler

* Fix formatting

* Add persistent to default_config.json

* Remove extra newlines and add persistent also to user config

* Reduce instances from default_config.json

* add persistent to benchmark.json and custom_ci_config.json

* changed the config file to have few instances

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: ThomasNing <thomasning@amd.com>
This commit is contained in:
Sami Remes
2025-08-08 02:03:49 +03:00
committed by GitHub
parent 5d6d236b25
commit 3c9400471d
8 changed files with 60 additions and 125 deletions

View File

@@ -65,93 +65,6 @@ CSHUFFLE_EPILOGUE = """
UniversalGemmProblem::TransposeC,
memory_operation>>;
"""
HOT_LOOP_FALSE = """
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
}
"""
RUN_MEM = """
// Handle One and Full cases directly
if (tail_num == ck_tile::TailNumber::One) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
} else if (tail_num == ck_tile::TailNumber::Full) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
auto check_tail = [&](auto... TNs) {
([&]{
if constexpr(BaseGemmPipeline::PrefetchStages > static_cast<int>(decltype(TNs)::value)) {
if(tail_num == decltype(TNs)::value) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, decltype(TNs)::value>{});
}
}
}(), ...);
};
check_tail(
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
);
"""
RUN_COMPV3 = """
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
}
"""
RUN_COMPV4 = """
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
"""
PIPELINE_MAP = {
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
@@ -172,8 +85,6 @@ SCHEDULER_MAP = {
EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
HOT_LOOP_TRUE = {"mem": RUN_MEM, "compv3": RUN_COMPV3, "compv4": RUN_COMPV4}
def BOOL_MAP(b_):
return {True: "true", False: "false"}[bool(b_)]

View File

@@ -96,6 +96,12 @@
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
}
}

View File

@@ -77,6 +77,12 @@
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
}
}

View File

@@ -95,6 +95,11 @@
"values": [
false
]
},
"persistent": {
"values": [
false
]
}
}
}
}

View File

@@ -82,6 +82,12 @@
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
}
}

View File

@@ -144,7 +144,8 @@ inline auto create_args(int argc, char* argv[])
.insert("pad_k",
"false",
"Whether pad or not in k direction. Possible values are true or false. Default is "
"false.");
"false.")
.insert("persistent", "false", "Whether to use persistent kernel. Default is false.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -208,12 +209,13 @@ void permute_vectors_i4x4_b(Tensor& tensor)
auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser)
{
KernelTraits trait;
trait.pipeline = arg_parser.get_str("pipeline");
trait.scheduler = arg_parser.get_str("scheduler");
trait.epilogue = arg_parser.get_str("epilogue");
trait.pad_m = arg_parser.get_bool("pad_m");
trait.pad_n = arg_parser.get_bool("pad_n");
trait.pad_k = arg_parser.get_bool("pad_k");
trait.pipeline = arg_parser.get_str("pipeline");
trait.scheduler = arg_parser.get_str("scheduler");
trait.epilogue = arg_parser.get_str("epilogue");
trait.pad_m = arg_parser.get_bool("pad_m");
trait.pad_n = arg_parser.get_bool("pad_n");
trait.pad_k = arg_parser.get_bool("pad_k");
trait.persistent = arg_parser.get_bool("persistent");
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");

View File

@@ -15,16 +15,9 @@ from json_config import GemmConfig, RangeConfigParam
from codegen_utils import (
DATA_TYPE_MAP,
LAYOUT_MAP,
DEFAULT_EPILOGUE,
CSHUFFLE_EPILOGUE,
HOT_LOOP_FALSE,
RUN_MEM,
RUN_COMPV3,
RUN_COMPV4,
PIPELINE_MAP,
SCHEDULER_MAP,
EPILOGUE_MAP,
HOT_LOOP_TRUE,
BOOL_MAP,
warp_tile_supported_combinations,
trait_unsupported_combinations,
@@ -114,7 +107,7 @@ class GemmCodeGenerator:
def _generate_all_traits(self):
"""Generate all possible kernel traits names."""
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k", "persistent"]
# Generate all unique_combinations
_unique = set(
@@ -124,13 +117,14 @@ class GemmCodeGenerator:
)
for combo in _unique:
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = combo
current_combination = (pipeline, epilogue, scheduler)
if current_combination not in trait_unsupported_combinations:
trait_name = (
f"{pipeline}_{epilogue}_{scheduler}_"
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}_"
f"{BOOL_MAP(persistent)}"
)
self.valid_trait_names.append(trait_name)
else:
@@ -189,7 +183,7 @@ using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]};
def _generate_trait_file(self, trait: str):
"""Generate a trait with all tile/warp combinations."""
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_")
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = trait.split("_")
filename = f"gemm_{trait}.hpp"
content = f"""// SPDX-License-Identifier: MIT
@@ -206,8 +200,7 @@ namespace {trait} {{
"""
# Add template struct with configuration
content += self._generate_kernel_struct(
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
)
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent)
content += f"\n}} // namespace {trait}\n"
(self.output_dir / filename).write_text(content)
@@ -220,6 +213,7 @@ namespace {trait} {{
pad_m: str,
pad_n: str,
pad_k: str,
persistent: str,
) -> str:
"""Generate the code block of kernel struct"""
return f"""
@@ -229,9 +223,10 @@ template <int TileM, int TileN, int TileK,
int WarpTileM, int WarpTileN, int WarpTileK,
bool structured_sparsity>
struct GemmKernel {{
static constexpr bool kPadM = {pad_m};
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static constexpr bool kPadM = {pad_m};
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static constexpr bool kPersistent = {persistent};
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
static constexpr bool permuteA = false;
@@ -250,7 +245,6 @@ struct GemmKernel {{
permuteA,
permuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
@@ -261,7 +255,8 @@ struct GemmKernel {{
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
ALayout, BLayout, CLayout, TransposeC,
structured_sparsity, kPersistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -297,14 +292,14 @@ struct GemmKernel {{
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = {'Kernel::MaxOccupancyGridSize(stream)' if persistent == 'true' else 'Kernel::GridSize(args.M, args.N, args.k_batch)'};
if(stream.log_level_ > 0)
{{
std::cout << "Launching kernel with args:"
@@ -377,11 +372,7 @@ struct GemmKernel {{
}}
}};
if(has_hot_loop) {{
{HOT_LOOP_TRUE[pipeline]}
}} else {{
{HOT_LOOP_FALSE}
}}
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}}
@@ -395,7 +386,8 @@ struct GemmKernel {{
"{pad_k}" + "_" +
"{pipeline}" + "_" +
"{epilogue}" + "_" +
"{scheduler}";
"{scheduler}" + "_" +
"{persistent}";
}}
}};
"""
@@ -673,6 +665,8 @@ struct KernelTraits
bool pad_n;
/// @brief Indicates whether padding is applied to the K dimension.
bool pad_k;
/// @brief Indicates whether the kernel is persistent.
bool persistent;
};
struct GemmDispatcher {
@@ -773,7 +767,8 @@ private:
trait.scheduler + "_" +
(trait.pad_m ? "true" : "false") + "_" +
(trait.pad_n ? "true" : "false") + "_" +
(trait.pad_k ? "true" : "false");
(trait.pad_k ? "true" : "false") + "_" +
(trait.persistent ? "true" : "false");
}
};

View File

@@ -107,6 +107,7 @@ class TraitConfig:
pad_m: EnumConfigParam
pad_n: EnumConfigParam
pad_k: EnumConfigParam
persistent: EnumConfigParam
@dataclass
@@ -215,6 +216,9 @@ class GemmConfig:
pad_k=EnumConfigParam(
values=config_dict["trait_config"]["pad_k"]["values"]
),
persistent=EnumConfigParam(
values=config_dict["trait_config"]["persistent"]["values"]
),
)
return cls(