diff --git a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt index 8f9bd39886..aa1a2d2d1c 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt @@ -126,6 +126,7 @@ function(build_gemm_test_targets datatype layout config_name configs_dir_path) --layout ${layout} --config_json ${json_blob} --list_kernels + --gpu_targets "${SUPPORTED_GPU_TARGETS}" WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} RESULT_VARIABLE ret OUTPUT_VARIABLE list_output @@ -188,6 +189,7 @@ function(build_gemm_test_targets datatype layout config_name configs_dir_path) --kernel_name "${kernel_name}" --tile_config "${tile_config}" --trait_combo "${trait_combo}" + --gpu_targets "${SUPPORTED_GPU_TARGETS}" WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} RESULT_VARIABLE gen_ret OUTPUT_VARIABLE gen_output diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 4f3992bf99..5c87d6f50c 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -15,6 +15,7 @@ from typing import Optional from gemm_streamk_validation_utils import ( is_tile_config_valid, is_trait_combination_valid, + set_gpu_targets, ) logging.basicConfig(level=logging.INFO) @@ -819,9 +820,19 @@ def main(): action="store_true", help="List kernel configurations without generating files", ) + parser.add_argument( + "--gpu_targets", + help="Semicolon-separated list of GPU targets from CMake (e.g., 'gfx90a;gfx942;gfx950')", + ) args = parser.parse_args() + # Configure GPU targets for fallback if provided + if args.gpu_targets: + targets = [t.strip() for t in args.gpu_targets.split(';') if t.strip()] + set_gpu_targets(targets) + logging.debug(f"Configured GPU targets: {targets}") + # Create builder builder = GemmKernelBuilder( args.working_path, args.datatype, args.layout, args.config_json diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py index bef3cdfe85..d6c76c95b5 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py @@ -11,7 +11,7 @@ import subprocess import re from functools import lru_cache import logging -from typing import Tuple, List +from typing import Tuple, List, Optional # Element size mapping for different data types ELEMENT_SIZE_MAP = { @@ -124,19 +124,57 @@ def element_size(data_type: str) -> float: GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)") +# Module-level storage for configured GPU targets (fallback for when rocminfo fails) +_configured_gpu_targets: List[str] = [] + + +def set_gpu_targets(targets: List[str]) -> None: + """ + Set the fallback GPU targets list (from CMake SUPPORTED_GPU_TARGETS). + + This list will be used as a fallback when rocminfo fails to detect GPU. + + Args: + targets: List of GPU target strings (e.g., ["gfx90a", "gfx942:xnack+", "gfx950"]) + """ + global _configured_gpu_targets + _configured_gpu_targets = list(targets) + + +def get_configured_gpu_targets() -> List[str]: + """ + Get the configured GPU targets list. + + Returns: + List of configured GPU target strings + """ + return _configured_gpu_targets + @lru_cache(maxsize=1) def get_gpu_name_by_id(gpu_id: int = 0) -> str: - """Retrieve GPU name (e.g. gfx90a) by device ID""" + """ + Retrieve GPU name (e.g. gfx90a) by device ID. + + First attempts to query the GPU using rocminfo. If that fails, falls back + to using the first supported gfx target from the configured GPU targets list + (set via set_gpu_targets()). + + Args: + gpu_id: Device ID to query (default: 0) + + Returns: + GPU architecture name (e.g., "gfx90a") or empty string if detection fails + """ + # Try rocminfo first try: output = subprocess.check_output( ["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5 ) if matches := GPU_NAME_PATTERN.finditer(output): gpu_list = [m.group(1) for m in matches] - return gpu_list[gpu_id] if gpu_id < len(gpu_list) else "" - - return "" + if gpu_id < len(gpu_list): + return gpu_list[gpu_id] except subprocess.CalledProcessError as e: logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}") @@ -147,6 +185,18 @@ def get_gpu_name_by_id(gpu_id: int = 0) -> str: except Exception as e: logging.debug(f"GPU detection error: {str(e)}") + # Fallback to configured GPU targets from CMake + if _configured_gpu_targets: + target = _configured_gpu_targets[0] + # Extract base gfx name (e.g., "gfx90a" from "gfx90a:xnack+") + match = re.match(r'(gfx\d+\w*)', target) + if match: + gpu_name = match.group(1) + logging.debug(f"rocminfo failed, using fallback GPU target: {gpu_name}") + return gpu_name + else: + logging.debug(f"Failed to parse GPU target: {target}") + return "" @@ -234,6 +284,7 @@ def validate_warp_tile_combination( gpu_name: str = None, ) -> Tuple[bool, str]: """Validate warp tile combination against GPU-specific supported combinations.""" + # This is likely going to need to change to support multiple targets, not just a single one: if gpu_name is None: gpu_name = get_gpu_name_by_id(0)