mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Enable persistent kernel and tail handler in tile_engine (#2300)
* Enable persistent kernel in tile_engine and use tail handler * Fix formatting * Add persistent to default_config.json * Remove extra newlines and add persistent also to user config * Reduce instances from default_config.json * add persistent to benchmark.json and custom_ci_config.json * changed the config file to have few instances --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: ThomasNing <thomasning@amd.com>
This commit is contained in:
@@ -65,93 +65,6 @@ CSHUFFLE_EPILOGUE = """
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
"""
|
||||
HOT_LOOP_FALSE = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
|
||||
}
|
||||
"""
|
||||
RUN_MEM = """
|
||||
// Handle One and Full cases directly
|
||||
if (tail_num == ck_tile::TailNumber::One) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
} else if (tail_num == ck_tile::TailNumber::Full) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
auto check_tail = [&](auto... TNs) {
|
||||
([&]{
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > static_cast<int>(decltype(TNs)::value)) {
|
||||
if(tail_num == decltype(TNs)::value) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, decltype(TNs)::value>{});
|
||||
}
|
||||
}
|
||||
}(), ...);
|
||||
};
|
||||
|
||||
check_tail(
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
|
||||
);
|
||||
"""
|
||||
|
||||
RUN_COMPV3 = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
|
||||
}
|
||||
"""
|
||||
|
||||
RUN_COMPV4 = """
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
|
||||
@@ -172,8 +85,6 @@ SCHEDULER_MAP = {
|
||||
|
||||
EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
|
||||
|
||||
HOT_LOOP_TRUE = {"mem": RUN_MEM, "compv3": RUN_COMPV3, "compv4": RUN_COMPV4}
|
||||
|
||||
|
||||
def BOOL_MAP(b_):
|
||||
return {True: "true", False: "false"}[bool(b_)]
|
||||
|
||||
@@ -96,6 +96,12 @@
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,12 @@
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,6 +95,11 @@
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +82,12 @@
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -144,7 +144,8 @@ inline auto create_args(int argc, char* argv[])
|
||||
.insert("pad_k",
|
||||
"false",
|
||||
"Whether pad or not in k direction. Possible values are true or false. Default is "
|
||||
"false.");
|
||||
"false.")
|
||||
.insert("persistent", "false", "Whether to use persistent kernel. Default is false.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -208,12 +209,13 @@ void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
KernelTraits trait;
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.pad_m = arg_parser.get_bool("pad_m");
|
||||
trait.pad_n = arg_parser.get_bool("pad_n");
|
||||
trait.pad_k = arg_parser.get_bool("pad_k");
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.pad_m = arg_parser.get_bool("pad_m");
|
||||
trait.pad_n = arg_parser.get_bool("pad_n");
|
||||
trait.pad_k = arg_parser.get_bool("pad_k");
|
||||
trait.persistent = arg_parser.get_bool("persistent");
|
||||
|
||||
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
|
||||
|
||||
|
||||
@@ -15,16 +15,9 @@ from json_config import GemmConfig, RangeConfigParam
|
||||
from codegen_utils import (
|
||||
DATA_TYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
DEFAULT_EPILOGUE,
|
||||
CSHUFFLE_EPILOGUE,
|
||||
HOT_LOOP_FALSE,
|
||||
RUN_MEM,
|
||||
RUN_COMPV3,
|
||||
RUN_COMPV4,
|
||||
PIPELINE_MAP,
|
||||
SCHEDULER_MAP,
|
||||
EPILOGUE_MAP,
|
||||
HOT_LOOP_TRUE,
|
||||
BOOL_MAP,
|
||||
warp_tile_supported_combinations,
|
||||
trait_unsupported_combinations,
|
||||
@@ -114,7 +107,7 @@ class GemmCodeGenerator:
|
||||
|
||||
def _generate_all_traits(self):
|
||||
"""Generate all possible kernel traits names."""
|
||||
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
|
||||
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k", "persistent"]
|
||||
|
||||
# Generate all unique_combinations
|
||||
_unique = set(
|
||||
@@ -124,13 +117,14 @@ class GemmCodeGenerator:
|
||||
)
|
||||
|
||||
for combo in _unique:
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = combo
|
||||
current_combination = (pipeline, epilogue, scheduler)
|
||||
|
||||
if current_combination not in trait_unsupported_combinations:
|
||||
trait_name = (
|
||||
f"{pipeline}_{epilogue}_{scheduler}_"
|
||||
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
|
||||
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}_"
|
||||
f"{BOOL_MAP(persistent)}"
|
||||
)
|
||||
self.valid_trait_names.append(trait_name)
|
||||
else:
|
||||
@@ -189,7 +183,7 @@ using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]};
|
||||
|
||||
def _generate_trait_file(self, trait: str):
|
||||
"""Generate a trait with all tile/warp combinations."""
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_")
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = trait.split("_")
|
||||
filename = f"gemm_{trait}.hpp"
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
@@ -206,8 +200,7 @@ namespace {trait} {{
|
||||
"""
|
||||
# Add template struct with configuration
|
||||
content += self._generate_kernel_struct(
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
|
||||
)
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent)
|
||||
|
||||
content += f"\n}} // namespace {trait}\n"
|
||||
(self.output_dir / filename).write_text(content)
|
||||
@@ -220,6 +213,7 @@ namespace {trait} {{
|
||||
pad_m: str,
|
||||
pad_n: str,
|
||||
pad_k: str,
|
||||
persistent: str,
|
||||
) -> str:
|
||||
"""Generate the code block of kernel struct"""
|
||||
return f"""
|
||||
@@ -229,9 +223,10 @@ template <int TileM, int TileN, int TileK,
|
||||
int WarpTileM, int WarpTileN, int WarpTileK,
|
||||
bool structured_sparsity>
|
||||
struct GemmKernel {{
|
||||
static constexpr bool kPadM = {pad_m};
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
static constexpr bool kPadM = {pad_m};
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
static constexpr bool kPersistent = {persistent};
|
||||
|
||||
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
static constexpr bool permuteA = false;
|
||||
@@ -250,7 +245,6 @@ struct GemmKernel {{
|
||||
permuteA,
|
||||
permuteB>;
|
||||
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TileParitionerGroupNum,
|
||||
@@ -261,7 +255,8 @@ struct GemmKernel {{
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
|
||||
ALayout, BLayout, CLayout, TransposeC,
|
||||
structured_sparsity, kPersistent>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
@@ -297,14 +292,14 @@ struct GemmKernel {{
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = {'Kernel::MaxOccupancyGridSize(stream)' if persistent == 'true' else 'Kernel::GridSize(args.M, args.N, args.k_batch)'};
|
||||
|
||||
if(stream.log_level_ > 0)
|
||||
{{
|
||||
std::cout << "Launching kernel with args:"
|
||||
@@ -377,11 +372,7 @@ struct GemmKernel {{
|
||||
}}
|
||||
}};
|
||||
|
||||
if(has_hot_loop) {{
|
||||
{HOT_LOOP_TRUE[pipeline]}
|
||||
}} else {{
|
||||
{HOT_LOOP_FALSE}
|
||||
}}
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
@@ -395,7 +386,8 @@ struct GemmKernel {{
|
||||
"{pad_k}" + "_" +
|
||||
"{pipeline}" + "_" +
|
||||
"{epilogue}" + "_" +
|
||||
"{scheduler}";
|
||||
"{scheduler}" + "_" +
|
||||
"{persistent}";
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
@@ -673,6 +665,8 @@ struct KernelTraits
|
||||
bool pad_n;
|
||||
/// @brief Indicates whether padding is applied to the K dimension.
|
||||
bool pad_k;
|
||||
/// @brief Indicates whether the kernel is persistent.
|
||||
bool persistent;
|
||||
};
|
||||
|
||||
struct GemmDispatcher {
|
||||
@@ -773,7 +767,8 @@ private:
|
||||
trait.scheduler + "_" +
|
||||
(trait.pad_m ? "true" : "false") + "_" +
|
||||
(trait.pad_n ? "true" : "false") + "_" +
|
||||
(trait.pad_k ? "true" : "false");
|
||||
(trait.pad_k ? "true" : "false") + "_" +
|
||||
(trait.persistent ? "true" : "false");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ class TraitConfig:
|
||||
pad_m: EnumConfigParam
|
||||
pad_n: EnumConfigParam
|
||||
pad_k: EnumConfigParam
|
||||
persistent: EnumConfigParam
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -215,6 +216,9 @@ class GemmConfig:
|
||||
pad_k=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_k"]["values"]
|
||||
),
|
||||
persistent=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["persistent"]["values"]
|
||||
),
|
||||
)
|
||||
|
||||
return cls(
|
||||
|
||||
Reference in New Issue
Block a user