Partial Progress : Generate Single Kernel until trait config

This commit is contained in:
ThruptiRajLakshmanaGowda
2025-11-20 22:36:59 +00:00
parent d6db805e82
commit 2c4d0dd289
2 changed files with 417 additions and 53 deletions

View File

@@ -26,6 +26,9 @@ def _import_validation_utils():
_validation_utils = _import_validation_utils()
is_tile_config_valid = _validation_utils.is_tile_config_valid
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
get_abc_layouts = _validation_utils.get_abc_layouts
get_abcd_layouts = _validation_utils.get_abcd_layouts
get_dtype_string = _validation_utils.get_dtype_string
class GemmKernelBuilder:
@@ -44,7 +47,7 @@ class GemmKernelBuilder:
with open(config_json, "r") as f:
self.config = json.load(f)
def list_kernels(self, kernel_name_prefix):
def _list_kernels(self, kernel_name_prefix):
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
# Get configurations using comprehensive validation
tile_configs = self._get_tile_configs(kernel_name_prefix)
@@ -313,3 +316,376 @@ class GemmKernelBuilder:
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
return combinations
def _generate_kernel_instance(self, kernel_name_prefix, tile_config, trait_combo):
"""Generate a single kernel instance"""
k_block_per_cu = self.config.get("k_block_per_cu")
if k_block_per_cu is None:
k_block_per_cu = 1
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create kernel name with proper boolean capitalization
kernel_name = f"{kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
# Create tile configuration string
tile_str = (
f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
)
tile_str += (
f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
)
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
kernel_name += f"_{tile_str}"
"""
if (
kernel_name_prefix == "gemm_universal"
or kernel_name_prefix == "gemm_multi_d"
):
# Map pipeline names to the correct pipeline implementation
pipeline_impl_map = {
"mem": "ck_tile::GemmPipelineAgBgCrMem",
"compv3": "ck_tile::GemmPipelineAgBgCrCompV3",
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
}
# Map pipeline names to base pipeline for hot loop detection
base_pipeline_map = {
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
}
# Map scheduler names to the correct enum values
elif kernel_name_prefix == "gemm_preshuffle":
permute_n = self.config.get("permute_n")
# Map pipeline names to the correct pipeline implementation
pipeline_impl_map = {
"preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2",
}
# Map pipeline names to base pipeline for hot loop detection
base_pipeline_map = {
"preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
}
scheduler_type_map = {
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
"default": "ck_tile::GemmPipelineScheduler::Default",
}
"""
instance_code = self.populate_kernel_header(kernel_name)
instance_code += self.populate_kernel_dtype_layout(kernel_name_prefix)
instance_code += self.populate_strut_begin(kernel_name)
instance_code += self.populate_tile_config(tile_config)
instance_code += self.populate_trait_config(trait_combo)
# Write into a file
simplified_name = kernel_name
if simplified_name.startswith(f"{kernel_name_prefix}_"):
simplified_name = simplified_name[len(kernel_name_prefix) + 1 :]
header_file = (
self.working_path / f"{kernel_name_prefix}_single_{simplified_name}.hpp"
)
with open(header_file, "w") as f:
f.write(instance_code)
print(f"Generated {header_file}")
def populate_kernel_header(self, kernel_name):
instance_code = f"""// Generated kernel instance for {kernel_name}
#pragma once
#include <cstdint>
#include <utility>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
"""
return instance_code
def populate_kernel_dtype_layout(self, kernel_name_prefix):
# Determine accumulator type based on datatype
acc_type = "float"
# Determine output type
c_type = self.datatype
if self.datatype in ["fp8", "bf8"]:
c_type = "fp16"
# Determine layouts based on self.layout
if kernel_name_prefix == "gemm_multi_d":
a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)
elif (
kernel_name_prefix == "gemm_universal"
or kernel_name_prefix == "gemm_preshuffle"
):
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
instance_code = f"""
using ADataType = {get_dtype_string(self.datatype)};
using BDataType = {get_dtype_string(self.datatype)};
using AccDataType = {acc_type};
using CDataType = {get_dtype_string(c_type)};"""
if kernel_name_prefix == "gemm_multi_d":
instance_code += f"""
using D0DataType = {get_dtype_string(self.datatype)};
using D1DataType = {get_dtype_string(self.datatype)};
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;"""
instance_code += f"""
using ALayout = {a_layout};
using BLayout = {b_layout};
using CLayout = {c_layout};
"""
if kernel_name_prefix == "gemm_multi_d":
instance_code += f"""
using D0Layout = {ds_layout[0]};
using D1Layout = {ds_layout[1]};
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function};"""
return instance_code
def populate_strut_begin(self, kernel_name):
instance_code = f"""
// Kernel name for display
constexpr const char* KERNEL_NAME = "{kernel_name}";
// Wrapper for simplified launch interface
struct SelectedKernel {{
"""
return instance_code
def populate_tile_config(self, tile_config):
instance_code = f"""// Tile configuration
static constexpr ck_tile::index_t BlockSize = 256;
static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};"""
return instance_code
def populate_trait_config(self, trait_combo):
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
instance_code = f"""
// Traits configurations
static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"};
static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"};
static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"};
"""
return instance_code
#########################################
'''
static constexpr bool TransposeC = false;
static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"};
static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"};
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Preshuffle = false;
static constexpr ck_tile::index_t NumWaveGroups = 1;
// Tile shape
using TileShape = ck_tile::TileGemmShape<
ck_tile::sequence<TileM, TileN, TileK>,
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>,
false, false>;
// Tile partitioner
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;
// Traits
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, NumWaveGroups>;
// Pipeline problem
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
Traits>;
// Base pipeline for hot loop detection
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
const ck_tile::index_t k_grain = args.k_batch * TileK;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{{0}};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC,
UseStructuredSparsity, UsePersistentKernel,
NumWaveGroups, Preshuffle>,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
// Epilogue
"""
# Add epilogue configuration based on type
if epilogue == "cshuffle":
instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock, // kM_
TilePartitioner::NPerBlock, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
memory_operation, // MemoryOperation_
NumWaveGroups>; // kNumWaveGroups_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
"""
else: # default epilogue
instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock, // kM_
TilePartitioner::NPerBlock, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;
"""
instance_code += f"""
// Kernel type
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Make kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);
if (!GemmKernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
// Get grid and block sizes
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 blocks = GemmKernel::BlockSize();
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
if(args.k_batch == 1) {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}}
}};
"""
'''

