# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT import os import json from pathlib import Path import importlib.util import itertools import logging def _import_validation_utils(): """Import validation utilities from commons directory.""" current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) # Load the module dynamically spec = importlib.util.spec_from_file_location( "validation_utils", os.path.join(parent_dir, "gemm", "gemm_validation_utils.py"), ) validation_utils = importlib.util.module_from_spec(spec) spec.loader.exec_module(validation_utils) return validation_utils # Import validation functions _validation_utils = _import_validation_utils() is_tile_config_valid = _validation_utils.is_tile_config_valid is_trait_combination_valid = _validation_utils.is_trait_combination_valid get_abc_layouts = _validation_utils.get_abc_layouts get_abcd_layouts = _validation_utils.get_abcd_layouts get_dtype_string = _validation_utils.get_dtype_string class GemmKernelBuilder: def __init__( self, kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json=None, max_instances=None, seed=None, tier=None, manifest_path=None, ): self.kernel_name_prefix = kernel_name_prefix self.working_path = Path(working_path) self.gpu_target = gpu_target self.datatype = datatype self.layout = layout self.config_json = config_json self.max_instances = max_instances self.seed = seed self.tier = tier self.manifest_path = manifest_path # Create working directory if it doesn't exist self.working_path.mkdir(parents=True, exist_ok=True) # Load configuration if config_json and os.path.exists(config_json): with open(config_json, "r") as f: self.config = json.load(f) def _apply_sampling(self, kernel_list): """Apply RFC Sobol+LHS+maximin sampling. Returns sampled subset.""" if self.max_instances is None or len(kernel_list) <= self.max_instances: return kernel_list import sys sampling_parent = os.path.join(os.path.dirname(__file__), "..", "..") if sampling_parent not in sys.path: sys.path.insert(0, sampling_parent) from sampling.sampler import sample_feasible_set from sampling.seed import make_seed from sampling.feasible_set import GEMM_AXES effective_seed = make_seed( self.seed, self.gpu_target, self.datatype, self.layout ) flat_items = [] for k in kernel_list: flat = dict(k["tile_config"]) pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = ( self._normalize_trait_combo(k["trait_combo"]) ) flat.update( { "pipeline": pipeline, "epilogue": epilogue, "scheduler": scheduler, "pad_m": pad_m, "pad_n": pad_n, "pad_k": pad_k, } ) if self._uses_persistent_trait(): flat["persistent"] = persistent flat_items.append(flat) selected, method, selected_indices = sample_feasible_set( flat_items, self.max_instances, effective_seed, GEMM_AXES, ) kernel_list = [kernel_list[i] for i in selected_indices] if self.manifest_path: from sampling.manifest import write_manifest write_manifest( selected, self.manifest_path, self.kernel_name_prefix, self.datatype, self.layout, self.gpu_target, effective_seed, self.tier or "daily", method, ) print( f"Sampled {len(kernel_list)} from feasible set " f"(budget={self.max_instances}, seed={effective_seed}, method={method})" ) return kernel_list def _get_sampled_kernel_list(self): """Enumerate all valid (tile_config, trait_combo) pairs and apply sampling. Returns a list of dicts with keys: name, tile_config, trait_combo. Both _list_kernels and _generate_all_individual should use this to guarantee identical enumeration and sampling.""" tile_configs = self._get_tile_configs() trait_combos = self._generate_trait_combinations() kernel_list = [] for tile_config in tile_configs: for trait_combo in trait_combos: kernel_name = self._format_kernel_name(trait_combo, tile_config) ( pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent, ) = trait_combo # Skip if this tile config is not valid for this specific pipeline if not self._validate_tile_config( tile_config["tile_m"], tile_config["tile_n"], tile_config["tile_k"], tile_config["warp_m"], tile_config["warp_n"], tile_config["warp_k"], tile_config["warp_tile_m"], tile_config["warp_tile_n"], tile_config["warp_tile_k"], pipeline, ): continue # Create kernel name with proper boolean capitalization kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" # Create tile configuration string tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" kernel_name += f"_{tile_str}" kernel_list.append( { "name": kernel_name, "tile_config": tile_config, "trait_combo": trait_combo, } ) return self._apply_sampling(kernel_list) def _uses_persistent_trait(self): return self.kernel_name_prefix != "batched_gemm" def _normalize_trait_combo(self, trait_combo): if len(trait_combo) == 7: return trait_combo if len(trait_combo) == 6: return (*trait_combo, False) raise ValueError(f"Unexpected trait combination: {trait_combo}") def _format_tile_config_string(self, tile_config): tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" tile_str += ( f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" ) tile_str += ( f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" ) return tile_str def _format_trait_combo_string(self, trait_combo): pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = ( self._normalize_trait_combo(trait_combo) ) trait_parts = [ pipeline, epilogue, scheduler, str(pad_m), str(pad_n), str(pad_k), ] if self._uses_persistent_trait(): trait_parts.append(str(persistent)) return "_".join(trait_parts) def _format_kernel_name(self, trait_combo, tile_config=None): pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = ( self._normalize_trait_combo(trait_combo) ) kernel_name = ( f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_" f"{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_" f"{str(pad_k).capitalize()}" ) if self._uses_persistent_trait(): kernel_name += f"_{str(persistent).capitalize()}" if tile_config is not None: kernel_name += f"_{self._format_tile_config_string(tile_config)}" return kernel_name def _list_kernels(self): """Write kernel list to file for CMake to read (with comprehensive validation)""" kernel_list = self._get_sampled_kernel_list() # Write kernel count with open( self.working_path / f"{self.kernel_name_prefix}_kernel_count.txt", "w" ) as f: f.write(str(len(kernel_list))) # Write kernel list with open( self.working_path / f"{self.kernel_name_prefix}_kernel_list.txt", "w" ) as f: for kernel in kernel_list: # Format: kernel_name|tile_config|trait_combo tile_config = kernel["tile_config"] trait_combo = kernel["trait_combo"] tile_str = self._format_tile_config_string(tile_config) trait_str = self._format_trait_combo_string(trait_combo) f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") print(f"Listed {len(kernel_list)} kernel configurations") def _get_tile_configs(self): """Get tile configurations for the current datatype and layout""" tile_config = self.config["tile_config"] # Generate values in the config if default range is given if tile_config.get("tile_m").get("values") is None: tile_config.get("tile_m")["values"] = self._generate_values( tile_config.get("tile_m").get("min"), tile_config.get("tile_m").get("max"), tile_config.get("tile_m").get("step"), ) if tile_config.get("tile_n").get("values") is None: tile_config.get("tile_n")["values"] = self._generate_values( tile_config.get("tile_n").get("min"), tile_config.get("tile_n").get("max"), tile_config.get("tile_n").get("step"), ) if tile_config.get("tile_k").get("values") is None: tile_config.get("tile_k")["values"] = self._generate_values( tile_config.get("tile_k").get("min"), tile_config.get("tile_k").get("max"), tile_config.get("tile_k").get("step"), ) # Get all possible values for each parameter tile_m_values = tile_config.get("tile_m").get("values") tile_n_values = tile_config.get("tile_n").get("values") tile_k_values = tile_config.get("tile_k").get("values") warp_m_values = tile_config.get("warp_m").get("values") warp_n_values = tile_config.get("warp_n").get("values") warp_k_values = tile_config.get("warp_k").get("values") warp_tile_m_values = tile_config.get("warp_tile_m").get("values") warp_tile_n_values = tile_config.get("warp_tile_n").get("values") warp_tile_k_values = tile_config.get("warp_tile_k").get("values") # Generate all combinations pipelines = self.config["trait_config"].get("pipeline", {}).get("values", []) if not pipelines: if self.kernel_name_prefix == "gemm_preshuffle": pipelines = ["preshufflev2"] elif self.kernel_name_prefix == "mx_gemm": pipelines = ["comp_async"] elif self.kernel_name_prefix in ["grouped_gemm_rowcolquant", "grouped_gemm_tensorquant", "gemm_rowcolquant", "gemm_tensor_quant"]: pipelines = ["compv3"] elif self.kernel_name_prefix in ["gemm_universal", "gemm_multi_d", "gemm_multi_abd", "grouped_gemm", "batched_contraction", "batched_gemm"]: pipelines = ["compv4"] configs = [] for tile_m in tile_m_values: for tile_n in tile_n_values: for tile_k in tile_k_values: for warp_m in warp_m_values: for warp_n in warp_n_values: for warp_k in warp_k_values: for warp_tile_m in warp_tile_m_values: for warp_tile_n in warp_tile_n_values: for warp_tile_k in warp_tile_k_values: # Accept tile if valid for any pipeline if any( self._validate_tile_config( tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k, pipeline, ) for pipeline in pipelines ): configs.append( { "tile_m": tile_m, "tile_n": tile_n, "tile_k": tile_k, "warp_m": warp_m, "warp_n": warp_n, "warp_k": warp_k, "warp_tile_m": warp_tile_m, "warp_tile_n": warp_tile_n, "warp_tile_k": warp_tile_k, } ) return configs def _generate_values(self, min_val, max_val, step): """Generate a list of values from min to max with the given step""" values = [] val = min_val while val <= max_val: values.append(val) val += step return values def _validate_tile_config( self, tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k, pipeline, ): """Validate that tile configuration is reasonable""" # Validate preshuffle specific constraints if ( self.config.get("permute_n") is not None and self.config.get("permute_n") is True ): valid = (tile_n / warp_tile_n / warp_n) % 2 == 0 if not valid: return False # Determine data types for validation a_datatype = self.datatype b_datatype = self.datatype c_datatype = self.datatype layout = self.layout # Special handling for certain data types if self.datatype in ["fp4", "fp8", "bf8"]: c_datatype = "fp16" # Use the comprehensive validation function return is_tile_config_valid( tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype, pipeline, layout, self.gpu_target, self.kernel_name_prefix, ) def _generate_trait_combinations(self): """Generate all combinations of traits""" trait_config = self.config["trait_config"] pipelines = trait_config.get("pipeline").get("values") epilogues = trait_config.get("epilogue").get("values") schedulers = trait_config.get("scheduler").get("values") pad_m_values = trait_config.get("pad_m").get("values") pad_n_values = trait_config.get("pad_n").get("values") pad_k_values = trait_config.get("pad_k").get("values") if self.kernel_name_prefix in ["gemm_rowcolquant", "batched_gemm"]: persistent_values = [ False ] # Force disable persistent where it is unsupported or not part of the trait key else: persistent_values = trait_config.get("persistent").get("values") all_combinations = list( itertools.product( pipelines, epilogues, schedulers, pad_m_values, pad_n_values, pad_k_values, persistent_values, ) ) # Filter out unsupported trait combinations combinations = [] for combo in all_combinations: pipeline, epilogue, scheduler = combo[:3] if is_trait_combination_valid( pipeline, epilogue, scheduler, self.kernel_name_prefix ): combinations.append(combo) else: logging.debug( f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" ) return combinations def _generate_kernel_instance(self, tile_config, trait_combo): """Generate a single kernel instance""" k_block_per_cu = self.config.get("k_block_per_cu", 1) ( pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent, ) = self._normalize_trait_combo(trait_combo) kernel_name = self._format_kernel_name(trait_combo, tile_config) if self.kernel_name_prefix in [ "gemm_universal", "gemm_multi_d", "gemm_multi_abd", "grouped_gemm", "batched_contraction", "batched_gemm", ]: # Map pipeline names to the correct pipeline implementation pipeline_impl_map = { "mem": "ck_tile::GemmPipelineAgBgCrMem", "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", } # Map pipeline names to base pipeline for hot loop detection base_pipeline_map = { "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", } elif self.kernel_name_prefix == "gemm_preshuffle": # Map pipeline names to the correct pipeline implementation pipeline_impl_map = { "preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2", } # Map pipeline names to base pipeline for hot loop detection base_pipeline_map = { "preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2", } elif self.kernel_name_prefix == "mx_gemm": pipeline_impl_map = { "comp_async": "ck_tile::MXGemmPipelineAgBgCrCompAsync", } base_pipeline_map = {} scheduler_type_map = { "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", "interwave": "ck_tile::GemmPipelineScheduler::Interwave", "default": "ck_tile::GemmPipelineScheduler::Default", } instance_code = self.populate_kernel_header(kernel_name) instance_code += self.populate_kernel_dtype_layout() instance_code += self.populate_strut_begin(kernel_name) instance_code += self.populate_tile_config(tile_config) instance_code += self.populate_trait_config(trait_combo) instance_code += self.populate_initialization(base_pipeline_map, pipeline) instance_code += self.populate_launch( scheduler_type_map, scheduler, pipeline_impl_map, pipeline, epilogue, k_block_per_cu, persistent, ) # Write into a file simplified_name = kernel_name if simplified_name.startswith(f"{self.kernel_name_prefix}_"): simplified_name = simplified_name[len(self.kernel_name_prefix) + 1 :] header_file = ( self.working_path / f"{self.kernel_name_prefix}_single_{simplified_name}.hpp" ) with open(header_file, "w") as f: f.write(instance_code) print(f"Generated {header_file}") return kernel_name, instance_code def populate_kernel_header(self, kernel_name): instance_code = f"""// Generated kernel instance for {kernel_name} #pragma once #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" """ if self.kernel_name_prefix == "grouped_gemm": instance_code += """#include #include #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" """ elif self.kernel_name_prefix == "mx_gemm": instance_code += """#include "ck_tile/ops/gemm_mx.hpp" """ return instance_code def populate_kernel_dtype_layout(self): # Determine accumulator type based on datatype acc_type = "float" # Determine output type c_type = self.datatype if self.datatype in ["fp4", "fp8", "bf8"]: c_type = "fp16" # Assign layouts based on self.layout if self.kernel_name_prefix == "gemm_multi_d": a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout) elif self.kernel_name_prefix in [ "gemm_universal", "gemm_preshuffle", "grouped_gemm", "mx_gemm", "batched_gemm", ]: a_layout, b_layout, c_layout = get_abc_layouts(self.layout) instance_code = f""" using ADataType = {get_dtype_string(self.datatype)}; using BDataType = {get_dtype_string(self.datatype)}; using AccDataType = {acc_type}; using CDataType = {get_dtype_string(c_type)};""" if self.kernel_name_prefix == "mx_gemm": instance_code += """ using ScaleType = ck_tile::e8m0_t; using ScaleM = ck_tile::MXScalePointer; using ScaleN = ck_tile::MXScalePointer; using MxGemmHostArgs = ck_tile::MXGemmKernelArgs;""" if self.kernel_name_prefix == "gemm_multi_d": instance_code += f""" using D0DataType = {get_dtype_string(self.datatype)}; using D1DataType = {get_dtype_string(self.datatype)}; using DsDataType = ck_tile::tuple;""" instance_code += f""" using ALayout = {a_layout}; using BLayout = {b_layout}; using CLayout = {c_layout}; """ if self.kernel_name_prefix == "gemm_multi_d": instance_code += f""" using D0Layout = {ds_layout[0]}; using D1Layout = {ds_layout[1]}; using DsLayout = ck_tile::tuple; using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function};""" return instance_code def populate_strut_begin(self, kernel_name): instance_code = f""" // Kernel name for display constexpr const char* KERNEL_NAME = "{kernel_name}"; // Wrapper for simplified launch interface struct SelectedKernel {{ """ return instance_code def populate_tile_config(self, tile_config): instance_code = f"""// Tile configuration static constexpr ck_tile::index_t BlockSize = 256; static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};""" return instance_code def populate_trait_config(self, trait_combo): ( pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent, ) = self._normalize_trait_combo(trait_combo) instance_code = f""" // Traits configurations static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"}; static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"}; static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"}; static constexpr bool TransposeC = false; static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2"] else "false"};""" if self.kernel_name_prefix in [ "gemm_universal", "gemm_preshuffle", "grouped_gemm", "mx_gemm", "batched_gemm", ]: instance_code += f""" static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"}; static constexpr bool UseStructuredSparsity = false; static constexpr ck_tile::index_t NumWaveGroups = 1;""" if self.kernel_name_prefix == "gemm_preshuffle": instance_code += f""" static constexpr bool Preshuffle = true; static constexpr bool PermuteN = {"true" if self.config.get("permute_n") else "false"};""" else: instance_code += """ static constexpr bool Preshuffle = false;""" return instance_code def populate_initialization(self, base_pipeline_map, pipeline): # Tile Shape if self.kernel_name_prefix in ["gemm_multi_d", "batched_gemm"]: instance_code = """ // Tile shape using TileShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, ck_tile::sequence>;""" elif self.kernel_name_prefix in [ "gemm_universal", "gemm_preshuffle", "grouped_gemm", "mx_gemm", "batched_gemm", ]: instance_code = """ // Tile shape using TileShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, ck_tile::sequence, false, false>;""" # Tile partitioner instance_code += """ // Tile partitioner using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner;""" # Traits if self.kernel_name_prefix == "gemm_multi_d": instance_code += """ // Traits using Traits = ck_tile::TileGemmTraits;""" elif self.kernel_name_prefix == "gemm_preshuffle": instance_code += """ // Traits using Traits = ck_tile::TileGemmTraits;""" # Pipeline problem if self.kernel_name_prefix in ["gemm_preshuffle", "gemm_multi_d"]: instance_code += """ // Pipeline problem using GemmPipelineProblem = ck_tile::GemmPipelineProblem< ADataType, BDataType, AccDataType, TileShape, Traits>;""" # Base pipeline for hot loop detection if self.kernel_name_prefix == "gemm_preshuffle": instance_code += f""" // Base pipeline for hot loop detection using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")};""" elif self.kernel_name_prefix == "gemm_multi_d": instance_code += f""" // Base pipeline for hot loop detection using BaseGemmPipeline = {base_pipeline_map.get(pipeline)};""" return instance_code def populate_launch( self, scheduler_type_map, scheduler, pipeline_impl_map, pipeline, epilogue, k_block_per_cu, persistent, ): # Function Signature if self.kernel_name_prefix == "gemm_multi_d": instance_code = """ // Launch function static float launch(const ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) {""" elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: instance_code = """ // Launch function static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {""" elif self.kernel_name_prefix == "batched_gemm": instance_code = """ // Launch function static float launch(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& stream) {""" elif self.kernel_name_prefix == "grouped_gemm": instance_code = """ // Launch function static float launch(const std::vector>& gemm_descs, const ck_tile::stream_config& stream, void* kargs_ptr) {""" elif self.kernel_name_prefix == "mx_gemm": instance_code = """ // Launch function static float launch(const MxGemmHostArgs& args, const ck_tile::stream_config& stream) {""" # Scheduler initialization if self.kernel_name_prefix in [ "gemm_preshuffle", "gemm_multi_d", "batched_gemm", ]: instance_code += f""" constexpr auto scheduler = {scheduler_type_map.get(scheduler)};""" # Problem Initialization if self.kernel_name_prefix == "gemm_preshuffle": instance_code += """ using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< ADataType, BDataType, AccDataType, TileShape, ck_tile::TileGemmUniversalTraits, scheduler>;""" elif self.kernel_name_prefix == "batched_gemm": instance_code += """ using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< kPadM, kPadN, kPadK, DoubleSmemBuffer, ALayout, BLayout, CLayout, TransposeC>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< ADataType, BDataType, AccDataType, TileShape, GemmUniversalTraits, scheduler>;""" elif self.kernel_name_prefix == "gemm_multi_d": instance_code += """ using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< ADataType, BDataType, AccDataType, TileShape, ck_tile::TileGemmUniversalTraits, scheduler>;""" # GemmPipeline if self.kernel_name_prefix in [ "gemm_preshuffle", "gemm_multi_d", "batched_gemm", ]: instance_code += f""" using GemmPipeline = {pipeline_impl_map.get(pipeline)};""" # Scheduler initialization if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm", "mx_gemm"]: instance_code += f""" constexpr auto scheduler = {scheduler_type_map.get(scheduler)};""" # UniversalGemmProblem if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm", "mx_gemm"]: instance_code += """ using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< ADataType, BDataType, AccDataType, TileShape, ck_tile::TileGemmUniversalTraits, scheduler>;""" # GemmPipeline if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm", "mx_gemm"]: instance_code += f""" using GemmPipeline = {pipeline_impl_map.get(pipeline)};""" # Epilogue instance_code += self.populate_epilogue(epilogue) # Kernel type if self.kernel_name_prefix == "gemm_multi_d": instance_code += """ // Kernel type using GemmKernelMultiD = ck_tile::GemmKernelMultiD; // Kernel arguments auto kargs = GemmKernelMultiD::MakeKernelArgs(args); if (!GemmKernelMultiD::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); } // Get grid and block sizes const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = GemmKernelMultiD::BlockSize(); if(stream.log_level_ > 0) { std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; }""" instance_code += f""" // Launch kernel constexpr int kBlockPerCu = {k_block_per_cu}; float ave_time = ck_tile::launch_kernel( stream, ck_tile::make_kernel(GemmKernelMultiD{{}}, grids, blocks, 0, kargs)); return ave_time; }} }}; """ elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: instance_code += f""" // Kernel type using GemmKernel = ck_tile::GemmKernel; // Kernel arguments auto kargs = GemmKernel::MakeKernelArgs(args); if (!GemmKernel::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); }} // Get grid and block sizes const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; const dim3 blocks = GemmKernel::BlockSize(); if(stream.log_level_ > 0) {{ std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" << std::endl; }}""" instance_code += f""" // Launch kernel constexpr int kBlockPerCu = {k_block_per_cu}; float ave_time = ck_tile::launch_kernel( stream, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); return ave_time; }} }}; """ elif self.kernel_name_prefix == "batched_gemm": instance_code += f""" // Kernel type using GemmKernel = ck_tile::BatchedGemmKernel; // Kernel arguments auto kargs = GemmKernel::MakeKernelArgs(args); if (!GemmKernel::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); }} // Get grid and block sizes const dim3 grids = GemmKernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); const dim3 blocks = GemmKernel::BlockSize(); if(stream.log_level_ > 0) {{ std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' << "shape: " << TileShape::GetName() << '\\n' << "pipeline: " << GemmPipeline::GetName() << '\\n' << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" << std::endl; }}""" instance_code += f""" // Launch kernel constexpr int kBlockPerCu = {k_block_per_cu}; float ave_time = ck_tile::launch_kernel( stream, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); return ave_time; }} }}; """ elif self.kernel_name_prefix == "grouped_gemm": instance_code += f""" // Kernel type using Kernel = ck_tile::GroupedGemmKernel; // Kernel arguments auto kargs = Kernel::MakeKargs(gemm_descs); if(!Kernel::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping grouped gemm!"); }} // Get grid and block sizes const dim3 grids = {"Kernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "dim3(kargs.empty() ? 0 : kargs.back().block_end, 1, 1)"}; const dim3 blocks = Kernel::BlockSize(); HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, kargs.data(), kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>), hipMemcpyHostToDevice, stream.stream_id_)); if(stream.log_level_ > 0) {{ std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" << std::endl; }} // Launch kernel constexpr int kBlockPerCu = {k_block_per_cu}; float ave_time = ck_tile::launch_kernel( stream, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), kargs.size())); return ave_time; }} }}; """ elif self.kernel_name_prefix == "mx_gemm": instance_code += f""" // Kernel type using Kernel = ck_tile::MXGemmKernel; // Kernel arguments auto kargs = args; if(!Kernel::Underlying::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping mx gemm!"); }} // Get grid and block sizes const dim3 grids = Kernel::GridSize(kargs); const dim3 blocks = Kernel::BlockSize(); if(stream.log_level_ > 0) {{ std::cout << "Launching kernel: " << KERNEL_NAME << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" << std::endl; }} // Launch kernel constexpr int kBlockPerCu = {k_block_per_cu}; float ave_time = ck_tile::launch_kernel( stream, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); return ave_time; }} }}; """ return instance_code def populate_epilogue(self, epilogue): instance_code = """ // Epilogue """ if epilogue == "cshuffle": if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]: instance_code += self.populate_cshuffle_gemm_universal() elif self.kernel_name_prefix == "batched_gemm": instance_code += self.populate_cshuffle_batched_gemm() elif self.kernel_name_prefix == "gemm_multi_d": instance_code += self.populate_cshuffle_gemm_multi_d() elif self.kernel_name_prefix == "gemm_preshuffle": instance_code += self.populate_cshuffle_gemm_preshuffle() elif self.kernel_name_prefix == "mx_gemm": instance_code += self.populate_cshuffle_mx_gemm() else: # default epilogue if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm", "batched_gemm"]: instance_code += self.populate_default_gemm_universal() elif self.kernel_name_prefix == "gemm_multi_d": instance_code += self.populate_default_gemm_multi_d() elif self.kernel_name_prefix == "gemm_preshuffle": instance_code += self.populate_default_gemm_preshuffle() elif self.kernel_name_prefix == "mx_gemm": raise ValueError("MX GEMM currently supports only cshuffle epilogue") return instance_code def populate_cshuffle_gemm_universal(self): instance_code = """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< ADataType, BDataType, ck_tile::tuple<>, // DsDataType AccDataType, CDataType, ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, TileM, // kM_ TileN, // kN_ WarpPerBlock_M, // MWave_ WarpPerBlock_N, // NWave_ WarpTileM, // MPerXdl_ WarpTileN, // NPerXdl_ WarpTileK, // KPerXdl_ TransposeC, // isCTransposed_ NumWaveGroups>; // kNumWaveGroups_ using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_cshuffle_batched_gemm(self): instance_code = """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< ADataType, BDataType, ck_tile::tuple<>, // DsDataType AccDataType, CDataType, ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, UniversalGemmProblem::TransposeC>; using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_cshuffle_gemm_multi_d(self): instance_code = """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< ADataType, BDataType, DsDataType, AccDataType, CDataType, DsLayout, CLayout, ElementWiseFn, TileM, // kM_ TileN, // kN_ WarpPerBlock_M, // MWave_ WarpPerBlock_N, // NWave_ WarpTileM, // MPerXdl_ WarpTileN, // NPerXdl_ WarpTileK, // KPerXdl_ TransposeC>; // isCTransposed_ using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_cshuffle_gemm_preshuffle(self): instance_code = """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< ADataType, BDataType, ck_tile::tuple<>, // DsDataType AccDataType, CDataType, ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, TileM, // kM_ TileN, // kN_ WarpPerBlock_M, // MWave_ WarpPerBlock_N, // NWave_ WarpTileM, // MPerXdl_ WarpTileN, // NPerXdl_ WarpTileK, // KPerXdl_ TransposeC, // isCTransposed_ NumWaveGroups, // kNumWaveGroups_ false, // FixedVectorSize_ 1, // VectorSizeC_ PermuteN>; // isPermuteN_ using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_cshuffle_mx_gemm(self): instance_code = """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< ADataType, BDataType, ck_tile::tuple<>, // DsDataType AccDataType, CDataType, ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, TransposeC>; using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_default_gemm_universal(self): instance_code = """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< ADataType, BDataType, ck_tile::tuple<>, // DsDataType AccDataType, CDataType, ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, TileM, // kM_ TileN, // kN_ kPadM, kPadN, WarpTileM, // kMPerXdl_ WarpTileN, // kNPerXdl_ WarpTileK, // kKPerXdl_ TransposeC>; // isCTransposed_ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" return instance_code def populate_default_gemm_multi_d(self): instance_code = """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< ADataType, BDataType, DsDataType, AccDataType, CDataType, DsLayout, CLayout, ElementWiseFn, TileM, // kM_ TileN, // kN_ kPadM, kPadN, WarpTileM, // kMPerXdl_ WarpTileN, // kNPerXdl_ WarpTileK, // kKPerXdl_ TransposeC>; // isCTransposed_ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" return instance_code def populate_default_gemm_preshuffle(self): instance_code = """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< ADataType, BDataType, ck_tile::tuple<>, // DsDataType AccDataType, CDataType, ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, TileM, // kM_ TileN, // kN_ kPadM, kPadN, WarpTileM, // kMPerXdl_ WarpTileN, // kNPerXdl_ WarpTileK, // kKPerXdl_ TransposeC>; // isCTransposed_ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" return instance_code def _generate_cmake_individual_targets(self, kernel_list): """Generate CMake include file that creates individual targets""" cmake_code = f"""# Generated CMake file for individual {self.kernel_name_prefix} targets # Datatype: {self.datatype}, Layout: {self.layout} """ for kernel_name, trait_combo, tile_config in kernel_list: # Format tile config for CMake function tile_str = self._format_tile_config_string(tile_config) trait_str = self._format_trait_combo_string(trait_combo) cmake_code += f'create_individual_{self.kernel_name_prefix}_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' # Write CMake include file with open( self.working_path / f"{self.kernel_name_prefix}_individual_targets.cmake", "w", ) as f: f.write(cmake_code)