[CK_Tile] Simplified Mem pipeline (#2159)

* simplify code

* compiled the code

* Simplified example and codegen for mem pipeline

* Reveting config and universal gemm example

* clang formatted

* remove comments

* clang formatted

* Add memory operation changes for defualt pipeline

* fix config file

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-05-07 18:37:31 -07:00
committed by GitHub
parent cb07ad84d5
commit c7b8e86e34
4 changed files with 107 additions and 165 deletions

View File

@@ -37,7 +37,9 @@ DEFAULT_EPILOGUE = """
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC>>;
UniversalGemmProblem::TransposeC,
true,
memory_operation>>;
"""
CSHUFFLE_EPILOGUE = """
@@ -55,22 +57,23 @@ CSHUFFLE_EPILOGUE = """
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC>>;
UniversalGemmProblem::TransposeC,
memory_operation>>;
"""
HOT_LOOP_FALSE = """
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<false>{},
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
Run(ck_tile::bool_constant<false>{},
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<false>{},
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
@@ -79,68 +82,43 @@ HOT_LOOP_FALSE = """
}
"""
RUN_MEM = """
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
// Handle One and Full cases directly
if (tail_num == ck_tile::TailNumber::One) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
} else if (tail_num == ck_tile::TailNumber::Full) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
// Variadic call using fold expression
auto check_tail = [&](auto... TNs) {
(try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
};
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers");
}
check_tail(
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
);
"""
RUN_COMPV3 = """
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
@@ -152,12 +130,12 @@ RUN_COMPV3 = """
RUN_COMPV4 = """
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
"""
@@ -347,6 +325,15 @@ namespace {group_name} {{
kPadM: bool, kPadN: bool, kPadK: bool) -> str:
"""Generate kernel struct template"""
return f"""
template <typename Pipeline, ck_tile::TailNumber TN>
void try_run(ck_tile::TailNumber tn) {{
if constexpr (Pipeline::PrefetchStages > static_cast<int>(TN)) {{
if (tn == TN) {{
RunSplitk(ck_tile::bool_constant<true>{{}},
ck_tile::integral_constant<ck_tile::TailNumber, TN>{{}});
}}
}}
}}
template <int TileM, int TileN, int TileK,
int WarpM, int WarpN, int WarpK,
int WarpTileM, int WarpTileN, int WarpTileK,
@@ -355,7 +342,7 @@ struct GemmKernel {{
static constexpr bool kPadM = {BOOL_MAP(kPadM)};
static constexpr bool kPadN = {BOOL_MAP(kPadN)};
static constexpr bool kPadK = {BOOL_MAP(kPadK)};
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{
static constexpr bool permuteA = false;
static constexpr bool permuteB = false;
@@ -399,10 +386,11 @@ struct GemmKernel {{
float ave_time{{0}};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {SCHEDULER_MAP[scheduler]};
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -442,6 +430,20 @@ struct GemmKernel {{
}};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
if(args.k_batch == 1) {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}};
if(has_hot_loop) {{
{HOT_LOOP_TRUE[pipeline]}
}} else {{
@@ -450,6 +452,7 @@ struct GemmKernel {{
return ave_time;
}}
static std::string get_name() {{
return std::string("GemmKernel<Bllktile: ") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + ", " +
"WaveMap: " + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + ", " +