View File

@@ -95,65 +95,53 @@ def main():
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
)
kernel_name_prefix = "gemm_universal"
if args.list_kernels:
builder.list_kernels("gemm_universal")
builder._list_kernels(kernel_name_prefix)
elif args.gen_single:
# # Generate a single kernel file
# if not args.kernel_name or not args.tile_config or not args.trait_combo:
# parser.error(
# "--gen_single requires --kernel_name, --tile_config, and --trait_combo"
# )
# Generate a single kernel file input validation
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# # Parse tile config
# tile_parts = args.tile_config.split("_")
# tile_dims = tile_parts[0].split("x")
# warp_dims = tile_parts[1].split("x")
# warp_tile_dims = tile_parts[2].split("x")
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
# tile_config = {
# "tile_m": int(tile_dims[0]),
# "tile_n": int(tile_dims[1]),
# "tile_k": int(tile_dims[2]),
# "warp_m": int(warp_dims[0]),
# "warp_n": int(warp_dims[1]),
# "warp_k": int(warp_dims[2]),
# "warp_tile_m": int(warp_tile_dims[0]),
# "warp_tile_n": int(warp_tile_dims[1]),
# "warp_tile_k": int(warp_tile_dims[2]),
# }
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# # Parse trait combo
# trait_parts = args.trait_combo.split("_")
# trait_combo = (
# trait_parts[0], # pipeline
# trait_parts[1], # epilogue
# trait_parts[2], # scheduler
# trait_parts[3] == "True", # pad_m
# trait_parts[4] == "True", # pad_n
# trait_parts[5] == "True", # pad_k
# trait_parts[6] == "True", # persistent
# )
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3] == "True", # pad_m
trait_parts[4] == "True", # pad_n
trait_parts[5] == "True", # pad_k
trait_parts[6] == "True", # persistent
)
# k_block_per_cu = builder.config.get("k_block_per_cu")
# if k_block_per_cu is None:
# k_block_per_cu = 1
# Generate the kernel
builder._generate_kernel_instance(
kernel_name_prefix,
tile_config,
trait_combo,
)
# # Generate the kernel
# kernel_name, instance_code = builder._generate_kernel_instance(
# tile_config, trait_combo, k_block_per_cu
# )
# # Write the file
# simplified_name = kernel_name
# if simplified_name.startswith("gemm_"):
# simplified_name = simplified_name[5:]
# header_file = builder.working_path / f"gemm_single_{simplified_name}.hpp"
# with open(header_file, "w") as f:
# f.write(instance_code)
# print(f"Generated {header_file}")
pass
elif args.gen_all_individual:
# Generate all individual kernel files
# builder.run(args.num_workers)