From 270c651d3cf4f7ba262ba1103a0a608452b13aa0 Mon Sep 17 00:00:00 2001 From: Brock Hargreaves Date: Tue, 10 Mar 2026 17:11:56 -0600 Subject: [PATCH] [CK] Fix warp tile combination selection in absence of a GPU (#5213) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The `get_gpu_name_by_id()` function in `gemm_streamk_validation_utils.py` relies on `rocminfo` to detect the GPU architecture at runtime. However, __`rocminfo` fails in CI/build environments__ where: - No physical GPU is present - ROCm tools are not installed - The build is running in a container without GPU access In any of these environments, the problem manifests itself in incorrect kernel validation and will generate template instantiations that do not exist: ``` [composable_kernel] FAILED: test/ck_tile/gemm_streamk_tile_engine/CMakeFiles/test_gemm_streamk_tile_engine_fp16_rcr_streamk_atomic_smoke_tests_config_fp16_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.dir/test_gemm_streamk_simple.cpp.o [composable_kernel] /__w/TheRock/TheRock/build/core/clr/dist/lib/llvm/bin/clang++ -DCK_ENABLE_BF16 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_FP8 -DCK_ENABLE_INT8 -DCK_ENABLE_TF32 -DCK_TILE_USE_WMMA=0 -DCK_TIME_KERNEL=1 -DCK_USE_FNUZ_FP8 -DCK_USE_GFX94 -DCK_USE_XDL -DDPP_KERNELS -DGEMM_SINGLE_INSTANCE_HPP=\"/__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/gemm_streamk_single_fp16_rcr_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.hpp\" -DGEMM_TEST_PARAMS_HPP=\"/__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/test_params.hpp\" -DUSE_PROF_API=1 -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1 -D__HIP_ROCclr__=1 -I/__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/profiler/include -I/__w/TheRock/TheRock/rocm-libraries/projects/composablekernel -I/__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/library/include -I/__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include -I/__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/include -I/__w/TheRock/TheRock/build/profiler/rocprofiler-sdk/stage/include -I/__w/TheRock/TheRock/build/profiler/roctracer/stage/include -I/__w/TheRock/TheRock/build/base/half/stage/include -I/__w/TheRock/TheRock/build/third-party/sysdeps/linux/libdrm/build/stage/lib/rocm_sysdeps/include -isystem /__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest/include -isystem /__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest -O3 -DNDEBUG -std=gnu++20 --offload-arch=gfx942 -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Wno-missing-field-initializers -Wno-error=deprecated-declarations -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Weverything -Wno-c++98-compat -Wno-c++98-compat-pedantic -Wno-conversion -Wno-double-promotion -Wno-exit-time-destructors -Wno-extra-semi -Wno-float-conversion -Wno-gnu-anonymous-struct -Wno-gnu-zero-variadic-macro-arguments -Wno-missing-prototypes -Wno-nested-anon-types -Wno-padded -Wno-return-std-move-in-c++11 -Wno-shorten-64-to-32 -Wno-sign-conversion -Wno-unknown-warning-option -Wno-unused-command-line-argument -Wno-weak-vtables -Wno-covered-switch-default -Wno-unsafe-buffer-usage -Wno-unused-lambda-capture -Wno-nvcc-compat -Wno-c++20-compat -Wno-bit-int-extension -Wno-pass-failed -Wno-switch-default -Wno-unique-object-duplication -fbracket-depth=1024 -Wno-nrvo -fno-offload-uniform-block -mllvm --lsr-drop-solution=1 -mllvm -enable-post-misched=0 -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -Werror -Weverything -fcolor-diagnostics -Wno-c++20-extensions -Wno-global-constructors -Wno-undef -Wno-undefined-func-template -Wno-float-equal --offload-compress -include /__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/gemm_streamk_single_fp16_rcr_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.hpp -MD -MT test/ck_tile/gemm_streamk_tile_engine/CMakeFiles/test_gemm_streamk_tile_engine_fp16_rcr_streamk_atomic_smoke_tests_config_fp16_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.dir/test_gemm_streamk_simple.cpp.o -MF test/ck_tile/gemm_streamk_tile_engine/CMakeFiles/test_gemm_streamk_tile_engine_fp16_rcr_streamk_atomic_smoke_tests_config_fp16_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.dir/test_gemm_streamk_simple.cpp.o.d -o test/ck_tile/gemm_streamk_tile_engine/CMakeFiles/test_gemm_streamk_tile_engine_fp16_rcr_streamk_atomic_smoke_tests_config_fp16_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.dir/test_gemm_streamk_simple.cpp.o -x hip -c /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp [composable_kernel] In file included from :2: [composable_kernel] In file included from /__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/gemm_streamk_single_fp16_rcr_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.hpp:9: [composable_kernel] In file included from /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include/ck_tile/ops/gemm.hpp:23: [composable_kernel] In file included from /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp:7: [composable_kernel] In file included from /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp:8: [composable_kernel] /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp:185:1: error: implicit instantiation of undefined template 'ck_tile::impl::warp_gemm_dispatcher::Dispatcher<_Float16, _Float16, float, 16, 16, 8, false, false, false, ck_tile::WGAttrNumAccessEnum::Single, ck_tile::WGAttrNumAccessEnum::Single>' ``` ## Technical Details ### Changes Made: #### 1. __gemm_streamk_validation_utils.py__ - Added module-level storage: `_configured_gpu_targets` - Added `set_gpu_targets(targets: List[str])` to configure fallback GPU targets - Added `get_configured_gpu_targets() -> List[str]` to retrieve configured targets - Enhanced `get_gpu_name_by_id()` to: - First try `rocminfo` (existing behavior) - If `rocminfo` fails, fall back to first configured GPU target - Extract base gfx name (e.g., "gfx90a" from "gfx90a:xnack+") - Log debug messages when using fallback #### 2. __gemm_streamk_instance_builder.py__ - Added `--gpu_targets` command-line argument - Automatically calls `set_gpu_targets()` when `--gpu_targets` is provided - Parses semicolon-separated GPU target list from CMake #### 3. __test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt__ - Modified both `--list_kernels` and `--gen_single` invocations to pass `--gpu_targets "${SUPPORTED_GPU_TARGETS}"` - GPU targets are now automatically wired from CMake to Python scripts ### How It Works: 1. __CMake Configuration__: `SUPPORTED_GPU_TARGETS` is determined from `GPU_TARGETS` or defaults 2. __CMake → Python__: CMake passes targets via `--gpu_targets` argument to Python scripts 3. __Python Configuration__: Scripts call `set_gpu_targets()` to configure the fallback 4. __Fallback Mechanism__: When `rocminfo` fails, `get_gpu_name_by_id()` uses the first configured target 5. __Target Parsing__: Extracts clean gfx name (e.g., "gfx90a" from "gfx90a:xnack+") ## Test Plan Confirm that only the appropriate kernels are selected and that CI passes. ## Test Result 1. Waiting on CI 2. Compilation succeeded locally and the kernel list does not contain the 16x16x8 kernel for gfx942 anymore: ``` (.venv) bhargrea@ctr-cx66-mi300x-02:~/github/TheRock$ cat build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/gemm_kernel_list.txt gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_True_256x256x32_2x2x1_16x16x16|256x256x32_2x2x1_16x16x16|compv3_cshuffle_intrawave_atomic_False_False_False_True gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_False_256x256x32_2x2x1_16x16x16|256x256x32_2x2x1_16x16x16|compv3_cshuffle_intrawave_atomic_False_False_False_False gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_True_256x256x32_2x2x1_16x16x32|256x256x32_2x2x1_16x16x32|compv3_cshuffle_intrawave_atomic_False_False_False_True gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_False_256x256x32_2x2x1_16x16x32|256x256x32_2x2x1_16x16x32|compv3_cshuffle_intrawave_atomic_False_False_False_False gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_True_256x256x32_2x2x1_32x32x8|256x256x32_2x2x1_32x32x8|compv3_cshuffle_intrawave_atomic_False_False_False_True gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_False_256x256x32_2x2x1_32x32x8|256x256x32_2x2x1_32x32x8|compv3_cshuffle_intrawave_atomic_False_False_False_False gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_True_256x256x32_2x2x1_32x32x16|256x256x32_2x2x1_32x32x16|compv3_cshuffle_intrawave_atomic_False_False_False_True gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_False_256x256x32_2x2x1_32x32x16|256x256x32_2x2x1_32x32x16|compv3_cshuffle_intrawave_atomic_False_False_False_False ``` ## Submission Checklist - [ x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../gemm_streamk_tile_engine/CMakeLists.txt | 2 + .../gemm_streamk_instance_builder.py | 11 ++++ .../gemm_streamk_validation_utils.py | 61 +++++++++++++++++-- 3 files changed, 69 insertions(+), 5 deletions(-) diff --git a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt index 8f9bd39886..aa1a2d2d1c 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt @@ -126,6 +126,7 @@ function(build_gemm_test_targets datatype layout config_name configs_dir_path) --layout ${layout} --config_json ${json_blob} --list_kernels + --gpu_targets "${SUPPORTED_GPU_TARGETS}" WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} RESULT_VARIABLE ret OUTPUT_VARIABLE list_output @@ -188,6 +189,7 @@ function(build_gemm_test_targets datatype layout config_name configs_dir_path) --kernel_name "${kernel_name}" --tile_config "${tile_config}" --trait_combo "${trait_combo}" + --gpu_targets "${SUPPORTED_GPU_TARGETS}" WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} RESULT_VARIABLE gen_ret OUTPUT_VARIABLE gen_output diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 4f3992bf99..5c87d6f50c 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -15,6 +15,7 @@ from typing import Optional from gemm_streamk_validation_utils import ( is_tile_config_valid, is_trait_combination_valid, + set_gpu_targets, ) logging.basicConfig(level=logging.INFO) @@ -819,9 +820,19 @@ def main(): action="store_true", help="List kernel configurations without generating files", ) + parser.add_argument( + "--gpu_targets", + help="Semicolon-separated list of GPU targets from CMake (e.g., 'gfx90a;gfx942;gfx950')", + ) args = parser.parse_args() + # Configure GPU targets for fallback if provided + if args.gpu_targets: + targets = [t.strip() for t in args.gpu_targets.split(';') if t.strip()] + set_gpu_targets(targets) + logging.debug(f"Configured GPU targets: {targets}") + # Create builder builder = GemmKernelBuilder( args.working_path, args.datatype, args.layout, args.config_json diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py index bef3cdfe85..d6c76c95b5 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py @@ -11,7 +11,7 @@ import subprocess import re from functools import lru_cache import logging -from typing import Tuple, List +from typing import Tuple, List, Optional # Element size mapping for different data types ELEMENT_SIZE_MAP = { @@ -124,19 +124,57 @@ def element_size(data_type: str) -> float: GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)") +# Module-level storage for configured GPU targets (fallback for when rocminfo fails) +_configured_gpu_targets: List[str] = [] + + +def set_gpu_targets(targets: List[str]) -> None: + """ + Set the fallback GPU targets list (from CMake SUPPORTED_GPU_TARGETS). + + This list will be used as a fallback when rocminfo fails to detect GPU. + + Args: + targets: List of GPU target strings (e.g., ["gfx90a", "gfx942:xnack+", "gfx950"]) + """ + global _configured_gpu_targets + _configured_gpu_targets = list(targets) + + +def get_configured_gpu_targets() -> List[str]: + """ + Get the configured GPU targets list. + + Returns: + List of configured GPU target strings + """ + return _configured_gpu_targets + @lru_cache(maxsize=1) def get_gpu_name_by_id(gpu_id: int = 0) -> str: - """Retrieve GPU name (e.g. gfx90a) by device ID""" + """ + Retrieve GPU name (e.g. gfx90a) by device ID. + + First attempts to query the GPU using rocminfo. If that fails, falls back + to using the first supported gfx target from the configured GPU targets list + (set via set_gpu_targets()). + + Args: + gpu_id: Device ID to query (default: 0) + + Returns: + GPU architecture name (e.g., "gfx90a") or empty string if detection fails + """ + # Try rocminfo first try: output = subprocess.check_output( ["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5 ) if matches := GPU_NAME_PATTERN.finditer(output): gpu_list = [m.group(1) for m in matches] - return gpu_list[gpu_id] if gpu_id < len(gpu_list) else "" - - return "" + if gpu_id < len(gpu_list): + return gpu_list[gpu_id] except subprocess.CalledProcessError as e: logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}") @@ -147,6 +185,18 @@ def get_gpu_name_by_id(gpu_id: int = 0) -> str: except Exception as e: logging.debug(f"GPU detection error: {str(e)}") + # Fallback to configured GPU targets from CMake + if _configured_gpu_targets: + target = _configured_gpu_targets[0] + # Extract base gfx name (e.g., "gfx90a" from "gfx90a:xnack+") + match = re.match(r'(gfx\d+\w*)', target) + if match: + gpu_name = match.group(1) + logging.debug(f"rocminfo failed, using fallback GPU target: {gpu_name}") + return gpu_name + else: + logging.debug(f"Failed to parse GPU target: {target}") + return "" @@ -234,6 +284,7 @@ def validate_warp_tile_combination( gpu_name: str = None, ) -> Tuple[bool, str]: """Validate warp tile combination against GPU-specific supported combinations.""" + # This is likely going to need to change to support multiple targets, not just a single one: if gpu_name is None: gpu_name = get_gpu_name_by_id(0)