mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
[CK_TILE] Enable persistent kernel and tail handler in tile_engine (#2300)
* Enable persistent kernel in tile_engine and use tail handler
* Fix formatting
* Add persistent to default_config.json
* Remove extra newlines and add persistent also to user config
* Reduce instances from default_config.json
* add persistent to benchmark.json and custom_ci_config.json
* changed the config file to have few instances
---------
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: ThomasNing <thomasning@amd.com>
[ROCm/composable_kernel commit: 3c9400471d]
This commit is contained in:
@@ -15,16 +15,9 @@ from json_config import GemmConfig, RangeConfigParam
|
||||
from codegen_utils import (
|
||||
DATA_TYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
DEFAULT_EPILOGUE,
|
||||
CSHUFFLE_EPILOGUE,
|
||||
HOT_LOOP_FALSE,
|
||||
RUN_MEM,
|
||||
RUN_COMPV3,
|
||||
RUN_COMPV4,
|
||||
PIPELINE_MAP,
|
||||
SCHEDULER_MAP,
|
||||
EPILOGUE_MAP,
|
||||
HOT_LOOP_TRUE,
|
||||
BOOL_MAP,
|
||||
warp_tile_supported_combinations,
|
||||
trait_unsupported_combinations,
|
||||
@@ -114,7 +107,7 @@ class GemmCodeGenerator:
|
||||
|
||||
def _generate_all_traits(self):
|
||||
"""Generate all possible kernel traits names."""
|
||||
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
|
||||
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k", "persistent"]
|
||||
|
||||
# Generate all unique_combinations
|
||||
_unique = set(
|
||||
@@ -124,13 +117,14 @@ class GemmCodeGenerator:
|
||||
)
|
||||
|
||||
for combo in _unique:
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = combo
|
||||
current_combination = (pipeline, epilogue, scheduler)
|
||||
|
||||
if current_combination not in trait_unsupported_combinations:
|
||||
trait_name = (
|
||||
f"{pipeline}_{epilogue}_{scheduler}_"
|
||||
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
|
||||
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}_"
|
||||
f"{BOOL_MAP(persistent)}"
|
||||
)
|
||||
self.valid_trait_names.append(trait_name)
|
||||
else:
|
||||
@@ -189,7 +183,7 @@ using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]};
|
||||
|
||||
def _generate_trait_file(self, trait: str):
|
||||
"""Generate a trait with all tile/warp combinations."""
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_")
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = trait.split("_")
|
||||
filename = f"gemm_{trait}.hpp"
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
@@ -206,8 +200,7 @@ namespace {trait} {{
|
||||
"""
|
||||
# Add template struct with configuration
|
||||
content += self._generate_kernel_struct(
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
|
||||
)
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent)
|
||||
|
||||
content += f"\n}} // namespace {trait}\n"
|
||||
(self.output_dir / filename).write_text(content)
|
||||
@@ -220,6 +213,7 @@ namespace {trait} {{
|
||||
pad_m: str,
|
||||
pad_n: str,
|
||||
pad_k: str,
|
||||
persistent: str,
|
||||
) -> str:
|
||||
"""Generate the code block of kernel struct"""
|
||||
return f"""
|
||||
@@ -229,9 +223,10 @@ template <int TileM, int TileN, int TileK,
|
||||
int WarpTileM, int WarpTileN, int WarpTileK,
|
||||
bool structured_sparsity>
|
||||
struct GemmKernel {{
|
||||
static constexpr bool kPadM = {pad_m};
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
static constexpr bool kPadM = {pad_m};
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
static constexpr bool kPersistent = {persistent};
|
||||
|
||||
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
static constexpr bool permuteA = false;
|
||||
@@ -250,7 +245,6 @@ struct GemmKernel {{
|
||||
permuteA,
|
||||
permuteB>;
|
||||
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TileParitionerGroupNum,
|
||||
@@ -261,7 +255,8 @@ struct GemmKernel {{
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
|
||||
ALayout, BLayout, CLayout, TransposeC,
|
||||
structured_sparsity, kPersistent>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
@@ -297,14 +292,14 @@ struct GemmKernel {{
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = {'Kernel::MaxOccupancyGridSize(stream)' if persistent == 'true' else 'Kernel::GridSize(args.M, args.N, args.k_batch)'};
|
||||
|
||||
if(stream.log_level_ > 0)
|
||||
{{
|
||||
std::cout << "Launching kernel with args:"
|
||||
@@ -377,11 +372,7 @@ struct GemmKernel {{
|
||||
}}
|
||||
}};
|
||||
|
||||
if(has_hot_loop) {{
|
||||
{HOT_LOOP_TRUE[pipeline]}
|
||||
}} else {{
|
||||
{HOT_LOOP_FALSE}
|
||||
}}
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
@@ -395,7 +386,8 @@ struct GemmKernel {{
|
||||
"{pad_k}" + "_" +
|
||||
"{pipeline}" + "_" +
|
||||
"{epilogue}" + "_" +
|
||||
"{scheduler}";
|
||||
"{scheduler}" + "_" +
|
||||
"{persistent}";
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
@@ -673,6 +665,8 @@ struct KernelTraits
|
||||
bool pad_n;
|
||||
/// @brief Indicates whether padding is applied to the K dimension.
|
||||
bool pad_k;
|
||||
/// @brief Indicates whether the kernel is persistent.
|
||||
bool persistent;
|
||||
};
|
||||
|
||||
struct GemmDispatcher {
|
||||
@@ -773,7 +767,8 @@ private:
|
||||
trait.scheduler + "_" +
|
||||
(trait.pad_m ? "true" : "false") + "_" +
|
||||
(trait.pad_n ? "true" : "false") + "_" +
|
||||
(trait.pad_k ? "true" : "false");
|
||||
(trait.pad_k ? "true" : "false") + "_" +
|
||||
(trait.persistent ? "true" : "false");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user