[CK TILE ENGINE] Code changes to finding GPU id from TARGET (#3055)

* Reading gpuname from target for gemm in ck tile engine

* Reading gpuname from target for gemm preshuffle in ck tile engine

* Reading gpuname from target for gemm preshuffle in ck tile engine

* Get GPU changes for GEMM Muti D in TILE ENGINE

* Addressing errors for gpu name in cktileengine
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-10-20 11:02:18 -05:00
committed by GitHub
parent f18b79f328
commit 9f77061094
12 changed files with 59 additions and 149 deletions

View File

@@ -7,9 +7,6 @@ Validation utilities for GEMM kernel generation.
Extracted from tile_engine_develop for consistency.
"""
import subprocess
import re
from functools import lru_cache
import logging
from typing import Tuple, List
@@ -123,34 +120,6 @@ def element_size(data_type: str) -> float:
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
logging.debug("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
logging.debug("GPU query timeout (5s)")
except Exception as e:
logging.debug(f"GPU detection error: {str(e)}")
return ""
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
"""Check if a trait combination is valid."""
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
@@ -225,11 +194,9 @@ def validate_warp_tile_combination(
a_datatype: str,
b_datatype: str,
c_datatype: str,
gpu_name: str = None,
gpu_name: str,
) -> Tuple[bool, str]:
"""Validate warp tile combination against GPU-specific supported combinations."""
if gpu_name is None:
gpu_name = get_gpu_name_by_id(0)
# Construct the key for looking up supported combinations
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
@@ -276,6 +243,7 @@ def is_tile_config_valid(
b_datatype: str,
c_datatype: str,
pipeline: str,
gpu_target: str,
trait_name: str = None,
) -> bool:
"""
@@ -335,7 +303,13 @@ def is_tile_config_valid(
# Validate warp tile combination
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
gpu_target,
)
if not warp_tile_valid:
logging.debug(f"Warp tile validation failed: {warp_tile_error}")