diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index b7415f9975..916f843c36 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,13 +1,22 @@ - +option(USE_CUSTOM_CONFIG "Enable user-provided configuration file" ON) # generate a list of kernels, but not actually emit files at config stage -execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - --working_path ${CMAKE_CURRENT_BINARY_DIR} - --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json - --list_blobs - RESULT_VARIABLE ret -) +if(USE_CUSTOM_CONFIG) + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json + --list_blobs + RESULT_VARIABLE ret + ) +else() + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --list_blobs + RESULT_VARIABLE ret + ) +endif() if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "Fail to list kernels via Python. ${ret}") @@ -15,13 +24,22 @@ endif() file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS) -add_custom_command( - OUTPUT ${GEMM_CODEGEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - --working_path ${CMAKE_CURRENT_BINARY_DIR} - --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json - --gen_blobs -) +if(USE_CUSTOM_CONFIG) + add_custom_command( + OUTPUT ${GEMM_CODEGEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json + --gen_blobs + ) +else() + add_custom_command( + OUTPUT ${GEMM_CODEGEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --gen_blobs + ) +endif() set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm") message("adding example ${EXECUTABLE_GEMM_INSTANCE}") diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 44e4ff4788..67dfa0570b 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -153,7 +153,8 @@ HOT_LOOP_TRUE = {'mem': RUN_MEM, def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)] -warp_tile_combinations = { +# To Do: add some more supported combinations +warp_tile_supported_combinations = { 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], # last 2 were not supported by MI300 architecture. @@ -161,13 +162,21 @@ warp_tile_combinations = { 'bf8_bf8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] } +# To Do: remove some unsupported combinations +trait_unsupported_combinations = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave") + } -def size_of(data_type): - if data_type == 'fp16' or data_type == 'bf16': +def element_size(data_type: str) -> float: + data_type = data_type.lower() + if data_type in {'fp16', 'bf16'}: return 2 - elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8': + elif data_type in {'int8', 'fp8', 'bf8'}: return 1 - elif data_type == 'int4': # TODO:: needs to confirm + elif data_type == 'int4': return 0.5 else: - return 4 + raise ValueError(f"Unsupported data type: {data_type}") diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json index f030397ebd..4305368392 100644 --- a/tile_engine/ops/gemm/configs/default_config.json +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -34,53 +34,84 @@ "tile_config": { "tile_m": { "max": 512, - "min": 256, - "step": 64 + "min": 64, + "step": 8 }, "tile_n": { - "values": [ - 256 - ] + "max": 512, + "min": 64, + "step": 8 }, "tile_k": { - "values": [ - 32 - ] + "max": 512, + "min": 64, + "step": 8 }, "warp_m": { "values": [ - 2 + 4, + 8, + 16, + 32, + 64, + 128 ] }, "warp_n": { "values": [ - 2 + 4, + 8, + 16, + 32, + 64, + 128 ] }, "warp_k": { "values": [ - 1 + 4, + 8, + 16, + 32, + 64, + 128 ] }, "warp_tile_m": { "values": [ - 32 + 4, + 8, + 16, + 32, + 64, + 128 ] }, "warp_tile_n": { "values": [ - 32 + 4, + 8, + 16, + 32, + 64, + 128 ] }, "warp_tile_k": { "values": [ - 8 + 4, + 8, + 16, + 32, + 64, + 128 ] } }, "trait_config": { "pipeline": { "values": [ + "compv4", "compv3", "mem" ] diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index efbd58e575..8589d910c8 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -26,8 +26,9 @@ from codegen_utils import ( EPILOGUE_MAP, HOT_LOOP_TRUE, BOOL_MAP, - warp_tile_combinations, - size_of + warp_tile_supported_combinations, + trait_unsupported_combinations, + element_size ) import logging @@ -73,14 +74,6 @@ class GemmCodeGenerator: "pad_n", "pad_k"] - # To remove some unsupported combinations - unsupported_combinations = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave") - } - # Generate all unique_combinations _unique = set(itertools.product(*[ getattr(self.config.trait_config, param).values @@ -91,7 +84,7 @@ class GemmCodeGenerator: pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo current_combination = (pipeline, epilogue, scheduler) - if current_combination not in unsupported_combinations: + if current_combination not in trait_unsupported_combinations: trait_name = ( f"{pipeline}_{epilogue}_{scheduler}_" f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}" @@ -118,15 +111,15 @@ class GemmCodeGenerator: #include "ck_tile/core.hpp" // Data types -using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_values[0]]}; -using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_values[1]]}; +using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_a']]}; +using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_b']]}; using AccDataType = float; -using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_values[2]]}; +using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_c']]}; // Layout configurations -using ALayout = {LAYOUT_MAP[self.config.problem.layout_values[0]]}; -using BLayout = {LAYOUT_MAP[self.config.problem.layout_values[1]]}; -using CLayout = {LAYOUT_MAP[self.config.problem.layout_values[2]]}; +using ALayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_a']]}; +using BLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_b']]}; +using CLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_c']]}; """ (self.output_dir / "gemm_common.hpp").write_text(content) @@ -358,15 +351,15 @@ struct GemmKernel {{ logging.warning( f"Dimension alignment failed [{trait}]: {', '.join(alignment_issues)}. " f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " - f"[warp×tile] {warp_m}x{warp_n}x{warp_k} × {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + f"[warpxtile] {warp_m}x{warp_n}x{warp_k} x {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" ) return False # LDS capacity verification matrix_a_size = (tile_m * tile_k) * \ - size_of(self.config.problem.datatype_values[0]) + element_size(self.config.problem.datatype_map['matrix_a']) matrix_b_size = (tile_n * tile_k) * \ - size_of(self.config.problem.datatype_values[1]) + element_size(self.config.problem.datatype_map['matrix_b']) total_tile_in_lds = matrix_a_size + matrix_b_size max_tile_size = 2**16 if pipeline == "compv4" else 2**15 @@ -374,15 +367,15 @@ struct GemmKernel {{ logging.warning( f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > " f"maximum allowed {max_tile_size:,}B ({max_tile_size/1024}KB). Breakdown:\n" - f"- Matrix A ({self.config.problem.datatype_values[0]}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" - f"- Matrix B ({self.config.problem.datatype_values[1]}): {tile_n}x{tile_k} = {matrix_b_size:,}B" + f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" + f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B" ) return False # Warp combination validation - warp_tile_key = f"{self.config.problem.datatype_values[0]}_{self.config.problem.datatype_values[1]}_{self.config.problem.datatype_values[2]}" + warp_tile_key = f"{self.config.problem.datatype_map['matrix_a']}_{self.config.problem.datatype_map['matrix_b']}_{self.config.problem.datatype_map['matrix_c']}" current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] - allowed_combinations = warp_tile_combinations.get(warp_tile_key, []) + allowed_combinations = warp_tile_supported_combinations.get(warp_tile_key, []) if current_combination not in allowed_combinations: logging.warning( @@ -455,7 +448,9 @@ struct GemmDispatcher { if(structured_sparsity){{ // SMFMA""" for tile in tile_params: if self.is_tile_valid(tile, trait): - sparse = self.config.problem.datatype_values[0] == 'fp16' and \ + sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \ + self.config.problem.datatype_map['matrix_b'] == 'fp16' and \ + self.config.problem.datatype_map['matrix_c'] == 'fp16' and \ ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" diff --git a/tile_engine/ops/gemm/json_config.py b/tile_engine/ops/gemm/json_config.py index c3e273b307..0408b2dfb5 100644 --- a/tile_engine/ops/gemm/json_config.py +++ b/tile_engine/ops/gemm/json_config.py @@ -228,12 +228,22 @@ class ProblemConfig: ) @property - def datatype_values(self) -> list: - return [p.values[0] for p in self.datatypes] + def datatype_map(self) -> dict[str, str]: + """Get current layout selections as a key-value map.""" + return { + 'matrix_a': self.datatypes[0].values[0], + 'matrix_b': self.datatypes[1].values[0], + 'matrix_c': self.datatypes[2].values[0] + } @property - def layout_values(self) -> list: - return [p.values[0] for p in self.layouts] + def layout_map(self) -> dict[str, str]: + """Get current layout selections as a key-value map.""" + return { + 'matrix_a': self.layouts[0].values[0], + 'matrix_b': self.layouts[1].values[0], + 'matrix_c': self.layouts[2].values[0] + } @dataclass @@ -257,33 +267,33 @@ class TileConfig: warp_m: Union[EnumConfigParam, RangeConfigParam] = Field( default_factory=lambda: EnumConfigParam( - values=[256] + values=[8] ) ) warp_n: Union[EnumConfigParam, RangeConfigParam] = Field( default_factory=lambda: EnumConfigParam( - values=[256] + values=[8] ) ) warp_k: Union[EnumConfigParam, RangeConfigParam] = Field( default_factory=lambda: EnumConfigParam( - values=[256] + values=[8] ) ) warp_tile_m: Union[EnumConfigParam, RangeConfigParam] = Field( default_factory=lambda: EnumConfigParam( - values=[256] + values=[8] ) ) warp_tile_n: Union[EnumConfigParam, RangeConfigParam] = Field( default_factory=lambda: EnumConfigParam( - values=[256] + values=[8] ) ) warp_tile_k: Union[EnumConfigParam, RangeConfigParam] = Field( default_factory=lambda: EnumConfigParam( - values=[256] + values=[8] ) )