From f32ef6ed176b2f7659d453116ee6f9926a0fea38 Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Mon, 27 Oct 2025 21:11:13 -0500 Subject: [PATCH] Ck tile engine gemm (#2982) * Partial Progress : CK Tile Engine GEMM * Partial Progress : CK Tile Engine GEMM * Partial Progress : Working GEMM Code * Partial Progress : Working GEMM Code * Changinf jenkins to remove preshuffle * Partial Progress : CK TILE ENGINE GEMM Debugging * Partial Progress : Removing changes that are not GEMM * Partial Progress : Validation of full block size in GEMM * Changes in Jenkins to run only fp16 and bf16 * Addressing Review Comments * Partial Progress : Addressing CI issues * Partial Progress - Runing GEMM for fp16,bf16 and rcr * Clang * Adding fp8 and bf8 * Adding fp8 and bf8 * Adding additional architrcture * Limited datatypes and layouts * Adding k_block_per_cu in test config * Changes to faling CI errors * Changes to faling CI errors * Validation for GEMM * Adding Layout support * Adding Validations * Adding layout in jenkins * Update on Jenkins * Distribution validation for GEMM * Resolving merge conflicts * Solving merge conflicts [ROCm/composable_kernel commit: 7fc0a38e903e265374a82d02a19db697290d48fb] --- Jenkinsfile | 1 - tile_engine/ops/gemm/CMakeLists.txt | 14 +- tile_engine/ops/gemm/README.md | 2 +- .../ops/gemm/{ => commons}/test_benchmark.sh | 0 .../ops/gemm/{ => commons}/test_validation.py | 0 .../gemm/{ => commons}/validation_utils.py | 250 ++++++++ tile_engine/ops/gemm/configs/benchmark.json | 105 ---- .../ops/gemm/configs/custom_ci_config.json | 88 --- .../ops/gemm/configs/default_config.json | 5 +- .../ops/gemm/configs/gfx120x_config.json | 102 ---- .../gemm/configs/user_provided_config.json | 30 +- ...{benchmark_gemm.hpp => gemm_benchmark.hpp} | 0 tile_engine/ops/gemm/gemm_benchmark.py | 42 -- ...m_single.cpp => gemm_benchmark_single.cpp} | 8 +- tile_engine/ops/gemm/gemm_common.hpp | 47 +- tile_engine/ops/gemm/gemm_instance_builder.py | 564 +++++++----------- tile_engine/ops/gemm/gemm_profiler.hpp | 2 +- tile_engine/ops/gemm/json_config.py | 231 ------- 18 files changed, 504 insertions(+), 987 deletions(-) rename tile_engine/ops/gemm/{ => commons}/test_benchmark.sh (100%) rename tile_engine/ops/gemm/{ => commons}/test_validation.py (100%) rename tile_engine/ops/gemm/{ => commons}/validation_utils.py (60%) delete mode 100644 tile_engine/ops/gemm/configs/benchmark.json delete mode 100644 tile_engine/ops/gemm/configs/custom_ci_config.json delete mode 100644 tile_engine/ops/gemm/configs/gfx120x_config.json rename tile_engine/ops/gemm/{benchmark_gemm.hpp => gemm_benchmark.hpp} (100%) rename tile_engine/ops/gemm/{benchmark_gemm_single.cpp => gemm_benchmark_single.cpp} (96%) delete mode 100644 tile_engine/ops/gemm/json_config.py diff --git a/Jenkinsfile b/Jenkinsfile index 9acbbeeca2..c642e2d3b1 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1574,7 +1574,6 @@ pipeline { -D GPU_TARGETS="gfx1201" \ -D GEMM_DATATYPE="fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ - -DGEMM_CONFIG_FILE=gfx120x_config.json \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index 1eb49c0c7f..3c18fc4952 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -65,7 +65,7 @@ function(create_individual_gemm_target datatype layout trait tile_config config_ # Create the executable add_executable(${target_name} EXCLUDE_FROM_ALL - ${GEMM_SOURCE_DIR}/benchmark_gemm_single.cpp + ${GEMM_SOURCE_DIR}/gemm_benchmark_single.cpp ${instance_header} ) @@ -103,9 +103,9 @@ function(create_individual_gemm_target datatype layout trait tile_config config_ list(GET trait_parts 1 epilogue) list(GET trait_parts 2 scheduler) - add_dependencies(benchmark_gemm_${pipeline} ${target_name}) - add_dependencies(benchmark_gemm_${epilogue} ${target_name}) - add_dependencies(benchmark_gemm_${scheduler} ${target_name}) + add_dependencies(benchmark_gemm_${pipeline}_pipeline ${target_name}) + add_dependencies(benchmark_gemm_${epilogue}_epilogue ${target_name}) + add_dependencies(benchmark_gemm_${scheduler}_scheduler ${target_name}) endfunction() # Function to build individual GEMM targets @@ -286,15 +286,15 @@ else() set(GEMM_SCHEDULERS "intrawave;interwave") foreach(pipeline IN LISTS GEMM_PIPELINES) - add_custom_target(benchmark_gemm_${pipeline}) + add_custom_target(benchmark_gemm_${pipeline}_pipeline) endforeach() foreach(epilogue IN LISTS GEMM_EPILOGUES) - add_custom_target(benchmark_gemm_${epilogue}) + add_custom_target(benchmark_gemm_${epilogue}_epilogue) endforeach() foreach(scheduler IN LISTS GEMM_SCHEDULERS) - add_custom_target(benchmark_gemm_${scheduler}) + add_custom_target(benchmark_gemm_${scheduler}_scheduler) endforeach() # Build individual targets for each datatype/layout combination diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index 01ffbb6da7..ce62f8dca5 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -187,7 +187,7 @@ python gemm_instance_builder.py \ --datatype fp16 \ --layout rcr \ --config_json configs/user_provided_config.json \ - --gen_individual + --gen_all_individual ``` #### gemm_instance_builder_parallel.py diff --git a/tile_engine/ops/gemm/test_benchmark.sh b/tile_engine/ops/gemm/commons/test_benchmark.sh similarity index 100% rename from tile_engine/ops/gemm/test_benchmark.sh rename to tile_engine/ops/gemm/commons/test_benchmark.sh diff --git a/tile_engine/ops/gemm/test_validation.py b/tile_engine/ops/gemm/commons/test_validation.py similarity index 100% rename from tile_engine/ops/gemm/test_validation.py rename to tile_engine/ops/gemm/commons/test_validation.py diff --git a/tile_engine/ops/gemm/validation_utils.py b/tile_engine/ops/gemm/commons/validation_utils.py similarity index 60% rename from tile_engine/ops/gemm/validation_utils.py rename to tile_engine/ops/gemm/commons/validation_utils.py index c71f0e8a09..3077ae4ba0 100644 --- a/tile_engine/ops/gemm/validation_utils.py +++ b/tile_engine/ops/gemm/commons/validation_utils.py @@ -23,6 +23,31 @@ ELEMENT_SIZE_MAP = { "fp64": 8, } +WARP_SUPPORTED_COMBINATIONS = { + "gfx90a": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + "gfx942": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + "gfx950": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + "gfx1201": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1], + ], +} + +# [TODO] Handle this while moving code to commons # Supported warp tile combinations for different GPU architectures and data types WARP_TILE_SUPPORTED_COMBINATIONS = { "gfx90a": { @@ -290,6 +315,7 @@ def is_tile_config_valid( b_datatype: str, c_datatype: str, pipeline: str, + layout: str, gpu_target: str, trait_name: str = None, ) -> bool: @@ -348,6 +374,24 @@ def is_tile_config_valid( logging.debug(f"LDS validation failed: {lds_error}") return False + # Validate whole workgroup cover configuration + wr_cover_valid, wg_cover_error = validate_whole_wg_cover_configuration( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + layout, + a_datatype, + b_datatype, + ) + if not wr_cover_valid: + logging.debug( + f"Whole workgroup cover configuration validation failed: {wg_cover_error}" + ) + return False + # Validate warp tile combination warp_tile_valid, warp_tile_error = validate_warp_tile_combination( warp_tile_m, @@ -363,3 +407,209 @@ def is_tile_config_valid( return False return True + + +# [TODO] Handle this while moving code to commons Add more datatype to this function if needed +def get_dtype_string(datatype: str) -> str: + """Get C++ type string for datatype""" + dtype_map = { + "fp16": "ck_tile::fp16_t", + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp64": "double", + } + return dtype_map.get(datatype, "float") + + +LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", +} + + +def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: + """ + Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. + """ + code = str(layout_code).strip().lower() + + a_layout = LAYOUT_MAP[code[0]] + b_layout = LAYOUT_MAP[code[1]] + c_layout = LAYOUT_MAP[code[2]] + return a_layout, b_layout, c_layout + + +def validate_whole_wg_cover_configuration( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + layout, + a_datatype, + b_datatype, +) -> Tuple[bool, str]: + # Validate whole workgroup cover configuration + + warp_size = 64 + NumWarps = warp_m * warp_n * warp_k + BlockSize = NumWarps * warp_size + + XPerTile = 0 + YPerTile = 0 + vector_load_size = 0 + + # A matrix validation + if layout[0] == "r": + XPerTile = tile_k + YPerTile = tile_m + + vector_load_size = get_global_vector_load_size( + BlockSize, tile_k, a_datatype, tile_m, tile_k + ) + + elif layout[0] == "c": + vector_load_size = get_global_vector_load_size( + BlockSize, tile_k, a_datatype, tile_m, tile_m + ) + + # Validate distribution + XPerTile = tile_k + YPerTile = tile_m + + wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( + XPerTile, YPerTile, BlockSize, vector_load_size, warp_size + ) + + if not wg_cover_core_valid: + print("I am here 1") + logging.debug( + f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}" + ) + return False, wg_cover_core_error + + XPerTile = tile_m + YPerTile = tile_k + + wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( + XPerTile, YPerTile, BlockSize, vector_load_size, warp_size + ) + + if not wg_cover_core_valid: + logging.debug( + f"whole workgroup cover failed for Matrix A: {wg_cover_core_error}" + ) + return False, wg_cover_core_error + + # B matrix validation + if layout[1] == "r": + vector_load_size = get_global_vector_load_size( + BlockSize, tile_k, b_datatype, tile_n, tile_n + ) + + # Validate distribution + XPerTile = tile_k + YPerTile = tile_n + + wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( + XPerTile, YPerTile, BlockSize, vector_load_size, warp_size + ) + + if not wg_cover_core_valid: + print("I am here 3") + logging.debug( + f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}" + ) + return False, wg_cover_core_error + + XPerTile = tile_n + YPerTile = tile_k + + elif layout[1] == "c": + XPerTile = tile_k + YPerTile = tile_n + + vector_load_size = get_global_vector_load_size( + BlockSize, tile_k, b_datatype, tile_n, tile_k + ) + + wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( + XPerTile, YPerTile, BlockSize, vector_load_size, warp_size + ) + if not wg_cover_core_valid: + print("I am here 4") + logging.debug( + f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}" + ) + return False, wg_cover_core_error + + return True, "" + + +def wg_cover_core_validation( + XPerTile: int, + YPerTile: int, + BlockSize: int, + vector_load_size: int, + warp_size: int, +) -> Tuple[bool, str]: + if XPerTile % vector_load_size != 0: + return False + + num_warps = BlockSize / warp_size + LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size) + + X1 = LargestVec if vector_load_size > LargestVec else vector_load_size + X0 = XPerTile / X1 + Y1 = warp_size // X0 + + if X0 * Y1 != warp_size: + return False, "" + + return True, "" + + +def get_global_vector_load_size( + BlockSize: int, + KPerBlock: int, + DataType: str, + MNPerBlock: int, + XPerTile: int, +) -> int: + elements_per_thread = MNPerBlock * KPerBlock / BlockSize + PackedSize = 1 + + if ( + XPerTile % (PackedSize * 32 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0 + and PackedSize == 2 + ): + return PackedSize * 32 / element_size(DataType) + elif ( + XPerTile % (PackedSize * 16 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 16 / element_size(DataType)) == 0 + ): + return int(PackedSize * 16 / element_size(DataType)) + + elif ( + XPerTile % (PackedSize * 8 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 8 / element_size(DataType)) == 0 + ): + return int(PackedSize * 8 / element_size(DataType)) + elif ( + element_size(DataType) >= PackedSize * 4 + and XPerTile % (PackedSize * 4 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 4 / element_size(DataType)) == 0 + ): + return int(PackedSize * 4 / element_size(DataType)) + elif ( + element_size(DataType) >= PackedSize * 2 + and XPerTile % (PackedSize * 2 / element_size(DataType)) == 0 + and elements_per_thread % (PackedSize * 2 / element_size(DataType)) == 0 + ): + return int(PackedSize * 2 / element_size(DataType)) + else: + return PackedSize diff --git a/tile_engine/ops/gemm/configs/benchmark.json b/tile_engine/ops/gemm/configs/benchmark.json deleted file mode 100644 index b15b587147..0000000000 --- a/tile_engine/ops/gemm/configs/benchmark.json +++ /dev/null @@ -1,105 +0,0 @@ -{ - "problem": { - }, - "tile_config": { - "tile_m": { - "max": 256, - "min": 64, - "step": 64 - }, - "tile_n": { - "max": 256, - "min": 64, - "step": 64 - }, - "tile_k": { - "max": 256, - "min": 64, - "step": 64 - }, - "warp_m": { - "values": [ - 4, - 2, - 1 - ] - }, - "warp_n": { - "values": [ - 4, - 2, - 1 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 4, - 16, - 32 - ] - }, - "warp_tile_n": { - "values": [ - 16, - 32, - 64 - ] - }, - "warp_tile_k": { - "values": [ - 8, - 16, - 32, - 64, - 128 - ] - } - }, - "trait_config": { - "pipeline": { - "values": [ - "compv3", - "compv4", - "mem" - ] - }, - "scheduler": { - "values": [ - "intrawave", - "interwave" - ] - }, - "epilogue": { - "values": [ - "cshuffle", - "default" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - }, - "persistent": { - "values": [ - false, - true - ] - } - } -} \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/custom_ci_config.json b/tile_engine/ops/gemm/configs/custom_ci_config.json deleted file mode 100644 index ca6c7230fd..0000000000 --- a/tile_engine/ops/gemm/configs/custom_ci_config.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "problem": { - }, - "tile_config": { - "tile_m": { - "values": [ - 128 ] - }, - "tile_n": { - "values": [ - 128 - ] - }, - "tile_k": { - "values": [ - 128 - ] - }, - "warp_m": { - "values": [ - 2 - ] - }, - "warp_n": { - "values": [ - 2 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 32 - ] - }, - "warp_tile_n": { - "values": [ - 32 - ] - }, - "warp_tile_k": { - "values": [ - 16 - ] - } - }, - "trait_config": { - "pipeline": { - "values": [ - "compv3" - ] - }, - "scheduler": { - "values": [ - "intrawave" - ] - }, - "epilogue": { - "values": [ - "default" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - }, - "persistent": { - "values": [ - false, - true - ] - } - } -} \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json index b245c3167f..2447428158 100644 --- a/tile_engine/ops/gemm/configs/default_config.json +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -1,6 +1,4 @@ { - "problem": { - }, "tile_config": { "tile_m": { "max": 256, @@ -101,5 +99,6 @@ true ] } - } + }, + "k_block_per_cu": 1 } diff --git a/tile_engine/ops/gemm/configs/gfx120x_config.json b/tile_engine/ops/gemm/configs/gfx120x_config.json deleted file mode 100644 index 6c4a5d0ec0..0000000000 --- a/tile_engine/ops/gemm/configs/gfx120x_config.json +++ /dev/null @@ -1,102 +0,0 @@ -{ - "problem": { - }, - "tile_config": { - "tile_m": { - "values": [ - 256, - 128, - 64 - ] - }, - "tile_n": { - "values": [ - 256, - 128, - 64 - ] - }, - "tile_k": { - "values": [ - 256, - 128, - 64 - ] - }, - "warp_m": { - "values": [ - 4, - 2, - 1 - ] - }, - "warp_n": { - "values": [ - 4, - 2, - 1 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 16 - ] - }, - "warp_tile_n": { - "values": [ - 16 - ] - }, - "warp_tile_k": { - "values": [ - 16 - ] - } - }, - "trait_config": { - "pipeline": { - "values": [ - "compv3", - "mem" - ] - }, - "scheduler": { - "values": [ - "intrawave", - "interwave" - ] - }, - "epilogue": { - "values": [ - "cshuffle", - "default" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - }, - "persistent": { - "values": [ - false, - true - ] - } - } -} diff --git a/tile_engine/ops/gemm/configs/user_provided_config.json b/tile_engine/ops/gemm/configs/user_provided_config.json index 76e194f6b9..40a7dda6cc 100644 --- a/tile_engine/ops/gemm/configs/user_provided_config.json +++ b/tile_engine/ops/gemm/configs/user_provided_config.json @@ -1,31 +1,28 @@ { - "problem": { - }, "tile_config": { "tile_m": { "values": [ - 128, - 256 + 64 ] }, "tile_n": { "values": [ - 128 + 192 ] }, "tile_k": { "values": [ - 128 + 64 ] }, "warp_m": { "values": [ - 4 + 2 ] }, "warp_n": { "values": [ - 1 + 2 ] }, "warp_k": { @@ -35,36 +32,33 @@ }, "warp_tile_m": { "values": [ - 16, 32 + 32 ] }, "warp_tile_n": { "values": [ - 16, 32 + 32 ] }, "warp_tile_k": { "values": [ - 16, 32 + 8 ] } }, "trait_config": { "pipeline": { "values": [ - "compv3", - "mem" + "compv4" ] }, "scheduler": { "values": [ - "intrawave", - "interwave" + "intrawave" ] }, "epilogue": { "values": [ - "default", "cshuffle" ] }, @@ -85,9 +79,9 @@ }, "persistent": { "values": [ - false, true ] } - } + }, + "k_block_per_cu": 1 } \ No newline at end of file diff --git a/tile_engine/ops/gemm/benchmark_gemm.hpp b/tile_engine/ops/gemm/gemm_benchmark.hpp similarity index 100% rename from tile_engine/ops/gemm/benchmark_gemm.hpp rename to tile_engine/ops/gemm/gemm_benchmark.hpp diff --git a/tile_engine/ops/gemm/gemm_benchmark.py b/tile_engine/ops/gemm/gemm_benchmark.py index 3b0f0e619d..9f323f2640 100755 --- a/tile_engine/ops/gemm/gemm_benchmark.py +++ b/tile_engine/ops/gemm/gemm_benchmark.py @@ -273,48 +273,6 @@ class GemmBenchmark: print(f"Error reading JSON file {json_file}: {e}") return None - def parse_benchmark_output(self, output: str) -> Optional[Dict]: - """Parse the benchmark output format - extract JSON directly""" - try: - # Find JSON block between asterisk markers - lines = output.split("\n") - json_start = -1 - json_end = -1 - - for i, line in enumerate(lines): - if line.strip().startswith("{"): - json_start = i - elif line.strip().endswith("}") and json_start != -1: - json_end = i - break - - if json_start != -1 and json_end != -1: - json_text = "\n".join(lines[json_start : json_end + 1]) - data = json.loads(json_text) - - # Return the complete JSON data as-is, just add some convenience fields - result = data.copy() - if "perf_result" in data: - perf = data["perf_result"] - # Add convenience fields for backward compatibility - result["time_ms"] = perf.get("latency(ms)", 0) - result["tflops"] = perf.get("tflops(TFlops)", 0) - result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0) - - return result - - return None - - except json.JSONDecodeError as e: - if self.verbose: - print(f"Failed to parse JSON: {e}") - print(f"Output was: {output[:200]}...") - return None - except Exception as e: - if self.verbose: - print(f"Error parsing output: {e}") - return None - def benchmark_problem_size( self, kernels: List[Path], diff --git a/tile_engine/ops/gemm/benchmark_gemm_single.cpp b/tile_engine/ops/gemm/gemm_benchmark_single.cpp similarity index 96% rename from tile_engine/ops/gemm/benchmark_gemm_single.cpp rename to tile_engine/ops/gemm/gemm_benchmark_single.cpp index 58532ffbe8..bbcc6eb505 100644 --- a/tile_engine/ops/gemm/benchmark_gemm_single.cpp +++ b/tile_engine/ops/gemm/gemm_benchmark_single.cpp @@ -30,9 +30,9 @@ inline auto create_args(int argc, char* argv[]) .insert("stride_c", "0", "The stride value for tensor C. Default is 0.") .insert("split_k", "1", "The split value for k dimension. Default is 1.") .insert("verify", - "0", + "2", "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " - "for validation on GPU. Default is 0, no validation.") + "for validation on GPU. Default is 2, GPU validation.") .insert("log", "false", "Whether output kernel instance information or not. Possible values are true or " @@ -75,7 +75,7 @@ inline auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser) +void benchmark_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType @@ -149,7 +149,7 @@ int main(int argc, char* argv[]) if(!result) return EXIT_FAILURE; - benchmark_gemm_single(parser); + benchmark_single(parser); return 0; } catch(const std::exception& e) diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp index 179aeb7307..4732f2a1ba 100644 --- a/tile_engine/ops/gemm/gemm_common.hpp +++ b/tile_engine/ops/gemm/gemm_common.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" +//[TODO] This can be moved to commons // DataTypeTraits for all supported types template struct DataTypeTraits; @@ -97,49 +98,3 @@ struct KernelTraits { } }; - -// Helper to extract traits from kernel name -inline KernelTraits extract_traits_from_name(const std::string& kernel_name) -{ - KernelTraits traits; - - // Extract pipeline - if(kernel_name.find("compv3") != std::string::npos) - { - traits.pipeline = "compv3"; - } - else if(kernel_name.find("compv4") != std::string::npos) - { - traits.pipeline = "compv4"; - } - else if(kernel_name.find("mem") != std::string::npos) - { - traits.pipeline = "mem"; - } - - // Extract scheduler - if(kernel_name.find("interwave") != std::string::npos) - { - traits.scheduler = "interwave"; - } - else - { - traits.scheduler = "intrawave"; - } - - // Extract epilogue - if(kernel_name.find("default") != std::string::npos && - kernel_name.find("default_") == std::string::npos) - { - traits.epilogue = "default"; - } - else - { - traits.epilogue = "cshuffle"; - } - - // Padding flags would need to be extracted from the kernel configuration - // For now, we'll leave them as false - - return traits; -} diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index ae9e5a7728..81b25e592f 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -8,8 +8,12 @@ import multiprocessing import concurrent.futures from pathlib import Path import logging -from typing import Optional -from validation_utils import is_tile_config_valid, is_trait_combination_valid +from commons.validation_utils import ( + is_tile_config_valid, + is_trait_combination_valid, + get_dtype_string, + get_abc_layouts, +) logging.basicConfig(level=logging.INFO) @@ -29,149 +33,150 @@ class GemmKernelBuilder: if config_json and os.path.exists(config_json): with open(config_json, "r") as f: self.config = json.load(f) - else: - self.config = self._get_default_config() - def _get_default_config(self): - """Return default configuration if no config file is provided""" - # Define base tile configurations that work for all layouts - base_fp16_configs = [ - { - "tile_m": 256, - "tile_n": 256, - "tile_k": 32, - "warp_m": 2, - "warp_n": 2, - "warp_k": 1, - "warp_tile_m": 32, - "warp_tile_n": 32, - "warp_tile_k": 32, - }, - { - "tile_m": 256, - "tile_n": 128, - "tile_k": 32, - "warp_m": 2, - "warp_n": 2, - "warp_k": 1, - "warp_tile_m": 32, - "warp_tile_n": 32, - "warp_tile_k": 16, - }, - ] + def write_kernel_list(self): + """Write kernel list to file for CMake to read (with comprehensive validation)""" + # Get configurations using comprehensive validation + tile_configs = self._get_tile_configs(fast_mode=False) + trait_combos = self._generate_trait_combinations() - base_fp8_configs = [ - { - "tile_m": 256, - "tile_n": 256, - "tile_k": 32, - "warp_m": 4, - "warp_n": 1, - "warp_k": 1, - "warp_tile_m": 32, - "warp_tile_n": 32, - "warp_tile_k": 32, - }, - { - "tile_m": 256, - "tile_n": 128, - "tile_k": 32, - "warp_m": 1, - "warp_n": 4, - "warp_k": 1, - "warp_tile_m": 16, - "warp_tile_n": 16, - "warp_tile_k": 32, - }, - ] + kernel_list = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo - # Create configurations for all supported layouts - all_layouts = ["rcr", "rrr", "ccr", "crr"] - tile_configs = {} + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" - for datatype, base_configs in [ - ("fp16", base_fp16_configs), - ("fp8", base_fp8_configs), - ]: - tile_configs[datatype] = {} - for layout in all_layouts: - tile_configs[datatype][layout] = base_configs + # Create tile configuration string + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - return { - "tile_configs": tile_configs, - "traits": { - "pipelines": ["mem", "compv3", "compv4"], - "epilogues": ["default", "cshuffle"], - "schedulers": ["intrawave", "interwave"], - }, - "structured_sparsity": ["false"], - "padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]}, - "persistent": ["false"], - } + kernel_name += f"_{tile_str}" + + kernel_list.append( + { + "name": kernel_name, + "tile_config": tile_config, + "trait_combo": trait_combo, + } + ) + + # Write kernel count + with open(self.working_path / "gemm_kernel_count.txt", "w") as f: + f.write(str(len(kernel_list))) + + # Write kernel list + with open(self.working_path / "gemm_kernel_list.txt", "w") as f: + for kernel in kernel_list: + # Format: kernel_name|tile_config|trait_combo + tile_config = kernel["tile_config"] + trait_combo = kernel["trait_combo"] + + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = ( + f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" + + "_".join(str(x) for x in trait_combo[3:]) + ) + + f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") + + print(f"Listed {len(kernel_list)} kernel configurations") def _get_tile_configs(self, fast_mode=False): """Get tile configurations for the current datatype and layout""" - if "tile_configs" in self.config: - # Old format - return ( - self.config["tile_configs"].get(self.datatype, {}).get(self.layout, []) + tile_config = self.config["tile_config"] + + # Generate values in the config if default range is given + if tile_config.get("tile_m").get("values") is None: + tile_config.get("tile_m")["values"] = self._generate_values( + tile_config.get("tile_m").get("min"), + tile_config.get("tile_m").get("max"), + tile_config.get("tile_m").get("step"), + ) + if tile_config.get("tile_n").get("values") is None: + tile_config.get("tile_n")["values"] = self._generate_values( + tile_config.get("tile_n").get("min"), + tile_config.get("tile_n").get("max"), + tile_config.get("tile_n").get("step"), + ) + if tile_config.get("tile_k").get("values") is None: + tile_config.get("tile_k")["values"] = self._generate_values( + tile_config.get("tile_k").get("min"), + tile_config.get("tile_k").get("max"), + tile_config.get("tile_k").get("step"), ) - elif "tile_config" in self.config: - # New format - generate combinations from individual parameter values - tile_config = self.config["tile_config"] - # Get all possible values for each parameter - tile_m_values = tile_config.get("tile_m", {}).get("values", [256]) - tile_n_values = tile_config.get("tile_n", {}).get("values", [256]) - tile_k_values = tile_config.get("tile_k", {}).get("values", [32]) - warp_m_values = tile_config.get("warp_m", {}).get("values", [2]) - warp_n_values = tile_config.get("warp_n", {}).get("values", [2]) - warp_k_values = tile_config.get("warp_k", {}).get("values", [1]) - warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32]) - warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32]) - warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32]) + # Get all possible values for each parameter + tile_m_values = tile_config.get("tile_m").get("values") + tile_n_values = tile_config.get("tile_n").get("values") + tile_k_values = tile_config.get("tile_k").get("values") + warp_m_values = tile_config.get("warp_m").get("values") + warp_n_values = tile_config.get("warp_n").get("values") + warp_k_values = tile_config.get("warp_k").get("values") + warp_tile_m_values = tile_config.get("warp_tile_m").get("values") + warp_tile_n_values = tile_config.get("warp_tile_n").get("values") + warp_tile_k_values = tile_config.get("warp_tile_k").get("values") - # Generate all combinations - configs = [] - for tile_m in tile_m_values: - for tile_n in tile_n_values: - for tile_k in tile_k_values: - for warp_m in warp_m_values: - for warp_n in warp_n_values: - for warp_k in warp_k_values: - for warp_tile_m in warp_tile_m_values: - for warp_tile_n in warp_tile_n_values: - for warp_tile_k in warp_tile_k_values: - # Validate configuration - if self._validate_tile_config( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - fast_mode=fast_mode, - ): - configs.append( - { - "tile_m": tile_m, - "tile_n": tile_n, - "tile_k": tile_k, - "warp_m": warp_m, - "warp_n": warp_n, - "warp_k": warp_k, - "warp_tile_m": warp_tile_m, - "warp_tile_n": warp_tile_n, - "warp_tile_k": warp_tile_k, - } - ) - return configs - else: - # Fallback to default - return [] + # Generate all combinations + configs = [] + for tile_m in tile_m_values: + for tile_n in tile_n_values: + for tile_k in tile_k_values: + for warp_m in warp_m_values: + for warp_n in warp_n_values: + for warp_k in warp_k_values: + for warp_tile_m in warp_tile_m_values: + for warp_tile_n in warp_tile_n_values: + for warp_tile_k in warp_tile_k_values: + # Validate configuration + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + fast_mode=fast_mode, + ): + configs.append( + { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + } + ) + return configs + + def _generate_values(self, min_val, max_val, step): + """Generate a list of values from min to max with the given step""" + values = [] + val = min_val + while val <= max_val: + values.append(val) + val += step + return values def _validate_tile_config( self, @@ -184,7 +189,7 @@ class GemmKernelBuilder: warp_tile_m, warp_tile_n, warp_tile_k, - pipeline="mem", # Default pipeline for validation + pipeline="compv4", # Default pipeline for validation fast_mode=False, # Add fast mode option ): """Validate that tile configuration is reasonable""" @@ -213,6 +218,8 @@ class GemmKernelBuilder: b_datatype = self.datatype c_datatype = self.datatype + layout = self.layout + # Special handling for certain data types if self.datatype in ["fp8", "bf8"]: c_datatype = "fp16" @@ -232,125 +239,50 @@ class GemmKernelBuilder: b_datatype, c_datatype, pipeline, + layout, self.gpu_target, ) def _generate_trait_combinations(self): """Generate all combinations of traits""" - if "traits" in self.config: - # Old format - traits = self.config["traits"] - pipelines = traits["pipelines"] - epilogues = traits["epilogues"] - schedulers = traits["schedulers"] - padding = self.config["padding"] - persistent = self.config["persistent"] + trait_config = self.config["trait_config"] - all_combinations = list( - itertools.product( - pipelines, - epilogues, - schedulers, - padding["pad_m"], - padding["pad_n"], - padding["pad_k"], - persistent, + pipelines = trait_config.get("pipeline").get("values") + epilogues = trait_config.get("epilogue").get("values") + schedulers = trait_config.get("scheduler").get("values") + pad_m_values = trait_config.get("pad_m").get("values") + pad_n_values = trait_config.get("pad_n").get("values") + pad_k_values = trait_config.get("pad_k").get("values") + persistent_values = trait_config.get("persistent").get("values") + + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + pad_m_values, + pad_n_values, + pad_k_values, + persistent_values, + ) + ) + + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler = combo[:3] + if is_trait_combination_valid(pipeline, epilogue, scheduler): + combinations.append(combo) + else: + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" ) - ) - - # Filter out unsupported trait combinations - combinations = [] - for combo in all_combinations: - pipeline, epilogue, scheduler = combo[:3] - if is_trait_combination_valid(pipeline, epilogue, scheduler): - combinations.append(combo) - else: - logging.debug( - f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" - ) - - elif "trait_config" in self.config: - # New format - trait_config = self.config["trait_config"] - - pipelines = trait_config.get("pipeline", {}).get("values", ["mem"]) - epilogues = trait_config.get("epilogue", {}).get("values", ["default"]) - schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"]) - pad_m_values = trait_config.get("pad_m", {}).get("values", [False]) - pad_n_values = trait_config.get("pad_n", {}).get("values", [False]) - pad_k_values = trait_config.get("pad_k", {}).get("values", [False]) - persistent_values = trait_config.get("persistent", {}).get( - "values", [False] - ) - - all_combinations = list( - itertools.product( - pipelines, - epilogues, - schedulers, - pad_m_values, - pad_n_values, - pad_k_values, - persistent_values, - ) - ) - - # Filter out unsupported trait combinations - combinations = [] - for combo in all_combinations: - pipeline, epilogue, scheduler = combo[:3] - if is_trait_combination_valid(pipeline, epilogue, scheduler): - combinations.append(combo) - else: - logging.debug( - f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" - ) - else: - # Fallback to minimal default - combinations = [("mem", "default", "intrawave", False, False, False, False)] - return combinations - def _get_dtype_string(self): - """Get C++ type string for datatype""" - dtype_map = { - "fp16": "ck_tile::fp16_t", - "fp8": "ck_tile::fp8_t", - "bf16": "ck_tile::bf16_t", - "fp32": "float", - "fp64": "double", - } - return dtype_map.get(self.datatype, "float") - - _LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", - } - - def _get_abc_layouts(self, layout_code: Optional[str] = None): - """ - Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. - If layout_code is None, use self.layout. - """ - if layout_code is None: - # fall back to the instance field - layout_code = getattr(self, "layout", "") - - code = str(layout_code).strip().lower() - - if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code): - raise ValueError( - f"Invalid layout '{layout_code}'. " - "Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)." - ) - - a_layout = self._LAYOUT_MAP[code[0]] - b_layout = self._LAYOUT_MAP[code[1]] - c_layout = self._LAYOUT_MAP[code[2]] - return a_layout, b_layout, c_layout - - def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True): + def _generate_kernel_instance( + self, tile_config, trait_combo, k_block_per_cu, is_header=True + ): """Generate a single kernel instance""" ( pipeline, @@ -383,6 +315,13 @@ class GemmKernelBuilder: "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", } + # Map pipeline names to base pipeline for hot loop detection + base_pipeline_map = { + "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", + "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", + } + # Map scheduler names to the correct enum values scheduler_type_map = { "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", @@ -392,23 +331,14 @@ class GemmKernelBuilder: # Determine accumulator type based on datatype acc_type = "float" - if self.datatype in ["int8", "int4"]: - acc_type = "ck_tile::int32_t" # Determine output type - c_type = self._get_dtype_string() + c_type = self.datatype if self.datatype in ["fp8", "bf8"]: - c_type = "ck_tile::fp16_t" + c_type = "fp16" # Determine layouts based on self.layout - a_layout, b_layout, c_layout = self._get_abc_layouts() - - # Map pipeline names to base pipeline for hot loop detection - base_pipeline_map = { - "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", - "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", - "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", - } + a_layout, b_layout, c_layout = get_abc_layouts(self.layout) # Generate kernel instance code using the correct API pragma_line = "#pragma once\n" if is_header else "" @@ -425,10 +355,10 @@ class GemmKernelBuilder: #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" -using ADataType = {self._get_dtype_string()}; -using BDataType = {self._get_dtype_string()}; +using ADataType = {get_dtype_string(self.datatype)}; +using BDataType = {get_dtype_string(self.datatype)}; using AccDataType = {acc_type}; -using CDataType = {c_type}; +using CDataType = {get_dtype_string(c_type)}; using ALayout = {a_layout}; using BLayout = {b_layout}; @@ -484,7 +414,7 @@ struct SelectedKernel {{ Traits>; // Base pipeline for hot loop detection - using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseGemmPipelineAgBgCrMem")}; + using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}; static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ const ck_tile::index_t k_grain = args.k_batch * TileK; @@ -498,7 +428,7 @@ struct SelectedKernel {{ const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Intrawave")}; + constexpr auto scheduler = {scheduler_type_map.get(scheduler)}; [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< @@ -514,7 +444,7 @@ struct SelectedKernel {{ has_hot_loop_v, tail_number_v>; - using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}; + using GemmPipeline = {pipeline_impl_map.get(pipeline)}; // Epilogue """ @@ -589,7 +519,7 @@ struct SelectedKernel {{ }} // Launch kernel - constexpr int kBlockPerCu = 1; + constexpr int kBlockPerCu = {k_block_per_cu}; ave_time = ck_tile::launch_kernel( stream, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); @@ -616,9 +546,13 @@ struct SelectedKernel {{ }} }}; """ - return kernel_name, instance_code + def run(self, num_workers=None): + """Run the builder to generate individual kernel files""" + # Generate individual kernel files + self.generate_individual(num_workers) + def generate_individual(self, num_workers=None): """Generate individual kernel files for separate compilation with parallel processing""" if num_workers is None: @@ -628,6 +562,7 @@ struct SelectedKernel {{ tile_configs = self._get_tile_configs() trait_combos = self._generate_trait_combinations() + k_block_per_cu = self.config.get("k_block_per_cu") # Prepare work items for parallel processing work_items = [] @@ -637,6 +572,7 @@ struct SelectedKernel {{ ( tile_config, trait_combo, + k_block_per_cu, self.working_path, self.datatype, self.layout, @@ -723,83 +659,17 @@ struct SelectedKernel {{ with open(self.working_path / "gemm_individual_targets.cmake", "w") as f: f.write(cmake_code) - def write_kernel_list(self): - """Write kernel list to file for CMake to read (with comprehensive validation)""" - # Get configurations using comprehensive validation - tile_configs = self._get_tile_configs(fast_mode=False) - trait_combos = self._generate_trait_combinations() - - kernel_list = [] - for tile_config in tile_configs: - for trait_combo in trait_combos: - ( - pipeline, - epilogue, - scheduler, - pad_m, - pad_n, - pad_k, - persistent, - ) = trait_combo - - # Create kernel name with proper boolean capitalization - kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" - - # Create tile configuration string - tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - kernel_name += f"_{tile_str}" - - kernel_list.append( - { - "name": kernel_name, - "tile_config": tile_config, - "trait_combo": trait_combo, - } - ) - - # Write kernel count - with open(self.working_path / "gemm_kernel_count.txt", "w") as f: - f.write(str(len(kernel_list))) - - # Write kernel list - with open(self.working_path / "gemm_kernel_list.txt", "w") as f: - for kernel in kernel_list: - # Format: kernel_name|tile_config|trait_combo - tile_config = kernel["tile_config"] - trait_combo = kernel["trait_combo"] - - tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - trait_str = ( - f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" - + "_".join(str(x) for x in trait_combo[3:]) - ) - - f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") - - print(f"Listed {len(kernel_list)} kernel configurations") - - def run(self, num_workers=None): - """Run the builder to generate individual kernel files""" - # Generate individual kernel files - self.generate_individual(num_workers) - def _generate_single_kernel_individual(work_item): """Worker function to generate a single individual kernel file""" - tile_config, trait_combo, working_path, datatype, layout = work_item + tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item # Create a temporary builder instance for this worker builder = GemmKernelBuilder(working_path, datatype, layout) try: kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo + tile_config, trait_combo, k_block_per_cu ) # Create simplified filename without the "gemm_" prefix @@ -832,7 +702,7 @@ def main(): parser.add_argument( "--datatype", required=True, - choices=["fp16", "fp8", "bf16", "fp32", "fp64"], + choices=["fp16", "fp8", "bf16", "bf8"], help="Data type", ) parser.add_argument( @@ -846,7 +716,9 @@ def main(): "--num_workers", type=int, help="Number of parallel workers (default: auto)" ) parser.add_argument( - "--gen_individual", action="store_true", help="Generate individual kernel files" + "--gen_all_individual", + action="store_true", + help="Generate individual kernel files", ) parser.add_argument( "--gen_single", action="store_true", help="Generate a single kernel file" @@ -866,13 +738,27 @@ def main(): args = parser.parse_args() + assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], ( + f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])" + ) + + layout_parts = args.layout.lower() + assert len(layout_parts) == 3, ( + f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" + ) + assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], ( + f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)" + ) + assert layout_parts[2] == "r", ( + f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" + ) + # Create builder builder = GemmKernelBuilder( args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json ) if args.list_kernels: - # Fast listing mode - just write kernel list without generating files builder.write_kernel_list() elif args.gen_single: # Generate a single kernel file @@ -911,9 +797,11 @@ def main(): trait_parts[6] == "True", # persistent ) + k_block_per_cu = builder.config.get("k_block_per_cu") + # Generate the kernel kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo + tile_config, trait_combo, k_block_per_cu ) # Write the file @@ -927,12 +815,12 @@ def main(): print(f"Generated {header_file}") - elif args.gen_individual: + elif args.gen_all_individual: # Generate all individual kernel files builder.run(args.num_workers) else: parser.error( - "Must specify one of: --list_kernels, --gen_individual, or --gen_single" + "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" ) diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 1298c78d18..575e5240a8 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -9,7 +9,7 @@ #include "ck_tile/host/device_prop.hpp" #include "ck_tile/ops/gemm.hpp" -#include "benchmark_gemm.hpp" +#include "gemm_benchmark.hpp" class GemmProfiler { diff --git a/tile_engine/ops/gemm/json_config.py b/tile_engine/ops/gemm/json_config.py deleted file mode 100644 index 04f2dd4890..0000000000 --- a/tile_engine/ops/gemm/json_config.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -# -*- coding: utf-8 -*- - -""" -Handles loading, parsing, and validation of JSON configuration parameters. -""" - -from pathlib import Path -from dataclasses import dataclass -from typing import List, Optional, Union, Tuple, Type, Dict -import json - - -@dataclass -class EnumConfigParam: - """Represents an enumeration-type configuration parameter""" - - values: List[Union[int, str, bool]] - - -@dataclass -class RangeConfigParam: - """Represents a numeric range-type configuration parameter""" - - min: int - max: int - step: int - exclude: Optional[List[int]] - - def generate_candidates(self) -> List[int]: - """Generates valid candidates after applying range constraints""" - - if self.min > self.max: - raise ValueError(f"Invalid range: min({self.min}) > max({self.max})") - if self.step <= 0: - raise ValueError(f"Step must be positive, got {self.step}") - - candidates = list(range(self.min, self.max + 1, self.step)) - - if hasattr(self, "exclude") and self.exclude: - if not isinstance(self.exclude, list): - raise TypeError("exclude must be list type") - exclude_set = set(self.exclude) - candidates = [x for x in candidates if x not in exclude_set] - - if not candidates: - raise ValueError( - f"No valid candidates for range [{self.min}-{self.max}] " - f"with step {self.step} and excludes {self.exclude}" - ) - - return candidates - - -@dataclass -class ProblemConfig: - """configuration class for problem parameter.""" - - datatypes: Tuple[EnumConfigParam, ...] - layouts: Tuple[EnumConfigParam, ...] - - @property - def datatype_map(self) -> Dict[str, str]: - """Get datatype as a key-value map.""" - return { - "matrix_a": self.datatypes[0].values[0], - "matrix_b": self.datatypes[1].values[0], - "matrix_c": self.datatypes[2].values[0], - } - - @property - def layout_map(self) -> Dict[str, str]: - """Get layout as a key-value map.""" - return { - "matrix_a": self.layouts[0].values[0], - "matrix_b": self.layouts[1].values[0], - "matrix_c": self.layouts[2].values[0], - } - - -@dataclass -class TileConfig: - """Configuration class for tile parameter.""" - - tile_m: Union[EnumConfigParam, RangeConfigParam] - tile_n: Union[EnumConfigParam, RangeConfigParam] - tile_k: Union[EnumConfigParam, RangeConfigParam] - - warp_m: Union[EnumConfigParam, RangeConfigParam] - warp_n: Union[EnumConfigParam, RangeConfigParam] - warp_k: Union[EnumConfigParam, RangeConfigParam] - - warp_tile_m: Union[EnumConfigParam, RangeConfigParam] - warp_tile_n: Union[EnumConfigParam, RangeConfigParam] - warp_tile_k: Union[EnumConfigParam, RangeConfigParam] - - -@dataclass -class TraitConfig: - """Configuration class for kernel traits.""" - - pipeline: EnumConfigParam - scheduler: EnumConfigParam - epilogue: EnumConfigParam - pad_m: EnumConfigParam - pad_n: EnumConfigParam - pad_k: EnumConfigParam - persistent: EnumConfigParam - - -@dataclass -class GemmConfig: - """Main configuration class for GEMM operations""" - - problem: ProblemConfig - tile_config: TileConfig - trait_config: TraitConfig - - @classmethod - def from_json( - cls: Type["GemmConfig"], filepath: str, datatype: str, layout: str - ) -> "GemmConfig": - """JSON configuration loader with validation controls""" - config_path = Path(filepath) - - try: - if not config_path.exists(): - raise FileNotFoundError(f"Config file {filepath} not found") - - with config_path.open("r") as f: - config_dict = json.load(f) - - a_type = datatype - b_type = datatype - c_type = datatype - if b_type == "int4": - a_type = "fp16" - if b_type in ["bf8", "fp8", "int4"]: - c_type = "fp16" - - layout_parts = layout.lower() - assert len(layout_parts) == 3, ( - f"Invalid layout string: {layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" - ) - assert layout_parts[0] in ("r", "c"), ( - f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)" - ) - assert layout_parts[1] in ("r", "c"), ( - f"Invalid matrix_a layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)" - ) - assert layout_parts[2] == "r", ( - f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" - ) - a_layout = layout_parts[0] - b_layout = layout_parts[1] - c_layout = layout_parts[2] - - # Parse problem config - # TODO: Not reading datatype information from json file. - problem = ProblemConfig( - datatypes=( - EnumConfigParam(values=[a_type]), - EnumConfigParam(values=[b_type]), - EnumConfigParam(values=[c_type]), - ), - layouts=( - EnumConfigParam(values=[a_layout]), - EnumConfigParam(values=[b_layout]), - EnumConfigParam(values=[c_layout]), - ), - ) - - # Parse tile config - def create_param(param_dict): - if "values" in param_dict: - return EnumConfigParam(values=param_dict["values"]) - else: - return RangeConfigParam( - min=param_dict["min"], - max=param_dict["max"], - step=param_dict["step"], - exclude=param_dict.get("exclude", []), - ) - - tile_config = TileConfig( - tile_m=create_param(config_dict["tile_config"]["tile_m"]), - tile_n=create_param(config_dict["tile_config"]["tile_n"]), - tile_k=create_param(config_dict["tile_config"]["tile_k"]), - warp_m=create_param(config_dict["tile_config"]["warp_m"]), - warp_n=create_param(config_dict["tile_config"]["warp_n"]), - warp_k=create_param(config_dict["tile_config"]["warp_k"]), - warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]), - warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]), - warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]), - ) - - # Parse trait config - trait_config = TraitConfig( - pipeline=EnumConfigParam( - values=config_dict["trait_config"]["pipeline"]["values"] - ), - scheduler=EnumConfigParam( - values=config_dict["trait_config"]["scheduler"]["values"] - ), - epilogue=EnumConfigParam( - values=config_dict["trait_config"]["epilogue"]["values"] - ), - pad_m=EnumConfigParam( - values=config_dict["trait_config"]["pad_m"]["values"] - ), - pad_n=EnumConfigParam( - values=config_dict["trait_config"]["pad_n"]["values"] - ), - pad_k=EnumConfigParam( - values=config_dict["trait_config"]["pad_k"]["values"] - ), - persistent=EnumConfigParam( - values=config_dict["trait_config"]["persistent"]["values"] - ), - ) - - return cls( - problem=problem, tile_config=tile_config, trait_config=trait_config - ) - - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON format: {str(e)}") - except KeyError as e: - raise KeyError(f"Missing required configuration field: {str(e)}")