From 1ccecf9a11ef32574bb21138634fff414e7d1dfb Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Wed, 7 May 2025 10:59:36 +0000 Subject: [PATCH] add default config --- tile_engine/ops/gemm/CMakeLists.txt | 19 ++- tile_engine/ops/gemm/README.md | 10 +- tile_engine/ops/gemm/benchmark_gemm.hpp | 2 +- .../ops/gemm/configs/default_config.json | 81 +++++++++ .../ops/gemm/configs/gemm_problem.json | 22 +++ .../gemm/configs/instance_combination.json | 60 ------- .../ops/gemm/configs/user_provide_config.json | 81 +++++++++ tile_engine/ops/gemm/gemm_host_api.hpp | 78 ++++----- tile_engine/ops/gemm/gemm_instance_builder.py | 161 +++++++++++------- 9 files changed, 344 insertions(+), 170 deletions(-) create mode 100644 tile_engine/ops/gemm/configs/default_config.json create mode 100644 tile_engine/ops/gemm/configs/gemm_problem.json delete mode 100644 tile_engine/ops/gemm/configs/instance_combination.json create mode 100644 tile_engine/ops/gemm/configs/user_provide_config.json diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index bc613a931e..eb9a4c5c46 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -4,30 +4,39 @@ execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${CMAKE_CURRENT_BINARY_DIR} - --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --problem_json ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json + --use_default_config + --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json --list_blobs RESULT_VARIABLE ret ) set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json + ${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json + ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json ) if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") + message( FATAL_ERROR "Fail to list kernels via Python. ${ret}") endif() file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS) +message("test--------------------------------------------------") add_custom_command( OUTPUT ${GEMM_CODEGEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${CMAKE_CURRENT_BINARY_DIR} - --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --problem_json ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json + --use_default_config + --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json --gen_blobs DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt - ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json + ${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json + ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json ) set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm") diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index 0ac3629ba6..f30be86d33 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -1,10 +1,16 @@ # GEMM Matrix Multiplication -Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file `./configs/instance_combination.json`. +Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file. + +# Gemm Problem + +User needs to provide gemm problem such as datatype, layout in the config file. For reference please see `./configs/gemm_problem.json`. + # Kernel Configurations -User needs to provide kernel configuration such as datatype, layout, tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/instance_combination.json` +User can provide kernel configuration such as tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/user_provide_config.json`. The Tile engine also has default kernel configuration to expand the range of kernel configuration which is saved in `./configs/default.json`. + ## Build ``` diff --git a/tile_engine/ops/gemm/benchmark_gemm.hpp b/tile_engine/ops/gemm/benchmark_gemm.hpp index 72fca244f6..e9b8f08f10 100644 --- a/tile_engine/ops/gemm/benchmark_gemm.hpp +++ b/tile_engine/ops/gemm/benchmark_gemm.hpp @@ -57,7 +57,7 @@ class Profiler BLayout::name, CLayout::name}; - KernelInstance kernel_instance{environment_, description, problem, {-1.0f, -1.0f, -1.0f}}; + KernelInstance kernel_instance{description, problem, {-1.0f, -1.0f, -1.0f}}; float avg_time = Kernel::launch(args, s); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json new file mode 100644 index 0000000000..d726ba6add --- /dev/null +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -0,0 +1,81 @@ +{ + "tile_m": { + "values": [ + 256 + ] + }, + "tile_n": { + "values": [ + 256 + ] + }, + "tile_k": { + "values": [ + 64, + 32 + ] + }, + "warp_m": { + "values": [ + 2 + ] + }, + "warp_n": { + "values": [ + 2 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 32 + ] + }, + "warp_tile_n": { + "values": [ + 32 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + }, + "kPadM": { + "values": [ + false + ] + }, + "kPadN": { + "values": [ + false + ] + }, + "kPadK": { + "values": [ + false + ] + }, + "pipeline": { + "values": [ + "compv3", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "default", + "cshuffle" + ] + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/gemm_problem.json b/tile_engine/ops/gemm/configs/gemm_problem.json new file mode 100644 index 0000000000..6c9139bdca --- /dev/null +++ b/tile_engine/ops/gemm/configs/gemm_problem.json @@ -0,0 +1,22 @@ +{ + "layout_a": { + "values": [ + "r" + ] + }, + "layout_b": { + "values": [ + "c" + ] + }, + "layout_c": { + "values": [ + "r" + ] + }, + "datatype": { + "values": [ + "fp16" + ] + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json deleted file mode 100644 index e23df11500..0000000000 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ /dev/null @@ -1,60 +0,0 @@ -{ - - "layout_a": { - "values": ["r"] - }, - "layout_b": { - "values": ["c"] - }, - "layout_c": { - "values": ["r"] - }, - "datatype": { - "values": ["fp16"] - }, - "tile_m": { - "values": [256] - }, - "tile_n": { - "values": [256] - }, - "tile_k": { - "values": [64, 32] - }, - "warp_m": { - "values": [2] - }, - "warp_n": { - "values": [2] - }, - "warp_k": { - "values": [1] - }, - "warp_tile_m": { - "values": [32] - }, - "warp_tile_n": { - "values": [32] - }, - "warp_tile_k": { - "values": [16] - }, - "kPadM": { - "values": [false] - }, - "kPadN": { - "values": [false] - }, - "kPadK": { - "values": [false] - }, - "pipeline": { - "values": ["compv3", "mem"] - }, - "scheduler": { - "values": ["intrawave", "interwave"] - }, - "epilogue": { - "values": ["default", "cshuffle"] - } -} diff --git a/tile_engine/ops/gemm/configs/user_provide_config.json b/tile_engine/ops/gemm/configs/user_provide_config.json new file mode 100644 index 0000000000..76c2acb5c9 --- /dev/null +++ b/tile_engine/ops/gemm/configs/user_provide_config.json @@ -0,0 +1,81 @@ +{ + "tile_m": { + "values": [ + 256 + ] + }, + "tile_n": { + "values": [ + 256 + ] + }, + "tile_k": { + "values": [ + 64, + 32 + ] + }, + "warp_m": { + "values": [ + 2 + ] + }, + "warp_n": { + "values": [ + 2 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 32 + ] + }, + "warp_tile_n": { + "values": [ + 32 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + }, + "kPadM": { + "values": [ + false + ] + }, + "kPadN": { + "values": [ + false + ] + }, + "kPadK": { + "values": [ + false + ] + }, + "pipeline": { + "values": [ + "compv3", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "default", + "cshuffle" + ] + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp index 356a42bab7..7f1c518eba 100644 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -39,26 +39,25 @@ struct GemmProblem std::string dtype_a, dtype_b, dtype_acc, dtype_c; std::string layout_a, layout_b, layout_c; - std::string serialize() const + friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem) { - std::ostringstream oss; - oss << "{" - << "\"split_k\":" << split_k << "," - << "\"m\":" << m << "," - << "\"n\":" << n << "," - << "\"k\":" << k << "," - << "\"stride_a\":" << stride_a << "," - << "\"stride_b\":" << stride_b << "," - << "\"stride_c\":" << stride_c << "," - << "\"dtype_a\":\"" << dtype_a << "\"," - << "\"dtype_b\":\"" << dtype_b << "\"," - << "\"dtype_acc\":\"" << dtype_acc << "\"," - << "\"dtype_c\":\"" << dtype_c << "\"," - << "\"layout_a\":\"" << layout_a << "\"," - << "\"layout_b\":\"" << layout_b << "\"," - << "\"layout_c\":\"" << layout_c << "\"" - << "}"; - return oss.str(); + os << "{\n" + << " \"split_k\":" << problem.split_k << ",\n" + << " \"m\":" << problem.m << ",\n" + << " \"n\":" << problem.n << ",\n" + << " \"k\":" << problem.k << ",\n" + << " \"stride_a\":" << problem.stride_a << ",\n" + << " \"stride_b\":" << problem.stride_b << ",\n" + << " \"stride_c\":" << problem.stride_c << ",\n" + << " \"dtype_a\":\"" << problem.dtype_a << "\",\n" + << " \"dtype_b\":\"" << problem.dtype_b << "\",\n" + << " \"dtype_acc\":\"" << problem.dtype_acc << "\",\n" + << " \"dtype_c\":\"" << problem.dtype_c << "\",\n" + << " \"layout_a\":\"" << problem.layout_a << "\",\n" + << " \"layout_b\":\"" << problem.layout_b << "\",\n" + << " \"layout_c\":\"" << problem.layout_c << "\"\n" + << "}"; + return os; } }; @@ -80,14 +79,15 @@ struct PerformanceResult return false; } - std::string serialize() const + friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result) { - std::ostringstream oss; - oss << "{" - << "\"latency(ms)\":" << latency << "," - << "\"tflops(TFlops)\":" << tflops << "," - << "\"bandwidth(GB/s)\":" << bandwidth << "}"; - return oss.str(); + os << "{\n" + << " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency + << ",\n" + << " \"tflops(TFlops)\": " << result.tflops << ",\n" + << " \"bandwidth(GB/s)\": " << result.bandwidth << "\n" + << "}"; + return os; } }; @@ -96,20 +96,18 @@ struct Environment std::string rocm_version; std::string device_name; - std::string serialize() const + friend std::ostream& operator<<(std::ostream& os, const Environment& env) { - std::ostringstream oss; - oss << "{" - << "\"rocm_version\":\"" << rocm_version << "\"," - << "\"device_name\":\"" << device_name << "\"" - << "}"; - return oss.str(); + os << "{\n" + << " \"rocm_version\": \"" << env.rocm_version << "\",\n" + << " \"device_name\": \"" << env.device_name << "\"\n" + << "}"; + return os; } }; struct KernelInstance { - Environment env; std::string name; GemmProblem problem; PerformanceResult perf_result; @@ -122,11 +120,12 @@ struct KernelInstance friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) { - os << "{" - << "\"env\":" << obj.env.serialize() << "," - << "\"name\":\"" << obj.name << "\"," - << "\"problem\":" << obj.problem.serialize() << "," - << "\"perf_result\":" << obj.perf_result.serialize() << "}"; + os << "{\n" + << " \"name\": \"" << "{\n" + << obj.name << "\n}" << "\",\n" + << " \"problem\": \"" << obj.problem << "\",\n" + << " \"perf_result\": " << obj.perf_result << "\n" + << "}"; return os; } }; @@ -257,6 +256,7 @@ inline auto create_args(int argc, char* argv[]) "compv3", "The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.") .insert("scheduler", + "intrawave", "The type of pipeline. Possible values are compv3, compv4 or mem. Default is " "compv3.") .insert( diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 4a82dbe9be..5eedb91b2d 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -183,55 +183,76 @@ def BOOL_MAP(b_) -> str: return 'false' @dataclass -class GemmConfig: +class GemmProblem: def __init__(self, config_data): - self.matrix_cfg : Dict[str, Any] = {} - self.impl_cfg : Dict[str, Any] = {} + self.data : Dict[str, Any] = {} for key, value in config_data.items(): - if key in ["datatype", "layout_a", "layout_b", "layout_c"]: - self.matrix_cfg[key] = value - else: - self.impl_cfg[key] = value - + self.data[key] = value + @property def datatype(self) -> str: - return self.matrix_cfg["datatype"]["values"][0] + return self.data["datatype"]["values"][0] @property def layouts(self) -> List[str]: return [ - self.matrix_cfg["layout_a"]["values"][0], - self.matrix_cfg["layout_b"]["values"][0], - self.matrix_cfg["layout_c"]["values"][0] + self.data["layout_a"]["values"][0], + self.data["layout_b"]["values"][0], + self.data["layout_c"]["values"][0] ] -class GemmCodeGenerator: - def __init__(self, output_dir: str, config: GemmConfig): - self.output_dir = Path(output_dir) - if not self.output_dir.exists(): - self.output_dir.mkdir() +@dataclass +class GemmConfig: + def __init__(self, config_data): + self.data : Dict[str, Any] = {} + for key, value in config_data.items(): + self.data[key] = value + + +class GemmCodeGenerator: + def __init__(self, output_dir: str, problem: GemmProblem, use_default_config: bool, user_provide_config: Optional[GemmConfig] = None): + self.output_dir = Path(output_dir) + self.problem = problem + + self.config = [] + if use_default_config: + config_path = Path(__file__).resolve().parent / "configs" / "default_config.json" + with open(config_path, 'r') as json_file: + config_data = json.load(json_file) + default_config = GemmConfig(config_data) + self.config.append(default_config) + if user_provide_config is not None: + if not isinstance(user_provide_config, GemmConfig): + raise TypeError("user_provide_config must be a GemmConfig instance") + self.config.append(user_provide_config) + else: + if user_provide_config is None: + raise ValueError("user_provide_config must be provided when use_default_config=False") + if not isinstance(user_provide_config, GemmConfig): + raise TypeError("user_provide_config must be a GemmConfig instance") + self.config.append(user_provide_config) - self.config = config self.all_kernels = [] self.unique_configs = [] # Validate configurations - self._validate_config() + self._check_validate() - def _validate_config(self): - """Validate matrix and implementation configurations""" - # Matrix config validation + def _check_validate(self): + """Validate matrix problem and kernel configurations""" + # Matrix problem validation for param in ["datatype", "layout_a", "layout_b", "layout_c"]: - if len(self.config.matrix_cfg[param]["values"]) != 1: - raise ValueError(f"Matrix config {param} must have exactly one value") + if len(self.problem.data[param]["values"]) != 1: + raise ValueError(f"Matrix problem {param} must have exactly one value") - # Implementation traits validation + # kernel config validation required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k", "warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline", "epilogue", "scheduler", "kPadM", "kPadN", "kPadK"] - for param in required_params: - if not self.config.impl_cfg.get(param, {}).get("values"): - raise ValueError(f"Missing implementation parameter: {param}") + for config in self.config: + for param in required_params: + if not config.data.get(param, {}).get("values"): + raise ValueError(f"Missing kernel parameter: {param}") def list_all(self): """List all possible kernel configurations""" @@ -258,7 +279,12 @@ class GemmCodeGenerator: ] # Generate all unique_combinations - _unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params])) + _unique = set( + itertools.product(*[ + [value for config in self.config for value in config.data[p]["values"]] + for (p, _) in params + ]) + ) for combo in _unique: config = {name: value for (_, name), value in zip(params, combo)} pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values() @@ -280,12 +306,12 @@ class GemmCodeGenerator: def _generate_common_header(self): """Generate common header with datatypes and layout""" - ctype = self.config.datatype - atype = self.config.datatype - btype = self.config.datatype - if self.config.datatype in ['fp8', 'bf8']: + ctype = self.problem.datatype + atype = self.problem.datatype + btype = self.problem.datatype + if self.problem.datatype in ['fp8', 'bf8']: ctype = 'fp16' - elif self.config.datatype in ['int4']: + elif self.problem.datatype in ['int4']: atype = 'fp16' ctype = 'fp16' @@ -302,9 +328,9 @@ using AccDataType = float; using CDataType = {DATA_TYPE_MAP[ctype]}; // Layout configurations -using ALayout = {LAYOUT_MAP[self.config.layouts[0]]}; -using BLayout = {LAYOUT_MAP[self.config.layouts[1]]}; -using CLayout = {LAYOUT_MAP[self.config.layouts[2]]}; +using ALayout = {LAYOUT_MAP[self.problem.layouts[0]]}; +using BLayout = {LAYOUT_MAP[self.problem.layouts[1]]}; +using CLayout = {LAYOUT_MAP[self.problem.layouts[2]]}; """ @@ -504,19 +530,18 @@ struct GemmDispatcher { auto& kernel_map = get_kernel_map(); if(!kernel_map.empty()) return; \n""" - # Add tile/warp instantiations - tile_params = set(itertools.product( - self.config.impl_cfg["tile_m"]["values"], - self.config.impl_cfg["tile_n"]["values"], - self.config.impl_cfg["tile_k"]["values"], - self.config.impl_cfg["warp_m"]["values"], - self.config.impl_cfg["warp_n"]["values"], - self.config.impl_cfg["warp_k"]["values"], - self.config.impl_cfg["warp_tile_m"]["values"], - self.config.impl_cfg["warp_tile_n"]["values"], - self.config.impl_cfg["warp_tile_k"]["values"] - )) - + # Add tile/warp instantiations + tile_params = set( + itertools.product(*[ + [value for config in self.config + for value in config.data[param]["values"]] + for param in [ + "tile_m", "tile_n", "tile_k", + "warp_m", "warp_n", "warp_k", + "warp_tile_m", "warp_tile_n", "warp_tile_k" + ] + ]) + ) for group in self.all_kernels: content += f""" kernel_map["{group}"] = [](Profiler& profiler, @@ -576,27 +601,31 @@ private: (self.output_dir / "gemm_dispatcher.hpp").write_text(content) -def do_list_blobs(args, gemm_config): - generator = GemmCodeGenerator(args.working_path, gemm_config) +def do_list_blobs(args, gemm_problem, user_provide_config): + generator = GemmCodeGenerator(args.working_path, gemm_problem, args.use_default_config, user_provide_config) generator.list_all() -def do_gen_blobs(args, gemm_config): - generator = GemmCodeGenerator(args.working_path, gemm_config) +def do_gen_blobs(args, gemm_problem, user_provide_config): + generator = GemmCodeGenerator(args.working_path, gemm_problem, args.use_default_config, user_provide_config) generator.generate_all() def main(args): - # Read json file - with open(args.json, 'r') as json_file: + # Read problem json file + with open(args.problem_json, 'r') as json_file: + config_data = json.load(json_file) + gemm_problem = GemmProblem(config_data) + + # Read user provide json file + with open(args.config_json, 'r') as json_file: config_data = json.load(json_file) - gemm_config = GemmConfig(config_data) if args.list_blobs: - do_list_blobs(args, gemm_config) + do_list_blobs(args, gemm_problem, gemm_config) elif args.gen_blobs: - do_gen_blobs(args, gemm_config) + do_gen_blobs(args, gemm_problem, gemm_config) else: # If neither was specified, either do nothing or default to gen_blobs print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...") @@ -610,16 +639,22 @@ if __name__ == "__main__": description="gen API for CK gemm kernel", ) parser.add_argument( - "-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated" + "-w", "--working_path", default="./", required=False, help="The path where all the blobs are going to be generated" ) parser.add_argument( - "-j", "--json", required=True, help="Path to the json which contains the kernel configurations" + "-pj", "--problem_json", required=True, help="Path to the json which defines gemm problem" ) parser.add_argument( - "-l", "--list_blobs", action = 'store_true', help="List all kernel to file" + "-u", "--use_default_config", action = 'store_true', help="Wether use default config json file to generate kernel instance or not" ) parser.add_argument( - "-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files" + "-cj", "--config_json", required=True, help="Path to the json which contains the kernel configurations that user provide" + ) + parser.add_argument( + "-l", "--list_blobs", action = 'store_true', help="List all kernel instance to file" + ) + parser.add_argument( + "-g", "--gen_blobs", action = 'store_true', help="Generate all kernels instance into different files" ) args = parser.parse_args()