[CK TILE ENGINE] Code changes to finding GPU id from TARGET (#3055)

* Reading gpuname from target for gemm in ck tile engine

* Reading gpuname from target for gemm preshuffle in ck tile engine

* Reading gpuname from target for gemm preshuffle in ck tile engine

* Get GPU changes for GEMM Muti D in TILE ENGINE

* Addressing errors for gpu name in cktileengine
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-10-20 11:02:18 -05:00
committed by GitHub
parent f18b79f328
commit 9f77061094
12 changed files with 59 additions and 149 deletions

View File

@@ -43,6 +43,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout)
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json ${json_blob}
--list_blobs
--gpu_target ${GEMM_GPU_TARGETS}
RESULT_VARIABLE ret
)
if(NOT ret EQUAL 0)
@@ -62,6 +63,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout)
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json "${json_blob}"
--gen_blobs
--gpu_target ${GEMM_GPU_TARGETS}
COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}"
)
add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})

View File

@@ -7,10 +7,6 @@
Mappings and utility functions for kernel code generation.
"""
import subprocess
import re
from functools import lru_cache
DATA_TYPE_MAP = {
"fp32": "float",
"fp16": "ck_tile::half_t",
@@ -198,31 +194,3 @@ def element_size(data_type: str) -> float:
if data_type not in ELEMENT_SIZE_MAP:
raise ValueError(f"Unsupported data type: {data_type}")
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
print("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
print("GPU query timeout (5s)")
except Exception as e:
print(f"GPU detection error: {str(e)}")
return ""

View File

@@ -22,7 +22,6 @@ from gemm_multi_d_codegen_utils import (
warp_tile_supported_combinations,
trait_unsupported_combinations,
element_size,
get_gpu_name_by_id,
)
import logging
@@ -40,6 +39,8 @@ class GemmMultiDCodeGenerator:
self.output_dir = Path(args.working_path)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.gpu_target = args.gpu_target
if user_provided_config is not None:
self.config = user_provided_config
else:
@@ -261,7 +262,7 @@ class GemmMultiDCodeGenerator:
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
gpu_name = get_gpu_name_by_id(0)
gpu_name = self.gpu_target
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
if not gpu_warp_tile_key:
@@ -713,6 +714,11 @@ if __name__ == "__main__":
required=False,
help="The path where all the blobs are going to be generated",
)
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"-j",
"--config_json",