mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK TILE ENGINE] Code changes to finding GPU id from TARGET (#3055)
* Reading gpuname from target for gemm in ck tile engine * Reading gpuname from target for gemm preshuffle in ck tile engine * Reading gpuname from target for gemm preshuffle in ck tile engine * Get GPU changes for GEMM Muti D in TILE ENGINE * Addressing errors for gpu name in cktileengine
This commit is contained in:
committed by
GitHub
parent
f18b79f328
commit
9f77061094
@@ -43,6 +43,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout)
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json ${json_blob}
|
||||
--list_blobs
|
||||
--gpu_target ${GEMM_GPU_TARGETS}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(NOT ret EQUAL 0)
|
||||
@@ -62,6 +63,7 @@ function(build_gemm_multi_d_for_datatype_layout datatype layout)
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json "${json_blob}"
|
||||
--gen_blobs
|
||||
--gpu_target ${GEMM_GPU_TARGETS}
|
||||
COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}"
|
||||
)
|
||||
add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})
|
||||
|
||||
@@ -7,10 +7,6 @@
|
||||
Mappings and utility functions for kernel code generation.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
DATA_TYPE_MAP = {
|
||||
"fp32": "float",
|
||||
"fp16": "ck_tile::half_t",
|
||||
@@ -198,31 +194,3 @@ def element_size(data_type: str) -> float:
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
print("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
print(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
|
||||
@@ -22,7 +22,6 @@ from gemm_multi_d_codegen_utils import (
|
||||
warp_tile_supported_combinations,
|
||||
trait_unsupported_combinations,
|
||||
element_size,
|
||||
get_gpu_name_by_id,
|
||||
)
|
||||
import logging
|
||||
|
||||
@@ -40,6 +39,8 @@ class GemmMultiDCodeGenerator:
|
||||
self.output_dir = Path(args.working_path)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.gpu_target = args.gpu_target
|
||||
|
||||
if user_provided_config is not None:
|
||||
self.config = user_provided_config
|
||||
else:
|
||||
@@ -261,7 +262,7 @@ class GemmMultiDCodeGenerator:
|
||||
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
gpu_name = self.gpu_target
|
||||
|
||||
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
|
||||
if not gpu_warp_tile_key:
|
||||
@@ -713,6 +714,11 @@ if __name__ == "__main__":
|
||||
required=False,
|
||||
help="The path where all the blobs are going to be generated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu_target",
|
||||
required=True,
|
||||
help="GPU target architecture",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--config_json",
|
||||
|
||||
Reference in New Issue
Block a user