diff --git a/tile_engine/ops_new/gemm/gemm_instance_builder.py b/tile_engine/ops_new/gemm/gemm_instance_builder.py index 0f4f6184a3..6a5452f329 100644 --- a/tile_engine/ops_new/gemm/gemm_instance_builder.py +++ b/tile_engine/ops_new/gemm/gemm_instance_builder.py @@ -26,6 +26,9 @@ def _import_validation_utils(): _validation_utils = _import_validation_utils() is_tile_config_valid = _validation_utils.is_tile_config_valid is_trait_combination_valid = _validation_utils.is_trait_combination_valid +get_abc_layouts = _validation_utils.get_abc_layouts +get_abcd_layouts = _validation_utils.get_abcd_layouts +get_dtype_string = _validation_utils.get_dtype_string class GemmKernelBuilder: @@ -44,7 +47,7 @@ class GemmKernelBuilder: with open(config_json, "r") as f: self.config = json.load(f) - def list_kernels(self, kernel_name_prefix): + def _list_kernels(self, kernel_name_prefix): """Write kernel list to file for CMake to read (with comprehensive validation)""" # Get configurations using comprehensive validation tile_configs = self._get_tile_configs(kernel_name_prefix) @@ -313,3 +316,376 @@ class GemmKernelBuilder: f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" ) return combinations + + def _generate_kernel_instance(self, kernel_name_prefix, tile_config, trait_combo): + """Generate a single kernel instance""" + + k_block_per_cu = self.config.get("k_block_per_cu") + if k_block_per_cu is None: + k_block_per_cu = 1 + + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"{kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = ( + f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + ) + tile_str += ( + f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + ) + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + """ + + if ( + kernel_name_prefix == "gemm_universal" + or kernel_name_prefix == "gemm_multi_d" + ): + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "mem": "ck_tile::GemmPipelineAgBgCrMem", + "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", + } + + # Map pipeline names to base pipeline for hot loop detection + base_pipeline_map = { + "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", + "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", + } + + # Map scheduler names to the correct enum values + + elif kernel_name_prefix == "gemm_preshuffle": + permute_n = self.config.get("permute_n") + + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2", + } + + # Map pipeline names to base pipeline for hot loop detection + base_pipeline_map = { + "preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2", + } + + scheduler_type_map = { + "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", + "interwave": "ck_tile::GemmPipelineScheduler::Interwave", + "default": "ck_tile::GemmPipelineScheduler::Default", + } + + """ + + instance_code = self.populate_kernel_header(kernel_name) + instance_code += self.populate_kernel_dtype_layout(kernel_name_prefix) + instance_code += self.populate_strut_begin(kernel_name) + instance_code += self.populate_tile_config(tile_config) + instance_code += self.populate_trait_config(trait_combo) + + # Write into a file + simplified_name = kernel_name + if simplified_name.startswith(f"{kernel_name_prefix}_"): + simplified_name = simplified_name[len(kernel_name_prefix) + 1 :] + + header_file = ( + self.working_path / f"{kernel_name_prefix}_single_{simplified_name}.hpp" + ) + with open(header_file, "w") as f: + f.write(instance_code) + + print(f"Generated {header_file}") + + def populate_kernel_header(self, kernel_name): + instance_code = f"""// Generated kernel instance for {kernel_name} +#pragma once + +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" +""" + return instance_code + + def populate_kernel_dtype_layout(self, kernel_name_prefix): + # Determine accumulator type based on datatype + acc_type = "float" + + # Determine output type + c_type = self.datatype + if self.datatype in ["fp8", "bf8"]: + c_type = "fp16" + + # Determine layouts based on self.layout + if kernel_name_prefix == "gemm_multi_d": + a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout) + elif ( + kernel_name_prefix == "gemm_universal" + or kernel_name_prefix == "gemm_preshuffle" + ): + a_layout, b_layout, c_layout = get_abc_layouts(self.layout) + + instance_code = f""" +using ADataType = {get_dtype_string(self.datatype)}; +using BDataType = {get_dtype_string(self.datatype)}; +using AccDataType = {acc_type}; +using CDataType = {get_dtype_string(c_type)};""" + + if kernel_name_prefix == "gemm_multi_d": + instance_code += f""" +using D0DataType = {get_dtype_string(self.datatype)}; +using D1DataType = {get_dtype_string(self.datatype)}; +using DsDataType = ck_tile::tuple;""" + + instance_code += f""" +using ALayout = {a_layout}; +using BLayout = {b_layout}; +using CLayout = {c_layout}; +""" + if kernel_name_prefix == "gemm_multi_d": + instance_code += f""" +using D0Layout = {ds_layout[0]}; +using D1Layout = {ds_layout[1]}; +using DsLayout = ck_tile::tuple; + +using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function};""" + + return instance_code + + def populate_strut_begin(self, kernel_name): + instance_code = f""" +// Kernel name for display +constexpr const char* KERNEL_NAME = "{kernel_name}"; + +// Wrapper for simplified launch interface +struct SelectedKernel {{ + """ + return instance_code + + def populate_tile_config(self, tile_config): + instance_code = f"""// Tile configuration + static constexpr ck_tile::index_t BlockSize = 256; + static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; + static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; + static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; + static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; + static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; + static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; + static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; + static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; + static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};""" + return instance_code + + def populate_trait_config(self, trait_combo): + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + instance_code = f""" + + // Traits configurations + static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"}; + static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"}; + static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"}; + + """ + return instance_code + + +######################################### + +''' + + + + + + static constexpr bool TransposeC = false; + static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"}; + static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"}; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Preshuffle = false; + static constexpr ck_tile::index_t NumWaveGroups = 1; + + // Tile shape + using TileShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, + false, false>; + + // Tile partitioner + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + + // Traits + using Traits = ck_tile::TileGemmTraits; + + // Pipeline problem + using GemmPipelineProblem = ck_tile::GemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + Traits>; + + // Base pipeline for hot loop detection + using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}; + + static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ + const ck_tile::index_t k_grain = args.k_batch * TileK; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = {scheduler_type_map.get(scheduler)}; + [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + ck_tile::TileGemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v>; + + using GemmPipeline = {pipeline_impl_map.get(pipeline)}; + + // Epilogue +""" + + # Add epilogue configuration based on type + if epilogue == "cshuffle": + instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + memory_operation, // MemoryOperation_ + NumWaveGroups>; // kNumWaveGroups_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue; +""" + else: # default epilogue + instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; +""" + + instance_code += f""" + + // Kernel type + using GemmKernel = ck_tile::GemmKernel; + + // Make kernel arguments + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + // Get grid and block sizes + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' + << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; + }} + + // Launch kernel + constexpr int kBlockPerCu = {k_block_per_cu}; + ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ + if(args.k_batch == 1) {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} else {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} + }}; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; + }} +}}; +""" + +''' diff --git a/tile_engine/ops_new/gemm/gemm_universal/gemm_universal_instance_builder.py b/tile_engine/ops_new/gemm/gemm_universal/gemm_universal_instance_builder.py index 1bddbd1f70..819c1b94dd 100644 --- a/tile_engine/ops_new/gemm/gemm_universal/gemm_universal_instance_builder.py +++ b/tile_engine/ops_new/gemm/gemm_universal/gemm_universal_instance_builder.py @@ -95,65 +95,53 @@ def main(): args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json ) + kernel_name_prefix = "gemm_universal" if args.list_kernels: - builder.list_kernels("gemm_universal") + builder._list_kernels(kernel_name_prefix) elif args.gen_single: - # # Generate a single kernel file - # if not args.kernel_name or not args.tile_config or not args.trait_combo: - # parser.error( - # "--gen_single requires --kernel_name, --tile_config, and --trait_combo" - # ) + # Generate a single kernel file input validation + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) - # # Parse tile config - # tile_parts = args.tile_config.split("_") - # tile_dims = tile_parts[0].split("x") - # warp_dims = tile_parts[1].split("x") - # warp_tile_dims = tile_parts[2].split("x") + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") - # tile_config = { - # "tile_m": int(tile_dims[0]), - # "tile_n": int(tile_dims[1]), - # "tile_k": int(tile_dims[2]), - # "warp_m": int(warp_dims[0]), - # "warp_n": int(warp_dims[1]), - # "warp_k": int(warp_dims[2]), - # "warp_tile_m": int(warp_tile_dims[0]), - # "warp_tile_n": int(warp_tile_dims[1]), - # "warp_tile_k": int(warp_tile_dims[2]), - # } + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } - # # Parse trait combo - # trait_parts = args.trait_combo.split("_") - # trait_combo = ( - # trait_parts[0], # pipeline - # trait_parts[1], # epilogue - # trait_parts[2], # scheduler - # trait_parts[3] == "True", # pad_m - # trait_parts[4] == "True", # pad_n - # trait_parts[5] == "True", # pad_k - # trait_parts[6] == "True", # persistent - # ) + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3] == "True", # pad_m + trait_parts[4] == "True", # pad_n + trait_parts[5] == "True", # pad_k + trait_parts[6] == "True", # persistent + ) - # k_block_per_cu = builder.config.get("k_block_per_cu") - # if k_block_per_cu is None: - # k_block_per_cu = 1 + # Generate the kernel + builder._generate_kernel_instance( + kernel_name_prefix, + tile_config, + trait_combo, + ) - # # Generate the kernel - # kernel_name, instance_code = builder._generate_kernel_instance( - # tile_config, trait_combo, k_block_per_cu - # ) - - # # Write the file - # simplified_name = kernel_name - # if simplified_name.startswith("gemm_"): - # simplified_name = simplified_name[5:] - - # header_file = builder.working_path / f"gemm_single_{simplified_name}.hpp" - # with open(header_file, "w") as f: - # f.write(instance_code) - - # print(f"Generated {header_file}") - pass elif args.gen_all_individual: # Generate all individual kernel files # builder.run(args.num_workers)