mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5213 (commit 9f7e62c)
[CK] Fix warp tile combination selection in absence of a GPU (#5213) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The `get_gpu_name_by_id()` function in `gemm_streamk_validation_utils.py` relies on `rocminfo` to detect the GPU architecture at runtime. However, __`rocminfo` fails in CI/build environments__ where: - No physical GPU is present - ROCm tools are not installed - The build is running in a container without GPU access In any of these environments, the problem manifests itself in incorrect kernel validation and will generate template instantiations that do not exist: ``` [composable_kernel] FAILED: test/ck_tile/gemm_streamk_tile_engine/CMakeFiles/test_gemm_streamk_tile_engine_fp16_rcr_streamk_atomic_smoke_tests_config_fp16_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.dir/test_gemm_streamk_simple.cpp.o [composable_kernel] /__w/TheRock/TheRock/build/core/clr/dist/lib/llvm/bin/clang++ -DCK_ENABLE_BF16 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_FP8 -DCK_ENABLE_INT8 -DCK_ENABLE_TF32 -DCK_TILE_USE_WMMA=0 -DCK_TIME_KERNEL=1 -DCK_USE_FNUZ_FP8 -DCK_USE_GFX94 -DCK_USE_XDL -DDPP_KERNELS -DGEMM_SINGLE_INSTANCE_HPP=\"/__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/gemm_streamk_single_fp16_rcr_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.hpp\" -DGEMM_TEST_PARAMS_HPP=\"/__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/test_params.hpp\" -DUSE_PROF_API=1 -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1 -D__HIP_ROCclr__=1 -I/__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/profiler/include -I/__w/TheRock/TheRock/rocm-libraries/projects/composablekernel -I/__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/library/include -I/__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include -I/__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/include -I/__w/TheRock/TheRock/build/profiler/rocprofiler-sdk/stage/include -I/__w/TheRock/TheRock/build/profiler/roctracer/stage/include -I/__w/TheRock/TheRock/build/base/half/stage/include -I/__w/TheRock/TheRock/build/third-party/sysdeps/linux/libdrm/build/stage/lib/rocm_sysdeps/include -isystem /__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest/include -isystem /__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest -O3 -DNDEBUG -std=gnu++20 --offload-arch=gfx942 -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Wno-missing-field-initializers -Wno-error=deprecated-declarations -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Weverything -Wno-c++98-compat -Wno-c++98-compat-pedantic -Wno-conversion -Wno-double-promotion -Wno-exit-time-destructors -Wno-extra-semi -Wno-float-conversion -Wno-gnu-anonymous-struct -Wno-gnu-zero-variadic-macro-arguments -Wno-missing-prototypes -Wno-nested-anon-types -Wno-padded -Wno-return-std-move-in-c++11 -Wno-shorten-64-to-32 -Wno-sign-conversion -Wno-unknown-warning-option -Wno-unused-command-line-argument -Wno-weak-vtables -Wno-covered-switch-default -Wno-unsafe-buffer-usage -Wno-unused-lambda-capture -Wno-nvcc-compat -Wno-c++20-compat -Wno-bit-int-extension -Wno-pass-failed -Wno-switch-default -Wno-unique-object-duplication -fbracket-depth=1024 -Wno-nrvo -fno-offload-uniform-block -mllvm --lsr-drop-solution=1 -mllvm -enable-post-misched=0 -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -Werror -Weverything -fcolor-diagnostics -Wno-c++20-extensions -Wno-global-constructors -Wno-undef -Wno-undefined-func-template -Wno-float-equal --offload-compress -include /__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/gemm_streamk_single_fp16_rcr_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.hpp -MD -MT test/ck_tile/gemm_streamk_tile_engine/CMakeFiles/test_gemm_streamk_tile_engine_fp16_rcr_streamk_atomic_smoke_tests_config_fp16_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.dir/test_gemm_streamk_simple.cpp.o -MF test/ck_tile/gemm_streamk_tile_engine/CMakeFiles/test_gemm_streamk_tile_engine_fp16_rcr_streamk_atomic_smoke_tests_config_fp16_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.dir/test_gemm_streamk_simple.cpp.o.d -o test/ck_tile/gemm_streamk_tile_engine/CMakeFiles/test_gemm_streamk_tile_engine_fp16_rcr_streamk_atomic_smoke_tests_config_fp16_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.dir/test_gemm_streamk_simple.cpp.o -x hip -c /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp [composable_kernel] In file included from <built-in>:2: [composable_kernel] In file included from /__w/TheRock/TheRock/build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/gemm_streamk_single_fp16_rcr_compv3_cshuffle_intrawave_atomic_False_False_False_False_256x256x32_2x2x1_16x16x8.hpp:9: [composable_kernel] In file included from /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include/ck_tile/ops/gemm.hpp:23: [composable_kernel] In file included from /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp:7: [composable_kernel] In file included from /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp:8: [composable_kernel] /__w/TheRock/TheRock/rocm-libraries/projects/composablekernel/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp:185:1: error: implicit instantiation of undefined template 'ck_tile::impl::warp_gemm_dispatcher::Dispatcher<_Float16, _Float16, float, 16, 16, 8, false, false, false, ck_tile::WGAttrNumAccessEnum::Single, ck_tile::WGAttrNumAccessEnum::Single>' ``` ## Technical Details ### Changes Made: #### 1. __gemm_streamk_validation_utils.py__ - Added module-level storage: `_configured_gpu_targets` - Added `set_gpu_targets(targets: List[str])` to configure fallback GPU targets - Added `get_configured_gpu_targets() -> List[str]` to retrieve configured targets - Enhanced `get_gpu_name_by_id()` to: - First try `rocminfo` (existing behavior) - If `rocminfo` fails, fall back to first configured GPU target - Extract base gfx name (e.g., "gfx90a" from "gfx90a:xnack+") - Log debug messages when using fallback #### 2. __gemm_streamk_instance_builder.py__ - Added `--gpu_targets` command-line argument - Automatically calls `set_gpu_targets()` when `--gpu_targets` is provided - Parses semicolon-separated GPU target list from CMake #### 3. __test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt__ - Modified both `--list_kernels` and `--gen_single` invocations to pass `--gpu_targets "${SUPPORTED_GPU_TARGETS}"` - GPU targets are now automatically wired from CMake to Python scripts ### How It Works: 1. __CMake Configuration__: `SUPPORTED_GPU_TARGETS` is determined from `GPU_TARGETS` or defaults 2. __CMake → Python__: CMake passes targets via `--gpu_targets` argument to Python scripts 3. __Python Configuration__: Scripts call `set_gpu_targets()` to configure the fallback 4. __Fallback Mechanism__: When `rocminfo` fails, `get_gpu_name_by_id()` uses the first configured target 5. __Target Parsing__: Extracts clean gfx name (e.g., "gfx90a" from "gfx90a:xnack+") ## Test Plan Confirm that only the appropriate kernels are selected and that CI passes. ## Test Result 1. Waiting on CI 2. Compilation succeeded locally and the kernel list does not contain the 16x16x8 kernel for gfx942 anymore: ``` (.venv) bhargrea@ctr-cx66-mi300x-02:~/github/TheRock$ cat build/ml-libs/composable_kernel/build/test/ck_tile/gemm_streamk_tile_engine/fp16/rcr/streamk_atomic_smoke_tests_config_fp16/gemm_kernel_list.txt gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_True_256x256x32_2x2x1_16x16x16|256x256x32_2x2x1_16x16x16|compv3_cshuffle_intrawave_atomic_False_False_False_True gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_False_256x256x32_2x2x1_16x16x16|256x256x32_2x2x1_16x16x16|compv3_cshuffle_intrawave_atomic_False_False_False_False gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_True_256x256x32_2x2x1_16x16x32|256x256x32_2x2x1_16x16x32|compv3_cshuffle_intrawave_atomic_False_False_False_True gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_False_256x256x32_2x2x1_16x16x32|256x256x32_2x2x1_16x16x32|compv3_cshuffle_intrawave_atomic_False_False_False_False gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_True_256x256x32_2x2x1_32x32x8|256x256x32_2x2x1_32x32x8|compv3_cshuffle_intrawave_atomic_False_False_False_True gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_False_256x256x32_2x2x1_32x32x8|256x256x32_2x2x1_32x32x8|compv3_cshuffle_intrawave_atomic_False_False_False_False gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_True_256x256x32_2x2x1_32x32x16|256x256x32_2x2x1_32x32x16|compv3_cshuffle_intrawave_atomic_False_False_False_True gemm_fp16_rcr_compv3_cshuffle_intrawave_Atomic_False_False_False_False_256x256x32_2x2x1_32x32x16|256x256x32_2x2x1_32x32x16|compv3_cshuffle_intrawave_atomic_False_False_False_False ``` ## Submission Checklist - [ x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
8f27f65d44
commit
26d29374e5
@@ -15,6 +15,7 @@ from typing import Optional
|
||||
from gemm_streamk_validation_utils import (
|
||||
is_tile_config_valid,
|
||||
is_trait_combination_valid,
|
||||
set_gpu_targets,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -819,9 +820,19 @@ def main():
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu_targets",
|
||||
help="Semicolon-separated list of GPU targets from CMake (e.g., 'gfx90a;gfx942;gfx950')",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure GPU targets for fallback if provided
|
||||
if args.gpu_targets:
|
||||
targets = [t.strip() for t in args.gpu_targets.split(';') if t.strip()]
|
||||
set_gpu_targets(targets)
|
||||
logging.debug(f"Configured GPU targets: {targets}")
|
||||
|
||||
# Create builder
|
||||
builder = GemmKernelBuilder(
|
||||
args.working_path, args.datatype, args.layout, args.config_json
|
||||
|
||||
@@ -11,7 +11,7 @@ import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
# Element size mapping for different data types
|
||||
ELEMENT_SIZE_MAP = {
|
||||
@@ -124,19 +124,57 @@ def element_size(data_type: str) -> float:
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
# Module-level storage for configured GPU targets (fallback for when rocminfo fails)
|
||||
_configured_gpu_targets: List[str] = []
|
||||
|
||||
|
||||
def set_gpu_targets(targets: List[str]) -> None:
|
||||
"""
|
||||
Set the fallback GPU targets list (from CMake SUPPORTED_GPU_TARGETS).
|
||||
|
||||
This list will be used as a fallback when rocminfo fails to detect GPU.
|
||||
|
||||
Args:
|
||||
targets: List of GPU target strings (e.g., ["gfx90a", "gfx942:xnack+", "gfx950"])
|
||||
"""
|
||||
global _configured_gpu_targets
|
||||
_configured_gpu_targets = list(targets)
|
||||
|
||||
|
||||
def get_configured_gpu_targets() -> List[str]:
|
||||
"""
|
||||
Get the configured GPU targets list.
|
||||
|
||||
Returns:
|
||||
List of configured GPU target strings
|
||||
"""
|
||||
return _configured_gpu_targets
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
"""
|
||||
Retrieve GPU name (e.g. gfx90a) by device ID.
|
||||
|
||||
First attempts to query the GPU using rocminfo. If that fails, falls back
|
||||
to using the first supported gfx target from the configured GPU targets list
|
||||
(set via set_gpu_targets()).
|
||||
|
||||
Args:
|
||||
gpu_id: Device ID to query (default: 0)
|
||||
|
||||
Returns:
|
||||
GPU architecture name (e.g., "gfx90a") or empty string if detection fails
|
||||
"""
|
||||
# Try rocminfo first
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
if gpu_id < len(gpu_list):
|
||||
return gpu_list[gpu_id]
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
@@ -147,6 +185,18 @@ def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
except Exception as e:
|
||||
logging.debug(f"GPU detection error: {str(e)}")
|
||||
|
||||
# Fallback to configured GPU targets from CMake
|
||||
if _configured_gpu_targets:
|
||||
target = _configured_gpu_targets[0]
|
||||
# Extract base gfx name (e.g., "gfx90a" from "gfx90a:xnack+")
|
||||
match = re.match(r'(gfx\d+\w*)', target)
|
||||
if match:
|
||||
gpu_name = match.group(1)
|
||||
logging.debug(f"rocminfo failed, using fallback GPU target: {gpu_name}")
|
||||
return gpu_name
|
||||
else:
|
||||
logging.debug(f"Failed to parse GPU target: {target}")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@@ -234,6 +284,7 @@ def validate_warp_tile_combination(
|
||||
gpu_name: str = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate warp tile combination against GPU-specific supported combinations."""
|
||||
# This is likely going to need to change to support multiple targets, not just a single one:
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user