mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Partial Progress : Generate Single Kernel until trait config
This commit is contained in:
@@ -26,6 +26,9 @@ def _import_validation_utils():
|
||||
_validation_utils = _import_validation_utils()
|
||||
is_tile_config_valid = _validation_utils.is_tile_config_valid
|
||||
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
|
||||
get_abc_layouts = _validation_utils.get_abc_layouts
|
||||
get_abcd_layouts = _validation_utils.get_abcd_layouts
|
||||
get_dtype_string = _validation_utils.get_dtype_string
|
||||
|
||||
|
||||
class GemmKernelBuilder:
|
||||
@@ -44,7 +47,7 @@ class GemmKernelBuilder:
|
||||
with open(config_json, "r") as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
def list_kernels(self, kernel_name_prefix):
|
||||
def _list_kernels(self, kernel_name_prefix):
|
||||
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
|
||||
# Get configurations using comprehensive validation
|
||||
tile_configs = self._get_tile_configs(kernel_name_prefix)
|
||||
@@ -313,3 +316,376 @@ class GemmKernelBuilder:
|
||||
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
|
||||
)
|
||||
return combinations
|
||||
|
||||
def _generate_kernel_instance(self, kernel_name_prefix, tile_config, trait_combo):
|
||||
"""Generate a single kernel instance"""
|
||||
|
||||
k_block_per_cu = self.config.get("k_block_per_cu")
|
||||
if k_block_per_cu is None:
|
||||
k_block_per_cu = 1
|
||||
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"{kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
# Create tile configuration string
|
||||
tile_str = (
|
||||
f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
)
|
||||
tile_str += (
|
||||
f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
)
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
"""
|
||||
|
||||
if (
|
||||
kernel_name_prefix == "gemm_universal"
|
||||
or kernel_name_prefix == "gemm_multi_d"
|
||||
):
|
||||
# Map pipeline names to the correct pipeline implementation
|
||||
pipeline_impl_map = {
|
||||
"mem": "ck_tile::GemmPipelineAgBgCrMem",
|
||||
"compv3": "ck_tile::GemmPipelineAgBgCrCompV3",
|
||||
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
|
||||
# Map pipeline names to base pipeline for hot loop detection
|
||||
base_pipeline_map = {
|
||||
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
|
||||
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
|
||||
# Map scheduler names to the correct enum values
|
||||
|
||||
elif kernel_name_prefix == "gemm_preshuffle":
|
||||
permute_n = self.config.get("permute_n")
|
||||
|
||||
# Map pipeline names to the correct pipeline implementation
|
||||
pipeline_impl_map = {
|
||||
"preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
# Map pipeline names to base pipeline for hot loop detection
|
||||
base_pipeline_map = {
|
||||
"preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
scheduler_type_map = {
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
|
||||
"default": "ck_tile::GemmPipelineScheduler::Default",
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
instance_code = self.populate_kernel_header(kernel_name)
|
||||
instance_code += self.populate_kernel_dtype_layout(kernel_name_prefix)
|
||||
instance_code += self.populate_strut_begin(kernel_name)
|
||||
instance_code += self.populate_tile_config(tile_config)
|
||||
instance_code += self.populate_trait_config(trait_combo)
|
||||
|
||||
# Write into a file
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith(f"{kernel_name_prefix}_"):
|
||||
simplified_name = simplified_name[len(kernel_name_prefix) + 1 :]
|
||||
|
||||
header_file = (
|
||||
self.working_path / f"{kernel_name_prefix}_single_{simplified_name}.hpp"
|
||||
)
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
print(f"Generated {header_file}")
|
||||
|
||||
def populate_kernel_header(self, kernel_name):
|
||||
instance_code = f"""// Generated kernel instance for {kernel_name}
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
"""
|
||||
return instance_code
|
||||
|
||||
def populate_kernel_dtype_layout(self, kernel_name_prefix):
|
||||
# Determine accumulator type based on datatype
|
||||
acc_type = "float"
|
||||
|
||||
# Determine output type
|
||||
c_type = self.datatype
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_type = "fp16"
|
||||
|
||||
# Determine layouts based on self.layout
|
||||
if kernel_name_prefix == "gemm_multi_d":
|
||||
a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)
|
||||
elif (
|
||||
kernel_name_prefix == "gemm_universal"
|
||||
or kernel_name_prefix == "gemm_preshuffle"
|
||||
):
|
||||
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
|
||||
|
||||
instance_code = f"""
|
||||
using ADataType = {get_dtype_string(self.datatype)};
|
||||
using BDataType = {get_dtype_string(self.datatype)};
|
||||
using AccDataType = {acc_type};
|
||||
using CDataType = {get_dtype_string(c_type)};"""
|
||||
|
||||
if kernel_name_prefix == "gemm_multi_d":
|
||||
instance_code += f"""
|
||||
using D0DataType = {get_dtype_string(self.datatype)};
|
||||
using D1DataType = {get_dtype_string(self.datatype)};
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;"""
|
||||
|
||||
instance_code += f"""
|
||||
using ALayout = {a_layout};
|
||||
using BLayout = {b_layout};
|
||||
using CLayout = {c_layout};
|
||||
"""
|
||||
if kernel_name_prefix == "gemm_multi_d":
|
||||
instance_code += f"""
|
||||
using D0Layout = {ds_layout[0]};
|
||||
using D1Layout = {ds_layout[1]};
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
|
||||
using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function};"""
|
||||
|
||||
return instance_code
|
||||
|
||||
def populate_strut_begin(self, kernel_name):
|
||||
instance_code = f"""
|
||||
// Kernel name for display
|
||||
constexpr const char* KERNEL_NAME = "{kernel_name}";
|
||||
|
||||
// Wrapper for simplified launch interface
|
||||
struct SelectedKernel {{
|
||||
"""
|
||||
return instance_code
|
||||
|
||||
def populate_tile_config(self, tile_config):
|
||||
instance_code = f"""// Tile configuration
|
||||
static constexpr ck_tile::index_t BlockSize = 256;
|
||||
static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
|
||||
static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
|
||||
static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
|
||||
static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
|
||||
static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
|
||||
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};"""
|
||||
return instance_code
|
||||
|
||||
def populate_trait_config(self, trait_combo):
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
instance_code = f"""
|
||||
|
||||
// Traits configurations
|
||||
static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"};
|
||||
static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"};
|
||||
static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"};
|
||||
|
||||
"""
|
||||
return instance_code
|
||||
|
||||
|
||||
#########################################
|
||||
|
||||
'''
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"};
|
||||
static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"};
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
|
||||
// Tile shape
|
||||
using TileShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<TileM, TileN, TileK>,
|
||||
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>,
|
||||
false, false>;
|
||||
|
||||
// Tile partitioner
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;
|
||||
|
||||
// Traits
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, NumWaveGroups>;
|
||||
|
||||
// Pipeline problem
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
Traits>;
|
||||
|
||||
// Base pipeline for hot loop detection
|
||||
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
|
||||
|
||||
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
const ck_tile::index_t k_grain = args.k_batch * TileK;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{{0}};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
|
||||
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC,
|
||||
UseStructuredSparsity, UsePersistentKernel,
|
||||
NumWaveGroups, Preshuffle>,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
|
||||
|
||||
// Epilogue
|
||||
"""
|
||||
|
||||
# Add epilogue configuration based on type
|
||||
if epilogue == "cshuffle":
|
||||
instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock, // kM_
|
||||
TilePartitioner::NPerBlock, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
memory_operation, // MemoryOperation_
|
||||
NumWaveGroups>; // kNumWaveGroups_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
|
||||
"""
|
||||
else: # default epilogue
|
||||
instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock, // kM_
|
||||
TilePartitioner::NPerBlock, // kN_
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM, // kMPerXdl_
|
||||
WarpTileN, // kNPerXdl_
|
||||
WarpTileK, // kKPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;
|
||||
"""
|
||||
|
||||
instance_code += f"""
|
||||
|
||||
// Kernel type
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Make kernel arguments
|
||||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||||
|
||||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
|
||||
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||||
if(args.k_batch == 1) {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}} else {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
}};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
'''
|
||||
|
||||
@@ -95,65 +95,53 @@ def main():
|
||||
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
|
||||
)
|
||||
|
||||
kernel_name_prefix = "gemm_universal"
|
||||
if args.list_kernels:
|
||||
builder.list_kernels("gemm_universal")
|
||||
builder._list_kernels(kernel_name_prefix)
|
||||
elif args.gen_single:
|
||||
# # Generate a single kernel file
|
||||
# if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
# parser.error(
|
||||
# "--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
# )
|
||||
# Generate a single kernel file input validation
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
|
||||
# # Parse tile config
|
||||
# tile_parts = args.tile_config.split("_")
|
||||
# tile_dims = tile_parts[0].split("x")
|
||||
# warp_dims = tile_parts[1].split("x")
|
||||
# warp_tile_dims = tile_parts[2].split("x")
|
||||
# Parse tile config
|
||||
tile_parts = args.tile_config.split("_")
|
||||
tile_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
|
||||
# tile_config = {
|
||||
# "tile_m": int(tile_dims[0]),
|
||||
# "tile_n": int(tile_dims[1]),
|
||||
# "tile_k": int(tile_dims[2]),
|
||||
# "warp_m": int(warp_dims[0]),
|
||||
# "warp_n": int(warp_dims[1]),
|
||||
# "warp_k": int(warp_dims[2]),
|
||||
# "warp_tile_m": int(warp_tile_dims[0]),
|
||||
# "warp_tile_n": int(warp_tile_dims[1]),
|
||||
# "warp_tile_k": int(warp_tile_dims[2]),
|
||||
# }
|
||||
tile_config = {
|
||||
"tile_m": int(tile_dims[0]),
|
||||
"tile_n": int(tile_dims[1]),
|
||||
"tile_k": int(tile_dims[2]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_k": int(warp_dims[2]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"warp_tile_k": int(warp_tile_dims[2]),
|
||||
}
|
||||
|
||||
# # Parse trait combo
|
||||
# trait_parts = args.trait_combo.split("_")
|
||||
# trait_combo = (
|
||||
# trait_parts[0], # pipeline
|
||||
# trait_parts[1], # epilogue
|
||||
# trait_parts[2], # scheduler
|
||||
# trait_parts[3] == "True", # pad_m
|
||||
# trait_parts[4] == "True", # pad_n
|
||||
# trait_parts[5] == "True", # pad_k
|
||||
# trait_parts[6] == "True", # persistent
|
||||
# )
|
||||
# Parse trait combo
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # pipeline
|
||||
trait_parts[1], # epilogue
|
||||
trait_parts[2], # scheduler
|
||||
trait_parts[3] == "True", # pad_m
|
||||
trait_parts[4] == "True", # pad_n
|
||||
trait_parts[5] == "True", # pad_k
|
||||
trait_parts[6] == "True", # persistent
|
||||
)
|
||||
|
||||
# k_block_per_cu = builder.config.get("k_block_per_cu")
|
||||
# if k_block_per_cu is None:
|
||||
# k_block_per_cu = 1
|
||||
# Generate the kernel
|
||||
builder._generate_kernel_instance(
|
||||
kernel_name_prefix,
|
||||
tile_config,
|
||||
trait_combo,
|
||||
)
|
||||
|
||||
# # Generate the kernel
|
||||
# kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
# tile_config, trait_combo, k_block_per_cu
|
||||
# )
|
||||
|
||||
# # Write the file
|
||||
# simplified_name = kernel_name
|
||||
# if simplified_name.startswith("gemm_"):
|
||||
# simplified_name = simplified_name[5:]
|
||||
|
||||
# header_file = builder.working_path / f"gemm_single_{simplified_name}.hpp"
|
||||
# with open(header_file, "w") as f:
|
||||
# f.write(instance_code)
|
||||
|
||||
# print(f"Generated {header_file}")
|
||||
pass
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
# builder.run(args.num_workers)
|
||||
|
||||
Reference in New Issue
Block a user