From 047f6e448043b1936bb33a90d8dfa052da6acf6a Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Thu, 15 May 2025 11:16:13 +0000 Subject: [PATCH] python format --- tile_engine/ops/gemm/codegen_utils.py | 47 ++++++++++--------- tile_engine/ops/gemm/gemm_instance_builder.py | 12 +++-- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 03617a7f59..1480300b68 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -159,25 +159,25 @@ def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)] # To Do: add some more supported combinations warp_tile_supported_combinations = { - "gfx90a": { - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]] - }, - "gfx942": { - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] - }, - "gfx950": { - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] - } - } + "gfx90a": { + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]] + }, + "gfx942": { + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] + }, + "gfx950": { + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] + } +} # To Do: remove some unsupported combinations trait_unsupported_combinations = { @@ -200,6 +200,7 @@ def element_size(data_type: str) -> float: else: raise ValueError(f"Unsupported data type: {data_type}") + @lru_cache(maxsize=1) def get_gpu_name_by_id(gpu_id: int = 0) -> str: """Retrieve GPU name (e.g. gfx90a) by device ID""" @@ -212,15 +213,15 @@ def get_gpu_name_by_id(gpu_id: int = 0) -> str: text=True, check=True ) - + arch_pattern = r'gfx\d{3,4}[a-z]?' match = re.search(arch_pattern, result.stdout.lower()) return match.group() if match else "" - + except (FileNotFoundError, subprocess.CalledProcessError) as e: print(f"System Error: {str(e)}, when get the name of gpu:{gpu_id}") return "" except Exception as e: - print(f"Runtime Exception: {str(e)}, when get the name of gpu:{gpu_id}") + print( + f"Runtime Exception: {str(e)}, when get the name of gpu:{gpu_id}") return "" - diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index bc3ba7e81a..31e4a2401c 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -376,18 +376,20 @@ struct GemmKernel {{ # Warp combination validation warp_tile_key = f"{self.config.problem.datatype_map['matrix_a']}_{self.config.problem.datatype_map['matrix_b']}_{self.config.problem.datatype_map['matrix_c']}" current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] - + gpu_name = get_gpu_name_by_id(0) gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {}) if not gpu_warp_tile_key: - logging.warning(f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.") + logging.warning( + f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.") return True - + allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, []) if not allowed_combinations: - logging.warning(f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.") + logging.warning( + f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.") return True - + if current_combination not in allowed_combinations: logging.warning( f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. "