mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK TILE ENGINE] Code changes to finding GPU id from TARGET (#3055)
* Reading gpuname from target for gemm in ck tile engine * Reading gpuname from target for gemm preshuffle in ck tile engine * Reading gpuname from target for gemm preshuffle in ck tile engine * Get GPU changes for GEMM Muti D in TILE ENGINE * Addressing errors for gpu name in cktileengine
This commit is contained in:
committed by
GitHub
parent
f18b79f328
commit
9f77061094
@@ -7,9 +7,6 @@ Validation utilities for GEMM kernel generation.
|
||||
Extracted from tile_engine_develop for consistency.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
@@ -123,34 +120,6 @@ def element_size(data_type: str) -> float:
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
logging.debug("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
logging.debug("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
logging.debug(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
|
||||
"""Check if a trait combination is valid."""
|
||||
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
@@ -225,11 +194,9 @@ def validate_warp_tile_combination(
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
gpu_name: str = None,
|
||||
gpu_name: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate warp tile combination against GPU-specific supported combinations."""
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
# Construct the key for looking up supported combinations
|
||||
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
|
||||
@@ -276,6 +243,7 @@ def is_tile_config_valid(
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
gpu_target: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -335,7 +303,13 @@ def is_tile_config_valid(
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
|
||||
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
if not warp_tile_valid:
|
||||
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
|
||||
|
||||
Reference in New Issue
Block a user