diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index af36f492ba..0d7bcd6b18 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron diff --git a/.gitignore b/.gitignore index 98234268c1..17f93500bd 100644 --- a/.gitignore +++ b/.gitignore @@ -81,14 +81,35 @@ CMakeUserPresets.json # Python cache __pycache__/ +# Cache directories .cache/ +.ck_tile_cache/ +ck_tile_cache/ +**/kernel_cache/ +**/.kernel_cache/ + +# Dispatcher kernel cache (user-generated, can be large) +dispatcher/**/kernel_cache/ +dispatcher/**/.kernel_cache/ +dispatcher/**/cached_kernels/ +dispatcher/**/*.hsaco +dispatcher/**/*.co + +# Dispatcher generated JSON exports +dispatcher/**/*_kernels.json +dispatcher/**/dispatcher_kernels.json # Generated test data test_data/* !test_data/*.py !test_data/*.sh +!test_data/requirements.txt # Exceptions to build* patterns above # The experimental/builder directory should be tracked despite matching build* !experimental/builder !experimental/builder/** +experimental/grouped_convolution_tile_instances/instances/* +!experimental/grouped_convolution_tile_instances/instances/*.in +!experimental/grouped_convolution_tile_instances/instances/*.inc +experimental/grouped_convolution_tile_instances/*.inc diff --git a/CHANGELOG.md b/CHANGELOG.md index 066dc9aa3b..c99fc1d065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for gfx1153 target. * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. +* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. +* Added FP8 block scale quantization for FMHA forward kernel. ### Changed @@ -23,6 +25,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4 +* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM. diff --git a/CMakeLists.txt b/CMakeLists.txt index dc77337248..80572c309c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -cmake_minimum_required(VERSION 3.14) +cmake_minimum_required(VERSION 3.21) if(POLICY CMP0140) # policies CMP0140 not known to CMake until 3.25 cmake_policy(SET CMP0140 NEW) @@ -41,6 +41,7 @@ include(CTest) option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF) option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) @@ -259,6 +260,11 @@ if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx950" AND NOT FORCE_DISABLE_XDL) + message(STATUS "Enabling XDL FP8 gemms on gfx950") + add_definitions(-DCK_USE_GFX950) + set(CK_USE_GFX950 "ON") +endif() # new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA set(CK_TILE_USE_WMMA 0) @@ -643,7 +649,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS add_compile_options(-fdiagnostics-color=always) endif() -if(NOT MIOPEN_REQ_LIBS_ONLY) +if(NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) # make check runs the entire set of examples and tests add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} USES_TERMINAL) # make smoke runs the tests and examples that runs within 30 seconds on gfx90a @@ -654,7 +660,9 @@ endif() -file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") +# Optimization: Search only in library/src where all instance files actually live +# (was searching entire source tree, taking ~40s instead of <1s) +file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/library/src/*/device_*_instance.cpp") file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) set(CK_DEVICE_INSTANCES) FOREACH(subdir_path ${dir_list}) @@ -699,12 +707,18 @@ ENDFOREACH() add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF) option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) add_subdirectory(library) -if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) +if (CK_EXPERIMENTAL_BUILDER) + add_subdirectory(experimental/builder) + add_subdirectory(experimental/grouped_convolution_tile_instances) +endif() + +if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name @@ -727,7 +741,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) endif() endif() -if (NOT MIOPEN_REQ_LIBS_ONLY) +if (NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY) rocm_package_setup_component(profiler LIBRARY_NAME composablekernel PACKAGE_NAME ckprofiler @@ -735,10 +749,6 @@ if (NOT MIOPEN_REQ_LIBS_ONLY) add_subdirectory(profiler) endif() -if (CK_EXPERIMENTAL_BUILDER) - add_subdirectory(experimental/builder) -endif() - if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) add_subdirectory(codegen) endif() diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 0000000000..f81dbadb19 --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,91 @@ +{ + "version": 3, + "cmakeMinimumRequired": { + "major": 3, + "minor": 21, + "patch": 0 + }, + "configurePresets": [ + { + "name": "use-gfx908", + "hidden": true, + "cacheVariables": { + "GPU_TARGETS": "gfx908" + } + }, + { + "name": "use-gfx90a", + "hidden": true, + "cacheVariables": { + "GPU_TARGETS": "gfx90a" + } + }, + { + "name": "use-gfx942", + "hidden": true, + "cacheVariables": { + "GPU_TARGETS": "gfx942" + } + }, + { + "name": "use-gfx950", + "hidden": true, + "cacheVariables": { + "GPU_TARGETS": "gfx950" + } + }, + { + "name": "dev", + "binaryDir": "${sourceDir}/build", + "displayName": "CK Dev", + "environment": {}, + "cacheVariables": { + "CMAKE_PREFIX_PATH": "/opt/rocm/", + "CMAKE_CXX_COMPILER": "/opt/rocm/llvm/bin/clang++", + "CMAKE_HIP_COMPILER": "/opt/rocm/llvm/bin/clang++", + "CMAKE_CXX_FLAGS": "-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -fbracket-depth=512", + "CMAKE_BUILD_TYPE": "Release", + "BUILD_DEV": "ON", + "CMAKE_VERBOSE_MAKEFILE": "ON", + "USE_BITINT_EXTENSION_INT4": "OFF", + "GPU_TARGETS": "gfx908;gfx90a;gfx942" + } + }, + { + "name": "dev-gfx908", + "displayName": "CK Dev - gfx908", + "description": "Development build for AMD GPU gfx908", + "inherits": [ + "use-gfx908", + "dev" + ] + }, + { + "name": "dev-gfx90a", + "displayName": "CK Dev - gfx90a", + "description": "Development build for AMD GPU gfx90a", + "inherits": [ + "use-gfx90a", + "dev" + ] + }, + { + "name": "dev-gfx942", + "displayName": "CK Dev - gfx942", + "description": "Development build for AMD GPU gfx942", + "inherits": [ + "use-gfx942", + "dev" + ] + }, + { + "name": "dev-gfx950", + "displayName": "CK Dev - gfx950", + "description": "Development build for AMD GPU gfx950", + "inherits": [ + "use-gfx950", + "dev" + ] + } + ] +} \ No newline at end of file diff --git a/Dockerfile.manylinux b/Dockerfile.manylinux new file mode 100644 index 0000000000..0683bcd4a6 --- /dev/null +++ b/Dockerfile.manylinux @@ -0,0 +1,101 @@ +FROM ghcr.io/rocm/therock_build_manylinux_x86_64:latest +ARG DEBIAN_FRONTEND=noninteractive +ARG ROCMVERSION=7.2 +ARG compiler_version="" +ARG compiler_commit="" +ARG CK_SCCACHE="" +ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn +ENV DEBIAN_FRONTEND=noninteractive + +USER root + +# Add rocm repository +RUN dnf clean all && dnf update -y && dnf -v install wget gnupg2 curl -y + +RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2.70200-1.el8.noarch.rpm && \ + dnf install ./amdgpu-install-7.2.70200-1.el8.noarch.rpm -y && \ + dnf update -y && \ + dnf install python3-setuptools python3-wheel -y && \ + dnf install rocm-dev -y + +## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined +ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache +ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin +ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} +ENV CK_SCCACHE=$CK_SCCACHE +RUN if [ "$CK_SCCACHE" != "" ]; then \ + mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ + curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ + chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \ + fi + +# Install dependencies +RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \ + cmake \ + clang-tools-extra \ + gcc-c++ \ + libstdc++ \ + libstdc++-devel \ + libstdc++-static \ + git \ + hip-rocclr \ + jq \ + mpich \ + net-tools \ + pkg-config \ + redis \ + sshpass \ + stunnel \ + vim \ + nano \ + zip \ + openssh-server \ + kmod && \ + dnf clean all && \ + rm -rf /var/lib/apt/lists/* && \ + rm -rf amdgpu-install* && \ +#Install latest ccache + git clone https://github.com/ccache/ccache.git && \ + cd ccache && mkdir build && cd build && cmake .. && make install && \ +#Install ClangBuildAnalyzer + git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \ + cd ClangBuildAnalyzer/ && \ + make -f projects/make/Makefile && \ + cd / && \ +#Install latest cppcheck + git clone https://github.com/danmar/cppcheck.git && \ + cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \ + cd / && \ +# Install packages for processing the performance results + pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ +# Add render group + groupadd -f render && \ +# Install the new rocm-cmake version + git clone -b master https://github.com/ROCm/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install + +WORKDIR / +# Add alternative compilers, if necessary +ENV compiler_version=$compiler_version +ENV compiler_commit=$compiler_commit +RUN sh -c "echo compiler version = '$compiler_version'" && \ + sh -c "echo compiler commit = '$compiler_commit'" + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + diff --git a/Jenkinsfile b/Jenkinsfile index 9c670183fd..1a8be258bd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,10 +39,10 @@ def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. def failurePatterns = [ [pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"], - [pattern: /docker login failed/, description: "Docker login failed"], + [pattern: /(.*)docker login failed(.*)/, description: "Docker login failed"], [pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"], [pattern: /cat: .* No such file or directory/, description: "GPU not found"], - [pattern: /GPU not found/, description: "GPU not found"], + [pattern: /(.*)GPU not found(.*)/, description: "GPU not found"], [pattern: /Could not connect to Redis at .* Connection timed out/, description: "Redis connection timed out"] ] @@ -115,7 +115,7 @@ def generateAndArchiveBuildTraceVisualization(String buildTraceFileName) { // Run container to get snapshot def dockerOpts = "--cap-add=SYS_ADMIN -v \"\$(pwd)/workspace:/workspace\" -e NODE_PATH=/home/pptruser/node_modules -e BUILD_TRACE_FILE=${buildTraceFileName}" // Create unique image name by sanitizing job name - def sanitizedJobName = env.JOB_NAME.replaceAll(/[\/\\:*?"<>| ]/, '_') + def sanitizedJobName = env.JOB_NAME.replaceAll(/[\/\\:*?"<>| ]/, '_').replaceAll('%2F', '_') def architectureName = (buildTraceFileName =~ /(gfx[0-9a-zA-Z]+)/)[0][1] def imageName = "perfetto_snapshot_${sanitizedJobName}_build_${env.BUILD_NUMBER}_${architectureName}.png" sh """ @@ -580,7 +580,7 @@ def cmake_build(Map conf=[:]){ if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" } - if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { + if ((params.RUN_BUILDER_TESTS || params.RUN_FULL_CONV_TILE_TESTS) && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { setup_args = " -D CK_EXPERIMENTAL_BUILDER=ON " + setup_args } setup_cmd = conf.get( @@ -811,41 +811,12 @@ def Build_CK(Map conf=[:]){ archiveArtifacts "perf_*.log" stash includes: "perf_**.log", name: "perf_log_${arch}" } - // disable performance tests on gfx1030 for now. - //else if ( arch == "gfx10"){ - // run basic tests on gfx1030 - // echo "Run gemm performance tests" - // sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10" - // archiveArtifacts "perf_onnx_gemm_gfx10.log" - // stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10" - //} - else if ( arch == "gfx11"){ - // run basic tests on gfx11 + else if ( arch != "gfx10"){ + // run basic tests on gfx11/gfx12/gfx908/gfx950, but not on gfx10, it takes too long echo "Run gemm performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11" - archiveArtifacts "perf_onnx_gemm_gfx11.log" - stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11" - } - else if ( arch == "gfx120" ){ - // run basic tests on gfx12 - echo "Run gemm performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12" - archiveArtifacts "perf_onnx_gemm_gfx12.log" - stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" - } - else if ( arch == "gfx908" ){ - // run basic tests on gfx908 - echo "Run performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" - archiveArtifacts "perf_onnx_gemm_gfx908.log" - stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908" - } - else if ( arch == "gfx950" ){ - // run basic tests on gfx950 - echo "Run performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950" - archiveArtifacts "perf_onnx_gemm_gfx950.log" - stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} ${arch}" + archiveArtifacts "perf_onnx_gemm_*.log" + stash includes: "perf_onnx_gemm_**.log", name: "perf_log_${arch}" } } } @@ -1049,6 +1020,7 @@ def run_aiter_tests(Map conf=[:]){ sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_batch_prefill.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py" @@ -1119,7 +1091,7 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_ 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true - 0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true + 0 13 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true 0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : "" pipeline { @@ -1283,6 +1255,10 @@ pipeline { name: "RUN_AITER_TESTS", defaultValue: false, description: "Run AITER tests with latest CK develop branch (default: OFF)") + booleanParam( + name: "RUN_FULL_CONV_TILE_TESTS", + defaultValue: false, + description: "Run CK Tile grouped convolution tests with latest CK develop branch (default: OFF)") string( name: 'aiter_branch', defaultValue: 'main', @@ -1346,21 +1322,15 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "(cd .. && git ls-files \'*.h\' \ - \'*.hpp\' \ - \'*.cpp\' \ - \'*.h.in\' \ - \'*.hpp.in\' \ - \'*.cpp.in\' \ - \'*.cl\' \ - | grep -v 'build/' \ - | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\') && \ + execute_cmd = """cd .. && \ + find . -type f \\( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.cl' \\) \ + -not -path '*/build/*' -not -path '*/include/rapidjson/*' | \ + xargs -P 8 -I{} sh -c 'clang-format-18 -style=file {} | diff -u - {} || (echo "ERROR: {} needs formatting" && exit 1)' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ -U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \ - --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log" + --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd) @@ -1376,17 +1346,10 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "(cd .. && git ls-files \ - \'*.h\' \ - \'*.hpp\' \ - \'*.cpp\' \ - \'*.h.in\' \ - \'*.hpp.in\' \ - \'*.cpp.in\' \ - \'*.cl\' \ - | grep -v 'build/' \ - | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\')" + execute_cmd = """cd .. && \ + find . -type f \\( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.cl' \\) \ + -not -path '*/build/*' -not -path '*/include/rapidjson/*' | \ + xargs -P 8 -I{} sh -c 'clang-format-18 -style=file {} | diff -u - {} || (echo "ERROR: {} needs formatting" && exit 1)'""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd) @@ -1451,6 +1414,36 @@ pipeline { } } } + stage("Run Full Grouped Conv Tile Tests") + { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } + parallel + { + stage("Run Full Grouped Conv Tile Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_FULL_CONV_TILE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ python3 ../experimental/builder/src/generate_instances.py --mode=profiler && \ + ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 test_grouped_convnd_fwd_tile && \ + ./bin/test_grouped_convnd_fwd_tile""" + } + steps{ + // TODO: Reenable after the instance fixes + // buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } stage("Run Grouped Conv Large Case Tests") { when { @@ -1468,7 +1461,7 @@ pipeline { agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + execute_args = """ cmake .. --preset dev-gfx90a && \ make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ ./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" } @@ -1497,8 +1490,8 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ cd ../build && \ - ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_dataset_xdl \ + cmake .. --preset dev-gfx90a && \ + make -j64 test_grouped_convnd_fwd_dataset_xdl && \ test_grouped_convnd_bwd_data_dataset_xdl \ test_grouped_convnd_bwd_weight_dataset_xdl && \ cd ../test_data && \ @@ -1759,7 +1752,10 @@ pipeline { } agent{ label rocmnode("gfx90a") } environment{ - setup_args = """ -DGPU_TARGETS="gfx942" -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ + // SLES15 is a legacy platform with limited C++20 ecosystem support (older system libraries, + // standard library implementation). While the ROCm compiler supports C++20, the experimental + // CK Builder requires full C++20 feature support that does not be reliably available on SLES15. + setup_args = """ -DGPU_TARGETS="gfx942" -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 -DCK_EXPERIMENTAL_BUILDER=OFF """ execute_args = " " } steps{ diff --git a/dispatcher/CMakeLists.txt b/dispatcher/CMakeLists.txt new file mode 100644 index 0000000000..2acc73d1d5 --- /dev/null +++ b/dispatcher/CMakeLists.txt @@ -0,0 +1,117 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +cmake_minimum_required(VERSION 3.16) + +project(ck_tile_dispatcher VERSION 1.0.0 LANGUAGES CXX) + +# C++17 required +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# Find HIP for headers (needed for validation kernels) +find_package(hip QUIET) +if(NOT hip_FOUND) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip) + find_package(hip REQUIRED) +endif() + +# Dispatcher library +add_library(ck_tile_dispatcher + src/registry.cpp + src/dispatcher.cpp +) + +# Enable PIC for Python bindings +set_target_properties(ck_tile_dispatcher PROPERTIES + POSITION_INDEPENDENT_CODE ON +) + +target_include_directories(ck_tile_dispatcher + PUBLIC + $ + $ +) + +# Link against CK Tile headers (header-only) +target_include_directories(ck_tile_dispatcher + PUBLIC + $ + $ +) + +# Link against HIP headers if available +if(hip_FOUND) + target_link_libraries(ck_tile_dispatcher PUBLIC hip::host) +endif() + +# Compiler warnings +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + target_compile_options(ck_tile_dispatcher PRIVATE + -Wall -Wextra -Wpedantic + ) +elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + target_compile_options(ck_tile_dispatcher PRIVATE + /W4 + ) +endif() + +# Optional: Build tests +option(BUILD_DISPATCHER_TESTS "Build dispatcher unit tests" OFF) +if(BUILD_DISPATCHER_TESTS) + enable_testing() + add_subdirectory(tests) +endif() + +# Optional: Build Python bindings +option(BUILD_DISPATCHER_PYTHON "Build Python bindings for dispatcher" OFF) +if(BUILD_DISPATCHER_PYTHON) + add_subdirectory(python) +endif() + +# Optional: Codegen for tile_engine integration +option(DISPATCHER_AUTO_GENERATE_WRAPPERS "Auto-generate wrappers from tile_engine" OFF) +if(DISPATCHER_AUTO_GENERATE_WRAPPERS) + add_subdirectory(codegen) +endif() + +# Optional: Build examples +option(BUILD_DISPATCHER_EXAMPLES "Build dispatcher examples" OFF) +if(BUILD_DISPATCHER_EXAMPLES) + add_subdirectory(examples) +endif() + +# Optional: Build ctypes bindings +option(BUILD_DISPATCHER_BINDINGS "Build language bindings for dispatcher" OFF) +if(BUILD_DISPATCHER_BINDINGS) + add_subdirectory(bindings/ctypes) +endif() + +# If codegen is enabled, add generated include directory +if(DISPATCHER_AUTO_GENERATE_WRAPPERS AND DISPATCHER_GENERATED_INCLUDE_DIR) + target_include_directories(ck_tile_dispatcher + PUBLIC + $ + ) +endif() + +# Installation +install(TARGETS ck_tile_dispatcher + EXPORT ck_tile_dispatcher_targets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin +) + +install(DIRECTORY include/ + DESTINATION include + FILES_MATCHING PATTERN "*.hpp" +) + +install(EXPORT ck_tile_dispatcher_targets + FILE ck_tile_dispatcher_targets.cmake + NAMESPACE ck_tile:: + DESTINATION lib/cmake/ck_tile_dispatcher +) + diff --git a/dispatcher/README.md b/dispatcher/README.md new file mode 100644 index 0000000000..fa3fbd3a59 --- /dev/null +++ b/dispatcher/README.md @@ -0,0 +1,736 @@ +# CK Tile Dispatcher + +A unified kernel dispatch system for AMD GPUs with C++ and Python frontends. + +**Validated Platform:** AMD Instinct MI300 series (gfx942) + + +--- + +## Table of Contents + +1. [Quick Start](#quick-start) +2. [Docker Setup](#docker-setup-recommended) +3. [Prerequisites](#prerequisites) +4. [Step-by-Step Build Guide](#step-by-step-build-guide) +5. [Running Examples](#running-examples) +6. [External Integration](#external-integration) +7. [Core Concepts](#core-concepts) +8. [Troubleshooting](#troubleshooting) +9. [File Structure](#file-structure) + +--- + +## Quick Start + +**Complete setup from scratch (5 minutes):** + +```bash +# From the composable_kernel root directory +cd dispatcher + +# Step 1: Create build directory +mkdir -p build && cd build + +# Step 2: Configure CMake +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Step 3: Generate kernels and build (CMake handles this automatically) +make -j$(nproc) + +# Step 4: Run C++ examples +./examples/gemm_01_basic + +# Step 5: Build Python libraries (required for Python examples) +make python_libs + +# Step 6: Run Python examples (from dispatcher directory) +cd .. +python3 examples/gemm/python/01_basic_gemm.py +``` + +--- + +## Docker Setup (Recommended) + +For a reproducible build environment, use the official ROCm Docker image: + +### Step 1: Pull and Run Container + +```bash +# Pull the CK Docker image +docker pull rocm/composable_kernel:ck_ub24.04_rocm7.0.1 + +# Run container with GPU access +docker run \ + -it \ + --privileged \ + --device=/dev/kfd \ + --device=/dev/dri \ + --group-add video \ + --group-add render \ + -w /root/workspace \ + -v $(pwd):/root/workspace \ + rocm/composable_kernel:ck_ub24.04_rocm7.0.1 \ + /bin/bash +``` + +> **Note:** Omit `--device` flags if building without GPU access. + +### Step 2: Clone and Build + +```bash +# Inside the container +git clone https://github.com/ROCm/composable_kernel.git +cd composable_kernel +git checkout builder-dispatch-tile-gemm + +# Set up Python environment +python3 -m venv .venv +source .venv/bin/activate +pip install numpy + +# Build dispatcher +cd dispatcher +mkdir -p build && cd build +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +make -j$(nproc) +``` + +### One-Liner Build (inside container) + +```bash +git clone https://github.com/ROCm/composable_kernel.git && \ +cd composable_kernel && git checkout builder-dispatch-tile-gemm && \ +python3 -m venv .venv && source .venv/bin/activate && pip install numpy && \ +cd dispatcher && mkdir -p build && cd build && \ +cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release -DGPU_TARGETS="gfx942" -DBUILD_DISPATCHER_EXAMPLES=ON && \ +make -j$(nproc) +``` + +--- + +## Prerequisites + +### Required Software + +| Software | Minimum Version | Check Command | +|----------|-----------------|---------------| +| ROCm | 6.4+ | `rocminfo` | +| CMake | 3.16+ | `cmake --version` | +| Python | 3.8+ | `python3 --version` | +| NumPy | 1.20+ | `pip show numpy` | +| hipcc | (from ROCm) | `/opt/rocm/bin/hipcc --version` | + +> **Note:** Newer GPU targets (gfx950, gfx1201) require ROCm 6.3+. For ROCm 6.4+, you can also use `amdclang++` instead of `hipcc`. + +### Check Your GPU Architecture + +```bash +# Find your GPU architecture +rocminfo | grep -i "gfx" +# Example output: "gfx942" +``` + +**Supported architectures:** +- **gfx942** - MI300X, MI300A, MI308, MI325 (Instinct MI300 series) +- **gfx90a** - MI200 series (MI250, MI250X) +- **gfx950** - MI350 series +- **gfx1101** - RDNA3 series +- **gfx1201** - RDNA4 series + +### Install Python Dependencies + +NumPy is required for Python examples and kernel generation. We recommend using a virtual environment: + +**Option 1: Using standard venv** +```bash +# Create virtual environment +python3 -m venv .venv + +# Activate virtual environment +source .venv/bin/activate # Linux/macOS +# .venv\Scripts\activate # Windows + +# Install NumPy +pip install numpy +``` + +**Option 2: Using uv (faster alternative)** +```bash +# Install uv if not already installed +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Create and activate virtual environment +uv venv .venv +source .venv/bin/activate # Linux/macOS +# .venv\Scripts\activate # Windows + +# Install NumPy +uv pip install numpy +``` + +**Option 3: System-wide install (not recommended)** +```bash +pip install numpy +``` + +> **Note:** Always activate your virtual environment before running CMake or Python examples. + +### Supported Data Types + +CK Tile supports a wide range of data types for GEMM operations: + +| A dtype | B dtype | Acc dtype | Warp Tile Sizes | Notes | +|---------|---------|-----------|-----------------|-------| +| `fp32` | `fp32` | `fp32` | 16x16x4, 16x16x16 | Full precision | +| `fp16` | `fp16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Standard half | +| `bf16` | `bf16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Brain float 16 | +| `fp8` | `fp8` | `fp32` | 32x32x16, 32x32x32, 16x16x32, 16x16x64 | FP8 E4M3 | +| `fp8` | `bf8` | `fp32` | 32x32x16, 16x16x32 | Mixed FP8/BF8 | +| `bf8` | `fp8` | `fp32` | 32x32x16, 16x16x128 | Mixed BF8/FP8 | +| `bf8` | `bf8` | `fp32` | 32x32x16, 32x32x32, 16x16x32 | BF8 E5M2 | +| `int8` | `int8` | `int32` | 32x32x16, 16x16x32, 16x16x16 | Integer GEMM | +| `pk_fp4` | `pk_fp4` | `fp32` | 16x16x128 | Packed 4-bit float | + +**Notes:** +- Accumulator is always `fp32` except for `int8` which uses `int32` +- FP8 types: `fp8` = E4M3, `bf8` = E5M2 +- `pk_fp4` = Packed 4-bit float (2 values per byte) +- Some dtypes require specific GPU architectures (e.g., FP8 requires MI300+) + +--- + +## Step-by-Step Build Guide + +### Step 1: Navigate to Dispatcher Directory + +```bash +# From composable_kernel root +cd dispatcher + +# Verify you're in the right place +ls CMakeLists.txt # Should exist +``` + +### Step 2: Create Build Directory + +```bash +mkdir -p build +cd build +``` + +### Step 3: Configure CMake + +**Basic configuration (library only):** +```bash +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" +``` + +**Full configuration (with examples and tests):** +```bash +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON \ + -DBUILD_DISPATCHER_TESTS=ON +``` + +**Expected output:** +``` +-- Found hip: /opt/rocm (found suitable version "6.x.x") +-- Generating GEMM kernels... +-- Built: gemm_01 through gemm_06, dispatcher_gemm_lib.so +-- Configuring done +``` + +### Step 4: Build + +```bash +# Build all targets (generates kernels automatically, then compiles) +make -j$(nproc) + +# Or build specific targets +make gemm_01_basic # Single GEMM example +make dispatcher_gemm_lib # GEMM shared library for Python + +# Build ONLY Python libraries (faster if you don't need C++ examples) +make python_libs -j$(nproc) +``` + +### Kernel Generation Targets + +Kernels are generated automatically during `make`, but you can also control generation explicitly: + +```bash +# Generate all kernels only (no compilation) +make generate_all_kernels + +# Generate GEMM kernels only +make generate_gemm_kernels + +# Force regenerate (even if kernels exist) +make regenerate_all_kernels +make regenerate_gemm_kernels + +# Generate for specific GPU architecture +make generate_kernels_gfx942 # MI300X +make generate_kernels_gfx90a # MI200 +make generate_kernels_gfx1100 # RDNA3 +``` + +### Step 5: Verify Build + +```bash +# Check executables were built +ls examples/gemm_* + +# Check shared libraries were built +ls examples/libdispatcher_gemm_lib.so +``` + +### CMake Options Reference + +| Flag | Default | Description | +|------|---------|-------------| +| `CMAKE_BUILD_TYPE` | Debug | **Use `Release` for performance!** | +| `GPU_TARGETS` | None | Target GPU: `"gfx942"`, `"gfx90a"`, etc. | +| `BUILD_DISPATCHER_EXAMPLES` | OFF | Build C++ examples and Python libs | +| `BUILD_DISPATCHER_TESTS` | OFF | Build unit tests | +| `CMAKE_PREFIX_PATH` | - | ROCm installation path | +| `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler | + +⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower. +⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories). + +--- + +## Running Examples + +### C++ Examples + +After building, executables are in `build/examples/`: + +```bash +cd build/examples + +# GEMM Examples +./gemm_01_basic # Basic GEMM with autofill/autocorrect +./gemm_02_multi_size # Wildcard expansion +./gemm_03_benchmark_validation # Benchmarking + validation +./gemm_04_heuristics # Heuristic kernel selection +./gemm_05_json_export # Registry JSON export +./gemm_06_multi_registry # Multiple registries +``` + +### Python Examples + +Run from the `dispatcher` directory: + +```bash +cd /path/to/composable_kernel/dispatcher + +# GEMM Examples +python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM +python3 examples/gemm/python/04_validation.py # CPU reference validation +python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels) +python3 examples/gemm/python/08_heuristics.py # Heuristic selection +``` + +### Example Output + +**Expected C++ output (`gemm_01_basic`):** +``` +====================================================================== +Example 01: Basic GEMM with Declarative Kernel Definition +====================================================================== + +Step 1: Declared Kernels +------------------------ +Kernel Set: fp16_gemm_kernels + Architecture: gfx942 + Configurations: 1 + - gemm_fp16_rcr_compv4_cshuffle_intrawave_128x128x32 + +Step 2: Create Registry and Dispatcher +-------------------------------------- + Registered 1 kernels + +Step 3: Define Problem +---------------------- + M=1024, N=1024, K=1024 + +Step 4: GPU Execution +--------------------- + *** GPU EXECUTION *** + Time: ms + TFLOPS: +``` + +> **Note:** Timing values vary by GPU model and system configuration. + +--- + +## Benchmark Parameters + +The dispatcher supports fine-grained control over benchmarking, matching CK Tile's `stream_config`: + +### Available Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `warmup` | int | 5 | Warmup iterations (discarded from timing) | +| `repeat` | int | 20 | Benchmark iterations (averaged) | +| `flush_cache` | bool | false | Flush GPU L2 cache between iterations | +| `rotating_count` | int | 1 | Rotating buffer count (for cache simulation) | +| `timer` | string | "gpu" | Timer type: "gpu" (HIP events) or "cpu" | +| `init` | string | "random" | Matrix initialization: "random", "linear", "constant" | +| `split_k` | int | 1 | Split-K parallelism factor | + +### Python Usage + +```python +from ctypes_utils import DispatcherLib + +# Basic usage (default benchmark settings) +lib = DispatcherLib.load() + +# Advanced benchmark settings via command line +python3 examples/gemm/python/10_advanced_benchmark.py \ + --warmup 10 \ + --repeat 100 \ + --flush-cache +``` + +### C++ Usage + +```cpp +// Basic timing +ck_tile::stream_config cfg{nullptr, true}; + +// Advanced benchmark settings +ck_tile::stream_config cfg{ + nullptr, // stream_id (nullptr = default stream) + true, // time_kernel + 1, // log_level + 10, // cold_niters (warmup) + 100, // nrepeat + true, // is_gpu_timer + true, // flush_cache + 4 // rotating_count +}; + +float avg_time = kernel.run(args, cfg); +``` + +### Command Line (Python Examples) + +```bash +# Basic run +python3 examples/gemm/python/10_advanced_benchmark.py + +# With benchmark parameters +python3 examples/gemm/python/10_advanced_benchmark.py \ + --warmup 10 \ + --repeat 100 \ + --flush-cache \ + --rotating-count 4 \ + --timer gpu +``` + +### When to Use Each Parameter + +| Use Case | Recommended Settings | +|----------|---------------------| +| Quick test | `warmup=1, repeat=3` | +| Stable benchmark | `warmup=10, repeat=100` | +| Memory-bound analysis | `flush_cache=True, rotating_count=4` | +| Compute-bound analysis | `flush_cache=False` (default) | +| Debug timing | `timer="cpu"` | +| Production | `timer="gpu"` (default) | + +--- + +## External Integration + +### Using Dispatcher in Your Own Project + +#### Option 1: CMake Integration (Recommended) + +Add to your `CMakeLists.txt`: + +```cmake +# Set path to composable_kernel +set(CK_ROOT "/path/to/composable_kernel") + +# Add dispatcher subdirectory +add_subdirectory(${CK_ROOT}/dispatcher dispatcher_build) + +# Link to your target +target_link_libraries(your_target PRIVATE ck_tile_dispatcher) +target_include_directories(your_target PRIVATE + ${CK_ROOT}/dispatcher/include + ${CK_ROOT}/include +) +``` + +#### Option 2: Include as Pre-built Library + +```cmake +# Find the pre-built library +find_library(CK_DISPATCHER ck_tile_dispatcher + PATHS /path/to/composable_kernel/dispatcher/build) + +# Include directories +set(CK_INCLUDE_DIRS + /path/to/composable_kernel/include + /path/to/composable_kernel/dispatcher/include +) + +target_link_libraries(your_target PRIVATE ${CK_DISPATCHER}) +target_include_directories(your_target PRIVATE ${CK_INCLUDE_DIRS}) +``` + +#### Option 3: Python Integration + +```python +import sys +sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python") + +# For GEMM +from ctypes_utils import DispatcherLib, Dispatcher, KernelConfig +``` + +### Required Include Paths + +When integrating, you need these include paths: + +``` +/path/to/composable_kernel/include # CK Tile core headers +/path/to/composable_kernel/dispatcher/include # Dispatcher headers +/path/to/composable_kernel/dispatcher/build/generated_kernels # Generated kernels +``` + +### Required Compile Flags + +```bash +# Minimum flags for hipcc +-std=c++17 +-D__HIP_PLATFORM_AMD__=1 +--offload-arch=gfx942 # Your target GPU + +# Recommended flags +-O3 +-mllvm -enable-noalias-to-md-conversion=0 +-Wno-undefined-func-template +-Wno-float-equal +-Wall +-Werror +``` + +### Python Path Setup + +For Python scripts outside the dispatcher directory: + +```bash +# Option 1: Environment variable +export PYTHONPATH="/path/to/composable_kernel/dispatcher/examples/gemm/python:$PYTHONPATH" + +# Option 2: In your Python script +import sys +sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python") +``` + +### Library Search Paths + +The Python utilities search for the shared library in these locations: + +```python +# For GEMM (ctypes_utils.py) +SEARCH_PATHS = [ + "build/examples/libdispatcher_gemm_lib.so", + "../build/examples/libdispatcher_gemm_lib.so", + "../../build/examples/libdispatcher_gemm_lib.so", +] +``` + +If using from a different location, set the library path explicitly: + +```python +# GEMM +from ctypes_utils import DispatcherLib +lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so") +``` + +--- + +## Core Concepts + +### Data Flow + +``` +KernelConfig → Registry → Dispatcher → GPU Execution +``` + +1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts) +2. **Registry**: Stores multiple kernel configurations +3. **Dispatcher**: Selects best kernel for a given problem and executes it + +### GEMM Layouts + +| Layout | A | B | C | Use Case | +|--------|---|---|---|----------| +| RCR | Row | Col | Row | Most common (PyTorch default) | +| RRR | Row | Row | Row | Both inputs row-major | +| CRR | Col | Row | Row | A transposed | +| CCR | Col | Col | Row | Both inputs column-major | + +### Split-K Support + +Split-K divides the K dimension across multiple thread blocks, useful for large K dimensions. + +**Usage (C++):** +```cpp +// GEMM with 4-way K split +auto problem = ProblemBuilder() + .m(1024).n(1024).k(8192) + .split_k(4) + .build(); +``` + +--- + +## Troubleshooting + +### Build Issues + +| Problem | Solution | +|---------|----------| +| `hipcc not found` | Set `-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc` | +| `hip not found` | Set `-DCMAKE_PREFIX_PATH=/opt/rocm` | +| Very slow performance | Use `-DCMAKE_BUILD_TYPE=Release` | +| `gfx942 not supported` | Check ROCm version (need 6.0+) | +| Kernel generation fails | Ensure Python 3.8+ with NumPy installed in active venv | +| Build errors | First verify CK builds without dispatcher (see main CK README) | + +### Runtime Issues + +| Problem | Solution | +|---------|----------| +| `Library not found` | Build with `-DBUILD_DISPATCHER_EXAMPLES=ON` | +| `No kernel found` | Check GPU arch matches build target | +| Python `ModuleNotFoundError` | Add paths to `PYTHONPATH` (see above) | +| Wrong results | Verify layout matches your data | + +### Debug Commands + +```bash +# Check ROCm installation +rocminfo | head -20 + +# Check GPU architecture +rocminfo | grep "Name:" + +# Verify library exists +ls -la build/examples/libdispatcher_*.so + +# Run with verbose output +./build/examples/gemm_01_basic 2>&1 + +# Python: Check library loading +python3 -c " +import ctypes +lib = ctypes.CDLL('/path/to/libdispatcher_gemm_lib.so') +print('Library loaded successfully') +" +``` + +### Clean Rebuild + +If you encounter issues, try a clean rebuild: + +```bash +cd dispatcher +rm -rf build +mkdir build && cd build +cmake .. [your options] +make -j$(nproc) +``` + +--- + +## File Structure + +``` +dispatcher/ +├── README.md # This file +├── CMakeLists.txt # Build configuration +│ +├── include/ck_tile/dispatcher/ # C++ headers +│ ├── dispatcher.hpp # GEMM dispatcher +│ ├── registry.hpp # Kernel registry +│ └── kernel_key.hpp # Kernel configuration +│ +├── src/ # C++ implementation +│ +├── codegen/ # Kernel generation +│ ├── unified_gemm_codegen.py # GEMM kernel generator +│ └── arch_specs.json # GPU specifications +│ +├── bindings/ctypes/ # Python ctypes interface +│ └── gemm_ctypes_lib.cpp # GEMM Python library +│ +├── examples/ # Examples +│ └── gemm/ +│ ├── cpp/ # C++ GEMM examples (01-06) +│ └── python/ # Python GEMM examples (01-11) +│ +├── scripts/ # Build scripts +│ +└── tests/ # Unit tests +``` + +--- + +## Example Documentation + +| Directory | README | +|-----------|--------| +| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) | +| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) | +| Codegen | [codegen/README.md](codegen/README.md) | + +--- + +## Archived Content + +Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`: +- `examples/conv/cpp/` - 11 C++ convolution examples +- `examples/conv/python/` - 14 Python convolution examples +- `codegen/unified_conv_codegen.py` - Conv kernel generator +- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers +- `python/conv_utils.py` - Conv Python utilities + +--- + +## License + +MIT License - Copyright (c) 2025, Advanced Micro Devices, Inc. diff --git a/dispatcher/bindings/README.md b/dispatcher/bindings/README.md new file mode 100644 index 0000000000..7cda21f6ec --- /dev/null +++ b/dispatcher/bindings/README.md @@ -0,0 +1,109 @@ +# CK Tile Dispatcher - Language Bindings + +This directory contains language bindings for the CK Tile Dispatcher. + +## Structure + +``` +bindings/ +├── ctypes/ # Python ctypes bindings (C API) +│ ├── gemm_ctypes_lib.cpp # GEMM dispatcher C API +│ ├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data) +│ ├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API +│ ├── gpu_helper.cpp # CLI helper for Python +│ └── CMakeLists.txt +└── README.md +``` + +## ctypes Bindings + +The ctypes bindings provide a C API that Python can load via `ctypes.CDLL()`. + +### Building + +```bash +cd build +cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm +make dispatcher_gemm_lib dispatcher_conv_lib gpu_helper +``` + +### Usage from Python + +```python +import ctypes + +# Load the library +lib = ctypes.CDLL("path/to/libdispatcher_gemm_lib.so") + +# Initialize +lib.dispatcher_init() + +# Check if problem is supported +is_supported = lib.dispatcher_is_supported(M, N, K) + +# Run GEMM +time_ms = ctypes.c_float() +result = lib.dispatcher_run_gemm( + A_ptr, B_ptr, C_ptr, + M, N, K, + ctypes.byref(time_ms) +) + +# Cleanup +lib.dispatcher_cleanup() +``` + +### GEMM API + +| Function | Description | +|----------|-------------| +| `dispatcher_init()` | Initialize the dispatcher | +| `dispatcher_is_supported(M, N, K)` | Check if problem size is supported | +| `dispatcher_select_kernel(M, N, K, name_buf, buf_size)` | Get kernel name for problem | +| `dispatcher_run_gemm(A, B, C, M, N, K, time_ms)` | Execute GEMM | +| `dispatcher_get_kernel_count()` | Get number of registered kernels | +| `dispatcher_export_registry_json()` | Export registry as JSON | +| `dispatcher_cleanup()` | Release resources | + +### Convolution API + +| Function | Description | +|----------|-------------| +| `conv_dispatcher_init()` | Initialize the dispatcher | +| `conv_dispatcher_is_supported(prob)` | Check if problem is supported | +| `conv_dispatcher_select_kernel(prob, name_buf, buf_size)` | Get kernel name | +| `conv_dispatcher_run(input, weight, output, prob, stream)` | Execute convolution | +| `conv_dispatcher_get_kernel_count()` | Get number of registered kernels | +| `conv_dispatcher_cleanup()` | Release resources | + +## GPU Helper + +The `gpu_helper` executable provides a CLI interface for Python: + +```bash +./gpu_helper 1024 1024 1024 --validate +``` + +Output is JSON for easy parsing: +```json +{ + "problem": {"M": 1024, "N": 1024, "K": 1024}, + "kernel": "gemm_fp16_rcr_...", + "execution": { + "time_ms": 0.5, + "tflops": 4.2 + }, + "validation": { + "accuracy": 100.0 + }, + "status": "success" +} +``` + +## Examples + +See the examples that use these bindings: + +- **GEMM**: `dispatcher/examples/gemm/python/` +- **Conv**: `dispatcher/examples/conv/python/` + diff --git a/dispatcher/bindings/ctypes/CMakeLists.txt b/dispatcher/bindings/ctypes/CMakeLists.txt new file mode 100644 index 0000000000..804e5e9bd7 --- /dev/null +++ b/dispatcher/bindings/ctypes/CMakeLists.txt @@ -0,0 +1,181 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# ============================================================================= +# CK Tile Dispatcher - ctypes Bindings +# ============================================================================= +# +# Provides shared libraries with C API for Python ctypes integration. +# +# Targets: +# - dispatcher_gemm_lib : GEMM dispatcher library +# - dispatcher_conv_lib : Convolution dispatcher library (forward + bwd_data) +# - dispatcher_conv_bwdw_lib : Convolution backward weight library +# - gpu_helper : GPU helper executable for Python +# + +cmake_minimum_required(VERSION 3.16) + +# Helper function to add a ctypes library +function(add_ctypes_library TARGET_NAME SOURCE_FILE) + cmake_parse_arguments(ARG "CONV" "KERNEL_HEADER" "" ${ARGN}) + + add_library(${TARGET_NAME} SHARED ${SOURCE_FILE}) + + target_include_directories(${TARGET_NAME} PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + + target_link_libraries(${TARGET_NAME} PRIVATE + hip::device + ) + + # Force-include kernel header if provided + if(ARG_KERNEL_HEADER AND EXISTS ${ARG_KERNEL_HEADER}) + target_compile_options(${TARGET_NAME} PRIVATE + -include ${ARG_KERNEL_HEADER} + ) + if(ARG_CONV) + target_compile_definitions(${TARGET_NAME} PRIVATE CONV_KERNEL_AVAILABLE) + endif() + endif() + + set_target_properties(${TARGET_NAME} PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) +endfunction() + +# ============================================================================= +# GEMM ctypes Library +# ============================================================================= + +# Find a generated GEMM kernel header for the library +file(GLOB GEMM_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/gemm_*.hpp") +if(GEMM_KERNEL_HEADERS) + list(GET GEMM_KERNEL_HEADERS 0 GEMM_KERNEL_HEADER) + message(STATUS "Found GEMM kernel for ctypes lib: ${GEMM_KERNEL_HEADER}") + + add_ctypes_library(dispatcher_gemm_lib + gemm_ctypes_lib.cpp + KERNEL_HEADER ${GEMM_KERNEL_HEADER} + ) +else() + message(STATUS "No GEMM kernel found for ctypes lib - building without kernel") + add_library(dispatcher_gemm_lib SHARED gemm_ctypes_lib.cpp) + target_include_directories(dispatcher_gemm_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device) +endif() + +# ============================================================================= +# Convolution ctypes Library (supports forward + bwd_data) +# ============================================================================= + +# Look for forward kernels +file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp") +# Look for backward data kernels +file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp") +# Fallback: any conv kernel (for backwards compatibility) +file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp") + +add_library(dispatcher_conv_lib SHARED conv_ctypes_lib.cpp) +target_include_directories(dispatcher_conv_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include +) +target_link_libraries(dispatcher_conv_lib PRIVATE hip::device) +set_target_properties(dispatcher_conv_lib PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 +) + +# Add forward kernel if available +if(CONV_FWD_KERNEL_HEADERS) + list(GET CONV_FWD_KERNEL_HEADERS 0 CONV_FWD_KERNEL_HEADER) + message(STATUS "Found Conv FWD kernel for ctypes lib: ${CONV_FWD_KERNEL_HEADER}") + target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_FWD_KERNEL_HEADER}) + target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE) +elseif(CONV_KERNEL_HEADERS) + # Fallback to any conv kernel + list(GET CONV_KERNEL_HEADERS 0 CONV_KERNEL_HEADER) + message(STATUS "Found Conv kernel for ctypes lib: ${CONV_KERNEL_HEADER}") + target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_KERNEL_HEADER}) + target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE) +else() + message(STATUS "No Conv FWD kernel found for ctypes lib - building without kernel") +endif() + +# Add backward data kernel if available +if(CONV_BWDD_KERNEL_HEADERS) + list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER) + message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}") + target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER}) + target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE) +endif() + +# ============================================================================= +# Convolution Backward Weight ctypes Library (separate lib for bwd_weight) +# ============================================================================= + +file(GLOB CONV_BWDW_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*bwd_weight*.hpp") +if(CONV_BWDW_KERNEL_HEADERS) + list(GET CONV_BWDW_KERNEL_HEADERS 0 CONV_BWDW_KERNEL_HEADER) + message(STATUS "Found Conv BwdWeight kernel for ctypes lib: ${CONV_BWDW_KERNEL_HEADER}") + + add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp) + target_include_directories(dispatcher_conv_bwdw_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device) + target_compile_options(dispatcher_conv_bwdw_lib PRIVATE + -include ${CONV_BWDW_KERNEL_HEADER} + ) + target_compile_definitions(dispatcher_conv_bwdw_lib PRIVATE CONV_BWD_WEIGHT_AVAILABLE) + set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) +else() + message(STATUS "No Conv BwdWeight kernel found for ctypes lib - building without kernel") + add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp) + target_include_directories(dispatcher_conv_bwdw_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device) + set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) +endif() + +# ============================================================================= +# GPU Helper Executable +# ============================================================================= + +if(GEMM_KERNEL_HEADERS) + add_executable(gpu_helper gpu_helper.cpp) + + target_include_directories(gpu_helper PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + + target_link_libraries(gpu_helper PRIVATE + hip::device + ) + + target_compile_options(gpu_helper PRIVATE + -include ${GEMM_KERNEL_HEADER} + ) + + set_target_properties(gpu_helper PROPERTIES + CXX_STANDARD 17 + ) +endif() + diff --git a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp new file mode 100644 index 0000000000..09e058f80f --- /dev/null +++ b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp @@ -0,0 +1,175 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Convolution Backward Weight Dispatcher ctypes Library + * + * SEPARATE library for backward weight to avoid template conflicts with + * forward/backward_data kernels in the main conv_ctypes_lib. + * + * Usage from Python: + * lib = ctypes.CDLL("libdispatcher_conv_bwdw_lib.so") + * lib.conv_bwdw_init() + * lib.conv_bwdw_run(...) + */ + +#include +#include +#include + +// Minimal includes - matching the C++ example +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/gemm.hpp" // Must be before grouped_convolution for TileGemmTraits +#include "ck_tile/ops/grouped_convolution.hpp" + +// Global state - minimal, no registry needed for direct launch +static bool g_bwdw_initialized = false; + +extern "C" { + +// ============================================================================= +// Initialization (minimal - just sets flag) +// ============================================================================= + +int conv_bwdw_init() +{ + g_bwdw_initialized = true; + return 0; // Return 0 on success (consistent with other init functions) +} + +void conv_bwdw_cleanup() { g_bwdw_initialized = false; } + +// ============================================================================= +// Problem Structure (same as main library) +// ============================================================================= + +struct ConvBwdwProblemC +{ + int N, G, C, K; + int input_d, input_h, input_w; + int filter_z, filter_y, filter_x; + int stride_d, stride_h, stride_w; + int pad_d, pad_h, pad_w; + int dilation_d, dilation_h, dilation_w; +}; + +// ============================================================================= +// Backward Weight Execution +// ============================================================================= + +#ifdef CONV_BWD_WEIGHT_AVAILABLE +static ck_tile::conv::ConvParam build_conv_param(const ConvBwdwProblemC* prob) +{ + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + + if(is_3d) + { + return ck_tile::conv::ConvParam{3, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_z, prob->filter_y, prob->filter_x}, + {prob->input_d, prob->input_h, prob->input_w}, + {prob->stride_d, prob->stride_h, prob->stride_w}, + {prob->dilation_d, prob->dilation_h, prob->dilation_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}}; + } + else + { + return ck_tile::conv::ConvParam{2, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_y, prob->filter_x}, + {prob->input_h, prob->input_w}, + {prob->stride_h, prob->stride_w}, + {prob->dilation_h, prob->dilation_w}, + {prob->pad_h, prob->pad_w}, + {prob->pad_h, prob->pad_w}}; + } +} + +static float run_bwd_weight_impl(const void* input_ptr, + const void* grad_output_ptr, + void* grad_weight_ptr, + const ConvBwdwProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + // Backward weight: A=input, B=grad_output, C=grad_weight + ck_tile::GroupedConvBwdWeightHostArgs args(conv_param, + input_ptr, // in_ptr = input + grad_weight_ptr, // wei_ptr = grad_weight (output) + {}, // ds_ptr + grad_output_ptr, // out_ptr = grad_output + 1 // k_batch + ); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + return SelectedConvBwdWeightLauncher::launch(args, stream_cfg); +} +#endif + +float conv_bwdw_run(const void* input_ptr, + const void* grad_output_ptr, + void* grad_weight_ptr, + const ConvBwdwProblemC* prob, + void* stream) +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + // Validate all required pointers before kernel launch + if(!g_bwdw_initialized || !prob) + return -1.0f; + if(!input_ptr || !grad_output_ptr || !grad_weight_ptr) + return -1.0f; // Null data pointer would cause kernel crash + return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream); +#else + return -1.0f; +#endif +} + +// ============================================================================= +// Info +// ============================================================================= + +const char* conv_bwdw_version() { return "1.0.0"; } + +int conv_bwdw_has_kernels() +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + return 1; +#else + return 0; +#endif +} + +int conv_bwdw_get_kernel_count() +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + return 1; +#else + return 0; +#endif +} + +int conv_bwdw_get_kernel_name(int index, char* buffer, int buffer_size) +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + if(index != 0 || !buffer || buffer_size <= 0) + return -1; + std::strncpy(buffer, CONV_BWD_WEIGHT_KERNEL_NAME, buffer_size - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +#else + return -1; +#endif +} + +} // extern "C" diff --git a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp new file mode 100644 index 0000000000..d3c64621a7 --- /dev/null +++ b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp @@ -0,0 +1,411 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Convolution Dispatcher ctypes Library + * + * Provides C API for Python ctypes integration. + * Supports forward convolution. Backward operations require additional headers. + * + * REQUIRED: Forward kernel header must be force-included via -include flag. + * OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE + * + * Usage from Python: + * lib = ctypes.CDLL("libdispatcher_conv.so") + * lib.conv_dispatcher_init() + * lib.conv_dispatcher_run(...) + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher/conv_utils.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +using namespace ck_tile::dispatcher; + +// Global state (using shared_ptr for safe memory management) +static std::shared_ptr g_registry = nullptr; +static std::shared_ptr g_dispatcher = nullptr; +static std::vector g_kernels; + +extern "C" { + +// ============================================================================= +// Initialization +// ============================================================================= + +int conv_dispatcher_init() +{ + if(g_registry) + return 0; // Already initialized + + g_registry = std::make_shared(); + g_dispatcher = std::make_shared(g_registry.get()); + + // Register kernel configurations using simple ConvKernelSet + // (actual kernel launch uses the force-included SelectedConvKernelLauncher) + using namespace ck_tile::dispatcher::conv_decl; + + // Forward kernels (required - must be force-included) + // Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb + ConvKernelSet fwd_set; + fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + ConvAlgorithm() + .tile(128, 128, 64) // tile_m x tile_n x tile_k + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave"), + "gfx942"); + g_registry->register_set(fwd_set, ConvRegistry::Priority::High); + +#ifdef CONV_BWD_DATA_AVAILABLE + // Backward data kernels + // Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16 + ConvKernelSet bwd_data_set; + bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), + ConvAlgorithm() + .tile(128, 128, 64) // tile_m x tile_n x tile_k + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave"), + "gfx942"); + g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High); +#endif + + return 0; +} + +int conv_dispatcher_cleanup() +{ + // shared_ptr automatically handles cleanup when reset + g_dispatcher.reset(); + g_registry.reset(); + g_kernels.clear(); + return 0; +} + +// ============================================================================= +// Registry Management +// ============================================================================= + +int conv_dispatcher_get_kernel_count() +{ + if(!g_registry) + return 0; + return static_cast(g_registry->size()); +} + +int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size) +{ + if(index < 0 || !buffer || buffer_size <= 0) + return -1; + + if(!g_registry) + return -1; + + // Use registry to get kernel names (they are registered with full names) + const auto& kernels = g_registry->all_kernels(); + if(static_cast(index) >= kernels.size()) + return -1; + + const auto* kernel = kernels[index]; + std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +} + +// ============================================================================= +// Problem Definition +// ============================================================================= + +struct ConvProblemC +{ + int N, G, C, K; + int input_d, input_h, input_w; + int filter_z, filter_y, filter_x; + int stride_d, stride_h, stride_w; + int pad_d, pad_h, pad_w; + int dilation_d, dilation_h, dilation_w; + int direction; // 0=forward, 1=bwd_data, 2=bwd_weight +}; + +// ============================================================================= +// Kernel Selection +// ============================================================================= + +int conv_dispatcher_is_supported(const ConvProblemC* prob) +{ + if(!g_registry || !prob) + return 0; + + ConvProblem problem; + problem.N = prob->N; + problem.G = prob->G; + problem.C = prob->C; + problem.K = prob->K; + problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; + problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; + problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; + problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; + problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; + problem.op = static_cast(prob->direction); + problem.compute_output_size(); + + const auto* kernel = g_dispatcher->select(problem); + return kernel ? 1 : 0; +} + +int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size) +{ + if(!g_registry || !prob || !kernel_name || buffer_size <= 0) + return -1; + + ConvProblem problem; + problem.N = prob->N; + problem.G = prob->G; + problem.C = prob->C; + problem.K = prob->K; + problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; + problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; + problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; + problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; + problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; + problem.op = static_cast(prob->direction); + problem.compute_output_size(); + + const auto* kernel = g_dispatcher->select(problem); + if(!kernel) + return -1; + + std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1); + kernel_name[buffer_size - 1] = '\0'; + + return 0; +} + +// ============================================================================= +// Convolution Execution +// ============================================================================= + +// Helper to build ConvParam +static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob) +{ + // Determine if this is 2D or 3D convolution + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + + if(is_3d) + { + // 3D convolution: use all spatial dimensions + return ck_tile::conv::ConvParam{3, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_z, prob->filter_y, prob->filter_x}, + {prob->input_d, prob->input_h, prob->input_w}, + {prob->stride_d, prob->stride_h, prob->stride_w}, + {prob->dilation_d, prob->dilation_h, prob->dilation_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}}; + } + else + { + // 2D convolution: only use H, W dimensions + return ck_tile::conv::ConvParam{2, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_y, prob->filter_x}, + {prob->input_h, prob->input_w}, + {prob->stride_h, prob->stride_w}, + {prob->dilation_h, prob->dilation_w}, + {prob->pad_h, prob->pad_w}, + {prob->pad_h, prob->pad_w}}; + } +} + +// Forward convolution (required - kernel header must be force-included) +static float run_forward(const void* input_ptr, + const void* weight_ptr, + void* output_ptr, + const ConvProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + // SelectedConvKernelLauncher is defined in the force-included forward kernel header + return SelectedConvKernelLauncher::launch(args, stream_cfg); +} + +#ifdef CONV_BWD_DATA_AVAILABLE +// Backward data convolution (optional) +// Computes: grad_input = conv_bwd_data(weight, grad_output) +// +// Parameters: +// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT) +// weight_ptr: W - frozen weights (const, read-only INPUT) +// grad_input_ptr: dX - gradient for input (writable, OUTPUT) +static float run_bwd_data(const void* grad_output_ptr, + const void* weight_ptr, + void* grad_input_ptr, + const ConvProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + // CK Tile API uses tensor POSITION names (from forward pass), not data flow: + // in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data) + // wei_ptr = weight tensor = weight_ptr (W, const) + // out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data) + ck_tile::GroupedConvBwdDataHostArgs args( + conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + return SelectedConvBwdDataLauncher::launch(args, stream_cfg); +} +#endif + +#ifdef CONV_BWD_WEIGHT_AVAILABLE +// Backward weight convolution (optional) +// Parameters: +// input_ptr: original forward input X (const, read-only) +// grad_output_ptr: gradient from next layer dY (const, read-only) +// grad_weight_ptr: gradient of weights dW (writable, OUTPUT) +static float run_bwd_weight(const void* input_ptr, + const void* grad_output_ptr, + void* grad_weight_ptr, + const ConvProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + // GroupedConvBwdWeightHostArgs constructor order: + // (param, in=X, wei=dW (output), ds, out=dY (input), k_batch) + // Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output) + ck_tile::GroupedConvBwdWeightHostArgs args( + conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + return SelectedConvBwdWeightLauncher::launch(args, stream_cfg); +} +#endif + +/** + * @brief Execute convolution based on direction specified in prob + * + * Parameter mapping varies by direction: + * Forward (direction=0): + * input_ptr = X (input tensor) + * weight_ptr = W (weight tensor) + * output_ptr = Y (output buffer) + * + * Backward Data (direction=1): + * input_ptr = dY (grad_output - gradient from next layer) + * weight_ptr = W (weight tensor, frozen) + * output_ptr = dX (grad_input buffer) + * + * Backward Weight (direction=2): + * input_ptr = X (forward input tensor) + * weight_ptr = dY (grad_output - gradient from next layer) + * output_ptr = dW (grad_weight buffer) + */ +float conv_dispatcher_run(const void* input_ptr, + const void* weight_ptr, + void* output_ptr, + const ConvProblemC* prob, + void* stream) +{ + // Validate all required pointers before kernel launch + if(!g_dispatcher || !prob) + return -1.0f; + if(!input_ptr || !weight_ptr || !output_ptr) + return -1.0f; // Null data pointer would cause kernel crash + + // Build problem for kernel selection + ConvProblem problem; + problem.N = prob->N; + problem.G = prob->G; + problem.C = prob->C; + problem.K = prob->K; + problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; + problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; + problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; + problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; + problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; + problem.op = static_cast(prob->direction); + problem.compute_output_size(); + + // Select kernel + const auto* kernel = g_dispatcher->select(problem); + if(!kernel) + return -1.0f; + + // Dispatch based on direction + switch(prob->direction) + { + case 0: // Forward (always available) + return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream); + +#ifdef CONV_BWD_DATA_AVAILABLE + case 1: // Backward data + // Convention: caller passes (grad_output, weight, grad_input_buffer) + // in the (input_ptr, weight_ptr, output_ptr) slots respectively. + // run_bwd_data expects: (grad_output, weight, grad_input) + return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream); +#endif + +#ifdef CONV_BWD_WEIGHT_AVAILABLE + case 2: // Backward weight + // Convention: caller passes (input, grad_output, grad_weight_buffer) + // in the (input_ptr, weight_ptr, output_ptr) slots respectively. + // run_bwd_weight expects: (input, grad_output, grad_weight) + return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream); +#endif + + default: return -1.0f; + } +} + +// ============================================================================= +// Info +// ============================================================================= + +const char* conv_dispatcher_version() { return "1.0.0"; } + +int conv_dispatcher_has_kernels() +{ + return 1; // Forward kernel is required +} + +int conv_dispatcher_has_bwd_data() +{ +#ifdef CONV_BWD_DATA_AVAILABLE + return 1; +#else + return 0; +#endif +} + +int conv_dispatcher_has_bwd_weight() +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + return 1; +#else + return 0; +#endif +} + +} // extern "C" diff --git a/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp b/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp new file mode 100644 index 0000000000..85c0c2f2c1 --- /dev/null +++ b/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp @@ -0,0 +1,401 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * GEMM Dispatcher ctypes Library + * + * Provides C API for Python ctypes integration. + * Kernel header included via -include at compile time. + * + * Usage from Python: + * lib = ctypes.CDLL("libdispatcher_gemm.so") + * lib.dispatcher_init() + * lib.dispatcher_run_gemm(...) + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag +// Defines: ADataType, BDataType, CDataType, AccDataType, SelectedKernel, KERNEL_NAME + +// GPU architecture - can be overridden via -DGFX_ARCH="gfx90a" at compile time +#ifndef GFX_ARCH +#define GFX_ARCH "gfx942" +#endif + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +// Global dispatcher (initialized once, managed via shared_ptr for safe cleanup) +static std::shared_ptr g_dispatcher = nullptr; +static bool g_initialized = false; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + return -1; \ + } \ + } + +extern "C" { + +/** + * Initialize dispatcher with a kernel + * Must be called before run_gemm + * + * Returns: 0 on success, -1 on error + */ +int dispatcher_initialize() +{ + if(g_initialized) + { + return 0; // Already initialized + } + + // Create kernel key from the force-included kernel header + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = GFX_ARCH; + + // Register kernel using types from force-included header + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + // Create dispatcher (using shared_ptr for safe memory management) + g_dispatcher = std::make_shared(); + g_initialized = true; + + return 0; +} + +/** + * Get kernel tile configuration + */ +int dispatcher_get_kernel_config(int* tile_m, + int* tile_n, + int* tile_k, + int* warp_tile_m, + int* warp_tile_n, + int* warp_tile_k, + int* warp_m, + int* warp_n, + int* warp_k) +{ + if(!g_initialized) + { + return -1; + } + + auto kernels = Registry::instance().get_all(); + if(kernels.empty()) + { + return -1; + } + + // Get configuration from first kernel + auto& key = kernels[0]->get_key(); + auto& algo = key.algorithm; + + if(tile_m) + *tile_m = algo.tile_shape.m; + if(tile_n) + *tile_n = algo.tile_shape.n; + if(tile_k) + *tile_k = algo.tile_shape.k; + if(warp_tile_m) + *warp_tile_m = algo.warp_tile_shape.m; + if(warp_tile_n) + *warp_tile_n = algo.warp_tile_shape.n; + if(warp_tile_k) + *warp_tile_k = algo.warp_tile_shape.k; + if(warp_m) + *warp_m = algo.wave_shape.m; + if(warp_n) + *warp_n = algo.wave_shape.n; + if(warp_k) + *warp_k = algo.wave_shape.k; + + return 0; +} + +/** + * Get the selected kernel name for a problem + */ +int dispatcher_select_kernel(int64_t M, int64_t N, int64_t K, char* name_buffer, int buffer_size) +{ + if(!g_initialized || !name_buffer || buffer_size <= 0) + { + return -1; + } + + Problem problem(M, N, K); + auto kernel = g_dispatcher->select_kernel(problem); + + if(!kernel) + { + return -1; + } + + std::string name = kernel->get_name(); + strncpy(name_buffer, name.c_str(), buffer_size - 1); + name_buffer[buffer_size - 1] = '\0'; + + return 0; +} + +/** + * Check if a problem size is supported by available kernels + */ +int dispatcher_is_supported(int64_t M, int64_t N, int64_t K) +{ + if(!g_initialized) + { + return 0; + } + + if(M <= 0 || N <= 0 || K <= 0) + { + return 0; + } + + Problem problem(M, N, K); + auto kernel = g_dispatcher->select_kernel(problem); + return kernel != nullptr ? 1 : 0; +} + +/** + * Run GEMM on GPU via dispatcher + */ +int dispatcher_run_gemm( + const void* A, const void* B, void* C, int64_t M, int64_t N, int64_t K, float* time_ms) +{ + if(!g_initialized || !A || !B || !C) + { + return -1; + } + + // First check if any kernel supports this problem + Problem problem(M, N, K); + auto kernel = g_dispatcher->select_kernel(problem); + if(!kernel) + { + if(time_ms) + { + *time_ms = -1.0f; + } + return -2; // No suitable kernel + } + + // Cast to correct types (from force-included header) + const ADataType* A_host = static_cast(A); + const BDataType* B_host = static_cast(B); + CDataType* C_host = static_cast(C); + + // Allocate GPU memory + ADataType* A_dev = nullptr; + BDataType* B_dev = nullptr; + CDataType* C_dev = nullptr; + + auto cleanup_gpu_mem = [&]() { + if(A_dev) + (void)hipFree(A_dev); + if(B_dev) + (void)hipFree(B_dev); + if(C_dev) + (void)hipFree(C_dev); + }; + + if(hipMalloc(&A_dev, M * K * sizeof(ADataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMalloc(&B_dev, K * N * sizeof(BDataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMalloc(&C_dev, M * N * sizeof(CDataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + + // Copy input data to GPU + if(hipMemcpy(A_dev, A_host, M * K * sizeof(ADataType), hipMemcpyHostToDevice) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMemcpy(B_dev, B_host, K * N * sizeof(BDataType), hipMemcpyHostToDevice) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMemset(C_dev, 0, M * N * sizeof(CDataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + + // Run GEMM via dispatcher + float exec_time; + try + { + exec_time = g_dispatcher->run(A_dev, B_dev, C_dev, problem); + } + catch(const std::exception& e) + { + cleanup_gpu_mem(); + return -1; + } + + // Copy result back to host + if(hipMemcpy(C_host, C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + + if(time_ms) + { + *time_ms = exec_time; + } + + cleanup_gpu_mem(); + return 0; +} + +/** + * Get kernel information + */ +const char* dispatcher_get_kernel_name() { return KERNEL_NAME; } + +/** + * Initialize dispatcher (alias) + */ +int dispatcher_init() { return dispatcher_initialize(); } + +/** + * Get the number of registered kernels + */ +int dispatcher_get_kernel_count() { return static_cast(Registry::instance().size()); } + +/** + * Export registry to JSON string + */ +static std::string g_json_buffer; + +const char* dispatcher_export_registry_json() +{ + auto& registry = Registry::instance(); + + std::ostringstream json; + json << "{\n"; + json << " \"metadata\": {\n"; + json << " \"timestamp\": \"" << __DATE__ << " " << __TIME__ << "\",\n"; + json << " \"total_kernels\": " << registry.size() << ",\n"; + json << " \"export_version\": \"1.0\",\n"; + json << " \"dispatcher_version\": \"1.0.0\"\n"; + json << " },\n"; + json << " \"statistics\": {\n"; + json << " \"by_datatype\": {},\n"; + json << " \"by_pipeline\": {},\n"; + json << " \"by_scheduler\": {}\n"; + json << " },\n"; + json << " \"kernels\": [\n"; + + auto kernels = registry.get_all(); + for(size_t i = 0; i < kernels.size(); ++i) + { + auto& kernel = kernels[i]; + auto& key = kernel->get_key(); + auto& algo = key.algorithm; + std::string name = kernel->get_name(); + + json << " {\n"; + json << " \"identifier\": \"" << key.encode_identifier() << "\",\n"; + json << " \"name\": \"" << name << "\",\n"; + json << " \"algorithm\": {\n"; + json << " \"tile_shape\": {\"m\": " << algo.tile_shape.m + << ", \"n\": " << algo.tile_shape.n << ", \"k\": " << algo.tile_shape.k << "},\n"; + json << " \"wave_shape\": {\"m\": " << unsigned(algo.wave_shape.m) + << ", \"n\": " << unsigned(algo.wave_shape.n) + << ", \"k\": " << unsigned(algo.wave_shape.k) << "},\n"; + json << " \"warp_tile_shape\": {\"m\": " << unsigned(algo.warp_tile_shape.m) + << ", \"n\": " << unsigned(algo.warp_tile_shape.n) + << ", \"k\": " << unsigned(algo.warp_tile_shape.k) << "},\n"; + json << " \"block_size\": " << algo.block_size << ",\n"; + json << " \"persistent\": " << (algo.persistent ? "true" : "false") << ",\n"; + json << " \"double_buffer\": " << (algo.double_buffer ? "true" : "false") << ",\n"; + json << " \"preshuffle\": " << (algo.preshuffle ? "true" : "false") << ",\n"; + json << " \"transpose_c\": " << (algo.transpose_c ? "true" : "false") << "\n"; + json << " }\n"; + json << " }"; + if(i < kernels.size() - 1) + { + json << ","; + } + json << "\n"; + } + + json << " ]\n"; + json << "}\n"; + + g_json_buffer = json.str(); + return g_json_buffer.c_str(); +} + +/** + * Cleanup dispatcher resources + */ +void dispatcher_cleanup() +{ + g_dispatcher.reset(); + g_initialized = false; +} + +} // extern "C" diff --git a/dispatcher/bindings/ctypes/gpu_helper.cpp b/dispatcher/bindings/ctypes/gpu_helper.cpp new file mode 100644 index 0000000000..1c72c14e39 --- /dev/null +++ b/dispatcher/bindings/ctypes/gpu_helper.cpp @@ -0,0 +1,206 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * GPU Helper - C++ executable for GPU GEMM execution + * + * A CLI tool for Python to execute GPU GEMM with generated kernels. + * Usage: gpu_helper [--validate] + * + * Kernel header included via -include flag at compile time. + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP_ERROR: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +// CPU reference GEMM (for validation) +template +void cpu_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + // A: RowMajor, B: ColumnMajor + acc += float(A[m * K + k]) * float(B[k + n * K]); + } + C[m * N + n] = T(acc); + } + } +} + +int main(int argc, char** argv) +{ + // Parse arguments + if(argc < 4) + { + std::cerr << "Usage: " << argv[0] << " [--validate]\n"; + std::cerr << "\nOptions:\n"; + std::cerr << " M, N, K : Problem dimensions\n"; + std::cerr << " --validate : Compare GPU results with CPU reference\n"; + return 1; + } + + int M = std::atoi(argv[1]); + int N = std::atoi(argv[2]); + int K = std::atoi(argv[3]); + bool validate = (argc > 4 && std::string(argv[4]) == "--validate"); + + // Output in JSON-like format for easy Python parsing + std::cout << "{" << std::endl; + std::cout << " \"problem\": {\"M\": " << M << ", \"N\": " << N << ", \"K\": " << K << "}," + << std::endl; + std::cout << " \"kernel\": \"" << KERNEL_NAME << "\"," << std::endl; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cout << " \"error\": \"No kernel selected\"" << std::endl; + std::cout << "}" << std::endl; + return 1; + } + + std::cout << " \"selected_kernel\": \"" << selected->get_name() << "\"," << std::endl; + + // Prepare data: A=1, B=1, so C should be K + std::vector A_host(M * K, ADataType(1.0f)); + std::vector B_host(K * N, BDataType(1.0f)); + std::vector C_gpu(M * N); + + // GPU execution + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Calculate performance + double flops = 2.0 * M * N * K; + double tflops = (flops / (gpu_time * 1e-3)) / 1e12; + + std::cout << " \"execution\": {" << std::endl; + std::cout << " \"time_ms\": " << gpu_time << "," << std::endl; + std::cout << " \"tflops\": " << tflops << "," << std::endl; + std::cout << " \"flops\": " << (long long)flops << std::endl; + std::cout << " }," << std::endl; + + // Validation + if(validate) + { + std::vector C_cpu(M * N); + cpu_gemm(A_host, B_host, C_cpu, M, N, K); + + int correct = 0; + float max_error = 0.0f; + + for(int i = 0; i < M * N; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + float error = std::abs(gpu_val - cpu_val) / (std::abs(cpu_val) + 1e-5f); + + max_error = std::max(max_error, error); + + if(error < 0.02f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + + std::cout << " \"validation\": {" << std::endl; + std::cout << " \"accuracy\": " << accuracy << "," << std::endl; + std::cout << " \"max_error\": " << max_error << "," << std::endl; + std::cout << " \"correct_elements\": " << correct << "," << std::endl; + std::cout << " \"total_elements\": " << M * N << std::endl; + std::cout << " }," << std::endl; + } + + std::cout << " \"status\": \"success\"" << std::endl; + std::cout << "}" << std::endl; + + // Cleanup + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + return 0; +} diff --git a/dispatcher/codegen/ADDING_NEW_GPU.md b/dispatcher/codegen/ADDING_NEW_GPU.md new file mode 100644 index 0000000000..0bd2966a85 --- /dev/null +++ b/dispatcher/codegen/ADDING_NEW_GPU.md @@ -0,0 +1,197 @@ +# Adding New GPU Architecture Support + +Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatcher. + +> **See also:** [Main Dispatcher README](../README.md) | [Codegen README](README.md) + +## Overview + +The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications: + +``` +arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python) + → arch_specs_generated.hpp (C++) +``` + +## Quick Start + +```bash +# 1. Edit arch_specs.json +# 2. Run generator +python generate_arch_specs.py +# 3. Rebuild +cd ../build && cmake --build . -j8 +# 4. Test +ctest +``` + +## Step-by-Step Guide + +### Step 1: Edit arch_specs.json + +Add new architecture under `"architectures"`: + +```json +{ + "architectures": { + "gfx1100": { + "family": "rdna3", + "description": "AMD Radeon RX 7000 series (RDNA3)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]], + "bf16_bf16_bf16": [[16, 16, 16], [32, 32, 16]] + } + } + } +} +``` + +### Step 2: Configuration Fields + +| Field | Description | Example | +|-------|-------------|---------| +| `family` | GPU family | `"cdna3"`, `"rdna4"` | +| `description` | Human-readable name | `"AMD Instinct MI300"` | +| `warp_size` | Wave/warp size | `64` (CDNA), `32` (RDNA) | +| `lds_capacity_kb` | LDS memory in KB | `64` | +| `warp_configs` | Valid `[warp_m, warp_n, warp_k]` | `[[2,2,1], [4,4,1]]` | +| `warp_tile_combos` | Warp tiles per dtype | See below | + +### Step 3: Warp Tile Combinations + +Map data type combinations to valid warp tile sizes: + +```json +"warp_tile_combos": { + "fp16_fp16_fp16": [[32, 32, 8], [16, 16, 16], [32, 32, 16]], + "bf16_bf16_bf16": [[32, 32, 8], [16, 16, 16]], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]] +} +``` + +Key format: `{A_dtype}_{B_dtype}_{C_dtype}` + +### Step 4: Run Generator + +```bash +cd dispatcher/codegen +python generate_arch_specs.py +``` + +This generates: +- `arch_specs_generated.py` (Python module) +- `../include/ck_tile/dispatcher/arch_specs_generated.hpp` (C++ header) + +### Step 5: Rebuild and Test + +```bash +cd ../build +cmake --build . -j8 +ctest --output-on-failure +``` + +### Step 6: Verify + +```python +from arch_filter import ArchFilter + +filter = ArchFilter("gfx1100") +is_valid = filter.is_kernel_valid( + datatype_a="fp16", datatype_b="fp16", datatype_c="fp16", + tile_m=128, tile_n=128, tile_k=32, + warp_m=2, warp_n=2, warp_k=1, + warp_tile_m=16, warp_tile_n=16, warp_tile_k=16 +) +print(f"Valid: {is_valid}") +``` + +## Reference + +### Supported Data Types + +| Key | Description | +|-----|-------------| +| `fp16` | Half precision (16-bit) | +| `bf16` | Brain float 16 | +| `fp32` | Single precision (32-bit) | +| `fp64` | Double precision (64-bit) | +| `fp8` | 8-bit float (E4M3) | +| `bf8` | 8-bit brain float (E5M2) | +| `int8` | 8-bit integer | +| `int4` | 4-bit integer | + +### GPU Families + +| Family | Description | +|--------|-------------| +| `cdna2` | MI200 series (gfx90a) | +| `cdna3` | MI300 series (gfx942) | +| `cdna4` | MI350 series (gfx950) | +| `rdna3` | RX 7000 series (gfx1100) | +| `rdna4` | RX 9000 series (gfx1201) | + +### Pipeline LDS Limits + +| Pipeline | LDS Limit | +|----------|-----------| +| `compv4` | 32 KB | +| `preshufflev2` | 32 KB | +| `default` | 64 KB | + +## Troubleshooting + +### "Unknown GPU architecture" + +1. Check architecture key matches exactly (e.g., `"gfx942"` not `"GFX942"`) +2. Verify you ran `generate_arch_specs.py` +3. Rebuild C++ code + +### Kernels being rejected + +```python +from arch_filter import ArchFilter, KernelConfig + +filter = ArchFilter("gfx942") +result = filter.validate_kernel(config) +print(f"Valid: {result.valid}") +for error in result.errors: + print(f" Error: {error}") +``` + +### Missing warp tile combination + +1. Check `warp_tile_combos` in `arch_specs.json` +2. Ensure `[warp_tile_m, warp_tile_n, warp_tile_k]` is in the list +3. Verify data type key format + +## File Structure + +``` +codegen/ +├── arch_specs.json # Single source of truth (EDIT THIS) +├── generate_arch_specs.py # Generator script +├── arch_specs_generated.py # Generated Python module +└── ADDING_NEW_GPU.md # This file + +include/ck_tile/dispatcher/ +├── arch_specs_generated.hpp # Generated C++ header +└── arch_filter.hpp # C++ filter +``` + +## Best Practices + +1. **Test thoroughly** - Run all tests after adding a new GPU +2. **Start minimal** - Add only validated configurations +3. **Document sources** - Note where warp tile combinations came from +4. **Keep in sync** - If using tile_engine, keep both updated + +--- + +> **More info:** See [../README.md](../README.md) for full documentation. diff --git a/dispatcher/codegen/CMakeLists.txt b/dispatcher/codegen/CMakeLists.txt new file mode 100644 index 0000000000..e63dcaab67 --- /dev/null +++ b/dispatcher/codegen/CMakeLists.txt @@ -0,0 +1,125 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Tile GEMM Unified Code Generator + +cmake_minimum_required(VERSION 3.16) + +# Find Python +find_package(Python3 COMPONENTS Interpreter REQUIRED) + +# Configuration +set(CODEGEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/unified_gemm_codegen.py") +set(CODEGEN_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json") +set(CODEGEN_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm") + +# Configurable options +set(CK_TILE_GEMM_DATATYPE "fp16" CACHE STRING "GEMM data type (fp16, bf16, fp32, fp8, bf8, int8)") +set(CK_TILE_GEMM_LAYOUT "rcr" CACHE STRING "GEMM layout (rcr, rrr, crr, ccr)") +set(CK_TILE_GEMM_VARIANTS "standard" CACHE STRING "GEMM variants (standard, preshuffle, multi_d)") +set(CK_TILE_GEMM_GPU_TARGET "gfx942" CACHE STRING "Target GPU architecture") +set(CK_TILE_GEMM_PARALLEL ON CACHE BOOL "Enable parallel generation") + +# Custom target to run code generation +add_custom_target(generate_tile_gemm_kernels + COMMAND ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT} + --output-dir ${CODEGEN_OUTPUT_DIR} + --datatype ${CK_TILE_GEMM_DATATYPE} + --layout ${CK_TILE_GEMM_LAYOUT} + --gpu-target ${CK_TILE_GEMM_GPU_TARGET} + --config ${CODEGEN_CONFIG} + --variants ${CK_TILE_GEMM_VARIANTS} + $<$>:--no-parallel> + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Generating CK Tile GEMM kernels and dispatcher wrappers..." + VERBATIM +) + +# Create output directory +file(MAKE_DIRECTORY ${CODEGEN_OUTPUT_DIR}) + +# Add generated headers to include path +include_directories(${CODEGEN_OUTPUT_DIR}) + +# Installation +install(FILES + ${CODEGEN_SCRIPT} + ${CODEGEN_CONFIG} + README.md + DESTINATION share/ck_tile/codegen +) + +# Helper function for projects to generate kernels +function(ck_tile_generate_gemm_kernels) + set(options PARALLEL) + set(oneValueArgs OUTPUT_DIR DATATYPE LAYOUT GPU_TARGET CONFIG) + set(multiValueArgs VARIANTS) + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + # Set defaults + if(NOT ARG_OUTPUT_DIR) + set(ARG_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm") + endif() + if(NOT ARG_DATATYPE) + set(ARG_DATATYPE "fp16") + endif() + if(NOT ARG_LAYOUT) + set(ARG_LAYOUT "rcr") + endif() + if(NOT ARG_GPU_TARGET) + set(ARG_GPU_TARGET "gfx942") + endif() + if(NOT ARG_CONFIG) + set(ARG_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json") + endif() + if(NOT ARG_VARIANTS) + set(ARG_VARIANTS "standard") + endif() + + # Build command + set(CMD ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT} + --output-dir ${ARG_OUTPUT_DIR} + --datatype ${ARG_DATATYPE} + --layout ${ARG_LAYOUT} + --gpu-target ${ARG_GPU_TARGET} + --config ${ARG_CONFIG} + --variants ${ARG_VARIANTS} + ) + + if(NOT ARG_PARALLEL) + list(APPEND CMD --no-parallel) + endif() + + # Execute + execute_process( + COMMAND ${CMD} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE RESULT + OUTPUT_VARIABLE OUTPUT + ERROR_VARIABLE ERROR + ) + + if(NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to generate GEMM kernels:\n${ERROR}") + else() + message(STATUS "Generated GEMM kernels: ${OUTPUT}") + endif() +endfunction() + +# Example usage documentation +message(STATUS "CK Tile GEMM Code Generator configured") +message(STATUS " Script: ${CODEGEN_SCRIPT}") +message(STATUS " Config: ${CODEGEN_CONFIG}") +message(STATUS " Output: ${CODEGEN_OUTPUT_DIR}") +message(STATUS "") +message(STATUS "To generate kernels:") +message(STATUS " cmake --build . --target generate_tile_gemm_kernels") +message(STATUS "") +message(STATUS "Or use CMake function:") +message(STATUS " ck_tile_generate_gemm_kernels(") +message(STATUS " OUTPUT_DIR ./generated") +message(STATUS " DATATYPE fp16") +message(STATUS " LAYOUT rcr") +message(STATUS " VARIANTS standard preshuffle multi_d") +message(STATUS " PARALLEL") +message(STATUS " )") diff --git a/dispatcher/codegen/README.md b/dispatcher/codegen/README.md new file mode 100644 index 0000000000..2d753924f5 --- /dev/null +++ b/dispatcher/codegen/README.md @@ -0,0 +1,123 @@ +# CK Tile GEMM Unified Code Generator + +Single source of truth for all GEMM kernel generation. + +> **See also:** [Main Dispatcher README](../README.md) for installation and core concepts. + +## Quick Start + +```bash +cd dispatcher/codegen + +# Generate standard FP16 kernels +python3 unified_gemm_codegen.py \ + --output-dir ../build/generated_kernels \ + --datatype fp16 \ + --layout rcr \ + --variants standard + +# Generate all variants +python3 unified_gemm_codegen.py \ + --output-dir ../build/generated_kernels \ + --variants standard preshuffle multi_d +``` + +## Using from Python + +```python +from ctypes_utils import CodegenRunner, KernelConfig + +# Generate from specific config +config = KernelConfig(tile_m=256, tile_n=256, tile_k=64) +codegen = CodegenRunner() +result = codegen.generate_from_config(config) + +# Generate variant +result = codegen.generate("preshuffle") + +# Generate all +results = codegen.generate_all() +``` + +## Command Line Options + +| Option | Values | Description | +|--------|--------|-------------| +| `--output-dir` | path | Output directory | +| `--datatype` | `fp16`, `bf16`, `fp32`, `int8` | Data type | +| `--layout` | `rcr`, `rrr`, `crr`, `ccr` | Matrix layouts | +| `--gpu-target` | `gfx942`, `gfx90a`, `gfx950` | Target GPU | +| `--variants` | `standard`, `preshuffle`, `multi_d` | Kernel variants | +| `--preselected` | `fp16_rcr_essential`, etc. | Predefined kernel set | + +### Layout Notation + +- `R` = Row-major, `C` = Column-major +- Order: A, B, C (e.g., `rcr` = A row, B col, C row) + +## Variants + +### Standard +Basic GEMM: `C = A × B` + +### PreShuffle +Optimized weight access with LDS pre-shuffling. Best for large matrices. + +### Multi-D +Element-wise fusion: `C = op(A × B + D0 + D1 + ...)` + +Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh` + +## Output Structure + +``` +generated_kernels/ +├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp +├── gemm_fp16_rcr_compv4_..._preshuffle.hpp +├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp +└── ... +``` + +## Configuration Files + +### arch_specs.json + +GPU architecture specifications (single source of truth): + +```json +{ + "architectures": { + "gfx942": { + "family": "cdna3", + "warp_size": 64, + "warp_configs": [[2, 2, 1], [4, 4, 1]], + ... + } + } +} +``` + +### preselected_kernels.py + +Curated kernel sets for common use cases. + +## Adding New GPU Support + +See [ADDING_NEW_GPU.md](ADDING_NEW_GPU.md) for complete guide. + +Quick steps: +1. Edit `arch_specs.json` +2. Run `python generate_arch_specs.py` +3. Rebuild + +## Troubleshooting + +| Issue | Solution | +|-------|----------| +| "Arguments not supported" | Check tile config validity | +| Missing element-wise op | Check `elementwise_ops.hpp` | +| Compilation errors | Verify C++17, include paths | + +--- + +> **More info:** See [../README.md](../README.md) for full documentation. diff --git a/dispatcher/codegen/arch_filter.py b/dispatcher/codegen/arch_filter.py new file mode 100644 index 0000000000..67f146045b --- /dev/null +++ b/dispatcher/codegen/arch_filter.py @@ -0,0 +1,1012 @@ +#!/usr/bin/env python + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Architecture-Specific Kernel Filtering for CK Tile Dispatcher + +Unified filtering mechanism for validating kernel configurations against +GPU architecture capabilities. Uses arch_specs.json as single source of truth. + +Key Features: +- GPU architecture-specific warp tile and warp configuration validation +- Data type compatibility checking +- Trait combination validation (pipeline, epilogue, scheduler) +- LDS capacity validation +- Single source of truth (arch_specs.json) + +Usage: + from arch_filter import ArchFilter, get_supported_archs + + # Create filter for specific architecture + filter = ArchFilter("gfx942") + + # Validate a kernel configuration + is_valid = filter.is_kernel_valid( + datatype_a="fp16", datatype_b="fp16", datatype_c="fp16", + tile_m=256, tile_n=256, tile_k=64, + warp_m=2, warp_n=2, warp_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", epilogue="cshuffle", scheduler="intrawave" + ) + + # Get detailed validation results + result = filter.validate_kernel_detailed(...) + print(result.valid, result.errors) +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Any +from enum import Enum +import logging + +logger = logging.getLogger(__name__) + + +class OperatorType(Enum): + """Supported operator types for kernel validation""" + + GEMM = "gemm" + GEMM_PRESHUFFLE = "gemm_preshuffle" + GEMM_MULTI_D = "gemm_multi_d" + CONV_FWD = "conv_fwd" + CONV_BWD_DATA = "conv_bwd_data" + CONV_BWD_WEIGHT = "conv_bwd_weight" + CONV3D_FWD = "conv3d_fwd" + CONV3D_BWD_DATA = "conv3d_bwd_data" + CONV3D_BWD_WEIGHT = "conv3d_bwd_weight" + + +# Operator-specific tile constraints +# Different operators may have different minimum tile sizes or alignment requirements +OPERATOR_TILE_CONSTRAINTS = { + OperatorType.GEMM: { + "min_tile_m": 16, + "min_tile_n": 16, + "min_tile_k": 8, + "tile_m_alignment": 16, + "tile_n_alignment": 16, + "tile_k_alignment": 8, + }, + OperatorType.GEMM_PRESHUFFLE: { + "min_tile_m": 64, + "min_tile_n": 64, + "min_tile_k": 32, + "tile_m_alignment": 32, + "tile_n_alignment": 32, + "tile_k_alignment": 16, + }, + OperatorType.GEMM_MULTI_D: { + "min_tile_m": 16, + "min_tile_n": 16, + "min_tile_k": 8, + "tile_m_alignment": 16, + "tile_n_alignment": 16, + "tile_k_alignment": 8, + }, + OperatorType.CONV_FWD: { + "min_tile_m": 1, # N dimension can be 1 + "min_tile_n": 16, # K (output channels) should be reasonable + "min_tile_k": 16, # C (input channels) should be reasonable + "tile_m_alignment": 1, + "tile_n_alignment": 16, + "tile_k_alignment": 16, + }, + OperatorType.CONV_BWD_DATA: { + "min_tile_m": 1, + "min_tile_n": 16, # C (input channels) + "min_tile_k": 16, # K (output channels) + "tile_m_alignment": 1, + "tile_n_alignment": 16, + "tile_k_alignment": 16, + }, + OperatorType.CONV_BWD_WEIGHT: { + "min_tile_m": 16, # K (output channels) + "min_tile_n": 16, # C (input channels) + "min_tile_k": 1, # Spatial reduction dimension + "tile_m_alignment": 16, + "tile_n_alignment": 16, + "tile_k_alignment": 1, + }, +} + +# Add 3D convolution constraints (same as 2D for now) +OPERATOR_TILE_CONSTRAINTS[OperatorType.CONV3D_FWD] = OPERATOR_TILE_CONSTRAINTS[ + OperatorType.CONV_FWD +] +OPERATOR_TILE_CONSTRAINTS[OperatorType.CONV3D_BWD_DATA] = OPERATOR_TILE_CONSTRAINTS[ + OperatorType.CONV_BWD_DATA +] +OPERATOR_TILE_CONSTRAINTS[OperatorType.CONV3D_BWD_WEIGHT] = OPERATOR_TILE_CONSTRAINTS[ + OperatorType.CONV_BWD_WEIGHT +] + +# ============================================================================= +# Import from Generated Module (Single Source of Truth) +# ============================================================================= + +# Try to import from the generated module (created from arch_specs.json) +try: + from arch_specs_generated import ( + ARCH_FAMILY_MAP, + ELEMENT_SIZE_MAP, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS, + PRESHUFFLE_PIPELINES, + LDS_CAPACITY_LIMITS, + TRAIT_UNSUPPORTED_COMBINATIONS, + DTYPE_COMBINATIONS, + ) + + _USING_GENERATED = True +except ImportError: + # Fallback to hardcoded values if generated module not available + logger.warning( + "arch_specs_generated.py not found, using fallback values. " + "Run 'python generate_arch_specs.py' to generate." + ) + _USING_GENERATED = False + + # Fallback data (minimal subset for basic operation) + ARCH_FAMILY_MAP = { + "gfx90a": "cdna2", + "gfx942": "cdna3", + "gfx950": "cdna4", + "gfx1201": "rdna4", + } + + ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "int32": 4, + } + + WARP_SUPPORTED_COMBINATIONS = { + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], + } + + WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": { + # Key format: A_B_Acc (e.g., fp16_fp16_fp32 = A/B are fp16, accumulator is fp32) + # These match tile_engine's GEMM_WARP_TILE_SUPPORTED_COMBINATIONS + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + } + + # Preshuffle-specific warp tile combinations (no [4, 64, 16]) + PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + }, + } + + PRESHUFFLE_PIPELINES = ["preshufflev2"] + + LDS_CAPACITY_LIMITS = {"compv4": 32768, "preshufflev2": 32768, "default": 65536} + + TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + } + + DTYPE_COMBINATIONS = { + "fp32_fp32": {"acc": "fp32", "notes": "Full precision"}, + "fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"}, + "bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"}, + "fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"}, + "fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"}, + "bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"}, + "bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"}, + "int8_int8": {"acc": "int32", "notes": "Integer GEMM"}, + "pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"}, + } + + +# ============================================================================= +# GPU Family Enum (for backwards compatibility) +# ============================================================================= + + +class GpuFamily(Enum): + """GPU architecture families""" + + CDNA2 = "cdna2" + CDNA3 = "cdna3" + CDNA4 = "cdna4" + RDNA4 = "rdna4" + + +# ============================================================================= +# Dtype Validation Helpers +# ============================================================================= + + +def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool: + """Check if a dtype combination is valid for GEMM.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + return key in DTYPE_COMBINATIONS + + +def get_dtype_acc(dtype_a: str, dtype_b: str) -> str: + """Get the accumulator type for a dtype combination.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + info = DTYPE_COMBINATIONS.get(key, {"acc": "fp32"}) + return info["acc"] + + +def get_valid_dtype_combos() -> List[str]: + """Get list of all valid dtype combinations.""" + return list(DTYPE_COMBINATIONS.keys()) + + +# ============================================================================= +# Validation Result Types +# ============================================================================= + + +@dataclass +class ValidationResult: + """Result of kernel configuration validation""" + + valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + + def __bool__(self) -> bool: + return self.valid + + def add_error(self, msg: str): + self.errors.append(msg) + self.valid = False + + def add_warning(self, msg: str): + self.warnings.append(msg) + + +@dataclass +class KernelConfig: + """Kernel configuration for validation""" + + # Data types + datatype_a: str + datatype_b: str + datatype_c: str + + # Tile dimensions + tile_m: int + tile_n: int + tile_k: int + + # Warp configuration + warp_m: int + warp_n: int + warp_k: int + + # Warp tile dimensions + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + # Traits + pipeline: str = "compv4" + epilogue: str = "cshuffle" + scheduler: str = "intrawave" + + # Layout (for whole-workgroup cover validation) + layout: str = "rcr" + + # Operator type (affects validation rules) + operator: OperatorType = OperatorType.GEMM + + @property + def dtype_key(self) -> str: + """Generate data type combination key for warp tile lookup. + + Uses accumulator dtype (not output C type) to match the format + used in WARP_TILE_SUPPORTED_COMBINATIONS dictionaries which are + keyed as {datatype_a}_{datatype_b}_{accumulator_dtype}. + """ + acc_dtype = get_dtype_acc(self.datatype_a, self.datatype_b) + return f"{self.datatype_a}_{self.datatype_b}_{acc_dtype}" + + +# ============================================================================= +# Architecture Filter Class +# ============================================================================= + + +class ArchFilter: + """ + Architecture-specific kernel configuration filter. + + Validates kernel configurations against GPU architecture capabilities + to ensure only compatible kernels are registered. + + Example: + filter = ArchFilter("gfx942") + + # Quick validation + if filter.is_kernel_valid(config): + registry.register_kernel(kernel) + + # Detailed validation with error messages + result = filter.validate_kernel(config) + if not result.valid: + for error in result.errors: + print(f"Validation failed: {error}") + """ + + def __init__(self, gpu_arch: str, strict_mode: bool = True): + """ + Initialize architecture filter. + + Args: + gpu_arch: GPU architecture string (e.g., "gfx942", "gfx90a") + strict_mode: If True, unknown configurations are rejected. + If False, unknown configurations pass with warnings. + """ + self.gpu_arch = gpu_arch.lower() + self.strict_mode = strict_mode + self.family = ARCH_FAMILY_MAP.get(self.gpu_arch) + + if self.family is None and strict_mode: + raise ValueError( + f"Unknown GPU architecture: {gpu_arch}. " + f"Supported: {list(ARCH_FAMILY_MAP.keys())}" + ) + + def validate_kernel(self, config: KernelConfig) -> ValidationResult: + """ + Validate a kernel configuration against architecture constraints. + + Validation is performed based on the operator type, as different + operators (GEMM, Conv FWD, Conv BWD) have different constraints. + + Args: + config: Kernel configuration to validate + + Returns: + ValidationResult with valid flag and error/warning messages + """ + result = ValidationResult(valid=True) + + # Operator-specific tile constraint validation + self._validate_operator_constraints(config, result) + if not result.valid and self.strict_mode: + return result + + # Basic sanity checks + self._validate_dimensions(config, result) + if not result.valid and self.strict_mode: + return result + + # Warp configuration validation + self._validate_warp_config(config, result) + + # Warp tile combination validation + self._validate_warp_tile_combo(config, result) + + # Trait combination validation + self._validate_trait_combo(config, result) + + # LDS capacity validation + self._validate_lds_capacity(config, result) + + # Dimension alignment validation + self._validate_dimension_alignment(config, result) + + return result + + def _validate_operator_constraints( + self, config: KernelConfig, result: ValidationResult + ): + """Validate operator-specific tile constraints""" + constraints = OPERATOR_TILE_CONSTRAINTS.get(config.operator) + + if constraints is None: + # Unknown operator - add warning but don't fail + result.add_warning( + f"Unknown operator type: {config.operator}. " + f"Skipping operator-specific validation." + ) + return + + # Validate minimum tile sizes + min_tile_m = constraints.get("min_tile_m", 1) + min_tile_n = constraints.get("min_tile_n", 1) + min_tile_k = constraints.get("min_tile_k", 1) + + if config.tile_m < min_tile_m: + result.add_error( + f"Operator {config.operator.value}: tile_m ({config.tile_m}) " + f"< minimum ({min_tile_m})" + ) + if config.tile_n < min_tile_n: + result.add_error( + f"Operator {config.operator.value}: tile_n ({config.tile_n}) " + f"< minimum ({min_tile_n})" + ) + if config.tile_k < min_tile_k: + result.add_error( + f"Operator {config.operator.value}: tile_k ({config.tile_k}) " + f"< minimum ({min_tile_k})" + ) + + # Validate tile alignment + tile_m_align = constraints.get("tile_m_alignment", 1) + tile_n_align = constraints.get("tile_n_alignment", 1) + tile_k_align = constraints.get("tile_k_alignment", 1) + + if tile_m_align > 1 and config.tile_m % tile_m_align != 0: + result.add_error( + f"Operator {config.operator.value}: tile_m ({config.tile_m}) " + f"must be aligned to {tile_m_align}" + ) + if tile_n_align > 1 and config.tile_n % tile_n_align != 0: + result.add_error( + f"Operator {config.operator.value}: tile_n ({config.tile_n}) " + f"must be aligned to {tile_n_align}" + ) + if tile_k_align > 1 and config.tile_k % tile_k_align != 0: + result.add_error( + f"Operator {config.operator.value}: tile_k ({config.tile_k}) " + f"must be aligned to {tile_k_align}" + ) + + def is_kernel_valid( + self, + datatype_a: str = "fp16", + datatype_b: str = "fp16", + datatype_c: str = "fp16", + tile_m: int = 256, + tile_n: int = 256, + tile_k: int = 64, + warp_m: int = 2, + warp_n: int = 2, + warp_k: int = 1, + warp_tile_m: int = 32, + warp_tile_n: int = 32, + warp_tile_k: int = 16, + pipeline: str = "compv4", + epilogue: str = "cshuffle", + scheduler: str = "intrawave", + layout: str = "rcr", + operator: Optional[OperatorType] = None, + ) -> bool: + """ + Quick validation check for a kernel configuration. + + Args: + datatype_a, datatype_b, datatype_c: Data types for A, B, C matrices + tile_m, tile_n, tile_k: Block tile dimensions + warp_m, warp_n, warp_k: Warp/wave configuration + warp_tile_m, warp_tile_n, warp_tile_k: Warp tile dimensions + pipeline, epilogue, scheduler: Kernel traits + layout: Matrix layout (e.g., "rcr") + operator: Operator type (GEMM, CONV_FWD, CONV_BWD_DATA, etc.) + Affects validation rules for tile constraints. + Defaults to GEMM if not specified. + + Returns: + True if configuration is valid for this architecture + """ + config = KernelConfig( + datatype_a=datatype_a.lower(), + datatype_b=datatype_b.lower(), + datatype_c=datatype_c.lower(), + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + pipeline=pipeline.lower(), + epilogue=epilogue.lower(), + scheduler=scheduler.lower(), + layout=layout.lower(), + operator=operator if operator is not None else OperatorType.GEMM, + ) + return self.validate_kernel(config).valid + + def _validate_dimensions(self, config: KernelConfig, result: ValidationResult): + """Validate basic dimension constraints""" + if config.tile_m <= 0 or config.tile_n <= 0 or config.tile_k <= 0: + result.add_error( + f"Tile dimensions must be positive: " + f"{config.tile_m}x{config.tile_n}x{config.tile_k}" + ) + + if config.warp_m <= 0 or config.warp_n <= 0 or config.warp_k <= 0: + result.add_error( + f"Warp dimensions must be positive: " + f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + ) + + if ( + config.warp_tile_m <= 0 + or config.warp_tile_n <= 0 + or config.warp_tile_k <= 0 + ): + result.add_error( + f"Warp tile dimensions must be positive: " + f"{config.warp_tile_m}x{config.warp_tile_n}x{config.warp_tile_k}" + ) + + # Check warp tiles fit within block tiles + if config.warp_m * config.warp_tile_m > config.tile_m: + result.add_error( + f"warp_m * warp_tile_m ({config.warp_m}*{config.warp_tile_m}=" + f"{config.warp_m * config.warp_tile_m}) > tile_m ({config.tile_m})" + ) + if config.warp_n * config.warp_tile_n > config.tile_n: + result.add_error( + f"warp_n * warp_tile_n ({config.warp_n}*{config.warp_tile_n}=" + f"{config.warp_n * config.warp_tile_n}) > tile_n ({config.tile_n})" + ) + if config.warp_k * config.warp_tile_k > config.tile_k: + result.add_error( + f"warp_k * warp_tile_k ({config.warp_k}*{config.warp_tile_k}=" + f"{config.warp_k * config.warp_tile_k}) > tile_k ({config.tile_k})" + ) + + def _validate_warp_config(self, config: KernelConfig, result: ValidationResult): + """Validate warp configuration against architecture""" + allowed = WARP_SUPPORTED_COMBINATIONS.get(self.gpu_arch, []) + current = [config.warp_m, config.warp_n, config.warp_k] + + if not allowed: + msg = f"No warp configurations defined for {self.gpu_arch}" + if self.strict_mode: + result.add_error(msg) + else: + result.add_warning(msg) + return + + if current not in allowed: + result.add_error( + f"Invalid warp configuration {current} for {self.gpu_arch}. " + f"Allowed: {allowed}" + ) + + def _validate_warp_tile_combo(self, config: KernelConfig, result: ValidationResult): + """Validate warp tile combination against architecture and data types""" + # Use preshuffle-specific warp tiles for preshuffle operator + if config.operator == OperatorType.GEMM_PRESHUFFLE: + gpu_combos = PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS.get( + self.gpu_arch, {} + ) + combo_source = "preshuffle" + else: + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {}) + combo_source = "standard" + + if not gpu_combos: + msg = ( + f"No {combo_source} warp tile combinations defined for {self.gpu_arch}" + ) + if self.strict_mode: + result.add_error(msg) + else: + result.add_warning(msg) + return + + dtype_combos = gpu_combos.get(config.dtype_key, []) + if not dtype_combos: + # Data type combo not explicitly listed - may still be valid + result.add_warning( + f"No {combo_source} warp tile combinations defined for {config.dtype_key} on {self.gpu_arch}" + ) + return + + current = [config.warp_tile_m, config.warp_tile_n, config.warp_tile_k] + if current not in dtype_combos: + result.add_error( + f"Invalid warp tile {current} for {config.dtype_key} on {self.gpu_arch} ({combo_source}). " + f"Allowed: {dtype_combos}" + ) + + def _validate_trait_combo(self, config: KernelConfig, result: ValidationResult): + """Validate trait (pipeline, epilogue, scheduler) combination""" + # Preshuffle requires specific pipelines + if config.operator == OperatorType.GEMM_PRESHUFFLE: + if config.pipeline not in PRESHUFFLE_PIPELINES: + result.add_error( + f"Preshuffle GEMM requires pipeline in {PRESHUFFLE_PIPELINES}, " + f"got {config.pipeline}" + ) + + # Conv backward operations only support compv3/mem pipelines + # (compv4/compv5 have template issues: transpose_tile2d for bwd_weight, + # get_length for bwd_data in ck_tile kernels) + conv_bwd_operators = { + OperatorType.CONV_BWD_DATA, + OperatorType.CONV_BWD_WEIGHT, + OperatorType.CONV3D_BWD_DATA, + OperatorType.CONV3D_BWD_WEIGHT, + } + conv_bwd_supported_pipelines = {"compv3", "mem"} + if config.operator in conv_bwd_operators: + if config.pipeline not in conv_bwd_supported_pipelines: + result.add_error( + f"Conv backward operations require pipeline in " + f"{conv_bwd_supported_pipelines}, got {config.pipeline}. " + f"(compv4/compv5 have ck_tile template compatibility issues)" + ) + + combo = (config.pipeline, config.epilogue, config.scheduler) + if combo in TRAIT_UNSUPPORTED_COMBINATIONS: + result.add_error( + f"Unsupported trait combination: pipeline={config.pipeline}, " + f"epilogue={config.epilogue}, scheduler={config.scheduler}" + ) + + def _validate_lds_capacity(self, config: KernelConfig, result: ValidationResult): + """Validate LDS (Local Data Share) memory capacity""" + elem_size_a = ELEMENT_SIZE_MAP.get(config.datatype_a, 2) + elem_size_b = ELEMENT_SIZE_MAP.get(config.datatype_b, 2) + + matrix_a_size = config.tile_m * config.tile_k * elem_size_a + matrix_b_size = config.tile_n * config.tile_k * elem_size_b + total_lds = matrix_a_size + matrix_b_size + + max_lds = LDS_CAPACITY_LIMITS.get( + config.pipeline, LDS_CAPACITY_LIMITS["default"] + ) + + if total_lds > max_lds: + result.add_error( + f"LDS capacity exceeded: {total_lds} bytes > {max_lds} bytes limit. " + f"Matrix A: {config.tile_m}x{config.tile_k}x{elem_size_a}={matrix_a_size}B, " + f"Matrix B: {config.tile_n}x{config.tile_k}x{elem_size_b}={matrix_b_size}B" + ) + + def _validate_dimension_alignment( + self, config: KernelConfig, result: ValidationResult + ): + """Validate tile dimensions are aligned with warp dimensions""" + if config.tile_m % (config.warp_m * config.warp_tile_m) != 0: + result.add_error( + f"tile_m ({config.tile_m}) must be divisible by " + f"warp_m*warp_tile_m ({config.warp_m}*{config.warp_tile_m}=" + f"{config.warp_m * config.warp_tile_m})" + ) + + if config.tile_n % (config.warp_n * config.warp_tile_n) != 0: + result.add_error( + f"tile_n ({config.tile_n}) must be divisible by " + f"warp_n*warp_tile_n ({config.warp_n}*{config.warp_tile_n}=" + f"{config.warp_n * config.warp_tile_n})" + ) + + if config.tile_k % (config.warp_k * config.warp_tile_k) != 0: + result.add_error( + f"tile_k ({config.tile_k}) must be divisible by " + f"warp_k*warp_tile_k ({config.warp_k}*{config.warp_tile_k}=" + f"{config.warp_k * config.warp_tile_k})" + ) + + def get_supported_warp_configs(self) -> List[List[int]]: + """Get list of supported warp configurations for this architecture""" + return WARP_SUPPORTED_COMBINATIONS.get(self.gpu_arch, []) + + def get_supported_warp_tiles(self, dtype_key: str) -> List[List[int]]: + """Get list of supported warp tile configurations for given data types""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {}) + return gpu_combos.get(dtype_key, []) + + def get_supported_datatypes(self) -> List[str]: + """Get list of data type combinations supported on this architecture""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {}) + return list(gpu_combos.keys()) + + +# ============================================================================= +# Registry Filter Integration +# ============================================================================= + + +class RegistryFilter: + """ + Filter wrapper for integrating with dispatcher Registry. + + Provides a callable interface that can be used with Registry.filter() + or during kernel registration. + + Example: + # Create filter for gfx942 + filter = RegistryFilter("gfx942") + + # Use with registry + registry = Registry() + registry.set_kernel_filter(filter) # Auto-filter on registration + + # Or filter existing kernels + valid_kernels = registry.filter(filter.accepts_kernel) + """ + + def __init__(self, gpu_arch: str, strict_mode: bool = False): + """ + Initialize registry filter. + + Args: + gpu_arch: Target GPU architecture + strict_mode: If True, reject unknown configurations + """ + self.arch_filter = ArchFilter(gpu_arch, strict_mode=strict_mode) + self.gpu_arch = gpu_arch + self._rejected_count = 0 + self._accepted_count = 0 + + def accepts_kernel(self, kernel_config: Dict[str, Any]) -> bool: + """ + Check if a kernel configuration should be accepted into the registry. + + Args: + kernel_config: Dictionary with kernel configuration values + + Returns: + True if kernel is valid for target architecture + """ + try: + is_valid = self.arch_filter.is_kernel_valid( + datatype_a=kernel_config.get("dtype_a", "fp16"), + datatype_b=kernel_config.get("dtype_b", "fp16"), + datatype_c=kernel_config.get("dtype_c", "fp16"), + tile_m=kernel_config.get("tile_m", 256), + tile_n=kernel_config.get("tile_n", 256), + tile_k=kernel_config.get("tile_k", 64), + warp_m=kernel_config.get("warp_m", 2), + warp_n=kernel_config.get("warp_n", 2), + warp_k=kernel_config.get("warp_k", 1), + warp_tile_m=kernel_config.get("warp_tile_m", 32), + warp_tile_n=kernel_config.get("warp_tile_n", 32), + warp_tile_k=kernel_config.get("warp_tile_k", 16), + pipeline=kernel_config.get("pipeline", "compv4"), + epilogue=kernel_config.get("epilogue", "cshuffle"), + scheduler=kernel_config.get("scheduler", "intrawave"), + layout=kernel_config.get("layout", "rcr"), + ) + + if is_valid: + self._accepted_count += 1 + else: + self._rejected_count += 1 + + return is_valid + + except Exception as e: + logger.warning(f"Error validating kernel config: {e}") + self._rejected_count += 1 + return False + + def get_stats(self) -> Dict[str, int]: + """Get filtering statistics""" + return { + "accepted": self._accepted_count, + "rejected": self._rejected_count, + "total": self._accepted_count + self._rejected_count, + } + + def reset_stats(self): + """Reset filtering statistics""" + self._accepted_count = 0 + self._rejected_count = 0 + + def __call__(self, kernel_config: Dict[str, Any]) -> bool: + """Callable interface for use with filter functions""" + return self.accepts_kernel(kernel_config) + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + + +def get_supported_archs() -> List[str]: + """Get list of all supported GPU architectures""" + return list(ARCH_FAMILY_MAP.keys()) + + +def get_arch_family(gpu_arch: str) -> Optional[str]: + """Get the GPU family for an architecture""" + family = ARCH_FAMILY_MAP.get(gpu_arch.lower()) + return family if family else None # ARCH_FAMILY_MAP contains strings, not Enums + + +def create_filter_for_current_gpu() -> Optional[ArchFilter]: + """ + Create a filter for the current GPU (auto-detect). + + Returns: + ArchFilter for detected GPU, or None if detection fails + """ + try: + import subprocess + + result = subprocess.run(["rocminfo"], capture_output=True, text=True, timeout=5) + + for line in result.stdout.split("\n"): + if "gfx" in line.lower(): + for arch in ARCH_FAMILY_MAP.keys(): + if arch in line.lower(): + return ArchFilter(arch) + + return None + except Exception: + return None + + +def filter_kernel_list( + kernels: List[Dict[str, Any]], gpu_arch: str +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Filter a list of kernel configurations for a specific architecture. + + Args: + kernels: List of kernel configuration dictionaries + gpu_arch: Target GPU architecture + + Returns: + Tuple of (valid_kernels, rejected_kernels) + """ + reg_filter = RegistryFilter(gpu_arch) + valid = [] + rejected = [] + + for kernel in kernels: + if reg_filter.accepts_kernel(kernel): + valid.append(kernel) + else: + rejected.append(kernel) + + return valid, rejected + + +# ============================================================================= +# Main (for testing) +# ============================================================================= + +if __name__ == "__main__": + # Test the filter + print("Testing ArchFilter for gfx942...\n") + + filter_942 = ArchFilter("gfx942") + + # Test valid configuration + print("Test 1: Valid FP16 GEMM kernel") + result = filter_942.validate_kernel( + KernelConfig( + datatype_a="fp16", + datatype_b="fp16", + datatype_c="fp16", + tile_m=256, + tile_n=256, + tile_k=64, + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + ) + ) + print(f" Valid: {result.valid}") + if result.errors: + print(f" Errors: {result.errors}") + print() + + # Test invalid warp configuration + print("Test 2: Invalid warp configuration") + result = filter_942.validate_kernel( + KernelConfig( + datatype_a="fp16", + datatype_b="fp16", + datatype_c="fp16", + tile_m=256, + tile_n=256, + tile_k=64, + warp_m=3, + warp_n=3, + warp_k=1, # Invalid! + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + ) + ) + print(f" Valid: {result.valid}") + if result.errors: + print(f" Errors: {result.errors}") + print() + + # Test LDS overflow + print("Test 3: LDS capacity overflow") + result = filter_942.validate_kernel( + KernelConfig( + datatype_a="fp16", + datatype_b="fp16", + datatype_c="fp16", + tile_m=512, + tile_n=512, + tile_k=256, # Too large! + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + ) + ) + print(f" Valid: {result.valid}") + if result.errors: + print(f" Errors: {result.errors}") + print() + + # Test quick validation + print("Test 4: Quick validation (is_kernel_valid)") + is_valid = filter_942.is_kernel_valid( + tile_m=128, + tile_n=128, + tile_k=32, + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=16, + ) + print(f" Valid: {is_valid}") + print() + + # Show supported configurations + print("Supported warp configurations for gfx942:") + for cfg in filter_942.get_supported_warp_configs(): + print(f" {cfg}") + print() + + print("Supported data types for gfx942:") + for dtype in filter_942.get_supported_datatypes(): + print(f" {dtype}") diff --git a/dispatcher/codegen/arch_specs.json b/dispatcher/codegen/arch_specs.json new file mode 100644 index 0000000000..7d8c83fbf7 --- /dev/null +++ b/dispatcher/codegen/arch_specs.json @@ -0,0 +1,270 @@ +{ + "_comment": "Single source of truth for GPU architecture specifications. Edit this file to add new GPU support.", + "_version": "1.2.0", + "_instructions": "See ADDING_NEW_GPU.md for instructions on adding new GPU support.", + "_supported_arch_note": "CK Tile supports: GFX9 (gfx908, gfx90a, gfx942, gfx950), GFX10.3 (gfx103x), GFX11 (gfx110x, gfx115x), GFX12 (gfx120x)", + + "architectures": { + "gfx908": { + "family": "cdna1", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI100", + "warp_size": 64, + "lds_capacity_kb": 64, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]] + } + }, + + "gfx90a": { + "family": "cdna2", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI200 series", + "warp_size": 64, + "lds_capacity_kb": 64, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]] + } + }, + + "gfx942": { + "family": "cdna3", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI300 series", + "warp_size": 64, + "lds_capacity_kb": 64, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]], + "bf8_fp8_fp32": [[32, 32, 16]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]] + } + }, + + "gfx950": { + "family": "cdna4", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI350 series", + "warp_size": 64, + "lds_capacity_kb": 160, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + "fp8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 128], [32, 32, 64]], + "bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + "pk_fp4_pk_fp4_fp32": [[16, 16, 128]] + } + }, + + "gfx1100": { + "family": "rdna3", + "target_family": "gfx11", + "architecture": "rdna", + "description": "AMD Radeon RX 7900 series (RDNA3)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]] + } + }, + + "gfx1200": { + "family": "rdna4", + "target_family": "gfx12", + "architecture": "rdna", + "description": "AMD Radeon RX 9000 series (RDNA4)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]] + } + }, + + "gfx1201": { + "family": "rdna4", + "target_family": "gfx12", + "architecture": "rdna", + "description": "AMD Radeon RX 9000 series (RDNA4)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]] + } + } + }, + + "element_sizes": { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "pk_fp4": 0.5, + "int32": 4 + }, + + "datatype_cpp_map": { + "_comment": "Maps dtype string to CK Tile C++ type for code generation", + "fp16": "ck_tile::half_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp64": "double", + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int8": "ck_tile::int8_t", + "int4": "ck_tile::pk_int4_t", + "pk_fp4": "ck_tile::pk_fp4_t", + "int32": "ck_tile::int32_t" + }, + + "dtype_combinations": { + "_comment": "All valid (A, B) -> Acc combinations for GEMM from warp_gemm_dispatcher.hpp", + "fp32_fp32": {"acc": "fp32", "notes": "Full precision"}, + "fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"}, + "bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"}, + "fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"}, + "fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"}, + "bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"}, + "bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"}, + "int8_int8": {"acc": "int32", "notes": "Integer GEMM"}, + "pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"} + }, + + "layout_cpp_map": { + "_comment": "Maps layout character to CK Tile C++ type", + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor" + }, + + "pipeline_lds_limits": { + "_comment": "LDS capacity limits in bytes for different pipeline types", + "mem": 65536, + "compv1": 65536, + "compv2": 65536, + "compv3": 65536, + "compv4": 32768, + "compv5": 65536, + "preshufflev1": 32768, + "preshufflev2": 32768, + "default": 65536 + }, + + "unsupported_trait_combos": { + "_comment": "Only 'mem' pipeline supports interwave scheduler. All compute pipelines only support intrawave.", + "combinations": [ + ["compv3", "cshuffle", "interwave"], + ["compv3", "default", "interwave"], + ["compv4", "cshuffle", "interwave"], + ["compv4", "default", "interwave"], + ["compv5", "cshuffle", "interwave"], + ["compv5", "default", "interwave"], + ["compv6", "cshuffle", "interwave"], + ["compv6", "default", "interwave"], + ["comp_async", "cshuffle", "interwave"], + ["comp_async", "default", "interwave"] + ] + }, + + "preshuffle_warp_tile_combos": { + "_comment": "Preshuffle-specific warp tile combinations (subset of standard GEMM, no [4, 64, 16])", + "gfx90a": { + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]] + }, + "gfx942": { + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]] + }, + "gfx950": { + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] + } + }, + + "preshuffle_pipelines": { + "_comment": "Pipelines supported for preshuffle GEMM variant", + "supported": ["preshufflev2"] + } +} diff --git a/dispatcher/codegen/arch_specs_generated.py b/dispatcher/codegen/arch_specs_generated.py new file mode 100644 index 0000000000..97f17e9724 --- /dev/null +++ b/dispatcher/codegen/arch_specs_generated.py @@ -0,0 +1,358 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + +Generated from: arch_specs.json +Generated at: 2026-01-05T19:34:01.224422 + +To update this file: +1. Edit arch_specs.json +2. Run: python generate_arch_specs.py + +This module provides architecture-specific configurations for kernel filtering. +""" + +from typing import Dict, List, Set, Tuple + +# ============================================================================= +# Architecture Data (Generated from arch_specs.json) +# ============================================================================= + +# GPU architecture to family mapping +ARCH_FAMILY_MAP: Dict[str, str] = { + "gfx908": "cdna1", + "gfx90a": "cdna2", + "gfx942": "cdna3", + "gfx950": "cdna4", + "gfx1100": "rdna3", + "gfx1200": "rdna4", + "gfx1201": "rdna4", +} + +# Element size in bytes for each data type +ELEMENT_SIZE_MAP: Dict[str, float] = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "pk_fp4": 0.5, + "int32": 4, +} + +# Supported warp configurations per architecture [warp_m, warp_n, warp_k] +WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = { + "gfx908": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx1100": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], + "gfx1200": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], + "gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], +} + +# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...] +WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = { + "gfx908": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + }, + "gfx90a": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + }, + "gfx942": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]], + "bf8_fp8_fp32": [[32, 32, 16]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + }, + "gfx950": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "fp8_bf8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]], + "bf8_bf8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + "pk_fp4_pk_fp4_fp32": [[16, 16, 128]], + }, + "gfx1100": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]], + }, + "gfx1200": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]], + }, + "gfx1201": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]], + }, +} + +# Preshuffle-specific warp tile combinations (subset of standard GEMM) +PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = { + "gfx90a": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +# Preshuffle-supported pipelines +PRESHUFFLE_PIPELINES: List[str] = ["preshufflev2"] + +# LDS capacity limits per pipeline type (in bytes) +LDS_CAPACITY_LIMITS: Dict[str, int] = { + "mem": 65536, + "compv1": 65536, + "compv2": 65536, + "compv3": 65536, + "compv4": 32768, + "compv5": 65536, + "preshufflev1": 32768, + "preshufflev2": 32768, + "default": 65536, +} + +# Unsupported trait combinations: (pipeline, epilogue, scheduler) +TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + ("compv5", "cshuffle", "interwave"), + ("compv5", "default", "interwave"), + ("compv6", "cshuffle", "interwave"), + ("compv6", "default", "interwave"), + ("comp_async", "cshuffle", "interwave"), + ("comp_async", "default", "interwave"), +} + +# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes +DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = { + "fp32_fp32": {"acc": "fp32", "notes": "Full precision"}, + "fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"}, + "bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"}, + "fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"}, + "fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"}, + "bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"}, + "bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"}, + "int8_int8": {"acc": "int32", "notes": "Integer GEMM"}, + "pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"}, +} + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def get_supported_archs() -> List[str]: + """Get list of all supported GPU architectures.""" + return list(ARCH_FAMILY_MAP.keys()) + + +def get_arch_family(gpu_arch: str) -> str: + """Get the GPU family for an architecture.""" + return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown") + + +def get_element_size(dtype: str) -> float: + """Get element size in bytes for a data type.""" + return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0) + + +def get_warp_configs(gpu_arch: str) -> List[List[int]]: + """Get supported warp configurations for an architecture.""" + return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), []) + + +def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]: + """Get supported warp tile combinations for arch and data types.""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {}) + return gpu_combos.get(dtype_key.lower(), []) + + +def get_lds_limit(pipeline: str) -> int: + """Get LDS capacity limit for a pipeline type.""" + return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"]) + + +def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool: + """Check if a trait combination is unsupported.""" + return ( + pipeline.lower(), + epilogue.lower(), + scheduler.lower(), + ) in TRAIT_UNSUPPORTED_COMBINATIONS + + +def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]: + """Get accumulator type and notes for a dtype combination.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + return DTYPE_COMBINATIONS.get(key, {"acc": "fp32", "notes": "unknown"}) + + +def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool: + """Check if a dtype combination is valid.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + return key in DTYPE_COMBINATIONS + + +def get_valid_dtype_combos() -> List[str]: + """Get list of all valid dtype combinations.""" + return list(DTYPE_COMBINATIONS.keys()) diff --git a/dispatcher/codegen/default_config.json b/dispatcher/codegen/default_config.json new file mode 100644 index 0000000000..3ef823fcc2 --- /dev/null +++ b/dispatcher/codegen/default_config.json @@ -0,0 +1,27 @@ +{ + "tile_config": { + "tile_m": [128, 256], + "tile_n": [128, 256], + "tile_k": [32, 64], + "warp_m": [2, 4], + "warp_n": [2, 4], + "warp_k": [1], + "warp_tile_m": [16, 32], + "warp_tile_n": [16, 32], + "warp_tile_k": [16] + }, + "trait_config": { + "pipeline": ["compv4"], + "epilogue": ["cshuffle"], + "scheduler": ["intrawave"], + "pad_m": [false], + "pad_n": [false], + "pad_k": [false], + "persistent": [false, true] + }, + "multi_d_config": { + "elementwise_ops": ["MultiDAdd", "Relu", "Gelu"], + "num_d_tensors": [1, 2] + } +} + diff --git a/dispatcher/codegen/generate_arch_specs.py b/dispatcher/codegen/generate_arch_specs.py new file mode 100644 index 0000000000..5b6fc2971b --- /dev/null +++ b/dispatcher/codegen/generate_arch_specs.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Architecture Specs Generator + +Generates both Python and C++ code from a single JSON source of truth. +This ensures consistency between Python codegen and C++ runtime filtering. + +Usage: + python generate_arch_specs.py [--json arch_specs.json] [--output-dir .] + + # Regenerate after editing arch_specs.json: + python generate_arch_specs.py + +Output: + - arch_specs_generated.py (Python module with arch data) + - arch_specs_generated.hpp (C++ header with arch data) +""" + +import json +import argparse +from pathlib import Path +from datetime import datetime +from typing import Dict, Any + +SCRIPT_DIR = Path(__file__).parent + + +def load_arch_specs(json_path: Path) -> Dict[str, Any]: + """Load architecture specifications from JSON file.""" + with open(json_path) as f: + return json.load(f) + + +def generate_python_module(specs: Dict[str, Any], output_path: Path): + """Generate Python module from arch specs.""" + + timestamp = datetime.now().isoformat() + + # Extract data + archs = specs["architectures"] + element_sizes = specs["element_sizes"] + pipeline_limits = specs["pipeline_lds_limits"] + unsupported = specs["unsupported_trait_combos"]["combinations"] + + # Build warp configs dict + warp_configs_str = "{\n" + for arch, data in archs.items(): + warp_configs_str += f' "{arch}": {data["warp_configs"]},\n' + warp_configs_str += "}" + + # Build warp tile combos dict + warp_tile_str = "{\n" + for arch, data in archs.items(): + warp_tile_str += f' "{arch}": {{\n' + for dtype, combos in data["warp_tile_combos"].items(): + warp_tile_str += f' "{dtype}": {combos},\n' + warp_tile_str += " },\n" + warp_tile_str += "}" + + # Build arch family map + arch_family_str = "{\n" + for arch, data in archs.items(): + arch_family_str += f' "{arch}": "{data["family"]}",\n' + arch_family_str += "}" + + # Build unsupported combos set + unsupported_str = "{\n" + for combo in unsupported: + unsupported_str += f' ("{combo[0]}", "{combo[1]}", "{combo[2]}"),\n' + unsupported_str += "}" + + # Pipeline LDS limits + pipeline_limits_clean = { + k: v for k, v in pipeline_limits.items() if not k.startswith("_") + } + + # Build dtype combinations dict + dtype_combos = specs.get("dtype_combinations", {}) + dtype_combos_str = "{\n" + for key, info in dtype_combos.items(): + if not key.startswith("_"): + dtype_combos_str += f' "{key}": {{"acc": "{info["acc"]}", "notes": "{info["notes"]}"}},\n' + dtype_combos_str += "}" + + # Build preshuffle warp tile combos dict (operator-specific) + preshuffle_combos = specs.get("preshuffle_warp_tile_combos", {}) + preshuffle_warp_tile_str = "{\n" + for arch, dtype_combos_dict in preshuffle_combos.items(): + if not arch.startswith("_"): + preshuffle_warp_tile_str += f' "{arch}": {{\n' + for dtype, combos in dtype_combos_dict.items(): + preshuffle_warp_tile_str += f' "{dtype}": {combos},\n' + preshuffle_warp_tile_str += " },\n" + preshuffle_warp_tile_str += "}" + + # Build preshuffle pipelines list + preshuffle_pipelines = specs.get("preshuffle_pipelines", {}).get( + "supported", ["preshufflev2"] + ) + preshuffle_pipelines_str = str(preshuffle_pipelines) + + content = f'''# SPDX-License-Identifier: MIT + +""" +AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + +Generated from: arch_specs.json +Generated at: {timestamp} + +To update this file: +1. Edit arch_specs.json +2. Run: python generate_arch_specs.py + +This module provides architecture-specific configurations for kernel filtering. +""" + +from typing import Dict, List, Set, Tuple + +# ============================================================================= +# Architecture Data (Generated from arch_specs.json) +# ============================================================================= + +# GPU architecture to family mapping +ARCH_FAMILY_MAP: Dict[str, str] = {arch_family_str} + +# Element size in bytes for each data type +ELEMENT_SIZE_MAP: Dict[str, float] = {element_sizes} + +# Supported warp configurations per architecture [warp_m, warp_n, warp_k] +WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {warp_configs_str} + +# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...] +WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {warp_tile_str} + +# Preshuffle-specific warp tile combinations (subset of standard GEMM) +PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {preshuffle_warp_tile_str} + +# Preshuffle-supported pipelines +PRESHUFFLE_PIPELINES: List[str] = {preshuffle_pipelines_str} + +# LDS capacity limits per pipeline type (in bytes) +LDS_CAPACITY_LIMITS: Dict[str, int] = {pipeline_limits_clean} + +# Unsupported trait combinations: (pipeline, epilogue, scheduler) +TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = {unsupported_str} + +# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes +DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = {dtype_combos_str} + +# ============================================================================= +# Helper Functions +# ============================================================================= + +def get_supported_archs() -> List[str]: + """Get list of all supported GPU architectures.""" + return list(ARCH_FAMILY_MAP.keys()) + + +def get_arch_family(gpu_arch: str) -> str: + """Get the GPU family for an architecture.""" + return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown") + + +def get_element_size(dtype: str) -> float: + """Get element size in bytes for a data type.""" + return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0) + + +def get_warp_configs(gpu_arch: str) -> List[List[int]]: + """Get supported warp configurations for an architecture.""" + return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), []) + + +def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]: + """Get supported warp tile combinations for arch and data types.""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {{}}) + return gpu_combos.get(dtype_key.lower(), []) + + +def get_lds_limit(pipeline: str) -> int: + """Get LDS capacity limit for a pipeline type.""" + return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"]) + + +def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool: + """Check if a trait combination is unsupported.""" + return (pipeline.lower(), epilogue.lower(), scheduler.lower()) in TRAIT_UNSUPPORTED_COMBINATIONS + + +def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]: + """Get accumulator type and notes for a dtype combination.""" + key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}" + return DTYPE_COMBINATIONS.get(key, {{"acc": "fp32", "notes": "unknown"}}) + + +def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool: + """Check if a dtype combination is valid.""" + key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}" + return key in DTYPE_COMBINATIONS + + +def get_valid_dtype_combos() -> List[str]: + """Get list of all valid dtype combinations.""" + return list(DTYPE_COMBINATIONS.keys()) +''' + + output_path.write_text(content) + print(f"Generated: {output_path}") + + +def generate_cpp_header(specs: Dict[str, Any], output_path: Path): + """Generate C++ header from arch specs.""" + + timestamp = datetime.now().isoformat() + + # Extract data + archs = specs["architectures"] + element_sizes = specs["element_sizes"] + pipeline_limits = specs["pipeline_lds_limits"] + specs["unsupported_trait_combos"]["combinations"] + + # Build arch enum and string functions + arch_enums = [] + arch_to_string_cases = [] + string_to_arch_cases = [] + + for arch, data in archs.items(): + enum_name = arch.upper().replace("GFX", "GFX_") + arch_enums.append(f" {enum_name}, // {data['description']}") + arch_to_string_cases.append( + f' case GpuArch::{enum_name}: return "{arch}";' + ) + string_to_arch_cases.append( + f' if (arch_str == "{arch}") return GpuArch::{enum_name};' + ) + + # Build warp configs switch + warp_config_cases = [] + for arch, data in archs.items(): + enum_name = arch.upper().replace("GFX", "GFX_") + configs = ", ".join( + [f"{{{c[0]}, {c[1]}, {c[2]}}}" for c in data["warp_configs"]] + ) + warp_config_cases.append( + f" case GpuArch::{enum_name}: return {{{configs}}};" + ) + + # Build element size switch + # Include all data types defined in kernel_key.hpp DataType enum + elem_size_cases = [] + dtype_enum_map = { + "fp16": "FP16", + "bf16": "BF16", + "fp32": "FP32", + "fp64": "FP64", + "fp8": "FP8", + "bf8": "BF8", + "int8": "INT8", + "int4": "INT4", + "int32": "INT32", + } + for dtype, size in element_sizes.items(): + if dtype in dtype_enum_map: + elem_size_cases.append( + f" case DataType::{dtype_enum_map[dtype]}: return {float(size)}f;" + ) + + # Build LDS limits + lds_limit_cases = [] + pipeline_enum_map = { + "mem": "Mem", + "compv1": "CompV1", + "compv2": "CompV2", + "compv3": "CompV3", + "compv4": "CompV4", + "compv5": "CompV5", + "preshufflev1": "PreShuffleV1", + "preshufflev2": "PreShuffleV2", + } + default_lds = pipeline_limits.get("default", 65536) + for pipeline, limit in pipeline_limits.items(): + if pipeline in pipeline_enum_map: + lds_limit_cases.append( + f" if (pipeline == Pipeline::{pipeline_enum_map[pipeline]}) return {limit};" + ) + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +/** + * AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + * + * Generated from: arch_specs.json + * Generated at: {timestamp} + * + * To update this file: + * 1. Edit arch_specs.json + * 2. Run: python generate_arch_specs.py + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include + +namespace ck_tile {{ +namespace dispatcher {{ +namespace arch_specs {{ + +// ============================================================================= +// GPU Architecture Enum (Generated) +// ============================================================================= + +enum class GpuArch : std::uint8_t {{ +{chr(10).join(arch_enums)} + UNKNOWN +}}; + +// ============================================================================= +// String Conversion Functions (Generated) +// ============================================================================= + +inline std::string arch_to_string(GpuArch arch) {{ + switch (arch) {{ +{chr(10).join(arch_to_string_cases)} + default: return "unknown"; + }} +}} + +inline GpuArch string_to_arch(const std::string& arch_str) {{ +{chr(10).join(string_to_arch_cases)} + return GpuArch::UNKNOWN; +}} + +// ============================================================================= +// Element Size (Generated) +// ============================================================================= + +inline float element_size(DataType dtype) {{ + switch (dtype) {{ +{chr(10).join(elem_size_cases)} + default: return 2.0f; + }} +}} + +// ============================================================================= +// Warp Configurations (Generated) +// ============================================================================= + +using WarpConfig = std::array; + +inline std::vector get_supported_warp_configs(GpuArch arch) {{ + switch (arch) {{ +{chr(10).join(warp_config_cases)} + default: return {{}}; + }} +}} + +// ============================================================================= +// LDS Capacity Limits (Generated) +// ============================================================================= + +inline std::size_t get_lds_capacity(Pipeline pipeline) {{ +{chr(10).join(lds_limit_cases)} + return {default_lds}; // Default +}} + +// ============================================================================= +// Unsupported Trait Combinations (Generated) +// ============================================================================= + +inline bool is_trait_unsupported(Pipeline pipeline, [[maybe_unused]] Epilogue epilogue, Scheduler scheduler) {{ + // Generated from unsupported_trait_combos in arch_specs.json + if (scheduler == Scheduler::Interwave) {{ + if (pipeline == Pipeline::CompV3 || pipeline == Pipeline::CompV4) {{ + return true; + }} + }} + return false; +}} + +}} // namespace arch_specs +}} // namespace dispatcher +}} // namespace ck_tile +""" + + output_path.write_text(content) + print(f"Generated: {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate Python and C++ code from arch_specs.json" + ) + parser.add_argument( + "--json", + type=Path, + default=SCRIPT_DIR / "arch_specs.json", + help="Path to arch_specs.json", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=SCRIPT_DIR, + help="Output directory for generated files", + ) + parser.add_argument( + "--cpp-output-dir", + type=Path, + default=None, + help="Output directory for C++ header (defaults to dispatcher/include/...)", + ) + + args = parser.parse_args() + + # Load specs + print(f"Loading: {args.json}") + specs = load_arch_specs(args.json) + + # Generate Python module + py_output = args.output_dir / "arch_specs_generated.py" + generate_python_module(specs, py_output) + + # Generate C++ header + if args.cpp_output_dir: + cpp_output = args.cpp_output_dir / "arch_specs_generated.hpp" + else: + cpp_output = ( + SCRIPT_DIR.parent + / "include" + / "ck_tile" + / "dispatcher" + / "arch_specs_generated.hpp" + ) + + cpp_output.parent.mkdir(parents=True, exist_ok=True) + generate_cpp_header(specs, cpp_output) + + print("\nDone! To apply changes:") + print(" 1. Python code will automatically use arch_specs_generated.py") + print(" 2. C++ code includes arch_specs_generated.hpp") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/codegen/generate_dispatcher_registration.py b/dispatcher/codegen/generate_dispatcher_registration.py new file mode 100644 index 0000000000..024ec4a7c8 --- /dev/null +++ b/dispatcher/codegen/generate_dispatcher_registration.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Generate dispatcher registration code for CK Tile kernels + +This script generates C++ registration code that instantiates TileKernelInstance +templates for each generated kernel, solving the "cannot instantiate from parsed headers" problem. +""" + +import json +import argparse +from pathlib import Path +from typing import List +from dataclasses import dataclass + + +@dataclass +class KernelConfig: + """Kernel configuration for registration""" + + name: str + header_file: str + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + block_size: int + pipeline: str + epilogue: str + scheduler: str + pad_m: bool + pad_n: bool + pad_k: bool + persistent: bool + double_buffer: bool + transpose_c: bool + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout_a: str = "row" + layout_b: str = "col" + layout_c: str = "row" + + +def generate_registration_header(kernels: List[KernelConfig], output_file: Path): + """Generate registration header file""" + + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated by generate_dispatcher_registration.py + +#pragma once + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/backends/kernel_registration.hpp" + +// Include all generated kernel headers +""" + + # Add includes for all kernel headers + for kernel in kernels: + content += f'#include "{kernel.header_file}"\n' + + content += """ + +namespace ck_tile { +namespace dispatcher { +namespace generated { + +/// Register all generated kernels with the dispatcher +inline void register_all_kernels(Registry& registry) +{ +""" + + # Add registration calls for each kernel + for kernel in kernels: + # Extract the SelectedKernel type name from the header file + # Assuming the header defines a type like: using SelectedKernel = ... + kernel_type = f"SelectedKernel_{kernel.name}" + + content += f""" // Register {kernel.name} + register_tile_kernel<{kernel_type}>(registry, "{kernel.name}"); +""" + + content += """} + +/// Register all generated kernels with the global registry +inline void register_all_kernels() +{ + auto& registry = Registry::instance(); + register_all_kernels(registry); +} + +} // namespace generated +} // namespace dispatcher +} // namespace ck_tile +""" + + output_file.write_text(content) + print(f"✓ Generated registration header: {output_file}") + + +def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path): + """Generate registration implementation file""" + + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated by generate_dispatcher_registration.py + +#include "dispatcher_registration.hpp" + +namespace ck_tile { +namespace dispatcher { +namespace generated { + +// Explicit instantiations to reduce compile time +// These ensure the templates are instantiated once + +""" + + for kernel in kernels: + kernel_type = f"SelectedKernel_{kernel.name}" + content += f"template class backends::TileKernelInstance<{kernel_type}>;\n" + + content += """ +} // namespace generated +} // namespace dispatcher +} // namespace ck_tile +""" + + output_file.write_text(content) + print(f"✓ Generated registration implementation: {output_file}") + + +def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path): + """Generate a wrapper header that defines SelectedKernel type""" + + wrapper_file = output_dir / f"{kernel.name}_wrapper.hpp" + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated by generate_dispatcher_registration.py + +#pragma once + +#include "{kernel.header_file}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +// Type alias for dispatcher registration +// This allows the registration code to reference the kernel type +using SelectedKernel_{kernel.name} = /* Actual kernel type from generated header */; + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + + wrapper_file.write_text(content) + + +def load_kernel_manifest(manifest_file: Path) -> List[KernelConfig]: + """Load kernel configurations from manifest file""" + + with open(manifest_file, "r") as f: + data = json.load(f) + + kernels = [] + for kernel_data in data.get("kernels", []): + kernel = KernelConfig( + name=kernel_data["name"], + header_file=kernel_data["header_file"], + tile_m=kernel_data["tile_m"], + tile_n=kernel_data["tile_n"], + tile_k=kernel_data["tile_k"], + warp_m=kernel_data.get("warp_m", 2), + warp_n=kernel_data.get("warp_n", 2), + warp_k=kernel_data.get("warp_k", 1), + warp_tile_m=kernel_data.get("warp_tile_m", 32), + warp_tile_n=kernel_data.get("warp_tile_n", 32), + warp_tile_k=kernel_data.get("warp_tile_k", 16), + block_size=kernel_data.get("block_size", 256), + pipeline=kernel_data.get("pipeline", "compv4"), + epilogue=kernel_data.get("epilogue", "cshuffle"), + scheduler=kernel_data.get("scheduler", "intrawave"), + pad_m=kernel_data.get("pad_m", False), + pad_n=kernel_data.get("pad_n", False), + pad_k=kernel_data.get("pad_k", False), + persistent=kernel_data.get("persistent", False), + double_buffer=kernel_data.get("double_buffer", True), + transpose_c=kernel_data.get("transpose_c", False), + dtype_a=kernel_data.get("dtype_a", "fp16"), + dtype_b=kernel_data.get("dtype_b", "fp16"), + dtype_c=kernel_data.get("dtype_c", "fp16"), + dtype_acc=kernel_data.get("dtype_acc", "fp32"), + ) + kernels.append(kernel) + + return kernels + + +def scan_generated_headers(generated_dir: Path) -> List[KernelConfig]: + """Scan generated headers and extract kernel configurations""" + + import re + + kernels = [] + + for header_file in generated_dir.glob("**/*.hpp"): + try: + content = header_file.read_text() + + # Extract kernel name + name_match = re.search( + r'constexpr const char\* KERNEL_NAME\s*=\s*"([^"]+)"', content + ) + if not name_match: + continue + + kernel_name = name_match.group(1) + + # Extract tile configuration (support ck_tile::index_t) + tile_m_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileM\s*=\s*(\d+)", + content, + ) + tile_n_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileN\s*=\s*(\d+)", + content, + ) + tile_k_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileK\s*=\s*(\d+)", + content, + ) + + tile_m = int(tile_m_match.group(1)) if tile_m_match else 256 + tile_n = int(tile_n_match.group(1)) if tile_n_match else 256 + tile_k = int(tile_k_match.group(1)) if tile_k_match else 32 + + # Extract warp configuration + warp_m_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_M\s*=\s*(\d+)", + content, + ) + warp_n_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_N\s*=\s*(\d+)", + content, + ) + warp_k_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_K\s*=\s*(\d+)", + content, + ) + + warp_m = int(warp_m_match.group(1)) if warp_m_match else 2 + warp_n = int(warp_n_match.group(1)) if warp_n_match else 2 + warp_k = int(warp_k_match.group(1)) if warp_k_match else 1 + + # Extract warp tile configuration + warp_tile_m_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileM\s*=\s*(\d+)", + content, + ) + warp_tile_n_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileN\s*=\s*(\d+)", + content, + ) + warp_tile_k_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileK\s*=\s*(\d+)", + content, + ) + + warp_tile_m = int(warp_tile_m_match.group(1)) if warp_tile_m_match else 32 + warp_tile_n = int(warp_tile_n_match.group(1)) if warp_tile_n_match else 32 + warp_tile_k = int(warp_tile_k_match.group(1)) if warp_tile_k_match else 16 + + # Extract other parameters (with defaults) + block_size_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+BlockSize\s*=\s*(\d+)", + content, + ) + block_size = int(block_size_match.group(1)) if block_size_match else 256 + + # Extract boolean flags + pad_m = re.search(r"kPadM\s*=\s*true", content) is not None + pad_n = re.search(r"kPadN\s*=\s*true", content) is not None + pad_k = re.search(r"kPadK\s*=\s*true", content) is not None + persistent = ( + re.search(r"UsePersistentKernel\s*=\s*true", content) is not None + ) + double_buffer = ( + re.search(r"DoubleSmemBuffer\s*=\s*true", content) is not None + ) + transpose_c = re.search(r"TransposeC\s*=\s*true", content) is not None + + kernel = KernelConfig( + name=kernel_name, + header_file=str(header_file.relative_to(generated_dir.parent)), + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + block_size=block_size, + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=pad_m, + pad_n=pad_n, + pad_k=pad_k, + persistent=persistent, + double_buffer=double_buffer, + transpose_c=transpose_c, + ) + + kernels.append(kernel) + + except Exception as e: + print(f"Warning: Failed to parse {header_file}: {e}") + continue + + return kernels + + +def main(): + parser = argparse.ArgumentParser( + description="Generate dispatcher registration code" + ) + parser.add_argument( + "--generated-dir", + type=str, + required=True, + help="Directory containing generated kernel headers", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory for registration code", + ) + parser.add_argument( + "--manifest", type=str, help="Optional manifest file with kernel configurations" + ) + parser.add_argument( + "--scan", + action="store_true", + help="Scan generated headers instead of using manifest", + ) + + args = parser.parse_args() + + generated_dir = Path(args.generated_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load kernel configurations + if args.manifest: + print(f"Loading kernels from manifest: {args.manifest}") + kernels = load_kernel_manifest(Path(args.manifest)) + elif args.scan: + print(f"Scanning generated headers in: {generated_dir}") + kernels = scan_generated_headers(generated_dir) + else: + print("Error: Must specify either --manifest or --scan") + return 1 + + print(f"Found {len(kernels)} kernels") + + # Generate registration code + registration_header = output_dir / "dispatcher_registration.hpp" + registration_cpp = output_dir / "dispatcher_registration.cpp" + + generate_registration_header(kernels, registration_header) + generate_registration_cpp(kernels, registration_cpp) + + # Generate manifest for Python + manifest_output = output_dir / "kernels_manifest.json" + manifest_data = { + "kernels": [ + { + "name": k.name, + "header_file": k.header_file, + "tile_m": k.tile_m, + "tile_n": k.tile_n, + "tile_k": k.tile_k, + "block_size": k.block_size, + "persistent": k.persistent, + } + for k in kernels + ] + } + + with open(manifest_output, "w") as f: + json.dump(manifest_data, f, indent=2) + + print(f"✓ Generated manifest: {manifest_output}") + print("\n✓ Registration code generation complete!") + print(f" Total kernels: {len(kernels)}") + print(" Output files:") + print(f" - {registration_header}") + print(f" - {registration_cpp}") + print(f" - {manifest_output}") + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/dispatcher/codegen/generate_kernel_wrappers.py b/dispatcher/codegen/generate_kernel_wrappers.py new file mode 100644 index 0000000000..53a9bff3ed --- /dev/null +++ b/dispatcher/codegen/generate_kernel_wrappers.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Generate one .cpp wrapper file per kernel header for maximum parallel compilation. + +Each kernel becomes its own translation unit, enabling: + - Maximum parallelism with make -j$(nproc) + - Per-kernel build progress (e.g., [5/128] Building kernel: gemm_fp16_128x128) + - Incremental rebuilds (only changed kernels recompile) + - Fine-grained build time analysis + +Usage: + python3 generate_kernel_wrappers.py --kernel-dir build/generated_kernels --output-dir build/kernel_wrappers + +Output structure: + build/kernel_wrappers/ + ├── gemm_fp16_rcr_128x128x32.cpp + ├── gemm_fp16_rcr_256x256x64.cpp + ├── conv_fwd_fp16_2d_128x128.cpp + └── ... + +Each .cpp simply includes its corresponding .hpp and forces symbol emission. +""" + +import argparse +import sys +from pathlib import Path +from typing import List, Tuple +import concurrent.futures + + +WRAPPER_TEMPLATE_GEMM = """// SPDX-License-Identifier: MIT +// Auto-generated wrapper for: {kernel_name} +// This file enables per-kernel parallel compilation + +#include "{kernel_hpp}" + +// Force symbol emission for kernel registration +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +// Marker to prevent dead code elimination +volatile bool _{kernel_id}_registered = true; + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + +WRAPPER_TEMPLATE_CONV = """// SPDX-License-Identifier: MIT +// Auto-generated wrapper for: {kernel_name} +// This file enables per-kernel parallel compilation + +#include "{kernel_hpp}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +volatile bool _{kernel_id}_registered = true; + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + + +def generate_wrapper( + kernel_hpp: Path, output_dir: Path, index: int, total: int +) -> Tuple[Path, bool]: + """Generate a .cpp wrapper for a single kernel header.""" + kernel_name = kernel_hpp.stem + kernel_id = kernel_name.replace("-", "_").replace(".", "_") + + # Select template based on kernel type + if kernel_name.startswith("gemm"): + template = WRAPPER_TEMPLATE_GEMM + else: + template = WRAPPER_TEMPLATE_CONV + + content = template.format( + kernel_name=kernel_name, + kernel_hpp=kernel_hpp.name, + kernel_id=kernel_id, + ) + + output_cpp = output_dir / f"{kernel_name}.cpp" + + # Only write if content changed (for incremental builds) + if output_cpp.exists(): + existing = output_cpp.read_text() + if existing == content: + return output_cpp, False # No change + + output_cpp.write_text(content) + return output_cpp, True # Written + + +def generate_cmake_list( + wrappers: List[Path], output_dir: Path, kernel_dir: Path +) -> Path: + """Generate CMakeLists.txt that compiles each wrapper as a separate object.""" + + num_kernels = len(wrappers) + + cmake_content = f'''# SPDX-License-Identifier: MIT +# Auto-generated CMakeLists.txt for per-kernel parallel compilation +# Generated {num_kernels} kernel translation units + +cmake_minimum_required(VERSION 3.16) + +# ============================================================================= +# Per-Kernel Object Targets ({num_kernels} kernels) +# ============================================================================= +# Each kernel is compiled as a separate OBJECT library for maximum parallelism. +# Build with: make -j$(nproc) all_kernels +# +# Progress output: +# [ 1/{num_kernels}] Building kernel: gemm_fp16_rcr_128x128x32 +# [ 2/{num_kernels}] Building kernel: gemm_fp16_rcr_256x256x64 +# ... + +set(KERNEL_INCLUDE_DIR "{kernel_dir}") +set(ALL_KERNEL_OBJECTS "") + +''' + + for idx, wrapper in enumerate(wrappers, 1): + kernel_name = wrapper.stem + obj_target = f"kobj_{kernel_name}" + + cmake_content += f""" +# [{idx}/{num_kernels}] {kernel_name} +add_library({obj_target} OBJECT {wrapper.name}) +target_include_directories({obj_target} PRIVATE ${{KERNEL_INCLUDE_DIR}} ${{CK_INCLUDE_DIR}}) +target_compile_options({obj_target} PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +set_target_properties({obj_target} PROPERTIES POSITION_INDEPENDENT_CODE ON) +if(hip_FOUND) + target_link_libraries({obj_target} PRIVATE hip::device hip::host) +endif() +list(APPEND ALL_KERNEL_OBJECTS $) +""" + + cmake_content += f""" + +# ============================================================================= +# Combined Kernel Library +# ============================================================================= +# Links all {num_kernels} kernel objects into a single shared library + +add_library(all_kernels SHARED ${{ALL_KERNEL_OBJECTS}}) +if(hip_FOUND) + target_link_libraries(all_kernels PRIVATE hip::device hip::host) +endif() +set_target_properties(all_kernels PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME "dispatcher_kernels" +) + +message(STATUS "Configured {num_kernels} kernel objects for parallel compilation") +message(STATUS "Build with: make -j$(nproc) all_kernels") +""" + + cmake_file = output_dir / "CMakeLists.txt" + cmake_file.write_text(cmake_content) + return cmake_file + + +def generate_ninja_build( + wrappers: List[Path], output_dir: Path, kernel_dir: Path +) -> Path: + """Generate build.ninja for even faster parallel compilation.""" + + num_kernels = len(wrappers) + + ninja_content = f"""# SPDX-License-Identifier: MIT +# Auto-generated build.ninja for per-kernel parallel compilation +# {num_kernels} kernel translation units + +# Variables +cxx = hipcc +cxxflags = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal --offload-compress +includes = -I{kernel_dir} -I/opt/rocm/include + +# Rules +rule compile + command = $cxx $cxxflags $includes -c $in -o $out + description = [{num_kernels}] Building kernel: $kernel_name + +rule link + command = $cxx -shared $in -o $out -L/opt/rocm/lib -lamdhip64 + description = Linking: $out + +# Kernel objects +""" + + obj_files = [] + for idx, wrapper in enumerate(wrappers, 1): + kernel_name = wrapper.stem + obj_file = f"{kernel_name}.o" + obj_files.append(obj_file) + + ninja_content += f""" +build {obj_file}: compile {wrapper.name} + kernel_name = {kernel_name} +""" + + ninja_content += f""" + +# Shared library +build libdispatcher_kernels.so: link {" ".join(obj_files)} + +# Default target +default libdispatcher_kernels.so +""" + + ninja_file = output_dir / "build.ninja" + ninja_file.write_text(ninja_content) + return ninja_file + + +def generate_makefile(wrappers: List[Path], output_dir: Path, kernel_dir: Path) -> Path: + """Generate Makefile for per-kernel parallel compilation.""" + + num_kernels = len(wrappers) + kernel_names = [w.stem for w in wrappers] + obj_files = [f"{name}.o" for name in kernel_names] + + makefile_content = f"""# SPDX-License-Identifier: MIT +# Auto-generated Makefile for per-kernel parallel compilation +# {num_kernels} kernel translation units +# +# Usage: +# make -j$(nproc) # Build all kernels in parallel +# make -j$(nproc) VERBOSE=1 # With per-kernel progress +# make clean # Remove all objects + +CXX = hipcc +CXXFLAGS = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 \\ + -Wno-undefined-func-template -Wno-float-equal --offload-compress +INCLUDES = -I{kernel_dir} -I/opt/rocm/include +LDFLAGS = -shared -L/opt/rocm/lib -lamdhip64 + +TARGET = libdispatcher_kernels.so +OBJECTS = {" ".join(obj_files)} + +# Progress counter (only works with make -j1, use ninja for parallel progress) +TOTAL_KERNELS = {num_kernels} +CURRENT = 0 + +.PHONY: all clean + +all: $(TARGET) +\t@echo "Built $(TARGET) with {num_kernels} kernels" + +$(TARGET): $(OBJECTS) +\t@echo "[LINK] Linking {num_kernels} kernel objects -> $@" +\t$(CXX) $(LDFLAGS) $^ -o $@ + +""" + + for idx, (wrapper, obj) in enumerate(zip(wrappers, obj_files), 1): + kernel_name = wrapper.stem + makefile_content += f""" +{obj}: {wrapper.name} +\t@echo "[{idx}/{num_kernels}] Building kernel: {kernel_name}" +\t$(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ +""" + + makefile_content += f""" + +clean: +\trm -f $(OBJECTS) $(TARGET) +\t@echo "Cleaned {num_kernels} kernel objects" +""" + + makefile = output_dir / "Makefile" + makefile.write_text(makefile_content) + return makefile + + +def main(): + parser = argparse.ArgumentParser( + description="Generate per-kernel wrapper .cpp files for parallel compilation" + ) + parser.add_argument( + "--kernel-dir", + type=Path, + required=True, + help="Directory containing generated kernel .hpp files", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for wrapper .cpp files", + ) + parser.add_argument( + "--pattern", + type=str, + default="*.hpp", + help="Glob pattern for kernel headers (default: *.hpp)", + ) + parser.add_argument( + "--generate-cmake", + action="store_true", + help="Generate CMakeLists.txt for the wrappers", + ) + parser.add_argument( + "--generate-ninja", + action="store_true", + help="Generate build.ninja for ninja builds", + ) + parser.add_argument( + "--generate-makefile", + action="store_true", + help="Generate Makefile for make builds", + ) + parser.add_argument( + "--parallel", + action="store_true", + default=True, + help="Generate wrappers in parallel (default: True)", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Verbose output", + ) + + args = parser.parse_args() + + # Find kernel headers + kernel_dir = args.kernel_dir.resolve() + if not kernel_dir.exists(): + print(f"Error: Kernel directory not found: {kernel_dir}", file=sys.stderr) + return 1 + + kernel_headers = sorted(kernel_dir.glob(args.pattern)) + if not kernel_headers: + print( + f"Error: No kernel headers found matching {args.pattern} in {kernel_dir}", + file=sys.stderr, + ) + return 1 + + num_kernels = len(kernel_headers) + print(f"Found {num_kernels} kernel headers in {kernel_dir}") + + # Create output directory + output_dir = args.output_dir.resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate wrappers + print(f"Generating {num_kernels} wrapper .cpp files...") + + wrappers = [] + written = 0 + + if args.parallel and num_kernels > 1: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = { + executor.submit( + generate_wrapper, hpp, output_dir, idx, num_kernels + ): hpp + for idx, hpp in enumerate(kernel_headers, 1) + } + for future in concurrent.futures.as_completed(futures): + wrapper_path, was_written = future.result() + wrappers.append(wrapper_path) + if was_written: + written += 1 + if args.verbose: + print(f" Generated: {wrapper_path.name}") + else: + for idx, hpp in enumerate(kernel_headers, 1): + wrapper_path, was_written = generate_wrapper( + hpp, output_dir, idx, num_kernels + ) + wrappers.append(wrapper_path) + if was_written: + written += 1 + if args.verbose: + print(f" [{idx}/{num_kernels}] Generated: {wrapper_path.name}") + + wrappers.sort(key=lambda p: p.name) + + print( + f" Total: {num_kernels} wrappers ({written} written, {num_kernels - written} unchanged)" + ) + + # Generate build files + if args.generate_cmake: + cmake_file = generate_cmake_list(wrappers, output_dir, kernel_dir) + print(f" Generated: {cmake_file}") + + if args.generate_ninja: + ninja_file = generate_ninja_build(wrappers, output_dir, kernel_dir) + print(f" Generated: {ninja_file}") + + if args.generate_makefile: + makefile = generate_makefile(wrappers, output_dir, kernel_dir) + print(f" Generated: {makefile}") + + print(f"\nOutput directory: {output_dir}") + print(f"Kernels ready for parallel compilation: {num_kernels}") + print("\nTo build:") + print(f" cd {output_dir}") + if args.generate_makefile: + print(" make -j$(nproc) # Parallel build with progress") + if args.generate_ninja: + print(" ninja # Fast parallel build") + if args.generate_cmake: + print(" cmake -B build && cmake --build build -j$(nproc)") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/codegen/kernel_config_loader.py b/dispatcher/codegen/kernel_config_loader.py new file mode 100644 index 0000000000..537fc40581 --- /dev/null +++ b/dispatcher/codegen/kernel_config_loader.py @@ -0,0 +1,798 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Kernel Configuration Loader + +Load kernel configurations from JSON files for generating specific kernel sets. +Compatible with tile_engine JSON format. + +Usage: + from kernel_config_loader import load_kernel_configs, KernelConfigSet + + # Load configs from JSON + config_set = load_kernel_configs("my_kernels.json") + + # Get all configurations (cartesian product of all parameter values) + for config in config_set.generate_configs(): + print(config) + + # Use with codegen + from unified_gemm_codegen import UnifiedGemmCodegen + codegen = UnifiedGemmCodegen(...) + codegen.generate_from_configs(config_set.generate_configs()) +""" + +import json +import itertools +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Dict, Any, Optional, Iterator + + +@dataclass +class TileConfig: + """Tile configuration for a kernel""" + + tile_m: int = 128 + tile_n: int = 128 + tile_k: int = 32 + warp_m: int = 2 + warp_n: int = 2 + warp_k: int = 1 + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + +@dataclass +class TraitConfig: + """Trait configuration for a kernel (order matches GEMM/Conv TraitConfig)""" + + pipeline: str = "compv4" + epilogue: str = "cshuffle" + scheduler: str = "intrawave" + pad_m: bool = False + pad_n: bool = False + pad_k: bool = False + + +@dataclass +class KernelConfig: + """Complete kernel configuration""" + + tile: TileConfig = field(default_factory=TileConfig) + trait: TraitConfig = field(default_factory=TraitConfig) + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout: str = "rcr" + gpu_target: str = "gfx942" + variant: str = "standard" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for codegen""" + return { + "tile_m": self.tile.tile_m, + "tile_n": self.tile.tile_n, + "tile_k": self.tile.tile_k, + "warp_m": self.tile.warp_m, + "warp_n": self.tile.warp_n, + "warp_k": self.tile.warp_k, + "warp_tile_m": self.tile.warp_tile_m, + "warp_tile_n": self.tile.warp_tile_n, + "warp_tile_k": self.tile.warp_tile_k, + "pipeline": self.trait.pipeline, + "scheduler": self.trait.scheduler, + "epilogue": self.trait.epilogue, + "pad_m": self.trait.pad_m, + "pad_n": self.trait.pad_n, + "pad_k": self.trait.pad_k, + "dtype_a": self.dtype_a, + "dtype_b": self.dtype_b, + "dtype_c": self.dtype_c, + "dtype_acc": self.dtype_acc, + "layout": self.layout, + "gpu_target": self.gpu_target, + "variant": self.variant, + } + + def kernel_name(self) -> str: + """Generate kernel name from config""" + name = f"gemm_{self.dtype_a}_{self.layout}_{self.trait.pipeline}" + name += f"_{self.trait.epilogue}_{self.trait.scheduler}" + name += f"_{str(self.trait.pad_m).capitalize()}" + name += f"_{str(self.trait.pad_n).capitalize()}" + name += f"_{str(self.trait.pad_k).capitalize()}" + name += "_False" # preshuffle + name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}" + name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}" + name += ( + f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}" + ) + return name + + +@dataclass +class KernelConfigSet: + """A set of kernel configurations loaded from JSON""" + + name: str = "default" + configs: List[KernelConfig] = field(default_factory=list) + + # Parameter ranges for generation + tile_m_values: List[int] = field(default_factory=lambda: [128]) + tile_n_values: List[int] = field(default_factory=lambda: [128]) + tile_k_values: List[int] = field(default_factory=lambda: [32]) + warp_m_values: List[int] = field(default_factory=lambda: [2]) + warp_n_values: List[int] = field(default_factory=lambda: [2]) + warp_k_values: List[int] = field(default_factory=lambda: [1]) + warp_tile_m_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_n_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_k_values: List[int] = field(default_factory=lambda: [16]) + + pipeline_values: List[str] = field(default_factory=lambda: ["compv4"]) + scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"]) + epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"]) + pad_m_values: List[bool] = field(default_factory=lambda: [False]) + pad_n_values: List[bool] = field(default_factory=lambda: [False]) + pad_k_values: List[bool] = field(default_factory=lambda: [False]) + + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout: str = "rcr" + gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"]) + variant: str = "standard" + + def generate_configs(self) -> Iterator[KernelConfig]: + """Generate all kernel configurations (cartesian product)""" + # Tile parameters + tile_params = itertools.product( + self.tile_m_values, + self.tile_n_values, + self.tile_k_values, + self.warp_m_values, + self.warp_n_values, + self.warp_k_values, + self.warp_tile_m_values, + self.warp_tile_n_values, + self.warp_tile_k_values, + ) + + # Trait parameters + trait_params = itertools.product( + self.pipeline_values, + self.scheduler_values, + self.epilogue_values, + self.pad_m_values, + self.pad_n_values, + self.pad_k_values, + ) + + # Convert to lists for reuse + tile_list = list(tile_params) + trait_list = list(trait_params) + + # Generate for each GPU target + for gpu_target in self.gpu_targets: + for tile in tile_list: + for trait in trait_list: + tile_cfg = TileConfig( + tile_m=tile[0], + tile_n=tile[1], + tile_k=tile[2], + warp_m=tile[3], + warp_n=tile[4], + warp_k=tile[5], + warp_tile_m=tile[6], + warp_tile_n=tile[7], + warp_tile_k=tile[8], + ) + trait_cfg = TraitConfig( + pipeline=trait[0], + scheduler=trait[1], + epilogue=trait[2], + pad_m=trait[3], + pad_n=trait[4], + pad_k=trait[5], + ) + yield KernelConfig( + tile=tile_cfg, + trait=trait_cfg, + dtype_a=self.dtype_a, + dtype_b=self.dtype_b, + dtype_c=self.dtype_c, + dtype_acc=self.dtype_acc, + layout=self.layout, + gpu_target=gpu_target, + variant=self.variant, + ) + + def config_count(self) -> int: + """Get total number of configurations""" + tile_count = ( + len(self.tile_m_values) + * len(self.tile_n_values) + * len(self.tile_k_values) + * len(self.warp_m_values) + * len(self.warp_n_values) + * len(self.warp_k_values) + * len(self.warp_tile_m_values) + * len(self.warp_tile_n_values) + * len(self.warp_tile_k_values) + ) + trait_count = ( + len(self.pipeline_values) + * len(self.scheduler_values) + * len(self.epilogue_values) + * len(self.pad_m_values) + * len(self.pad_n_values) + * len(self.pad_k_values) + ) + return tile_count * trait_count * len(self.gpu_targets) + + +def _get_values(config: Dict, key: str, default: List) -> List: + """Extract values from config dict, handling range specifications""" + if key not in config: + return default + + item = config[key] + + # Explicit values list + if "values" in item: + return item["values"] + + # Range specification (min, max, step) + if "min" in item and "max" in item: + min_val = item["min"] + max_val = item["max"] + step = item.get("step", 1) + return list(range(min_val, max_val + 1, step)) + + return default + + +def load_kernel_configs(json_path: str | Path) -> KernelConfigSet: + """ + Load kernel configurations from a JSON file. + + Supports both tile_engine format and dispatcher format. + + Args: + json_path: Path to JSON configuration file + + Returns: + KernelConfigSet with all parameter values loaded + """ + json_path = Path(json_path) + + with open(json_path) as f: + data = json.load(f) + + config_set = KernelConfigSet() + + # Name + config_set.name = data.get("kernel_set_name", json_path.stem) + + # Data types + if "datatype" in data: + dt = data["datatype"] + config_set.dtype_a = dt.get("a", "fp16") + config_set.dtype_b = dt.get("b", "fp16") + config_set.dtype_c = dt.get("c", "fp16") + config_set.dtype_acc = dt.get("acc", "fp32") + + # Layout + config_set.layout = data.get("layout", "rcr") + + # GPU targets + if "gpu_targets" in data: + config_set.gpu_targets = data["gpu_targets"] + elif "gpu_target" in data: + config_set.gpu_targets = [data["gpu_target"]] + + # Variant + config_set.variant = data.get("variant", "standard") + + # Tile config + tile_cfg = data.get("tile_config", {}) + config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128]) + config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128]) + config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32]) + config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2]) + config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2]) + config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1]) + config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32]) + config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32]) + config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16]) + + # Trait config + trait_cfg = data.get("trait_config", {}) + config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv4"]) + config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"]) + config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"]) + config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [False]) + config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [False]) + config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [False]) + + return config_set + + +# ============================================================================= +# Convolution Configuration Classes +# ============================================================================= + + +@dataclass +class ConvTileConfig: + """Tile configuration for a convolution kernel""" + + tile_m: int = 128 # M dimension (N * spatial_out for fwd) + tile_n: int = 128 # N dimension (K output channels for fwd) + tile_k: int = 32 # K dimension (C * filter for fwd) + warp_m: int = 2 + warp_n: int = 2 + warp_k: int = 1 + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + +@dataclass +class ConvTraitConfig: + """Trait configuration for a convolution kernel""" + + pipeline: str = "compv3" + scheduler: str = "intrawave" + epilogue: str = "cshuffle" + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + double_smem_buffer: bool = False + num_groups_to_merge: int = 1 + + +@dataclass +class ConvKernelConfig: + """Complete convolution kernel configuration""" + + tile: ConvTileConfig = field(default_factory=ConvTileConfig) + trait: ConvTraitConfig = field(default_factory=ConvTraitConfig) + dtype_input: str = "fp16" + dtype_weight: str = "fp16" + dtype_output: str = "fp16" + dtype_acc: str = "fp32" + variant: str = "forward" # forward, bwd_data, bwd_weight + ndim: int = 2 # 1, 2, or 3 + layout: str = "nhwgc" + gpu_target: str = "gfx942" + + # Vector sizes + vector_size_a: int = 4 + vector_size_b: int = 8 + vector_size_c: int = 8 + + # Occupancy + block_per_cu: int = 1 + num_wave_groups: int = 1 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for codegen""" + return { + "tile_m": self.tile.tile_m, + "tile_n": self.tile.tile_n, + "tile_k": self.tile.tile_k, + "warp_m": self.tile.warp_m, + "warp_n": self.tile.warp_n, + "warp_k": self.tile.warp_k, + "warp_tile_m": self.tile.warp_tile_m, + "warp_tile_n": self.tile.warp_tile_n, + "warp_tile_k": self.tile.warp_tile_k, + "pipeline": self.trait.pipeline, + "scheduler": self.trait.scheduler, + "epilogue": self.trait.epilogue, + "pad_m": self.trait.pad_m, + "pad_n": self.trait.pad_n, + "pad_k": self.trait.pad_k, + "double_smem_buffer": self.trait.double_smem_buffer, + "num_groups_to_merge": self.trait.num_groups_to_merge, + "dtype_input": self.dtype_input, + "dtype_weight": self.dtype_weight, + "dtype_output": self.dtype_output, + "dtype_acc": self.dtype_acc, + "variant": self.variant, + "ndim": self.ndim, + "layout": self.layout, + "gpu_target": self.gpu_target, + "vector_size_a": self.vector_size_a, + "vector_size_b": self.vector_size_b, + "vector_size_c": self.vector_size_c, + "block_per_cu": self.block_per_cu, + "num_wave_groups": self.num_wave_groups, + } + + def kernel_name(self) -> str: + """Generate kernel name from config""" + variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"} + var_str = variant_map.get(self.variant, self.variant) + + name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d" + name += f"_{self.trait.pipeline}_{self.trait.epilogue}_{self.trait.scheduler}" + name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}" + name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}" + name += ( + f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}" + ) + return name + + +@dataclass +class ConvKernelConfigSet: + """A set of convolution kernel configurations loaded from JSON""" + + name: str = "default" + configs: List[ConvKernelConfig] = field(default_factory=list) + + # Tile parameter ranges + tile_m_values: List[int] = field(default_factory=lambda: [128]) + tile_n_values: List[int] = field(default_factory=lambda: [128]) + tile_k_values: List[int] = field(default_factory=lambda: [32]) + warp_m_values: List[int] = field(default_factory=lambda: [2]) + warp_n_values: List[int] = field(default_factory=lambda: [2]) + warp_k_values: List[int] = field(default_factory=lambda: [1]) + warp_tile_m_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_n_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_k_values: List[int] = field(default_factory=lambda: [16]) + + # Trait parameter ranges + pipeline_values: List[str] = field(default_factory=lambda: ["compv3"]) + scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"]) + epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"]) + pad_m_values: List[bool] = field(default_factory=lambda: [True]) + pad_n_values: List[bool] = field(default_factory=lambda: [True]) + pad_k_values: List[bool] = field(default_factory=lambda: [True]) + double_smem_buffer_values: List[bool] = field(default_factory=lambda: [False]) + num_groups_to_merge_values: List[int] = field(default_factory=lambda: [1]) + + # Vector sizes + vector_size_a_values: List[int] = field(default_factory=lambda: [4]) + vector_size_b_values: List[int] = field(default_factory=lambda: [8]) + vector_size_c_values: List[int] = field(default_factory=lambda: [8]) + + # Occupancy + block_per_cu_values: List[int] = field(default_factory=lambda: [1]) + num_wave_groups_values: List[int] = field(default_factory=lambda: [1]) + + # Data types + dtype_input: str = "fp16" + dtype_weight: str = "fp16" + dtype_output: str = "fp16" + dtype_acc: str = "fp32" + + # Conv specific + variant: str = "forward" + ndim: int = 2 + layout: str = "nhwgc" + gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"]) + + def generate_configs(self) -> Iterator[ConvKernelConfig]: + """Generate all kernel configurations (cartesian product)""" + # Tile parameters + tile_params = itertools.product( + self.tile_m_values, + self.tile_n_values, + self.tile_k_values, + self.warp_m_values, + self.warp_n_values, + self.warp_k_values, + self.warp_tile_m_values, + self.warp_tile_n_values, + self.warp_tile_k_values, + ) + + # Trait parameters + trait_params = itertools.product( + self.pipeline_values, + self.scheduler_values, + self.epilogue_values, + self.pad_m_values, + self.pad_n_values, + self.pad_k_values, + self.double_smem_buffer_values, + self.num_groups_to_merge_values, + ) + + # Vector/occupancy parameters + extra_params = itertools.product( + self.vector_size_a_values, + self.vector_size_b_values, + self.vector_size_c_values, + self.block_per_cu_values, + self.num_wave_groups_values, + ) + + # Convert to lists for reuse + tile_list = list(tile_params) + trait_list = list(trait_params) + extra_list = list(extra_params) + + # Generate for each GPU target + for gpu_target in self.gpu_targets: + for tile in tile_list: + for trait in trait_list: + for extra in extra_list: + tile_cfg = ConvTileConfig( + tile_m=tile[0], + tile_n=tile[1], + tile_k=tile[2], + warp_m=tile[3], + warp_n=tile[4], + warp_k=tile[5], + warp_tile_m=tile[6], + warp_tile_n=tile[7], + warp_tile_k=tile[8], + ) + trait_cfg = ConvTraitConfig( + pipeline=trait[0], + scheduler=trait[1], + epilogue=trait[2], + pad_m=trait[3], + pad_n=trait[4], + pad_k=trait[5], + double_smem_buffer=trait[6], + num_groups_to_merge=trait[7], + ) + yield ConvKernelConfig( + tile=tile_cfg, + trait=trait_cfg, + dtype_input=self.dtype_input, + dtype_weight=self.dtype_weight, + dtype_output=self.dtype_output, + dtype_acc=self.dtype_acc, + variant=self.variant, + ndim=self.ndim, + layout=self.layout, + gpu_target=gpu_target, + vector_size_a=extra[0], + vector_size_b=extra[1], + vector_size_c=extra[2], + block_per_cu=extra[3], + num_wave_groups=extra[4], + ) + + def config_count(self) -> int: + """Get total number of configurations""" + tile_count = ( + len(self.tile_m_values) + * len(self.tile_n_values) + * len(self.tile_k_values) + * len(self.warp_m_values) + * len(self.warp_n_values) + * len(self.warp_k_values) + * len(self.warp_tile_m_values) + * len(self.warp_tile_n_values) + * len(self.warp_tile_k_values) + ) + trait_count = ( + len(self.pipeline_values) + * len(self.scheduler_values) + * len(self.epilogue_values) + * len(self.pad_m_values) + * len(self.pad_n_values) + * len(self.pad_k_values) + * len(self.double_smem_buffer_values) + * len(self.num_groups_to_merge_values) + ) + extra_count = ( + len(self.vector_size_a_values) + * len(self.vector_size_b_values) + * len(self.vector_size_c_values) + * len(self.block_per_cu_values) + * len(self.num_wave_groups_values) + ) + return tile_count * trait_count * extra_count * len(self.gpu_targets) + + +def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: + """ + Load convolution kernel configurations from a JSON file. + + Args: + json_path: Path to JSON configuration file + + Returns: + ConvKernelConfigSet with all parameter values loaded + """ + json_path = Path(json_path) + + with open(json_path) as f: + data = json.load(f) + + config_set = ConvKernelConfigSet() + + # Name + config_set.name = data.get("kernel_set_name", json_path.stem) + + # Data types + if "datatype" in data: + dt = data["datatype"] + config_set.dtype_input = dt.get("input", "fp16") + config_set.dtype_weight = dt.get("weight", "fp16") + config_set.dtype_output = dt.get("output", "fp16") + config_set.dtype_acc = dt.get("acc", "fp32") + + # Conv specific + config_set.variant = data.get("variant", "forward") + config_set.ndim = data.get("ndim", 2) + config_set.layout = data.get("layout", "nhwgc") + + # GPU targets + if "gpu_targets" in data: + config_set.gpu_targets = data["gpu_targets"] + elif "gpu_target" in data: + config_set.gpu_targets = [data["gpu_target"]] + + # Tile config + tile_cfg = data.get("tile_config", {}) + config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128]) + config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128]) + config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32]) + config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2]) + config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2]) + config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1]) + config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32]) + config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32]) + config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16]) + + # Trait config + trait_cfg = data.get("trait_config", {}) + config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv3"]) + config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"]) + config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"]) + config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [True]) + config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [True]) + config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [True]) + config_set.double_smem_buffer_values = _get_values( + trait_cfg, "double_smem_buffer", [False] + ) + config_set.num_groups_to_merge_values = _get_values( + trait_cfg, "num_groups_to_merge", [1] + ) + + # Vector config + vec_cfg = data.get("vector_config", {}) + config_set.vector_size_a_values = _get_values(vec_cfg, "vector_size_a", [4]) + config_set.vector_size_b_values = _get_values(vec_cfg, "vector_size_b", [8]) + config_set.vector_size_c_values = _get_values(vec_cfg, "vector_size_c", [8]) + + # Occupancy config + occ_cfg = data.get("occupancy_config", {}) + config_set.block_per_cu_values = _get_values(occ_cfg, "block_per_cu", [1]) + config_set.num_wave_groups_values = _get_values(occ_cfg, "num_wave_groups", [1]) + + return config_set + + +def generate_cpp_conv_kernel_set_declaration( + config_set: ConvKernelConfigSet, + set_name: Optional[str] = None, +) -> str: + """ + Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet. + """ + name = set_name or config_set.name + + lines = [f"DECL_CONV_KERNEL_SET({name},"] + + for config in config_set.generate_configs(): + line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, ' + line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})" + lines.append(line) + + lines.append(");") + + return "\n".join(lines) + + +# ============================================================================= +# GEMM Configuration Export Functions +# ============================================================================= + + +def generate_cpp_kernel_set_declaration( + config_set: KernelConfigSet, + set_name: Optional[str] = None, +) -> str: + """ + Generate C++ DECL_KERNEL_SET code from a KernelConfigSet. + + Args: + config_set: The kernel configuration set + set_name: Optional name override for the kernel set + + Returns: + C++ code string with DECL_KERNEL_SET declaration + """ + name = set_name or config_set.name + + lines = [f"DECL_KERNEL_SET({name},"] + + for config in config_set.generate_configs(): + # Generate .add() call for each config + line = f' .add("{config.dtype_a}", "{config.layout}", ' + line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})" + lines.append(line) + + lines.append(");") + + return "\n".join(lines) + + +# CLI for testing +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python kernel_config_loader.py ") + print("\nLoads kernel configurations from JSON and prints summary.") + sys.exit(1) + + json_path = sys.argv[1] + + try: + config_set = load_kernel_configs(json_path) + + print(f"Kernel Set: {config_set.name}") + print( + f"Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}, Acc={config_set.dtype_acc}" + ) + print(f"Layout: {config_set.layout}") + print(f"GPU Targets: {config_set.gpu_targets}") + print(f"Variant: {config_set.variant}") + print() + print("Tile Configurations:") + print(f" tile_m: {config_set.tile_m_values}") + print(f" tile_n: {config_set.tile_n_values}") + print(f" tile_k: {config_set.tile_k_values}") + print(f" warp_m: {config_set.warp_m_values}") + print(f" warp_n: {config_set.warp_n_values}") + print(f" warp_k: {config_set.warp_k_values}") + print( + f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}" + ) + print() + print("Trait Configurations:") + print(f" pipeline: {config_set.pipeline_values}") + print(f" scheduler: {config_set.scheduler_values}") + print(f" epilogue: {config_set.epilogue_values}") + print( + f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}" + ) + print() + print(f"Total configurations: {config_set.config_count()}") + print() + + # Print first few config names + print("Sample kernel names:") + for i, config in enumerate(config_set.generate_configs()): + if i >= 5: + print(f" ... and {config_set.config_count() - 5} more") + break + print(f" {config.kernel_name()}") + print() + + # Generate C++ code + if "--cpp" in sys.argv: + print("C++ Declaration:") + print("-" * 60) + print(generate_cpp_kernel_set_declaration(config_set)) + + except Exception as e: + print(f"Error: {e}") + sys.exit(1) diff --git a/dispatcher/codegen/preselected_kernels.py b/dispatcher/codegen/preselected_kernels.py new file mode 100644 index 0000000000..010d930639 --- /dev/null +++ b/dispatcher/codegen/preselected_kernels.py @@ -0,0 +1,518 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Preselected, Benchmarked Kernel Configurations + +Curated kernel sets optimized for different workload characteristics: +- Compute-friendly: Large tiles, high arithmetic intensity +- Memory-friendly: Smaller tiles, better memory access patterns +- Latency-friendly: Minimal tiles, low latency for small problems +""" + +from functools import partial, lru_cache +from typing import List +from unified_gemm_codegen import KernelConfig, TileConfig, TraitConfig, GemmVariant + + +# ============================================================================ +# Base Configurations +# ============================================================================ + + +def _base_fp16_rcr_compute() -> partial: + """Base configuration for compute-intensive FP16 RCR kernels""" + return partial( + KernelConfig, + tile=None, # Will be overridden + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +def _base_fp16_rcr_memory() -> partial: + """Base configuration for memory-intensive FP16 RCR kernels""" + # Note: Use 'mem' pipeline for interwave scheduler (compv3/compv4/compv5/compv6 only support intrawave) + return partial( + KernelConfig, + tile=None, # Will be overridden + trait=TraitConfig( + pipeline="mem", + epilogue="cshuffle", + scheduler="interwave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=128, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +def _base_fp16_rcr_latency() -> partial: + """Base configuration for latency-sensitive FP16 RCR kernels""" + return partial( + KernelConfig, + tile=None, # Will be overridden + trait=TraitConfig( + pipeline="mem", + epilogue="default", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=128, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +# ============================================================================ +# Preselected FP16 RCR Kernels +# ============================================================================ + + +@lru_cache(None) +def preselected_fp16_rcr_compute() -> List[KernelConfig]: + """ + Compute-friendly FP16 RCR kernels + + Optimized for: + - Large M, N dimensions (>= 128) + - High arithmetic intensity + - Good occupancy + - Maximum throughput + """ + base = _base_fp16_rcr_compute() + + return [ + # Large tiles for maximum compute + base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(256, 128, 32, 4, 2, 1, 32, 32, 16)), + base(tile=TileConfig(128, 256, 32, 2, 4, 1, 32, 32, 16)), + # Balanced tiles + base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)), + # With persistent kernel for large batches + base( + tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16), + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=False, + pad_n=False, + pad_k=False, + persistent=True, + ), + ), + ] + + +@lru_cache(None) +def preselected_fp16_rcr_memory() -> List[KernelConfig]: + """ + Memory-friendly FP16 RCR kernels + + Optimized for: + - Small to medium M, N dimensions + - Memory-bound workloads + - Better cache utilization + - Lower register pressure + """ + base = _base_fp16_rcr_memory() + + return [ + # Small tiles for memory efficiency + base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)), + base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)), + base(tile=TileConfig(16, 64, 32, 1, 2, 1, 16, 16, 16)), + base(tile=TileConfig(64, 16, 32, 2, 1, 1, 16, 16, 16)), + # Medium tiles + base(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)), + base(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)), + base(tile=TileConfig(32, 128, 32, 1, 2, 1, 32, 32, 16)), + base(tile=TileConfig(128, 32, 32, 2, 1, 1, 32, 32, 16)), + ] + + +@lru_cache(None) +def preselected_fp16_rcr_latency() -> List[KernelConfig]: + """ + Latency-friendly FP16 RCR kernels + + Optimized for: + - Very small M, N dimensions (< 64) + - Minimal launch overhead + - Low latency + - Quick execution + """ + base = _base_fp16_rcr_latency() + + return [ + # Minimal tiles for low latency + base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)), + base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)), + ] + + +# ============================================================================ +# Preselected Multi-D Kernels +# ============================================================================ + + +@lru_cache(None) +def preselected_fp16_rcr_multi_d() -> List[KernelConfig]: + """ + Multi-D GEMM kernels with element-wise fusion + + Common fusions: + - MultiDAdd: E = C + D0 + D1 + - Relu: E = max(C, 0) + - Gelu: E = gelu(C) + """ + base = _base_fp16_rcr_compute() + + configs = [] + + # Best-performing tile for fused operations + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + + # Common element-wise operations + for ew_op in ["MultiDAdd", "Relu", "Gelu", "FastGelu"]: + for num_d in [1, 2]: + configs.append( + base( + tile=tile, + variant=GemmVariant.MULTI_D, + elementwise_op=ew_op, + num_d_tensors=num_d, + ) + ) + + return configs + + +@lru_cache(None) +def preselected_fp16_rcr_preshuffle() -> List[KernelConfig]: + """ + Preshuffle GEMM kernels for weight optimization + + Best for: + - Repeated use of same weights + - Inference workloads + - Batch size > 1 + """ + base = _base_fp16_rcr_compute() + + return [ + base( + tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16), + variant=GemmVariant.PRESHUFFLE, + preshuffle=True, + ), + base( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + variant=GemmVariant.PRESHUFFLE, + preshuffle=True, + ), + ] + + +# ============================================================================ +# Unified Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_fp16_rcr_all() -> List[KernelConfig]: + """All preselected FP16 RCR kernels""" + return ( + preselected_fp16_rcr_compute() + + preselected_fp16_rcr_memory() + + preselected_fp16_rcr_latency() + + preselected_fp16_rcr_multi_d() + + preselected_fp16_rcr_preshuffle() + ) + + +@lru_cache(None) +def preselected_fp16_rcr_essential() -> List[KernelConfig]: + """ + Essential FP16 RCR kernels - minimal set for most workloads + + Covers: + - 90% of common GEMM sizes + - Key fusion operations + - Balanced performance + """ + base_compute = _base_fp16_rcr_compute() + base_memory = _base_fp16_rcr_memory() + + return [ + # Top compute kernels + base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + # Top memory kernels + base_memory(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)), + base_memory(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)), + # Essential fusions + base_compute( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + variant=GemmVariant.MULTI_D, + elementwise_op="Relu", + num_d_tensors=1, + ), + base_compute( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + variant=GemmVariant.MULTI_D, + elementwise_op="Gelu", + num_d_tensors=1, + ), + ] + + +# ============================================================================ +# Default Fallback +# ============================================================================ + + +def default_kernel() -> KernelConfig: + """ + Default fallback kernel - guaranteed to work + + Known-good configuration tested on gfx942 + """ + return KernelConfig( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +# ============================================================================ +# BF16 Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_bf16_rcr_essential() -> List[KernelConfig]: + """Essential BF16 RCR kernels""" + base_compute = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# INT8 Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_int8_rcr_essential() -> List[KernelConfig]: + """Essential INT8 RCR kernels for quantized inference""" + base = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# FP8 Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_fp8_rcr_essential() -> List[KernelConfig]: + """Essential FP8 RCR kernels for AI training""" + base = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# Mixed Precision Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_mixed_precision() -> List[KernelConfig]: + """Mixed-precision kernels (FP16 inputs, FP32 output)""" + base = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# Registry +# ============================================================================ + +PRESELECTED_SETS = { + # FP16 sets + "fp16_rcr_compute": preselected_fp16_rcr_compute, + "fp16_rcr_memory": preselected_fp16_rcr_memory, + "fp16_rcr_latency": preselected_fp16_rcr_latency, + "fp16_rcr_multi_d": preselected_fp16_rcr_multi_d, + "fp16_rcr_preshuffle": preselected_fp16_rcr_preshuffle, + "fp16_rcr_all": preselected_fp16_rcr_all, + "fp16_rcr_essential": preselected_fp16_rcr_essential, + # BF16 sets + "bf16_rcr_essential": preselected_bf16_rcr_essential, + # INT8 sets + "int8_rcr_essential": preselected_int8_rcr_essential, + # FP8 sets + "fp8_rcr_essential": preselected_fp8_rcr_essential, + # Mixed precision + "mixed_precision": preselected_mixed_precision, +} + + +def get_preselected_set(name: str) -> List[KernelConfig]: + """Get a preselected kernel set by name""" + if name not in PRESELECTED_SETS: + raise ValueError( + f"Unknown preselected set: {name}. Available: {list(PRESELECTED_SETS.keys())}" + ) + return PRESELECTED_SETS[name]() + + +def list_preselected_sets() -> List[str]: + """List all available preselected sets""" + return list(PRESELECTED_SETS.keys()) + + +# ============================================================================ +# CLI for testing +# ============================================================================ + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="List preselected kernel configurations" + ) + parser.add_argument( + "--set", + type=str, + default="fp16_rcr_essential", + choices=list_preselected_sets(), + help="Preselected set to display", + ) + parser.add_argument("--count-only", action="store_true", help="Only show count") + + args = parser.parse_args() + + configs = get_preselected_set(args.set) + + if args.count_only: + print(f"{args.set}: {len(configs)} kernels") + else: + print(f"Preselected set: {args.set}") + print(f"Total kernels: {len(configs)}\n") + for i, cfg in enumerate(configs, 1): + print(f"{i}. {cfg.variant.value}") + print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}") + print(f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}") + if cfg.variant == GemmVariant.MULTI_D: + print( + f" Element-wise: {cfg.elementwise_op}, D tensors: {cfg.num_d_tensors}" + ) + print() diff --git a/dispatcher/codegen/unified_gemm_codegen.py b/dispatcher/codegen/unified_gemm_codegen.py new file mode 100755 index 0000000000..b0dd961be7 --- /dev/null +++ b/dispatcher/codegen/unified_gemm_codegen.py @@ -0,0 +1,1713 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Unified GEMM Code Generator - Single Source of Truth + +This is THE unified code generator for all GEMM kernel variants: +- Standard GEMM (C = A × B) +- Preshuffle GEMM (optimized weight access) +- Multi-D GEMM (element-wise fusion) + +Generates both CK Tile kernels AND dispatcher wrappers in one pass. +Replaces all tile_engine GEMM codegen. +""" + +import json +import argparse +import itertools +import logging +from pathlib import Path +from typing import Dict, List, Tuple, Optional +from dataclasses import dataclass, asdict +from enum import Enum +import concurrent.futures + +# Import architecture filter for GPU-specific validation +try: + from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType + + HAS_ARCH_FILTER = True +except ImportError: + HAS_ARCH_FILTER = False + ArchFilter = None + ArchKernelConfig = None + OperatorType = None + + +# ============================================================================= +# Preshuffle Validation (copied from tile_engine/ops/commons/gemm_validation_utils.py) +# ============================================================================= + +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, +} + + +def _validate_preshuffle_vector_load( + warp_tile_m: int, + warp_tile_k: int, + datatype: str, + m_iter_per_warp: float, + wave_size: int = 64, + vector_load_size: int = 16, +) -> bool: + """ + Validate vector load alignment for preshuffle pipeline. + + Checks: (warp_tile_m * warp_tile_k * elem_size * m_iter_per_warp / wave_size) % vector_load_size == 0 + """ + elem_size = ELEMENT_SIZE_MAP.get(datatype, 2) + access_size = (warp_tile_m * warp_tile_k * elem_size * m_iter_per_warp) / wave_size + return access_size % vector_load_size == 0 + + +def _validate_preshuffle_m0_m1_m2( + tile_m: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + datatype: str, + vector_load_size: int = 16, + warp_size: int = 64, +) -> bool: + """ + Validate M0, M1, M2 configuration for preshuffle matrix A row-major layout. + Ensures proper memory access pattern alignment. + """ + try: + elem_size = ELEMENT_SIZE_MAP.get(datatype, 2) + MPerBlock = tile_m + + # Calculate K1 + K1 = vector_load_size / elem_size + if K1 != int(K1): + return False + K1 = int(K1) + + # Calculate K0 + if tile_k % K1 != 0: + return False + K0 = tile_k // K1 + + # Calculate M2 + if warp_size % K0 != 0: + return False + M2 = warp_size // K0 + + # Calculate number of warps + NumWarps = warp_m * warp_n * warp_k + M0 = NumWarps + + # Calculate M1 + if (M2 * M0) == 0: + return False + if MPerBlock % (M2 * M0) != 0: + return False + M1 = MPerBlock // (M2 * M0) + + # Validate: M0 * M1 * M2 == MPerBlock + return (M0 * M1 * M2) == MPerBlock + + except (ZeroDivisionError, ValueError): + return False + + +def is_preshuffle_config_valid( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + datatype: str, +) -> bool: + """ + Comprehensive preshuffle configuration validation. + Copied from tile_engine/ops/commons/gemm_validation_utils.py + """ + # Basic divisibility checks + if tile_m % (warp_m * warp_tile_m) != 0: + return False + if tile_n % (warp_n * warp_tile_n) != 0: + return False + if tile_k % (warp_k * warp_tile_k) != 0: + return False + + # Calculate m_iter_per_warp + m_iter_per_warp = tile_m / (warp_m * warp_tile_m) + + # Validate vector load alignment + if not _validate_preshuffle_vector_load( + warp_tile_m, + warp_tile_k, + datatype, + m_iter_per_warp, + wave_size=64, + vector_load_size=16, + ): + return False + + # Validate M0/M1/M2 configuration + if not _validate_preshuffle_m0_m1_m2( + tile_m, + tile_k, + warp_m, + warp_n, + warp_k, + datatype, + vector_load_size=16, + warp_size=64, + ): + return False + + return True + + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + +log = logging.getLogger(__name__) + + +# ============================================================================ +# Configuration and Data Structures +# ============================================================================ + + +class GemmVariant(Enum): + """GEMM kernel variants""" + + STANDARD = "standard" + PRESHUFFLE = "preshuffle" + MULTI_D = "multi_d" + + +@dataclass +class TileConfig: + """Tile configuration parameters""" + + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + def is_valid(self) -> bool: + """Validate tile configuration""" + return ( + self.tile_m % (self.warp_m * self.warp_tile_m) == 0 + and self.tile_n % (self.warp_n * self.warp_tile_n) == 0 + and self.tile_k % (self.warp_k * self.warp_tile_k) == 0 + and self.tile_m > 0 + and self.tile_n > 0 + and self.tile_k > 0 + ) + + +@dataclass +class TraitConfig: + """Kernel trait configuration""" + + pipeline: str # mem, compv3, compv4 + epilogue: str # default, cshuffle + scheduler: str # intrawave, interwave + pad_m: bool + pad_n: bool + pad_k: bool + persistent: bool + + def is_valid(self) -> bool: + """Check if trait combination is valid""" + # Unsupported combinations + # Only 'mem' pipeline supports interwave scheduler. + # All compute pipelines (compv3/v4/v5/v6/async) only support intrawave. + unsupported = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + ("compv5", "cshuffle", "interwave"), + ("compv5", "default", "interwave"), + ("compv6", "cshuffle", "interwave"), + ("compv6", "default", "interwave"), + ("comp_async", "cshuffle", "interwave"), + ("comp_async", "default", "interwave"), + } + return (self.pipeline, self.epilogue, self.scheduler) not in unsupported + + +@dataclass +class KernelConfig: + """Complete kernel configuration""" + + tile: TileConfig + trait: TraitConfig + variant: GemmVariant = GemmVariant.STANDARD + + # Variant-specific + preshuffle: bool = False + elementwise_op: str = "PassThrough" + num_d_tensors: int = 0 + d_layout: str = "r" # Layout for D tensors (r=row, c=col) - same for all D tensors + + # Fixed parameters + block_size: int = 256 + k_block_per_cu: int = 1 + num_wave_groups: int = 1 + + def name(self, datatype: str, layout: str) -> str: + """C++ alias for template instance""" + return f"ck_tile_gemm_{self.key_name(datatype, layout)}" + + def key_name(self, datatype: str, layout: str) -> str: + """ + Unique identifier for this kernel configuration. + + All parameters that affect kernel behavior MUST be included to ensure + unique names for unique configurations: + - Data type and layout (signature) + - Tile, warp, warp_tile dimensions (algorithm) + - Pipeline, epilogue, scheduler (traits) + - Padding flags (affects divisibility requirements) + - Persistent mode + - Preshuffle variant + - Multi-D: elementwise op, num D tensors, D layout + - Occupancy: wave groups, k_block_per_cu (if non-default) + """ + parts = [] + # Signature + parts.append(f"dt_{datatype}") + parts.append(f"ly_{layout}") + + # Tile configuration + parts.append(f"tile_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}") + parts.append(f"warp_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}") + parts.append( + f"wtile_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}" + ) + + # Traits + parts.append(f"pipe_{self.trait.pipeline}") + parts.append(f"epi_{self.trait.epilogue}") + parts.append(f"sched_{self.trait.scheduler}") + + # Padding flags (only if not all True - the common case) + if not (self.trait.pad_m and self.trait.pad_n and self.trait.pad_k): + parts.append( + f"pad{int(self.trait.pad_m)}{int(self.trait.pad_n)}{int(self.trait.pad_k)}" + ) + + # Persistent mode + if self.trait.persistent: + parts.append("persist") + + # Preshuffle variant + if self.preshuffle: + parts.append("preshuffle") + + # Multi-D variant: include elementwise op, num tensors, and D layout + if self.variant == GemmVariant.MULTI_D: + parts.append(f"ew_{self.elementwise_op}") + parts.append(f"nd{self.num_d_tensors}") + parts.append(f"dly_{self.d_layout}") + + # Occupancy parameters (only if non-default) + if self.num_wave_groups != 1: + parts.append(f"wg{self.num_wave_groups}") + if self.k_block_per_cu != 1: + parts.append(f"kbpc{self.k_block_per_cu}") + + return "_".join(parts) + + def dict_items(self): + """Iterator over (field, value) pairs""" + return asdict(self).items() + + +# ============================================================================ +# Type Mappings +# ============================================================================ + + +class TypeMappings: + """Centralized type mappings for code generation""" + + DTYPE_TO_CK = { + "fp16": "fp16_t", + "bf16": "bf16_t", + "fp32": "float", + "fp8": "fp8_t", + "bf8": "bf8_t", + "int8": "int8_t", + } + + # Fully-qualified types for use outside of 'using namespace ck_tile' scope + DTYPE_TO_CK_QUALIFIED = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", # Built-in type, no namespace + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int8": "int8_t", # Built-in type + } + + DTYPE_TO_DISPATCHER = { + "fp16": "DataType::FP16", + "bf16": "DataType::BF16", + "fp32": "DataType::FP32", + "fp8": "DataType::FP8", + "bf8": "DataType::BF8", + "int8": "DataType::INT8", + } + + LAYOUT_TO_CK = { + "r": "tensor_layout::gemm::RowMajor", + "c": "tensor_layout::gemm::ColumnMajor", + } + + LAYOUT_TO_DISPATCHER = { + "r": "LayoutTag::RowMajor", + "c": "LayoutTag::ColMajor", + } + + PIPELINE_TO_CK = { + "mem": "GemmPipelineAgBgCrMem", + "compv3": "GemmPipelineAgBgCrCompV3", + "compv4": "GemmPipelineAgBgCrCompV4", + "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_BASE = { + "mem": "BaseGemmPipelineAgBgCrMem", + "compv3": "BaseGemmPipelineAgBgCrCompV3", + "compv4": "BaseGemmPipelineAgBgCrCompV4", + "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_DISPATCHER = { + "mem": "Pipeline::Mem", + "compv3": "Pipeline::CompV3", + "compv4": "Pipeline::CompV4", + "preshufflev2": "Pipeline::PreShuffleV2", + } + + SCHEDULER_TO_CK = { + "intrawave": "GemmPipelineScheduler::Intrawave", + "interwave": "GemmPipelineScheduler::Interwave", + "default": "GemmPipelineScheduler::Default", + } + + SCHEDULER_TO_DISPATCHER = { + "intrawave": "Scheduler::Intrawave", + "interwave": "Scheduler::Interwave", + "default": "Scheduler::Auto", + } + + EPILOGUE_TO_DISPATCHER = { + "cshuffle": "Epilogue::CShuffle", + "default": "Epilogue::Default", + } + + @staticmethod + def get_output_dtype(dtype: str) -> str: + """Get output datatype (fp8/bf8 -> fp16)""" + return "fp16" if dtype in ["fp8", "bf8"] else dtype + + +# ============================================================================ +# Kernel Name Generator +# ============================================================================ + + +class KernelNaming: + """Unified kernel naming""" + + @staticmethod + def generate(config: KernelConfig, datatype: str, layout: str) -> str: + """Generate kernel name following tile_engine convention""" + t = config.tile + tr = config.trait + + # For multi-d, use 4-char layout (abcd), otherwise use 3-char layout (abc) + if config.variant == GemmVariant.MULTI_D: + full_layout = layout + config.d_layout # e.g., "rcr" + "r" = "rcrr" + else: + full_layout = layout + + name = ( + f"gemm_{datatype}_{full_layout}_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}" + ) + name += f"_{str(tr.pad_m).capitalize()}_{str(tr.pad_n).capitalize()}" + name += f"_{str(tr.pad_k).capitalize()}_{str(tr.persistent).capitalize()}" + name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}" + name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}" + name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}" + + # Add variant suffix + if config.variant == GemmVariant.PRESHUFFLE: + name += "_preshuffle" + elif config.variant == GemmVariant.MULTI_D: + name += f"_multid_{config.elementwise_op}_d{config.num_d_tensors}" + + return name + + +# ============================================================================ +# CK Tile Kernel Generator +# ============================================================================ + + +class CKTileKernelGenerator: + """Generates CK Tile kernel instance code""" + + def __init__(self, datatype: str, layout: str): + self.datatype = datatype + self.layout = layout + self.tm = TypeMappings() + + def generate(self, config: KernelConfig) -> str: + """Generate complete CK Tile kernel""" + kernel_name = KernelNaming.generate(config, self.datatype, self.layout) + + return f"""{self._header(kernel_name, config)} +{self._types(config, kernel_name)} +{self._selected_kernel_struct(config, kernel_name)} +""" + + def _header(self, kernel_name: str, config: KernelConfig) -> str: + """Generate header includes""" + includes = """// SPDX-License-Identifier: MIT +// Auto-generated CK Tile GEMM kernel +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" + +""" + + if config.variant == GemmVariant.MULTI_D: + includes += """ +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" +""" + + if config.preshuffle: + includes += """ +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" +""" + + return includes + + def _types(self, config: KernelConfig, kernel_name: str) -> str: + """Generate type definitions - just the namespace import, types are in kernel namespace""" + # Note: Data types and layouts are now defined inside each kernel's unique namespace + # to avoid type alias redefinition conflicts when mixing layouts (e.g., RCR + RRR) + types = """ +// Use ck_tile namespace for generated code +using namespace ck_tile; +""" + return types + + def _kernel_local_types(self, config: KernelConfig) -> str: + """Generate data type and layout definitions inside kernel namespace""" + output_dtype = self.tm.get_output_dtype(self.datatype) + + return f""" + // Data types (inside namespace to avoid conflicts across layouts) + using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using AccDataType = float; + using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; + + // Layouts (inside namespace to avoid conflicts when mixing layouts) + using ALayout = {self.tm.LAYOUT_TO_CK[self.layout[0]]}; + using BLayout = {self.tm.LAYOUT_TO_CK[self.layout[1]]}; + using CLayout = {self.tm.LAYOUT_TO_CK[self.layout[2]]}; +""" + + def _multi_d_types(self, config: KernelConfig) -> str: + """Generate multi-d type definitions (inside namespace to avoid conflicts)""" + if config.variant != GemmVariant.MULTI_D: + return "" + + d_types = ", ".join(["CDataType"] * config.num_d_tensors) + d_layout_ck = self.tm.LAYOUT_TO_CK[config.d_layout] + d_layouts = ", ".join([d_layout_ck] * config.num_d_tensors) + + return f""" +// Multi-D types (defined in namespace to avoid conflicts) +using DsDataType = tuple<{d_types}>; +using DLayout = {d_layout_ck}; // D tensor layout (can differ from C) +using DsLayout = tuple<{d_layouts}>; +using ElementWiseFn = element_wise::{config.elementwise_op}; +static constexpr index_t NumDTensor = {config.num_d_tensors}; +using GemmMultiDArgs = GemmMultiDHostArgs; +""" + + def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str: + """Generate SelectedKernel struct with unique name in unique namespace""" + t = config.tile + tr = config.trait + output_dtype = self.tm.get_output_dtype(self.datatype) + + # Generate unique struct name and namespace from kernel name + struct_name = f"Kernel_{kernel_name}" + # Create valid C++ namespace name (replace invalid chars) + ns_name = "ns_" + kernel_name.replace("-", "_") + + multi_d_types = self._multi_d_types(config) + + return f""" +namespace {ns_name} {{ +constexpr const char* KERNEL_NAME = "{kernel_name}"; + +// Data types (inside namespace to avoid conflicts across different kernels) +using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; +using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; +using AccDataType = float; +using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; + +// Layouts (inside namespace to avoid conflicts when mixing layouts like RCR + RRR) +using ALayout = {self.tm.LAYOUT_TO_CK[self.layout[0]]}; +using BLayout = {self.tm.LAYOUT_TO_CK[self.layout[1]]}; +using CLayout = {self.tm.LAYOUT_TO_CK[self.layout[2]]}; +{multi_d_types} +struct {struct_name} {{ + // Data types (required by backend as member types) + using ADataType = {ns_name}::ADataType; + using BDataType = {ns_name}::BDataType; + using CDataType = {ns_name}::CDataType; + using AccDataType = {ns_name}::AccDataType; + + // Configuration + static constexpr index_t BlockSize = {config.block_size}; + static constexpr index_t TileM = {t.tile_m}; + static constexpr index_t TileN = {t.tile_n}; + static constexpr index_t TileK = {t.tile_k}; + static constexpr index_t WarpPerBlock_M = {t.warp_m}; + static constexpr index_t WarpPerBlock_N = {t.warp_n}; + static constexpr index_t WarpPerBlock_K = {t.warp_k}; + static constexpr index_t WarpTileM = {t.warp_tile_m}; + static constexpr index_t WarpTileN = {t.warp_tile_n}; + static constexpr index_t WarpTileK = {t.warp_tile_k}; + + // Traits + static constexpr bool kPadM = {str(tr.pad_m).lower()}; + static constexpr bool kPadN = {str(tr.pad_n).lower()}; + static constexpr bool kPadK = {str(tr.pad_k).lower()}; + static constexpr bool TransposeC = false; + static constexpr bool UsePersistentKernel = {str(tr.persistent).lower()}; + static constexpr bool DoubleSmemBuffer = {str(tr.pipeline == "compv4" or tr.pipeline == "preshufflev2").lower()}; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Preshuffle = {str(config.preshuffle).lower()}; + static constexpr index_t NumWaveGroups = {config.num_wave_groups}; + + {self._tile_types(config, ns_name)} + {self._launch_function(config)} +}}; + +// Alias for tile_engine style compatibility (when used with -include) +using SelectedKernel = {struct_name}; +using SelectedKernelLauncher = {struct_name}; +}} // namespace {ns_name} + +// Export to global namespace ONLY for single-kernel includes +// Define CK_TILE_SINGLE_KERNEL_INCLUDE before including this header to enable these aliases +#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE +using {struct_name} = {ns_name}::{struct_name}; +using SelectedKernel = {ns_name}::{struct_name}; +constexpr const char* KERNEL_NAME = {ns_name}::KERNEL_NAME; +using ADataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; +using BDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; +using CDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.tm.get_output_dtype(self.datatype)]}; +using AccDataType = float; +#endif // CK_TILE_SINGLE_KERNEL_INCLUDE +""" + + def _tile_types(self, config: KernelConfig, ns_name: str) -> str: + """Generate tile type definitions - uses namespace-qualified types""" + return ( + f"""// Tile shape + using TileShape = TileGemmShape< + sequence, + sequence, + sequence, + false, false>; + + using TilePartitioner = GemmSpatiallyLocalTilePartitioner; + using Traits = TileGemmTraits; + using GemmPipelineProblem = GemmPipelineProblem; + using BaseGemmPipeline = """ + + self.tm.PIPELINE_TO_BASE[config.trait.pipeline] + + """;""" + ) + + def _launch_function(self, config: KernelConfig) -> str: + """Generate launch function""" + if config.variant == GemmVariant.MULTI_D: + return self._launch_function_multi_d(config) + if config.preshuffle: + return self._launch_function_preshuffle(config) + return self._launch_function_standard(config) + + def _launch_function_standard(self, config: KernelConfig) -> str: + """Generate launch function for standard GEMM""" + return f""" + static float launch(const GemmHostArgs& args, const stream_config& stream) {{ + const index_t k_grain = args.k_batch * TileK; + const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = {self.tm.SCHEDULER_TO_CK[config.trait.scheduler]}; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + ADataType, BDataType, AccDataType, TileShape, + TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = {self.tm.PIPELINE_TO_CK[config.trait.pipeline]}; + {self._epilogue_code(config)} + + using GemmKernel = ck_tile::GemmKernel; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported!"); + }} + + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if config.trait.persistent else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + constexpr int kBlockPerCu = {config.k_block_per_cu}; + ave_time = launch_kernel(stream, + make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }}""" + + def _launch_function_preshuffle(self, config: KernelConfig) -> str: + """Generate launch function for preshuffle GEMM (weight preshuffle variant) + + Preshuffle uses WeightPreshufflePipelineAGmemBGmemCRegV2 which has a different + API than standard pipelines. It's designed for weight-preshuffled GEMM operations. + """ + return f""" + static float launch(const GemmHostArgs& args, const stream_config& stream) {{ + const index_t k_grain = args.k_batch * TileK; + const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = GemmPipelineScheduler::Default; // Preshuffle uses Default scheduler + + // Preshuffle uses TileFlatmmShape instead of TileGemmShape for the problem + using UniversalGemmProblem = UniversalGemmPipelineProblem< + ADataType, BDataType, AccDataType, TileShape, + TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = WeightPreshufflePipelineAGmemBGmemCRegV2; + {self._epilogue_code(config)} + + using GemmKernel = ck_tile::GemmKernel; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported for preshuffle kernel!"); + }} + + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if config.trait.persistent else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + constexpr int kBlockPerCu = {config.k_block_per_cu}; + ave_time = launch_kernel(stream, + make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }}""" + + def _launch_function_multi_d(self, config: KernelConfig) -> str: + """Generate launch function for Multi-D GEMM""" + return f""" + // Multi-D launch function - takes GemmMultiDHostArgs with D tensor pointers + static float launch(const GemmMultiDArgs& args, const stream_config& stream) {{ + const index_t k_grain = args.k_batch * TileK; + const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = {self.tm.SCHEDULER_TO_CK[config.trait.scheduler]}; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + ADataType, BDataType, AccDataType, TileShape, + TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = {self.tm.PIPELINE_TO_CK[config.trait.pipeline]}; + {self._epilogue_code(config)} + + // Use GemmKernelMultiD for Multi-D variant + using GemmKernel = ck_tile::GemmKernelMultiD; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported! Multi-D currently doesn't support k_batch > 1"); + }} + + const dim3 grids = GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); + + constexpr int kBlockPerCu = {config.k_block_per_cu}; + ave_time = launch_kernel(stream, + make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }} + + // Overload for standard GemmHostArgs (converts to Multi-D args with empty D tensors) + static float launch(const GemmHostArgs& args, const stream_config& stream) {{ + std::array empty_ds{{}}; + std::array empty_strides{{}}; + for (index_t i = 0; i < NumDTensor; ++i) {{ + empty_ds[i] = nullptr; + empty_strides[i] = 0; + }} + GemmMultiDArgs multi_d_args{{ + args.a_ptr, + args.b_ptr, + empty_ds, + args.e_ptr, + args.k_batch, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + empty_strides, + args.stride_C + }}; + return launch(multi_d_args, stream); + }}""" + + def _epilogue_code(self, config: KernelConfig) -> str: + """Generate epilogue code""" + if config.variant == GemmVariant.MULTI_D: + return """ + using EpilogueProblem = CShuffleEpilogueProblem< + ADataType, BDataType, DsDataType, AccDataType, CDataType, + DsLayout, CLayout, ElementWiseFn, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, + TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>; + using GemmEpilogue = CShuffleEpilogue;""" + elif config.trait.epilogue == "cshuffle": + return """ + using EpilogueProblem = CShuffleEpilogueProblem< + ADataType, BDataType, tuple<>, AccDataType, CDataType, + tuple<>, CLayout, element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, + TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>; + using GemmEpilogue = CShuffleEpilogue;""" + else: + return """ + using EpilogueProblem = DefaultGemm2DEpilogueProblem< + ADataType, BDataType, tuple<>, AccDataType, CDataType, + tuple<>, CLayout, element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + kPadM, kPadN, WarpTileM, WarpTileN, WarpTileK, TransposeC>; + using GemmEpilogue = DefaultGemm2DEpilogue;""" + + +# ============================================================================ +# Dispatcher Wrapper Generator +# ============================================================================ + + +class DispatcherWrapperGenerator: + """Generates dispatcher wrapper code""" + + def __init__(self, datatype: str, layout: str): + self.datatype = datatype + self.layout = layout + self.tm = TypeMappings() + + def generate( + self, config: KernelConfig, kernel_path: Path, output_dir: Path + ) -> str: + """Generate dispatcher wrapper""" + kernel_name = KernelNaming.generate(config, self.datatype, self.layout) + output_dtype = self.tm.get_output_dtype(self.datatype) + rel_path = kernel_path.relative_to(output_dir) + + return f"""// SPDX-License-Identifier: MIT +// Auto-generated dispatcher wrapper +#pragma once + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp" +#include "{rel_path}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +using ::ck_tile::dispatcher::KernelInstancePtr; +using ::ck_tile::dispatcher::KernelKey; +using ::ck_tile::dispatcher::DataType; +using ::ck_tile::dispatcher::LayoutTag; +using ::ck_tile::dispatcher::Pipeline; +using ::ck_tile::dispatcher::Scheduler; +using ::ck_tile::dispatcher::Epilogue; +using Priority = ::ck_tile::dispatcher::Registry::Priority; +namespace backends = ::ck_tile::dispatcher::backends; + +inline KernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{ + // Use the unique kernel struct name + using KernelStruct = Kernel_{kernel_name}; + + KernelKey key; + + // Signature + key.signature.dtype_a = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; + key.signature.dtype_b = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; + key.signature.dtype_c = {self.tm.DTYPE_TO_DISPATCHER[output_dtype]}; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[0]]}; + key.signature.layout_b = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[1]]}; + key.signature.layout_c = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[2]]}; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "{config.elementwise_op}"; + key.signature.num_d_tensors = {config.num_d_tensors}; + key.signature.structured_sparsity = false; + + // Algorithm + key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}}; + key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, {config.tile.warp_k}}}; + key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}}; + key.algorithm.pipeline = {self.tm.PIPELINE_TO_DISPATCHER[config.trait.pipeline]}; + key.algorithm.scheduler = {self.tm.SCHEDULER_TO_DISPATCHER[config.trait.scheduler]}; + key.algorithm.epilogue = {self.tm.EPILOGUE_TO_DISPATCHER[config.trait.epilogue]}; + key.algorithm.block_size = {config.block_size}; + key.algorithm.double_buffer = {str(config.trait.pipeline == "compv4").lower()}; + key.algorithm.persistent = {str(config.trait.persistent).lower()}; + key.algorithm.preshuffle = {str(config.preshuffle).lower()}; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = {config.num_wave_groups}; + + key.gfx_arch = gfx_arch; + + return std::make_shared>(key, "{kernel_name}"); +}} + +}}}}}} +""" + + +# ============================================================================ +# Main Unified Generator +# ============================================================================ + + +class UnifiedGemmCodegen: + """Unified GEMM code generator - single entry point""" + + def __init__( + self, + output_dir: Path, + datatype: str, + layout: str, + gpu_target: str = "gfx942", + config_file: Optional[Path] = None, + variants: List[GemmVariant] = None, + use_preselected: Optional[str] = None, + enable_arch_filter: bool = True, + kernel_set_name: Optional[str] = None, + ): + self.output_dir = Path(output_dir) + self.datatype = datatype + # Support 3-char (rcr) or 4-char (rcrr) layout codes + # 4th char specifies D tensor layout for multi-d + self.layout = layout[:3] # A, B, C layouts + self.d_layout = ( + layout[3] if len(layout) >= 4 else layout[2] + ) # D layout (default = C layout) + self.gpu_target = gpu_target + self.variants = variants or [GemmVariant.STANDARD] + self.use_preselected = use_preselected + self.kernel_set_name = kernel_set_name + + # Create directories - optionally with kernel set subdirectory + if kernel_set_name: + self.kernel_dir = self.output_dir / kernel_set_name + else: + self.kernel_dir = self.output_dir + self.kernel_dir.mkdir(parents=True, exist_ok=True) + self.wrapper_dir = self.kernel_dir / "dispatcher_wrappers" + self.wrapper_dir.mkdir(parents=True, exist_ok=True) + + # Load configuration + self.config = self._load_config(config_file) + + # Initialize architecture filter for GPU-specific validation + self.arch_filter = None + if enable_arch_filter and HAS_ARCH_FILTER: + try: + self.arch_filter = ArchFilter(gpu_target, strict_mode=False) + log.info(f"Architecture filter enabled for {gpu_target}") + except ValueError as e: + log.warning(f"Could not create arch filter: {e}") + + # Initialize generators (use self.layout which is the 3-char A,B,C layout) + self.ck_gen = CKTileKernelGenerator(datatype, self.layout) + self.disp_gen = DispatcherWrapperGenerator(datatype, self.layout) + + def _load_config(self, config_file: Optional[Path]) -> Dict: + """Load or create default configuration""" + if config_file and config_file.exists(): + with open(config_file) as f: + return json.load(f) + + # Match tile_engine default configs for GEMM/Preshuffle/Multi-D + # See: tile_engine/ops/gemm/configs/default_config.json + # tile_engine/ops/gemm_preshuffle/configs/default_config.json + # tile_engine/ops/gemm_multi_d/configs/default_config.json + return { + "tile_config": { + # tile_m/n/k: 64-256 step 64 = [64, 128, 192, 256] + "tile_m": [64, 128, 192, 256], + "tile_n": [64, 128, 192, 256], + "tile_k": [64, 128, 192, 256], + # warp configs matching tile_engine + "warp_m": [1, 2, 4], + "warp_n": [1, 2, 4], + "warp_k": [1], + # warp_tile configs matching tile_engine + "warp_tile_m": [4, 16, 32], + "warp_tile_n": [16, 32, 64], + "warp_tile_k": [8, 16, 32, 64, 128], + }, + "trait_config": { + "pipeline": ["compv3", "compv4", "mem"], + "epilogue": ["cshuffle", "default"], + "scheduler": ["intrawave", "interwave"], + "pad_m": [False], + "pad_n": [False], + "pad_k": [False], + "persistent": [False, True], + }, + "multi_d_config": { + # Note: Only MultiDAdd and MultiDMultiply are compatible with multi-D GEMM. + # Relu/Gelu are unary ops with signature (y, x), not multi-D signature (e, c, ds...) + "elementwise_ops": ["MultiDAdd", "MultiDMultiply"], + "num_d_tensors": [1, 2], + }, + } + + def generate_all(self, parallel: bool = True) -> Dict: + """Generate all kernels""" + log.info("Generating GEMM kernels:") + log.info(f" Datatype: {self.datatype}") + log.info(f" Layout: {self.layout}") + log.info(f" Variants: {[v.value for v in self.variants]}") + if self.use_preselected: + log.info(f" Using preselected set: {self.use_preselected}") + + results = {"kernels": [], "wrappers": [], "failed": []} + + # Get configurations + if self.use_preselected: + configs = self._get_preselected_configs() + log.info(f" Total configurations: {len(configs)}") + else: + for variant in self.variants: + log.info(f"\nGenerating {variant.value} kernels...") + configs = self._get_configs_for_variant(variant) + log.info(f" Configurations: {len(configs)}") + + if parallel: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self._generate_one, cfg) for cfg in configs + ] + for future in concurrent.futures.as_completed(futures): + try: + k, w = future.result() + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + else: + for cfg in configs: + try: + k, w = self._generate_one(cfg) + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + + # Generate registration header + if results["wrappers"]: + self._generate_registration_header(results["wrappers"]) + + return results + + # Generate from preselected set + if parallel: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(self._generate_one, cfg) for cfg in configs] + for future in concurrent.futures.as_completed(futures): + try: + k, w = future.result() + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + else: + for cfg in configs: + try: + k, w = self._generate_one(cfg) + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + + # Generate registration header + if results["wrappers"]: + self._generate_registration_header(results["wrappers"]) + + return results + + def _get_preselected_configs(self) -> List[KernelConfig]: + """Get preselected kernel configurations""" + try: + from preselected_kernels import get_preselected_set + + return get_preselected_set(self.use_preselected) + except ImportError: + log.warning( + "preselected_kernels module not found, falling back to config-based generation" + ) + return [] + except ValueError as e: + log.error(f"Invalid preselected set: {e}") + return [] + + def _get_configs_for_variant(self, variant: GemmVariant) -> List[KernelConfig]: + """Get all configurations for a variant + + Args: + variant: GEMM variant (STANDARD, PRESHUFFLE, MULTI_D) + + Returns: + List of valid kernel configurations for the variant + """ + configs = [] + + # Get base configs + tile_configs = self._get_tile_configs() + trait_configs = self._get_trait_configs() + + for tile, trait in itertools.product(tile_configs, trait_configs): + # Perform variant-specific architecture validation + if self.arch_filter and HAS_ARCH_FILTER: + if not self._is_tile_arch_valid(tile, variant): + continue + + if variant == GemmVariant.STANDARD: + configs.append(KernelConfig(tile=tile, trait=trait, variant=variant)) + + elif variant == GemmVariant.PRESHUFFLE: + # Preshuffle needs specific pipeline (preshufflev2) and scheduler (default) + # Skip configs that don't use preshuffle-compatible traits + preshuffle_trait = TraitConfig( + pipeline="preshufflev2", + epilogue="cshuffle", + scheduler="default", + pad_m=trait.pad_m, + pad_n=trait.pad_n, + pad_k=trait.pad_k, + persistent=trait.persistent, + ) + # Only generate one preshuffle config per tile (not per trait) + # since preshuffle has fixed pipeline/scheduler + if trait.pipeline == "compv3" and trait.scheduler == "intrawave": + configs.append( + KernelConfig( + tile=tile, + trait=preshuffle_trait, + variant=variant, + preshuffle=True, + ) + ) + + elif variant == GemmVariant.MULTI_D: + multi_d = self.config.get("multi_d_config", {}) + for ew_op, num_d in itertools.product( + multi_d.get("elementwise_ops", ["MultiDAdd"]), + multi_d.get("num_d_tensors", [1]), + ): + configs.append( + KernelConfig( + tile=tile, + trait=trait, + variant=variant, + elementwise_op=ew_op, + num_d_tensors=num_d, + d_layout=self.d_layout, # Use extracted D layout + ) + ) + + return configs + + def _get_tile_configs(self) -> List[TileConfig]: + """Get valid tile configurations, filtered by architecture constraints""" + tc = self.config["tile_config"] + configs = [] + rejected_count = 0 + + for params in itertools.product( + tc["tile_m"], + tc["tile_n"], + tc["tile_k"], + tc["warp_m"], + tc["warp_n"], + tc["warp_k"], + tc["warp_tile_m"], + tc["warp_tile_n"], + tc["warp_tile_k"], + ): + tile = TileConfig(*params) + + # Basic validation + if not tile.is_valid(): + rejected_count += 1 + continue + + # Architecture-specific validation + if self.arch_filter and HAS_ARCH_FILTER: + if not self._is_tile_arch_valid(tile): + rejected_count += 1 + continue + + configs.append(tile) + + if rejected_count > 0: + log.debug(f"Rejected {rejected_count} tile configs for {self.gpu_target}") + + return configs + + def _is_tile_arch_valid( + self, tile: TileConfig, variant: GemmVariant = None + ) -> bool: + """Check if tile configuration is valid for target architecture + + Args: + tile: Tile configuration to validate + variant: GEMM variant (affects operator-specific constraints) + """ + if not self.arch_filter or not HAS_ARCH_FILTER: + return True + + # Determine data types based on self.datatype + # Note: dtype_c is the ACCUMULATOR type, not output type (which may be fp16) + # WMMA instructions on gfx942 always use fp32 accumulator for fp16 inputs + dtype_map = { + "fp16": ("fp16", "fp16", "fp32"), # A=fp16, B=fp16, Acc=fp32 + "bf16": ("bf16", "bf16", "fp32"), # A=bf16, B=bf16, Acc=fp32 + "fp8": ("fp8", "fp8", "fp32"), # A=fp8, B=fp8, Acc=fp32 + "bf8": ("bf8", "bf8", "fp32"), # A=bf8, B=bf8, Acc=fp32 + "int8": ("int8", "int8", "int32"), # A=int8, B=int8, Acc=int32 + } + dtype_a, dtype_b, dtype_c = dtype_map.get( + self.datatype, ("fp16", "fp16", "fp32") + ) + + # Map GEMM variant to operator type for validation + operator = None + pipeline = "compv4" # Default + scheduler = "intrawave" # Default + + if OperatorType is not None and variant is not None: + variant_to_operator = { + GemmVariant.STANDARD: OperatorType.GEMM, + GemmVariant.PRESHUFFLE: OperatorType.GEMM_PRESHUFFLE, + GemmVariant.MULTI_D: OperatorType.GEMM_MULTI_D, + } + operator = variant_to_operator.get(variant, OperatorType.GEMM) + + # Preshuffle requires specific pipeline and scheduler + if variant == GemmVariant.PRESHUFFLE: + pipeline = "preshufflev2" + scheduler = "default" + + # Use preshuffle-specific validation (comprehensive CK-specific checks) + if variant == GemmVariant.PRESHUFFLE: + if not is_preshuffle_config_valid( + tile_m=tile.tile_m, + tile_n=tile.tile_n, + tile_k=tile.tile_k, + warp_m=tile.warp_m, + warp_n=tile.warp_n, + warp_k=tile.warp_k, + warp_tile_m=tile.warp_tile_m, + warp_tile_n=tile.warp_tile_n, + warp_tile_k=tile.warp_tile_k, + datatype=self.datatype, + ): + return False + + return self.arch_filter.is_kernel_valid( + datatype_a=dtype_a, + datatype_b=dtype_b, + datatype_c=dtype_c, + tile_m=tile.tile_m, + tile_n=tile.tile_n, + tile_k=tile.tile_k, + warp_m=tile.warp_m, + warp_n=tile.warp_n, + warp_k=tile.warp_k, + warp_tile_m=tile.warp_tile_m, + warp_tile_n=tile.warp_tile_n, + warp_tile_k=tile.warp_tile_k, + pipeline=pipeline, + scheduler=scheduler, + layout=self.layout, + operator=operator, + ) + + def _get_trait_configs(self) -> List[TraitConfig]: + """Get valid trait configurations, filtered by architecture constraints""" + tc = self.config["trait_config"] + configs = [] + rejected_count = 0 + + for params in itertools.product( + tc["pipeline"], + tc["epilogue"], + tc["scheduler"], + tc["pad_m"], + tc["pad_n"], + tc["pad_k"], + tc["persistent"], + ): + trait = TraitConfig(*params) + + # Basic trait validation (unsupported combinations) + if not trait.is_valid(): + rejected_count += 1 + continue + + configs.append(trait) + + if rejected_count > 0: + log.debug(f"Rejected {rejected_count} trait configs") + + return configs + + def _generate_one(self, config: KernelConfig) -> Tuple[str, str]: + """Generate one kernel and wrapper""" + kernel_name = KernelNaming.generate(config, self.datatype, self.layout) + + # Generate CK Tile kernel + kernel_code = self.ck_gen.generate(config) + kernel_path = self.kernel_dir / f"{kernel_name}.hpp" + kernel_path.write_text(kernel_code) + + # Generate dispatcher wrapper + wrapper_code = self.disp_gen.generate(config, kernel_path, self.kernel_dir) + wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp" + wrapper_path.write_text(wrapper_code) + + # Generate .cpp compilation unit for per-kernel parallel builds + cpp_path = self.kernel_dir / f"{kernel_name}.cpp" + cpp_code = f'''// SPDX-License-Identifier: MIT +// Auto-generated compilation unit for: {kernel_name} +// Enables per-kernel parallel compilation with make -j + +#include "{kernel_name}.hpp" + +namespace ck_tile {{ namespace generated {{ + volatile bool _{kernel_name.replace("-", "_")}_loaded = true; +}} }} +''' + cpp_path.write_text(cpp_code) + + return str(kernel_path), str(wrapper_path) + + def _generate_registration_header(self, wrapper_paths: List[str]): + """Generate master registration header""" + kernel_names = [ + Path(w).stem.replace("dispatcher_wrapper_", "") for w in wrapper_paths + ] + + includes = "\n".join( + [f'#include "dispatcher_wrapper_{n}.hpp"' for n in kernel_names] + ) + registrations = "\n ".join( + [ + f"registry.register_kernel(generated::make_{n}(gfx_arch), priority);" + for n in kernel_names + ] + ) + + content = f"""// SPDX-License-Identifier: MIT +// Auto-generated master registration +#pragma once + +#include "ck_tile/dispatcher.hpp" +{includes} + +namespace ck_tile {{ +namespace dispatcher {{ + +using ::ck_tile::dispatcher::Registry; +using Priority = ::ck_tile::dispatcher::Registry::Priority; + +inline void register_all_tile_gemm_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = Registry::instance(); + {registrations} +}} + +inline std::size_t get_tile_gemm_kernel_count() {{ return {len(kernel_names)}; }} + +}}}} +""" + + reg_path = self.wrapper_dir / "register_all_kernels.hpp" + reg_path.write_text(content) + logging.info(f"Generated registration header: {reg_path}") + + +# ============================================================================ +# CLI +# ============================================================================ + + +def _show_arch_info(gpu_target: str, datatype: str): + """Display supported configurations for a GPU architecture""" + if not HAS_ARCH_FILTER: + print("Architecture filter module not available") + return + + try: + from arch_filter import ( + get_supported_archs, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + LDS_CAPACITY_LIMITS, + TRAIT_UNSUPPORTED_COMBINATIONS, + ) + + print(f"\n=== Architecture Info for {gpu_target} ===\n") + + # Supported architectures + print(f"Supported GPUs: {get_supported_archs()}") + + # Warp configurations + warp_cfgs = WARP_SUPPORTED_COMBINATIONS.get(gpu_target, []) + print("\nWarp configurations [warp_m, warp_n, warp_k]:") + for cfg in warp_cfgs: + print(f" {cfg}") + + # Warp tile configurations for data type + dtype_map = { + "fp16": "fp16_fp16_fp16", + "bf16": "bf16_bf16_bf16", + "fp8": "fp8_fp8_fp16", + "bf8": "bf8_bf8_fp16", + "int8": "int8_int8_int32", + } + dtype_key = dtype_map.get(datatype, "fp16_fp16_fp16") + + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_target, {}) + warp_tiles = gpu_combos.get(dtype_key, []) + print( + f"\nWarp tile configurations for {dtype_key} [warp_tile_m, warp_tile_n, warp_tile_k]:" + ) + for cfg in warp_tiles: + print(f" {cfg}") + + # All supported data types + print(f"\nAll supported data types on {gpu_target}:") + for dtype in gpu_combos.keys(): + print(f" {dtype}") + + # LDS limits + print("\nLDS capacity limits:") + for pipeline, limit in LDS_CAPACITY_LIMITS.items(): + print(f" {pipeline}: {limit // 1024}KB") + + # Unsupported trait combinations + print("\nUnsupported trait combinations (pipeline, epilogue, scheduler):") + for combo in TRAIT_UNSUPPORTED_COMBINATIONS: + print(f" {combo}") + + print() + + except Exception as e: + print(f"Error showing arch info: {e}") + + +def main(): + parser = argparse.ArgumentParser( + description="Unified GEMM Code Generator - Single Source of Truth" + ) + parser.add_argument( + "--output-dir", type=Path, required=True, help="Output directory" + ) + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "bf16", "fp32", "fp8", "bf8", "int8", "pk_fp4"], + help="Data type (fp16, bf16, fp32, fp8, bf8, int8, pk_fp4)", + ) + parser.add_argument( + "--layout", + type=str, + default="rcr", + help="Layout (e.g., rcr for A=row, B=col, C=row; or rcrr for multi-d with D=row)", + ) + parser.add_argument( + "--gpu-target", + type=str, + default="gfx942", + help="Target GPU (gfx90a, gfx942, gfx950, gfx1201)", + ) + parser.add_argument("--config", type=Path, help="Configuration JSON file") + parser.add_argument( + "--variants", + nargs="+", + choices=["standard", "preshuffle", "multi_d"], + default=["standard"], + help="Variants to generate", + ) + parser.add_argument( + "--preselected", + type=str, + help="Use preselected kernel set (e.g., fp16_rcr_essential)", + ) + parser.add_argument( + "--no-parallel", action="store_true", help="Disable parallel generation" + ) + parser.add_argument( + "--register", action="store_true", help="Generate dispatcher registration code" + ) + parser.add_argument( + "--no-arch-filter", + action="store_true", + help="Disable architecture-specific kernel filtering", + ) + parser.add_argument( + "--show-arch-info", + action="store_true", + help="Show supported configurations for target GPU and exit", + ) + parser.add_argument( + "--kernel-set", + type=str, + help="Kernel set name (creates subdirectory for organization)", + ) + parser.add_argument( + "--tile-config-json", + type=str, + help="JSON string specifying exact tile configuration (for minimal builds)", + ) + + args = parser.parse_args() + + # Handle inline tile config JSON for minimal/single-kernel builds + if args.tile_config_json: + try: + cfg = json.loads(args.tile_config_json) + + # Build proper config structure + full_config = {} + + # Extract tile config + tile_keys = [ + "tile_m", + "tile_n", + "tile_k", + "warp_m", + "warp_n", + "warp_k", + "warp_tile_m", + "warp_tile_n", + "warp_tile_k", + "block_size", + ] + tile_config = {k: cfg[k] for k in tile_keys if k in cfg} + if tile_config: + full_config["tile_config"] = tile_config + + # Extract trait config + trait_keys = ["pipeline", "epilogue", "scheduler"] + trait_config = {k: cfg[k] for k in trait_keys if k in cfg} + # Add default pad/persistent values + trait_config.setdefault("pad_m", [False]) + trait_config.setdefault("pad_n", [False]) + trait_config.setdefault("pad_k", [False]) + trait_config.setdefault("persistent", [False]) + if trait_config: + full_config["trait_config"] = trait_config + + # Extract multi_d config (for multi_d variant) + if "elementwise_ops" in cfg or "num_d_tensors" in cfg: + multi_d_config = {} + if "elementwise_ops" in cfg: + multi_d_config["elementwise_ops"] = cfg["elementwise_ops"] + if "num_d_tensors" in cfg: + multi_d_config["num_d_tensors"] = cfg["num_d_tensors"] + full_config["multi_d_config"] = multi_d_config + + # Use already structured config if provided + if "tile_config" in cfg: + full_config = cfg + + # Write to temp file and use as config + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as f: + json.dump(full_config, f) + args.config = Path(f.name) + except json.JSONDecodeError as e: + logging.error(f"Invalid tile-config-json: {e}") + return 1 + except KeyError as e: + logging.error(f"Missing required key in tile-config-json: {e}") + return 1 + + # Show architecture info if requested + if args.show_arch_info: + _show_arch_info(args.gpu_target, args.datatype) + return 0 + + variants = [GemmVariant(v) for v in args.variants] if not args.preselected else None + + codegen = UnifiedGemmCodegen( + output_dir=args.output_dir, + datatype=args.datatype, + layout=args.layout, + gpu_target=args.gpu_target, + config_file=args.config, + variants=variants, + use_preselected=args.preselected, + enable_arch_filter=not args.no_arch_filter, + kernel_set_name=args.kernel_set, + ) + + results = codegen.generate_all(parallel=not args.no_parallel) + + logging.info("\n✅ Generation complete!") + logging.info(f" Kernels: {len(results['kernels'])}") + logging.info(f" Wrappers: {len(results['wrappers'])}") + logging.info(f" Failed: {len(results['failed'])}") + + if results["failed"]: + logging.error(f"\nFailed kernels: {len(results['failed'])}") + for err in results["failed"][:5]: + logging.error(f" {err}") + + # Generate dispatcher registration if requested + if args.register: + logging.info("\n📝 Generating dispatcher registration code...") + try: + from generate_dispatcher_registration import ( + scan_generated_headers, + generate_registration_header, + generate_registration_cpp, + ) + + kernels = scan_generated_headers(args.output_dir) + reg_dir = args.output_dir / "registration" + reg_dir.mkdir(exist_ok=True) + + generate_registration_header( + kernels, reg_dir / "dispatcher_registration.hpp" + ) + generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp") + + logging.info(f"✓ Generated registration code for {len(kernels)} kernels") + except Exception as e: + logging.error(f"Failed to generate registration code: {e}") + return 1 + + return 0 if not results["failed"] else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/dispatcher/examples/CMakeLists.txt b/dispatcher/examples/CMakeLists.txt new file mode 100644 index 0000000000..0359eb0d8d --- /dev/null +++ b/dispatcher/examples/CMakeLists.txt @@ -0,0 +1,448 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +cmake_minimum_required(VERSION 3.16) + +# Get processor count for parallel builds +include(ProcessorCount) +ProcessorCount(NPROC) +if(NPROC EQUAL 0) + set(NPROC 4) +endif() + +# GPU target architecture (passed from command line or default to gfx942) +if(NOT DEFINED GPU_TARGETS OR GPU_TARGETS STREQUAL "") + set(GPU_TARGETS "gfx942" CACHE STRING "GPU architecture target") +endif() +# Extract first target if multiple are provided (we only support single target builds) +string(REPLACE ";" " " GPU_TARGETS_SPACE "${GPU_TARGETS}") +string(REPLACE " " ";" GPU_TARGETS_LIST "${GPU_TARGETS_SPACE}") +list(GET GPU_TARGETS_LIST 0 GPU_TARGET) +message(STATUS "Building for GPU target: ${GPU_TARGET}") + +# NOTE: Per-kernel compilation is now automatic via declarative examples +# Each example generates only its declared kernels (from DECL_KERNEL_SET) + +# Link to dispatcher library +link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../build) + +# ============================================================================= +# Kernel Output Directory +# ============================================================================= + +set(KERNEL_OUTPUT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels") +file(MAKE_DIRECTORY ${KERNEL_OUTPUT_DIR}) + +# ============================================================================= +# Kernel Generation Targets (run during 'make', not 'cmake') +# ============================================================================= + +# Sentinel files to track generation +set(GEMM_SENTINEL "${KERNEL_OUTPUT_DIR}/.gemm_generated") + +# Generate GEMM kernels (standard + preshuffle + multi_d) - runs with internal parallelism +# Note: 4-char layout "rcrr" means A=row, B=col, C=row, D=row (for multi-d) +add_custom_command( + OUTPUT ${GEMM_SENTINEL} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcrr --variants standard preshuffle multi_d + --output ${KERNEL_OUTPUT_DIR} + COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating GEMM kernels (fp16, rcrr, standard + preshuffle + multi_d) with internal parallelism..." + VERBATIM +) + +add_custom_target(generate_gemm_kernels + DEPENDS ${GEMM_SENTINEL} + COMMENT "GEMM kernel generation target" +) + +# Alias for generate_all_kernels (GEMM only now) +add_custom_target(generate_all_kernels + DEPENDS generate_gemm_kernels +) + +# ============================================================================= +# Per-Kernel Compilation (Maximum Parallelism) +# ============================================================================= +# Enable with: cmake -DPER_KERNEL_COMPILATION=ON +# +# This creates ONE translation unit per kernel, enabling: +# 1. Maximum parallelism with make -j$(nproc) +# 2. Per-kernel build progress: "[1/128] Building kernel: gemm_fp16_128x128" +# 3. Incremental rebuilds (only changed kernels recompile) +# 4. Fine-grained build time analysis +# +# Build process: +# 1. Generate kernel headers (.hpp) +# 2. Generate wrapper files (.cpp) - one per kernel +# 3. Compile each wrapper in parallel +# 4. Link all objects into libdispatcher_kernels.so +# +# Example output: +# [ 1/128] Building kernel: gemm_fp16_rcr_128x128x32 +# [ 2/128] Building kernel: gemm_fp16_rcr_256x256x64 +# ... +# [128/128] Linking: libdispatcher_kernels.so +# ============================================================================= + +set(WRAPPER_DIR "${CMAKE_BINARY_DIR}/kernel_wrappers") +set(WRAPPER_SENTINEL "${WRAPPER_DIR}/.wrappers_generated") + +# Target: Generate wrapper .cpp files (one per kernel) +add_custom_command( + OUTPUT ${WRAPPER_SENTINEL} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/generate_kernel_wrappers.py + --kernel-dir ${KERNEL_OUTPUT_DIR} + --output-dir ${WRAPPER_DIR} + --generate-makefile + --generate-cmake + COMMAND ${CMAKE_COMMAND} -E touch ${WRAPPER_SENTINEL} + DEPENDS ${GEMM_SENTINEL} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating per-kernel wrapper .cpp files..." + VERBATIM +) + +add_custom_target(generate_kernel_wrappers + DEPENDS ${WRAPPER_SENTINEL} + COMMENT "Kernel wrapper generation target" +) + +# Target: Build kernels using generated Makefile (true per-kernel progress) +add_custom_target(build_kernels_parallel + COMMAND ${CMAKE_COMMAND} -E echo "Building kernels with per-kernel progress..." + COMMAND make -C ${WRAPPER_DIR} -j${NPROC} 2>&1 | grep -E "^\\[|Built|Linking|Error" + DEPENDS generate_kernel_wrappers + WORKING_DIRECTORY ${WRAPPER_DIR} + COMMENT "Compiling kernels in parallel (one translation unit per kernel)..." + VERBATIM +) + +# Global kernel build (optional - prefer per-example builds for minimal compilation) +# This builds ALL kernels into a shared library - use for Python bindings or full library +# For C++ examples, use declarative approach which builds only needed kernels +add_custom_target(dispatcher_kernels + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/parallel_kernel_builder.py + --kernel-dir ${KERNEL_OUTPUT_DIR} + --output-dir ${CMAKE_BINARY_DIR} + --include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include" + --jobs ${NPROC} + DEPENDS generate_all_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts + COMMENT "Building ALL kernels in parallel (prefer per-example builds for minimal compilation)..." + VERBATIM +) + +# ============================================================================= +# Force regeneration targets (useful when you want to regenerate) +# ============================================================================= + +add_custom_target(regenerate_gemm_kernels + COMMAND ${CMAKE_COMMAND} -E remove -f ${GEMM_SENTINEL} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcr --variants standard preshuffle multi_d + --output ${KERNEL_OUTPUT_DIR} + COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Force regenerating GEMM kernels (standard + preshuffle + multi_d)..." + VERBATIM +) + +add_custom_target(regenerate_all_kernels + DEPENDS regenerate_gemm_kernels +) + +# Clean all per-example kernel directories +add_custom_target(clean_example_kernels + COMMAND ${CMAKE_COMMAND} -E echo "Removing per-example kernel directories..." + COMMAND find ${CMAKE_BINARY_DIR} -maxdepth 1 -type d -name "*_kernels" -exec rm -rf {} + + COMMENT "Cleaning all per-example kernel directories..." + VERBATIM +) + +# ============================================================================= +# Helper function to add a GPU example with force-included kernel +# ============================================================================= + +# Helper for GPU examples that use the dispatcher registry +# KERNEL_HEADER can be: +# - A registration header (register_all_kernels.hpp) - included directly in source +# - A specific kernel header - force-included via compiler flag +function(add_gpu_example NAME SOURCE KERNEL_HEADER) + add_executable(${NAME} ${SOURCE}) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include + ${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include + ${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels + ${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels/dispatcher_wrappers # Wrapper headers + ) + + # Check if using registration header (no force-include needed) + get_filename_component(HEADER_NAME ${KERNEL_HEADER} NAME) + if(HEADER_NAME STREQUAL "register_all_kernels.hpp") + # Registration header - examples include it directly + target_compile_options(${NAME} PRIVATE + -DGEMM_KERNEL_AVAILABLE=1 + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + else() + # Specific kernel header - force-include it + target_compile_options(${NAME} PRIVATE + -include ${KERNEL_HEADER} + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + endif() + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() +endfunction() + +# Helper for standalone GPU examples (instantiate kernel directly, no pre-generated header) +function(add_standalone_gpu_example NAME SOURCE) + add_executable(${NAME} ${SOURCE}) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include + ${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include + ${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels (optional) + ) + + target_compile_options(${NAME} PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() +endfunction() + +# Helper for declarative examples (configuration demo, still needs HIP compiler for CK headers) +function(add_declarative_example NAME SOURCE) + add_executable(${NAME} ${SOURCE}) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ) + + target_compile_options(${NAME} PRIVATE + -Wno-float-equal + -Wno-unused-variable + -Wno-undefined-func-template + -mllvm -enable-noalias-to-md-conversion=0 + ) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() +endfunction() + +# ============================================================================= +# GEMM Examples +# ============================================================================= + +# Per-example kernel directories are created from DECL_KERNEL_SET declarations +# Each example gets its own: build/_kernels/ +# This prevents clashes during parallel compilation of multiple examples. + +# Helper function to add example with declarative kernel support +# Parses DECL_KERNEL_SET from source and generates ONLY the declared kernels +# This enables minimal builds: only kernels needed by this example are generated +# +# Key features: +# - Per-example kernel directories: build/_kernels/ (no clashes) +# - Automatic header inclusion: No hardcoded #include needed in source +# - Minimal builds: Only declared kernels are generated +# - Auto-regeneration: Kernels regenerated if directory missing +# - Parallel compilation: Each kernel is a separate translation unit +function(add_declarative_gpu_example NAME SOURCE) + set(EXAMPLE_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE}") + get_filename_component(EXAMPLE_STEM ${SOURCE} NAME_WE) + + # Per-example kernel directories + set(EXAMPLE_KERNEL_DIR "${CMAKE_BINARY_DIR}/${NAME}_kernels") + set(EXAMPLE_HEADER "${EXAMPLE_KERNEL_DIR}/${EXAMPLE_STEM}_kernels.hpp") + set(EXAMPLE_LIB "${EXAMPLE_KERNEL_DIR}/lib${NAME}_kernels.a") + set(EXAMPLE_SENTINEL "${EXAMPLE_KERNEL_DIR}/.generated") + + # Generate AND compile kernels in parallel at make time + # This avoids slow cmake and gets per-kernel progress + add_custom_command( + OUTPUT ${EXAMPLE_SENTINEL} ${EXAMPLE_LIB} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/example_kernel_builder.py + ${EXAMPLE_SOURCE} + --output-dir ${EXAMPLE_KERNEL_DIR} + --include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include" + --gpu-target ${GPU_TARGET} + --jobs ${NPROC} + --target-name ${NAME} + COMMAND ${CMAKE_COMMAND} -E touch ${EXAMPLE_SENTINEL} + DEPENDS ${EXAMPLE_SOURCE} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts + COMMENT "[${NAME}] Generating and compiling kernels from DECL_KERNEL_SET..." + VERBATIM + ) + + add_custom_target(generate_${NAME}_kernels DEPENDS ${EXAMPLE_SENTINEL}) + + # Add the executable + add_executable(${NAME} ${SOURCE}) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + # Link against the per-example kernel library + target_link_libraries(${NAME} PRIVATE ${EXAMPLE_LIB}) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${EXAMPLE_KERNEL_DIR} + ${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers + ) + + # Force-include the generated registration header + target_compile_options(${NAME} PRIVATE + -include ${EXAMPLE_HEADER} + -DGEMM_KERNEL_AVAILABLE=1 + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() + + # Only depends on generating THIS example's kernels + add_dependencies(${NAME} generate_${NAME}_kernels) +endfunction() + +# GEMM C++ examples with declarative kernel support +# Each example's C++ code contains DECL_KERNEL_SET which declares needed kernels +add_declarative_gpu_example(gemm_01_basic gemm/cpp/01_basic_gemm.cpp) +add_declarative_gpu_example(gemm_02_multi_size gemm/cpp/02_multi_size.cpp) +add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_validation.cpp) +add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp) +add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp) +add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp) + +# ============================================================================= +# GEMM Python Library - Single Fallback Kernel +# ============================================================================= + +# Generate a single fallback kernel for the Python library (fp16, rcr, compv4) +set(GEMM_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/gemm_python_fallback") +set(GEMM_FALLBACK_KERNEL "${GEMM_FALLBACK_KERNEL_DIR}/gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16.hpp") + +# Tile config JSON for single kernel generation +set(GEMM_FALLBACK_TILE_CONFIG "{\"tile_m\":[128],\"tile_n\":[128],\"tile_k\":[32],\"warp_m\":[2],\"warp_n\":[2],\"warp_k\":[1],\"warp_tile_m\":[32],\"warp_tile_n\":[32],\"warp_tile_k\":[16],\"pipeline\":[\"compv4\"],\"scheduler\":[\"intrawave\"],\"epilogue\":[\"cshuffle\"]}") + +# Generate single fallback kernel (not all 6000+ kernels) +add_custom_command( + OUTPUT ${GEMM_FALLBACK_KERNEL} + COMMAND ${CMAKE_COMMAND} -E make_directory ${GEMM_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcr --variants standard + --gpu-target ${GPU_TARGET} + --output-dir ${GEMM_FALLBACK_KERNEL_DIR} + --tile-config-json "${GEMM_FALLBACK_TILE_CONFIG}" + COMMENT "Generating single fallback GEMM kernel for Python library" + VERBATIM +) + +add_custom_target(generate_gemm_fallback_kernel DEPENDS ${GEMM_FALLBACK_KERNEL}) + +# GEMM dynamic library for Python +add_library(dispatcher_gemm_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/gemm_ctypes_lib.cpp) +target_link_libraries(dispatcher_gemm_lib PRIVATE ck_tile_dispatcher) +target_include_directories(dispatcher_gemm_lib PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${GEMM_FALLBACK_KERNEL_DIR} +) +target_compile_options(dispatcher_gemm_lib PRIVATE + -DCK_TILE_SINGLE_KERNEL_INCLUDE + -include ${GEMM_FALLBACK_KERNEL} + -DGFX_ARCH="${GPU_TARGET}" + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +if(hip_FOUND) + target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device hip::host) +endif() +add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel) + +message(STATUS "GEMM examples configured - kernels will be generated during 'make'") + +# Convenience target to build all Python ctypes libraries +add_custom_target(python_libs + DEPENDS dispatcher_gemm_lib + COMMENT "Building Python ctypes libraries (GEMM)" +) + +# ============================================================================= +# Per-Architecture Kernel Generation Targets +# ============================================================================= + +set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030) + +foreach(ARCH ${SUPPORTED_GPU_ARCHS}) + # GEMM kernels for this arch + add_custom_target(generate_gemm_kernels_${ARCH} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcr --gpu-target ${ARCH} + --output ${KERNEL_OUTPUT_DIR} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating GEMM kernels for ${ARCH}..." + VERBATIM + ) + + # Alias for kernels (GEMM only now) + add_custom_target(generate_kernels_${ARCH} + DEPENDS generate_gemm_kernels_${ARCH} + COMMENT "Generating all kernels for ${ARCH}..." + ) +endforeach() + +# ============================================================================= +# Summary +# ============================================================================= + +message(STATUS "") +message(STATUS "=== Dispatcher Examples Configuration ===") +message(STATUS "") +message(STATUS "Kernels will be generated automatically during 'make'") +message(STATUS " Generated to: ${KERNEL_OUTPUT_DIR}") +message(STATUS "") +message(STATUS "Build targets:") +message(STATUS " make - Build all examples (generates kernels first)") +message(STATUS " make python_libs - Build Python ctypes libraries") +message(STATUS " make generate_all_kernels - Generate all kernels only") +message(STATUS " make regenerate_all_kernels - Force regenerate all kernels") +message(STATUS "") +message(STATUS "Per-architecture targets:") +message(STATUS " make generate_kernels_ - Generate for specific arch") +message(STATUS " Supported archs: ${SUPPORTED_GPU_ARCHS}") +message(STATUS "") diff --git a/dispatcher/examples/README.md b/dispatcher/examples/README.md new file mode 100644 index 0000000000..fdee9c3583 --- /dev/null +++ b/dispatcher/examples/README.md @@ -0,0 +1,210 @@ +# CK Tile Dispatcher Examples + +Comprehensive examples for GEMM operations with GPU execution. + +> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference. + +--- + +## Quick Start + +### Step 1: Build + +```bash +cd /path/to/composable_kernel/dispatcher +mkdir -p build && cd build + +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Build everything (C++ examples + Python libraries) +make -j$(nproc) + +# Or build ONLY Python libraries (faster) +make python_libs -j$(nproc) +``` + +### Step 2: Run C++ Examples + +```bash +cd build/examples + +# GEMM +./gemm_01_basic +./gemm_02_multi_size +./gemm_03_benchmark_validation +./gemm_04_heuristics +./gemm_05_json_export +./gemm_06_multi_registry +``` + +### Step 3: Run Python Examples + +```bash +cd /path/to/composable_kernel/dispatcher + +# GEMM +python3 examples/gemm/python/01_basic_gemm.py +python3 examples/gemm/python/04_validation.py +python3 examples/gemm/python/07_stress_test.py +python3 examples/gemm/python/08_heuristics.py +``` + +--- + +## Directory Structure + +``` +examples/ +├── gemm/ +│ ├── cpp/ # 6 C++ GEMM examples +│ └── python/ # 11 Python GEMM examples +│ +└── README.md +``` + +--- + +## GEMM Examples + +### C++ Examples + +| # | Example | Description | +|---|---------|-------------| +| 01 | `gemm_01_basic` | Basic GEMM with declarative API, autofill, autocorrect | +| 02 | `gemm_02_multi_size` | Wildcard expansion for multiple configurations | +| 03 | `gemm_03_benchmark_validation` | Performance benchmarking with CPU/GPU validation | +| 04 | `gemm_04_heuristics` | Heuristic-based kernel selection | +| 05 | `gemm_05_json_export` | Registry JSON export for external tools | +| 06 | `gemm_06_multi_registry` | Multiple registries with named kernel sets | + +**Details:** [gemm/cpp/README.md](gemm/cpp/README.md) + +--- + +### Python Examples + +| # | Example | Description | +|---|---------|-------------| +| 01 | `01_basic_gemm.py` | Basic GEMM with multi-kernel support | +| 02 | `02_batch_gemm.py` | Batched GEMM operations | +| 03 | `03_benchmark.py` | Performance benchmarking | +| 04 | `04_validation.py` | CPU reference validation | +| 05 | `05_numpy_integration.py` | NumPy array integration | +| 06 | `06_json_export.py` | Registry JSON export | +| 07 | `07_stress_test.py` | Multi-kernel stress testing (48 configs) | +| 08 | `08_heuristics.py` | Heuristic-based kernel selection (24 configs) | +| 09 | `09_multi_registry.py` | Multiple registries | +| 10 | `10_advanced_benchmark.py` | Advanced benchmark with full control | +| 11 | `11_json_import.py` | Import kernels from JSON | + +**Details:** [gemm/python/README.md](gemm/python/README.md) + +--- + +## Key Features + +### Declarative Kernel API + +Both C++ and Python examples use a declarative approach: + +**C++ (DECL_KERNEL_SET macro):** +```cpp +DECL_KERNEL_SET(my_kernels, + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16) + .pipeline("compv4").scheduler("intrawave"), + "gfx942" + ) +); +``` + +**Python (KernelConfig):** +```python +config = KernelConfig( + tile_m=256, tile_n=256, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave" +) +``` + +### Autofill and Autocorrect + +The build system automatically: +- **Autofills** missing parameters with sensible defaults +- **Autocorrects** invalid parameters based on architecture constraints +- **Expands** wildcards (`*`, `-1`, `ANY_INT`) to all valid configurations + +### Architecture Filtering + +Kernel configurations are validated against GPU architecture constraints: +- Tile divisibility requirements +- Warp tile constraints +- Pipeline compatibility + +Invalid configurations are automatically pruned during code generation. + +--- + +## Validation Examples + +### C++ Validation + +```bash +./gemm_03_benchmark_validation --verify 1 # GEMM with CPU reference +./gemm_03_benchmark_validation --verify 2 # GEMM with GPU reference +``` + +### Python Validation + +```bash +python3 examples/gemm/python/04_validation.py +python3 examples/gemm/python/07_stress_test.py # Multi-kernel validation +``` + +--- + +## Troubleshooting + +### Python: Library not found + +```bash +# Run from dispatcher directory +cd /path/to/composable_kernel/dispatcher +python3 examples/gemm/python/01_basic_gemm.py +``` + +### C++: Executables not found + +```bash +# Build with examples enabled +cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON +make -j$(nproc) + +# Run from build/examples +cd build/examples +./gemm_01_basic +``` + +### GPU not detected + +```bash +rocminfo | grep "Name:" +# Should show: gfx942, gfx90a, etc. +``` + +--- + +## Archived Examples + +Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`: +- `examples/conv/cpp/` - 11 C++ convolution examples +- `examples/conv/python/` - 14 Python convolution examples + +See the archive for convolution functionality reference. diff --git a/dispatcher/examples/gemm/cpp/01_basic_gemm.cpp b/dispatcher/examples/gemm/cpp/01_basic_gemm.cpp new file mode 100644 index 0000000000..80b584a842 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/01_basic_gemm.cpp @@ -0,0 +1,243 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 01: Basic GEMM - Autofill, Autocorrect, and Full Declaration + * + * Demonstrates THREE declaration patterns: + * + * 1. AUTOFILL: Minimal declaration - missing params filled with defaults + * .add(Signature().dtype("fp16").layout("rcr"), + * Algorithm().tile(128,128,64).pipeline("compv3").scheduler("intrawave"), + * "gfx942") + * -> wave(2,2,1), warp(32,32,16), epilogue("cshuffle") added automatically + * + * 2. AUTOCORRECT: Invalid params corrected to valid values + * .add(..., Algorithm().wave(1,1,1)...) + * -> wave(1,1,1) is invalid for gfx942, corrected to wave(2,2,1) + * + * 3. FULL: All parameters explicitly specified + * .add(..., Algorithm().tile().wave().warp().pipeline().scheduler().epilogue()...) + * + * Build: cd dispatcher/build && cmake .. && make gemm_01_basic + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// THREE KERNEL DECLARATION PATTERNS +// ============================================================================= + +DECL_KERNEL_SET( + basic_gemm_kernels, + // ------------------------------------------------------------------------- + // Pattern 1: AUTOFILL - Minimal declaration + // Only specify: dtype, layout, tile, pipeline, scheduler + // Auto-filled: wave(2,2,1), warp(32,32,16), epilogue("cshuffle"), pad(false,false,false) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) // Required + .pipeline("compv3") // Required + .scheduler("intrawave"), // Required + "gfx942") + + // ------------------------------------------------------------------------- + // Pattern 2: AUTOCORRECT - Invalid wave config + // wave(1,1,1) is invalid for gfx942 WMMA, corrected to wave(2,2,1) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) // Different tile_k to make unique kernel + .wave(1, 1, 1) // INVALID: autocorrected to (2,2,1) + .warp(32, 32, 16) // Valid warp for 128x128 tile + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + + // ------------------------------------------------------------------------- + // Pattern 3: FULL - All parameters explicitly specified + // No autofill or autocorrect needed + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) // Explicit tile + .wave(2, 2, 1) // Explicit wave (valid) + .warp(16, 16, 32) // Explicit warp tile + .pipeline("compv3") // Explicit pipeline + .scheduler("intrawave") // Explicit scheduler + .epilogue("cshuffle") // Explicit epilogue + .pad(false, false, false), // Explicit padding + "gfx942")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 01: GEMM Autofill/Autocorrect/Full", + "Three kernel declaration patterns"); + args.add_flag("--list", "List registered kernels"); + args.add_flag("--list-verbose", "List registered kernels with full details"); + args.add_option("--size", "1024", "Problem size MxNxK"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 01: GEMM Declaration Patterns"); + + // ========================================================================= + // Show the Three Patterns + // ========================================================================= + std::cout << "\nTHREE DECLARATION PATTERNS:\n"; + std::cout << "============================\n\n"; + + std::cout << "1. AUTOFILL (minimal declaration):\n"; + std::cout << " .add(Signature().dtype(\"fp16\").layout(\"rcr\"),\n"; + std::cout + << " Algorithm().tile(128,128,64).pipeline(\"compv3\").scheduler(\"intrawave\"),\n"; + std::cout << " \"gfx942\")\n"; + std::cout << " -> Auto-filled: wave(2,2,1), warp(32,32,16), epilogue(\"cshuffle\")\n\n"; + + std::cout << "2. AUTOCORRECT (invalid params fixed):\n"; + std::cout << " .add(..., Algorithm().wave(1,1,1)...)\n"; + std::cout << " -> wave(1,1,1) invalid for gfx942, corrected to wave(2,2,1)\n\n"; + + std::cout << "3. FULL (all params explicit):\n"; + std::cout << " .add(..., " + "Algorithm().tile().wave().warp().pipeline().scheduler().epilogue().pad()...)\n"; + std::cout << " -> No changes needed\n\n"; + + std::string gfx_arch = args.get("--arch", "gfx942"); + + // ========================================================================= + // Step 1: Show Declared Kernel Sets + // ========================================================================= + std::cout << "Step 1: Declared Kernel Sets\n"; + KernelSetRegistry::instance().print(); + + const auto& decl_set = KernelSetRegistry::instance().get("basic_gemm_kernels"); + std::cout << " 'basic_gemm_kernels': " << decl_set.size() << " declaration(s)\n"; + + // ========================================================================= + // Step 2: Create Registry and Register Kernels + // ========================================================================= + std::cout << "\nStep 2: Register Kernels\n"; + + Registry registry; + // Use generic macro + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + // List kernels if requested + if(args.has("--list") || args.has("--list-verbose")) + { + std::cout << "\n"; + print_registered_kernels(registry, std::cout, args.has("--list-verbose")); + return 0; + } + + // ========================================================================= + // Step 3: Create Dispatcher + // ========================================================================= + std::cout << "\nStep 3: Create Dispatcher\n"; + Dispatcher dispatcher(®istry); + + // ========================================================================= + // Step 4: Setup Problem + // ========================================================================= + int size = args.get_int("--size", 1024); + const int M = size, N = size, K = size; + + std::cout << "\nStep 4: Setup Problem (" << M << "x" << N << "x" << K << ")\n"; + + Problem problem(M, N, K); + + using DataType = ck_tile::fp16_t; + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + // ========================================================================= + // Step 5: Select and Run + // ========================================================================= + std::cout << "\nStep 5: Select and Run\n"; + + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "ERROR: No kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->get_name() << "\n"; + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n"; + + // ========================================================================= + // Step 6: Verify + // ========================================================================= + std::cout << "\nStep 6: Verify\n"; + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + + const float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < M * N; ++i) + { + if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f) + ++errors; + } + + bool passed = (errors == 0); + std::cout << " Expected: " << expected << ", Errors: " << errors << "\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + + // ========================================================================= + // Summary + // ========================================================================= + print_separator(); + std::cout << "DECLARATION PATTERNS SUMMARY:\n"; + print_separator(); + std::cout << R"( + 1. AUTOFILL: Specify only required params, system fills defaults + - Useful for quick prototyping + - Guarantees valid configuration + + 2. AUTOCORRECT: System validates and fixes invalid params + - wave(1,1,1) -> wave(2,2,1) on gfx942 + - Invalid pipeline/scheduler combos fixed + - Logs corrections for debugging + + 3. FULL: All params explicit - no changes made + - Full control over configuration + - Best for production/tuning +)"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/02_multi_size.cpp b/dispatcher/examples/gemm/cpp/02_multi_size.cpp new file mode 100644 index 0000000000..5e620209f4 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/02_multi_size.cpp @@ -0,0 +1,215 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 02: Multi-Size GEMM with Wildcard Expansion + * + * Demonstrates the WILDCARD feature where specifying wildcards causes + * the build system to expand to ALL valid configurations for the architecture. + * + * WILDCARD SYNTAX: + * - Integer params: ANY_INT or -1 (both are equivalent, ANY_INT is just a #define for -1) + * - String params: "*" (for pipeline, scheduler) + * + * The kernel declaration: + * .add(..., Algorithm().tile(64,64,64).wave(ANY_INT,ANY_INT,1).warp(-1,-1,-1) + * .pipeline("*").scheduler("*"), ...) + * + * Expands to multiple kernels: + * - wave: (1,4,1), (2,2,1), (4,1,1) -> 3 options + * - warp: (16,16,32), (32,32,16) -> 2 options + * - pipeline: "compv3" -> 1 option (compv4 requires special handling) + * - scheduler: "intrawave" -> 1 option + * + * Raw expansion: 3 × 2 = 6 configs, but arch filter validates each: + * - tile_m must be divisible by (warp_m × warp_tile_m) + * - tile_n must be divisible by (warp_n × warp_tile_n) + * - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16) + * Result: 4 valid wildcard kernels + 1 explicit = 5 total + * + * Build: cd dispatcher/build && cmake .. && make gemm_02_multi_size + * Usage: ./gemm_02_multi_size [--max-size N] [--help] + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: Demonstrates Wildcard Expansion +// ============================================================================= + +DECL_KERNEL_SET(multi_size_kernels, + // ------------------------------------------------------------------------- + // Kernel 1: Explicit - all parameters specified (no expansion) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + + // ------------------------------------------------------------------------- + // Kernel 2: WILDCARD - expands to multiple valid configurations + // Wildcards: ANY_INT == -1 (for integers), "*" (for strings) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 64) + .wave(ANY_INT, ANY_INT, 1) // ANY_INT → (1,4,1), (2,2,1), (4,1,1) + .warp(-1, -1, -1) // -1 same as ANY_INT → (16,16,32), (32,32,16) + .pipeline("*") // "*" → valid pipelines + .scheduler("*") // "*" → valid schedulers + .epilogue("cshuffle"), + "gfx942")); +// Raw: 3×2=6, arch filter removes 2 invalid → 4 valid kernels + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 02: Multi-Size GEMM with Wildcards", + "Demonstrates wildcard expansion for kernel generation"); + args.add_option("--max-size", "4096", "Maximum problem size to test"); + args.add_option("--arch", "gfx942", "GPU architecture"); + args.add_flag("--list", "List all registered kernels"); + args.add_flag("--list-verbose", "List kernels with full configuration details"); + + if(!args.parse(argc, argv)) + return 0; + + int max_size = args.get_int("--max-size", 4096); + std::string gfx_arch = args.get("--arch", "gfx942"); + + print_header("Example 02: Multi-Size GEMM with Wildcards"); + + // ========================================================================= + // Show Wildcard Expansion Concept + // ========================================================================= + std::cout << "\nWILDCARD EXPANSION:\n"; + std::cout << "===================\n"; + std::cout << R"( + Wildcard syntax: + ANY_INT or -1 -> expands integer params to all valid values + "*" -> expands string params (pipeline/scheduler) to valid values + + Declaration with wildcards: + .tile(64, 64, 64) -> fixed tile size (no wildcard) + .wave(ANY_INT, ANY_INT, 1) -> expands to (1,4,1), (2,2,1), (4,1,1) = 3 + .warp(-1, -1, -1) -> expands to (16,16,32), (32,32,16) = 2 + .pipeline("*") -> expands to valid pipelines = 1 + .scheduler("*") -> expands to valid schedulers = 1 + + Expanded: 3 × 2 = 6 configs, but arch filter validates each: + - wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64 + - Result: 4 valid kernels from wildcard + 1 explicit = 5 total +)"; + + // ========================================================================= + // Setup Registry and Dispatcher + // ========================================================================= + std::cout << "\nStep 1: Register Kernels\n"; + std::cout << "------------------------\n"; + + Registry registry; + registry.set_name("multi_size_registry"); + + // Register kernels from generated header (includes expanded wildcards) + // Use generic macro - no need to hardcode example name + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s) from wildcard expansion\n"; + + if(args.has("--list") || args.has("--list-verbose")) + { + std::cout << "\n"; + print_registered_kernels(registry, std::cout, args.has("--list-verbose")); + return 0; + } + + Dispatcher dispatcher(®istry); + std::cout << " Max size: " << max_size << "\n"; + + // ========================================================================= + // Run Multiple Problem Sizes + // ========================================================================= + std::cout << "\nStep 2: Run Multiple Sizes\n"; + print_separator(); + std::cout << std::setw(12) << "M" << std::setw(12) << "N" << std::setw(12) << "K" + << std::setw(12) << "Time(ms)" << std::setw(12) << "TFLOPS" << "\n"; + print_separator(); + + std::vector> all_sizes = { + {256, 256, 256}, + {512, 512, 512}, + {1024, 1024, 1024}, + {2048, 2048, 2048}, + {4096, 4096, 4096}, + }; + + std::vector> sizes; + for(const auto& [M, N, K] : all_sizes) + { + if(std::max({M, N, K}) <= max_size) + sizes.push_back({M, N, K}); + } + + using DataType = ck_tile::fp16_t; + bool all_passed = true; + + for(const auto& [M, N, K] : sizes) + { + Problem problem(M, N, K); + + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + double tflops = calculate_tflops(M, N, K, time_ms); + + std::cout << std::setw(12) << M << std::setw(12) << N << std::setw(12) << K << std::setw(12) + << std::fixed << std::setprecision(4) << time_ms << std::setw(12) + << std::setprecision(2) << tflops << "\n"; + + // Verify + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < M * N; ++i) + { + if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f) + ++errors; + } + if(errors > 0) + all_passed = false; + } + + print_separator(); + std::cout << "Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n"; + print_separator(); + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp b/dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp new file mode 100644 index 0000000000..61608c7914 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp @@ -0,0 +1,344 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 03: GEMM Benchmark & Validation + * + * Combined example demonstrating: + * 1. Benchmarking with statistics (warmup, iterations, min/max/mean/median) + * 2. Validation against CK Tile reference (CPU or GPU) + * + * Build: cd dispatcher/build && cmake .. && make gemm_03_benchmark_validation + * Usage: ./gemm_03_benchmark_validation [--size N] [--verify MODE] [--benchmark] + * + * Options: + * --size N Problem size MxNxK (default: 512) + * --verify MODE 0=none, 1=CPU ref, 2=GPU ref (default: 1) + * --benchmark Run full benchmark with statistics + * --warmup N Warmup iterations (default: 5) + * --iterations N Benchmark iterations (default: 20) + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/reference/reference_gemm.hpp" + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using namespace ck_tile::literals; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: High-performance kernels for benchmarking/validation +// ============================================================================= + +DECL_KERNEL_SET(benchmark_validation_kernels, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// Helper: Layout detection +// ============================================================================= + +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 03: GEMM Benchmark & Validation", + "Benchmark and/or validate GEMM output against reference"); + args.add_option("--size", "512", "Problem size MxNxK"); + args.add_option("--verify", "1", "Verification: 0=none, 1=CPU ref, 2=GPU ref"); + args.add_flag("--benchmark", "Run benchmark with statistics"); + args.add_option("--warmup", "5", "Warmup iterations"); + args.add_option("--iterations", "20", "Benchmark iterations"); + args.add_option("--rtol", "0.01", "Relative tolerance"); + args.add_option("--atol", "0.01", "Absolute tolerance"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + int M = args.get_int("--size", 512); + int N = M; + int K = M; + int verify = args.get_int("--verify", 1); + bool do_benchmark = args.has("--benchmark"); + int warmup = args.get_int("--warmup", 5); + int iterations = args.get_int("--iterations", 20); + float rtol = args.get_float("--rtol", 0.01f); + float atol = args.get_float("--atol", 0.01f); + std::string gfx_arch = args.get("--arch", "gfx942"); + + print_header("Example 03: GEMM Benchmark & Validation"); + + std::cout << "\nConfiguration:\n"; + std::cout << " Problem: " << M << " x " << N << " x " << K << "\n"; + std::cout << " Layout: RCR (A=row, B=col, C=row)\n"; + std::cout << " Verify: " << verify; + if(verify == 0) + std::cout << " (disabled)"; + else if(verify == 1) + std::cout << " (CPU reference)"; + else if(verify == 2) + std::cout << " (GPU reference)"; + std::cout << "\n"; + std::cout << " Benchmark: " << (do_benchmark ? "yes" : "no") << "\n"; + if(do_benchmark) + { + std::cout << " Warmup: " << warmup << " iterations\n"; + std::cout << " Measure: " << iterations << " iterations\n"; + } + + // ========================================================================= + // Setup Registry and Dispatcher + // ========================================================================= + Registry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + Dispatcher dispatcher(®istry); + + std::cout << " Kernels: " << registry.size() << " registered\n"; + print_registered_kernels(registry); + + // ========================================================================= + // Initialize data with proper tensor descriptors + // ========================================================================= + using ALayout = ck_tile::tensor_layout::gemm::RowMajor; + using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayout = ck_tile::tensor_layout::gemm::RowMajor; + + using ADataType = ck_tile::fp16_t; + using BDataType = ck_tile::fp16_t; + using CDataType = ck_tile::fp16_t; + using AccDataType = float; + + auto stride_a = ck_tile::get_default_stride(M, K, 0_uz, is_row_major(ALayout{})); + auto stride_b = ck_tile::get_default_stride(K, N, 0_uz, is_row_major(BLayout{})); + auto stride_c = ck_tile::get_default_stride(M, N, 0_uz, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_a, is_row_major(ALayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_b, is_row_major(BLayout{}))); + ck_tile::HostTensor c_m_n_dev( + ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{}))); + ck_tile::HostTensor c_m_n_ref( + ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{}))); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(b_k_n); + + std::cout << "\nData:\n"; + std::cout << " A: " << M << " x " << K << " (fp16, row-major)\n"; + std::cout << " B: " << K << " x " << N << " (fp16, col-major)\n"; + std::cout << " C: " << M << " x " << N << " (fp16, row-major)\n"; + + // GPU memory + ck_tile::DeviceMem a_dev(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_dev(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev(c_m_n_dev.get_element_space_size_in_bytes()); + + a_dev.ToDevice(a_m_k.data()); + b_dev.ToDevice(b_k_n.data()); + + // ========================================================================= + // Compute Reference (if needed) + // ========================================================================= + if(verify > 0) + { + std::cout << "\nComputing reference...\n"; + c_m_n_ref.SetZero(); + + if(verify == 1) + { + std::cout << " Using CPU reference (ck_tile::reference_gemm)\n"; + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_ref); + } + else if(verify == 2) + { + std::cout << " Using GPU reference (ck_tile::reference_gemm_gpu)\n"; + ck_tile::DeviceMem c_ref_dev(c_m_n_ref.get_element_space_size_in_bytes()); + c_ref_dev.SetZero(); + + ck_tile::reference_gemm_gpu( + static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_ref_dev.GetDeviceBuffer()), + M, + N, + K, + stride_a, + stride_b, + stride_c); + + (void)hipDeviceSynchronize(); + c_ref_dev.FromDevice(c_m_n_ref.data()); + } + std::cout << " Reference complete.\n"; + } + + // ========================================================================= + // Run Kernel + // ========================================================================= + Problem problem(M, N, K); + auto selected = dispatcher.select_kernel(problem); + + std::cout << "\nRunning kernel:\n"; + if(selected) + std::cout << " Selected: " << selected->get_name() << "\n"; + + c_dev.SetZero(); + float time_ms = 0.0f; + std::vector times; + + if(do_benchmark) + { + // Warmup + std::cout << " Warming up (" << warmup << " iterations)...\n"; + for(int i = 0; i < warmup; ++i) + { + c_dev.SetZero(); + (void)dispatcher.run(static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_dev.GetDeviceBuffer()), + problem, + nullptr); + } + + // Benchmark + std::cout << " Benchmarking (" << iterations << " iterations)...\n"; + times.reserve(iterations); + for(int i = 0; i < iterations; ++i) + { + c_dev.SetZero(); + float t = dispatcher.run(static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_dev.GetDeviceBuffer()), + problem, + nullptr); + times.push_back(t); + } + time_ms = *std::min_element(times.begin(), times.end()); + } + else + { + // Single run + time_ms = dispatcher.run(static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_dev.GetDeviceBuffer()), + problem, + nullptr); + } + + c_dev.FromDevice(c_m_n_dev.data()); + + // ========================================================================= + // Results + // ========================================================================= + double flops = 2.0 * M * N * K; + double tflops = flops / (time_ms * 1e9); + + print_separator(); + std::cout << "Performance:\n"; + print_separator(); + + if(do_benchmark && !times.empty()) + { + std::sort(times.begin(), times.end()); + float min_t = times.front(); + float max_t = times.back(); + float median_t = times[times.size() / 2]; + float mean_t = std::accumulate(times.begin(), times.end(), 0.0f) / times.size(); + + std::cout << std::fixed << std::setprecision(4); + std::cout << " Min: " << min_t << " ms (" << std::setprecision(2) + << (flops / (min_t * 1e9)) << " TFLOPS)\n"; + std::cout << std::setprecision(4); + std::cout << " Max: " << max_t << " ms\n"; + std::cout << " Mean: " << mean_t << " ms (" << std::setprecision(2) + << (flops / (mean_t * 1e9)) << " TFLOPS)\n"; + std::cout << std::setprecision(4); + std::cout << " Median: " << median_t << " ms (" << std::setprecision(2) + << (flops / (median_t * 1e9)) << " TFLOPS)\n"; + } + else + { + std::cout << std::fixed << std::setprecision(4); + std::cout << " Time: " << time_ms << " ms\n"; + std::cout << std::setprecision(2); + std::cout << " TFLOPS: " << tflops << "\n"; + } + + // ========================================================================= + // Validation + // ========================================================================= + bool pass = true; + + if(verify > 0) + { + print_separator(); + std::cout << "Validation:\n"; + print_separator(); + std::cout << " Tolerance: rtol=" << rtol << ", atol=" << atol << "\n"; + + pass = ck_tile::check_err(c_m_n_dev, c_m_n_ref, "Validation Error!", rtol, atol); + + float max_abs_diff = 0.0f; + float max_rel_diff = 0.0f; + for(size_t i = 0; i < c_m_n_dev.get_element_space_size(); ++i) + { + float dev_val = static_cast(c_m_n_dev.mData[i]); + float ref_val = static_cast(c_m_n_ref.mData[i]); + float abs_diff = std::abs(dev_val - ref_val); + float rel_diff = (ref_val != 0.0f) ? abs_diff / std::abs(ref_val) : abs_diff; + max_abs_diff = std::max(max_abs_diff, abs_diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + } + + std::cout << " Max abs diff: " << max_abs_diff << "\n"; + std::cout << " Max rel diff: " << max_rel_diff << "\n"; + } + + // ========================================================================= + // Summary + // ========================================================================= + print_separator(); + std::cout << "Result: " << (pass ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return pass ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/04_heuristics.cpp b/dispatcher/examples/gemm/cpp/04_heuristics.cpp new file mode 100644 index 0000000000..2a8753cdff --- /dev/null +++ b/dispatcher/examples/gemm/cpp/04_heuristics.cpp @@ -0,0 +1,168 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 04: Custom Heuristics + * + * Demonstrates custom kernel selection heuristics for different workloads. + * + * Build: cd dispatcher/build && cmake .. && make gemm_04_heuristics + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: Multiple tile sizes for heuristic-based selection +// ============================================================================= + +DECL_KERNEL_SET(heuristics_kernels, + // Small tile - low latency + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + // Medium tile - balanced + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// Custom Heuristic +// ============================================================================= + +std::vector size_based_heuristic(const Problem& problem) +{ + std::vector ranked_kernels; + int64_t total_elements = problem.M * problem.N; + + if(total_elements < 100000) + { + ranked_kernels = {"gemm_64x64", "gemm_128x128"}; + } + else + { + ranked_kernels = {"gemm_128x128", "gemm_64x64"}; + } + return ranked_kernels; +} + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 04: Custom Heuristics", + "Demonstrates custom kernel selection heuristics"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 04: Custom Heuristics"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + + // ========================================================================= + // Setup Registry and Dispatcher + // ========================================================================= + Registry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + Dispatcher dispatcher(®istry); + dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic); + dispatcher.set_heuristic(size_based_heuristic); + + std::cout << "\nSetup:\n"; + std::cout << " Registry: " << registry.size() << " kernel(s)\n"; + std::cout << " Strategy: Heuristic (size-based)\n"; + + // ========================================================================= + // Test Different Problem Sizes + // ========================================================================= + std::cout << "\nTesting heuristic selection:\n"; + print_separator(); + + using DataType = ck_tile::fp16_t; + + std::vector> sizes = { + {128, 128, 64}, + {512, 512, 256}, + {2048, 2048, 1024}, + }; + + bool all_passed = true; + + for(const auto& [M, N, K] : sizes) + { + Problem problem(M, N, K); + auto selected = dispatcher.select_kernel(problem); + + std::cout << "Problem " << M << "x" << N << "x" << K << ":\n"; + if(selected) + { + std::cout << " Selected: " << selected->get_name() << "\n"; + } + + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + double tflops = calculate_tflops(M, N, K, time_ms); + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Verify + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < M * N; ++i) + { + float actual = static_cast(c_host[i]); + if(std::abs(actual - expected) > 0.01f * expected + 1.0f) + ++errors; + } + bool pass = (errors == 0); + std::cout << " Verify: " << (pass ? "PASS" : "FAIL") << "\n"; + if(!pass) + all_passed = false; + print_separator(); + } + + std::cout << "Overall: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n"; + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/05_json_export.cpp b/dispatcher/examples/gemm/cpp/05_json_export.cpp new file mode 100644 index 0000000000..75ed7308af --- /dev/null +++ b/dispatcher/examples/gemm/cpp/05_json_export.cpp @@ -0,0 +1,127 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 05: JSON Export + * + * Demonstrates exporting registry information to JSON format. + * + * Build: cd dispatcher/build && cmake .. && make gemm_05_json_export + */ + +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: Multiple kernels for JSON export demo +// ============================================================================= + +DECL_KERNEL_SET(json_export_kernels, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 05: JSON Export", "Export registry information to JSON format"); + args.add_option("--output", "registry.json", "Output JSON file path"); + args.add_option("--arch", "gfx942", "GPU architecture"); + args.add_flag("--list", "List all kernel sets"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 05: JSON Export"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + + if(args.has("--list")) + { + std::cout << "\nDeclared Kernel Sets:\n"; + KernelSetRegistry::instance().print(); + return 0; + } + + std::string output_file = args.get("--output", "registry.json"); + + // ========================================================================= + // Setup Registry + // ========================================================================= + std::cout << "\nSetting up registry...\n"; + Registry registry; + registry.set_name("json_export_registry"); + + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + std::cout << " Registry: " << registry.get_name() << "\n"; + std::cout << " Kernels: " << registry.size() << "\n"; + + // ========================================================================= + // Export to JSON + // ========================================================================= + std::cout << "\nExporting to JSON...\n"; + + std::string json = registry.export_json(true); + + std::cout << "\nJSON Preview (first 500 chars):\n"; + print_separator(); + std::cout << json.substr(0, std::min(size_t(500), json.size())); + if(json.size() > 500) + std::cout << "\n..."; + std::cout << "\n"; + print_separator(); + + // Write to file + std::ofstream file(output_file); + if(file.is_open()) + { + file << json; + file.close(); + std::cout << "\nExported to: " << output_file << "\n"; + std::cout << "File size: " << json.size() << " bytes\n"; + } + else + { + std::cerr << "Failed to write to: " << output_file << "\n"; + return 1; + } + + // ========================================================================= + // Also show kernel set declarations + // ========================================================================= + std::cout << "\nKernel Set Declarations:\n"; + print_separator(); + KernelSetRegistry::instance().print(); + print_separator(); + + return 0; +} diff --git a/dispatcher/examples/gemm/cpp/06_multi_registry.cpp b/dispatcher/examples/gemm/cpp/06_multi_registry.cpp new file mode 100644 index 0000000000..3077f2d754 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/06_multi_registry.cpp @@ -0,0 +1,294 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 06: Multiple Registries and Multiple Kernel Sets + * + * Demonstrates: + * - Multiple DECL_KERNEL_SET declarations (each with multiple kernels) + * - Separate Registry instances for different workload types + * - Independent Dispatchers that select from their respective registries + * + * Registration patterns: + * - REGISTER_GENERATED_KERNELS(registry, arch) -> all kernels to one registry + * - REGISTER_KERNEL_SET("set_name", registry, arch) -> specific set by name + * - generated::get_kernel_set_names() -> list available set names + * + * Build: cd dispatcher/build && cmake .. && make gemm_06_multi_registry + * Usage: ./gemm_06_multi_registry [--list] [--help] + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SETS: Multiple sets with multiple kernels each +// ============================================================================= + +// Compute-bound kernel set: Large tiles for high arithmetic intensity +// Max tile with 32x32 warp is 128x128 (16 warps = 1024 threads) +DECL_KERNEL_SET(compute_bound_set, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) // Large tile, max for 32x32 warp + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) // Same tile, different K for variety + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// Memory-bound kernel set: Smaller tiles for better cache efficiency +DECL_KERNEL_SET(memory_bound_set, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// Latency-optimized: Minimal overhead tiles +DECL_KERNEL_SET(latency_set, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 06: Multiple Registries", + "Separate registries for different workload types"); + args.add_flag("--list", "List all declared kernel sets"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 06: Multiple Registries & Kernel Sets"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + + // ========================================================================= + // Step 1: Show declared kernel sets (from DECL_KERNEL_SET macros) + // ========================================================================= + std::cout << "\nStep 1: Declared Kernel Sets\n"; + std::cout << "-----------------------------\n"; + KernelSetRegistry::instance().print(); + + if(args.has("--list")) + { + // Print detailed info + for(const auto& name : KernelSetRegistry::instance().names()) + { + const auto& set = KernelSetRegistry::instance().get(name); + std::cout << "\n " << name << ":\n"; + for(const auto& decl : set.declarations()) + { + std::cout << " - " << decl.name() << " (tile=" << decl.algorithm.tile_m_ << "x" + << decl.algorithm.tile_n_ << "x" << decl.algorithm.tile_k_ << ")\n"; + } + } + return 0; + } + + // ========================================================================= + // Step 2: Create registries and demonstrate MERGING + // ========================================================================= + std::cout << "\nStep 2: Create and Merge Registries\n"; + std::cout << "------------------------------------\n"; + + // Create individual registries first + Registry compute_registry; + Registry latency_registry; + Registry memory_registry; + + compute_registry.set_name("compute_bound"); + latency_registry.set_name("latency_optimized"); + memory_registry.set_name("memory_bound"); + + // Register kernels to individual registries using set names (no hardcoding) + REGISTER_KERNEL_SET("compute_bound_set", compute_registry, gfx_arch); + REGISTER_KERNEL_SET("latency_set", latency_registry, gfx_arch); + REGISTER_KERNEL_SET("memory_bound_set", memory_registry, gfx_arch); + + std::cout << " Individual registries:\n"; + std::cout << " compute_bound: " << compute_registry.size() << " kernel(s)\n"; + std::cout << " latency_optimized: " << latency_registry.size() << " kernel(s)\n"; + std::cout << " memory_bound: " << memory_registry.size() << " kernel(s)\n"; + + // MERGE compute + latency into a combined registry + Registry combined_registry; + combined_registry.set_name("compute_latency_combined"); + + // Register both sets into combined registry + REGISTER_KERNEL_SET("compute_bound_set", combined_registry, gfx_arch); + REGISTER_KERNEL_SET("latency_set", combined_registry, gfx_arch); + + std::cout << "\n After merging compute + latency:\n"; + std::cout << " combined: " << combined_registry.size() << " kernel(s)\n"; + std::cout << " memory (separate): " << memory_registry.size() << " kernel(s)\n"; + + // ========================================================================= + // Step 3: Create dispatchers - one merged, one separate + // ========================================================================= + std::cout << "\nStep 3: Create Dispatchers\n"; + std::cout << "--------------------------\n"; + + Dispatcher combined_dispatcher(&combined_registry); // compute + latency merged + Dispatcher memory_dispatcher(&memory_registry); // memory separate + + std::cout << " combined_dispatcher: compute + latency kernels (" << combined_registry.size() + << " kernels)\n"; + std::cout << " memory_dispatcher: memory-bound kernels (" << memory_registry.size() + << " kernels)\n"; + + // ========================================================================= + // Step 4: Run with different dispatchers + // ========================================================================= + std::cout << "\nStep 4: Run Workloads\n"; + print_separator(); + + using DataType = ck_tile::fp16_t; + + struct WorkloadTest + { + const char* name; + Dispatcher* dispatcher; + int M, N, K; + }; + + std::vector tests = { + {"Compute-bound (combined)", &combined_dispatcher, 4096, 4096, 4096}, + {"Memory-bound (separate)", &memory_dispatcher, 1024, 1024, 1024}, + {"Latency-opt (combined)", &combined_dispatcher, 512, 512, 512}, + }; + + bool all_passed = true; + + for(const auto& test : tests) + { + Problem problem(test.M, test.N, test.K); + + // Allocate and initialize + GpuBuffer a_dev(test.M * test.K); + GpuBuffer b_dev(test.K * test.N); + GpuBuffer c_dev(test.M * test.N); + + std::vector a_host(test.M * test.K, DataType(1.0f)); + std::vector b_host(test.K * test.N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + // Select kernel and run + auto selected = test.dispatcher->select_kernel(problem); + float time_ms = + test.dispatcher->run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + double tflops = calculate_tflops(test.M, test.N, test.K, time_ms); + + std::cout << test.name << " (" << test.M << "x" << test.N << "x" << test.K << "):\n"; + if(selected) + std::cout << " Selected: " << selected->get_name() << "\n"; + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Verify ALL elements + std::vector c_host(test.M * test.N); + c_dev.copy_to_host(c_host.data()); + const float expected = static_cast(test.K); + + int num_errors = 0; + float max_error = 0.0f; + for(int i = 0; i < test.M * test.N; ++i) + { + float actual = static_cast(c_host[i]); + float error = std::abs(actual - expected); + max_error = std::max(max_error, error); + // Allow 1% relative tolerance for FP16 accumulation + if(error > 0.01f * expected + 1.0f) + ++num_errors; + } + + bool test_passed = (num_errors == 0); + std::cout << " Verify: " << (test.M * test.N) << " elements, errors=" << num_errors + << "\n"; + std::cout << " Status: " << (test_passed ? "PASS" : "FAIL") << "\n\n"; + + if(!test_passed) + all_passed = false; + } + + // ========================================================================= + // Summary + // ========================================================================= + print_separator(); + std::cout << "Multi-Registry Pattern Summary:\n"; + print_separator(); + std::cout << R"( +// 1. Declare multiple kernel sets +DECL_KERNEL_SET(compute_bound_set, .add(...)); +DECL_KERNEL_SET(memory_bound_set, .add(...)); +DECL_KERNEL_SET(latency_set, .add(...)); + +// 2. Create registries and register by set NAME (no hardcoding!) +Registry combined_reg, memory_reg; +REGISTER_KERNEL_SET("compute_bound_set", combined_reg, arch); // Add compute +REGISTER_KERNEL_SET("latency_set", combined_reg, arch); // Merge latency +REGISTER_KERNEL_SET("memory_bound_set", memory_reg, arch); // Separate + +// 3. Create dispatchers from merged/separate registries +Dispatcher combined_disp(&combined_reg); // Has both compute + latency +Dispatcher memory_disp(&memory_reg); // Has only memory-bound + +// 4. Choose dispatcher based on workload +if (problem.is_memory_bound()) + memory_disp.run(...); +else + combined_disp.run(...); // Handles both compute & latency workloads +)"; + print_separator(); + std::cout << "Overall Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n"; + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/README.md b/dispatcher/examples/gemm/cpp/README.md new file mode 100644 index 0000000000..1d81a90a0e --- /dev/null +++ b/dispatcher/examples/gemm/cpp/README.md @@ -0,0 +1,229 @@ +# GEMM C++ Examples + +CK Tile Dispatcher C++ examples for GEMM (General Matrix Multiplication) operations. + +> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md) + +## Quick Start + +### Build and Run + +```bash +cd /path/to/composable_kernel/dispatcher +mkdir -p build && cd build + +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Build (kernels generated automatically by CMake) +make -j$(nproc) + +# Run examples +cd examples +./gemm_01_basic +./gemm_03_benchmark_validation +./gemm_04_heuristics +``` + +## Examples + +| Example | Description | Complexity | +|---------|-------------|------------| +| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ | +| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ | +| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ | +| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ | +| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ | +| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ | + +## Example Details + +### 01_basic_gemm.cpp - Basic GEMM +Demonstrates the declarative kernel API with three patterns: + +1. **Autofill Pattern** - Minimal specification, defaults filled automatically +2. **Autocorrect Pattern** - Invalid parameters corrected at build time +3. **Full Specification Pattern** - Complete kernel configuration + +```cpp +DECL_KERNEL_SET(basic_kernels, + // Pattern 1: Autofill - minimal specification + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm(), // Defaults filled by autofill + "gfx942" + ) + // Pattern 2: Full specification + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16) + .pipeline("compv4").scheduler("intrawave"), + "gfx942" + ) +); +``` + +**Features:** +- Uses generic `REGISTER_GENERATED_KERNELS` macro +- `print_registered_kernels()` utility for debugging +- Demonstrates autofill messages during build + +### 02_multi_size.cpp - Wildcard Expansion +Demonstrates automatic generation of multiple kernel configurations: + +```cpp +DECL_KERNEL_SET(multi_kernels, + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(*, *, 32) // Wildcard tile M and N + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave"), + "gfx942" + ) +); +``` + +**Wildcard Values:** +- `*`, `-1`, or `ANY_INT` expand to all valid configurations +- Architecture filter prunes invalid combinations automatically +- Example generates 5 valid kernels after arch filtering (from 7 expansions) + +### 03_benchmark_validation.cpp - Benchmark + Validation +Consolidated example combining performance benchmarking with correctness validation: + +```bash +# Benchmark only +./gemm_03_benchmark_validation --warmup 10 --iterations 100 + +# With CPU validation +./gemm_03_benchmark_validation --verify 1 --rtol 1e-3 --atol 1e-3 + +# With GPU reference validation (faster for large matrices) +./gemm_03_benchmark_validation --verify 2 +``` + +**Features:** +- Warmup iterations (discarded from timing) +- Benchmark iterations with statistics (min/max/mean/median) +- CPU reference validation using `ck_tile::reference_gemm` +- GPU reference validation using `ck_tile::reference_gemm_gpu` +- Configurable tolerances + +### 04_heuristics.cpp - Heuristic Selection +Demonstrates custom kernel selection based on problem characteristics: + +```cpp +// Problem size analysis +auto heuristic = [](const Problem& p) -> std::optional { + if (p.M() * p.N() < 256 * 256) { + return small_kernel_key; // Memory-bound heuristic + } else { + return large_kernel_key; // Compute-bound heuristic + } +}; + +dispatcher.set_heuristic(heuristic); +``` + +**Features:** +- Problem size analysis (small vs large matrices) +- Compute-bound vs memory-bound selection +- Custom heuristic function registration + +### 05_json_export.cpp - JSON Export +Exports registry information to JSON for external tool integration: + +```cpp +auto json = registry.to_json(); +std::ofstream file("kernels.json"); +file << json; +``` + +**Use Cases:** +- Kernel metadata serialization +- External analysis tools +- Configuration management + +### 06_multi_registry.cpp - Multiple Registries +Demonstrates using multiple registries with named kernel sets: + +```cpp +// Define separate kernel sets +DECL_KERNEL_SET(compute_optimized, ...); +DECL_KERNEL_SET(latency_optimized, ...); + +// Register to specific registries +Registry compute_registry, latency_registry; +REGISTER_KERNEL_SET(compute_optimized, compute_registry); +REGISTER_KERNEL_SET(latency_optimized, latency_registry); + +// Use appropriate registry based on workload +Dispatcher compute_dispatcher(compute_registry); +Dispatcher latency_dispatcher(latency_registry); +``` + +**Features:** +- Named kernel set registration with `REGISTER_KERNEL_SET` macro +- Separate registries for different optimization goals +- Dynamic kernel set selection by name + +## Benchmark Parameters (stream_config) + +CK Tile uses `stream_config` for benchmark control: + +```cpp +ck_tile::stream_config cfg{ + nullptr, // stream_id - HIP stream (nullptr = default) + true, // time_kernel - Enable timing + 1, // log_level - Verbosity (0=quiet, 1=normal) + 5, // cold_niters - Warmup iterations + 20, // nrepeat - Benchmark iterations + true, // is_gpu_timer - Use GPU events vs CPU chrono + false, // flush_cache - Flush L2 cache between iterations + 1 // rotating_count - Rotating buffers for cache simulation +}; +``` + +| Parameter | CLI Option | Default | Description | +|-----------|------------|---------|-------------| +| `cold_niters_` | `--warmup` | 5 | Warmup iterations | +| `nrepeat_` | `--iterations` | 100 | Benchmark iterations | +| `flush_cache_` | - | false | Flush L2 cache | +| `rotating_count_` | - | 1 | Rotating buffers | +| `is_gpu_timer_` | - | true | GPU timer vs CPU | + +## Declarative Kernel Pattern + +All examples use the declarative `DECL_KERNEL_SET` macro: + +```cpp +DECL_KERNEL_SET(my_kernels, + .add( + Signature() // WHAT: operation signature + .dtype("fp16") // Data type + .layout("rcr"), // Matrix layouts (A=row, B=col, C=row) + Algorithm() // HOW: implementation details + .tile(256, 256, 32) // Tile sizes (M, N, K) + .wave(2, 2, 1) // Wave configuration + .warp(32, 32, 16) // Warp tile sizes + .pipeline("compv4") // Pipeline type + .scheduler("intrawave"), // Scheduler type + "gfx942" // WHERE: target architecture + ) +); +``` + +**Key Macros:** +- `DECL_KERNEL_SET(name, ...)` - Declare a kernel set +- `REGISTER_GENERATED_KERNELS` - Register all kernels from this example +- `REGISTER_KERNEL_SET(name, registry)` - Register specific kernel set to a registry + +## Related Documentation + +- [Python GEMM Examples](../python/README.md) +- [Convolution Examples](../../conv/cpp/README.md) +- [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/gemm/python/01_basic_gemm.py b/dispatcher/examples/gemm/python/01_basic_gemm.py new file mode 100644 index 0000000000..93a78d24d1 --- /dev/null +++ b/dispatcher/examples/gemm/python/01_basic_gemm.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 01: Basic GEMM with Multiple Kernels + +Demonstrates: +1. Declaring multiple kernel configurations +2. Printing all registered kernels +3. Running each kernel and validating output +4. Comparing performance across kernels + +Complexity: ★★☆☆☆ + +Usage: + python3 01_basic_gemm.py + python3 01_basic_gemm.py --help + python3 01_basic_gemm.py --dtype bf16 + python3 01_basic_gemm.py --size 2048 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +@dataclass +class KernelSpec: + """Specification for a kernel configuration""" + + name: str + tile_m: int + tile_n: int + tile_k: int + pipeline: str = "compv3" + scheduler: str = "intrawave" + + +# Define multiple kernel configurations to test (50+ kernels) +KERNEL_SPECS = [ + # Small tiles - compv3 + KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"), + KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"), + # Small tiles - compv4 + KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"), + KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"), + # Medium tiles - compv3 + KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"), + KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"), + KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"), + # Medium tiles - compv4 + KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"), + KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"), + KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"), + # Rectangular tiles - compv3 + KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"), + KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"), + KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"), + KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"), + # Rectangular tiles - compv4 + KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"), + KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"), + KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"), + KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"), + # Large tiles - compv3 + KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"), + KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"), + KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"), + KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"), + KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"), + KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"), + # Large tiles - compv4 + KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"), + KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"), + KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"), + KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"), + KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"), + KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"), + # Interwave scheduler variants + KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"), + KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"), + KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"), + KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"), + # More tile_k variations - compv3 + KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"), + KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"), + KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"), + # More tile_k variations - compv4 + KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"), + KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"), + KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"), + # Additional rectangular + KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"), + KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"), + KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"), + KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"), + # Additional compv4 variants + KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"), + KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"), + KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"), + KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"), +] + + +def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + """Create a KernelConfig from a spec""" + # Adjust warp tiles based on tile size + if spec.tile_m <= 64: + warp_m, warp_n = 16, 16 + else: + warp_m, warp_n = 32, 32 + + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=spec.tile_m, + tile_n=spec.tile_n, + tile_k=spec.tile_k, + wave_m=2, + wave_n=2, + wave_k=1, + warp_m=warp_m, + warp_n=warp_n, + warp_k=16, + pipeline=spec.pipeline, + scheduler=spec.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +def print_kernel_table(specs: List[KernelSpec], dtype: str): + """Print a formatted table of kernel configurations""" + print("\n" + "=" * 70) + print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)") + print("=" * 70) + print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}") + print(" " + "-" * 68) + + for i, spec in enumerate(specs, 1): + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + print( + f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}" + ) + + print(" " + "-" * 68) + print(f" Data type: {dtype}") + + +def main(): + parser = argparse.ArgumentParser( + description="Basic GEMM Example with Multiple Kernels", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 01_basic_gemm.py # Default FP16 with 4 kernels + python3 01_basic_gemm.py --dtype bf16 # BF16 mode + python3 01_basic_gemm.py --size 2048 # Larger problem size + python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target architecture (default: gfx942)", + ) + parser.add_argument( + "--size", + type=int, + default=512, + help="Problem size MxNxK (default: 512)", + ) + parser.add_argument( + "--num-kernels", + type=int, + default=0, + help="Number of kernels to test (0 = all)", + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 70) + print("Example 01: Basic GEMM with Multiple Kernels") + print("=" * 70) + + # Select kernels to test + specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS + + # ========================================================================= + # Step 1: Print all kernel configurations + # ========================================================================= + print_kernel_table(specs, args.dtype) + + # ========================================================================= + # Step 2: Setup and test each kernel + # ========================================================================= + print("\n" + "=" * 70) + print(" RUNNING KERNELS") + print("=" * 70) + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + M, N, K = args.size, args.size, args.size + + results = [] + + print(f"\n Problem size: {M}x{N}x{K}\n") + print( + f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}" + ) + print(" " + "-" * 78) + + for i, spec in enumerate(specs, 1): + # Create unique test data per kernel + np.random.seed(42 + i * 1000) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Create config and setup dispatcher + config = create_kernel_config(spec, args.dtype, args.arch) + + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"kernel_{spec.name}", + verbose=False, + auto_rebuild=True, + ) + + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + + if not setup.success: + print( + f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((spec.name, False, 0, 0, 0)) + cleanup_gemm() + continue + + dispatcher = setup.dispatcher + + # Check if size is supported + if not dispatcher.is_supported(M, N, K): + print( + f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}" + ) + results.append((spec.name, False, 0, 0, 0)) + cleanup_gemm() + continue + + # Run GEMM + result = dispatcher.run(A, B, M, N, K) + + if not result.success: + print( + f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((spec.name, False, 0, 0, 0)) + cleanup_gemm() + continue + + # Validate against NumPy reference + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + max_err = np.max(np.abs(result.output - C_ref)) + + # Check if within tolerance + passed = max_err < 1e-2 + status = "PASS" if passed else "FAIL" + + print( + f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}" + ) + results.append((spec.name, passed, result.time_ms, result.tflops, max_err)) + + cleanup_gemm() + + # ========================================================================= + # Step 3: Summary + # ========================================================================= + print("\n" + "=" * 70) + print(" SUMMARY") + print("=" * 70) + + passed = sum(1 for r in results if r[1]) + failed = len(results) - passed + + print(f"\n Results: {passed}/{len(results)} kernels passed") + print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}") + + if results: + valid_results = [r for r in results if r[1]] + if valid_results: + best = max(valid_results, key=lambda x: x[3]) + print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)") + + if failed == 0: + print("\n *** ALL KERNELS PASSED ***") + else: + print(f"\n *** {failed} KERNELS FAILED ***") + + print("=" * 70) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/02_batch_gemm.py b/dispatcher/examples/gemm/python/02_batch_gemm.py new file mode 100644 index 0000000000..039aba2790 --- /dev/null +++ b/dispatcher/examples/gemm/python/02_batch_gemm.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 02: Batch GEMM + +Runs multiple GEMM operations with different sizes. + +Complexity: ★★☆☆☆ + +Usage: + python3 02_batch_gemm.py + python3 02_batch_gemm.py --help + python3 02_batch_gemm.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Batch GEMM Example - runs multiple sizes", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 02_batch_gemm.py # Default FP16 + python3 02_batch_gemm.py --dtype bf16 # BF16 GEMM + python3 02_batch_gemm.py --max-size 2048 # Limit max size + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--max-size", + type=int, + default=4096, + help="Maximum problem size (default: 4096)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 02: Batch GEMM") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + + # ========================================================================= + # Step 2: Run batch of different sizes + # ========================================================================= + print("\nStep 2: Run Batch") + + # Generate sizes up to max_size + all_sizes = [ + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + ] + sizes = [(m, n, k) for m, n, k in all_sizes if max(m, n, k) <= args.max_size] + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + print(f"\n {'Size':<20} | {'Time (ms)':>12} | {'TFLOPS':>10} | {'Status':>8}") + print(" " + "-" * 60) + + total_ops = 0 + total_time = 0 + + for M, N, K in sizes: + if not dispatcher.is_supported(M, N, K): + print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Skipped") + continue + + A = np.random.randn(M, K).astype(np_dtype) * 0.1 + B = np.random.randn(K, N).astype(np_dtype) * 0.1 + + result = dispatcher.run(A, B, M, N, K) + + if result.success: + total_ops += 2 * M * N * K + total_time += result.time_ms + print( + f" {M:>4}x{N:>4}x{K:<4} | {result.time_ms:>12.4f} | {result.tflops:>10.2f} | OK" + ) + else: + print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Error") + + print(" " + "-" * 60) + + if total_time > 0: + avg_tflops = (total_ops / 1e12) / (total_time / 1000) + print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS") + + # Cleanup + cleanup_gemm() + + print("\n" + "=" * 60) + print("Batch GEMM complete!") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/03_benchmark.py b/dispatcher/examples/gemm/python/03_benchmark.py new file mode 100644 index 0000000000..bec1b7e2fb --- /dev/null +++ b/dispatcher/examples/gemm/python/03_benchmark.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 03: Benchmark + +Performance benchmarking with compute-optimized kernel configuration. + +Complexity: ★★★☆☆ + +Usage: + python3 03_benchmark.py + python3 03_benchmark.py --help + python3 03_benchmark.py --size 4096 + python3 03_benchmark.py --dtype bf16 --iterations 20 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmark Example - performance testing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 03_benchmark.py # Default benchmark suite + python3 03_benchmark.py --size 4096 # Single size benchmark + python3 03_benchmark.py --dtype bf16 # BF16 benchmark + python3 03_benchmark.py --iterations 20 # More iterations + """, + ) + parser.add_argument( + "--dtype", + default="bf16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: bf16)", + ) + parser.add_argument( + "--size", + type=int, + default=0, + help="Single problem size MxNxK (default: run all sizes)", + ) + parser.add_argument( + "--warmup", type=int, default=3, help="Warmup iterations (default: 3)" + ) + parser.add_argument( + "--iterations", type=int, default=10, help="Benchmark iterations (default: 10)" + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 03: Benchmark") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher with compute-optimized config + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + pipeline="compv4", + scheduler="intrawave", + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + + # ========================================================================= + # Step 2: Benchmark + # ========================================================================= + print("\nStep 2: Benchmark") + + if args.size > 0: + sizes = [(args.size, args.size, args.size)] + else: + sizes = [ + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + (1024, 2048, 512), + (2048, 1024, 2048), + ] + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + print(f" Warmup: {args.warmup}, Iterations: {args.iterations}\n") + + print(f" {'Size':<20} | {'Min (ms)':>10} | {'Avg (ms)':>10} | {'TFLOPS':>10}") + print(" " + "-" * 60) + + all_tflops = [] + + for M, N, K in sizes: + if not dispatcher.is_supported(M, N, K): + continue + + A = np.random.randn(M, K).astype(np_dtype) * 0.1 + B = np.random.randn(K, N).astype(np_dtype) * 0.1 + + # Warmup + for _ in range(args.warmup): + dispatcher.run(A, B, M, N, K) + + # Benchmark + times = [] + for _ in range(args.iterations): + result = dispatcher.run(A, B, M, N, K) + if result.success: + times.append(result.time_ms) + + if times: + min_time = min(times) + avg_time = sum(times) / len(times) + tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12 + all_tflops.append(tflops) + print( + f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}" + ) + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + print("Summary") + print("=" * 60) + + if all_tflops: + print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS") + print(f" Peak: {max(all_tflops):.2f} TFLOPS") + + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/04_validation.py b/dispatcher/examples/gemm/python/04_validation.py new file mode 100644 index 0000000000..2fe54c53f7 --- /dev/null +++ b/dispatcher/examples/gemm/python/04_validation.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 04: Validation + +Validates GPU GEMM against NumPy reference. + +Complexity: ★★★☆☆ + +Usage: + python3 04_validation.py + python3 04_validation.py --help + python3 04_validation.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + Validator, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Validation Example - validates GPU results against NumPy", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 04_validation.py # Default FP16 validation + python3 04_validation.py --dtype bf16 # BF16 validation + python3 04_validation.py --rtol 1e-2 # Relaxed tolerance + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--rtol", type=float, default=1e-3, help="Relative tolerance (default: 1e-3)" + ) + parser.add_argument( + "--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)" + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 04: Validation") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + + # ========================================================================= + # Step 2: Run validation tests + # ========================================================================= + print("\nStep 2: Validation Tests") + + validator = Validator(rtol=args.rtol, atol=args.atol) + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + test_cases = [ + ("Identity", 128, 128, 128, "identity"), + ("Small", 256, 256, 256, "random"), + ("Medium", 512, 512, 512, "random"), + ("Large", 1024, 1024, 1024, "random"), + ("Non-square", 512, 1024, 256, "random"), + ] + + passed = 0 + failed = 0 + + print(f"\n {'Test':<15} | {'Size':<15} | {'Max Err':>10} | {'Status':>8}") + print(" " + "-" * 55) + + for name, M, N, K, pattern in test_cases: + if not dispatcher.is_supported(M, N, K): + print(f" {name:<15} | {M}x{N}x{K:<5} | {'N/A':>10} | Skipped") + continue + + np.random.seed(42) + if pattern == "identity": + A = np.eye(M, K, dtype=np_dtype) + B = np.eye(K, N, dtype=np_dtype) + else: + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + result = dispatcher.run(A, B, M, N, K) + if not result.success: + print(f" {name:<15} | {M}x{N}x{K:<5} | {'GPU Err':>10} | FAILED") + failed += 1 + continue + + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + is_valid, max_err, _ = validator.check(result.output, C_ref) + + if is_valid: + print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | PASSED") + passed += 1 + else: + print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED") + failed += 1 + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + total = passed + failed + print(f"Results: {passed}/{total} passed") + print(f"Settings: dtype={args.dtype}, rtol={args.rtol}, atol={args.atol}") + print("=" * 60) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/05_numpy_integration.py b/dispatcher/examples/gemm/python/05_numpy_integration.py new file mode 100644 index 0000000000..493ce46d22 --- /dev/null +++ b/dispatcher/examples/gemm/python/05_numpy_integration.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 05: NumPy Integration + +Shows how to create a GPU-accelerated matmul wrapper. + +Complexity: ★★☆☆☆ + +Usage: + python3 05_numpy_integration.py + python3 05_numpy_integration.py --help + python3 05_numpy_integration.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + Dispatcher, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +class GPUMatmul: + """GPU-accelerated matrix multiplication wrapper.""" + + def __init__(self, dispatcher: Dispatcher): + self.dispatcher = dispatcher + + def __call__(self, A: np.ndarray, B: np.ndarray) -> np.ndarray: + """Compute C = A @ B on GPU with CPU fallback.""" + M, K = A.shape + K2, N = B.shape + + if K != K2: + raise ValueError(f"Dimension mismatch: {A.shape} @ {B.shape}") + + if not self.dispatcher.is_supported(M, N, K): + return np.matmul(A, B) + + result = self.dispatcher.run(A, B, M, N, K) + return result.output if result.success else np.matmul(A, B) + + +def main(): + parser = argparse.ArgumentParser( + description="NumPy Integration Example - GPU-accelerated matmul wrapper", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 05_numpy_integration.py # Default FP16 + python3 05_numpy_integration.py --dtype bf16 # BF16 mode + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 05: NumPy Integration") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="numpy", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 2: Create GPU matmul wrapper + # ========================================================================= + print("\nStep 2: Create GPUMatmul") + + gpu_matmul = GPUMatmul(dispatcher=dispatcher) + print(" gpu_matmul ready") + + # ========================================================================= + # Step 3: Demo - Simple multiplication using gpu_matmul + # ========================================================================= + print("\nStep 3: Demo - Simple Multiplication") + + A = np.random.randn(1024, 512).astype(np_dtype) * 0.1 + B = np.random.randn(512, 256).astype(np_dtype) * 0.1 + + # Use the gpu_matmul wrapper + C = gpu_matmul(A, B) + print(f" gpu_matmul result: {C.shape}, sum={C.sum():.4f}") + + M, K = A.shape + _, N = B.shape + result = dispatcher.run(A, B, M, N, K) + + print(f" A: {A.shape}, B: {B.shape} -> C: {result.output.shape}") + print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS") + + # ========================================================================= + # Step 4: Demo - FFN block + # ========================================================================= + print("\nStep 4: Demo - FFN Block") + + batch, hidden, ffn = 128, 768, 3072 + X = np.random.randn(batch, hidden).astype(np_dtype) * 0.02 + W1 = np.random.randn(hidden, ffn).astype(np_dtype) * 0.02 + W2 = np.random.randn(ffn, hidden).astype(np_dtype) * 0.02 + + result1 = dispatcher.run(X, W1, batch, ffn, hidden) + H = result1.output + result2 = dispatcher.run(H, W2, batch, hidden, ffn) + + print(f" X: {X.shape} -> H: {H.shape} -> Y: {result2.output.shape}") + print(f" Total: {result1.time_ms + result2.time_ms:.4f} ms") + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + print("NumPy Integration Pattern:") + print("=" * 60) + print(" 1. setup_gemm_dispatcher(config)") + print(" 2. GPUMatmul(dispatcher)") + print(" 3. C = gpu_matmul(A, B)") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/06_json_export.py b/dispatcher/examples/gemm/python/06_json_export.py new file mode 100644 index 0000000000..9e062e507b --- /dev/null +++ b/dispatcher/examples/gemm/python/06_json_export.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 06: JSON Export + +Exports registry configuration to JSON. + +Complexity: ★★☆☆☆ + +Usage: + python3 06_json_export.py + python3 06_json_export.py --help + python3 06_json_export.py --output my_kernels.json +""" + +import sys +import json +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="JSON Export Example - exports registry to JSON", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 06_json_export.py # Default output to kernels.json + python3 06_json_export.py --output my.json # Custom output file + """, + ) + parser.add_argument( + "--output", + "-o", + default="kernels.json", + help="Output JSON file (default: kernels.json)", + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 06: JSON Export") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="export_demo", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + # ========================================================================= + # Step 2: Define additional configs for export + # ========================================================================= + print("\nStep 2: Define Additional Configs") + + configs = [ + config, + KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=256, + tile_n=256, + tile_k=64, + gfx_arch=args.arch, + ), + KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=64, + tile_n=64, + tile_k=32, + gfx_arch=args.arch, + ), + ] + + for cfg in configs: + print(f" - {cfg.tile_str}") + + # ========================================================================= + # Step 3: Export to JSON + # ========================================================================= + print("\nStep 3: Export to JSON") + + export_data = { + "registry": setup.registry.name, + "kernel_count": len(configs), + "kernels": [], + } + + for cfg in configs: + kernel_info = { + "tile": cfg.tile_str, + "dtypes": {"A": cfg.dtype_a, "B": cfg.dtype_b, "C": cfg.dtype_c}, + "layout": cfg.layout, + "pipeline": cfg.pipeline, + "target": cfg.gfx_arch, + } + export_data["kernels"].append(kernel_info) + + # Include C++ library info + if setup.lib: + cpp_json = setup.lib.export_registry_json() + try: + export_data["cpp_registry"] = json.loads(cpp_json) + except json.JSONDecodeError: + pass + + json_str = json.dumps(export_data, indent=2) + + with open(args.output, "w") as f: + f.write(json_str) + print(f" Saved to: {args.output}") + + # Preview + print("\nStep 4: Preview") + print("-" * 60) + print(json_str[:500] + ("..." if len(json_str) > 500 else "")) + print("-" * 60) + + # Cleanup + cleanup_gemm() + + print("\n" + "=" * 60) + print("JSON Export complete!") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/07_stress_test.py b/dispatcher/examples/gemm/python/07_stress_test.py new file mode 100644 index 0000000000..8160030631 --- /dev/null +++ b/dispatcher/examples/gemm/python/07_stress_test.py @@ -0,0 +1,513 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 07: Stress Test - Multiple Kernels with Validation + +Consolidated stress test that: +1. Declares multiple kernel configurations (various tiles, pipelines, layouts) +2. Prints all registered kernels with details +3. Validates each kernel against NumPy reference +4. Optional benchmarking mode + +This tests: +- Multiple tile sizes (64x64, 128x128, 256x256) +- Multiple pipelines (compv3, compv4) +- Multiple data types (fp16, bf16) +- Different schedulers (intrawave, interwave) + +Complexity: ★★★★☆ + +Usage: + python3 07_stress_test.py + python3 07_stress_test.py --help + python3 07_stress_test.py --num-kernels 10 + python3 07_stress_test.py --benchmark + python3 07_stress_test.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List, Tuple + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, + Validator, +) + + +@dataclass +class KernelSpec: + """A kernel specification for testing""" + + name: str + tile_m: int + tile_n: int + tile_k: int + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + warp_m: int = 32 + warp_n: int = 32 + warp_k: int = 16 + pipeline: str = "compv3" + scheduler: str = "intrawave" + layout: str = "rcr" + + def to_config(self, dtype: str, arch: str) -> KernelConfig: + """Convert to KernelConfig""" + # Adjust warp tiles for smaller tiles + warp_m = min(self.warp_m, self.tile_m // self.wave_m) + warp_n = min(self.warp_n, self.tile_n // self.wave_n) + warp_k = self.warp_k + + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a={"r": "row", "c": "col"}[self.layout[0]], + layout_b={"r": "row", "c": "col"}[self.layout[1]], + layout_c={"r": "row", "c": "col"}[self.layout[2]], + tile_m=self.tile_m, + tile_n=self.tile_n, + tile_k=self.tile_k, + wave_m=self.wave_m, + wave_n=self.wave_n, + wave_k=self.wave_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + pipeline=self.pipeline, + scheduler=self.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +# Define stress test kernel configurations +KERNEL_SPECS = [ + # Small tiles - compv3 + KernelSpec( + "small_compv3", + 64, + 64, + 32, + wave_m=2, + wave_n=2, + warp_m=16, + warp_n=16, + warp_k=32, + pipeline="compv3", + ), + KernelSpec( + "small_compv4", + 64, + 64, + 32, + wave_m=2, + wave_n=2, + warp_m=16, + warp_n=16, + warp_k=32, + pipeline="compv4", + ), + # Medium tiles + KernelSpec( + "medium_compv3", + 128, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + KernelSpec( + "medium_compv4", + 128, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv4", + ), + KernelSpec( + "medium_k64", + 128, + 128, + 64, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + # Rectangular tiles + KernelSpec( + "rect_64x128", + 64, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + KernelSpec( + "rect_128x64", + 128, + 64, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + # Different schedulers + KernelSpec( + "interwave", + 128, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + scheduler="interwave", + ), + # Large tiles + KernelSpec( + "large_compv3", + 256, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + KernelSpec( + "large_compv4", + 256, + 128, + 64, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv4", + ), +] + + +def print_kernel_summary(specs: List[KernelSpec], dtype: str): + """Print a summary table of all kernel specs""" + print("\n" + "=" * 80) + print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)") + print("=" * 80) + print( + f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Wave':<10} {'Warp':<12} {'Pipeline':<10} {'Sched':<10}" + ) + print(" " + "-" * 78) + + for i, spec in enumerate(specs, 1): + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + wave = f"{spec.wave_m}x{spec.wave_n}x{spec.wave_k}" + warp = f"{spec.warp_m}x{spec.warp_n}x{spec.warp_k}" + print( + f" {i:<3} {spec.name:<18} {tile:<12} {wave:<10} {warp:<12} {spec.pipeline:<10} {spec.scheduler:<10}" + ) + + print(" " + "-" * 78) + print(f" Data type: {dtype}\n") + + +def validate_kernel( + spec: KernelSpec, + dtype: str, + arch: str, + size: int, + validator: Validator, + kernel_index: int = 0, + verbose: bool = False, +) -> Tuple[bool, float, str]: + """ + Validate a single kernel configuration. + Returns: (passed, max_error, message) + """ + np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32 + + # Create config + config = spec.to_config(dtype, arch) + + # Setup dispatcher + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"stress_{spec.name}", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + return False, 0.0, f"Setup failed: {setup.error}" + + dispatcher = setup.dispatcher + M, N, K = size, size, size + + if not dispatcher.is_supported(M, N, K): + cleanup_gemm() + return False, 0.0, f"Size {M}x{N}x{K} not supported" + + # Use different seed per kernel to get unique test data + # This ensures each kernel is tested with different matrices + np.random.seed(42 + kernel_index * 1000) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Run GPU GEMM + result = dispatcher.run(A, B, M, N, K) + + if not result.success: + cleanup_gemm() + return False, 0.0, "GPU execution failed" + + # Validate against NumPy + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + is_valid, max_err, _ = validator.check(result.output, C_ref) + + cleanup_gemm() + + return is_valid, max_err, f"{result.time_ms:.2f}ms, {result.tflops:.1f} TFLOPS" + + +def benchmark_kernel( + spec: KernelSpec, + dtype: str, + arch: str, + size: int, + warmup: int = 3, + iterations: int = 10, +) -> Tuple[bool, float, float]: + """ + Benchmark a kernel configuration. + Returns: (success, avg_time_ms, tflops) + """ + np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32 + + config = spec.to_config(dtype, arch) + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"bench_{spec.name}", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + return False, 0.0, 0.0 + + dispatcher = setup.dispatcher + M, N, K = size, size, size + + if not dispatcher.is_supported(M, N, K): + cleanup_gemm() + return False, 0.0, 0.0 + + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Warmup + for _ in range(warmup): + dispatcher.run(A, B, M, N, K) + + # Benchmark + times = [] + for _ in range(iterations): + result = dispatcher.run(A, B, M, N, K) + if result.success: + times.append(result.time_ms) + + cleanup_gemm() + + if not times: + return False, 0.0, 0.0 + + avg_time = sum(times) / len(times) + tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12 + + return True, avg_time, tflops + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Stress Test - Multiple kernels with validation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 07_stress_test.py # Test all kernels + python3 07_stress_test.py --num-kernels 5 # Test first 5 kernels + python3 07_stress_test.py --benchmark # Include benchmarks + python3 07_stress_test.py --dtype bf16 # Test BF16 + python3 07_stress_test.py --size 2048 # Use 2048x2048 matrices + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--num-kernels", + type=int, + default=0, + help="Number of kernels to test (0 = all)", + ) + parser.add_argument( + "--size", + type=int, + default=512, + help="Problem size MxNxK (default: 512)", + ) + parser.add_argument( + "--benchmark", + action="store_true", + help="Include benchmark timing", + ) + parser.add_argument( + "--rtol", + type=float, + default=1e-2, + help="Relative tolerance (default: 1e-2)", + ) + parser.add_argument( + "--atol", + type=float, + default=1e-2, + help="Absolute tolerance (default: 1e-2)", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target architecture (default: gfx942)", + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 80) + print("Example 07: GEMM Stress Test - Multiple Kernels") + print("=" * 80) + + # Select kernels to test + specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS + + # Print kernel summary + print_kernel_summary(specs, args.dtype) + + # Run validation + print("\n" + "=" * 80) + print(" VALIDATION RESULTS") + print("=" * 80) + + validator = Validator(rtol=args.rtol, atol=args.atol) + + if args.benchmark: + print( + f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Time':>10} {'TFLOPS':>8} {'Status':<8}" + ) + else: + print( + f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Info':<25} {'Status':<8}" + ) + print(" " + "-" * 78) + + passed = 0 + failed = 0 + skipped = 0 + + for i, spec in enumerate(specs, 1): + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + + try: + is_valid, max_err, info = validate_kernel( + spec, args.dtype, args.arch, args.size, validator, kernel_index=i + ) + + if is_valid: + status = "PASS" + passed += 1 + else: + status = "FAIL" + failed += 1 + + if args.benchmark: + success, avg_time, tflops = benchmark_kernel( + spec, args.dtype, args.arch, args.size + ) + if success: + print( + f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {avg_time:>9.2f}ms {tflops:>7.1f} {status:<8}" + ) + else: + print( + f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {'N/A':>10} {'N/A':>8} {status:<8}" + ) + else: + print( + f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {info:<25} {status:<8}" + ) + + except Exception as e: + skipped += 1 + print( + f" {i:<3} {spec.name:<18} {tile:<12} {'N/A':>10} {str(e)[:25]:<25} {'SKIP':<8}" + ) + + # Summary + print("\n" + "=" * 80) + print(" SUMMARY") + print("=" * 80) + total = passed + failed + skipped + print(f"\n Results: {passed}/{total} passed, {failed} failed, {skipped} skipped") + print(f" Settings: dtype={args.dtype}, size={args.size}x{args.size}x{args.size}") + print(f" Tolerance: rtol={args.rtol}, atol={args.atol}") + print(f" Architecture: {args.arch}") + + if failed == 0 and skipped == 0: + print("\n *** ALL KERNELS PASSED ***") + elif failed > 0: + print(f"\n *** {failed} KERNELS FAILED ***") + + print("=" * 80) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/08_heuristics.py b/dispatcher/examples/gemm/python/08_heuristics.py new file mode 100644 index 0000000000..e2763c0513 --- /dev/null +++ b/dispatcher/examples/gemm/python/08_heuristics.py @@ -0,0 +1,718 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 08: Custom Heuristics + +Demonstrates custom kernel selection heuristics based on problem characteristics. + +This example shows how to: +1. Define multiple kernel configurations for different workloads +2. Implement custom heuristics to select the best kernel +3. Test heuristic selection across different problem sizes + +Heuristic strategies: +- Size-based: Small tiles for small problems, large tiles for large problems +- Compute-bound: Maximize compute utilization for large matrices +- Memory-bound: Optimize memory access for bandwidth-limited cases +- Latency-focused: Minimize kernel launch overhead for small problems + +Complexity: ★★★★☆ + +Usage: + python3 08_heuristics.py + python3 08_heuristics.py --help + python3 08_heuristics.py --strategy compute + python3 08_heuristics.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List +from enum import Enum + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +# ============================================================================= +# Kernel Specifications +# ============================================================================= + + +@dataclass +class KernelSpec: + """Kernel specification with metadata for heuristic selection""" + + name: str + tile_m: int + tile_n: int + tile_k: int + pipeline: str = "compv3" + scheduler: str = "intrawave" + # Metadata for heuristics + category: str = "balanced" # small, balanced, large, compute, memory + min_problem_size: int = 0 + max_problem_size: int = float("inf") + + +# Define kernel pool for heuristic selection (20+ kernels) +KERNEL_POOL = [ + # ========================================================================== + # SMALL TILES - Low latency, good for small problems + # ========================================================================== + KernelSpec( + "small_64x64_k32", + 64, + 64, + 32, + "compv3", + "intrawave", + category="small", + max_problem_size=256 * 256, + ), + KernelSpec( + "small_64x64_k64", + 64, + 64, + 64, + "compv3", + "intrawave", + category="small", + max_problem_size=256 * 256, + ), + KernelSpec( + "small_64x64_v4", + 64, + 64, + 32, + "compv4", + "intrawave", + category="small", + max_problem_size=256 * 256, + ), + # ========================================================================== + # MEDIUM TILES - Balanced performance + # ========================================================================== + KernelSpec( + "medium_128x128_k32", + 128, + 128, + 32, + "compv3", + "intrawave", + category="balanced", + min_problem_size=128 * 128, + max_problem_size=2048 * 2048, + ), + KernelSpec( + "medium_128x128_k64", + 128, + 128, + 64, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "medium_128x128_k128", + 128, + 128, + 128, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "medium_128x128_v4_k32", + 128, + 128, + 32, + "compv4", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "medium_128x128_v4_k64", + 128, + 128, + 64, + "compv4", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + # Rectangular medium tiles + KernelSpec( + "rect_64x128_k32", + 64, + 128, + 32, + "compv3", + "intrawave", + category="balanced", + min_problem_size=128 * 128, + ), + KernelSpec( + "rect_128x64_k32", + 128, + 64, + 32, + "compv3", + "intrawave", + category="balanced", + min_problem_size=128 * 128, + ), + KernelSpec( + "rect_64x128_k64", + 64, + 128, + 64, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "rect_128x64_k64", + 128, + 64, + 64, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + # ========================================================================== + # LARGE TILES - High throughput for large problems + # ========================================================================== + KernelSpec( + "large_256x128_k32", + 256, + 128, + 32, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_256x128_k64", + 256, + 128, + 64, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_128x256_k32", + 128, + 256, + 32, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_128x256_k64", + 128, + 256, + 64, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_256x256_k32", + 256, + 256, + 32, + "compv3", + "intrawave", + category="large", + min_problem_size=1024 * 1024, + ), + KernelSpec( + "large_256x256_k64", + 256, + 256, + 64, + "compv3", + "intrawave", + category="large", + min_problem_size=1024 * 1024, + ), + # ========================================================================== + # COMPUTE-OPTIMIZED - compv4 pipeline for compute-bound workloads + # ========================================================================== + KernelSpec( + "compute_128x128_v4_k32", + 128, + 128, + 32, + "compv4", + "intrawave", + category="compute", + min_problem_size=256 * 256, + ), + KernelSpec( + "compute_128x128_v4_k64", + 128, + 128, + 64, + "compv4", + "intrawave", + category="compute", + min_problem_size=256 * 256, + ), + KernelSpec( + "compute_256x128_v4", + 256, + 128, + 64, + "compv4", + "intrawave", + category="compute", + min_problem_size=512 * 512, + ), + KernelSpec( + "compute_256x256_v4", + 256, + 256, + 64, + "compv4", + "intrawave", + category="compute", + min_problem_size=1024 * 1024, + ), + # ========================================================================== + # MEMORY-OPTIMIZED - Good cache utilization for memory-bound workloads + # ========================================================================== + KernelSpec( + "memory_128x128_k16", + 128, + 128, + 16, + "compv3", + "intrawave", + category="memory", + min_problem_size=256 * 256, + ), + KernelSpec( + "memory_64x128_k16", + 64, + 128, + 16, + "compv3", + "intrawave", + category="memory", + min_problem_size=128 * 128, + ), +] + + +def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + """Create KernelConfig from spec""" + warp_m = 16 if spec.tile_m <= 64 else 32 + warp_n = 16 if spec.tile_n <= 64 else 32 + + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=spec.tile_m, + tile_n=spec.tile_n, + tile_k=spec.tile_k, + wave_m=2, + wave_n=2, + wave_k=1, + warp_m=warp_m, + warp_n=warp_n, + warp_k=16, + pipeline=spec.pipeline, + scheduler=spec.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +# ============================================================================= +# Heuristic Strategies +# ============================================================================= + + +class HeuristicStrategy(Enum): + SIZE_BASED = "size" + COMPUTE_BOUND = "compute" + MEMORY_BOUND = "memory" + LATENCY_FOCUSED = "latency" + + +def size_based_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel based on problem size. + - Small problems: Use small tiles for low latency + - Medium problems: Use balanced tiles + - Large problems: Use large tiles for high throughput + + Also considers K dimension for tile_k selection. + """ + total_elements = M * N + + # Filter by problem size constraints + candidates = [ + k for k in kernels if k.min_problem_size <= total_elements <= k.max_problem_size + ] + + if not candidates: + candidates = kernels # Fall back to all kernels + + # Determine target category based on problem size + if total_elements < 256 * 256: + target_category = "small" + elif total_elements < 1024 * 1024: + target_category = "balanced" + else: + target_category = "large" + + # Filter by category if possible + category_candidates = [k for k in candidates if k.category == target_category] + if category_candidates: + candidates = category_candidates + + # Select best tile_k based on K dimension + # Prefer tile_k that divides K well + def tile_k_score(k): + if K % k.tile_k == 0: + return 0 # Perfect division + return K % k.tile_k # Remainder (lower is better) + + # Sort by tile_k fit, then by tile size + candidates.sort(key=lambda k: (tile_k_score(k), -k.tile_m * k.tile_n)) + + return candidates[0] + + +def compute_bound_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel optimized for compute-bound workloads. + Prefers compv4 pipeline and larger tiles. + Selects based on problem size to maximize compute utilization. + """ + total_elements = M * N + + # Prefer compute category kernels + compute_kernels = [k for k in kernels if k.category == "compute"] + + if not compute_kernels: + # Fall back to compv4 kernels + compute_kernels = [k for k in kernels if k.pipeline == "compv4"] + + if not compute_kernels: + compute_kernels = kernels + + # Filter by problem size + valid = [k for k in compute_kernels if k.min_problem_size <= total_elements] + if valid: + compute_kernels = valid + + # For large problems, prefer larger tiles + if total_elements >= 1024 * 1024: + return max(compute_kernels, key=lambda k: k.tile_m * k.tile_n * k.tile_k) + else: + # For smaller problems, prefer medium tiles + return min( + compute_kernels, key=lambda k: abs(k.tile_m - 128) + abs(k.tile_n - 128) + ) + + +def memory_bound_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel optimized for memory-bound workloads. + Prefers smaller tile_k for better memory access patterns. + """ + # Prefer memory category kernels first + memory_kernels = [k for k in kernels if k.category == "memory"] + if memory_kernels: + # Select based on problem size + total = M * N + if total < 512 * 512: + return min(memory_kernels, key=lambda k: k.tile_m * k.tile_n) + return max(memory_kernels, key=lambda k: k.tile_m * k.tile_n) + + # Fall back to balanced with smaller tile_k + balanced = [k for k in kernels if k.category == "balanced"] + if balanced: + # Prefer smaller tile_k for memory-bound + return min(balanced, key=lambda k: k.tile_k) + + # Fall back to medium-sized tile with small tile_k + return min( + kernels, key=lambda k: (k.tile_k, abs(k.tile_m - 128) + abs(k.tile_n - 128)) + ) + + +def latency_focused_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel optimized for low latency. + Prefers smaller tiles and compv4 for faster execution. + """ + # Prefer small category + small_kernels = [k for k in kernels if k.category == "small"] + + if small_kernels: + # Among small kernels, prefer compv4 for lower latency + v4_small = [k for k in small_kernels if k.pipeline == "compv4"] + if v4_small: + return v4_small[0] + return small_kernels[0] + + # Fall back to smallest tile with compv4 if available + all_v4 = [k for k in kernels if k.pipeline == "compv4"] + if all_v4: + return min(all_v4, key=lambda k: k.tile_m * k.tile_n) + + # Fall back to smallest tile + return min(kernels, key=lambda k: k.tile_m * k.tile_n) + + +HEURISTICS = { + HeuristicStrategy.SIZE_BASED: size_based_heuristic, + HeuristicStrategy.COMPUTE_BOUND: compute_bound_heuristic, + HeuristicStrategy.MEMORY_BOUND: memory_bound_heuristic, + HeuristicStrategy.LATENCY_FOCUSED: latency_focused_heuristic, +} + + +# ============================================================================= +# Main +# ============================================================================= + + +def print_kernel_pool(kernels: List[KernelSpec]): + """Print available kernels""" + print("\n" + "=" * 75) + print(" KERNEL POOL") + print("=" * 75) + print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Category':<12}") + print(" " + "-" * 73) + + for i, k in enumerate(kernels, 1): + tile = f"{k.tile_m}x{k.tile_n}x{k.tile_k}" + print(f" {i:<3} {k.name:<22} {tile:<14} {k.pipeline:<10} {k.category:<12}") + + print(" " + "-" * 73) + + +def main(): + parser = argparse.ArgumentParser( + description="Custom Heuristics Example - intelligent kernel selection", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 08_heuristics.py # Default size-based heuristic + python3 08_heuristics.py --strategy compute # Compute-bound heuristic + python3 08_heuristics.py --strategy memory # Memory-bound heuristic + python3 08_heuristics.py --strategy latency # Latency-focused heuristic + python3 08_heuristics.py --dtype bf16 # BF16 mode + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--strategy", + default="size", + choices=["size", "compute", "memory", "latency"], + help="Heuristic strategy (default: size)", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target architecture (default: gfx942)", + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 75) + print("Example 08: Custom Heuristics") + print("=" * 75) + + # Map strategy string to enum + strategy_map = { + "size": HeuristicStrategy.SIZE_BASED, + "compute": HeuristicStrategy.COMPUTE_BOUND, + "memory": HeuristicStrategy.MEMORY_BOUND, + "latency": HeuristicStrategy.LATENCY_FOCUSED, + } + strategy = strategy_map[args.strategy] + heuristic_fn = HEURISTICS[strategy] + + print(f"\n Strategy: {strategy.value}") + print(f" Data type: {args.dtype}") + + # Print kernel pool + print_kernel_pool(KERNEL_POOL) + + # ========================================================================= + # Test heuristic selection across different problem sizes + # ========================================================================= + print("\n" + "=" * 75) + print(" HEURISTIC SELECTION TEST") + print("=" * 75) + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + test_sizes = [ + (128, 128, 64), # Small + (256, 256, 128), # Small-medium + (512, 512, 256), # Medium + (1024, 1024, 512), # Medium-large + (2048, 2048, 1024), # Large + ] + + print( + f"\n {'Size':<20} {'Selected Kernel':<25} {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}" + ) + print(" " + "-" * 78) + + results = [] + + for M, N, K in test_sizes: + # Use heuristic to select kernel + selected_spec = heuristic_fn(M, N, K, KERNEL_POOL) + + # Create config and setup + config = create_kernel_config(selected_spec, args.dtype, args.arch) + + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"heuristic_{selected_spec.name}", + verbose=False, + auto_rebuild=True, + ) + + size_str = f"{M}x{N}x{K}" + + if not setup.success: + print( + f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((size_str, selected_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + dispatcher = setup.dispatcher + + if not dispatcher.is_supported(M, N, K): + print( + f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'SKIP':<8}" + ) + results.append((size_str, selected_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + # Run GEMM + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + result = dispatcher.run(A, B, M, N, K) + + if not result.success: + print( + f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((size_str, selected_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + # Validate + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + max_err = np.max(np.abs(result.output - C_ref)) + passed = max_err < 1e-2 + + status = "PASS" if passed else "FAIL" + print( + f" {size_str:<20} {selected_spec.name:<25} {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}" + ) + results.append( + (size_str, selected_spec.name, passed, result.time_ms, result.tflops) + ) + + cleanup_gemm() + + # ========================================================================= + # Summary + # ========================================================================= + print("\n" + "=" * 75) + print(" SUMMARY") + print("=" * 75) + + passed = sum(1 for r in results if r[2]) + failed = len(results) - passed + + print(f"\n Strategy: {strategy.value}") + print(f" Results: {passed}/{len(results)} tests passed") + + # Show kernel selection distribution + kernel_usage = {} + for r in results: + kernel_usage[r[1]] = kernel_usage.get(r[1], 0) + 1 + + print("\n Kernel Selection Distribution:") + for kernel, count in sorted(kernel_usage.items(), key=lambda x: -x[1]): + print(f" {kernel}: {count} times") + + if results: + valid_results = [r for r in results if r[2]] + if valid_results: + avg_tflops = sum(r[4] for r in valid_results) / len(valid_results) + print(f"\n Average TFLOPS: {avg_tflops:.2f}") + + if failed == 0: + print("\n *** ALL TESTS PASSED ***") + else: + print(f"\n *** {failed} TESTS FAILED ***") + + print("=" * 75) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/09_multi_registry.py b/dispatcher/examples/gemm/python/09_multi_registry.py new file mode 100644 index 0000000000..97cbce3497 --- /dev/null +++ b/dispatcher/examples/gemm/python/09_multi_registry.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 09: Multiple Registries + +Demonstrates multiple registries for different optimization targets. + +Complexity: ★★★★★ + +Usage: + python3 09_multi_registry.py + python3 09_multi_registry.py --help + python3 09_multi_registry.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + Registry, + Dispatcher, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Multiple Registries Example - optimization-specific registries", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 09_multi_registry.py # Default FP16 + python3 09_multi_registry.py --dtype bf16 # BF16 mode + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 09: Multiple Registries") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup base dispatcher + # ========================================================================= + print("\nStep 1: Setup Base Dispatcher") + + base_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(base_config, registry_name="base", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + lib = setup.lib + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 2: Define configs for different optimization targets + # ========================================================================= + print("\nStep 2: Define Optimization Targets") + + compute_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=256, + tile_n=256, + tile_k=64, + wave_m=4, + wave_n=4, + pipeline="compv4", + gfx_arch=args.arch, + ) + memory_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + wave_m=2, + wave_n=2, + pipeline="compv4", + gfx_arch=args.arch, + ) + latency_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=64, + tile_n=64, + tile_k=32, + wave_m=1, + wave_n=1, + pipeline="compv3", + gfx_arch=args.arch, + ) + + print(f" Compute: {compute_config.tile_str} (large matrices)") + print(f" Memory: {memory_config.tile_str} (medium matrices)") + print(f" Latency: {latency_config.tile_str} (small matrices)") + + # ========================================================================= + # Step 3: Create registries + # ========================================================================= + print("\nStep 3: Create Registries") + + compute_registry = Registry(name="compute", lib=lib) + compute_registry.register_kernel(compute_config) + + memory_registry = Registry(name="memory", lib=lib) + memory_registry.register_kernel(memory_config) + + latency_registry = Registry(name="latency", lib=lib) + latency_registry.register_kernel(latency_config) + + # ========================================================================= + # Step 4: Create dispatchers + # ========================================================================= + print("\nStep 4: Create Dispatchers") + + compute_dispatcher = Dispatcher(registry=compute_registry, lib=lib) + memory_dispatcher = Dispatcher(registry=memory_registry, lib=lib) + latency_dispatcher = Dispatcher(registry=latency_registry, lib=lib) + + print(f" {compute_dispatcher}") + print(f" {memory_dispatcher}") + print(f" {latency_dispatcher}") + + # ========================================================================= + # Step 5: Smart dispatcher selection + # ========================================================================= + print("\nStep 5: Smart Dispatcher Selection") + + def select_dispatcher(M: int, N: int, K: int) -> Dispatcher: + elements = M * N + if elements >= 4096 * 4096: + return compute_dispatcher + elif elements >= 1024 * 1024: + return memory_dispatcher + else: + return latency_dispatcher + + test_sizes = [ + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + ] + + print(f"\n {'Size':<20} {'Registry':>10} {'Time (ms)':>12} {'TFLOPS':>10}") + print(" " + "-" * 55) + + for M, N, K in test_sizes: + dispatcher = select_dispatcher(M, N, K) + + if not dispatcher.is_supported(M, N, K): + continue + + A = np.random.randn(M, K).astype(np_dtype) * 0.1 + B = np.random.randn(K, N).astype(np_dtype) * 0.1 + + result = dispatcher.run(A, B, M, N, K) + + if result.success: + print( + f" {M}x{N}x{K:<10} {dispatcher.registry.name:>10} " + f"{result.time_ms:>12.4f} {result.tflops:>10.2f}" + ) + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + print("Multi-Registry Pattern:") + print("=" * 60) + print(" 1. Define KernelConfig for each optimization target") + print(" 2. Create Registry for each target") + print(" 3. Register configs to appropriate registries") + print(" 4. Create Dispatcher for each registry") + print(" 5. Select dispatcher based on problem characteristics") + print(" 6. Run GEMM with selected dispatcher") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/10_advanced_benchmark.py b/dispatcher/examples/gemm/python/10_advanced_benchmark.py new file mode 100644 index 0000000000..e16e4e271f --- /dev/null +++ b/dispatcher/examples/gemm/python/10_advanced_benchmark.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 10: Advanced Benchmarking with Full Control + +This example demonstrates all available benchmark parameters: + - warmup: Number of warmup iterations (default: 5) + - repeat: Number of benchmark iterations (default: 20) + - flush_cache: Flush GPU cache between iterations (default: False) + - timer: Timer type - "gpu" (default) or "cpu" + - init: Initialization method - "random", "linear", "constant" + +Usage: + python3 10_advanced_benchmark.py + python3 10_advanced_benchmark.py --warmup 10 --repeat 100 + python3 10_advanced_benchmark.py --init linear +""" + +import argparse +import sys +from pathlib import Path + +# Add paths for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Advanced GEMM benchmarking with full parameter control" + ) + + # Problem size + parser.add_argument("-m", type=int, default=2048, help="M dimension") + parser.add_argument("-n", type=int, default=2048, help="N dimension") + parser.add_argument("-k", type=int, default=2048, help="K dimension") + + # Benchmark parameters + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--repeat", type=int, default=20, help="Number of benchmark iterations" + ) + parser.add_argument( + "--flush-cache", action="store_true", help="Flush GPU cache between iterations" + ) + parser.add_argument( + "--timer", choices=["gpu", "cpu"], default="gpu", help="Timer type (gpu or cpu)" + ) + parser.add_argument( + "--init", + choices=["random", "linear", "constant"], + default="random", + help="Initialization method", + ) + + # Kernel configuration + parser.add_argument("--dtype", default="fp16", help="Data type") + parser.add_argument("--pipeline", default="compv4", help="Pipeline type") + parser.add_argument("--arch", default="gfx942", help="GPU architecture") + + return parser.parse_args() + + +def initialize_matrix(shape, method, dtype): + """Initialize matrix with specified method""" + if method == "random": + return np.random.randn(*shape).astype(dtype) * 0.5 + elif method == "linear": + total = np.prod(shape) + return np.arange(total).reshape(shape).astype(dtype) / total + elif method == "constant": + return np.ones(shape, dtype=dtype) + else: + return np.random.randn(*shape).astype(dtype) + + +def main(): + args = parse_args() + + reset_for_example() + + print("=" * 70) + print("Example 10: Advanced GEMM Benchmarking") + print("=" * 70) + + # Show benchmark configuration + print("\nBenchmark Configuration:") + print(f" Problem Size: {args.m} x {args.n} x {args.k}") + print(f" Warmup: {args.warmup} iterations") + print(f" Repeat: {args.repeat} iterations") + print(f" Flush Cache: {args.flush_cache}") + print(f" Timer: {args.timer}") + print(f" Init Method: {args.init}") + print(f" Data Type: {args.dtype}") + print(f" Pipeline: {args.pipeline}") + print(f" Architecture: {args.arch}") + print() + + # Map dtype + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # Initialize matrices + print("Step 1: Initialize matrices...") + A = initialize_matrix((args.m, args.k), args.init, np_dtype) + B = initialize_matrix((args.k, args.n), args.init, np_dtype) + print(f" A: {A.shape} ({args.init})") + print(f" B: {B.shape} ({args.init})") + + # Create kernel config (does not include M/N/K - those are problem size) + print("\nStep 2: Create kernel configuration...") + kernel_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", # B is column-major for optimal performance + layout_c="row", + tile_m=128, + tile_n=128, + tile_k=32, + wave_m=2, + wave_n=2, + wave_k=1, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline=args.pipeline, + scheduler="intrawave", + epilogue="cshuffle", + gfx_arch=args.arch, + ) + print(f" Config: {args.dtype}, tile=128x128x32, {args.pipeline}") + + # Setup dispatcher + print("\nStep 3: Setup dispatcher...") + setup = setup_gemm_dispatcher( + config=kernel_config, + registry_name="benchmark_gemm", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + print(f" Library: {setup.lib.path if setup.lib else 'N/A'}") + print(f" Kernel: {setup.lib.get_kernel_name() if setup.lib else 'N/A'}") + + # Run benchmark with multiple iterations + print("\nStep 4: Run benchmark...") + print(f" Running {args.warmup} warmup + {args.repeat} benchmark iterations...") + + # Warmup + for _ in range(args.warmup): + _ = dispatcher.run(A, B, args.m, args.n, args.k) + + # Benchmark + times = [] + for _ in range(args.repeat): + result = dispatcher.run(A, B, args.m, args.n, args.k) + if result.success: + times.append(result.time_ms) + + if times: + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + + # Calculate TFLOPS + flops = 2 * args.m * args.n * args.k + avg_tflops = (flops / 1e12) / (avg_time / 1000) if avg_time > 0 else 0 + max_tflops = (flops / 1e12) / (min_time / 1000) if min_time > 0 else 0 + + # Calculate bandwidth (C has same dtype as A and B) + C_bytes = args.m * args.n * np.dtype(np_dtype).itemsize + bandwidth_gb = ( + (A.nbytes + B.nbytes + C_bytes) / 1e9 / (avg_time / 1000) + if avg_time > 0 + else 0 + ) + + print(f"\n *** BENCHMARK RESULTS ({args.repeat} iterations) ***") + print(f" Average Time: {avg_time:.4f} ms") + print(f" Min Time: {min_time:.4f} ms") + print(f" Max Time: {max_time:.4f} ms") + print(f" Avg TFLOPS: {avg_tflops:.2f}") + print(f" Peak TFLOPS: {max_tflops:.2f}") + print(f" Bandwidth: {bandwidth_gb:.2f} GB/s") + else: + print(" FAILED: No successful runs") + return 1 + + # Summary + print("\n" + "=" * 70) + print("BENCHMARK PARAMETERS REFERENCE") + print("=" * 70) + print(""" +Available parameters for GEMM benchmarking: + + --warmup N Number of warmup iterations (discard results) + Higher = more stable results, longer run time + Default: 5 + + --repeat N Number of benchmark iterations + Higher = more accurate average, longer run time + Default: 20 + + --flush-cache Flush GPU L2 cache between iterations + Use for memory-bound benchmarks + Default: off + + --timer {gpu,cpu} Timer type + gpu = HIP events (more accurate for GPU) + cpu = std::chrono (includes kernel launch overhead) + Default: gpu + + --init METHOD Matrix initialization + random = uniform random [-0.5, 0.5] + linear = sequential values + constant = all ones + Default: random + +Note: For C++ examples, these parameters are passed to stream_config: + + ck_tile::stream_config cfg{ + nullptr, // stream_id + true, // time_kernel + 1, // log_level + 5, // cold_niters (warmup) + 20, // nrepeat + true, // is_gpu_timer + false, // flush_cache + 1 // rotating_count + }; +""") + + # Cleanup + cleanup_gemm() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/11_json_import.py b/dispatcher/examples/gemm/python/11_json_import.py new file mode 100644 index 0000000000..06743af406 --- /dev/null +++ b/dispatcher/examples/gemm/python/11_json_import.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 11: JSON-based Kernel Configuration Import + +Demonstrates loading kernel configurations from JSON files, similar to tile_engine. +This enables easy customization of kernel sets without modifying code. + +Key Features: + - Load tile configs from JSON (compatible with tile_engine format) + - Generate kernel sets from configuration + - Use arch_filter validation on loaded configs + - Export to C++ DECL_KERNEL_SET format + +Complexity: ★★★☆☆ + +Usage: + python3 11_json_import.py + python3 11_json_import.py --config my_kernels.json + python3 11_json_import.py --export-cpp +""" + +import sys +import argparse +import json +from pathlib import Path + +# Add codegen to path for kernel_config_loader +script_dir = Path(__file__).parent.resolve() +sys.path.insert(0, str(script_dir.parent.parent.parent / "codegen")) +sys.path.insert(0, str(script_dir.parent.parent.parent / "python")) + +from kernel_config_loader import ( # noqa: E402 + load_kernel_configs, + KernelConfig, + generate_cpp_kernel_set_declaration, +) + +from ctypes_utils import ( # noqa: E402 + KernelConfig as DispatcherKernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, + validate_kernel_config, +) + +# Sample JSON configuration (embedded for demonstration) +SAMPLE_JSON_CONFIG = { + "_comment": "Sample kernel configuration for GEMM", + "kernel_set_name": "inference_kernels", + "datatype": {"a": "fp16", "b": "fp16", "c": "fp16", "acc": "fp32"}, + "layout": "rcr", + "tile_config": { + "tile_m": {"values": [128, 256]}, + "tile_n": {"values": [128, 256]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32]}, + "warp_tile_n": {"values": [32]}, + "warp_tile_k": {"values": [16]}, + }, + "trait_config": { + "pipeline": {"values": ["compv4"]}, + "scheduler": {"values": ["intrawave"]}, + "epilogue": {"values": ["cshuffle"]}, + "pad_m": {"values": [False]}, + "pad_n": {"values": [False]}, + "pad_k": {"values": [False]}, + }, + "gpu_targets": ["gfx942"], +} + + +def print_section(title: str): + """Print a section header""" + print(f"\n{'=' * 70}") + print(f" {title}") + print(f"{'=' * 70}\n") + + +def convert_to_dispatcher_config( + config: KernelConfig, arch: str = "gfx942" +) -> DispatcherKernelConfig: + """Convert kernel_config_loader.KernelConfig to dispatcher KernelConfig""" + return DispatcherKernelConfig( + dtype_a=config.dtype_a, + dtype_b=config.dtype_b, + dtype_c=config.dtype_c, + dtype_acc=config.dtype_acc, + tile_m=config.tile.tile_m, + tile_n=config.tile.tile_n, + tile_k=config.tile.tile_k, + wave_m=config.tile.warp_m, + wave_n=config.tile.warp_n, + wave_k=config.tile.warp_k, + warp_m=config.tile.warp_tile_m, + warp_n=config.tile.warp_tile_n, + warp_k=config.tile.warp_tile_k, + pipeline=config.trait.pipeline, + scheduler=config.trait.scheduler, + epilogue=config.trait.epilogue, + pad_m=config.trait.pad_m, + pad_n=config.trait.pad_n, + pad_k=config.trait.pad_k, + gfx_arch=arch, + variant=config.variant, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="JSON Kernel Configuration Import Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 11_json_import.py # Use embedded sample config + python3 11_json_import.py --config my.json # Load from file + python3 11_json_import.py --export-cpp # Generate C++ declarations + python3 11_json_import.py --validate # Validate configs against arch + """, + ) + parser.add_argument( + "--config", + type=str, + help="Path to JSON configuration file (uses embedded sample if not provided)", + ) + parser.add_argument( + "--export-cpp", + action="store_true", + help="Export kernel set as C++ DECL_KERNEL_SET", + ) + parser.add_argument( + "--validate", + action="store_true", + help="Validate all configurations against arch filter", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target GPU architecture (default: gfx942)", + ) + args = parser.parse_args() + + reset_for_example() + + print_section("Example 11: JSON Kernel Configuration Import") + + # ========================================================================= + # Step 1: Load configuration from JSON + # ========================================================================= + print("Step 1: Load Kernel Configuration from JSON") + print("-" * 50) + + if args.config: + config_path = Path(args.config) + if not config_path.exists(): + print(f" ERROR: Config file not found: {config_path}") + return 1 + print(f" Loading from: {config_path}") + config_set = load_kernel_configs(config_path) + else: + # Use embedded sample config + print(" Using embedded sample configuration") + # Write to temp file and load + temp_path = Path("/tmp/sample_gemm_config.json") + with open(temp_path, "w") as f: + json.dump(SAMPLE_JSON_CONFIG, f, indent=2) + config_set = load_kernel_configs(temp_path) + + print(f"\n Kernel Set Name: {config_set.name}") + print( + f" Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}" + ) + print(f" Layout: {config_set.layout}") + print(f" GPU Targets: {config_set.gpu_targets}") + print(f" Total Configurations: {config_set.config_count()}") + + # ========================================================================= + # Step 2: Display configuration details + # ========================================================================= + print("\nStep 2: Configuration Details") + print("-" * 50) + + print("\n Tile Configurations:") + print(f" tile_m: {config_set.tile_m_values}") + print(f" tile_n: {config_set.tile_n_values}") + print(f" tile_k: {config_set.tile_k_values}") + print( + f" warp (wave): {config_set.warp_m_values}x{config_set.warp_n_values}x{config_set.warp_k_values}" + ) + print( + f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}" + ) + + print("\n Trait Configurations:") + print(f" pipeline: {config_set.pipeline_values}") + print(f" scheduler: {config_set.scheduler_values}") + print(f" epilogue: {config_set.epilogue_values}") + print( + f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}" + ) + + # ========================================================================= + # Step 3: Generate and display kernel names + # ========================================================================= + print("\nStep 3: Generated Kernel Names") + print("-" * 50) + + configs = list(config_set.generate_configs()) + for i, config in enumerate(configs[:5]): + print(f" {i + 1}. {config.kernel_name()}") + if len(configs) > 5: + print(f" ... and {len(configs) - 5} more configurations") + + # ========================================================================= + # Step 4: Validate against arch filter (optional) + # ========================================================================= + if args.validate: + print("\nStep 4: Architecture Validation") + print("-" * 50) + + valid_count = 0 + invalid_count = 0 + + for config in configs: + disp_config = convert_to_dispatcher_config(config, args.arch) + result = validate_kernel_config(disp_config) + + if result.is_valid: + valid_count += 1 + else: + invalid_count += 1 + if invalid_count <= 3: # Show first 3 invalid + print(f"\n ✗ Invalid: {config.kernel_name()}") + for error in result.errors: + print(f" Error: {error}") + + print("\n Validation Summary:") + print(f" ✓ Valid: {valid_count}") + print(f" ✗ Invalid: {invalid_count}") + print(f" Total: {len(configs)}") + + # ========================================================================= + # Step 5: Export to C++ (optional) + # ========================================================================= + if args.export_cpp: + print("\nStep 5: C++ Export") + print("-" * 50) + print("\n // Generated DECL_KERNEL_SET from JSON config:") + print(" // " + "=" * 56) + cpp_code = generate_cpp_kernel_set_declaration(config_set) + for line in cpp_code.split("\n"): + print(f" {line}") + + # ========================================================================= + # Step 6: Use first config with dispatcher (demo) + # ========================================================================= + print("\nStep 6: Dispatcher Integration Demo") + print("-" * 50) + + if configs: + first_config = configs[0] + disp_config = convert_to_dispatcher_config(first_config, args.arch) + + print( + f"\n Using first config: {first_config.tile.tile_m}x{first_config.tile.tile_n}x{first_config.tile.tile_k}" + ) + + setup = setup_gemm_dispatcher( + disp_config, registry_name="json_import", verbose=False + ) + if setup.success: + print(" ✓ Dispatcher setup successful") + print( + f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}" + ) + else: + print(f" ⚠ Dispatcher setup: {setup.error}") + print(" (This is expected if kernels aren't generated)") + + # ========================================================================= + # Summary + # ========================================================================= + print_section("Summary") + print(" JSON configuration allows easy kernel set customization:") + print(" - Define tile sizes and ranges") + print(" - Specify trait combinations (pipeline, scheduler, etc.)") + print(" - Target multiple GPU architectures") + print(" - Export to C++ DECL_KERNEL_SET for static compilation") + print() + print(" JSON Format (tile_engine compatible):") + print(' {"tile_config": {"tile_m": {"values": [128, 256]}, ...},') + print(' "trait_config": {"pipeline": {"values": ["compv4"]}, ...}}') + print() + print(" Usage:") + print(" config_set = load_kernel_configs('my_kernels.json')") + print(" for config in config_set.generate_configs():") + print(" # Use config for codegen or dispatcher setup") + + cleanup_gemm() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/README.md b/dispatcher/examples/gemm/python/README.md new file mode 100644 index 0000000000..0a83f3533f --- /dev/null +++ b/dispatcher/examples/gemm/python/README.md @@ -0,0 +1,299 @@ +# GEMM Python Examples + +CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations. + +> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md) + +## Quick Start + +### Build Library + +```bash +cd /path/to/composable_kernel/dispatcher +mkdir -p build && cd build + +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Build Python library (kernels generated automatically) +make dispatcher_gemm_lib -j$(nproc) +``` + +### Run Examples + +```bash +cd /path/to/composable_kernel/dispatcher + +python3 examples/gemm/python/01_basic_gemm.py +python3 examples/gemm/python/04_validation.py +python3 examples/gemm/python/07_stress_test.py +python3 examples/gemm/python/08_heuristics.py +``` + +## Examples + +| Example | Description | +|---------|-------------| +| [01_basic_gemm.py](01_basic_gemm.py) | Basic GEMM with multi-kernel support | +| [02_batch_gemm.py](02_batch_gemm.py) | Batched GEMM operations | +| [03_benchmark.py](03_benchmark.py) | Performance benchmarking | +| [04_validation.py](04_validation.py) | CPU reference validation | +| [05_numpy_integration.py](05_numpy_integration.py) | NumPy array integration | +| [06_json_export.py](06_json_export.py) | Registry JSON export | +| [07_stress_test.py](07_stress_test.py) | Multi-kernel stress testing | +| [08_heuristics.py](08_heuristics.py) | Heuristic-based kernel selection | +| [09_multi_registry.py](09_multi_registry.py) | Multiple registries | +| [10_advanced_benchmark.py](10_advanced_benchmark.py) | Advanced benchmark with full control | +| [11_json_import.py](11_json_import.py) | Import kernels from JSON | + +## Example Details + +### 01_basic_gemm.py - Basic GEMM +Demonstrates the Python API with multi-kernel support: + +```python +from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table + +# Define multiple kernel configurations +kernels = [ + KernelConfig( + tile_m=128, tile_n=128, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv3", scheduler="intrawave" + ), + KernelConfig( + tile_m=256, tile_n=256, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave" + ), +] + +# Display configurations +print_kernel_config_table(kernels) + +# Set up dispatcher with all kernels +lib, dispatcher, registry = setup_gemm_dispatcher(kernels) + +# Run GEMM +elapsed_ms = run_gemm(lib, M, N, K, ...) +``` + +### 02_batch_gemm.py - Batch GEMM +Batched matrix multiplication: +- Multiple independent GEMM operations +- Batch dimension handling + +### 03_benchmark.py - Benchmarking +Performance measurement: +- GPU timing +- TFLOPS calculation +- Multiple iterations + +### 04_validation.py - Validation +Correctness verification: +- NumPy reference implementation +- Tolerance-based validation +- Error reporting + +### 05_numpy_integration.py - NumPy Integration +Seamless NumPy integration: +- NumPy arrays to GPU buffers +- Results back to NumPy +- Automatic type conversion + +### 06_json_export.py - JSON Export +Registry serialization for tool integration: +- Export kernel configurations +- Machine-readable format + +### 07_stress_test.py - Stress Testing +Comprehensive multi-kernel stress testing: + +```python +from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table + +# Define 48 unique kernel configurations +kernels = [ + KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...), + KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...), + KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...), + # ... many more configurations +] + +# Test each kernel +for i, kernel in enumerate(kernels): + lib, dispatcher, registry = setup_gemm_dispatcher([kernel]) + result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel + print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}") +``` + +**Features:** +- 48 unique kernel configurations +- Various tile sizes, pipelines, and schedulers +- Per-kernel validation with unique random seeds +- Performance reporting + +### 08_heuristics.py - Heuristic Selection +Custom kernel selection based on problem characteristics: + +```python +# Define kernel pools for different strategies +SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...] +LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...] +COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...] +MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...] + +# Size-based heuristic +def size_based_heuristic(M, N, K): + if M * N < 512 * 512: + return SMALL_KERNELS + else: + return LARGE_KERNELS + +# Strategy-based selection +def compute_strategy(): + return COMPUTE_KERNELS # Optimized for compute-bound problems + +def memory_strategy(): + return MEMORY_KERNELS # Optimized for memory-bound problems + +# Test different strategies +for strategy in [size_based_heuristic, compute_strategy, memory_strategy]: + kernels = strategy(M, N, K) + lib, dispatcher, registry = setup_gemm_dispatcher(kernels) + elapsed_ms = run_gemm(lib, M, N, K, ...) +``` + +**Features:** +- 24 kernel configurations across 6 categories +- Size-based heuristic (small vs large) +- Optimization strategies (compute, memory, latency) +- Performance comparison across strategies + +### 09_multi_registry.py - Multiple Registries +Separate registries for different workloads: +- Compute-optimized registry +- Latency-optimized registry +- Dynamic registry selection + +### 10_advanced_benchmark.py - Advanced Benchmark +Full control over benchmark parameters: +- Warmup iterations +- Benchmark iterations +- Statistical analysis + +### 11_json_import.py - JSON Import +Import kernel configurations from JSON: +- External configuration files +- Dynamic kernel loading + +## Utility Module: ctypes_utils.py + +```python +from ctypes_utils import ( + KernelConfig, # Single kernel configuration + setup_gemm_dispatcher, # Set up dispatcher with kernels + print_kernel_config_table, # Display kernel configurations + Dispatcher, # High-level dispatcher + Registry, # Kernel registry + Validator, # Validation utilities +) +``` + +### KernelConfig + +```python +config = KernelConfig( + # Tile sizes + tile_m=256, tile_n=256, tile_k=32, + # Wave configuration + wave_m=2, wave_n=2, wave_k=1, + # Warp tile sizes + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + # Pipeline and scheduler + pipeline="compv4", # "compv3" or "compv4" + scheduler="intrawave", # "intrawave" or "interwave" + # Optional + epilogue="default", + padding=True, + double_buffer=True, +) +``` + +### setup_gemm_dispatcher + +```python +# Single kernel +lib, dispatcher, registry = setup_gemm_dispatcher(config) + +# Multiple kernels +lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...]) + +# With auto-rebuild +lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True) +``` + +### print_kernel_config_table + +```python +kernels = [config1, config2, config3] +print_kernel_config_table(kernels) +# Output: +# +----+-------+-------+-------+--------+-----------+ +# | # | Tile | Wave | Warp | Pipe | Scheduler | +# +----+-------+-------+-------+--------+-----------+ +# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave | +# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave | +# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave | +# +----+-------+-------+-------+--------+-----------+ +``` + +### GPU Memory Management + +```python +import ctypes +import numpy as np + +# Load HIP library +hip = ctypes.CDLL("libamdhip64.so") + +# Allocate GPU memory +gpu_ptr = ctypes.c_void_p() +hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes) + +# Copy to GPU (1 = hipMemcpyHostToDevice) +hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1) + +# Copy back (2 = hipMemcpyDeviceToHost) +hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2) + +# Free +hip.hipFree(gpu_ptr) +``` + +## Performance Testing + +Test compilation performance with different kernel counts: + +```bash +# Test with 10 kernels (~15s compile time) +python3 01_basic_gemm.py --num-kernels 10 + +# Test with 20 kernels (~25s compile time) +python3 01_basic_gemm.py --num-kernels 20 + +# Test with 48 kernels (~50s compile time) +python3 01_basic_gemm.py --num-kernels 48 +``` + +Compilation time scales roughly linearly with kernel count. + +## Related Documentation + +- [C++ GEMM Examples](../cpp/README.md) +- [Python Conv Examples](../../conv/python/README.md) +- [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/gemm/python/kernels.json b/dispatcher/examples/gemm/python/kernels.json new file mode 100644 index 0000000000..214b1cc42c --- /dev/null +++ b/dispatcher/examples/gemm/python/kernels.json @@ -0,0 +1,80 @@ +{ + "registry": "export_demo", + "kernel_count": 3, + "kernels": [ + { + "tile": "128x128x32", + "dtypes": { + "A": "fp16", + "B": "fp16", + "C": "fp16" + }, + "layout": "rcr", + "pipeline": "compv4", + "target": "gfx942" + }, + { + "tile": "256x256x64", + "dtypes": { + "A": "fp16", + "B": "fp16", + "C": "fp16" + }, + "layout": "rcr", + "pipeline": "compv4", + "target": "gfx942" + }, + { + "tile": "64x64x32", + "dtypes": { + "A": "fp16", + "B": "fp16", + "C": "fp16" + }, + "layout": "rcr", + "pipeline": "compv4", + "target": "gfx942" + } + ], + "cpp_registry": { + "metadata": { + "timestamp": "Dec 4 2025 06:23:15", + "total_kernels": 1, + "export_version": "1.0", + "dispatcher_version": "1.0.0" + }, + "statistics": { + "by_datatype": {}, + "by_pipeline": {}, + "by_scheduler": {} + }, + "kernels": [ + { + "identifier": "128x128x32_2x2x1_32x32x16_nopers", + "name": "gemm_fp16_rcrr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16", + "algorithm": { + "tile_shape": { + "m": 128, + "n": 128, + "k": 32 + }, + "wave_shape": { + "m": 2, + "n": 2, + "k": 1 + }, + "warp_tile_shape": { + "m": 32, + "n": 32, + "k": 16 + }, + "block_size": 256, + "persistent": false, + "double_buffer": true, + "preshuffle": false, + "transpose_c": false + } + } + ] + } +} \ No newline at end of file diff --git a/dispatcher/include/ck_tile/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher.hpp new file mode 100644 index 0000000000..98d8bb9333 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher.hpp @@ -0,0 +1,19 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +/// Main dispatcher header - includes all core components +/// Use this for convenient access to the full dispatcher API + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/kernel_config.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" +#include "ck_tile/dispatcher/utils.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher/README.md b/dispatcher/include/ck_tile/dispatcher/README.md new file mode 100644 index 0000000000..db3ce996a9 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/README.md @@ -0,0 +1,161 @@ +# CK Tile Dispatcher - C++ Headers + +C++ API for the CK Tile dispatcher. + +> **See also:** [Main Dispatcher README](../../../../README.md) for installation and core concepts. + +## File Organization + +``` +dispatcher/ +├── dispatcher.hpp # Main dispatcher (kernel selection) +├── registry.hpp # Kernel registry (storage & lookup) +├── problem.hpp # Problem specification +├── kernel_key.hpp # Kernel configuration key +├── kernel_instance.hpp # Kernel instance interface +├── utils.hpp # Utilities (timers, GPU buffers) +│ +└── backends/ # Backend implementations + ├── generated_tile_backend.hpp # CK Tile kernels (production) + └── tile_backend.hpp # Tile backend base +``` + +## Quick Start + +```cpp +#include "ck_tile/dispatcher.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +int main() { + // 1. Build kernel key + KernelKeyBuilder builder = KernelKeyBuilder::fp16_rcr(); + builder.tile_m = 128; + builder.tile_n = 128; + builder.tile_k = 32; + KernelKey key = builder.build(); + + // 2. Register kernel + auto kernel = create_generated_tile_kernel<...>(key, "my_kernel"); + Registry::instance().register_kernel(kernel, Priority::High); + + // 3. Run GEMM + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + float time_ms = dispatcher.run(a_ptr, b_ptr, c_ptr, problem, nullptr); +} +``` + +## Core Classes + +### KernelKey (`kernel_key.hpp`) + +Uniquely identifies a kernel configuration: + +```cpp +KernelKeyBuilder builder; +builder.dtype_a = DataType::FP16; +builder.layout_a = LayoutTag::Row; +builder.tile_m = 256; +builder.pipeline = Pipeline::CompV4; +KernelKey key = builder.build(); +``` + +### Registry (`registry.hpp`) + +Thread-safe kernel storage: + +```cpp +auto& registry = Registry::instance(); +registry.register_kernel(kernel, Priority::High); +registry.get_kernel_count(); +registry.export_json(); +``` + +### Dispatcher (`dispatcher.hpp`) + +Kernel selection and execution: + +```cpp +Dispatcher dispatcher; + +// Strategies +dispatcher.set_strategy(SelectionStrategy::FirstFit); +dispatcher.set_strategy(SelectionStrategy::Heuristic); + +// Run +float time = dispatcher.run(a, b, c, problem, stream); +``` + +### Problem (`problem.hpp`) + +GEMM problem specification: + +```cpp +Problem problem(M, N, K); +problem.batch_size = 4; +problem.alpha = 1.0f; +problem.beta = 0.0f; + +// Auto-inference +auto p = Problem::from_ab(a_rows, a_cols, b_rows, b_cols, trans_a, trans_b); +``` + +## Utilities (`utils.hpp`) + +### GPU Memory + +```cpp +GpuBuffer buffer(size); +buffer.copy_from_host(host_ptr); +buffer.copy_to_host(host_ptr); +buffer.zero(); +``` + +### Timing + +```cpp +GpuTimer timer; +timer.start(); +// kernel... +timer.stop(); +float ms = timer.elapsed_ms(); +``` + +### Quick Helpers + +```cpp +// Create FP16 RCR key +auto key = create_fp16_rcr_key(tile_m, tile_n, tile_k, ...); + +// Performance +double tflops = calculate_tflops(M, N, K, time_ms); + +// Validation +auto result = validate_result(gpu_ptr, cpu_ptr, size); +``` + +## Backend + +### Generated Tile Backend + +```cpp +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +auto kernel = create_generated_tile_kernel< + SelectedKernel, ADataType, BDataType, CDataType, AccDataType +>(key, name); +``` + +## Best Practices + +1. Use `Release` build for performance +2. Register kernels at startup +3. Use `Priority::High` for hand-tuned kernels +4. Reuse dispatcher instances +5. Clear registry between test runs + +--- + +> **More info:** See [../../../../README.md](../../../../README.md) for full documentation. diff --git a/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp b/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp new file mode 100644 index 0000000000..33a864a649 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp @@ -0,0 +1,393 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Architecture-Specific Kernel Filtering for CK Tile Dispatcher + * + * Provides GPU architecture-aware validation of kernel configurations. + * Uses arch_specs_generated.hpp as single source of truth (generated from arch_specs.json). + * + * Usage: + * ArchFilter filter("gfx942"); + * + * // Check if a kernel configuration is valid + * if (filter.is_valid(kernel_key)) { + * registry.register_kernel(kernel); + * } + * + * // Get validation result with error details + * auto result = filter.validate(kernel_key); + * if (!result.valid) { + * for (const auto& error : result.errors) { + * std::cerr << error << "\n"; + * } + * } + * + * Adding New GPU Support: + * 1. Edit dispatcher/codegen/arch_specs.json + * 2. Run: python dispatcher/codegen/generate_arch_specs.py + * 3. Rebuild the dispatcher + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/arch_specs_generated.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +// ============================================================================= +// Re-export from generated header for convenience +// ============================================================================= + +// Use the generated types and functions from arch_specs namespace +using GpuArch = arch_specs::GpuArch; +using WarpConfig = arch_specs::WarpConfig; +using WarpTileConfig = std::array; + +// Re-export string conversion functions +using arch_specs::arch_to_string; +using arch_specs::element_size; +using arch_specs::get_lds_capacity; +using arch_specs::get_supported_warp_configs; +using arch_specs::is_trait_unsupported; +using arch_specs::string_to_arch; + +// ============================================================================= +// Additional Helper Functions +// ============================================================================= + +/// Get supported warp tile configurations for arch and data types +/// This function wraps the generated data with runtime logic +inline std::vector get_supported_warp_tiles(GpuArch arch, + DataType dtype_a, + DataType dtype_b, + [[maybe_unused]] DataType dtype_c) +{ + // Common FP16 configurations (from arch_specs.json) + std::vector fp16_configs = { + {32, 32, 8}, {16, 16, 16}, {32, 32, 16}, {16, 16, 32}, {4, 64, 16}, {64, 4, 16}}; + + // FP8 configurations + std::vector fp8_gfx942 = { + {32, 32, 16}, {32, 32, 32}, {16, 16, 32}, {16, 16, 64}}; + std::vector fp8_gfx950 = { + {32, 32, 16}, {32, 32, 32}, {16, 16, 32}, {16, 16, 64}, {16, 16, 128}, {32, 32, 64}}; + + // INT8 configurations + std::vector int8_configs = {{16, 16, 32}, {32, 32, 16}}; + + // GFX1201 only supports limited FP16 + std::vector rdna4_fp16 = {{16, 16, 16}}; + + // Match based on architecture and data types + if(dtype_a == DataType::FP16 && dtype_b == DataType::FP16) + { + if(arch == GpuArch::GFX_1201) + return rdna4_fp16; + return fp16_configs; + } + if(dtype_a == DataType::BF16 && dtype_b == DataType::BF16) + { + if(arch == GpuArch::GFX_1201) + return {}; // Not supported on RDNA4 + return fp16_configs; // Same as FP16 + } + if(dtype_a == DataType::FP8 || dtype_a == DataType::BF8) + { + if(arch == GpuArch::GFX_950) + return fp8_gfx950; + if(arch == GpuArch::GFX_942) + return fp8_gfx942; + if(arch == GpuArch::GFX_90A) + return {{32, 32, 16}, {32, 32, 32}}; + } + if(dtype_a == DataType::INT8 && dtype_b == DataType::INT8) + { + if(arch == GpuArch::GFX_942) + return int8_configs; + } + + return {}; // Unknown combination +} + +// ============================================================================= +// Validation Result +// ============================================================================= + +/// Result of kernel validation +struct ValidationResult +{ + bool valid = true; + std::vector errors; + std::vector warnings; + + explicit operator bool() const { return valid; } + + void add_error(const std::string& msg) + { + errors.push_back(msg); + valid = false; + } + + void add_warning(const std::string& msg) { warnings.push_back(msg); } +}; + +// ============================================================================= +// Architecture Filter +// ============================================================================= + +/** + * Architecture-specific kernel filter. + * + * Validates kernel configurations against GPU architecture constraints + * including warp configurations, warp tiles, LDS capacity, and traits. + */ +class ArchFilter +{ + public: + /** + * Create architecture filter. + * @param arch Target GPU architecture + * @param strict_mode If true, unknown configurations are rejected + */ + explicit ArchFilter(GpuArch arch, bool strict_mode = false) + : arch_(arch), strict_mode_(strict_mode) + { + } + + /** + * Create architecture filter from string. + * @param arch_str GPU architecture string (e.g., "gfx942") + * @param strict_mode If true, unknown configurations are rejected + */ + explicit ArchFilter(const std::string& arch_str, bool strict_mode = false) + : arch_(string_to_arch(arch_str)), strict_mode_(strict_mode) + { + } + + /** + * Quick validation check. + * @param key Kernel configuration key + * @return true if configuration is valid for this architecture + */ + [[nodiscard]] bool is_valid(const KernelKey& key) const { return validate(key).valid; } + + /** + * Detailed validation with error messages. + * @param key Kernel configuration key + * @return ValidationResult with valid flag and error/warning messages + */ + [[nodiscard]] ValidationResult validate(const KernelKey& key) const + { + ValidationResult result; + + // Check architecture match + if(!key.gfx_arch.empty() && string_to_arch(key.gfx_arch) != arch_) + { + result.add_warning("Kernel compiled for different architecture: " + key.gfx_arch); + } + + // Validate dimensions + validate_dimensions(key, result); + + // Validate warp configuration + validate_warp_config(key, result); + + // Validate warp tile configuration + validate_warp_tiles(key, result); + + // Validate trait combination + validate_traits(key, result); + + // Validate LDS capacity + validate_lds(key, result); + + return result; + } + + /// Get target architecture + [[nodiscard]] GpuArch get_arch() const { return arch_; } + + /// Get target architecture as string + [[nodiscard]] std::string get_arch_string() const { return arch_to_string(arch_); } + + private: + void validate_dimensions(const KernelKey& key, ValidationResult& result) const + { + const auto& alg = key.algorithm; + + // Check positive dimensions + if(alg.tile_shape.m <= 0 || alg.tile_shape.n <= 0 || alg.tile_shape.k <= 0) + { + result.add_error("Tile dimensions must be positive"); + return; + } + + // Check warp tiles fit in block tiles + int warp_m_coverage = alg.wave_shape.m * alg.warp_tile_shape.m; + int warp_n_coverage = alg.wave_shape.n * alg.warp_tile_shape.n; + int warp_k_coverage = alg.wave_shape.k * alg.warp_tile_shape.k; + + if(warp_m_coverage > alg.tile_shape.m) + { + result.add_error("warp_m * warp_tile_m > tile_m: " + std::to_string(warp_m_coverage) + + " > " + std::to_string(alg.tile_shape.m)); + } + if(warp_n_coverage > alg.tile_shape.n) + { + result.add_error("warp_n * warp_tile_n > tile_n: " + std::to_string(warp_n_coverage) + + " > " + std::to_string(alg.tile_shape.n)); + } + if(warp_k_coverage > alg.tile_shape.k) + { + result.add_error("warp_k * warp_tile_k > tile_k: " + std::to_string(warp_k_coverage) + + " > " + std::to_string(alg.tile_shape.k)); + } + + // Check alignment + if(alg.tile_shape.m % warp_m_coverage != 0) + { + result.add_error("tile_m must be divisible by warp_m * warp_tile_m"); + } + if(alg.tile_shape.n % warp_n_coverage != 0) + { + result.add_error("tile_n must be divisible by warp_n * warp_tile_n"); + } + if(alg.tile_shape.k % warp_k_coverage != 0) + { + result.add_error("tile_k must be divisible by warp_k * warp_tile_k"); + } + } + + void validate_warp_config(const KernelKey& key, ValidationResult& result) const + { + auto supported = get_supported_warp_configs(arch_); + if(supported.empty()) + { + if(strict_mode_) + { + result.add_error("No warp configurations defined for " + get_arch_string()); + } + else + { + result.add_warning("No warp configurations defined for " + get_arch_string()); + } + return; + } + + WarpConfig current = { + key.algorithm.wave_shape.m, key.algorithm.wave_shape.n, key.algorithm.wave_shape.k}; + + bool found = false; + for(const auto& cfg : supported) + { + if(cfg == current) + { + found = true; + break; + } + } + + if(!found) + { + result.add_error("Invalid warp configuration [" + std::to_string(current[0]) + ", " + + std::to_string(current[1]) + ", " + std::to_string(current[2]) + + "] for " + get_arch_string()); + } + } + + void validate_warp_tiles(const KernelKey& key, ValidationResult& result) const + { + auto supported = get_supported_warp_tiles( + arch_, key.signature.dtype_a, key.signature.dtype_b, key.signature.dtype_c); + + if(supported.empty()) + { + // Unknown data type combination - allow with warning + result.add_warning("No warp tile combinations defined for data types"); + return; + } + + WarpTileConfig current = {key.algorithm.warp_tile_shape.m, + key.algorithm.warp_tile_shape.n, + key.algorithm.warp_tile_shape.k}; + + bool found = false; + for(const auto& cfg : supported) + { + if(cfg == current) + { + found = true; + break; + } + } + + if(!found) + { + result.add_error("Invalid warp tile [" + std::to_string(current[0]) + ", " + + std::to_string(current[1]) + ", " + std::to_string(current[2]) + + "] for " + get_arch_string()); + } + } + + void validate_traits(const KernelKey& key, ValidationResult& result) const + { + if(is_trait_unsupported( + key.algorithm.pipeline, key.algorithm.epilogue, key.algorithm.scheduler)) + { + result.add_error("Unsupported trait combination"); + } + } + + void validate_lds(const KernelKey& key, ValidationResult& result) const + { + const auto& sig = key.signature; + const auto& alg = key.algorithm; + + float elem_a = element_size(sig.dtype_a); + float elem_b = element_size(sig.dtype_b); + + std::size_t matrix_a_size = alg.tile_shape.m * alg.tile_shape.k * elem_a; + std::size_t matrix_b_size = alg.tile_shape.n * alg.tile_shape.k * elem_b; + std::size_t total_lds = matrix_a_size + matrix_b_size; + + std::size_t max_lds = get_lds_capacity(alg.pipeline); + + if(total_lds > max_lds) + { + result.add_error("LDS capacity exceeded: " + std::to_string(total_lds) + " bytes > " + + std::to_string(max_lds) + " bytes limit"); + } + } + + GpuArch arch_; + bool strict_mode_; +}; + +// ============================================================================= +// Registry Integration Helper +// ============================================================================= + +/** + * Create a filter function for use with Registry::filter() + * + * @tparam KernelT Kernel instance type with get_key() method + * @param arch Target GPU architecture + * @return Predicate function that returns true for valid kernels + */ +template +inline auto make_arch_filter_predicate(const std::string& arch) +{ + return [filter = ArchFilter(arch)](const KernelT& kernel) { + return filter.is_valid(kernel.get_key()); + }; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp b/dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp new file mode 100644 index 0000000000..af52c8eb1d --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp @@ -0,0 +1,168 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + * + * Generated from: arch_specs.json + * Generated at: 2026-01-05T19:34:01.229811 + * + * To update this file: + * 1. Edit arch_specs.json + * 2. Run: python generate_arch_specs.py + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace arch_specs { + +// ============================================================================= +// GPU Architecture Enum (Generated) +// ============================================================================= + +enum class GpuArch : std::uint8_t +{ + GFX_908, // AMD Instinct MI100 + GFX_90A, // AMD Instinct MI200 series + GFX_942, // AMD Instinct MI300 series + GFX_950, // AMD Instinct MI350 series + GFX_1100, // AMD Radeon RX 7900 series (RDNA3) + GFX_1200, // AMD Radeon RX 9000 series (RDNA4) + GFX_1201, // AMD Radeon RX 9000 series (RDNA4) + UNKNOWN +}; + +// ============================================================================= +// String Conversion Functions (Generated) +// ============================================================================= + +inline std::string arch_to_string(GpuArch arch) +{ + switch(arch) + { + case GpuArch::GFX_908: return "gfx908"; + case GpuArch::GFX_90A: return "gfx90a"; + case GpuArch::GFX_942: return "gfx942"; + case GpuArch::GFX_950: return "gfx950"; + case GpuArch::GFX_1100: return "gfx1100"; + case GpuArch::GFX_1200: return "gfx1200"; + case GpuArch::GFX_1201: return "gfx1201"; + default: return "unknown"; + } +} + +inline GpuArch string_to_arch(const std::string& arch_str) +{ + if(arch_str == "gfx908") + return GpuArch::GFX_908; + if(arch_str == "gfx90a") + return GpuArch::GFX_90A; + if(arch_str == "gfx942") + return GpuArch::GFX_942; + if(arch_str == "gfx950") + return GpuArch::GFX_950; + if(arch_str == "gfx1100") + return GpuArch::GFX_1100; + if(arch_str == "gfx1200") + return GpuArch::GFX_1200; + if(arch_str == "gfx1201") + return GpuArch::GFX_1201; + return GpuArch::UNKNOWN; +} + +// ============================================================================= +// Element Size (Generated) +// ============================================================================= + +inline float element_size(DataType dtype) +{ + switch(dtype) + { + case DataType::FP16: return 2.0f; + case DataType::BF16: return 2.0f; + case DataType::FP32: return 4.0f; + case DataType::FP64: return 8.0f; + case DataType::FP8: return 1.0f; + case DataType::BF8: return 1.0f; + case DataType::INT8: return 1.0f; + case DataType::INT4: return 0.5f; + case DataType::INT32: return 4.0f; + default: return 2.0f; + } +} + +// ============================================================================= +// Warp Configurations (Generated) +// ============================================================================= + +using WarpConfig = std::array; + +inline std::vector get_supported_warp_configs(GpuArch arch) +{ + switch(arch) + { + case GpuArch::GFX_908: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_90A: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_942: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_950: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_1100: return {{2, 4, 1}, {1, 8, 1}, {8, 1, 1}, {4, 2, 1}}; + case GpuArch::GFX_1200: return {{2, 4, 1}, {1, 8, 1}, {8, 1, 1}, {4, 2, 1}}; + case GpuArch::GFX_1201: return {{2, 4, 1}, {1, 8, 1}, {8, 1, 1}, {4, 2, 1}}; + default: return {}; + } +} + +// ============================================================================= +// LDS Capacity Limits (Generated) +// ============================================================================= + +inline std::size_t get_lds_capacity(Pipeline pipeline) +{ + if(pipeline == Pipeline::Mem) + return 65536; + if(pipeline == Pipeline::CompV1) + return 65536; + if(pipeline == Pipeline::CompV2) + return 65536; + if(pipeline == Pipeline::CompV3) + return 65536; + if(pipeline == Pipeline::CompV4) + return 32768; + if(pipeline == Pipeline::CompV5) + return 65536; + if(pipeline == Pipeline::PreShuffleV1) + return 32768; + if(pipeline == Pipeline::PreShuffleV2) + return 32768; + return 65536; // Default +} + +// ============================================================================= +// Unsupported Trait Combinations (Generated) +// ============================================================================= + +inline bool +is_trait_unsupported(Pipeline pipeline, [[maybe_unused]] Epilogue epilogue, Scheduler scheduler) +{ + // Generated from unsupported_trait_combos in arch_specs.json + if(scheduler == Scheduler::Interwave) + { + if(pipeline == Pipeline::CompV3 || pipeline == Pipeline::CompV4) + { + return true; + } + } + return false; +} + +} // namespace arch_specs +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp new file mode 100644 index 0000000000..79f8f30a9b --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp @@ -0,0 +1,143 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Generated Kernel Backend + * + * Backend for kernels generated by unified_gemm_codegen.py + * with unique namespace wrapping (Kernel_{name}). + * + * Status: Work in progress - use generated_tile_backend.hpp for now + * + * This backend handles the new codegen format with unique kernel structs. + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/** + * Kernel instance wrapper for unified_gemm_codegen.py generated kernels + * + * These kernels have: + * - namespace {kernel_name}_ns { ... } (NEW format) + * - struct Kernel_{name} with static launch() method + * - struct SelectedKernel alias for compatibility + * - Type aliases: ADataType, BDataType, CDataType, AccDataType + * + * Note: Currently use generated_tile_backend.hpp for production + */ +template +class GeneratedKernelInstance : public KernelInstance +{ + public: + using SelectedKernel = SelectedKernelType; + using ADataType = typename SelectedKernel::ADataType; + using BDataType = typename SelectedKernel::BDataType; + using CDataType = typename SelectedKernel::CDataType; + using AccDataType = typename SelectedKernel::AccDataType; + + GeneratedKernelInstance(const KernelKey& key, const std::string& name) : key_(key), name_(name) + { + } + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + // Check dimension divisibility based on padding flags + constexpr bool pad_m = SelectedKernel::kPadM; + constexpr bool pad_n = SelectedKernel::kPadN; + constexpr bool pad_k = SelectedKernel::kPadK; + + if(pad_m && pad_n && pad_k) + { + return true; // Padding enabled - supports any size + } + + // Check divisibility for dimensions without padding + constexpr int tile_m = SelectedKernel::TileM; + constexpr int tile_n = SelectedKernel::TileN; + constexpr int tile_k = SelectedKernel::TileK; + + if(!pad_m && problem.M % tile_m != 0) + return false; + if(!pad_n && problem.N % tile_n != 0) + return false; + if(!pad_k && problem.K % tile_k != 0) + return false; + + return true; + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + (void)d_ptrs; // Not used in basic GEMM + + // Create arguments using constructor + ck_tile::GemmHostArgs args(a_ptr, // a_ptr + b_ptr, // b_ptr + c_ptr, // e_ptr/c_ptr + problem.k_batch, // k_batch + problem.M, // M + problem.N, // N + problem.K, // K + problem.K, // stride_A (row-major A: stride = K) + problem.K, // stride_B (column-major B: stride = K) + problem.N // stride_E/C (row-major C: stride = N) + ); + + // Create stream config for timing + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = reinterpret_cast(stream); + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = 0; + stream_cfg.cold_niters_ = 5; // Warmup iterations + stream_cfg.nrepeat_ = 10; // Measurement iterations + stream_cfg.is_gpu_timer_ = true; + stream_cfg.flush_cache_ = false; + stream_cfg.rotating_count_ = 1; + + // Call the generated kernel's launch method + return SelectedKernel::launch(args, stream_cfg); + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + (void)a_ptr; + (void)b_ptr; + (void)c_ptr; + (void)d_ptrs; + (void)problem; + (void)tolerance; + // Validation would require reference implementation + return true; + } + + private: + KernelKey key_; + std::string name_; +}; + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp new file mode 100644 index 0000000000..76565045cf --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp @@ -0,0 +1,157 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/validation/reference_kernels.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/** + * Kernel instance wrapper for unified_gemm_codegen.py generated kernels + * + * These kernels have structure: + * - Types defined outside: using ADataType = ...; using BDataType = ...; + * - struct SelectedKernel with static constexpr config and launch() method + * - constexpr const char* KERNEL_NAME = "..."; + * + * This is different from tile_engine style where everything is in SelectedKernel. + */ +template +class GeneratedTileKernelInstance : public KernelInstance +{ + public: + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + using AccDataType = AccDataType_; + using SelectedKernel = SelectedKernelType; + + GeneratedTileKernelInstance(const KernelKey& key, const std::string& name) + : key_(key), name_(name) + { + } + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + // Check dimension divisibility if padding not enabled + constexpr bool pad_m = SelectedKernel::kPadM; + constexpr bool pad_n = SelectedKernel::kPadN; + constexpr bool pad_k = SelectedKernel::kPadK; + + if(pad_m && pad_n && pad_k) + { + return true; // Padding enabled - supports any size + } + + // Check divisibility + constexpr int tile_m = SelectedKernel::TileM; + constexpr int tile_n = SelectedKernel::TileN; + constexpr int tile_k = SelectedKernel::TileK; + + if(!pad_m && problem.M % tile_m != 0) + return false; + if(!pad_n && problem.N % tile_n != 0) + return false; + if(!pad_k && problem.K % tile_k != 0) + return false; + + return true; + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + (void)d_ptrs; // Not used in basic GEMM + + // Create arguments using constructor (correct order!) + // Order from GemmHostArgs constructor: a_ptr, b_ptr, e_ptr, k_batch, M, N, K, stride_A, + // stride_B, stride_E + ck_tile::GemmHostArgs args(a_ptr, // a_ptr + b_ptr, // b_ptr + c_ptr, // e_ptr/c_ptr + problem.k_batch, // k_batch (4th argument!) + problem.M, // M + problem.N, // N + problem.K, // K + problem.K, // stride_A (row-major A: stride = K) + problem.K, // stride_B (column-major B: stride = K) + problem.N // stride_E/C (row-major C: stride = N) + ); + + // Create stream config for timing + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = reinterpret_cast(stream); + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = 0; // No logging for performance + stream_cfg.cold_niters_ = 5; // Warmup iterations + stream_cfg.nrepeat_ = 10; // Measurement iterations + stream_cfg.is_gpu_timer_ = true; + stream_cfg.flush_cache_ = false; + stream_cfg.rotating_count_ = 1; + + // Call the generated kernel's launch method + return SelectedKernel::launch(args, stream_cfg); + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + (void)a_ptr; + (void)b_ptr; + (void)c_ptr; + (void)d_ptrs; + (void)problem; + (void)tolerance; + // Validation would require reference implementation + return true; + } + + private: + KernelKey key_; + std::string name_; +}; + +/// Helper function to create a generated tile kernel instance wrapper +template +std::shared_ptr create_generated_tile_kernel(const KernelKey& key, + const std::string& name) +{ + return std::make_shared< + GeneratedTileKernelInstance>( + key, name); +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp b/dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp new file mode 100644 index 0000000000..01ab1f5e52 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/// Helper to register a CK Tile generated kernel +/// This should be called from generated code for each kernel +template +void register_tile_kernel(Registry& registry, const std::string& kernel_name) +{ + // Extract metadata from SelectedKernel static members + KernelKey key; + + // Signature + key.signature.dtype_a = static_cast(SelectedKernel::ADataType); + key.signature.dtype_b = static_cast(SelectedKernel::BDataType); + key.signature.dtype_c = static_cast(SelectedKernel::CDataType); + key.signature.dtype_acc = static_cast(SelectedKernel::AccDataType); + + key.signature.layout_a = static_cast(SelectedKernel::ALayout); + key.signature.layout_b = static_cast(SelectedKernel::BLayout); + key.signature.layout_c = static_cast(SelectedKernel::CLayout); + + key.signature.transpose_a = false; // Extract from kernel if available + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + + key.signature.elementwise_op = "PassThrough"; // Extract if available + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = SelectedKernel::UseStructuredSparsity; + + // Algorithm + key.algorithm.tile_shape.m = SelectedKernel::TileM; + key.algorithm.tile_shape.n = SelectedKernel::TileN; + key.algorithm.tile_shape.k = SelectedKernel::TileK; + + key.algorithm.wave_shape.m = SelectedKernel::WarpPerBlock_M; + key.algorithm.wave_shape.n = SelectedKernel::WarpPerBlock_N; + key.algorithm.wave_shape.k = SelectedKernel::WarpPerBlock_K; + + key.algorithm.warp_tile_shape.m = SelectedKernel::WarpTileM; + key.algorithm.warp_tile_shape.n = SelectedKernel::WarpTileN; + key.algorithm.warp_tile_shape.k = SelectedKernel::WarpTileK; + + // Extract pipeline, epilogue, scheduler from traits + key.algorithm.pipeline = Pipeline::CompV4; // Extract from kernel + key.algorithm.epilogue = Epilogue::Default; // Extract from kernel + key.algorithm.scheduler = Scheduler::Auto; // Extract from kernel + + key.algorithm.block_size = SelectedKernel::BlockSize; + key.algorithm.double_buffer = SelectedKernel::DoubleSmemBuffer; + key.algorithm.persistent = SelectedKernel::UsePersistentKernel; + key.algorithm.preshuffle = false; // Extract if available + key.algorithm.transpose_c = SelectedKernel::TransposeC; + key.algorithm.num_wave_groups = 1; // Extract if available + + key.gfx_arch = 942; // Extract from build configuration + + // Create kernel instance + auto kernel_instance = std::make_shared>(key, kernel_name); + + // Register with high priority (Tile kernels preferred) + registry.register_kernel(kernel_instance, Registry::Priority::High); +} + +/// Macro to simplify kernel registration in generated code +#define CK_TILE_REGISTER_KERNEL(SelectedKernel, KernelName, Registry) \ + ::ck_tile::dispatcher::backends::register_tile_kernel(Registry, KernelName) + +/// Helper to register multiple kernels from a list +template +struct KernelRegistrar +{ + static void register_all(Registry& registry) + { + // This would be specialized for each kernel set + // For now, empty implementation + } +}; + +/// Auto-registration helper +/// Place this in generated files to automatically register kernels +template +struct AutoRegister +{ + AutoRegister(const std::string& kernel_name) + { + auto& registry = Registry::instance(); + register_tile_kernel(registry, kernel_name); + } +}; + +/// Macro for auto-registration +#define CK_TILE_AUTO_REGISTER(SelectedKernel, KernelName) \ + static ::ck_tile::dispatcher::backends::AutoRegister \ + auto_register_##SelectedKernel{KernelName}; + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp new file mode 100644 index 0000000000..a3a0b04685 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp @@ -0,0 +1,173 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/validation/reference_kernels.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/// Kernel instance for CK Tile generated kernels +template +class TileKernelInstance : public KernelInstance +{ + public: + TileKernelInstance(const KernelKey& key, const std::string& name) : key_(key), name_(name) {} + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + // Check dimension divisibility if padding not enabled + constexpr bool pad_m = SelectedKernel::kPadM; + constexpr bool pad_n = SelectedKernel::kPadN; + constexpr bool pad_k = SelectedKernel::kPadK; + + if(pad_m && pad_n && pad_k) + { + // Padding enabled - supports any size + return true; + } + + // Check divisibility + constexpr int tile_m = SelectedKernel::TileM; + constexpr int tile_n = SelectedKernel::TileN; + constexpr int tile_k = SelectedKernel::TileK; + + if(!pad_m && problem.M % tile_m != 0) + return false; + if(!pad_n && problem.N % tile_n != 0) + return false; + if(!pad_k && problem.K % tile_k != 0) + return false; + + // Check shared memory budget if specified + if(problem.smem_budget > 0) + { + int64_t estimated_smem = estimate_smem_usage(); + if(estimated_smem > problem.smem_budget) + return false; + } + + return true; + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + // Convert void* stream to hipStream_t + hipStream_t hip_stream = reinterpret_cast(stream); + + // Construct kernel arguments + using ADataType = typename SelectedKernel::ADataType; + using BDataType = typename SelectedKernel::BDataType; + using CDataType = typename SelectedKernel::CDataType; + + // Note: d_ptrs not yet supported in basic CK Tile kernels + (void)d_ptrs; // Suppress unused parameter warning + + auto kargs = SelectedKernel::MakeKernelArgs(static_cast(a_ptr), + static_cast(b_ptr), + static_cast(c_ptr), + problem.M, + problem.N, + problem.K, + problem.k_batch); + + // Validate arguments + if(!SelectedKernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel does not support the given arguments"); + } + + // Calculate grid and block dimensions + dim3 grids = SelectedKernel::GridSize(problem.M, problem.N, problem.K); + dim3 blocks = SelectedKernel::BlockSize(); + size_t lds_bytes = SelectedKernel::GetSmemSize(); + + // Time kernel execution + hipEvent_t start, stop; + (void)hipEventCreate(&start); + (void)hipEventCreate(&stop); + + (void)hipEventRecord(start, hip_stream); + + // Launch kernel + ck_tile::launch_kernel(SelectedKernel::Kernel, grids, blocks, lds_bytes, hip_stream, kargs); + + (void)hipEventRecord(stop, hip_stream); + (void)hipEventSynchronize(stop); + + float elapsed_ms = 0.0f; + (void)hipEventElapsedTime(&elapsed_ms, start, stop); + + (void)hipEventDestroy(start); + (void)hipEventDestroy(stop); + + return elapsed_ms; + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + // Use validation helper + using ADataType = typename SelectedKernel::ADataType; + using BDataType = typename SelectedKernel::BDataType; + using CDataType = typename SelectedKernel::CDataType; + using AccDataType = typename SelectedKernel::AccDataType; + + // d_ptrs not yet supported + (void)d_ptrs; + + // Convert tolerance to rtol and atol + float rtol = tolerance; + float atol = tolerance * 1e-2f; // atol is typically smaller + + return validation::validate_gemm_kernel( + a_ptr, b_ptr, c_ptr, problem, rtol, atol); + } + + private: + int64_t estimate_smem_usage() const + { + // Use kernel's reported shared memory size + return SelectedKernel::GetSmemSize(); + } + + KernelKey key_; + std::string name_; +}; + +/// Helper function to create a tile kernel instance wrapper +/// This should be called from generated code that knows the SelectedKernel type +template +std::shared_ptr create_tile_kernel_instance(const KernelKey& key, + const std::string& name) +{ + return std::make_shared>(key, name); +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp new file mode 100644 index 0000000000..6d3f548138 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp @@ -0,0 +1,146 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Dispatcher - Main Kernel Selection and Execution Engine + * + * The Dispatcher provides unified interface for selecting and executing + * CK Tile GEMM kernels based on problem specifications. + * + * Features: + * - Multiple selection strategies (FirstFit, Heuristic) + * - Custom heuristic functions + * - Thread-safe registry integration + * - Real GPU execution with timing + * + * Usage: + * Dispatcher dispatcher; + * Problem problem(M, N, K); + * float time = dispatcher.run(a_dev, b_dev, c_dev, problem); + * + * Status: Production ready - 319 TFLOPS validated + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Heuristic function type: maps Problem to ordered list of kernel identifiers +/// Returns kernel identifiers ranked by expected performance (best first) +using HeuristicFunction = std::function(const Problem&)>; + +/// Dispatcher: Top-level orchestration for kernel selection and execution +/// Provides unified interface for kernel dispatch across different backends +class Dispatcher +{ + public: + /// Selection strategy for kernel choice + enum class SelectionStrategy + { + FirstFit, // Use first kernel that supports the problem + Heuristic // Use heuristic function to guide selection + }; + + /// Constructor + /// @param registry Registry instance to use (default: global singleton) + explicit Dispatcher(Registry* registry = nullptr); + + /// Register a heuristic function for kernel selection + /// @param heuristic Function that maps problems to ranked kernel identifiers + void set_heuristic(HeuristicFunction heuristic); + + /// Set selection strategy + /// @param strategy Strategy to use for kernel selection + void set_strategy(SelectionStrategy strategy); + + /// Select a kernel for the given problem + /// @param problem Problem configuration + /// @return Selected kernel instance, or nullptr if no suitable kernel found + [[nodiscard]] KernelInstancePtr select_kernel(const Problem& problem) const; + + /// Execute GEMM operation with automatic kernel selection + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds + /// @throws std::runtime_error if no suitable kernel found + [[nodiscard]] float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const Problem& problem, + void* stream = nullptr) const; + + /// Execute GEMM operation with fusion (multi-D) + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds + /// @throws std::runtime_error if no suitable kernel found + [[nodiscard]] float run_fused(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream = nullptr) const; + + /// Execute with explicit kernel selection + /// @param kernel_id Kernel identifier string + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds + /// @throws std::runtime_error if kernel not found or doesn't support problem + [[nodiscard]] float run_explicit(const std::string& kernel_id, + const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream = nullptr) const; + + /// Validate kernel output + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, kernel output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param tolerance Relative error tolerance + /// @return true if validation passes, false otherwise + [[nodiscard]] bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance = 1e-3f) const; + + private: + Registry* registry_; + HeuristicFunction heuristic_; + SelectionStrategy strategy_; + + /// Select kernel using first-fit strategy + [[nodiscard]] KernelInstancePtr select_first_fit(const Problem& problem) const; + + /// Select kernel using heuristic strategy + [[nodiscard]] KernelInstancePtr select_heuristic(const Problem& problem) const; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/example_args.hpp b/dispatcher/include/ck_tile/dispatcher/example_args.hpp new file mode 100644 index 0000000000..f93a4d61f6 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/example_args.hpp @@ -0,0 +1,230 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace utils { + +/** + * Simple command-line argument parser for examples. + * + * Usage: + * ExampleArgs args("Example 01: Basic GEMM", "Demonstrates basic GEMM usage"); + * args.add_flag("--list", "List all kernel sets"); + * args.add_option("--dtype", "fp16", "Data type (fp16, bf16, fp32)"); + * args.add_option("--size", "1024", "Problem size MxNxK"); + * + * if (!args.parse(argc, argv)) return 0; // --help was printed + * + * bool do_list = args.has("--list"); + * std::string dtype = args.get("--dtype"); + * int size = args.get_int("--size"); + */ +class ExampleArgs +{ + public: + ExampleArgs(const std::string& name, const std::string& description = "") + : name_(name), description_(description) + { + // Always add --help + add_flag("--help", "Show this help message"); + add_flag("-h", "Show this help message"); + } + + // Add a boolean flag (no value) + void add_flag(const std::string& name, const std::string& help) + { + flags_[name] = false; + help_[name] = help; + order_.push_back(name); + } + + // Add an option with a default value + void + add_option(const std::string& name, const std::string& default_val, const std::string& help) + { + options_[name] = default_val; + defaults_[name] = default_val; + help_[name] = help; + order_.push_back(name); + } + + // Parse arguments. Returns false if --help was requested. + bool parse(int argc, char* argv[]) + { + for(int i = 1; i < argc; ++i) + { + std::string arg = argv[i]; + + // Check for --help + if(arg == "--help" || arg == "-h") + { + print_help(); + return false; + } + + // Check for flags + if(flags_.find(arg) != flags_.end()) + { + flags_[arg] = true; + continue; + } + + // Check for options (--name=value or --name value) + std::string name, value; + size_t eq_pos = arg.find('='); + if(eq_pos != std::string::npos) + { + name = arg.substr(0, eq_pos); + value = arg.substr(eq_pos + 1); + } + else if(options_.find(arg) != options_.end() && i + 1 < argc) + { + name = arg; + value = argv[++i]; + } + else + { + // Positional argument - store as _pos_N + std::string pos_name = "_pos_" + std::to_string(positional_.size()); + positional_.push_back(arg); + continue; + } + + if(options_.find(name) != options_.end()) + { + options_[name] = value; + } + } + return true; + } + + // Check if a flag is set + bool has(const std::string& name) const + { + auto it = flags_.find(name); + return it != flags_.end() && it->second; + } + + // Get an option value as string + std::string get(const std::string& name) const + { + auto it = options_.find(name); + return it != options_.end() ? it->second : ""; + } + + // Get an option value as string with default + std::string get(const std::string& name, const std::string& default_val) const + { + auto it = options_.find(name); + return it != options_.end() ? it->second : default_val; + } + + // Get an option value as int + int get_int(const std::string& name, int default_val = 0) const + { + std::string val = get(name); + if(val.empty()) + return default_val; + try + { + return std::stoi(val); + } + catch(...) + { + return default_val; + } + } + + // Get an option value as float + float get_float(const std::string& name, float default_val = 0.0f) const + { + std::string val = get(name); + if(val.empty()) + return default_val; + try + { + return std::stof(val); + } + catch(...) + { + return default_val; + } + } + + // Get positional arguments + const std::vector& positional() const { return positional_; } + + // Print help message + void print_help() const + { + std::cout << "\n"; + std::cout << " " << name_ << "\n"; + if(!description_.empty()) + { + std::cout << " " << description_ << "\n"; + } + std::cout << "\n"; + std::cout << "Usage:\n"; + std::cout << " ./example [OPTIONS]\n"; + std::cout << "\n"; + std::cout << "Options:\n"; + + // Find max option name length for alignment + size_t max_len = 0; + for(const auto& name : order_) + { + if(name == "-h") + continue; // Skip -h, show --help only + max_len = std::max(max_len, name.length()); + } + + // Print options in order + for(const auto& name : order_) + { + if(name == "-h") + continue; + + std::cout << " " << std::left << std::setw(max_len + 2) << name; + + auto help_it = help_.find(name); + if(help_it != help_.end()) + { + std::cout << help_it->second; + } + + // Show default value for options + auto def_it = defaults_.find(name); + if(def_it != defaults_.end() && !def_it->second.empty()) + { + std::cout << " (default: " << def_it->second << ")"; + } + + std::cout << "\n"; + } + std::cout << "\n"; + } + + private: + std::string name_; + std::string description_; + std::map flags_; + std::map options_; + std::map defaults_; + std::map help_; + std::vector order_; + std::vector positional_; +}; + +} // namespace utils +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/json_export.hpp b/dispatcher/include/ck_tile/dispatcher/json_export.hpp new file mode 100644 index 0000000000..ab1c45412f --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/json_export.hpp @@ -0,0 +1,370 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * JSON Export Utilities for Dispatcher Registry + * + * Provides functionality to export kernel registry metadata to JSON format, + * similar to the tile engine benchmarking JSON export. + * + * Features: + * - Export all registered kernels with full metadata + * - Include kernel configuration (tile shapes, pipeline, scheduler, etc.) + * - Group kernels by various properties (data type, layout, pipeline, etc.) + * - Export to string or file + * + * Usage: + * auto& registry = Registry::instance(); + * std::string json = export_registry_json(registry); + * // or + * export_registry_json_to_file(registry, "kernels.json"); + */ + +#pragma once + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Convert DataType enum to string +inline std::string datatype_to_string(DataType dtype) +{ + switch(dtype) + { + case DataType::FP16: return "fp16"; + case DataType::BF16: return "bf16"; + case DataType::FP32: return "fp32"; + case DataType::FP8: return "fp8"; + case DataType::BF8: return "bf8"; + case DataType::INT8: return "int8"; + case DataType::INT32: return "int32"; + default: return "unknown"; + } +} + +/// Convert LayoutTag enum to string +inline std::string layout_to_string(LayoutTag layout) +{ + switch(layout) + { + case LayoutTag::RowMajor: return "row_major"; + case LayoutTag::ColMajor: return "col_major"; + case LayoutTag::PackedExternal: return "packed_external"; + default: return "unknown"; + } +} + +/// Convert Pipeline enum to string +inline std::string pipeline_to_string(Pipeline pipeline) +{ + switch(pipeline) + { + case Pipeline::Mem: return "mem"; + case Pipeline::CompV1: return "compv1"; + case Pipeline::CompV2: return "compv2"; + case Pipeline::CompV3: return "compv3"; + case Pipeline::CompV4: return "compv4"; + case Pipeline::CompV5: return "compv5"; + default: return "unknown"; + } +} + +/// Convert Epilogue enum to string +inline std::string epilogue_to_string(Epilogue epilogue) +{ + switch(epilogue) + { + case Epilogue::None: return "none"; + case Epilogue::Bias: return "bias"; + case Epilogue::Activation: return "activation"; + case Epilogue::CShuffle: return "cshuffle"; + case Epilogue::Default: return "default"; + default: return "unknown"; + } +} + +/// Convert Scheduler enum to string +inline std::string scheduler_to_string(Scheduler scheduler) +{ + switch(scheduler) + { + case Scheduler::Auto: return "auto"; + case Scheduler::Intrawave: return "intrawave"; + case Scheduler::Interwave: return "interwave"; + default: return "unknown"; + } +} + +/// Escape string for JSON +inline std::string json_escape(const std::string& str) +{ + std::ostringstream oss; + for(char c : str) + { + switch(c) + { + case '"': oss << "\\\""; break; + case '\\': oss << "\\\\"; break; + case '\b': oss << "\\b"; break; + case '\f': oss << "\\f"; break; + case '\n': oss << "\\n"; break; + case '\r': oss << "\\r"; break; + case '\t': oss << "\\t"; break; + default: + if(c < 0x20) + { + oss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (int)c; + } + else + { + oss << c; + } + } + } + return oss.str(); +} + +/// Get current timestamp in ISO 8601 format +inline std::string get_iso_timestamp() +{ + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + std::tm tm_buf; + localtime_r(&time_t, &tm_buf); + + std::ostringstream oss; + oss << std::put_time(&tm_buf, "%Y-%m-%dT%H:%M:%S"); + return oss.str(); +} + +/// Export a single kernel's metadata to JSON +inline std::string export_kernel_json(const KernelInstance& kernel) +{ + std::ostringstream json; + const auto& key = kernel.get_key(); + + json << " {\n"; + json << " \"name\": \"" << json_escape(kernel.get_name()) << "\",\n"; + json << " \"identifier\": \"" << json_escape(key.encode_identifier()) << "\",\n"; + + // Signature (what operation is computed) + json << " \"signature\": {\n"; + json << " \"dtype_a\": \"" << datatype_to_string(key.signature.dtype_a) << "\",\n"; + json << " \"dtype_b\": \"" << datatype_to_string(key.signature.dtype_b) << "\",\n"; + json << " \"dtype_c\": \"" << datatype_to_string(key.signature.dtype_c) << "\",\n"; + json << " \"dtype_acc\": \"" << datatype_to_string(key.signature.dtype_acc) << "\",\n"; + json << " \"layout_a\": \"" << layout_to_string(key.signature.layout_a) << "\",\n"; + json << " \"layout_b\": \"" << layout_to_string(key.signature.layout_b) << "\",\n"; + json << " \"layout_c\": \"" << layout_to_string(key.signature.layout_c) << "\",\n"; + json << " \"transpose_a\": " << (key.signature.transpose_a ? "true" : "false") << ",\n"; + json << " \"transpose_b\": " << (key.signature.transpose_b ? "true" : "false") << ",\n"; + json << " \"grouped\": " << (key.signature.grouped ? "true" : "false") << ",\n"; + json << " \"split_k\": " << (int)key.signature.split_k << ",\n"; + json << " \"elementwise_op\": \"" << json_escape(key.signature.elementwise_op) + << "\",\n"; + json << " \"num_d_tensors\": " << (int)key.signature.num_d_tensors << ",\n"; + json << " \"structured_sparsity\": " + << (key.signature.structured_sparsity ? "true" : "false") << "\n"; + json << " },\n"; + + // Algorithm (how it's implemented) + json << " \"algorithm\": {\n"; + json << " \"tile_shape\": {\n"; + json << " \"m\": " << key.algorithm.tile_shape.m << ",\n"; + json << " \"n\": " << key.algorithm.tile_shape.n << ",\n"; + json << " \"k\": " << key.algorithm.tile_shape.k << "\n"; + json << " },\n"; + json << " \"wave_shape\": {\n"; + json << " \"m\": " << (int)key.algorithm.wave_shape.m << ",\n"; + json << " \"n\": " << (int)key.algorithm.wave_shape.n << ",\n"; + json << " \"k\": " << (int)key.algorithm.wave_shape.k << "\n"; + json << " },\n"; + json << " \"warp_tile_shape\": {\n"; + json << " \"m\": " << (int)key.algorithm.warp_tile_shape.m << ",\n"; + json << " \"n\": " << (int)key.algorithm.warp_tile_shape.n << ",\n"; + json << " \"k\": " << (int)key.algorithm.warp_tile_shape.k << "\n"; + json << " },\n"; + json << " \"pipeline\": \"" << pipeline_to_string(key.algorithm.pipeline) << "\",\n"; + json << " \"scheduler\": \"" << scheduler_to_string(key.algorithm.scheduler) << "\",\n"; + json << " \"epilogue\": \"" << epilogue_to_string(key.algorithm.epilogue) << "\",\n"; + json << " \"block_size\": " << key.algorithm.block_size << ",\n"; + json << " \"double_buffer\": " << (key.algorithm.double_buffer ? "true" : "false") + << ",\n"; + json << " \"persistent\": " << (key.algorithm.persistent ? "true" : "false") << ",\n"; + json << " \"preshuffle\": " << (key.algorithm.preshuffle ? "true" : "false") << ",\n"; + json << " \"transpose_c\": " << (key.algorithm.transpose_c ? "true" : "false") << ",\n"; + json << " \"num_wave_groups\": " << (int)key.algorithm.num_wave_groups << "\n"; + json << " },\n"; + + json << " \"gfx_arch\": \"" << json_escape(key.gfx_arch) << "\"\n"; + json << " }"; + + return json.str(); +} + +/// Export registry metadata and statistics to JSON +inline std::string export_registry_json(const Registry& registry, bool include_statistics = true) +{ + std::ostringstream json; + + auto all_kernels = registry.get_all(); + + json << "{\n"; + + // Metadata + json << " \"metadata\": {\n"; + json << " \"timestamp\": \"" << get_iso_timestamp() << "\",\n"; + json << " \"registry_name\": \"" << json_escape(registry.get_name()) << "\",\n"; + json << " \"total_kernels\": " << all_kernels.size() << ",\n"; + json << " \"export_version\": \"1.0.0\"\n"; + json << " },\n"; + + // Statistics (if enabled) + if(include_statistics && !all_kernels.empty()) + { + std::map by_datatype; + std::map by_pipeline; + std::map by_scheduler; + std::map by_layout; + std::map by_gfx_arch; + + for(const auto& kernel : all_kernels) + { + const auto& key = kernel->get_key(); + + // Count by data type + std::string dtype_key = datatype_to_string(key.signature.dtype_a) + "_" + + datatype_to_string(key.signature.dtype_b) + "_" + + datatype_to_string(key.signature.dtype_c); + by_datatype[dtype_key]++; + + // Count by pipeline + by_pipeline[pipeline_to_string(key.algorithm.pipeline)]++; + + // Count by scheduler + by_scheduler[scheduler_to_string(key.algorithm.scheduler)]++; + + // Count by layout + std::string layout_key = layout_to_string(key.signature.layout_a) + "_" + + layout_to_string(key.signature.layout_b) + "_" + + layout_to_string(key.signature.layout_c); + by_layout[layout_key]++; + + // Count by GFX architecture + by_gfx_arch[key.gfx_arch]++; + } + + json << " \"statistics\": {\n"; + + // Data type breakdown + json << " \"by_datatype\": {\n"; + bool first = true; + for(const auto& [dtype, count] : by_datatype) + { + if(!first) + json << ",\n"; + json << " \"" << dtype << "\": " << count; + first = false; + } + json << "\n },\n"; + + // Pipeline breakdown + json << " \"by_pipeline\": {\n"; + first = true; + for(const auto& [pipeline, count] : by_pipeline) + { + if(!first) + json << ",\n"; + json << " \"" << pipeline << "\": " << count; + first = false; + } + json << "\n },\n"; + + // Scheduler breakdown + json << " \"by_scheduler\": {\n"; + first = true; + for(const auto& [scheduler, count] : by_scheduler) + { + if(!first) + json << ",\n"; + json << " \"" << scheduler << "\": " << count; + first = false; + } + json << "\n },\n"; + + // Layout breakdown + json << " \"by_layout\": {\n"; + first = true; + for(const auto& [layout, count] : by_layout) + { + if(!first) + json << ",\n"; + json << " \"" << layout << "\": " << count; + first = false; + } + json << "\n },\n"; + + // GFX architecture breakdown + json << " \"by_gfx_arch\": {\n"; + first = true; + for(const auto& [arch, count] : by_gfx_arch) + { + if(!first) + json << ",\n"; + json << " \"" << arch << "\": " << count; + first = false; + } + json << "\n }\n"; + + json << " },\n"; + } + + // Kernels list + json << " \"kernels\": [\n"; + for(size_t i = 0; i < all_kernels.size(); ++i) + { + json << export_kernel_json(*all_kernels[i]); + if(i < all_kernels.size() - 1) + { + json << ","; + } + json << "\n"; + } + json << " ]\n"; + + json << "}\n"; + + return json.str(); +} + +/// Export registry to a JSON file +inline bool export_registry_json_to_file(const Registry& registry, + const std::string& filename, + bool include_statistics = true) +{ + std::string json = export_registry_json(registry, include_statistics); + + std::ofstream file(filename); + if(!file.is_open()) + { + return false; + } + + file << json; + file.close(); + + return true; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_config.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_config.hpp new file mode 100644 index 0000000000..05011d2c2d --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_config.hpp @@ -0,0 +1,370 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file kernel_config.hpp + * @brief Explicit kernel configuration for CK Tile Dispatcher + * + * This header provides a KernelConfig struct that mirrors the Python API, + * allowing explicit, self-contained kernel configuration without relying + * on force-included generated headers. + * + * Usage: + * #include "ck_tile/dispatcher/kernel_config.hpp" + * using namespace ck_tile::dispatcher; + * + * // Step 1: Define explicit config + * auto config = KernelConfig::fp16_rcr() + * .tile(128, 128, 32) + * .wave(2, 2, 1) + * .warp_tile(32, 32, 16) + * .pipeline(Pipeline::CompV4) + * .scheduler(Scheduler::Intrawave); + * + * // Step 2: Create registry and register + * Registry registry; + * registry.register_kernel(config.build_key(), config.get_name()); + * + * // Step 3: Create dispatcher + * Dispatcher dispatcher(®istry); + * + * // Step 4: Run GEMM + * dispatcher.run(a, b, c, Problem(M, N, K)); + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/** + * @brief Explicit kernel configuration matching Python's KernelConfig + * + * This provides a fluent builder API for creating kernel configurations + * with all parameters visible and explicit. + */ +class KernelConfig +{ + public: + // ========================================================================= + // Data types + // ========================================================================= + DataType dtype_a = DataType::FP16; + DataType dtype_b = DataType::FP16; + DataType dtype_c = DataType::FP16; + DataType dtype_acc = DataType::FP32; + + // ========================================================================= + // Layouts + // ========================================================================= + LayoutTag layout_a = LayoutTag::RowMajor; + LayoutTag layout_b = LayoutTag::ColMajor; + LayoutTag layout_c = LayoutTag::RowMajor; + + // ========================================================================= + // Tile shape + // ========================================================================= + int tile_m = 128; + int tile_n = 128; + int tile_k = 32; + + // ========================================================================= + // Wave shape (warps per block) + // ========================================================================= + int wave_m = 2; + int wave_n = 2; + int wave_k = 1; + + // ========================================================================= + // Warp tile shape + // ========================================================================= + int warp_m = 32; + int warp_n = 32; + int warp_k = 16; + + // ========================================================================= + // Block and pipeline + // ========================================================================= + int block_size = 256; + Pipeline pipeline_type = Pipeline::CompV4; + Scheduler scheduler_type = Scheduler::Intrawave; + Epilogue epilogue_type = Epilogue::CShuffle; + + // ========================================================================= + // Padding and features + // ========================================================================= + bool pad_m = true; + bool pad_n = true; + bool pad_k = true; + bool preshuffle = false; + + // ========================================================================= + // Target architecture + // ========================================================================= + std::string gfx_arch = "gfx942"; + + // ========================================================================= + // Fluent builder methods + // ========================================================================= + + /// Set tile dimensions (M x N x K) + KernelConfig& tile(int m, int n, int k) + { + tile_m = m; + tile_n = n; + tile_k = k; + return *this; + } + + /// Set wave dimensions (warps per block M x N x K) + KernelConfig& wave(int m, int n, int k) + { + wave_m = m; + wave_n = n; + wave_k = k; + return *this; + } + + /// Set warp tile dimensions (M x N x K) + KernelConfig& warp_tile(int m, int n, int k) + { + warp_m = m; + warp_n = n; + warp_k = k; + return *this; + } + + /// Set block size + KernelConfig& block(int size) + { + block_size = size; + return *this; + } + + /// Set pipeline type + KernelConfig& pipeline(Pipeline p) + { + pipeline_type = p; + return *this; + } + + /// Set scheduler type + KernelConfig& scheduler(Scheduler s) + { + scheduler_type = s; + return *this; + } + + /// Set epilogue type + KernelConfig& epilogue(Epilogue e) + { + epilogue_type = e; + return *this; + } + + /// Set data types for A, B, C + KernelConfig& dtypes(DataType a, DataType b, DataType c, DataType acc = DataType::FP32) + { + dtype_a = a; + dtype_b = b; + dtype_c = c; + dtype_acc = acc; + return *this; + } + + /// Set layouts for A, B, C + KernelConfig& layouts(LayoutTag a, LayoutTag b, LayoutTag c) + { + layout_a = a; + layout_b = b; + layout_c = c; + return *this; + } + + /// Set padding flags + KernelConfig& padding(bool m, bool n, bool k) + { + pad_m = m; + pad_n = n; + pad_k = k; + return *this; + } + + /// Set target GPU architecture + KernelConfig& arch(const std::string& gpu) + { + gfx_arch = gpu; + return *this; + } + + // ========================================================================= + // Preset configurations + // ========================================================================= + + /// FP16 Row-Column-Row layout (most common) + static KernelConfig fp16_rcr() { return KernelConfig{}; } + + /// FP16 Row-Row-Row layout + static KernelConfig fp16_rrr() + { + KernelConfig cfg; + cfg.layout_b = LayoutTag::RowMajor; + return cfg; + } + + /// BF16 Row-Column-Row layout + static KernelConfig bf16_rcr() + { + KernelConfig cfg; + cfg.dtype_a = DataType::BF16; + cfg.dtype_b = DataType::BF16; + cfg.dtype_c = DataType::BF16; + return cfg; + } + + /// FP32 Row-Column-Row layout + static KernelConfig fp32_rcr() + { + KernelConfig cfg; + cfg.dtype_a = DataType::FP32; + cfg.dtype_b = DataType::FP32; + cfg.dtype_c = DataType::FP32; + cfg.dtype_acc = DataType::FP32; + return cfg; + } + + // ========================================================================= + // Build KernelKey + // ========================================================================= + + /// Build a KernelKey from this configuration + [[nodiscard]] KernelKey build_key() const + { + KernelKey key; + + // Signature + key.signature.dtype_a = dtype_a; + key.signature.dtype_b = dtype_b; + key.signature.dtype_c = dtype_c; + key.signature.dtype_acc = dtype_acc; + key.signature.layout_a = layout_a; + key.signature.layout_b = layout_b; + key.signature.layout_c = layout_c; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + // Algorithm + key.algorithm.tile_shape = {static_cast(tile_m), + static_cast(tile_n), + static_cast(tile_k)}; + key.algorithm.wave_shape = {static_cast(wave_m), + static_cast(wave_n), + static_cast(wave_k)}; + key.algorithm.warp_tile_shape = {static_cast(warp_m), + static_cast(warp_n), + static_cast(warp_k)}; + key.algorithm.pipeline = pipeline_type; + key.algorithm.scheduler = scheduler_type; + key.algorithm.epilogue = epilogue_type; + key.algorithm.block_size = block_size; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = preshuffle; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + + key.gfx_arch = gfx_arch; + + return key; + } + + // ========================================================================= + // String representations + // ========================================================================= + + /// Get tile string (e.g., "128x128x32") + [[nodiscard]] std::string tile_str() const + { + std::ostringstream oss; + oss << tile_m << "x" << tile_n << "x" << tile_k; + return oss.str(); + } + + /// Get wave string (e.g., "2x2x1") + [[nodiscard]] std::string wave_str() const + { + std::ostringstream oss; + oss << wave_m << "x" << wave_n << "x" << wave_k; + return oss.str(); + } + + /// Get warp tile string (e.g., "32x32x16") + [[nodiscard]] std::string warp_tile_str() const + { + std::ostringstream oss; + oss << warp_m << "x" << warp_n << "x" << warp_k; + return oss.str(); + } + + /// Get layout string (e.g., "rcr") + [[nodiscard]] std::string layout_str() const + { + std::ostringstream oss; + oss << to_string(layout_a) << to_string(layout_b) << to_string(layout_c); + return oss.str(); + } + + /// Get kernel name for generated code lookup + [[nodiscard]] std::string get_name() const + { + std::ostringstream oss; + oss << "gemm_" << to_string(dtype_a) << "_" << layout_str() << "_" + << to_string(pipeline_type) << "_" << to_string(epilogue_type) << "_" + << to_string(scheduler_type) << "_" << (pad_m ? "True" : "False") << "_" + << (pad_n ? "True" : "False") << "_" << (pad_k ? "True" : "False") << "_" + << "False" // preshuffle + << "_" << tile_str() << "_" << wave_str() << "_" << warp_tile_str(); + return oss.str(); + } + + /// Print configuration to stdout + void print_config(std::ostream& os = std::cout) const + { + os << " Data types:\n"; + os << " dtype_a = " << to_string(dtype_a) << "\n"; + os << " dtype_b = " << to_string(dtype_b) << "\n"; + os << " dtype_c = " << to_string(dtype_c) << "\n"; + os << " dtype_acc = " << to_string(dtype_acc) << "\n"; + os << " Layouts:\n"; + os << " layout_a = " << to_string(layout_a) << "\n"; + os << " layout_b = " << to_string(layout_b) << "\n"; + os << " layout_c = " << to_string(layout_c) << "\n"; + os << " Tile shape:\n"; + os << " tile = " << tile_str() << "\n"; + os << " wave = " << wave_str() << "\n"; + os << " warp_tile = " << warp_tile_str() << "\n"; + os << " Pipeline:\n"; + os << " pipeline = " << to_string(pipeline_type) << "\n"; + os << " scheduler = " << to_string(scheduler_type) << "\n"; + os << " epilogue = " << to_string(epilogue_type) << "\n"; + os << " Padding:\n"; + os << " pad_m = " << (pad_m ? "true" : "false") << "\n"; + os << " pad_n = " << (pad_n ? "true" : "false") << "\n"; + os << " pad_k = " << (pad_k ? "true" : "false") << "\n"; + os << " Target:\n"; + os << " gfx_arch = " << gfx_arch << "\n"; + } +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp new file mode 100644 index 0000000000..095de52e06 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp @@ -0,0 +1,509 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file kernel_decl.hpp + * @brief Declarative kernel specification with KernelSet + * + * USAGE: + * ====== + * + * // Named kernel sets + * DECL_KERNEL_SET(compute_bound, + * .add("fp16", "rcr", 256, 256, 64) + * .add("fp16", "rcr", 128, 128, 32) + * ); + * + * // Access at runtime + * auto& set = KernelSetRegistry::instance().get("compute_bound"); + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace decl { + +// ============================================================================= +// Wildcard constants +// ============================================================================= + +constexpr const char* ANY = "*"; +constexpr int ANY_INT = -1; + +// ============================================================================= +// Signature Builder +// ============================================================================= + +class Signature +{ + public: + std::string dtype_a_ = "fp16"; + std::string dtype_b_ = "fp16"; + std::string dtype_c_ = "fp16"; + std::string dtype_acc_ = "fp32"; + std::string layout_a_ = "row"; + std::string layout_b_ = "col"; + std::string layout_c_ = "row"; + std::string elementwise_op_ = "PassThrough"; + int num_d_tensors_ = 0; + bool structured_sparsity_ = false; + + Signature& dtype(const std::string& a, + const std::string& b, + const std::string& c, + const std::string& acc = "fp32") + { + dtype_a_ = a; + dtype_b_ = b; + dtype_c_ = c; + dtype_acc_ = acc; + return *this; + } + + Signature& dtype(const std::string& all) + { + dtype_a_ = dtype_b_ = dtype_c_ = all; + dtype_acc_ = "fp32"; + return *this; + } + + Signature& layout(const std::string& a, const std::string& b, const std::string& c) + { + layout_a_ = a; + layout_b_ = b; + layout_c_ = c; + return *this; + } + + Signature& layout(const std::string& combined) + { + if(combined.size() >= 3) + { + layout_a_ = (combined[0] == 'r') ? "row" : "col"; + layout_b_ = (combined[1] == 'r') ? "row" : "col"; + layout_c_ = (combined[2] == 'r') ? "row" : "col"; + } + return *this; + } + + Signature& elementwise(const std::string& op, int num_d = 0) + { + elementwise_op_ = op; + num_d_tensors_ = num_d; + return *this; + } + + std::string layout_str() const + { + std::string r; + r += (layout_a_ == "col") ? 'c' : 'r'; + r += (layout_b_ == "col") ? 'c' : 'r'; + r += (layout_c_ == "col") ? 'c' : 'r'; + return r; + } +}; + +// ============================================================================= +// Algorithm Builder +// ============================================================================= + +class Algorithm +{ + public: + int tile_m_ = 128, tile_n_ = 128, tile_k_ = 32; + int wave_m_ = ANY_INT, wave_n_ = ANY_INT, wave_k_ = 1; + int warp_m_ = ANY_INT, warp_n_ = ANY_INT, warp_k_ = 16; + std::string pipeline_ = "compv4"; + std::string scheduler_ = "intrawave"; + std::string epilogue_ = "cshuffle"; + int block_size_ = 256; + int pad_m_ = 1, pad_n_ = 1, pad_k_ = 1; + bool preshuffle_ = false; + + Algorithm& tile(int m, int n, int k) + { + tile_m_ = m; + tile_n_ = n; + tile_k_ = k; + return *this; + } + + Algorithm& wave(int m, int n, int k = 1) + { + wave_m_ = m; + wave_n_ = n; + wave_k_ = k; + return *this; + } + + Algorithm& warp(int m, int n, int k = 16) + { + warp_m_ = m; + warp_n_ = n; + warp_k_ = k; + return *this; + } + + Algorithm& pipeline(const std::string& p) + { + pipeline_ = p; + return *this; + } + Algorithm& scheduler(const std::string& s) + { + scheduler_ = s; + return *this; + } + Algorithm& epilogue(const std::string& e) + { + epilogue_ = e; + return *this; + } + + Algorithm& pad(bool m, bool n, bool k) + { + pad_m_ = m ? 1 : 0; + pad_n_ = n ? 1 : 0; + pad_k_ = k ? 1 : 0; + return *this; + } + + Algorithm& preshuffle(bool v) + { + preshuffle_ = v; + return *this; + } + + bool needs_expansion() const + { + return wave_m_ == ANY_INT || warp_m_ == ANY_INT || pipeline_ == "*" || pad_m_ == ANY_INT; + } + + void auto_fill() + { + if(wave_m_ == ANY_INT) + wave_m_ = 2; + if(wave_n_ == ANY_INT) + wave_n_ = 2; + if(wave_k_ == ANY_INT) + wave_k_ = 1; + if(warp_m_ == ANY_INT) + warp_m_ = 32; + if(warp_n_ == ANY_INT) + warp_n_ = 32; + if(warp_k_ == ANY_INT) + warp_k_ = 16; + } +}; + +// ============================================================================= +// Kernel Declaration +// ============================================================================= + +struct KernelDecl +{ + Signature signature; + Algorithm algorithm; + std::string arch = "gfx942"; + + KernelDecl() = default; + + KernelDecl(const Signature& sig, const Algorithm& algo, const std::string& a = "gfx942") + : signature(sig), algorithm(algo), arch(a) + { + } + + std::string name() const + { + std::ostringstream oss; + oss << signature.dtype_a_ << "_" << signature.layout_str(); + if(algorithm.tile_m_ > 0) + { + oss << "_" << algorithm.tile_m_ << "x" << algorithm.tile_n_ << "x" << algorithm.tile_k_; + } + return oss.str(); + } + + bool has_wildcards() const { return algorithm.needs_expansion() || arch == "*"; } +}; + +// ============================================================================= +// KernelSet - Collection of declarations +// ============================================================================= + +class KernelSet +{ + public: + KernelSet() = default; + + KernelSet& add(const Signature& sig, const Algorithm& algo, const std::string& arch = "gfx942") + { + decls_.emplace_back(sig, algo, arch); + return *this; + } + + KernelSet& add(const std::string& dtype, + const std::string& layout, + int tm, + int tn, + int tk, + const std::string& arch = "gfx942") + { + Signature sig; + sig.dtype(dtype).layout(layout); + Algorithm algo; + algo.tile(tm, tn, tk); + decls_.emplace_back(sig, algo, arch); + return *this; + } + + KernelSet& add(const KernelDecl& decl) + { + decls_.push_back(decl); + return *this; + } + + KernelSet& merge(const KernelSet& other) + { + decls_.insert(decls_.end(), other.decls_.begin(), other.decls_.end()); + return *this; + } + + const std::vector& declarations() const { return decls_; } + size_t size() const { return decls_.size(); } + + bool needs_expansion() const + { + for(const auto& d : decls_) + { + if(d.algorithm.needs_expansion()) + return true; + } + return false; + } + + void print(std::ostream& os = std::cout) const + { + os << "KernelSet (" << size() << " declarations):\n"; + for(const auto& d : decls_) + { + os << " - " << d.name(); + if(d.algorithm.needs_expansion()) + os << " [expands]"; + os << "\n"; + } + } + + KernelSet& tag(const std::string& t) + { + tag_ = t; + return *this; + } + std::string tag() const { return tag_; } + + private: + std::vector decls_; + std::string tag_; +}; + +// ============================================================================= +// KernelSet Registry +// ============================================================================= + +class KernelSetRegistry +{ + public: + static KernelSetRegistry& instance() + { + static KernelSetRegistry reg; + return reg; + } + + void add(const std::string& name, const KernelSet& set) + { + sets_[name] = set; + order_.push_back(name); + } + + const KernelSet& get(const std::string& name) const + { + static KernelSet empty; + auto it = sets_.find(name); + return it != sets_.end() ? it->second : empty; + } + + bool has(const std::string& name) const { return sets_.find(name) != sets_.end(); } + + // Return const reference to avoid deep copy + const std::vector& names() const { return order_; } + size_t size() const { return sets_.size(); } + + void print() const + { + std::cout << "Named Kernel Sets (" << size() << "):\n"; + for(const auto& name : order_) + { + const auto& set = sets_.at(name); + std::cout << " " << name << ": " << set.size() << " declarations\n"; + } + } + + private: + KernelSetRegistry() = default; + std::unordered_map sets_; + std::vector order_; +}; + +// ============================================================================= +// Declaration Registry (for DECL_KERNEL) +// ============================================================================= + +class Registry +{ + public: + static Registry& instance() + { + static Registry reg; + return reg; + } + + void add(const KernelDecl& decl) + { + std::string key = decl.has_wildcards() + ? ("wildcard_" + std::to_string(declarations_.size())) + : decl.name(); + declarations_[key] = decl; + order_.push_back(key); + } + + std::vector all() const + { + std::vector result; + for(const auto& key : order_) + { + result.push_back(declarations_.at(key)); + } + return result; + } + + size_t size() const { return declarations_.size(); } + + void print() const + { + std::cout << "Declared kernels (" << size() << "):\n"; + for(const auto& key : order_) + { + const auto& d = declarations_.at(key); + std::cout << " " << d.name(); + if(d.has_wildcards()) + std::cout << " [wildcards]"; + std::cout << "\n"; + } + } + + private: + Registry() = default; + std::unordered_map declarations_; + std::vector order_; +}; + +// ============================================================================= +// Static Registrars +// ============================================================================= + +struct Declarator +{ + Declarator(const Signature& sig, const Algorithm& algo, const std::string& arch = "gfx942") + { + Registry::instance().add(KernelDecl(sig, algo, arch)); + } + + Declarator(const std::string& dtype, + const std::string& layout, + int tm, + int tn, + int tk, + const std::string& arch = "gfx942") + { + Signature sig; + sig.dtype(dtype).layout(layout); + Algorithm algo; + algo.tile(tm, tn, tk); + Registry::instance().add(KernelDecl(sig, algo, arch)); + } + + Declarator(const std::string& dtype, const std::string& layout, const std::string& arch) + { + Signature sig; + sig.dtype(dtype).layout(layout); + Algorithm algo; + algo.tile(ANY_INT, ANY_INT, ANY_INT); + Registry::instance().add(KernelDecl(sig, algo, arch)); + } +}; + +struct KernelSetRegistrar +{ + KernelSetRegistrar(const std::string& name, const KernelSet& set) + { + KernelSetRegistry::instance().add(name, set); + } +}; + +} // namespace decl + +// ============================================================================= +// Convenience Aliases +// ============================================================================= + +using KernelSignature = decl::Signature; +using KernelAlgorithm = decl::Algorithm; +using KernelDecl = decl::KernelDecl; +using KernelDeclRegistry = decl::Registry; +using KernelSet = decl::KernelSet; +using KernelSetRegistry = decl::KernelSetRegistry; + +constexpr const char* ANY = decl::ANY; +constexpr int ANY_INT = decl::ANY_INT; + +} // namespace dispatcher +} // namespace ck_tile + +// ============================================================================= +// Declaration Macros +// ============================================================================= + +#define CK_DECL_CAT_(a, b) CK_DECL_CAT_IMPL_(a, b) +#define CK_DECL_CAT_IMPL_(a, b) a##b + +// Note: __extension__ suppresses warnings about __COUNTER__ being a GCC/Clang extension +#define DECL_KERNEL(sig, algo, ...) \ + __extension__ static ::ck_tile::dispatcher::decl::Declarator CK_DECL_CAT_( \ + _kdecl_, __COUNTER__)(sig, algo, ##__VA_ARGS__) + +#define DECL_KERNEL_SIMPLE(dtype, layout, tm, tn, tk) \ + __extension__ static ::ck_tile::dispatcher::decl::Declarator CK_DECL_CAT_( \ + _kdecl_, __COUNTER__)(#dtype, #layout, tm, tn, tk) + +#define DECL_KERNEL_ALL(dtype, layout) \ + __extension__ static ::ck_tile::dispatcher::decl::Declarator CK_DECL_CAT_( \ + _kdecl_, __COUNTER__)(#dtype, #layout, "*") + +#define DECL_KERNEL_SET(name, ...) \ + __extension__ static ::ck_tile::dispatcher::decl::KernelSetRegistrar CK_DECL_CAT_( \ + _kset_reg_, __COUNTER__)(#name, \ + ::ck_tile::dispatcher::decl::KernelSet() __VA_ARGS__.tag(#name)) + +#define KERNEL_SET(name) ::ck_tile::dispatcher::decl::KernelSet name +#define BEGIN_KERNEL_SET() ::ck_tile::dispatcher::decl::KernelSet() + +// Legacy compatibility +// Legacy aliases removed - use DECL_KERNEL_SET instead diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp new file mode 100644 index 0000000000..4a734f4c3f --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp @@ -0,0 +1,68 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// KernelInstance: Uniform interface for kernel execution +/// Abstracts away implementation details (CK Library vs CK Tile vs future JIT) +/// Enables type-erased storage in registry while backends perform type-safe casts +class KernelInstance +{ + public: + virtual ~KernelInstance() = default; + + /// Get the kernel's configuration metadata + [[nodiscard]] virtual const KernelKey& get_key() const = 0; + + /// Check if this kernel supports the given problem + /// Returns false if problem dimensions don't meet kernel requirements + /// (e.g., divisibility constraints, resource limits) + [[nodiscard]] virtual bool supports(const Problem& problem) const = 0; + + /// Get human-readable kernel name for logging and debugging + [[nodiscard]] virtual std::string get_name() const = 0; + + /// Execute the kernel with given problem and data pointers + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param d_ptrs Array of pointers to additional D tensors for fusion (device memory) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds (0 if timing not available) + [[nodiscard]] virtual float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream = nullptr) const = 0; + + /// Validate kernel output against reference implementation + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, kernel output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param tolerance Relative error tolerance for validation + /// @return true if validation passes, false otherwise + [[nodiscard]] virtual bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance = 1e-3f) const = 0; +}; + +/// Shared pointer type for kernel instances +using KernelInstancePtr = std::shared_ptr; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp new file mode 100644 index 0000000000..f49b3a0d74 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp @@ -0,0 +1,428 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Data types supported by CK Tile GEMM kernels +/// Matches tile_engine DATA_TYPE_MAP for full compatibility +enum class DataType : std::uint8_t +{ + FP16, // ck_tile::half_t + BF16, // ck_tile::bf16_t + FP32, // float + FP64, // double + FP8, // ck_tile::fp8_t (E4M3) + BF8, // ck_tile::bf8_t (E5M2) + INT8, // ck_tile::int8_t + INT4, // ck_tile::pk_int4_t (packed int4) + INT32, // ck_tile::int32_t + UNKNOWN +}; + +/// Memory layout tags for tensors +enum class LayoutTag : std::uint8_t +{ + RowMajor, + ColMajor, + PackedExternal +}; + +/// Pipeline variants for memory/compute optimization +/// Matches tile_engine PIPELINE_MAP for full compatibility +enum class Pipeline : std::uint8_t +{ + Mem, // Memory-bound pipeline + CompV1, // Compute pipeline v1 + CompV2, // Compute pipeline v2 + CompV3, // Compute pipeline v3 + CompV4, // Compute pipeline v4 (double buffering) + CompV5, // Compute pipeline v5 + PreShuffleV1, // Weight preshuffle pipeline v1 + PreShuffleV2 // Weight preshuffle pipeline v2 (optimized) +}; + +/// Epilogue strategies for output processing +/// Matches tile_engine epilogue options for full compatibility +enum class Epilogue : std::uint8_t +{ + None, + Default, // DefaultGemm2DEpilogue + CShuffle, // CShuffleEpilogue (cross-shuffle) + Bias, // Bias addition + Activation, // Fused activation + BiasActivation // Fused bias + activation +}; + +/// Scheduler types for wave coordination +enum class Scheduler : std::uint8_t +{ + Auto, + Intrawave, + Interwave +}; + +/// KernelKey: Compile-time kernel configuration metadata +/// Organized into Signature (what operation) and Algorithm (how it's implemented) +struct KernelKey +{ + /// Signature: Describes WHAT operation is computed (mathematical semantics) + /// Two kernels with different signatures compute different mathematical operations + struct Signature + { + DataType dtype_a; + DataType dtype_b; + DataType dtype_c; + DataType dtype_acc; + LayoutTag layout_a; + LayoutTag layout_b; + LayoutTag layout_c; + bool transpose_a; + bool transpose_b; + bool grouped; + std::uint8_t split_k; + + // Element-wise fusion: Describes mathematical operation applied to GEMM output + // Examples: PassThrough (C = A*B), MultiDAdd (E = C + D0 + D1), + // MultiDMultiply (E = C * D0 * D1), Clamp, Relu, Gelu, etc. + // This affects the mathematical result, so it belongs in Signature + std::string elementwise_op; // e.g., "PassThrough", "MultiDAdd", "Relu" + std::uint8_t + num_d_tensors; // Number of additional input tensors for fusion (0 for basic GEMM) + + bool structured_sparsity; // 2:4 sparsity affects mathematical correctness + } signature; + + /// Algorithm: Describes HOW it's implemented (performance tuning parameters) + /// Two kernels with same signature but different algorithms compute the same result + /// with different performance characteristics + struct Algorithm + { + // Hierarchical tiling configuration (primary tuning knobs) + struct TileShape + { + std::uint16_t m; + std::uint16_t n; + std::uint16_t k; + } tile_shape; + + struct WaveShape + { + std::uint8_t m; // WarpPerBlock_M in generated kernels + std::uint8_t n; // WarpPerBlock_N + std::uint8_t k; // WarpPerBlock_K + } wave_shape; + + struct WarpTileShape + { + std::uint8_t m; // WarpTileM in generated kernels + std::uint8_t n; // WarpTileN + std::uint8_t k; // WarpTileK + } warp_tile_shape; + + // Pipeline and scheduling strategy + Pipeline pipeline; + Scheduler scheduler; + Epilogue epilogue; + + // Block and memory configuration + std::uint16_t block_size; // BlockSize in generated kernels (typically 256) + bool double_buffer; // DoubleSmemBuffer (true for compv4) + bool persistent; // UsePersistentKernel + bool preshuffle; // Preshuffle (for weight preshuffle variants) + bool transpose_c; // TransposeC + std::uint8_t num_wave_groups; // NumWaveGroups + } algorithm; + + std::string gfx_arch; // e.g. "gfx942", "gfx90a", "gfx908" + + /// Generate a unique string identifier for this kernel configuration + /// Format matches tile_engine naming convention for registry lookup + /// Note: Defined after to_string() functions to use them + [[nodiscard]] std::string encode_identifier() const; + + /// Create a tuple of all fields for comparison operators + auto tie() const + { + return std::tie(signature.dtype_a, + signature.dtype_b, + signature.dtype_c, + signature.dtype_acc, + signature.layout_a, + signature.layout_b, + signature.layout_c, + signature.transpose_a, + signature.transpose_b, + signature.grouped, + signature.split_k, + signature.elementwise_op, + signature.num_d_tensors, + signature.structured_sparsity, + algorithm.tile_shape.m, + algorithm.tile_shape.n, + algorithm.tile_shape.k, + algorithm.wave_shape.m, + algorithm.wave_shape.n, + algorithm.wave_shape.k, + algorithm.warp_tile_shape.m, + algorithm.warp_tile_shape.n, + algorithm.warp_tile_shape.k, + algorithm.pipeline, + algorithm.epilogue, + algorithm.scheduler, + algorithm.block_size, + gfx_arch, + signature.structured_sparsity, + algorithm.persistent, + algorithm.double_buffer, + algorithm.preshuffle, + algorithm.transpose_c, + algorithm.num_wave_groups); + } + + /// Equality comparison + friend bool operator==(const KernelKey& lhs, const KernelKey& rhs) + { + return lhs.tie() == rhs.tie(); + } + + /// Inequality comparison + friend bool operator!=(const KernelKey& lhs, const KernelKey& rhs) { return !(lhs == rhs); } +}; + +// ============================================================================= +// String Conversion Helpers (for serialization and debugging) +// ============================================================================= + +/// Convert DataType to string +inline std::string to_string(DataType dtype) +{ + switch(dtype) + { + case DataType::FP16: return "fp16"; + case DataType::BF16: return "bf16"; + case DataType::FP32: return "fp32"; + case DataType::FP64: return "fp64"; + case DataType::FP8: return "fp8"; + case DataType::BF8: return "bf8"; + case DataType::INT8: return "int8"; + case DataType::INT4: return "int4"; + case DataType::INT32: return "int32"; + default: return "unknown"; + } +} + +/// Convert string to DataType +inline DataType string_to_dtype(const std::string& str) +{ + if(str == "fp16") + return DataType::FP16; + if(str == "bf16") + return DataType::BF16; + if(str == "fp32") + return DataType::FP32; + if(str == "fp64") + return DataType::FP64; + if(str == "fp8") + return DataType::FP8; + if(str == "bf8") + return DataType::BF8; + if(str == "int8") + return DataType::INT8; + if(str == "int4") + return DataType::INT4; + if(str == "int32") + return DataType::INT32; + return DataType::UNKNOWN; +} + +/// Convert LayoutTag to string +inline std::string to_string(LayoutTag layout) +{ + switch(layout) + { + case LayoutTag::RowMajor: return "r"; + case LayoutTag::ColMajor: return "c"; + case LayoutTag::PackedExternal: return "p"; + default: return "?"; + } +} + +/// Convert string to LayoutTag +inline LayoutTag string_to_layout(const std::string& str) +{ + if(str == "r" || str == "row" || str == "RowMajor") + return LayoutTag::RowMajor; + if(str == "c" || str == "col" || str == "ColMajor") + return LayoutTag::ColMajor; + if(str == "p" || str == "packed") + return LayoutTag::PackedExternal; + return LayoutTag::RowMajor; // Default +} + +/// Convert Pipeline to string +inline std::string to_string(Pipeline pipeline) +{ + switch(pipeline) + { + case Pipeline::Mem: return "mem"; + case Pipeline::CompV1: return "compv1"; + case Pipeline::CompV2: return "compv2"; + case Pipeline::CompV3: return "compv3"; + case Pipeline::CompV4: return "compv4"; + case Pipeline::CompV5: return "compv5"; + case Pipeline::PreShuffleV1: return "preshufflev1"; + case Pipeline::PreShuffleV2: return "preshufflev2"; + default: return "unknown"; + } +} + +/// Convert string to Pipeline +inline Pipeline string_to_pipeline(const std::string& str) +{ + if(str == "mem") + return Pipeline::Mem; + if(str == "compv1") + return Pipeline::CompV1; + if(str == "compv2") + return Pipeline::CompV2; + if(str == "compv3") + return Pipeline::CompV3; + if(str == "compv4") + return Pipeline::CompV4; + if(str == "compv5") + return Pipeline::CompV5; + if(str == "preshufflev1") + return Pipeline::PreShuffleV1; + if(str == "preshufflev2") + return Pipeline::PreShuffleV2; + return Pipeline::Mem; // Default +} + +/// Convert Epilogue to string +inline std::string to_string(Epilogue epilogue) +{ + switch(epilogue) + { + case Epilogue::None: return "none"; + case Epilogue::Default: return "default"; + case Epilogue::CShuffle: return "cshuffle"; + case Epilogue::Bias: return "bias"; + case Epilogue::Activation: return "activation"; + case Epilogue::BiasActivation: return "bias_activation"; + default: return "unknown"; + } +} + +/// Convert string to Epilogue +inline Epilogue string_to_epilogue(const std::string& str) +{ + if(str == "none") + return Epilogue::None; + if(str == "default") + return Epilogue::Default; + if(str == "cshuffle") + return Epilogue::CShuffle; + if(str == "bias") + return Epilogue::Bias; + if(str == "activation") + return Epilogue::Activation; + if(str == "bias_activation") + return Epilogue::BiasActivation; + return Epilogue::Default; // Default +} + +/// Convert Scheduler to string +inline std::string to_string(Scheduler scheduler) +{ + switch(scheduler) + { + case Scheduler::Auto: return "auto"; + case Scheduler::Intrawave: return "intrawave"; + case Scheduler::Interwave: return "interwave"; + default: return "unknown"; + } +} + +/// Convert string to Scheduler +inline Scheduler string_to_scheduler(const std::string& str) +{ + if(str == "auto") + return Scheduler::Auto; + if(str == "intrawave") + return Scheduler::Intrawave; + if(str == "interwave") + return Scheduler::Interwave; + return Scheduler::Intrawave; // Default +} + +/// Common elementwise operations (for reference in elementwise_op field) +/// These match CK Tile's ck_tile::element_wise namespace +namespace ElementwiseOps { +constexpr const char* PassThrough = "PassThrough"; +constexpr const char* Add = "Add"; +constexpr const char* Multiply = "Multiply"; +constexpr const char* MultiDAdd = "MultiDAdd"; +constexpr const char* MultiDMultiply = "MultiDMultiply"; +constexpr const char* Relu = "Relu"; +constexpr const char* Gelu = "Gelu"; +constexpr const char* Clamp = "Clamp"; +constexpr const char* Sigmoid = "Sigmoid"; +constexpr const char* Tanh = "Tanh"; +constexpr const char* Swish = "Swish"; +constexpr const char* HardSwish = "HardSwish"; +} // namespace ElementwiseOps + +// ============================================================================= +// KernelKey::encode_identifier() implementation +// Defined after to_string() functions to use them +// ============================================================================= + +inline std::string KernelKey::encode_identifier() const +{ + std::ostringstream oss; + + // Include data types and layout for uniqueness across different signatures + oss << to_string(signature.dtype_a) << "_"; + oss << to_string(signature.layout_a) << to_string(signature.layout_b) + << to_string(signature.layout_c) << "_"; + + // Include pipeline, scheduler, epilogue for uniqueness + oss << to_string(algorithm.pipeline) << "_"; + oss << to_string(algorithm.scheduler) << "_"; + oss << to_string(algorithm.epilogue) << "_"; + + // Match tile_engine naming: tile_m x tile_n x tile_k _ warp_m x warp_n x warp_k _ + // warp_tile_m x warp_tile_n x warp_tile_k + oss << algorithm.tile_shape.m << "x" << algorithm.tile_shape.n << "x" << algorithm.tile_shape.k + << "_" << unsigned(algorithm.wave_shape.m) << "x" << unsigned(algorithm.wave_shape.n) << "x" + << unsigned(algorithm.wave_shape.k) << "_" << unsigned(algorithm.warp_tile_shape.m) << "x" + << unsigned(algorithm.warp_tile_shape.n) << "x" << unsigned(algorithm.warp_tile_shape.k); + + // Add trait flags + oss << "_" << (algorithm.persistent ? "persist" : "nopers"); + + if(signature.split_k > 1) + oss << "_splitk" << unsigned(signature.split_k); + if(!signature.elementwise_op.empty() && signature.elementwise_op != "PassThrough") + oss << "_" << signature.elementwise_op; + if(signature.num_d_tensors > 0) + oss << "_d" << unsigned(signature.num_d_tensors); + if(signature.structured_sparsity) + oss << "_sparse"; + if(algorithm.preshuffle) + oss << "_preshuffle"; + + return oss.str(); +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/problem.hpp b/dispatcher/include/ck_tile/dispatcher/problem.hpp new file mode 100644 index 0000000000..437511d1ba --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/problem.hpp @@ -0,0 +1,311 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +// ============================================================================= +// Tensor Information for Automatic MNK Inference +// ============================================================================= + +/// TensorShape: Describes tensor dimensions for automatic MNK inference +struct TensorShape +{ + std::int64_t rows; // First dimension + std::int64_t cols; // Second dimension + bool is_transposed; // Whether the tensor is transposed (column-major) + + TensorShape() : rows(0), cols(0), is_transposed(false) {} + TensorShape(std::int64_t r, std::int64_t c, bool trans = false) + : rows(r), cols(c), is_transposed(trans) + { + } + + /// Get logical M (rows when not transposed) + [[nodiscard]] std::int64_t logical_rows() const { return is_transposed ? cols : rows; } + + /// Get logical N (cols when not transposed) + [[nodiscard]] std::int64_t logical_cols() const { return is_transposed ? rows : cols; } +}; + +// ============================================================================= +// Problem: Runtime Parameters +// ============================================================================= + +/// Problem: Runtime parameters for kernel invocation +/// Captures problem dimensions and resource constraints that vary between invocations +/// even when using the same kernel +struct Problem +{ + // Problem dimensions + std::int64_t M; // Number of rows in A and C + std::int64_t N; // Number of columns in B and C + std::int64_t K; // Shared dimension (columns of A, rows of B) + + // Batch configuration + std::int32_t k_batch; // Number of K-dimension splits for split-K GEMM + + // Resource preferences + std::int32_t smem_budget; // Shared memory budget in bytes (0 = no constraint) + bool prefer_persistent; // Prefer persistent kernel variants + + // Validation control + bool enable_validation; // Enable output validation against reference + + /// Default constructor with sensible defaults + Problem() + : M(0), + N(0), + K(0), + k_batch(1), + smem_budget(0), + prefer_persistent(false), + enable_validation(false) + { + } + + /// Constructor with problem dimensions + Problem(std::int64_t m, std::int64_t n, std::int64_t k) + : M(m), + N(n), + K(k), + k_batch(1), + smem_budget(0), + prefer_persistent(false), + enable_validation(false) + { + } + + /// Check if problem dimensions are valid + [[nodiscard]] bool is_valid() const { return M > 0 && N > 0 && K > 0 && k_batch > 0; } + + /// Get total number of operations (for performance metrics) + [[nodiscard]] std::int64_t num_ops() const + { + return 2 * M * N * K; // Multiply-add counts as 2 ops + } + + // ========================================================================= + // Factory Methods for Automatic MNK Inference + // ========================================================================= + + /** + * Create Problem by inferring MNK from tensor shapes. + * + * For GEMM: C[M,N] = A[M,K] × B[K,N] + * + * @param a_shape Shape of matrix A (M x K, or K x M if transposed) + * @param b_shape Shape of matrix B (K x N, or N x K if transposed) + * @param c_shape Shape of matrix C (M x N) - used for validation + * @throws std::invalid_argument if dimensions are inconsistent + * + * Example: + * // A is 512x256, B is 256x1024, C is 512x1024 + * auto problem = Problem::from_shapes({512, 256}, {256, 1024}, {512, 1024}); + * // Infers: M=512, N=1024, K=256 + */ + [[nodiscard]] static Problem + from_shapes(TensorShape a_shape, TensorShape b_shape, TensorShape c_shape) + { + // For C = A × B: + // A: [M, K] (or [K, M] if transposed) + // B: [K, N] (or [N, K] if transposed) + // C: [M, N] + + std::int64_t M_from_A = a_shape.logical_rows(); + std::int64_t K_from_A = a_shape.logical_cols(); + std::int64_t K_from_B = b_shape.logical_rows(); + std::int64_t N_from_B = b_shape.logical_cols(); + std::int64_t M_from_C = c_shape.logical_rows(); + std::int64_t N_from_C = c_shape.logical_cols(); + + // Validate K dimension matches between A and B + if(K_from_A != K_from_B) + { + throw std::invalid_argument( + "K dimension mismatch: A has K=" + std::to_string(K_from_A) + + ", B has K=" + std::to_string(K_from_B)); + } + + // Validate M dimension matches between A and C + if(M_from_A != M_from_C) + { + throw std::invalid_argument( + "M dimension mismatch: A has M=" + std::to_string(M_from_A) + + ", C has M=" + std::to_string(M_from_C)); + } + + // Validate N dimension matches between B and C + if(N_from_B != N_from_C) + { + throw std::invalid_argument( + "N dimension mismatch: B has N=" + std::to_string(N_from_B) + + ", C has N=" + std::to_string(N_from_C)); + } + + return Problem(M_from_A, N_from_B, K_from_A); + } + + /** + * Create Problem from tensor dimensions (simple version without transpose). + * + * @param a_rows Rows of matrix A (= M) + * @param a_cols Columns of matrix A (= K) + * @param b_rows Rows of matrix B (= K) + * @param b_cols Columns of matrix B (= N) + * @param c_rows Rows of matrix C (= M) - for validation + * @param c_cols Columns of matrix C (= N) - for validation + * @throws std::invalid_argument if dimensions are inconsistent + * + * Example: + * // A[512,256] × B[256,1024] = C[512,1024] + * auto problem = Problem::from_dimensions(512, 256, 256, 1024, 512, 1024); + */ + [[nodiscard]] static Problem from_dimensions(std::int64_t a_rows, + std::int64_t a_cols, + std::int64_t b_rows, + std::int64_t b_cols, + std::int64_t c_rows, + std::int64_t c_cols) + { + return from_shapes( + TensorShape(a_rows, a_cols), TensorShape(b_rows, b_cols), TensorShape(c_rows, c_cols)); + } + + /** + * Create Problem from A and B dimensions only (C is inferred). + * + * @param a_rows Rows of matrix A (= M) + * @param a_cols Columns of matrix A (= K) + * @param b_rows Rows of matrix B (= K) - validated + * @param b_cols Columns of matrix B (= N) + * @throws std::invalid_argument if K dimensions don't match + * + * Example: + * // A[512,256] × B[256,1024] = C[512,1024] + * auto problem = Problem::from_ab(512, 256, 256, 1024); + */ + [[nodiscard]] static Problem + from_ab(std::int64_t a_rows, std::int64_t a_cols, std::int64_t b_rows, std::int64_t b_cols) + { + if(a_cols != b_rows) + { + throw std::invalid_argument("K dimension mismatch: A.cols=" + std::to_string(a_cols) + + ", B.rows=" + std::to_string(b_rows)); + } + return Problem(a_rows, b_cols, a_cols); + } + + /** + * Validate that tensor pointers have consistent sizes. + * Call this before kernel execution to catch dimension errors early. + * + * @param a_size Total elements in A tensor + * @param b_size Total elements in B tensor + * @param c_size Total elements in C tensor + * @throws std::invalid_argument if sizes don't match expected dimensions + */ + void validate_sizes(std::int64_t a_size, std::int64_t b_size, std::int64_t c_size) const + { + std::int64_t expected_a = M * K; + std::int64_t expected_b = K * N; + std::int64_t expected_c = M * N; + + if(a_size != expected_a) + { + throw std::invalid_argument("A tensor size mismatch: got " + std::to_string(a_size) + + ", expected " + std::to_string(expected_a) + " (M*K = " + + std::to_string(M) + "*" + std::to_string(K) + ")"); + } + if(b_size != expected_b) + { + throw std::invalid_argument("B tensor size mismatch: got " + std::to_string(b_size) + + ", expected " + std::to_string(expected_b) + " (K*N = " + + std::to_string(K) + "*" + std::to_string(N) + ")"); + } + if(c_size != expected_c) + { + throw std::invalid_argument("C tensor size mismatch: got " + std::to_string(c_size) + + ", expected " + std::to_string(expected_c) + " (M*N = " + + std::to_string(M) + "*" + std::to_string(N) + ")"); + } + } +}; + +// ============================================================================= +// Convenience Builders +// ============================================================================= + +/// Builder pattern for Problem configuration +class ProblemBuilder +{ + public: + ProblemBuilder() = default; + + /// Set dimensions from A and B shapes + ProblemBuilder& + from_ab(std::int64_t a_rows, std::int64_t a_cols, std::int64_t b_rows, std::int64_t b_cols) + { + problem_ = Problem::from_ab(a_rows, a_cols, b_rows, b_cols); + return *this; + } + + /// Set MNK directly + ProblemBuilder& dimensions(std::int64_t m, std::int64_t n, std::int64_t k) + { + problem_.M = m; + problem_.N = n; + problem_.K = k; + return *this; + } + + /// Set split-K batch count + ProblemBuilder& split_k(std::int32_t k_batch) + { + problem_.k_batch = k_batch; + return *this; + } + + /// Set shared memory budget + ProblemBuilder& smem_budget(std::int32_t budget) + { + problem_.smem_budget = budget; + return *this; + } + + /// Prefer persistent kernels + ProblemBuilder& persistent(bool prefer = true) + { + problem_.prefer_persistent = prefer; + return *this; + } + + /// Enable validation + ProblemBuilder& validate(bool enable = true) + { + problem_.enable_validation = enable; + return *this; + } + + /// Build the Problem + [[nodiscard]] Problem build() const + { + if(!problem_.is_valid()) + { + throw std::invalid_argument("Invalid problem dimensions"); + } + return problem_; + } + + private: + Problem problem_; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/registry.hpp b/dispatcher/include/ck_tile/dispatcher/registry.hpp new file mode 100644 index 0000000000..93d1eb9f64 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/registry.hpp @@ -0,0 +1,197 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Registry - Thread-Safe Kernel Storage + * + * Central registry for all available kernel instances with priority-based + * ordering and efficient lookup. + * + * Features: + * - Thread-safe registration and lookup + * - Priority-based ordering (High, Normal, Low) + * - Lookup by name or KernelKey + * - Filter by problem compatibility + * - Supports both singleton and multiple instance patterns + * + * Usage (Singleton - backward compatible): + * auto& registry = Registry::instance(); + * registry.register_kernel(kernel, Priority::High); + * auto kernel = registry.lookup("kernel_name"); + * + * Usage (Multiple registries): + * Registry fp16_registry; + * Registry bf16_registry; + * fp16_registry.register_kernel(fp16_kernel, Priority::High); + * bf16_registry.register_kernel(bf16_kernel, Priority::High); + * + * Dispatcher fp16_dispatcher(&fp16_registry); + * Dispatcher bf16_dispatcher(&bf16_registry); + * + * Status: Production ready, thread-safe + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Registry: Central mapping from kernel configurations to executable instances +/// Thread-safe kernel registration and lookup +/// Supports both singleton pattern and multiple independent instances +class Registry +{ + public: + /// Priority levels for conflict resolution when multiple kernels have same key + enum class Priority + { + Low = 0, + Normal = 1, + High = 2 + }; + + /// Default constructor - creates an empty registry instance + /// Use this to create independent registries for different kernel sets + Registry(); + + /// Destructor - triggers auto-export if enabled + ~Registry(); + + /// Move constructor + Registry(Registry&& other) noexcept; + + /// Move assignment + Registry& operator=(Registry&& other) noexcept; + + // Prevent copying (registries contain shared_ptrs that shouldn't be duplicated) + Registry(const Registry&) = delete; + Registry& operator=(const Registry&) = delete; + + /// Register a kernel instance with the registry + /// @param instance Kernel instance to register + /// @param priority Priority level for conflict resolution (default: Normal) + /// @return true if registered successfully, false if duplicate with higher priority exists + bool register_kernel(KernelInstancePtr instance, Priority priority = Priority::Normal); + + /// Lookup a kernel by its string identifier + /// @param identifier Kernel identifier string + /// @return Kernel instance if found, nullptr otherwise + [[nodiscard]] KernelInstancePtr lookup(const std::string& identifier) const; + + /// Lookup a kernel by its KernelKey + /// @param key Kernel configuration key + /// @return Kernel instance if found, nullptr otherwise + [[nodiscard]] KernelInstancePtr lookup(const KernelKey& key) const; + + /// Get all registered kernels + /// @return Vector of all kernel instances + [[nodiscard]] std::vector get_all() const; + + /// Get all kernels matching a predicate + /// @param predicate Function to filter kernels + /// @return Vector of matching kernel instances + [[nodiscard]] std::vector + filter(std::function predicate) const; + + /// Get number of registered kernels + [[nodiscard]] std::size_t size() const; + + /// Check if registry is empty + [[nodiscard]] bool empty() const; + + /// Clear all registered kernels + void clear(); + + /// Get registry name (for logging/debugging) + [[nodiscard]] const std::string& get_name() const; + + /// Set registry name (for logging/debugging) + void set_name(const std::string& name); + + /// Export registry to JSON string + /// @param include_statistics Whether to include kernel statistics breakdown + /// @return JSON string with all kernel metadata + [[nodiscard]] std::string export_json(bool include_statistics = true) const; + + /// Export registry to JSON file + /// @param filename Output filename + /// @param include_statistics Whether to include kernel statistics breakdown + /// @return true if export succeeded, false otherwise + bool export_json_to_file(const std::string& filename, bool include_statistics = true) const; + + /// Enable automatic JSON export on kernel registration + /// @param filename Output filename for auto-export + /// @param include_statistics Whether to include statistics in auto-export + /// @param export_on_every_registration If true, exports after every registration (default). + /// If false, only exports on destruction. + void enable_auto_export(const std::string& filename, + bool include_statistics = true, + bool export_on_every_registration = true); + + /// Disable automatic JSON export + void disable_auto_export(); + + /// Check if auto-export is enabled + [[nodiscard]] bool is_auto_export_enabled() const; + + /// Merge kernels from another registry into this one + /// @param other Registry to merge from + /// @param priority Priority for merged kernels (default: Normal) + /// @return Number of kernels successfully merged + std::size_t merge_from(const Registry& other, Priority priority = Priority::Normal); + + /// Filter kernels in-place by architecture + /// @param gpu_arch Target GPU architecture string (e.g., "gfx942") + /// @return Number of kernels removed + std::size_t filter_by_arch(const std::string& gpu_arch); + + /// Get singleton instance of the global registry (backward compatible) + /// This is the default registry used when no specific registry is provided + static Registry& instance(); + + private: + struct RegistryEntry + { + KernelInstancePtr instance; + Priority priority; + }; + + /// Perform auto-export if enabled + void perform_auto_export(); + + mutable std::mutex mutex_; + std::unordered_map kernels_; + std::string name_; + + // Auto-export configuration + bool auto_export_enabled_ = false; + std::string auto_export_filename_; + bool auto_export_include_statistics_ = true; + bool auto_export_on_every_registration_ = true; +}; + +/// Shared pointer type for registries (useful for managing lifetime) +using RegistryPtr = std::shared_ptr; + +/// Create a new registry instance (factory function) +inline RegistryPtr make_registry(const std::string& name = "") +{ + auto reg = std::make_shared(); + if(!name.empty()) + { + reg->set_name(name); + } + return reg; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/utils.hpp b/dispatcher/include/ck_tile/dispatcher/utils.hpp new file mode 100644 index 0000000000..0f9990c45e --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/utils.hpp @@ -0,0 +1,724 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file utils.hpp + * @brief Common utilities for CK Tile Dispatcher + * + * This header provides reusable utilities for: + * - GPU memory management (GpuBuffer) + * - Performance measurement (Timer, GpuTimer, BenchmarkStats) + * - Validation (ValidationResult, validate_result) + * - Kernel registration helpers + * - Data generation (fill_random, etc.) + * + * Usage: + * #include "ck_tile/dispatcher/utils.hpp" + * using namespace ck_tile::dispatcher::utils; + * + * // GPU memory + * GpuBuffer buffer(1024); + * + * // Timing + * GpuTimer timer; + * timer.start(); + * // ... kernel ... + * timer.stop(); + * float ms = timer.elapsed_ms(); + * + * // Validation + * auto result = validate_result(gpu_data, ref_data, size); + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +namespace ck_tile { +namespace dispatcher { +namespace utils { + +// ============================================================================= +// HIP Error Handling +// ============================================================================= + +#define CK_HIP_CHECK(call) \ + do \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP error at " << __FILE__ << ":" << __LINE__ << ": " \ + << hipGetErrorString(err) << std::endl; \ + return false; \ + } \ + } while(0) + +#define CK_HIP_CHECK_THROW(call) \ + do \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + throw std::runtime_error(std::string("HIP error: ") + hipGetErrorString(err)); \ + } \ + } while(0) + +// ============================================================================= +// Timing Utilities +// ============================================================================= + +/** + * @brief High-resolution timer for CPU timing + */ +class Timer +{ + public: + void start() { start_ = std::chrono::high_resolution_clock::now(); } + + double elapsed_ms() const + { + auto end = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end - start_).count(); + } + + private: + std::chrono::high_resolution_clock::time_point start_; +}; + +/** + * @brief GPU timing using HIP events + * + * Times kernel execution on a specific HIP stream. Events are recorded + * on the provided stream to accurately measure kernel execution time. + * + * Usage: + * hipStream_t stream; + * hipStreamCreate(&stream); + * GpuTimer timer(stream); // or timer.set_stream(stream) + * timer.start(); + * kernel<<>>(...); + * timer.stop(); + * float ms = timer.elapsed_ms(); + */ +class GpuTimer +{ + public: + /** + * @brief Construct timer with optional stream + * @param stream HIP stream to record events on (default: null stream) + */ + explicit GpuTimer(hipStream_t stream = nullptr) : stream_(stream) + { + (void)hipEventCreate(&start_); + (void)hipEventCreate(&stop_); + } + + ~GpuTimer() + { + (void)hipEventDestroy(start_); + (void)hipEventDestroy(stop_); + } + + // Non-copyable + GpuTimer(const GpuTimer&) = delete; + GpuTimer& operator=(const GpuTimer&) = delete; + + // Movable + GpuTimer(GpuTimer&& other) noexcept + : start_(other.start_), stop_(other.stop_), stream_(other.stream_) + { + other.start_ = nullptr; + other.stop_ = nullptr; + other.stream_ = nullptr; + } + + GpuTimer& operator=(GpuTimer&& other) noexcept + { + if(this != &other) + { + if(start_) + (void)hipEventDestroy(start_); + if(stop_) + (void)hipEventDestroy(stop_); + start_ = other.start_; + stop_ = other.stop_; + stream_ = other.stream_; + other.start_ = nullptr; + other.stop_ = nullptr; + other.stream_ = nullptr; + } + return *this; + } + + /** + * @brief Set the stream to record events on + * @param stream HIP stream (pass nullptr for default stream) + */ + void set_stream(hipStream_t stream) { stream_ = stream; } + + /** + * @brief Get the current stream + */ + hipStream_t get_stream() const { return stream_; } + + /** + * @brief Record start event on the stream + */ + void start() { (void)hipEventRecord(start_, stream_); } + + /** + * @brief Record stop event on the stream + */ + void stop() { (void)hipEventRecord(stop_, stream_); } + + /** + * @brief Get elapsed time in milliseconds + * + * Synchronizes on the stop event before calculating time. + * @return Elapsed time between start and stop in milliseconds + */ + float elapsed_ms() + { + (void)hipEventSynchronize(stop_); + float ms = 0; + (void)hipEventElapsedTime(&ms, start_, stop_); + return ms; + } + + private: + hipEvent_t start_ = nullptr; + hipEvent_t stop_ = nullptr; + hipStream_t stream_ = nullptr; +}; + +// ============================================================================= +// Performance Metrics +// ============================================================================= + +/** + * @brief Calculate TFLOPS for GEMM + */ +inline double calculate_tflops(int64_t M, int64_t N, int64_t K, double time_ms) +{ + double flops = 2.0 * M * N * K; + return (flops / (time_ms * 1e-3)) / 1e12; +} + +/** + * @brief Calculate memory bandwidth in GB/s + */ +template +inline double calculate_bandwidth_gbs(int64_t M, int64_t N, int64_t K, double time_ms) +{ + double bytes = M * K * sizeof(AType) + K * N * sizeof(BType) + M * N * sizeof(CType); + return (bytes / (time_ms * 1e-3)) / 1e9; +} + +/** + * @brief Benchmark statistics + */ +struct BenchmarkStats +{ + double min_ms = 0; + double avg_ms = 0; + double max_ms = 0; + double median_ms = 0; + double tflops = 0; + double bandwidth_gbs = 0; + int iterations = 0; + + void print(std::ostream& os = std::cout) const + { + os << std::fixed << std::setprecision(4); + os << " Min: " << min_ms << " ms\n"; + os << " Avg: " << avg_ms << " ms\n"; + os << " Max: " << max_ms << " ms\n"; + os << " Median: " << median_ms << " ms\n"; + os << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + os << " Bandwidth: " << bandwidth_gbs << " GB/s\n"; + } +}; + +/** + * @brief Run benchmark and compute statistics + */ +template +BenchmarkStats run_benchmark(Func&& func, int warmup = 2, int iterations = 10) +{ + std::vector times; + times.reserve(iterations); + + for(int i = 0; i < warmup; ++i) + func(); + + for(int i = 0; i < iterations; ++i) + times.push_back(func()); + + std::sort(times.begin(), times.end()); + + BenchmarkStats stats; + stats.iterations = iterations; + stats.min_ms = times.front(); + stats.max_ms = times.back(); + stats.median_ms = times[iterations / 2]; + + double sum = 0; + for(double t : times) + sum += t; + stats.avg_ms = sum / iterations; + + return stats; +} + +// ============================================================================= +// Validation Utilities +// ============================================================================= + +/** + * @brief Validation result + */ +struct ValidationResult +{ + bool correct = false; + double max_diff = 0; + double mean_diff = 0; + double accuracy = 0; + int64_t matches = 0; + int64_t total = 0; + + void print(std::ostream& os = std::cout) const + { + os << " Correct: " << (correct ? "YES" : "NO") << "\n"; + os << " Max diff: " << max_diff << "\n"; + os << " Mean diff: " << mean_diff << "\n"; + os << " Accuracy: " << accuracy << "%\n"; + os << " Matches: " << matches << "/" << total << "\n"; + } +}; + +/** + * @brief Validate GEMM result against reference + */ +template +ValidationResult validate_result( + const T* result, const T* reference, int64_t size, double rtol = 1e-3, double atol = 1e-2) +{ + ValidationResult v; + v.total = size; + v.max_diff = 0; + v.matches = 0; + + double sum_diff = 0; + + for(int64_t i = 0; i < size; ++i) + { + double r = static_cast(result[i]); + double ref = static_cast(reference[i]); + double diff = std::abs(r - ref); + + v.max_diff = std::max(v.max_diff, diff); + sum_diff += diff; + + double threshold = atol + rtol * std::abs(ref); + if(diff <= threshold) + ++v.matches; + } + + v.mean_diff = sum_diff / size; + v.accuracy = 100.0 * v.matches / v.total; + v.correct = (v.matches == v.total) || (v.accuracy >= 99.9); + + return v; +} + +/** + * @brief Compute reference GEMM on CPU + */ +template +void compute_reference_gemm( + const AType* A, const BType* B, CType* C, int64_t M, int64_t N, int64_t K) +{ + for(int64_t m = 0; m < M; ++m) + { + for(int64_t n = 0; n < N; ++n) + { + double acc = 0; + for(int64_t k = 0; k < K; ++k) + acc += static_cast(A[m * K + k]) * static_cast(B[k * N + n]); + C[m * N + n] = static_cast(acc); + } + } +} + +// ============================================================================= +// Data Generation +// ============================================================================= + +template +void fill_random(T* data, int64_t size, T min_val = T(-1), T max_val = T(1)) +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(static_cast(min_val), + static_cast(max_val)); + for(int64_t i = 0; i < size; ++i) + data[i] = static_cast(dist(gen)); +} + +template +void fill_zeros(T* data, int64_t size) +{ + std::fill(data, data + size, T(0)); +} + +template +void fill_ones(T* data, int64_t size) +{ + std::fill(data, data + size, T(1)); +} + +template +void fill_identity(T* data, int64_t rows, int64_t cols) +{ + fill_zeros(data, rows * cols); + int64_t min_dim = std::min(rows, cols); + for(int64_t i = 0; i < min_dim; ++i) + data[i * cols + i] = T(1); +} + +// ============================================================================= +// GPU Memory Management +// ============================================================================= + +/** + * @brief RAII wrapper for GPU memory + */ +template +class GpuBuffer +{ + public: + GpuBuffer() : data_(nullptr), size_(0) {} + + explicit GpuBuffer(int64_t count) : size_(count * sizeof(T)) + { + CK_HIP_CHECK_THROW(hipMalloc(&data_, size_)); + } + + ~GpuBuffer() + { + if(data_) + (void)hipFree(data_); + } + + // Non-copyable + GpuBuffer(const GpuBuffer&) = delete; + GpuBuffer& operator=(const GpuBuffer&) = delete; + + // Movable + GpuBuffer(GpuBuffer&& other) noexcept : data_(other.data_), size_(other.size_) + { + other.data_ = nullptr; + other.size_ = 0; + } + + GpuBuffer& operator=(GpuBuffer&& other) noexcept + { + if(this != &other) + { + if(data_) + (void)hipFree(data_); + data_ = other.data_; + size_ = other.size_; + other.data_ = nullptr; + other.size_ = 0; + } + return *this; + } + + T* get() { return data_; } + const T* get() const { return data_; } + int64_t size_bytes() const { return size_; } + int64_t count() const { return size_ / sizeof(T); } + + void copy_from_host(const T* host_data) + { + CK_HIP_CHECK_THROW(hipMemcpy(data_, host_data, size_, hipMemcpyHostToDevice)); + } + + void copy_to_host(T* host_data) const + { + CK_HIP_CHECK_THROW(hipMemcpy(host_data, data_, size_, hipMemcpyDeviceToHost)); + } + + void zero() { CK_HIP_CHECK_THROW(hipMemset(data_, 0, size_)); } + + private: + T* data_; + int64_t size_; +}; + +// ============================================================================= +// Printing Utilities +// ============================================================================= + +inline void print_separator(char c = '=', int width = 70) +{ + std::cout << std::string(width, c) << "\n"; +} + +inline void print_header(const std::string& title) +{ + print_separator(); + std::cout << title << "\n"; + print_separator(); +} + +inline std::string format_size(int64_t M, int64_t N, int64_t K) +{ + std::ostringstream oss; + oss << M << "x" << N << "x" << K; + return oss.str(); +} + +inline std::string format_number(int64_t n) +{ + std::string s = std::to_string(n); + int pos = static_cast(s.length()) - 3; + while(pos > 0) + { + s.insert(pos, ","); + pos -= 3; + } + return s; +} + +/** + * @brief Print all registered kernels in a registry + * + * @param registry The registry to list kernels from + * @param os Output stream (default: std::cout) + * @param verbose If true, show full kernel config details + */ +inline void print_registered_kernels(const Registry& registry, + std::ostream& os = std::cout, + bool verbose = false) +{ + const auto& kernels = registry.get_all(); + os << "Registered Kernels (" << kernels.size() << "):\n"; + os << std::string(70, '-') << "\n"; + + int idx = 1; + for(const auto& kernel : kernels) + { + const auto& key = kernel->get_key(); + + os << " " << idx++ << ". " << kernel->get_name() << "\n"; + + if(verbose) + { + os << " Tile: " << key.algorithm.tile_shape.m << "x" + << key.algorithm.tile_shape.n << "x" << key.algorithm.tile_shape.k << "\n"; + os << " Wave: " << static_cast(key.algorithm.wave_shape.m) << "x" + << static_cast(key.algorithm.wave_shape.n) << "x" + << static_cast(key.algorithm.wave_shape.k) << "\n"; + os << " WarpTile: " << static_cast(key.algorithm.warp_tile_shape.m) << "x" + << static_cast(key.algorithm.warp_tile_shape.n) << "x" + << static_cast(key.algorithm.warp_tile_shape.k) << "\n"; + os << " Pipeline: " << to_string(key.algorithm.pipeline) << "\n"; + os << " Scheduler: " << to_string(key.algorithm.scheduler) << "\n"; + os << " Arch: " << key.gfx_arch << "\n"; + os << "\n"; + } + } + + if(!verbose && !kernels.empty()) + { + os << "\n Use --list-verbose for full details\n"; + } + os << std::string(70, '-') << "\n"; +} + +/** + * @brief Print a single kernel's configuration + */ +inline void print_kernel_info(const KernelInstance& kernel, std::ostream& os = std::cout) +{ + const auto& key = kernel.get_key(); + + os << "Kernel: " << kernel.get_name() << "\n"; + os << " Signature:\n"; + os << " dtype: " << to_string(key.signature.dtype_a) << "/" + << to_string(key.signature.dtype_b) << "/" << to_string(key.signature.dtype_c) << "\n"; + os << " layout: " << to_string(key.signature.layout_a) << to_string(key.signature.layout_b) + << to_string(key.signature.layout_c) << "\n"; + + os << " Algorithm:\n"; + os << " tile: " << key.algorithm.tile_shape.m << "x" << key.algorithm.tile_shape.n + << "x" << key.algorithm.tile_shape.k << "\n"; + os << " wave: " << static_cast(key.algorithm.wave_shape.m) << "x" + << static_cast(key.algorithm.wave_shape.n) << "x" + << static_cast(key.algorithm.wave_shape.k) << "\n"; + os << " warp_tile: " << static_cast(key.algorithm.warp_tile_shape.m) << "x" + << static_cast(key.algorithm.warp_tile_shape.n) << "x" + << static_cast(key.algorithm.warp_tile_shape.k) << "\n"; + os << " pipeline: " << to_string(key.algorithm.pipeline) << "\n"; + os << " scheduler: " << to_string(key.algorithm.scheduler) << "\n"; + os << " epilogue: " << to_string(key.algorithm.epilogue) << "\n"; + + os << " Target: " << key.gfx_arch << "\n"; +} + +// ============================================================================= +// Kernel Key Builders +// ============================================================================= + +/** + * @brief Build a KernelKey for FP16 Row-Col-Row layout GEMM + * + * This is the most common configuration. Customize parameters as needed. + */ +struct KernelKeyBuilder +{ + // Tile shape + int tile_m = 128; + int tile_n = 128; + int tile_k = 32; + + // Wave shape (warps per block) + int wave_m = 2; + int wave_n = 2; + int wave_k = 1; + + // Warp tile shape + int warp_m = 32; + int warp_n = 32; + int warp_k = 16; + + // Block size + int block_size = 256; + + // Data types + DataType dtype_a = DataType::FP16; + DataType dtype_b = DataType::FP16; + DataType dtype_c = DataType::FP16; + DataType dtype_acc = DataType::FP32; + + // Layouts + LayoutTag layout_a = LayoutTag::RowMajor; + LayoutTag layout_b = LayoutTag::ColMajor; + LayoutTag layout_c = LayoutTag::RowMajor; + + // Pipeline/scheduler + Pipeline pipeline = Pipeline::CompV4; + Scheduler scheduler = Scheduler::Intrawave; + Epilogue epilogue = Epilogue::CShuffle; + + // Features + bool preshuffle = false; + int num_d_tensors = 0; // Multi-D: number of additional input tensors + std::string elementwise_op = "PassThrough"; + + // Target GPU + std::string gfx_arch = "gfx942"; + + /** + * @brief Build the KernelKey + */ + KernelKey build() const + { + KernelKey key; + + // Signature + key.signature.dtype_a = dtype_a; + key.signature.dtype_b = dtype_b; + key.signature.dtype_c = dtype_c; + key.signature.dtype_acc = dtype_acc; + key.signature.layout_a = layout_a; + key.signature.layout_b = layout_b; + key.signature.layout_c = layout_c; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = elementwise_op; + key.signature.num_d_tensors = num_d_tensors; + key.signature.structured_sparsity = false; + + // Algorithm + key.algorithm.tile_shape = {static_cast(tile_m), + static_cast(tile_n), + static_cast(tile_k)}; + key.algorithm.wave_shape = {static_cast(wave_m), + static_cast(wave_n), + static_cast(wave_k)}; + key.algorithm.warp_tile_shape = {static_cast(warp_m), + static_cast(warp_n), + static_cast(warp_k)}; + key.algorithm.pipeline = pipeline; + key.algorithm.scheduler = scheduler; + key.algorithm.epilogue = epilogue; + key.algorithm.block_size = block_size; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = preshuffle; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + + key.gfx_arch = gfx_arch; + + return key; + } + + // Convenience preset methods + static KernelKeyBuilder fp16_rcr() { return KernelKeyBuilder{}; } + + static KernelKeyBuilder fp16_rrr() + { + auto b = KernelKeyBuilder{}; + b.layout_b = LayoutTag::RowMajor; + return b; + } + + static KernelKeyBuilder preshuffle_v1() + { + auto b = KernelKeyBuilder{}; + b.pipeline = Pipeline::PreShuffleV1; + b.preshuffle = true; + return b; + } + + static KernelKeyBuilder preshuffle_v2() + { + auto b = KernelKeyBuilder{}; + b.pipeline = Pipeline::PreShuffleV2; + b.preshuffle = true; + return b; + } + + static KernelKeyBuilder multi_d(int num_d, const std::string& op = "MultiDAdd") + { + auto b = KernelKeyBuilder{}; + b.num_d_tensors = num_d; + b.elementwise_op = op; + return b; + } +}; + +} // namespace utils +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp b/dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp new file mode 100644 index 0000000000..a7e063c3cc --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp @@ -0,0 +1,228 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/problem.hpp" +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace validation { + +/// Reference CPU GEMM implementation for validation +template +void reference_gemm_cpu(const ADataType* a, + const BDataType* b, + CDataType* c, + int M, + int N, + int K, + int stride_a, + int stride_b, + int stride_c, + bool transpose_a = false, + bool transpose_b = false) +{ + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + AccDataType acc = 0; + + for(int k = 0; k < K; ++k) + { + // Get A element + int a_idx = transpose_a ? (k * stride_a + m) : (m * stride_a + k); + AccDataType a_val = static_cast(a[a_idx]); + + // Get B element + int b_idx = transpose_b ? (n * stride_b + k) : (k * stride_b + n); + AccDataType b_val = static_cast(b[b_idx]); + + acc += a_val * b_val; + } + + // Write C element + int c_idx = m * stride_c + n; + c[c_idx] = static_cast(acc); + } + } +} + +/// Validate kernel output against reference +template +bool validate_output(const CDataType* result, + const CDataType* reference, + int size, + float rtol = 1e-3f, + float atol = 1e-5f) +{ + int errors = 0; + const int max_errors_to_print = 10; + + for(int i = 0; i < size; ++i) + { + float res_val = static_cast(result[i]); + float ref_val = static_cast(reference[i]); + + float abs_diff = std::abs(res_val - ref_val); + float abs_ref = std::abs(ref_val); + + bool is_valid = (abs_diff <= atol) || (abs_diff <= rtol * abs_ref); + + if(!is_valid) + { + if(errors < max_errors_to_print) + { + printf("Mismatch at index %d: result=%.6f, reference=%.6f, diff=%.6e\n", + i, + res_val, + ref_val, + abs_diff); + } + errors++; + } + } + + if(errors > 0) + { + printf("Validation failed: %d/%d elements mismatched (%.2f%%)\n", + errors, + size, + 100.0f * errors / size); + return false; + } + + return true; +} + +/// Validate kernel with reference implementation +template +bool validate_gemm_kernel(const void* a_dev_ptr, + const void* b_dev_ptr, + const void* c_dev_ptr, + const Problem& problem, + float rtol = 1e-3f, + float atol = 1e-5f) +{ + const int M = problem.M; + const int N = problem.N; + const int K = problem.K; + + // Allocate host memory + std::vector a_host(M * K); + std::vector b_host(K * N); + std::vector c_host(M * N); + std::vector c_ref(M * N); + + // Copy from device + hipMemcpy(a_host.data(), a_dev_ptr, M * K * sizeof(ADataType), hipMemcpyDeviceToHost); + hipMemcpy(b_host.data(), b_dev_ptr, K * N * sizeof(BDataType), hipMemcpyDeviceToHost); + hipMemcpy(c_host.data(), c_dev_ptr, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); + + // Compute reference + reference_gemm_cpu(a_host.data(), + b_host.data(), + c_ref.data(), + M, + N, + K, + K, // stride_a (row-major) + N, // stride_b (row-major) + N, // stride_c (row-major) + false, + false); + + // Validate + return validate_output(c_host.data(), c_ref.data(), M * N, rtol, atol); +} + +/// Validator class for kernel instances +class KernelValidator +{ + public: + KernelValidator(float rtol = 1e-3f, float atol = 1e-5f) : rtol_(rtol), atol_(atol) {} + + /// Validate a kernel instance + template + bool validate(KernelInstance& kernel, + const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const Problem& problem) + { + // Use kernel's validate method if available + return kernel.validate(a_ptr, b_ptr, c_ptr, problem, rtol_, atol_); + } + + /// Set tolerances + void set_tolerances(float rtol, float atol) + { + rtol_ = rtol; + atol_ = atol; + } + + /// Get tolerances + std::pair get_tolerances() const { return {rtol_, atol_}; } + + private: + float rtol_; + float atol_; +}; + +/// Helper to generate random test data +template +void generate_random_data(T* data, int size, float min_val = -1.0f, float max_val = 1.0f) +{ + for(int i = 0; i < size; ++i) + { + float rand_val = min_val + (max_val - min_val) * (rand() / (float)RAND_MAX); + data[i] = static_cast(rand_val); + } +} + +/// Helper to allocate and initialize test tensors +template +struct TestTensor +{ + T* host_ptr; + T* device_ptr; + int size; + + TestTensor(int size_) : size(size_) + { + host_ptr = new T[size]; + hipMalloc(&device_ptr, size * sizeof(T)); + } + + ~TestTensor() + { + delete[] host_ptr; + hipFree(device_ptr); + } + + void randomize(float min_val = -1.0f, float max_val = 1.0f) + { + generate_random_data(host_ptr, size, min_val, max_val); + hipMemcpy(device_ptr, host_ptr, size * sizeof(T), hipMemcpyHostToDevice); + } + + void copy_to_device() + { + hipMemcpy(device_ptr, host_ptr, size * sizeof(T), hipMemcpyHostToDevice); + } + + void copy_from_device() + { + hipMemcpy(host_ptr, device_ptr, size * sizeof(T), hipMemcpyDeviceToHost); + } + + void zero() { hipMemset(device_ptr, 0, size * sizeof(T)); } +}; + +} // namespace validation +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/python/CMakeLists.txt b/dispatcher/python/CMakeLists.txt new file mode 100644 index 0000000000..e57678952e --- /dev/null +++ b/dispatcher/python/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# This directory contains Python utilities for the dispatcher examples. +# The main utility file is ctypes_utils.py which is used by GEMM Python examples. +# Conv Python examples use their own conv_utils.py in the examples directory. + +# No build targets needed - these are pure Python utilities. +message(STATUS "Python utilities directory configured (no build targets)") diff --git a/dispatcher/python/README.md b/dispatcher/python/README.md new file mode 100644 index 0000000000..9286acbf72 --- /dev/null +++ b/dispatcher/python/README.md @@ -0,0 +1,60 @@ +# CK Tile Dispatcher Python Utilities + +This directory contains Python utilities used by the dispatcher examples. + +## Contents + +- `ctypes_utils.py` - Core ctypes utilities for GEMM Python examples + - `KernelConfig` - Kernel configuration dataclass + - `setup_gemm_dispatcher()` - Setup dispatcher with auto-correction + - `cleanup_gemm()` - Cleanup dispatcher resources + - `GemmRunner` - GPU execution helper + - Auto-correction and validation utilities + +- `conv_utils.py` - Core utilities for Conv Python examples + - `ConvSignature`, `ConvAlgorithm` - Convolution configuration + - `ConvProblem` - Problem definition + - `GpuConvRunner` - GPU execution helper + - `EnhancedConvCodegenRunner` - Kernel codegen utilities + +## Usage + +### GEMM Examples + +The GEMM Python examples in `dispatcher/examples/gemm/python/` import: + +```python +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + GemmRunner, +) +``` + +### Conv Examples + +The Conv Python examples in `dispatcher/examples/conv/python/` import: + +```python +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from conv_utils import ( + ConvSignature, + ConvAlgorithm, + ConvProblem, + GpuConvRunner, +) +``` + +## Requirements + +- Python 3.8+ +- NumPy +- HIP runtime (for GPU execution) diff --git a/dispatcher/python/ctypes_utils.py b/dispatcher/python/ctypes_utils.py new file mode 100644 index 0000000000..821fc2b08d --- /dev/null +++ b/dispatcher/python/ctypes_utils.py @@ -0,0 +1,2347 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +CK Tile Dispatcher Utilities + +Common utilities for loading, compiling, and using the CK Tile dispatcher. + +Usage: + from ck_tile_dispatcher.utils import DispatcherLib, GemmRunner, Validator + + # Option 1: Auto-compile and load + lib = DispatcherLib.auto() + + # Option 2: Load existing library + lib = DispatcherLib.load("/path/to/libdispatcher_gemm.so") + + # Run GEMM + runner = GemmRunner(lib) + result = runner.run(A, B) + + # Validate + validator = Validator() + check = validator.check(result.C, C_reference) +""" + +import ctypes +import subprocess +import numpy as np +from pathlib import Path +from typing import Optional, Tuple, List, Dict, Any +from dataclasses import dataclass, field +from concurrent.futures import ProcessPoolExecutor, as_completed +import multiprocessing +import time + + +# ============================================================================= +# Path Configuration +# ============================================================================= + + +def get_dispatcher_root() -> Path: + """Get the dispatcher root directory""" + # This file is in dispatcher/python/ + return Path(__file__).parent.parent + + +def get_ck_root() -> Path: + """Get the CK root directory""" + return get_dispatcher_root().parent + + +def get_build_dir() -> Path: + """Get the build directory""" + return get_dispatcher_root() / "build" + + +# ============================================================================= +# Supported Data Types +# ============================================================================= + +# All supported GEMM dtype combinations from warp_gemm_dispatcher.hpp +SUPPORTED_DTYPES = { + # dtype_a, dtype_b -> acc_dtype, warp_tiles + ("fp32", "fp32"): {"acc": "fp32", "warp_tiles": [(16, 16, 4), (16, 16, 16)]}, + ("fp16", "fp16"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 8), (32, 32, 16), (16, 16, 16), (16, 16, 32)], + }, + ("bf16", "bf16"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 8), (32, 32, 16), (16, 16, 16), (16, 16, 32)], + }, + ("fp8", "fp8"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 16), (32, 32, 32), (16, 16, 32), (16, 16, 64)], + }, + ("fp8", "bf8"): {"acc": "fp32", "warp_tiles": [(32, 32, 16), (16, 16, 32)]}, + ("bf8", "fp8"): {"acc": "fp32", "warp_tiles": [(32, 32, 16), (16, 16, 128)]}, + ("bf8", "bf8"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 16), (32, 32, 32), (16, 16, 32)], + }, + ("int8", "int8"): { + "acc": "int32", + "warp_tiles": [(32, 32, 16), (16, 16, 32), (16, 16, 16)], + }, + ("pk_fp4", "pk_fp4"): {"acc": "fp32", "warp_tiles": [(16, 16, 128)]}, +} + +# All valid individual dtypes +VALID_DTYPES = ["fp16", "bf16", "fp32", "fp8", "bf8", "int8", "pk_fp4"] + + +def get_generated_kernels_dir() -> Path: + """Get the generated kernels directory""" + return get_build_dir() / "generated_kernels" + + +# ============================================================================= +# Arch Filter and Validation +# ============================================================================= + + +def get_arch_filter_data() -> Dict[str, Any]: + """Load arch filter data from arch_specs_generated if available.""" + codegen_dir = get_dispatcher_root() / "codegen" + import sys + + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + TRAIT_UNSUPPORTED_COMBINATIONS, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + return { + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + # Fallback defaults + return { + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + }, + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + "gfx90a": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + + +@dataclass +class ValidationResult: + """Result of kernel config validation.""" + + is_valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + suggested_fixes: Dict[str, Any] = field(default_factory=dict) + + def print_result(self, indent: str = " "): + """Print validation result.""" + if self.is_valid: + print(f"{indent}✓ Configuration valid") + else: + print(f"{indent}⚠ Configuration has issues:") + for err in self.errors: + print(f"{indent} - {err}") + + if self.warnings: + for warn in self.warnings: + print(f"{indent} Warning: {warn}") + + if self.suggested_fixes: + print(f"{indent} Suggested fixes:") + for key, val in self.suggested_fixes.items(): + print(f"{indent} {key}: {val}") + + +def validate_kernel_config(config: "KernelConfig") -> ValidationResult: + """ + Validate a KernelConfig against arch filter rules. + + Validation considers the GEMM variant (standard, preshuffle, multi_d) + for operator-specific constraints like minimum tile sizes. + + Returns ValidationResult with is_valid, errors, and suggested fixes. + """ + arch_data = get_arch_filter_data() + + errors = [] + warnings = [] + suggested_fixes = {} + + pipeline = config.pipeline + epilogue = config.epilogue + scheduler = config.scheduler + dtype = config.dtype_a + arch = config.gfx_arch + variant = getattr(config, "variant", "standard") + + wave_m = config.wave_m + wave_n = config.wave_n + wave_k = config.wave_k + + warp_m = config.warp_m + warp_n = config.warp_n + warp_k = config.warp_k + + # Variant-specific tile constraints + if variant == "preshuffle": + # Preshuffle requires larger minimum tiles for efficiency + if config.tile_m < 64: + errors.append(f"Preshuffle requires tile_m >= 64, got {config.tile_m}") + suggested_fixes["tile_m"] = 64 + if config.tile_n < 64: + errors.append(f"Preshuffle requires tile_n >= 64, got {config.tile_n}") + suggested_fixes["tile_n"] = 64 + if config.tile_k < 32: + errors.append(f"Preshuffle requires tile_k >= 32, got {config.tile_k}") + suggested_fixes["tile_k"] = 32 + + elif variant == "multi_d": + # Multi-D has standard GEMM constraints + # Could add specific constraints here if needed + pass + + # Check trait combination (pipeline, epilogue, scheduler) + combo = (pipeline, epilogue, scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}" + ) + suggested_fixes["scheduler"] = "intrawave" + + # Check wave configuration for this arch + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}. Valid: {valid_str}" + ) + if warp_combos: + suggested_fixes["wave_m"] = warp_combos[0][0] + suggested_fixes["wave_n"] = warp_combos[0][1] + suggested_fixes["wave_k"] = warp_combos[0][2] + + # Check warp tile configuration for this arch and dtype + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}. Valid: {valid_str}" + ) + if warp_tile_combos: + suggested_fixes["warp_m"] = warp_tile_combos[0][0] + suggested_fixes["warp_n"] = warp_tile_combos[0][1] + suggested_fixes["warp_k"] = warp_tile_combos[0][2] + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}. Supported: {', '.join(arch_data['supported_archs'])}" + ) + + return ValidationResult( + is_valid=len(errors) == 0, + errors=errors, + warnings=warnings, + suggested_fixes=suggested_fixes, + ) + + +def auto_correct_kernel_config( + config: "KernelConfig", verbose: bool = False +) -> Tuple["KernelConfig", bool, List[str]]: + """ + Validate and auto-correct a KernelConfig. + + Returns (corrected_config, was_modified, corrections_list). + If the config was valid, returns (original_config, False, []). + If corrections were made, returns (new_config, True, [list of correction descriptions]). + """ + validation = validate_kernel_config(config) + + if validation.is_valid: + return config, False, [] + + # Apply suggested fixes and track what changed + from dataclasses import replace + + fixes = validation.suggested_fixes + corrections = [] + + # Check each fix and describe what changed + if "scheduler" in fixes and fixes["scheduler"] != config.scheduler: + corrections.append( + f"Scheduler: {config.scheduler} → {fixes['scheduler']} " + f"('{config.scheduler}' not supported with pipeline={config.pipeline}, epilogue={config.epilogue})" + ) + + if "wave_m" in fixes or "wave_n" in fixes or "wave_k" in fixes: + old_wave = f"[{config.wave_m}, {config.wave_n}, {config.wave_k}]" + new_wave = f"[{fixes.get('wave_m', config.wave_m)}, {fixes.get('wave_n', config.wave_n)}, {fixes.get('wave_k', config.wave_k)}]" + if old_wave != new_wave: + corrections.append( + f"Wave config: {old_wave} → {new_wave} " + f"(original not supported on {config.gfx_arch})" + ) + + if "warp_m" in fixes or "warp_n" in fixes or "warp_k" in fixes: + old_warp = f"[{config.warp_m}, {config.warp_n}, {config.warp_k}]" + new_warp = f"[{fixes.get('warp_m', config.warp_m)}, {fixes.get('warp_n', config.warp_n)}, {fixes.get('warp_k', config.warp_k)}]" + if old_warp != new_warp: + corrections.append( + f"Warp tile: {old_warp} → {new_warp} " + f"(original not supported for {config.dtype_a} on {config.gfx_arch})" + ) + + new_config = replace( + config, + scheduler=fixes.get("scheduler", config.scheduler), + wave_m=fixes.get("wave_m", config.wave_m), + wave_n=fixes.get("wave_n", config.wave_n), + wave_k=fixes.get("wave_k", config.wave_k), + warp_m=fixes.get("warp_m", config.warp_m), + warp_n=fixes.get("warp_n", config.warp_n), + warp_k=fixes.get("warp_k", config.warp_k), + ) + + return new_config, True, corrections + + +def print_kernel_config(config: "KernelConfig", title: str = "KERNEL CONFIGURATION"): + """ + Print a formatted kernel configuration for GEMM. + + Args: + config: The KernelConfig to print + title: Title to display (e.g., "REQUESTED KERNEL CONFIGURATION") + """ + print() + print("=" * 70) + print(f" {title}") + print("=" * 70) + print(f" Data Type A: {config.dtype_a}") + print(f" Data Type B: {config.dtype_b}") + print(f" Data Type C: {config.dtype_c}") + print(f" Accumulator: {config.dtype_acc}") + print() + print( + f" Layout: {config.layout} (A={config.layout_a}, B={config.layout_b}, C={config.layout_c})" + ) + print() + print(f" Tile M x N x K: {config.tile_m} x {config.tile_n} x {config.tile_k}") + print(f" Wave Config: {config.wave_m} x {config.wave_n} x {config.wave_k}") + print(f" Warp Tile: {config.warp_m} x {config.warp_n} x {config.warp_k}") + print() + print(f" Pipeline: {config.pipeline}") + print(f" Scheduler: {config.scheduler}") + print(f" Epilogue: {config.epilogue}") + print() + print(f" Target Arch: {config.gfx_arch}") + print("=" * 70) + print() + + +def print_auto_correction( + original: "KernelConfig", + corrected: "KernelConfig", + corrections: List[str], + indent: str = " ", +): + """ + Print what was auto-corrected and why. + + Args: + original: Original configuration before correction + corrected: Configuration after correction + corrections: List of correction descriptions + indent: Indentation for output + """ + if not corrections: + print(f"{indent}✓ Configuration valid - no corrections needed") + return + + print(f"\n{indent}⚠ AUTO-CORRECTION APPLIED:") + print(f"{indent}" + "-" * 50) + for correction in corrections: + print(f"{indent} • {correction}") + print(f"{indent}" + "-" * 50) + print() + + +def find_matching_kernel_header(config: "KernelConfig") -> Optional[Path]: + """ + Find a kernel header that EXACTLY matches the config. + + Uses progressively relaxed matching strategies. + """ + kernel_dir = get_generated_kernels_dir() + + dtype = config.dtype_a + layout = config.layout + pipeline = config.pipeline + scheduler = config.scheduler + tile_str = config.tile_str + wave_str = f"{config.wave_m}x{config.wave_n}x{config.wave_k}" + warp_str = f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + + # Strategy 1: Exact match with ALL parameters including warp tile + pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_{warp_str}.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 2: Match with tile and wave, any warp + pattern = ( + f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_*.hpp" + ) + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 3: Match with just tile (ignore wave/warp) + pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 4: Match with intrawave (known to work) + pattern = f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 5: Any kernel with matching dtype/layout/tile + pattern = f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + return None + + +# ============================================================================= +# Library Loading +# ============================================================================= + + +class DispatcherLib: + """Wrapper for the dispatcher dynamic library""" + + # Default library search paths (relative to dispatcher root) + SEARCH_PATHS = [ + "build/examples/libdispatcher_gemm_lib.so", + "build/libdispatcher_gemm_lib.so", + "build/examples/libdispatcher_gemm.so", + "build/lib/libdispatcher_gemm.so", + ] + + # Track loaded libraries globally for cleanup + _loaded_libs: List[Path] = [] + + def __init__(self, lib: ctypes.CDLL, path: Path): + self._lib = lib + self._path = path + self._closed = False + DispatcherLib._loaded_libs.append(path) + self._setup_functions() + + def _setup_functions(self): + """Setup ctypes function signatures""" + # Initialize + self._lib.dispatcher_initialize.argtypes = [] + self._lib.dispatcher_initialize.restype = ctypes.c_int + + # Alias for init + self._lib.dispatcher_init.argtypes = [] + self._lib.dispatcher_init.restype = ctypes.c_int + + # Get kernel count + self._lib.dispatcher_get_kernel_count.argtypes = [] + self._lib.dispatcher_get_kernel_count.restype = ctypes.c_int + + # Check if supported + self._lib.dispatcher_is_supported.argtypes = [ + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_int64, + ] + self._lib.dispatcher_is_supported.restype = ctypes.c_int + + # Run GEMM + self._lib.dispatcher_run_gemm.argtypes = [ + ctypes.c_void_p, # A + ctypes.c_void_p, # B + ctypes.c_void_p, # C + ctypes.c_int64, # M + ctypes.c_int64, # N + ctypes.c_int64, # K + ctypes.POINTER(ctypes.c_float), # time_ms + ] + self._lib.dispatcher_run_gemm.restype = ctypes.c_int + + # Get kernel name + self._lib.dispatcher_get_kernel_name.argtypes = [] + self._lib.dispatcher_get_kernel_name.restype = ctypes.c_char_p + + # Select kernel + self._lib.dispatcher_select_kernel.argtypes = [ + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_char_p, + ctypes.c_int, + ] + self._lib.dispatcher_select_kernel.restype = ctypes.c_int + + # Export JSON + self._lib.dispatcher_export_registry_json.argtypes = [] + self._lib.dispatcher_export_registry_json.restype = ctypes.c_char_p + + # Cleanup + self._lib.dispatcher_cleanup.argtypes = [] + self._lib.dispatcher_cleanup.restype = None + + @property + def path(self) -> Path: + return self._path + + def initialize(self) -> bool: + """Initialize the dispatcher""" + return self._lib.dispatcher_initialize() == 0 + + def get_kernel_count(self) -> int: + """Get number of registered kernels""" + return self._lib.dispatcher_get_kernel_count() + + def is_supported(self, M: int, N: int, K: int) -> bool: + """Check if a problem size is supported""" + return self._lib.dispatcher_is_supported(M, N, K) == 1 + + def get_kernel_name(self) -> str: + """Get the kernel name""" + name = self._lib.dispatcher_get_kernel_name() + return name.decode("utf-8") if name else "unknown" + + def select_kernel(self, M: int, N: int, K: int) -> Optional[str]: + """Select kernel for problem and return its name""" + buffer = ctypes.create_string_buffer(256) + result = self._lib.dispatcher_select_kernel(M, N, K, buffer, 256) + if result == 0: + return buffer.value.decode("utf-8") + return None + + def run_gemm( + self, A: np.ndarray, B: np.ndarray, C: np.ndarray, M: int, N: int, K: int + ) -> Tuple[int, float]: + """ + Run GEMM operation + + Returns: (status, time_ms) + status: 0 = success, -1 = error, -2 = no suitable kernel + """ + time_ms = ctypes.c_float(0.0) + + status = self._lib.dispatcher_run_gemm( + A.ctypes.data_as(ctypes.c_void_p), + B.ctypes.data_as(ctypes.c_void_p), + C.ctypes.data_as(ctypes.c_void_p), + M, + N, + K, + ctypes.byref(time_ms), + ) + + return status, time_ms.value + + def export_json(self) -> Optional[str]: + """Export registry to JSON string""" + json_ptr = self._lib.dispatcher_export_registry_json() + if json_ptr: + return json_ptr.decode("utf-8") + return None + + def export_registry_json(self) -> str: + """Alias for export_json for compatibility""" + return self.export_json() or "{}" + + def cleanup(self): + """Cleanup dispatcher resources""" + self._lib.dispatcher_cleanup() + + @classmethod + def find(cls) -> Optional[Path]: + """Find the dispatcher library""" + root = get_dispatcher_root() + + for rel_path in cls.SEARCH_PATHS: + path = root / rel_path + if path.exists(): + return path + + return None + + @classmethod + def load(cls, path: Optional[Path] = None) -> Optional["DispatcherLib"]: + """Load the dispatcher library from path or auto-find""" + if path is None: + path = cls.find() + + if path is None or not path.exists(): + return None + + try: + lib = ctypes.CDLL(str(path)) + return cls(lib, path) + except OSError as e: + print(f"Failed to load library: {e}") + return None + + @classmethod + def compile(cls, output_path: Optional[Path] = None) -> Optional[Path]: + """Compile the dispatcher library""" + root = get_dispatcher_root() + ck_root = get_ck_root() + + if output_path is None: + output_path = get_build_dir() / "examples" / "libdispatcher_gemm.so" + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Find a kernel header to include + kernel_dir = get_generated_kernels_dir() + kernel_headers = list(kernel_dir.glob("gemm_fp16_rcr_compv4*128x128x32*.hpp")) + + if not kernel_headers: + print("No kernel headers found. Generate kernels first.") + return None + + kernel_header = kernel_headers[0] + + # Use the ctypes binding source file + ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp" + if not ctypes_source.exists(): + print(f"Source file not found: {ctypes_source}") + print( + "Please build with CMake: cd build && cmake .. && make dispatcher_gemm_lib" + ) + return None + + # CK_TILE_SINGLE_KERNEL_INCLUDE exports types to global namespace for ctypes binding + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{root / 'build/generated_kernels'}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", # Enable global namespace exports + f"-include{kernel_header}", + "-D__HIP_PLATFORM_AMD__", + "--offload-arch=gfx942", + "-DAMDGPU_ARCH=gfx942", + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(output_path), + ] + + try: + result = subprocess.run( + compile_cmd, capture_output=True, text=True, timeout=120 + ) + if result.returncode == 0: + return output_path + else: + print(f"Compilation failed:\n{result.stderr}") + return None + except subprocess.TimeoutExpired: + print("Compilation timed out") + return None + + @classmethod + def auto(cls, recompile: bool = False) -> Optional["DispatcherLib"]: + """Auto-find or compile the library. + + Note: The library is built by CMake with a specific kernel configuration. + If you need a different dtype/layout, rebuild with: + cd build && cmake .. && make dispatcher_gemm_lib + """ + lib = cls.load() + if lib is not None: + if lib.initialize(): + return lib + else: + print(" Library found but failed to initialize") + print( + " Rebuild with: cd build && cmake .. && make dispatcher_gemm_lib" + ) + + # Don't fall back to old compile method - use CMake instead + print(" Library not found. Build with:") + print(" cd dispatcher/build && cmake .. && make dispatcher_gemm_lib") + return None + + +# ============================================================================= +# GEMM Runner +# ============================================================================= + + +@dataclass +class GemmResult: + """Result of a GEMM operation""" + + output: np.ndarray # The output C matrix + time_ms: float + status: int + tflops: float + kernel_name: str + + @property + def success(self) -> bool: + return self.status == 0 + + # Alias for backward compatibility + @property + def C(self) -> np.ndarray: + return self.output + + +class GemmRunner: + """High-level GEMM runner using the dispatcher""" + + def __init__(self, lib: DispatcherLib): + self.lib = lib + + def run(self, A: np.ndarray, B: np.ndarray, dtype=np.float16) -> GemmResult: + """ + Run GEMM: C = A @ B + + Args: + A: Input matrix (M x K) + B: Input matrix (K x N) + dtype: Output data type (default: float16) + + Returns: + GemmResult with output matrix and timing + """ + M, K = A.shape + K2, N = B.shape + + assert K == K2, f"Dimension mismatch: A is {M}x{K}, B is {K2}x{N}" + + # Ensure contiguous float16 arrays + A_gpu = np.ascontiguousarray(A, dtype=np.float16) + B_gpu = np.ascontiguousarray(B.T, dtype=np.float16) # Column-major + C_gpu = np.zeros((M, N), dtype=np.float16) + + # Run + status, time_ms = self.lib.run_gemm(A_gpu, B_gpu, C_gpu, M, N, K) + + # Calculate TFLOPS + flops = 2.0 * M * N * K + tflops = (flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0 + + return GemmResult( + output=C_gpu, + time_ms=time_ms, + status=status, + tflops=tflops, + kernel_name=self.lib.get_kernel_name(), + ) + + def benchmark( + self, M: int, N: int, K: int, warmup: int = 2, iterations: int = 10 + ) -> dict: + """Benchmark GEMM for given dimensions""" + A = np.random.randn(M, K).astype(np.float16) + B = np.random.randn(K, N).astype(np.float16) + + times = [] + + # Warmup + for _ in range(warmup): + self.run(A, B) + + # Benchmark + for _ in range(iterations): + result = self.run(A, B) + if result.success: + times.append(result.time_ms) + + if not times: + return {"error": "All iterations failed"} + + flops = 2.0 * M * N * K + avg_time = sum(times) / len(times) + + return { + "M": M, + "N": N, + "K": K, + "min_ms": min(times), + "avg_ms": avg_time, + "max_ms": max(times), + "tflops": (flops / (avg_time * 1e-3)) / 1e12, + "iterations": len(times), + } + + +# ============================================================================= +# Validation Utilities +# ============================================================================= + + +class Validator: + """Utilities for validating GEMM results""" + + def __init__(self, rtol: float = 1e-3, atol: float = 1e-2): + self.rtol = rtol + self.atol = atol + + def check( + self, result: np.ndarray, reference: np.ndarray + ) -> Tuple[bool, float, float]: + """ + Check if result matches reference + + Returns: (is_correct, max_diff, mean_diff) + """ + result = result.astype(np.float32) + reference = reference.astype(np.float32) + + diff = np.abs(result - reference) + max_diff = float(np.max(diff)) + mean_diff = float(np.mean(diff)) + + close = np.allclose(result, reference, rtol=self.rtol, atol=self.atol) + + return close, max_diff, mean_diff + + def compute_reference(self, A: np.ndarray, B: np.ndarray) -> np.ndarray: + """Compute reference GEMM result using NumPy""" + return np.matmul(A.astype(np.float32), B.astype(np.float32)) + + +# ============================================================================= +# Code Generation Utilities +# ============================================================================= + + +def get_codegen_path() -> Path: + """Get path to unified_gemm_codegen.py""" + return get_dispatcher_root() / "codegen" / "unified_gemm_codegen.py" + + +@dataclass +class CodegenResult: + """Result of kernel code generation""" + + success: bool + output_dir: Path + variant: str + stdout: str = "" + stderr: str = "" + kernel_count: int = 0 + elapsed_seconds: float = 0.0 + instance_names: List[str] = field(default_factory=list) + + def get_generated_kernels(self) -> List[Path]: + """Get list of generated kernel headers""" + if self.output_dir.exists(): + return list(self.output_dir.glob("*.hpp")) + return [] + + def print_instances(self, prefix: str = " "): + """Print all generated instance names.""" + for name in self.instance_names: + print(f"{prefix}{name}") + + +def _run_codegen_subprocess(args: Dict[str, Any]) -> CodegenResult: + """ + Worker function for parallel codegen execution. + + This is a module-level function to allow pickling for ProcessPoolExecutor. + """ + import sys + import subprocess + from pathlib import Path + + codegen_path = Path(args["codegen_path"]) + out_dir = Path(args["output_dir"]) + variant = args["variant"] + datatype = args["datatype"] + layout = args["layout"] + gpu_target = args["gpu_target"] + extra_args = args.get("extra_args", []) + timeout = args.get("timeout", 300) + + out_dir.mkdir(parents=True, exist_ok=True) + + start = time.time() + + # Get existing kernels before generation + existing_kernels = set(out_dir.glob("*.hpp")) if out_dir.exists() else set() + + cmd = [ + sys.executable, + str(codegen_path), + "--output-dir", + str(out_dir), + "--datatype", + datatype, + "--layout", + layout, + "--gpu-target", + gpu_target, + "--variants", + variant, + ] + + if extra_args: + cmd.extend(extra_args) + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + + # Get new kernels after generation + all_kernels = set(out_dir.glob("*.hpp")) + new_kernels = all_kernels - existing_kernels + kernel_count = len(all_kernels) + elapsed = time.time() - start + + # Build instance names list for verbose output + instance_names = sorted([k.stem for k in new_kernels]) + + return CodegenResult( + success=result.returncode == 0, + output_dir=out_dir, + variant=variant, + stdout=result.stdout, + stderr=result.stderr, + kernel_count=kernel_count, + elapsed_seconds=elapsed, + instance_names=instance_names, + ) + except subprocess.TimeoutExpired: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=variant, + stderr=f"Code generation timed out ({timeout}s)", + elapsed_seconds=time.time() - start, + ) + except Exception as e: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=variant, + stderr=str(e), + elapsed_seconds=time.time() - start, + ) + + +# ============================================================================= +# Preshuffle Utilities +# ============================================================================= + + +def preshuffle_weight_matrix( + B: np.ndarray, + warp_tile_n: int, + warp_tile_k: int, + arch: str = "gfx942", +) -> np.ndarray: + """ + Preshuffle the B (weight) matrix for optimized GEMM inference. + + This transforms the B matrix layout to match the expected memory access + pattern for preshuffle-enabled kernels. The transformation reorders data + so that warp-level loads are coalesced. + + Args: + B: Weight matrix of shape (K, N) in column-major / (K, N) layout + warp_tile_n: Warp tile size in N dimension (e.g., 32) + warp_tile_k: Warp tile size in K dimension (e.g., 16) + arch: Target GPU architecture (gfx9xx, gfx11xx, gfx12xx) + + Returns: + Shuffled B matrix with same data but reordered layout + + Example: + >>> B = np.random.randn(1024, 2048).astype(np.float16) + >>> B_shuffled = preshuffle_weight_matrix(B, warp_tile_n=32, warp_tile_k=16) + >>> # Use B_shuffled with preshuffle-enabled kernel + """ + K, N = B.shape + + # Validate dimensions are divisible by warp tiles + if N % warp_tile_n != 0: + raise ValueError(f"N ({N}) must be divisible by warp_tile_n ({warp_tile_n})") + if K % warp_tile_k != 0: + raise ValueError(f"K ({K}) must be divisible by warp_tile_k ({warp_tile_k})") + + # Architecture-specific shuffle patterns + # Based on ck_tile/host/tensor_shuffle_utils.hpp + if arch.startswith("gfx12"): + # GFX12 (RDNA4) pattern + divisor = 2 + k_abk1_per_lane = 8 + k_abk0_per_lane = warp_tile_k // divisor // k_abk1_per_lane + + if k_abk0_per_lane <= 0: + raise ValueError( + f"warp_tile_k ({warp_tile_k}) too small for GFX12 preshuffle" + ) + + # Reshape: (K, N) -> (N/warp_n, warp_n, K/warp_k, k0, div, k1) + B_view = B.T.reshape( + N // warp_tile_n, + warp_tile_n, + K // warp_tile_k, + k_abk0_per_lane, + divisor, + k_abk1_per_lane, + ) + # Permute: {0, 2, 4, 1, 3, 5} + B_shuffled = np.transpose(B_view, (0, 2, 4, 1, 3, 5)) + + elif arch.startswith("gfx11"): + # GFX11 (RDNA3) pattern - divisor = 1 + divisor = 1 + + # Reshape: (K, N) -> (N/warp_n, warp_n, K/warp_k, div, warp_k/div) + B_view = B.T.reshape( + N // warp_tile_n, + warp_tile_n, + K // warp_tile_k, + divisor, + warp_tile_k // divisor, + ) + # Permute: {0, 2, 3, 1, 4} + B_shuffled = np.transpose(B_view, (0, 2, 3, 1, 4)) + + else: + # GFX9 (CDNA) pattern - wave64 + divisor = 2 if warp_tile_n == 32 else 4 + + # Reshape: (K, N) -> (N/warp_n, warp_n, K/warp_k, div, warp_k/div) + B_view = B.T.reshape( + N // warp_tile_n, + warp_tile_n, + K // warp_tile_k, + divisor, + warp_tile_k // divisor, + ) + # Permute: {0, 2, 3, 1, 4} + B_shuffled = np.transpose(B_view, (0, 2, 3, 1, 4)) + + # Return contiguous array with same dtype + return np.ascontiguousarray(B_shuffled.reshape(-1)).reshape(B.shape) + + +def is_preshuffle_supported(arch: str) -> bool: + """Check if preshuffle is supported for the given architecture.""" + # Preshuffle is supported on CDNA (gfx9xx) and RDNA (gfx11xx, gfx12xx) + return arch.startswith(("gfx9", "gfx11", "gfx12")) + + +@dataclass +class KernelConfig: + """ + Complete kernel configuration for GEMM. + + This defines all parameters needed to generate and run a specific kernel. + """ + + # Data types + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + + # Layouts (row/col) + layout_a: str = "row" + layout_b: str = "col" + layout_c: str = "row" + + # Tile shape (work per thread block) + tile_m: int = 128 + tile_n: int = 128 + tile_k: int = 32 + + # Wave shape (warps per block) + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + + # Warp tile (elements per warp) + warp_m: int = 32 + warp_n: int = 32 + warp_k: int = 16 + + # Block configuration + block_size: int = 256 + + # Pipeline configuration + pipeline: str = "compv4" + scheduler: str = "intrawave" + epilogue: str = "cshuffle" + + # Padding (enables arbitrary problem sizes) + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + + # GPU target + gfx_arch: str = "gfx942" + + # GEMM variant (affects arch filter validation) + # "standard", "preshuffle", or "multi_d" + variant: str = "standard" + + @property + def layout(self) -> str: + """Get layout string (e.g., 'rcr' for row-col-row)""" + mapping = {"row": "r", "col": "c"} + return mapping[self.layout_a] + mapping[self.layout_b] + mapping[self.layout_c] + + @property + def tile_str(self) -> str: + """Get tile size string""" + return f"{self.tile_m}x{self.tile_n}x{self.tile_k}" + + def print_config(self, indent: str = " "): + """Pretty print the configuration.""" + print(f"{indent}KernelConfig:") + print( + f"{indent} Data types: A={self.dtype_a}, B={self.dtype_b}, C={self.dtype_c}, Acc={self.dtype_acc}" + ) + print( + f"{indent} Layouts: A={self.layout_a}, B={self.layout_b}, C={self.layout_c} ({self.layout})" + ) + print(f"{indent} Tile: {self.tile_m}x{self.tile_n}x{self.tile_k}") + print(f"{indent} Waves: {self.wave_m}x{self.wave_n}x{self.wave_k}") + print(f"{indent} Warp tile: {self.warp_m}x{self.warp_n}x{self.warp_k}") + print(f"{indent} Block size: {self.block_size}") + print(f"{indent} Pipeline: {self.pipeline}/{self.scheduler}/{self.epilogue}") + print(f"{indent} Padding: M={self.pad_m}, N={self.pad_n}, K={self.pad_k}") + print(f"{indent} Target: {self.gfx_arch}") + + +class CodegenRunner: + """ + Runner for the unified GEMM code generator with parallel execution support. + + Usage: + codegen = CodegenRunner() + + # Generate standard kernels + result = codegen.generate("standard") + + # Generate preshuffle kernels + result = codegen.generate("preshuffle") + + # Generate multi-D kernels + result = codegen.generate("multi_d") + + # Generate all variants IN PARALLEL + results = codegen.generate_all_parallel() + + # Generate multiple configs IN PARALLEL + configs = [KernelConfig(...), KernelConfig(...)] + results = codegen.generate_configs_parallel(configs) + + # Generate with custom output directory + result = codegen.generate("standard", output_dir=Path("/custom/path")) + + # Generate from specific config + config = KernelConfig(tile_m=256, tile_n=256, tile_k=64) + result = codegen.generate_from_config(config) + """ + + VARIANTS = ["standard", "preshuffle", "multi_d"] + + def __init__( + self, + codegen_path: Optional[Path] = None, + output_dir: Optional[Path] = None, + datatype: str = "fp16", + layout: str = "rcr", + gpu_target: str = "gfx942", + max_workers: Optional[int] = None, + ): + self.codegen_path = codegen_path or get_codegen_path() + self.output_dir = output_dir or get_generated_kernels_dir() + self.datatype = datatype + self.layout = layout + self.gpu_target = gpu_target + # Default to CPU count, but cap at reasonable value + self.max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + + def _make_args( + self, + variant: str, + output_dir: Optional[Path] = None, + extra_args: Optional[List[str]] = None, + timeout: int = 300, + show_instances: bool = False, + ) -> Dict[str, Any]: + """Build args dict for parallel worker.""" + return { + "codegen_path": str(self.codegen_path), + "output_dir": str(output_dir or self.output_dir), + "variant": variant, + "datatype": self.datatype, + "layout": self.layout, + "gpu_target": self.gpu_target, + "extra_args": extra_args or [], + "timeout": timeout, + "show_instances": show_instances, + } + + def generate( + self, + variant: str = "standard", + output_dir: Optional[Path] = None, + extra_args: Optional[List[str]] = None, + show_instances: bool = False, + ) -> CodegenResult: + """ + Generate kernels for a specific variant (single-threaded). + + Args: + variant: One of "standard", "preshuffle", "multi_d" + output_dir: Override output directory + extra_args: Additional arguments to pass to codegen + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + CodegenResult with generation status and info + """ + args = self._make_args( + variant, output_dir, extra_args, show_instances=show_instances + ) + result = _run_codegen_subprocess(args) + + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + + return result + + def generate_all(self, output_dir: Optional[Path] = None) -> List[CodegenResult]: + """Generate all variants sequentially (use generate_all_parallel for speed).""" + results = [] + for variant in self.VARIANTS: + result = self.generate(variant, output_dir) + results.append(result) + return results + + def generate_all_parallel( + self, + output_dir: Optional[Path] = None, + variants: Optional[List[str]] = None, + verbose: bool = True, + show_instances: bool = False, + ) -> List[CodegenResult]: + """ + Generate all variants IN PARALLEL. + + Args: + output_dir: Override output directory + variants: List of variants to generate (default: all) + verbose: Print progress + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + List of CodegenResult for each variant + """ + variants = variants or self.VARIANTS + start_total = time.time() + + if verbose: + print( + f"Generating {len(variants)} variants in parallel (workers={self.max_workers})..." + ) + + # Build args for each variant + args_list = [self._make_args(v, output_dir) for v in variants] + for args in args_list: + args["show_instances"] = show_instances + + results = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_codegen_subprocess, args): args["variant"] + for args in args_list + } + + for future in as_completed(futures): + variant = futures[future] + try: + result = future.result() + results.append(result) + if verbose: + status = "✓" if result.success else "✗" + print( + f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" + ) + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + except Exception as e: + results.append( + CodegenResult( + success=False, + output_dir=output_dir or self.output_dir, + variant=variant, + stderr=str(e), + ) + ) + if verbose: + print(f" ✗ {variant}: FAILED - {e}") + + total_time = time.time() - start_total + if verbose: + total_kernels = sum(r.kernel_count for r in results) + print(f"Total: {total_kernels} kernels in {total_time:.2f}s") + + return results + + def generate_configs_parallel( + self, + configs: List["KernelConfig"], + output_dir: Optional[Path] = None, + verbose: bool = True, + show_instances: bool = False, + ) -> List[CodegenResult]: + """ + Generate kernels from multiple configs IN PARALLEL. + + Each config generates independently, allowing maximum parallelism. + + Args: + configs: List of KernelConfig objects + output_dir: Override output directory + verbose: Print progress + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + List of CodegenResult for each config + """ + start_total = time.time() + out_dir = output_dir or self.output_dir + + if verbose: + print( + f"Generating {len(configs)} configs in parallel (workers={self.max_workers})..." + ) + + results = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = {} + for config in configs: + args = { + "codegen_path": str(self.codegen_path), + "output_dir": str(out_dir), + "variant": "standard", + "datatype": config.dtype_a, + "layout": config.layout, + "gpu_target": config.gfx_arch, + "extra_args": [], + "timeout": 300, + "show_instances": show_instances, + } + future = executor.submit(_run_codegen_subprocess, args) + futures[future] = config.tile_str + + for future in as_completed(futures): + tile_str = futures[future] + try: + result = future.result() + results.append(result) + if verbose: + status = "✓" if result.success else "✗" + print( + f" {status} {tile_str}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" + ) + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + except Exception as e: + results.append( + CodegenResult( + success=False, + output_dir=out_dir, + variant=f"config:{tile_str}", + stderr=str(e), + ) + ) + if verbose: + print(f" ✗ {tile_str}: FAILED - {e}") + + total_time = time.time() - start_total + if verbose: + total_kernels = sum(r.kernel_count for r in results) + print(f"Total: {total_kernels} kernels in {total_time:.2f}s") + + return results + + def generate_batch_parallel( + self, + batch: List[Dict[str, Any]], + verbose: bool = True, + show_instances: bool = False, + ) -> List[CodegenResult]: + """ + Generate a batch of kernel specs IN PARALLEL. + + This is the most flexible parallel generation method. + + Args: + batch: List of dicts with keys: variant, datatype, layout, gpu_target, output_dir + verbose: Print progress + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + List of CodegenResult + """ + start_total = time.time() + + if verbose: + print( + f"Generating {len(batch)} kernel specs in parallel (workers={self.max_workers})..." + ) + + # Build args for each spec + args_list = [] + for spec in batch: + args = { + "codegen_path": str(self.codegen_path), + "output_dir": str(spec.get("output_dir", self.output_dir)), + "variant": spec.get("variant", "standard"), + "datatype": spec.get("datatype", self.datatype), + "layout": spec.get("layout", self.layout), + "gpu_target": spec.get("gpu_target", self.gpu_target), + "extra_args": spec.get("extra_args", []), + "timeout": spec.get("timeout", 300), + "show_instances": show_instances, + } + args_list.append(args) + + results = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_codegen_subprocess, args): args["variant"] + for args in args_list + } + + for future in as_completed(futures): + variant = futures[future] + try: + result = future.result() + results.append(result) + if verbose: + status = "✓" if result.success else "✗" + print( + f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" + ) + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + except Exception as e: + results.append( + CodegenResult( + success=False, + output_dir=self.output_dir, + variant=variant, + stderr=str(e), + ) + ) + if verbose: + print(f" ✗ {variant}: FAILED - {e}") + + total_time = time.time() - start_total + if verbose: + total_kernels = sum(r.kernel_count for r in results) + print(f"Total: {total_kernels} kernels in {total_time:.2f}s") + + return results + + def generate_from_config( + self, + config: KernelConfig, + output_dir: Optional[Path] = None, + force: bool = False, + show_instances: bool = False, + ) -> CodegenResult: + """ + Generate kernel from a specific KernelConfig. + + This generates ONLY the specific kernel header needed (not all kernels). + Note: This does NOT rebuild the library - use build_library_for_configs() + for that. + + Args: + config: KernelConfig with all kernel parameters + output_dir: Override output directory + force: Force regeneration even if kernel exists + show_instances: Print instance names when generating + + Returns: + CodegenResult with the specific kernel + """ + import sys + import json + import tempfile + + out_dir = output_dir or self.output_dir + out_dir.mkdir(parents=True, exist_ok=True) + + # Build kernel filename pattern for this config + # Note: padding flags may differ from config (arch filter may enable padding) + tile_str = config.tile_str # e.g., "128x128x32" + wave_str = f"{config.wave_m}x{config.wave_n}x{config.wave_k}" + warp_str = f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + + # Build pattern - use * for padding flags since arch filter may change them + precise_pattern = f"gemm_{config.dtype_a}_{config.layout}_{config.pipeline}_{config.epilogue}_{config.scheduler}_*_*_*_*_{tile_str}_{wave_str}_{warp_str}.hpp" + + # Check if exact kernel already exists + existing = list(out_dir.glob(precise_pattern)) + if existing and not force: + instance_names = sorted([k.stem for k in existing]) + if show_instances: + for name in instance_names: + print(f" Kernel exists: {name}") + + return CodegenResult( + success=True, + output_dir=out_dir, + variant=f"config:{tile_str}", + kernel_count=len(existing), + instance_names=instance_names, + stdout=f"Kernel exists, using: {existing[0].name}", + ) + + if not self.codegen_path.exists(): + return CodegenResult( + success=False, + output_dir=out_dir, + variant=f"config:{tile_str}", + stderr=f"Codegen not found at {self.codegen_path}", + ) + + start = time.time() + + # Create a temporary config file for single-kernel generation + # Format must match what unified_gemm_codegen.py expects + single_config = { + "tile_config": { + "tile_m": [config.tile_m], + "tile_n": [config.tile_n], + "tile_k": [config.tile_k], + "warp_m": [config.wave_m], + "warp_n": [config.wave_n], + "warp_k": [config.wave_k], + "warp_tile_m": [config.warp_m], + "warp_tile_n": [config.warp_n], + "warp_tile_k": [config.warp_k], + }, + "trait_config": { + "pipeline": [config.pipeline], + "epilogue": [config.epilogue], + "scheduler": [config.scheduler], + "pad_m": [config.pad_m], + "pad_n": [config.pad_n], + "pad_k": [config.pad_k], + "persistent": [False], + }, + } + + # Write temp config file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(single_config, f) + config_file = f.name + + try: + # Generate ONLY this specific kernel using config file + cmd = [ + sys.executable, + str(self.codegen_path), + "--output-dir", + str(out_dir), + "--datatype", + config.dtype_a, + "--layout", + config.layout, + "--gpu-target", + config.gfx_arch, + "--config", + config_file, + "--variants", + "standard", + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + + # Find the generated kernel + matching = list(out_dir.glob(precise_pattern)) + kernel_count = len(matching) + elapsed = time.time() - start + + instance_names = sorted([k.stem for k in matching]) + if show_instances and instance_names: + for name in instance_names: + print(f" Generated: {name}") + + return CodegenResult( + success=result.returncode == 0 and kernel_count > 0, + output_dir=out_dir, + variant=f"config:{tile_str}", + stdout=result.stdout, + stderr=result.stderr, + kernel_count=kernel_count, + elapsed_seconds=elapsed, + instance_names=instance_names, + ) + except Exception as e: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=f"config:{tile_str}", + stderr=str(e), + ) + finally: + # Clean up temp file + import os + + try: + os.unlink(config_file) + except Exception: + pass + + def _rebuild_library_for_config( + self, config: KernelConfig, kernel_header: Path + ) -> Optional[Path]: + """ + Rebuild the library with the specified kernel header using hipcc directly. + + This compiles a new library with exactly the kernel specified. + Builds to a UNIQUE filename to avoid conflicts with loaded libraries. + + Architecture Note - C++ vs Python Paths: + ----------------------------------------- + C++ Multi-Kernel Path: + - Each kernel is in its own namespace (ns_gemm_...) + - Multiple kernel headers can be included together + - Uses namespace-qualified types: ns_...:SelectedKernel + - Does NOT define CK_TILE_SINGLE_KERNEL_INCLUDE + - Registration code uses block-scoped type aliases + + Python Single-Kernel JIT Path (this function): + - Each library contains exactly ONE kernel + - Uses -DCK_TILE_SINGLE_KERNEL_INCLUDE to export types to global namespace + - gemm_ctypes_lib.cpp expects: SelectedKernel, KERNEL_NAME, ADataType, etc. + - Different configs get different library files (by dtype/layout) + - This enables Python to use any kernel config without pre-building all + + Returns: Path to new library, or None on failure + """ + build_dir = get_build_dir() + # Use unique filename based on dtype/layout to avoid overwriting loaded library + lib_name = f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_lib.so" + lib_path = build_dir / "examples" / lib_name + + print(f" Rebuilding library: {lib_name}") + print(f" With kernel: {kernel_header.name}") + + root = get_dispatcher_root() + ck_root = root.parent + + ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp" + if not ctypes_source.exists(): + print(f" Source not found: {ctypes_source}") + return None + + # Link against the static dispatcher library (contains Registry, Dispatcher) + static_lib = build_dir / "libck_tile_dispatcher.a" + if not static_lib.exists(): + print(f" Static library not found: {static_lib}") + print(" Build with: cd build && cmake .. && make ck_tile_dispatcher") + return None + + # Compile source to object first, then link + obj_file = lib_path.with_suffix(".o") + + # Step 1: Compile source to object + # CK_TILE_SINGLE_KERNEL_INCLUDE enables global namespace exports in the kernel header + # This exports: SelectedKernel, KERNEL_NAME, ADataType, BDataType, CDataType, AccDataType + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", # Compile only + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{root / 'build/generated_kernels'}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", # Enable global namespace exports + f"-include{kernel_header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={config.gfx_arch}", + f'-DGFX_ARCH="{config.gfx_arch}"', # Pass arch as string for gemm_ctypes_lib.cpp + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + + try: + print(" Compiling source...") + result = subprocess.run( + compile_cmd, capture_output=True, text=True, timeout=300 + ) + if result.returncode != 0: + print(f" Compilation failed: {result.stderr[:300]}") + return None + + # Step 2: Link object with static library into shared library + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={config.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + print(" Linking...") + result = subprocess.run( + link_cmd, capture_output=True, text=True, timeout=300 + ) + if result.returncode == 0: + print(f" ✓ Library rebuilt: {lib_path.name}") + # Clean up object file + obj_file.unlink(missing_ok=True) + return lib_path + else: + print(f" Linking failed: {result.stderr[:300]}") + return None + except subprocess.TimeoutExpired: + print(" Build timed out") + return None + except Exception as e: + print(f" Build error: {e}") + return None + + def generate_preselected( + self, preset: str = "fp16_rcr_essential", output_dir: Optional[Path] = None + ) -> CodegenResult: + """ + Generate kernels from a preselected set. + + Args: + preset: Preselected kernel set name (e.g., "fp16_rcr_essential") + output_dir: Override output directory + + Returns: + CodegenResult + """ + import sys + + out_dir = output_dir or self.output_dir + out_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, + str(self.codegen_path), + "--output-dir", + str(out_dir), + "--preselected", + preset, + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + kernel_count = len(list(out_dir.glob("*.hpp"))) + + return CodegenResult( + success=result.returncode == 0, + output_dir=out_dir, + variant=f"preselected:{preset}", + stdout=result.stdout, + stderr=result.stderr, + kernel_count=kernel_count, + ) + except Exception as e: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=f"preselected:{preset}", + stderr=str(e), + ) + + def ensure_kernels_exist(self) -> bool: + """ + Ensure kernel headers exist, generating if necessary. + + Returns: + True if kernels exist or were successfully generated + """ + if self.output_dir.exists(): + kernels = list(self.output_dir.glob("*.hpp")) + if kernels: + return True + + # Generate standard kernels + result = self.generate("standard") + return result.success + + def list_kernels(self) -> List[Path]: + """List all generated kernel headers""" + if self.output_dir.exists(): + return sorted(self.output_dir.glob("*.hpp")) + return [] + + def categorize_kernels(self) -> dict: + """ + Categorize kernels by tile size and variant. + + Returns: + Dict with categories by tile size and variant type + """ + kernels = self.list_kernels() + + # Separate by variant first + preshuffle = [k for k in kernels if "_preshuffle" in k.name] + multi_d = [k for k in kernels if "_multid_" in k.name] + standard = [ + k + for k in kernels + if "_preshuffle" not in k.name and "_multid_" not in k.name + ] + + # Categorize standard kernels by tile size + compute = [k for k in standard if "_256x" in k.name] + memory = [k for k in standard if "_128x" in k.name] + latency = [k for k in standard if "_64x" in k.name or "_32x" in k.name] + + return { + "total": len(kernels), + "standard": len(standard), + "compute": compute, + "memory": memory, + "latency": latency, + "preshuffle": preshuffle, + "multi_d": multi_d, + } + + +# ============================================================================= +# Registry and Dispatcher (Explicit API) +# ============================================================================= + + +class Registry: + """ + Kernel registry - stores and manages kernel instances. + + This provides an explicit registry API that mirrors the C++ Registry class. + + Usage: + registry = Registry() + registry.register_kernel(kernel_config) + dispatcher = Dispatcher(registry) + """ + + def __init__(self, lib: Optional[DispatcherLib] = None, name: str = "default"): + self._lib = lib + self._name = name + self._kernels: List[KernelConfig] = [] + + @property + def name(self) -> str: + return self._name + + @property + def kernel_count(self) -> int: + if self._lib: + return self._lib.get_kernel_count() + return len(self._kernels) + + def register_kernel(self, config: KernelConfig) -> bool: + """Register a kernel configuration.""" + self._kernels.append(config) + return True + + def get_kernels(self) -> List[KernelConfig]: + """Get all registered kernel configs.""" + return self._kernels.copy() + + def clear(self): + """Clear all kernels.""" + self._kernels.clear() + + def bind_library(self, lib: DispatcherLib): + """Bind to a loaded dispatcher library.""" + self._lib = lib + + def __repr__(self) -> str: + return f"Registry(name='{self._name}', kernels={self.kernel_count})" + + +class Dispatcher: + """ + Kernel dispatcher - selects and runs kernels for problems. + + This provides an explicit dispatcher API that mirrors the C++ Dispatcher class. + + Usage: + registry = Registry() + registry.register_kernel(config) + + dispatcher = Dispatcher(registry) + result = dispatcher.run(A, B, M, N, K) + """ + + def __init__(self, registry: Registry, lib: Optional[DispatcherLib] = None): + self._registry = registry + self._lib = lib or registry._lib + + @property + def registry(self) -> Registry: + return self._registry + + def select_kernel(self, M: int, N: int, K: int) -> Optional[str]: + """Select best kernel for problem dimensions.""" + if self._lib: + return self._lib.select_kernel(M, N, K) + # Fallback: return first matching kernel + for config in self._registry.get_kernels(): + return f"kernel_{config.tile_str}" + return None + + def is_supported(self, M: int, N: int, K: int) -> bool: + """Check if problem size is supported.""" + if self._lib: + return self._lib.is_supported(M, N, K) + return len(self._registry.get_kernels()) > 0 + + def run(self, A: np.ndarray, B: np.ndarray, M: int, N: int, K: int) -> GemmResult: + """ + Run GEMM: C = A @ B + + Args: + A: Input matrix (M x K) + B: Input matrix (K x N) + M, N, K: Problem dimensions + + Returns: + GemmResult with output and timing + """ + if self._lib is None: + raise RuntimeError("Dispatcher not bound to library") + + # Ensure contiguous float16 arrays + A_gpu = np.ascontiguousarray(A, dtype=np.float16) + B_gpu = np.ascontiguousarray(B.T, dtype=np.float16) # Column-major + C_gpu = np.zeros((M, N), dtype=np.float16) + + # Run via library + status, time_ms = self._lib.run_gemm(A_gpu, B_gpu, C_gpu, M, N, K) + + # Calculate TFLOPS + flops = 2.0 * M * N * K + tflops = (flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0 + + return GemmResult( + output=C_gpu, + time_ms=time_ms, + status=status, + tflops=tflops, + kernel_name=self._lib.get_kernel_name() if self._lib else "unknown", + ) + + def __repr__(self) -> str: + return f"Dispatcher(registry={self._registry.name}, kernels={self._registry.kernel_count})" + + +# ============================================================================= +# Main (self-test) +# ============================================================================= + +if __name__ == "__main__": + print("CK Tile Dispatcher Utils Self-Test") + print("=" * 60) + + # Test library loading + print("\n1. Loading library...") + lib = DispatcherLib.auto() + if lib is None: + print(" FAILED: Could not load library") + exit(1) + print(f" OK: Loaded from {lib.path}") + print(f" Kernel: {lib.get_kernel_name()}") + print(f" Registered kernels: {lib.get_kernel_count()}") + + # Test GEMM + print("\n2. Running GEMM 256x256x256...") + runner = GemmRunner(lib) + A = np.random.randn(256, 256).astype(np.float16) + B = np.random.randn(256, 256).astype(np.float16) + + result = runner.run(A, B) + print(f" Status: {'OK' if result.success else 'FAILED'}") + print(f" Time: {result.time_ms:.4f} ms") + print(f" TFLOPS: {result.tflops:.2f}") + + # Test validation + print("\n3. Validating result...") + validator = Validator() + reference = validator.compute_reference(A, B) + correct, max_diff, mean_diff = validator.check(result.output, reference) + print(f" Correct: {correct}") + print(f" Max diff: {max_diff:.6f}") + + print("\n" + "=" * 60) + print("All tests passed!") + + +# ============================================================================= +# High-Level Helper Functions +# ============================================================================= + + +@dataclass +class GemmSetupResult: + """Result of setup_gemm_dispatcher""" + + success: bool + dispatcher: Optional[Dispatcher] = None + lib: Optional[DispatcherLib] = None + registry: Optional[Registry] = None + codegen: Optional[CodegenRunner] = None + config: Optional[KernelConfig] = None + kernel_header: Optional[Path] = None + error: str = "" + corrections: List[str] = field(default_factory=list) + + +def setup_gemm_dispatcher( + config: KernelConfig, + registry_name: str = "gemm_registry", + verbose: bool = True, + auto_rebuild: bool = True, +) -> GemmSetupResult: + """ + High-level helper to setup a GEMM dispatcher from a kernel config. + + This handles: + 1. Validate config against arch filter (auto-correct if needed) + 2. Generate kernel code if needed + 3. Find matching kernel header + 4. Load or rebuild library (if dtype mismatch) + 5. Create registry and dispatcher + + Args: + config: KernelConfig with all parameters + registry_name: Name for the registry + verbose: Print progress messages + auto_rebuild: Rebuild library if dtype doesn't match + + Returns: + GemmSetupResult with dispatcher, lib, registry, etc. + """ + result = GemmSetupResult(success=False, config=config) + + def log(msg): + if verbose: + print(msg) + + # Step 1: Validate config + log(" Validating config...") + validation = validate_kernel_config(config) + if not validation.is_valid: + log(" ⚠ Auto-correcting configuration...") + config, was_modified, corrections = auto_correct_kernel_config( + config, verbose=verbose + ) + result.config = config + result.corrections = corrections + # Note: corrections will be displayed by the caller via print_auto_correction + + # Step 2: Setup codegen and generate kernel + log(f" Generating kernel (tile={config.tile_str})...") + codegen = CodegenRunner( + datatype=config.dtype_a, + layout=config.layout, + gpu_target=config.gfx_arch, + ) + result.codegen = codegen + + codegen_result = codegen.generate_from_config(config) + if not codegen_result.success: + log(" ⚠ Kernel generation: using existing") + + # Step 3: Find matching kernel header + kernel_header = find_matching_kernel_header(config) + result.kernel_header = kernel_header + if not kernel_header: + log(" ⚠ No matching kernel header found") + + # Step 4: Load library + log(" Loading library...") + lib = DispatcherLib.auto() + if lib is None: + result.error = "Could not load dispatcher library" + return result + result.lib = lib + + # Check if library kernel matches config - rebuild if ANY parameter differs + lib_kernel = lib.get_kernel_name() + needs_rebuild = False + mismatches = [] + + if lib_kernel: + # Build expected kernel signature components from config + expected_parts = { + "dtype": config.dtype_a, + "layout": config.layout, + "pipeline": config.pipeline, + "epilogue": config.epilogue, + "scheduler": config.scheduler, + "tile": f"{config.tile_m}x{config.tile_n}x{config.tile_k}", + "wave": f"{config.wave_m}x{config.wave_n}x{config.wave_k}", + "warp": f"{config.warp_m}x{config.warp_n}x{config.warp_k}", + } + + # Check each component against the library kernel name + for name, expected in expected_parts.items(): + if expected not in lib_kernel: + needs_rebuild = True + mismatches.append(f"{name}={expected}") + + if needs_rebuild and auto_rebuild: + log(f" Library kernel doesn't match config: {', '.join(mismatches)}") + log(" Rebuilding library for exact config match...") + + # First ensure we have a kernel header for this exact config + if not kernel_header: + # Generate kernel for the exact config + log(" Generating kernel for config...") + codegen_result = codegen.generate_from_config(config, force=True) + kernel_header = find_matching_kernel_header(config) + result.kernel_header = kernel_header + + if kernel_header: + new_lib_path = codegen._rebuild_library_for_config(config, kernel_header) + if new_lib_path: + lib = DispatcherLib.load(new_lib_path) + if lib is None or not lib.initialize(): + result.error = "Failed to load rebuilt library" + return result + result.lib = lib + log(f" ✓ Rebuilt library: {lib.get_kernel_name()}") + else: + log(" ⚠ Rebuild failed, using existing library") + else: + log(" ⚠ No kernel header found for config, using existing library") + + # Step 5: Create registry and dispatcher + log(" Creating registry and dispatcher...") + registry = Registry(name=registry_name, lib=lib) + registry.register_kernel(config) + result.registry = registry + + dispatcher = Dispatcher(registry=registry, lib=lib) + result.dispatcher = dispatcher + + log(f" ✓ Ready: {lib.get_kernel_name()}") + + result.success = True + return result + + +def cleanup_gemm(): + """ + Cleanup function to call after running GEMM examples. + + This helps ensure clean state between examples by: + 1. Clearing any global state + 2. Suggesting garbage collection + """ + import gc + + # Clear loaded libraries list + DispatcherLib._loaded_libs.clear() + + # Suggest garbage collection + gc.collect() + + +def cleanup_generated_kernels( + keep_default: bool = True, + verbose: bool = False, +) -> int: + """ + Clean up generated kernel files. + + Call this at the start of examples to ensure fresh state. + + Args: + keep_default: Keep the default fp16 kernel (True) or delete all (False) + verbose: Print what's being deleted + + Returns: + Number of files deleted + """ + + kernel_dir = get_generated_kernels_dir() + if not kernel_dir.exists(): + return 0 + + deleted = 0 + + # Default kernel pattern to keep + default_pattern = ( + "gemm_fp16_rcr_compv4_cshuffle_intrawave_*_128x128x32_2x2x1_16x16x16.hpp" + ) + + for f in kernel_dir.glob("*.hpp"): + # Skip dispatcher_wrappers directory + if f.is_dir(): + continue + + # Optionally keep default kernel + if keep_default and f.match(default_pattern): + continue + + if verbose: + print(f" Deleting: {f.name}") + f.unlink() + deleted += 1 + + # Also clean up any temp libs + build_dir = get_build_dir() + examples_dir = build_dir / "examples" + if examples_dir.exists(): + for f in examples_dir.glob("libdispatcher_gemm_*_lib.so"): + if f.name != "libdispatcher_gemm_lib.so": + if verbose: + print(f" Deleting: {f.name}") + f.unlink() + deleted += 1 + + return deleted + + +def reset_for_example(verbose: bool = False): + """ + Reset state for a fresh example run. + + Call this at the START of each example to ensure clean state. + Cleans up generated kernels (except default) and resets globals. + """ + # Cleanup any previously generated kernels + deleted = cleanup_generated_kernels(keep_default=True, verbose=verbose) + if verbose and deleted > 0: + print(f" Cleaned up {deleted} generated files") + + # Clear any cached state + cleanup_gemm() + + +# Main (self-test) +# ============================================================================= + +if __name__ == "__main__": + print("CK Tile Dispatcher Utils Self-Test") + print("=" * 60) + + # Test library loading + print("\n1. Loading library...") + lib = DispatcherLib.auto() + if lib is None: + print(" FAILED: Could not load library") + exit(1) + print(f" OK: Loaded from {lib.path}") + print(f" Kernel: {lib.get_kernel_name()}") + print(f" Registered kernels: {lib.get_kernel_count()}") + + # Test GEMM + print("\n2. Running GEMM 256x256x256...") + runner = GemmRunner(lib) + A = np.random.randn(256, 256).astype(np.float16) + B = np.random.randn(256, 256).astype(np.float16) + + result = runner.run(A, B) + print(f" Status: {'OK' if result.success else 'FAILED'}") + print(f" Time: {result.time_ms:.4f} ms") + print(f" TFLOPS: {result.tflops:.2f}") + + # Test validation + print("\n3. Validating result...") + validator = Validator() + reference = validator.compute_reference(A, B) + correct, max_diff, mean_diff = validator.check(result.output, reference) + print(f" Correct: {correct}") + print(f" Max diff: {max_diff:.6f}") + + # Test high-level helper + print("\n4. Testing setup_gemm_dispatcher...") + config = KernelConfig(tile_m=128, tile_n=128, tile_k=32) + setup = setup_gemm_dispatcher(config, verbose=True) + print(f" Success: {setup.success}") + + # Cleanup + cleanup_gemm() + + print("\n" + "=" * 60) + print("All tests passed!") diff --git a/dispatcher/python/pytest.ini b/dispatcher/python/pytest.ini new file mode 100644 index 0000000000..08cd235fda --- /dev/null +++ b/dispatcher/python/pytest.ini @@ -0,0 +1,43 @@ +[pytest] +# Pytest configuration for CK Tile Dispatcher Python tests + +# Test discovery +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Test paths +testpaths = tests + +# Options +addopts = + -v + --strict-markers + --tb=short + --color=yes + --durations=10 + +# Markers +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + cuda: marks tests requiring CUDA/ROCm + torch: marks tests requiring PyTorch + integration: marks integration tests + unit: marks unit tests + +# Coverage +[coverage:run] +source = . +omit = + */tests/* + */examples/* + setup.py + +[coverage:report] +precision = 2 +show_missing = True +skip_covered = False + +[coverage:html] +directory = htmlcov + diff --git a/dispatcher/python/requirements.txt b/dispatcher/python/requirements.txt new file mode 100644 index 0000000000..9d429235f7 --- /dev/null +++ b/dispatcher/python/requirements.txt @@ -0,0 +1,22 @@ +# Core dependencies +numpy>=1.19.0 + +# Optional dependencies (install with pip install -e ".[torch]") +# torch>=2.0.0 + +# Development dependencies (install with pip install -e ".[dev]") +# pytest>=6.0.0 +# pytest-cov>=2.0.0 +# black>=21.0 +# flake8>=3.9.0 +# mypy>=0.910 +# isort>=5.0.0 + +# Visualization dependencies (install with pip install -e ".[viz]") +# matplotlib>=3.3.0 +# seaborn>=0.11.0 + +# Documentation dependencies +# sphinx>=4.0.0 +# sphinx-rtd-theme>=1.0.0 + diff --git a/dispatcher/scripts/compile_gemm_examples.py b/dispatcher/scripts/compile_gemm_examples.py new file mode 100644 index 0000000000..b19c18a13a --- /dev/null +++ b/dispatcher/scripts/compile_gemm_examples.py @@ -0,0 +1,2253 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Cross-platform build script for declarative kernel workflow. + +Uses existing ctypes_utils.py for path management and codegen. + +Usage: + python3 compile_gemm_examples.py [output_name] + +Example: + python3 compile_gemm_examples.py examples/cpp/01_basic_gemm.cpp my_app +""" + +import argparse +import os +import re +import subprocess +import sys +from pathlib import Path +import shutil + +# Add dispatcher/python to path to reuse existing utilities +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +# Import existing utilities (after sys.path modification) +from ctypes_utils import ( # noqa: E402 + get_dispatcher_root, + get_ck_root, + get_build_dir, + get_generated_kernels_dir, + CodegenRunner, +) + + +# ============================================================================= +# Terminal Colors (cross-platform) +# ============================================================================= + + +class Colors: + if sys.platform != "win32" and sys.stdout.isatty(): + GREEN = "\033[0;32m" + YELLOW = "\033[1;33m" + RED = "\033[0;31m" + NC = "\033[0m" + else: + GREEN = YELLOW = RED = NC = "" + + +def print_phase(msg: str): + print(f"{Colors.YELLOW}{msg}{Colors.NC}") + + +def print_success(msg: str): + print(f"{Colors.GREEN}{msg}{Colors.NC}") + + +def print_error(msg: str): + print(f"{Colors.RED}{msg}{Colors.NC}", file=sys.stderr) + + +# ============================================================================= +# Compiler Detection +# ============================================================================= + + +def find_hipcc() -> str: + """Find hipcc compiler.""" + candidates = [ + os.environ.get("HIPCC"), + "/opt/rocm/bin/hipcc", + "/opt/rocm/hip/bin/hipcc", + shutil.which("hipcc"), + ] + + for path in candidates: + if path and os.path.isfile(path): + return path + + raise RuntimeError( + "hipcc not found. Please install ROCm or set HIPCC environment variable." + ) + + +# ============================================================================= +# Declaration Extraction +# ============================================================================= + + +def extract_conv_kernel_declarations(source_file: Path) -> list: + """Extract CONVOLUTION kernel declarations from C++ source file. + + Supports DECL_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern. + Extracts all parameters: dtype, layout, conv_type, dims, tile, wave, warp, pipeline, scheduler. + """ + content = source_file.read_text() + declarations = [] + seen = set() + + # Pattern: DECL_CONV_KERNEL_SET(name, .add(...).add(...)) + set_pattern = r"DECL_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)" + + for match in re.finditer(set_pattern, content, re.DOTALL): + set_name = match.group(1) + set_body = match.group(2) + + # Pattern 1: Simple add("dtype", "layout", "conv_type", tile_k, tile_c) + simple_add = ( + r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)' + ) + for add_match in re.finditer(simple_add, set_body): + dtype = add_match.group(1) + layout = add_match.group(2) + conv_type = add_match.group(3) + tile_k = int(add_match.group(4)) + tile_c = int(add_match.group(5)) + + name = f"{set_name}:{dtype}_{layout}_{conv_type}_{tile_k}x{tile_c}" + if name not in seen: + seen.add(name) + declarations.append( + { + "type": "conv", + "dtype": dtype, + "layout": layout, + "conv_type": conv_type, + "num_dims": 2, + "groups": 1, + "tile_n": 1, + "tile_k": tile_k, + "tile_c": tile_c, + "wave_m": -1, # Wildcard - will expand + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "set": set_name, + "arch": "gfx942", + } + ) + + # Pattern 2: Full specification with ConvSig() and ConvAlgo() + # Match .add( ConvSig()..., ConvAlgo()..., "arch" ) + # Use robust parsing that handles multi-line and comments + + # Find all .add( blocks containing ConvSig + add_blocks = re.findall( + r"\.add\s*\(\s*ConvSig\(\)([\s\S]*?)(?=\.add\s*\(|$)", set_body + ) + + for add_block in add_blocks: + # Find ConvAlgo and arch in this block + algo_match = re.search(r'ConvAlgo\(\)([\s\S]*?),\s*"(\w+)"\s*\)', add_block) + if not algo_match: + continue + + sig_str = add_block[: add_block.find("ConvAlgo()")] + algo_str = algo_match.group(1) + arch = algo_match.group(2) + + # Parse ConvSig + dtype = "fp16" + dtype_match = re.search(r'\.dtype\s*\(\s*"([^"]+)"', sig_str) + if dtype_match: + dtype = dtype_match.group(1) + + layout = "nhwgc" + layout_match = re.search(r'\.layout\s*\(\s*"([^"]+)"', sig_str) + if layout_match: + layout = layout_match.group(1) + + conv_type = "forward" + conv_type_match = re.search(r'\.conv_type\s*\(\s*"([^"]+)"', sig_str) + if conv_type_match: + conv_type = conv_type_match.group(1) + + num_dims = 2 + dims_match = re.search(r"\.dims\s*\(\s*(\d+)", sig_str) + if dims_match: + num_dims = int(dims_match.group(1)) + + groups = 1 + groups_match = re.search(r"\.groups\s*\(\s*(\d+)", sig_str) + if groups_match: + groups = int(groups_match.group(1)) + + # Parse ConvAlgo + tile_n, tile_k, tile_c = 1, 128, 128 + tile_match = re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", algo_str + ) + if tile_match: + tile_n = int(tile_match.group(1)) + tile_k = int(tile_match.group(2)) + tile_c = int(tile_match.group(3)) + + wave_m, wave_n, wave_k = 2, 2, 1 + wave_match = re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if wave_match: + wave_m = int(wave_match.group(1)) + wave_n = int(wave_match.group(2)) + wave_k = int(wave_match.group(3) or 1) + + warp_m, warp_n, warp_k = 32, 32, 16 + warp_match = re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if warp_match: + warp_m = int(warp_match.group(1)) + warp_n = int(warp_match.group(2)) + warp_k = int(warp_match.group(3) or 16) + + pipeline = "compv3" + pipeline_match = re.search(r'\.pipeline\s*\(\s*"([^"]+)"', algo_str) + if pipeline_match: + pipeline = pipeline_match.group(1) + + scheduler = "intrawave" + scheduler_match = re.search(r'\.scheduler\s*\(\s*"([^"]+)"', algo_str) + if scheduler_match: + scheduler = scheduler_match.group(1) + + epilogue = "cshuffle" + epilogue_match = re.search(r'\.epilogue\s*\(\s*"([^"]+)"', algo_str) + if epilogue_match: + epilogue = epilogue_match.group(1) + + # Build unique name with full config + name = f"{set_name}:{dtype}_{conv_type}_{num_dims}d_{pipeline}_{scheduler}_{tile_k}x{tile_c}_{wave_m}x{wave_n}x{wave_k}" + if name not in seen: + seen.add(name) + declarations.append( + { + "type": "conv", + "dtype": dtype, + "layout": layout, + "conv_type": conv_type, + "num_dims": num_dims, + "groups": groups, + "tile_n": tile_n, + "tile_k": tile_k, + "tile_c": tile_c, + "wave_m": wave_m, + "wave_n": wave_n, + "wave_k": wave_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "pipeline": pipeline, + "scheduler": scheduler, + "epilogue": epilogue, + "name": name, + "set": set_name, + "arch": arch, + } + ) + + return declarations + + +def expand_conv_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") -> list: + """Expand a convolution declaration to all valid combinations. + + Like GEMM, convolution supports wildcard expansion for: + - wave/warp: If -1, generates all valid combinations + - pipeline/scheduler: If "*", generates all valid trait combinations + """ + # Import arch filter + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + TRAIT_UNSUPPORTED_COMBINATIONS, + ) + except ImportError: + # Fallback + WARP_SUPPORTED_COMBINATIONS = { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + } + WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + } + TRAIT_UNSUPPORTED_COMBINATIONS = set() + + d = decl.copy() + tile_k = d.get("tile_k", 128) + tile_c = d.get("tile_c", 128) + dtype = d.get("dtype", "fp16") + + # Check what needs expansion + needs_wave_expansion = d.get("wave_m", -1) < 0 or d.get("wave_n", -1) < 0 + needs_warp_expansion = d.get("warp_m", -1) < 0 or d.get("warp_n", -1) < 0 + needs_pipeline_expansion = d.get("pipeline", "compv4") == "*" + needs_scheduler_expansion = d.get("scheduler", "intrawave") == "*" + + if ( + not needs_wave_expansion + and not needs_warp_expansion + and not needs_pipeline_expansion + and not needs_scheduler_expansion + ): + return [d] + + # Build valid combinations + if needs_wave_expansion or needs_warp_expansion: + wave_configs = WARP_SUPPORTED_COMBINATIONS.get(arch, [[2, 2, 1]]) + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_configs = WARP_TILE_SUPPORTED_COMBINATIONS.get(arch, {}).get( + dtype_key, [[32, 32, 16], [16, 16, 16]] + ) + else: + wave_configs = [[d.get("wave_m", 2), d.get("wave_n", 2), d.get("wave_k", 1)]] + warp_tile_configs = [ + [d.get("warp_m", 32), d.get("warp_n", 32), d.get("warp_k", 16)] + ] + + # Pipeline/scheduler combinations + ALL_PIPELINES = ["compv3", "compv4"] + ALL_SCHEDULERS = ["intrawave", "interwave"] + + pipelines = ( + ALL_PIPELINES if needs_pipeline_expansion else [d.get("pipeline", "compv4")] + ) + schedulers = ( + ALL_SCHEDULERS + if needs_scheduler_expansion + else [d.get("scheduler", "intrawave")] + ) + + expanded = [] + + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_tile_configs: + # Check divisibility for conv (M=output spatial, N=K channels, K=C channels) + # Simplified check for now + if tile_k % (wn * wtn) != 0: + continue + if tile_c % (wk * wtk) != 0: + continue + + for pipeline in pipelines: + for scheduler in schedulers: + # Check trait combination + if ( + pipeline, + "cshuffle", + scheduler, + ) in TRAIT_UNSUPPORTED_COMBINATIONS: + continue + + expanded_d = d.copy() + expanded_d["wave_m"] = wm + expanded_d["wave_n"] = wn + expanded_d["wave_k"] = wk + expanded_d["warp_m"] = wtm + expanded_d["warp_n"] = wtn + expanded_d["warp_k"] = wtk + expanded_d["pipeline"] = pipeline + expanded_d["scheduler"] = scheduler + + expanded_d["name"] = ( + f"conv_{d['conv_type']}_{dtype}_{d['num_dims']}d_{pipeline}_" + f"{scheduler}_{tile_k}x{tile_c}_{wm}x{wn}x{wk}" + ) + expanded.append(expanded_d) + + if not expanded: + # Fallback to defaults + d["wave_m"] = 2 + d["wave_n"] = 2 + d["wave_k"] = 1 + d["warp_m"] = 32 + d["warp_n"] = 32 + d["warp_k"] = 16 + d["pipeline"] = "compv4" + d["scheduler"] = "intrawave" + return [d] + + return expanded + + +def generate_conv_kernels(declarations: list, gpu_target: str = "gfx942") -> int: + """Generate convolution kernels using unified_conv_codegen.""" + kernel_dir = get_generated_kernels_dir() + kernel_dir.mkdir(parents=True, exist_ok=True) + + # Import conv codegen + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from unified_conv_codegen import ( + UnifiedConvCodegen, + ConvKernelConfig, + ConvVariant, + TileConfig, + TraitConfig, + ) + except ImportError as e: + print_error(f" Failed to import conv codegen: {e}") + return 0 + + codegen = UnifiedConvCodegen(kernel_dir) + total_generated = 0 + + # Group by dtype and variant for efficient generation + groups = {} + for decl in declarations: + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + key = (dtype, conv_type, num_dims) + if key not in groups: + groups[key] = [] + groups[key].append(decl) + + for (dtype, conv_type, num_dims), decls in groups.items(): + print(f" Generating {dtype} {conv_type} {num_dims}D kernels...") + + # Map to ConvVariant + variant = ConvVariant.FORWARD + if conv_type == "bwd_data": + variant = ConvVariant.BACKWARD_DATA + elif conv_type == "bwd_weight": + variant = ConvVariant.BACKWARD_WEIGHT + + for decl in decls: + pipeline = decl.get("pipeline", "compv3") + scheduler = decl.get("scheduler", "intrawave") + epilogue = decl.get("epilogue", "cshuffle") + + tile_k = decl.get("tile_k", 128) + tile_c = decl.get("tile_c", 128) + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + # Adjust tile_k for compv4 + adj_tile_k = 64 * 2 if pipeline == "compv4" else 64 + + # Create TileConfig + tile_config = TileConfig( + tile_m=tile_k, # K is M in conv GEMM view + tile_n=tile_c, # C is N in conv GEMM view + tile_k=adj_tile_k, + warp_m=wave_m, + warp_n=wave_n, + warp_k=1, + warp_tile_m=warp_m, + warp_tile_n=warp_n, + warp_tile_k=warp_k, + ) + + # Create TraitConfig + trait_config = TraitConfig( + pipeline=pipeline, + scheduler=scheduler, + epilogue=epilogue, + double_smem_buffer=(pipeline == "compv4"), + pad_m=True, + pad_n=True, + pad_k=True, + ) + + # Create ConvKernelConfig + config = ConvKernelConfig( + tile=tile_config, + trait=trait_config, + variant=variant, + ndim_spatial=num_dims, + arch=gpu_target, + ) + + try: + filepath = codegen.generate_kernel(config, dtype) + total_generated += 1 + print(f" Generated: {filepath.name}") + except Exception as e: + print_error(f" Failed to generate {decl['name']}: {e}") + + return total_generated + + +# Original GEMM extraction continues here +def extract_kernel_declarations(source_file: Path) -> list: + """Extract GEMM kernel declarations from C++ source file.""" + content = source_file.read_text() + declarations = [] + seen = set() + + # ------------------------------------------------------------------------- + # Pattern 1: Simple DECL_KERNEL_SIMPLE(dtype, layout, tile_m, tile_n, tile_k) + # ------------------------------------------------------------------------- + legacy_pattern = r"DECL_KERNEL_SIMPLE\s*\(\s*(\w+)\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)" + for match in re.findall(legacy_pattern, content): + dtype, layout, tm, tn, tk = match + name = f"{dtype}_{layout}_{tm}x{tn}x{tk}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": int(tm), + "tile_n": int(tn), + "tile_k": int(tk), + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": False, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 2: Fluent API: DECL_KERNEL(Signature()..., Algorithm()..., arch) + # ------------------------------------------------------------------------- + # Match DECL_KERNEL( ... ); blocks + fluent_pattern = r'DECL_KERNEL\s*\(\s*(Signature\(\)[^,]+),\s*(Algorithm\(\)[^,]+)(?:,\s*"([^"]+)")?\s*\)' + + for match in re.finditer(fluent_pattern, content, re.DOTALL): + sig_str = match.group(1) + algo_str = match.group(2) + arch = match.group(3) or "gfx942" + + # Parse Signature + sig = {"dtype_a": "fp16", "dtype_b": "fp16", "dtype_c": "fp16", "layout": "rcr"} + + # .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16") + dtype_match = re.search( + r'\.dtype\("([^"]+)"(?:,\s*"([^"]+)")?(?:,\s*"([^"]+)")?', sig_str + ) + if dtype_match: + sig["dtype_a"] = dtype_match.group(1) + sig["dtype_b"] = dtype_match.group(2) or dtype_match.group(1) + sig["dtype_c"] = dtype_match.group(3) or dtype_match.group(1) + + # .layout("rcr") or .layout("row", "col", "row") + layout_match = re.search( + r'\.layout\("([^"]+)"(?:,\s*"([^"]+)")?(?:,\s*"([^"]+)")?', sig_str + ) + if layout_match: + if layout_match.group(2): # Three-arg form + la = layout_match.group(1) + lb = layout_match.group(2) + lc = layout_match.group(3) or "row" + sig["layout"] = ( + ("r" if la == "row" else "c") + + ("r" if lb == "row" else "c") + + ("r" if lc == "row" else "c") + ) + else: # Single arg "rcr" + sig["layout"] = layout_match.group(1) + + # Parse Algorithm + algo = {} + + # .tile(128, 128, 32) + tile_match = re.search(r"\.tile\((\d+),\s*(\d+),\s*(\d+)\)", algo_str) + if tile_match: + algo["tile_m"] = int(tile_match.group(1)) + algo["tile_n"] = int(tile_match.group(2)) + algo["tile_k"] = int(tile_match.group(3)) + + # .wave(2, 2, 1) + wave_match = re.search(r"\.wave\((\d+),\s*(\d+)(?:,\s*(\d+))?\)", algo_str) + if wave_match: + algo["wave_m"] = int(wave_match.group(1)) + algo["wave_n"] = int(wave_match.group(2)) + algo["wave_k"] = int(wave_match.group(3) or 1) + + # .warp(32, 32, 16) + warp_match = re.search(r"\.warp\((\d+),\s*(\d+)(?:,\s*(\d+))?\)", algo_str) + if warp_match: + algo["warp_m"] = int(warp_match.group(1)) + algo["warp_n"] = int(warp_match.group(2)) + algo["warp_k"] = int(warp_match.group(3) or 16) + + # .pipeline("compv4"), .scheduler("intrawave"), .epilogue("cshuffle") + for field in ["pipeline", "scheduler", "epilogue"]: + fmatch = re.search(rf'\.{field}\("([^"]+)"\)', algo_str) + if fmatch: + algo[field] = fmatch.group(1) + + # Build declaration + tm = algo.get("tile_m", 128) + tn = algo.get("tile_n", 128) + tk = algo.get("tile_k", 32) + + name = f"{sig['dtype_a']}_{sig['layout']}_{tm}x{tn}x{tk}" + + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": sig["dtype_a"], + "dtype_b": sig["dtype_b"], + "dtype_c": sig["dtype_c"], + "layout": sig["layout"], + "tile_m": tm, + "tile_n": tn, + "tile_k": tk, + "wave_m": algo.get("wave_m", -1), + "wave_n": algo.get("wave_n", -1), + "wave_k": algo.get("wave_k", 1), + "warp_m": algo.get("warp_m", -1), + "warp_n": algo.get("warp_n", -1), + "warp_k": algo.get("warp_k", 16), + "pipeline": algo.get("pipeline", "compv4"), + "scheduler": algo.get("scheduler", "intrawave"), + "epilogue": algo.get("epilogue", "cshuffle"), + "arch": arch, + "name": name, + "wildcard": False, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 3: DECL_KERNEL_ALL(dtype, layout) - wildcard + # ------------------------------------------------------------------------- + all_pattern = r"DECL_KERNEL(?:S)?_ALL\s*\(\s*(\w+)\s*,\s*(\w+)\s*\)" + for match in re.findall(all_pattern, content): + dtype, layout = match + name = f"wildcard_{dtype}_{layout}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": -1, + "tile_n": -1, + "tile_k": -1, + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": True, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 4: DECL_KERNEL_SIMPLE(dtype, layout, tm, tn, tk) + # ------------------------------------------------------------------------- + simple_pattern = r"DECL_KERNEL_SIMPLE\s*\(\s*(\w+)\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)" + for match in re.findall(simple_pattern, content): + dtype, layout, tm, tn, tk = match + name = f"{dtype}_{layout}_{tm}x{tn}x{tk}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": int(tm), + "tile_n": int(tn), + "tile_k": int(tk), + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": False, + "set": None, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 5: DECL_KERNEL_SET(name, .add(...).add(...)) + # Named kernel sets for multiple registries + # Match only DECL_KERNEL_SET at start of line (not in comments) + # ------------------------------------------------------------------------- + set_pattern = r"^DECL_KERNEL_SET\s*\(\s*(\w+)\s*,([\s\S]*?)\)\s*;" + for match in re.finditer(set_pattern, content, re.MULTILINE): + set_name = match.group(1) + set_body = match.group(2) + + # Parse .add("dtype", "layout", tm, tn, tk) calls - simple form + add_simple = r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)' + for add_match in re.findall(add_simple, set_body): + dtype, layout, tm, tn, tk = add_match + name = f"{set_name}:{dtype}_{layout}_{tm}x{tn}x{tk}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": int(tm), + "tile_n": int(tn), + "tile_k": int(tk), + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": False, + "set": set_name, + } + ) + + # Parse .add(Signature()..., Algorithm()..., "arch") fluent calls + # Robust approach: find each .add( block and parse methods individually + # This handles any method order and optional methods + + # Split set_body into .add() blocks + add_blocks = [] + add_starts = [m.start() for m in re.finditer(r"\.add\s*\(", set_body)] + + for i, start in enumerate(add_starts): + # Find the matching closing paren by counting parens + depth = 0 + end = start + in_string = False + escape_next = False + + for j, ch in enumerate(set_body[start:], start): + if escape_next: + escape_next = False + continue + if ch == "\\": + escape_next = True + continue + if ch == '"' and not escape_next: + in_string = not in_string + continue + if in_string: + continue + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + end = j + 1 + break + + if end > start: + add_blocks.append(set_body[start:end]) + + for add_block in add_blocks: + # Skip if doesn't have both Signature() and Algorithm() + if "Signature()" not in add_block or "Algorithm()" not in add_block: + continue + + # Split on Algorithm() to separate Signature and Algorithm parts + algo_idx = add_block.find("Algorithm()") + if algo_idx == -1: + continue + + sig_str = add_block[:algo_idx] + algo_str = add_block[algo_idx:] # Include Algorithm() and everything after + + # Parse dtype from Signature - handles .dtype("fp16", "fp16", "fp16", "fp32") + dtype = "fp16" + dtype_m = re.search(r'\.dtype\s*\(\s*"([^"]+)"', sig_str) + if dtype_m: + dtype = dtype_m.group(1) + + # Parse layout from Signature - handles .layout("row", "col", "row") + layout = "rcr" + layout_m = re.search( + r'\.layout\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"', sig_str + ) + if layout_m: + la, lb, lc = layout_m.group(1), layout_m.group(2), layout_m.group(3) + layout = ( + ("r" if la == "row" else "c") + + ("r" if lb == "row" else "c") + + ("r" if lc == "row" else "c") + ) + else: + # Single arg form: .layout("rcr") + layout_m = re.search(r'\.layout\s*\(\s*"([^"]+)"', sig_str) + if layout_m: + layout = layout_m.group(1) + + # Parse tile from Algorithm + tm, tn, tk = 128, 128, 32 + tile_m = re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", algo_str + ) + if tile_m: + tm, tn, tk = ( + int(tile_m.group(1)), + int(tile_m.group(2)), + int(tile_m.group(3)), + ) + + # Parse wave + wave_m, wave_n, wave_k = 2, 2, 1 + wave_match = re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if wave_match: + wave_m, wave_n = int(wave_match.group(1)), int(wave_match.group(2)) + wave_k = int(wave_match.group(3) or 1) + + # Parse warp + warp_m, warp_n, warp_k = 32, 32, 16 + warp_match = re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if warp_match: + warp_m, warp_n = int(warp_match.group(1)), int(warp_match.group(2)) + warp_k = int(warp_match.group(3) or 16) + + # Parse pipeline - NEW: extract from declaration + pipeline = "compv4" + pipeline_m = re.search(r'\.pipeline\s*\(\s*"([^"]+)"', algo_str) + if pipeline_m: + pipeline = pipeline_m.group(1) + + # Parse scheduler - NEW: extract from declaration + scheduler = "intrawave" + scheduler_m = re.search(r'\.scheduler\s*\(\s*"([^"]+)"', algo_str) + if scheduler_m: + scheduler = scheduler_m.group(1) + + # Parse epilogue - NEW: extract from declaration + epilogue = "cshuffle" + epilogue_m = re.search(r'\.epilogue\s*\(\s*"([^"]+)"', algo_str) + if epilogue_m: + epilogue = epilogue_m.group(1) + + # Parse padding - NEW: extract from declaration + pad_m, pad_n, pad_k = False, False, False + pad_match = re.search( + r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)\s*\)", + algo_str, + re.IGNORECASE, + ) + if pad_match: + pad_m = pad_match.group(1).lower() == "true" + pad_n = pad_match.group(2).lower() == "true" + pad_k = pad_match.group(3).lower() == "true" + + # Parse elementwise from Signature - for Multi-D kernels + elementwise_op = "PassThrough" + num_d_tensors = 0 + elem_match = re.search( + r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)\s*\)', + sig_str, + ) + if elem_match: + elementwise_op = elem_match.group(1) + num_d_tensors = int(elem_match.group(2)) + + name = f"{set_name}:{dtype}_{layout}_{pipeline}_{scheduler}_{tm}x{tn}x{tk}_{wave_m}x{wave_n}x{wave_k}" + if elementwise_op != "PassThrough": + name += f"_{elementwise_op}_d{num_d_tensors}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": tm, + "tile_n": tn, + "tile_k": tk, + "wave_m": wave_m, + "wave_n": wave_n, + "wave_k": wave_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "pipeline": pipeline, + "scheduler": scheduler, + "epilogue": epilogue, + "pad_m": pad_m, + "pad_n": pad_n, + "pad_k": pad_k, + "elementwise_op": elementwise_op, + "num_d_tensors": num_d_tensors, + "name": name, + "wildcard": False, + "set": set_name, + } + ) + + return declarations + + +def expand_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") -> list: + """Expand a declaration to all valid combinations using arch filter. + + Expands wildcards for: + - wave/warp: If -1, generates all valid wave/warp_tile combinations + - pipeline/scheduler/epilogue: If "*", generates all valid trait combinations + + Uses the arch_filter module for architecture-specific validation. + """ + # Import arch filter + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + TRAIT_UNSUPPORTED_COMBINATIONS, + ) + except ImportError: + # Fallback to hardcoded valid combinations + WARP_SUPPORTED_COMBINATIONS = { + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + } + WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + } + TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + } + + d = decl.copy() + tm = d.get("tile_m", 128) + tn = d.get("tile_n", 128) + tk = d.get("tile_k", 32) + dtype = d.get("dtype_a", "fp16") + + # Check what needs expansion + needs_wave_expansion = d.get("wave_m", -1) < 0 or d.get("wave_n", -1) < 0 + needs_warp_expansion = d.get("warp_m", -1) < 0 or d.get("warp_n", -1) < 0 + needs_pipeline_expansion = d.get("pipeline", "compv4") == "*" + needs_scheduler_expansion = d.get("scheduler", "intrawave") == "*" + needs_epilogue_expansion = d.get("epilogue", "cshuffle") == "*" + needs_pad_m_expansion = d.get("pad_m", 1) == -1 + needs_pad_n_expansion = d.get("pad_n", 1) == -1 + needs_pad_k_expansion = d.get("pad_k", 1) == -1 + needs_trait_expansion = ( + needs_pipeline_expansion + or needs_scheduler_expansion + or needs_epilogue_expansion + ) + needs_pad_expansion = ( + needs_pad_m_expansion or needs_pad_n_expansion or needs_pad_k_expansion + ) + + if ( + not needs_wave_expansion + and not needs_warp_expansion + and not needs_trait_expansion + and not needs_pad_expansion + ): + # Already fully specified + return [d] + + # === Build valid combinations === + + # Wave configurations + if needs_wave_expansion: + wave_configs = WARP_SUPPORTED_COMBINATIONS.get(arch, [[2, 2, 1]]) + else: + wave_configs = [[d.get("wave_m", 2), d.get("wave_n", 2), d.get("wave_k", 1)]] + + # Warp tile configurations + if needs_warp_expansion: + arch_warp_tiles = WARP_TILE_SUPPORTED_COMBINATIONS.get(arch, {}) + + # Try to find warp tile configs for this dtype + # Keys are like: fp16_fp16_fp32, int8_int8_int32, etc. + warp_tile_configs = None + dtype_key_variants = [ + f"{dtype}_{dtype}_{dtype}", # e.g., fp32_fp32_fp32 + f"{dtype}_{dtype}_fp32", # e.g., fp16_fp16_fp32 + f"{dtype}_{dtype}_int32", # e.g., int8_int8_int32 + ] + for dtype_key in dtype_key_variants: + warp_tile_configs = arch_warp_tiles.get(dtype_key, None) + if warp_tile_configs is not None: + break + + # If dtype is not supported on this arch, return empty list + if warp_tile_configs is None: + return [] + else: + warp_tile_configs = [ + [d.get("warp_m", 32), d.get("warp_n", 32), d.get("warp_k", 16)] + ] + + # Pipeline/scheduler/epilogue combinations + # Valid options per category + ALL_PIPELINES = ["compv3", "compv4"] # Most common; add more if needed + ALL_SCHEDULERS = ["intrawave", "interwave"] + ALL_EPILOGUES = ["cshuffle", "default"] + ALL_PAD_OPTIONS = [False, True] # 0 and 1 + + pipelines = ( + ALL_PIPELINES if needs_pipeline_expansion else [d.get("pipeline", "compv4")] + ) + schedulers = ( + ALL_SCHEDULERS + if needs_scheduler_expansion + else [d.get("scheduler", "intrawave")] + ) + epilogues = ( + ALL_EPILOGUES if needs_epilogue_expansion else [d.get("epilogue", "cshuffle")] + ) + pad_m_opts = ALL_PAD_OPTIONS if needs_pad_m_expansion else [bool(d.get("pad_m", 1))] + pad_n_opts = ALL_PAD_OPTIONS if needs_pad_n_expansion else [bool(d.get("pad_n", 1))] + pad_k_opts = ALL_PAD_OPTIONS if needs_pad_k_expansion else [bool(d.get("pad_k", 1))] + + expanded = [] + + # Generate all valid combinations + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_tile_configs: + # Check divisibility constraints + if tm % (wm * wtm) != 0: + continue + if tn % (wn * wtn) != 0: + continue + if tk % (wk * wtk) != 0: + continue + + for pipeline in pipelines: + for scheduler in schedulers: + for epilogue in epilogues: + # Check trait combination is valid + if ( + pipeline, + epilogue, + scheduler, + ) in TRAIT_UNSUPPORTED_COMBINATIONS: + continue + + for pad_m in pad_m_opts: + for pad_n in pad_n_opts: + for pad_k in pad_k_opts: + # Create expanded declaration + expanded_d = d.copy() + expanded_d["wave_m"] = wm + expanded_d["wave_n"] = wn + expanded_d["wave_k"] = wk + expanded_d["warp_m"] = wtm + expanded_d["warp_n"] = wtn + expanded_d["warp_k"] = wtk + expanded_d["pipeline"] = pipeline + expanded_d["scheduler"] = scheduler + expanded_d["epilogue"] = epilogue + expanded_d["pad_m"] = int(pad_m) + expanded_d["pad_n"] = int(pad_n) + expanded_d["pad_k"] = int(pad_k) + + pad_str = f"{'T' if pad_m else 'F'}{'T' if pad_n else 'F'}{'T' if pad_k else 'F'}" + expanded_d["name"] = ( + f"{dtype}_{d.get('layout', 'rcr')}_{pipeline}_{scheduler}_" + f"pad{pad_str}_{tm}x{tn}x{tk}_{wm}x{wn}x{wk}" + ) + expanded_d["wildcard"] = False + expanded.append(expanded_d) + + if not expanded: + # No valid combinations found, return single default + print(f" Warning: No valid combinations for {tm}x{tn}x{tk} on {arch}") + d["wave_m"] = 2 + d["wave_n"] = 2 + d["wave_k"] = 1 + d["warp_m"] = 32 + d["warp_n"] = 32 + d["warp_k"] = 16 + d["pipeline"] = "compv4" + d["scheduler"] = "intrawave" + d["epilogue"] = "cshuffle" + return [d] + + return expanded + + +def auto_fill_declaration(decl: dict) -> dict: + """Auto-fill with single default (for backward compat).""" + expanded = expand_declaration_with_arch_filter(decl, decl.get("arch", "gfx942")) + return expanded[0] if expanded else decl + + +# ============================================================================= +# Build Functions +# ============================================================================= + + +def generate_kernels(declarations: list, gpu_target: str = "gfx942") -> int: + """Generate kernels using CodegenRunner from ctypes_utils.""" + kernel_dir = get_generated_kernels_dir() + kernel_dir.mkdir(parents=True, exist_ok=True) + + # Group by dtype+layout for efficient generation + groups = {} + for decl in declarations: + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + key = (dtype, layout) + if key not in groups: + groups[key] = [] + groups[key].append(auto_fill_declaration(decl)) + + total_generated = 0 + + for (dtype, layout), decls in groups.items(): + print(f" Generating {dtype} {layout} kernels...") + + # Check for wildcards - if any decl is wildcard, generate all + has_wildcard = any(d.get("wildcard", False) for d in decls) + + # Use CodegenRunner from ctypes_utils + runner = CodegenRunner( + datatype=dtype, + layout=layout, + gpu_target=gpu_target, + ) + + result = runner.generate("standard") + + if result.success: + total_generated += result.kernel_count + if has_wildcard: + print(f" [wildcard] Generated all {result.kernel_count} variants") + else: + print_error(f" Failed: {result.stderr[:200]}") + + return total_generated + + +def get_arch_filter_data(): + """Load arch filter data from arch_specs_generated if available.""" + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + TRAIT_UNSUPPORTED_COMBINATIONS, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + return { + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + # Fallback defaults + return { + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + }, + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + + +def is_wildcard_declaration(decl: dict) -> bool: + """Check if declaration has wildcards that need expansion.""" + # Wave/warp wildcards + if decl.get("wave_m", 2) < 0 or decl.get("wave_n", 2) < 0: + return True + if decl.get("warp_m", 32) < 0 or decl.get("warp_n", 32) < 0: + return True + # Pipeline/scheduler wildcards + if decl.get("pipeline", "compv4") == "*": + return True + if decl.get("scheduler", "intrawave") == "*": + return True + if decl.get("epilogue", "cshuffle") == "*": + return True + return False + + +def validate_kernel_config(decl: dict, arch: str = "gfx942") -> tuple: + """Validate a kernel configuration against known supported combinations. + + Uses arch_specs_generated for architecture-specific validation. + + For wildcard declarations (-1 values or "*" strings), validation is skipped + because the expansion phase will generate only valid combinations. + + Returns: (is_valid, error_message) + """ + # Skip validation for wildcards - expansion will filter invalid combos + if is_wildcard_declaration(decl): + return (True, None) + + arch_data = get_arch_filter_data() + + pipeline = decl.get("pipeline", "compv4") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + dtype = decl.get("dtype_a", "fp16") + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + errors = [] + + # Check trait combination (pipeline, epilogue, scheduler) + combo = (pipeline, epilogue, scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}\n" + f" Valid schedulers for {pipeline}+{epilogue}: intrawave" + ) + + # Check wave configuration for this arch + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n" + f" Valid wave configs: {valid_str}" + ) + + # Check warp tile configuration for this arch and dtype + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n" + f" Valid warp tiles: {valid_str}" + ) + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}\n" + f" Supported: {', '.join(arch_data['supported_archs'])}" + ) + + if errors: + return (False, "\n".join(errors)) + + return (True, None) + + +def build_exact_kernel_filename(decl: dict) -> str: + """Build the exact kernel filename from a fully-specified declaration. + + Standard format: + gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile}_{wave}_{warp}.hpp + + Multi-D format: + gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile}_{wave}_{warp}_multid_{op}_d{num}.hpp + """ + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + pipeline = decl.get("pipeline", "compv4") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + + pad_m = "True" if decl.get("pad_m", False) else "False" + pad_n = "True" if decl.get("pad_n", False) else "False" + pad_k = "True" if decl.get("pad_k", False) else "False" + preshuffle = "True" if decl.get("preshuffle", False) else "False" + + tile_m = decl.get("tile_m", 128) + tile_n = decl.get("tile_n", 128) + tile_k = decl.get("tile_k", 32) + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + tile_str = f"{tile_m}x{tile_n}x{tile_k}" + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + warp_str = f"{warp_m}x{warp_n}x{warp_k}" + + base = f"gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile_str}_{wave_str}_{warp_str}" + + # Handle Multi-D kernels + elementwise_op = decl.get("elementwise_op", "PassThrough") + num_d_tensors = decl.get("num_d_tensors", 0) + if elementwise_op != "PassThrough" and num_d_tensors > 0: + base += f"_multid_{elementwise_op}_d{num_d_tensors}" + + return f"{base}.hpp" + + +def generate_specific_kernel(decl: dict, gpu_target: str = "gfx942") -> bool: + """Generate a specific kernel based on declaration.""" + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + + print(f" Generating kernel for {dtype}/{layout}...") + + # Use CodegenRunner to generate + runner = CodegenRunner( + datatype=dtype, + layout=layout, + gpu_target=gpu_target, + ) + + result = runner.generate("standard") + return result.success + + +def find_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path: + """Find a matching kernel header file for a declaration. + + Tries multiple matching strategies: + 1. Exact filename match + 2. Match with key parameters (dtype, layout, pipeline, scheduler, tile) + 3. Match with just dtype, layout, and tile (more flexible) + 4. Any kernel with matching dtype and layout + + If no kernel exists, attempts to generate it. + Returns None only if all strategies fail. + """ + kernel_dir = get_generated_kernels_dir() + + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + pipeline = decl.get("pipeline", "compv4") + scheduler = decl.get("scheduler", "intrawave") + tile_m = decl.get("tile_m", 128) + tile_n = decl.get("tile_n", 128) + tile_k = decl.get("tile_k", 32) + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + tile_str = f"{tile_m}x{tile_n}x{tile_k}" + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + + # Build exact filename + exact_filename = build_exact_kernel_filename(decl) + exact_path = kernel_dir / exact_filename + + # Strategy 1: Exact filename match + if exact_path.exists(): + print(f" Found exact kernel: {exact_filename}") + return exact_path + + # Strategy 2: Match with key parameters + pattern = ( + f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_*.hpp" + ) + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found matching kernel: {matches[0].name}") + return matches[0] + + # Strategy 3: Match with just dtype, layout, tile (ignore wave/warp) + pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found kernel with matching tile: {matches[0].name}") + return matches[0] + + # Strategy 4: Match with just dtype, layout (most flexible, for wildcards) + # Prefer kernels with intrawave scheduler (known to work) + pattern = f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found kernel with intrawave: {matches[0].name}") + return matches[0] + + # Strategy 5: Any kernel with matching dtype and layout + pattern = f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found kernel with matching dtype/layout/tile: {matches[0].name}") + return matches[0] + + # Strategy 6: Try to generate the kernel + print(" No matching kernel found, attempting to generate...") + if generate_specific_kernel(decl, gpu_target): + # Check strategies again after generation + for pattern in [ + f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp", + f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp", + f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp", + ]: + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Generated: {matches[0].name}") + return matches[0] + + # All strategies failed - return None (caller will try next expanded decl) + return None + + +def is_conv_wildcard_declaration(decl: dict) -> bool: + """Check if conv declaration has wildcards that need expansion.""" + if decl.get("wave_m", 2) < 0 or decl.get("wave_n", 2) < 0: + return True + if decl.get("warp_m", 32) < 0 or decl.get("warp_n", 32) < 0: + return True + if decl.get("pipeline", "compv3") == "*": + return True + if decl.get("scheduler", "intrawave") == "*": + return True + return False + + +def validate_conv_kernel_config(decl: dict, arch: str = "gfx942") -> tuple: + """Validate a conv kernel configuration against arch filter. + + For wildcard declarations, validation is skipped (expansion handles it). + + Returns: (is_valid, error_message) + """ + # Skip validation for wildcards + if is_conv_wildcard_declaration(decl): + return (True, None) + + arch_data = get_arch_filter_data() + + pipeline = decl.get("pipeline", "compv3") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + dtype = decl.get("dtype", "fp16") + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + errors = [] + + # Check trait combination + combo = (pipeline, epilogue, scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}\n" + f" Valid schedulers for {pipeline}+{epilogue}: intrawave" + ) + + # Check wave configuration + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n" + f" Valid wave configs: {valid_str}" + ) + + # Check warp tile configuration + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n" + f" Valid warp tiles: {valid_str}" + ) + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}\n" + f" Supported: {', '.join(arch_data['supported_archs'])}" + ) + + if errors: + return (False, "\n".join(errors)) + + return (True, None) + + +def build_exact_conv_kernel_filename(decl: dict) -> str: + """Build the exact conv kernel filename from a fully-specified declaration. + + Conv filename format: + conv_{type}_{dtype}_{ndim}d_{pipeline}_{epilogue}_{scheduler}_{tile}_{wave}.hpp + + Example: + conv_fwd_fp16_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1.hpp + """ + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + pipeline = decl.get("pipeline", "compv3") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + + # Map conv_type to filename prefix + if conv_type == "forward": + type_prefix = "fwd" + elif conv_type == "bwd_data": + type_prefix = "bwdd" + elif conv_type == "bwd_weight": + type_prefix = "bwdw" + else: + type_prefix = conv_type + + tile_k = decl.get("tile_k", 128) + tile_c = decl.get("tile_c", 128) + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + tile_str = f"{tile_k}x{tile_c}x32" # Conv uses tile_k x tile_c x 32 format + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + + return f"conv_{type_prefix}_{dtype}_{num_dims}d_{pipeline}_{epilogue}_{scheduler}_{tile_str}_{wave_str}.hpp" + + +def generate_specific_conv_kernel(decl: dict, gpu_target: str = "gfx942") -> bool: + """Generate a specific conv kernel based on declaration.""" + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + + print(f" Generating conv kernel for {dtype}/{conv_type}/{num_dims}d...") + + # Map to variant name + if conv_type == "forward": + variant = "forward" + elif conv_type == "bwd_data": + variant = "bwd_data" + elif conv_type == "bwd_weight": + variant = "bwd_weight" + else: + variant = "forward" + + # Use unified_conv_codegen + codegen_dir = get_dispatcher_root() / "codegen" + codegen_script = codegen_dir / "unified_conv_codegen.py" + output_dir = get_generated_kernels_dir() + + cmd = [ + "python3", + str(codegen_script), + "--datatype", + dtype, + "--variant", + variant, + "--ndim", + str(num_dims), + "--arch", + gpu_target, + "--output", + str(output_dir), + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + return result.returncode == 0 + except subprocess.TimeoutExpired: + return False + + +def find_conv_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path: + """Find the EXACT matching conv kernel header file for a declaration. + + If the kernel doesn't exist, attempts to generate it. + Returns None only if generation also fails. + """ + kernel_dir = get_generated_kernels_dir() + + # Build exact filename + exact_filename = build_exact_conv_kernel_filename(decl) + exact_path = kernel_dir / exact_filename + + # Check if exact kernel exists + if exact_path.exists(): + print(f" Found exact conv kernel: {exact_filename}") + return exact_path + + # Try to find with glob (in case of minor variations) + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + pipeline = decl.get("pipeline", "compv3") + scheduler = decl.get("scheduler", "intrawave") + tile_k = decl.get("tile_k", 128) + tile_c = decl.get("tile_c", 128) + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + # Map conv_type to prefix + if conv_type == "forward": + type_prefix = "fwd" + elif conv_type == "bwd_data": + type_prefix = "bwdd" + elif conv_type == "bwd_weight": + type_prefix = "bwdw" + else: + type_prefix = conv_type + + tile_str = f"{tile_k}x{tile_c}" + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + + # Search pattern with key parameters + pattern = f"conv_{type_prefix}_{dtype}_{num_dims}d_{pipeline}_*_{scheduler}_*{tile_str}*_{wave_str}.hpp" + matches = list(kernel_dir.glob(pattern)) + + if matches: + print(f" Found matching conv kernel: {matches[0].name}") + return matches[0] + + # Kernel doesn't exist - try to generate it + print(f" Conv kernel not found: {exact_filename}") + print(" Attempting to generate...") + + if generate_specific_conv_kernel(decl, gpu_target): + # Check again after generation + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Generated: {matches[0].name}") + return matches[0] + + # Check for exact match + if exact_path.exists(): + print(f" Generated: {exact_filename}") + return exact_path + + # Still not found - print helpful error + print_error( + " ERROR: Could not find or generate conv kernel matching declaration:" + ) + print_error(f" dtype={dtype}, conv_type={conv_type}, num_dims={num_dims}") + print_error(f" pipeline={pipeline}, scheduler={scheduler}") + print_error(f" tile={tile_k}x{tile_c}, wave={wave_str}") + print_error(f" Expected: {exact_filename}") + print_error(f" Available conv kernels in {kernel_dir}:") + + available = list(kernel_dir.glob(f"conv_{type_prefix}_{dtype}_{num_dims}d_*.hpp"))[ + :5 + ] + for k in available: + print_error(f" - {k.name}") + if len(list(kernel_dir.glob(f"conv_{type_prefix}_{dtype}_{num_dims}d_*.hpp"))) > 5: + print_error(" ... and more") + + return None + + +def build_dispatcher_library(hipcc: str) -> bool: + """Build the dispatcher library if needed.""" + build_dir = get_build_dir() + lib_path = build_dir / "libck_tile_dispatcher.a" + + if lib_path.exists(): + return True + + print(" Building dispatcher library...") + build_dir.mkdir(parents=True, exist_ok=True) + + dispatcher_dir = get_dispatcher_root() + + # Run cmake + cmake_cmd = ["cmake", str(dispatcher_dir), f"-DCMAKE_CXX_COMPILER={hipcc}"] + result = subprocess.run( + cmake_cmd, cwd=str(build_dir), capture_output=True, text=True + ) + if result.returncode != 0: + print_error(f"CMake failed: {result.stderr}") + return False + + # Run make + make_cmd = ["make", "ck_tile_dispatcher", f"-j{os.cpu_count() or 4}"] + result = subprocess.run( + make_cmd, cwd=str(build_dir), capture_output=True, text=True + ) + if result.returncode != 0: + print_error(f"Make failed: {result.stderr}") + return False + + return True + + +def compile_application( + source_file: Path, + output_bin: Path, + kernel_header: Path, + hipcc: str, + gpu_target: str = "gfx942", +) -> bool: + """Compile the application with hipcc.""" + ck_root = get_ck_root() + dispatcher_dir = get_dispatcher_root() + build_dir = get_build_dir() + kernel_dir = get_generated_kernels_dir() + + includes = [ + f"-I{ck_root / 'include'}", + f"-I{dispatcher_dir / 'include'}", + f"-I{kernel_dir}", + ] + + cmd = [ + hipcc, + "-std=c++17", + "-O3", + f"--offload-arch={gpu_target}", + *includes, + "-include", + str(kernel_header), + f"-L{build_dir}", + "-lck_tile_dispatcher", + "-o", + str(output_bin), + str(source_file), + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + # Filter out nodiscard warnings + if result.stderr: + lines = result.stderr.split("\n") + errors = [line for line in lines if "error:" in line.lower()] + if errors: + for err_line in errors[:5]: + print_error(f" {err_line}") + + return result.returncode == 0 + + +# ============================================================================= +# Main +# ============================================================================= + + +def main(): + parser = argparse.ArgumentParser( + description="Build CK Tile application with declarative kernels", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Example: + python3 compile_gemm_examples.py examples/cpp/01_basic_gemm_declarative.cpp my_app + +In your C++ code, declare kernels like: + DECL_KERNEL_SET(my_kernels, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(128, 128, 32).wave(2, 2, 1).warp(32, 32, 16) + .pipeline("compv4").scheduler("intrawave")) + ); +""", + ) + parser.add_argument("source", help="Source file (.cpp)") + parser.add_argument( + "output", nargs="?", help="Output name (default: source basename)" + ) + parser.add_argument( + "--gpu-target", default="gfx942", help="GPU target architecture" + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + args = parser.parse_args() + + # Resolve paths using utilities from ctypes_utils + dispatcher_dir = get_dispatcher_root() + build_dir = get_build_dir() + + source_file = Path(args.source) + if not source_file.is_absolute(): + # Try relative to dispatcher dir first, then CWD + candidates = [ + dispatcher_dir / args.source, + dispatcher_dir / "examples" / args.source, # examples/gemm/cpp/... + Path.cwd() / args.source, + ] + for candidate in candidates: + if candidate.exists(): + source_file = candidate + break + + if not source_file.exists(): + print_error(f"Source file not found: {source_file}") + return 1 + + output_name = args.output or source_file.stem + output_bin = build_dir / output_name + + # Ensure build directory exists + build_dir.mkdir(parents=True, exist_ok=True) + + print_success("=== CK Tile Declarative Kernel Build ===") + print() + + # Phase 1: Extract declarations (both GEMM and Conv) + print_phase("Phase 1: Scanning for kernel declarations...") + + gemm_declarations = extract_kernel_declarations(source_file) + conv_declarations = extract_conv_kernel_declarations(source_file) + + if not gemm_declarations and not conv_declarations: + print_error(" No kernel declarations found!") + print(" Add DECL_KERNEL_SET for GEMM or DECL_CONV_KERNEL_SET for Conv") + return 1 + + # Handle GEMM declarations + if gemm_declarations: + print(f"\n GEMM: Found {len(gemm_declarations)} declaration(s)") + + # Group by kernel set + sets = {} + for decl in gemm_declarations: + set_name = decl.get("set") or "(global)" + if set_name not in sets: + sets[set_name] = [] + sets[set_name].append(decl) + + for set_name, set_decls in sets.items(): + print(f" [{set_name}] ({len(set_decls)} kernels):") + for decl in set_decls[:5]: + needs_expansion = ( + decl.get("wave_m", -1) < 0 or decl.get("warp_m", -1) < 0 + ) + suffix = " [expands]" if needs_expansion else "" + display_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + print(f" - {display_name}{suffix}") + if len(set_decls) > 5: + print(f" ... and {len(set_decls) - 5} more") + + # Validate declarations against arch filter + print(f"\n Validating against {args.gpu_target} arch filter...") + wildcard_count = 0 + invalid_count = 0 + auto_corrections = [] + + for decl in gemm_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + # Check for wildcards + if is_wildcard_declaration(decl): + wildcard_count += 1 + continue # Wildcards validated during expansion + + is_valid, error_msg = validate_kernel_config(decl, arch) + if not is_valid: + print(f"\n ⚠ Invalid configuration: {decl_name}") + + # Parse the error and show specific auto-corrections + corrections = [] + original_values = {} + + if "wave configuration" in error_msg.lower(): + original_values["wave"] = ( + f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]" + ) + decl["wave_m"] = -1 + decl["wave_n"] = -1 + corrections.append( + f"wave: {original_values['wave']} → [wildcard expansion]" + ) + + if "warp tile" in error_msg.lower(): + original_values["warp"] = ( + f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]" + ) + decl["warp_m"] = -1 + decl["warp_n"] = -1 + corrections.append( + f"warp_tile: {original_values['warp']} → [wildcard expansion]" + ) + + if "trait combination" in error_msg.lower(): + original_values["pipeline"] = decl.get("pipeline", "compv4") + original_values["scheduler"] = decl.get("scheduler", "intrawave") + decl["pipeline"] = "*" + decl["scheduler"] = "*" + corrections.append( + f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + ) + corrections.append( + f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + ) + + # Print the auto-corrections + print(" AUTO-CORRECTION:") + for corr in corrections: + print(f" • {corr}") + auto_corrections.append((decl_name, corrections)) + + invalid_count += 1 + wildcard_count += 1 + + if invalid_count > 0: + print( + f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + ) + + if wildcard_count > 0: + print( + f" ✓ {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + ) + else: + print(f" ✓ All {len(gemm_declarations)} configurations valid") + + # Expand GEMM declarations (for wildcards) + print("\n Expanding wildcards to valid configurations...") + expanded_gemm = [] + for decl in gemm_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + expanded = expand_declaration_with_arch_filter(decl, arch) + expanded_gemm.extend(expanded) + + # Show what the wildcard expanded to + if len(expanded) > 1: + print( + f" {decl_name}: expanded to {len(expanded)} valid configurations" + ) + # Show first few expanded configs + for exp in expanded[:3]: + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print( + f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + ) + if len(expanded) > 3: + print(f" ... and {len(expanded) - 3} more") + elif len(expanded) == 1 and is_wildcard_declaration(decl): + exp = expanded[0] + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + + if len(expanded_gemm) > len(gemm_declarations): + print( + f"\n Total: {len(gemm_declarations)} declarations → {len(expanded_gemm)} configurations" + ) + + gemm_declarations = expanded_gemm + + # Handle Conv declarations + if conv_declarations: + print(f"\n CONV: Found {len(conv_declarations)} declaration(s)") + + # Group by kernel set + sets = {} + for decl in conv_declarations: + set_name = decl.get("set") or "(global)" + if set_name not in sets: + sets[set_name] = [] + sets[set_name].append(decl) + + for set_name, set_decls in sets.items(): + print(f" [{set_name}] ({len(set_decls)} kernels):") + for decl in set_decls[:5]: + needs_expansion = is_conv_wildcard_declaration(decl) + suffix = " [expands]" if needs_expansion else "" + display_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + print(f" - {display_name}{suffix}") + if len(set_decls) > 5: + print(f" ... and {len(set_decls) - 5} more") + + # Validate Conv declarations against arch filter + print(f"\n Validating against {args.gpu_target} arch filter...") + wildcard_count = 0 + invalid_count = 0 + auto_corrections = [] + + for decl in conv_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + # Check for wildcards + if is_conv_wildcard_declaration(decl): + wildcard_count += 1 + continue # Wildcards validated during expansion + + is_valid, error_msg = validate_conv_kernel_config(decl, arch) + if not is_valid: + print(f"\n ⚠ Invalid conv configuration: {decl_name}") + + # Parse the error and show specific auto-corrections + corrections = [] + original_values = {} + + if "wave configuration" in error_msg.lower(): + original_values["wave"] = ( + f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]" + ) + decl["wave_m"] = -1 + decl["wave_n"] = -1 + corrections.append( + f"wave: {original_values['wave']} → [wildcard expansion]" + ) + + if "warp tile" in error_msg.lower(): + original_values["warp"] = ( + f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]" + ) + decl["warp_m"] = -1 + decl["warp_n"] = -1 + corrections.append( + f"warp_tile: {original_values['warp']} → [wildcard expansion]" + ) + + if "trait combination" in error_msg.lower(): + original_values["pipeline"] = decl.get("pipeline", "compv3") + original_values["scheduler"] = decl.get("scheduler", "intrawave") + decl["pipeline"] = "*" + decl["scheduler"] = "*" + corrections.append( + f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + ) + corrections.append( + f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + ) + + # Print the auto-corrections + print(" AUTO-CORRECTION:") + for corr in corrections: + print(f" • {corr}") + auto_corrections.append((decl_name, corrections)) + + invalid_count += 1 + wildcard_count += 1 + + if invalid_count > 0: + print( + f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + ) + + if wildcard_count > 0: + print( + f" ✓ {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + ) + else: + print(f" ✓ All {len(conv_declarations)} configurations valid") + + # Expand Conv declarations (for wildcards) + print("\n Expanding wildcards to valid configurations...") + expanded_conv = [] + for decl in conv_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + expanded = expand_conv_declaration_with_arch_filter(decl, arch) + expanded_conv.extend(expanded) + + # Show what the wildcard expanded to + if len(expanded) > 1: + print( + f" {decl_name}: expanded to {len(expanded)} valid configurations" + ) + for exp in expanded[:3]: + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print( + f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + ) + if len(expanded) > 3: + print(f" ... and {len(expanded) - 3} more") + elif len(expanded) == 1 and is_conv_wildcard_declaration(decl): + exp = expanded[0] + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + + if len(expanded_conv) > len(conv_declarations): + print( + f"\n Total: {len(conv_declarations)} declarations → {len(expanded_conv)} configurations" + ) + + conv_declarations = expanded_conv + + print() + + # Phase 2: Generate kernels + print_phase("Phase 2: Generating kernels...") + + total_generated = 0 + + # Generate GEMM kernels + if gemm_declarations: + print(" GEMM kernels:") + num_gemm = generate_kernels(gemm_declarations, args.gpu_target) + total_generated += num_gemm + print(f" Generated: {num_gemm}") + + # Generate Conv kernels + if conv_declarations: + print(" CONV kernels:") + num_conv = generate_conv_kernels(conv_declarations, args.gpu_target) + total_generated += num_conv + print(f" Generated: {num_conv}") + + print(f" Total kernel files: {total_generated}") + print() + + # Phase 3: Find kernel header + print_phase("Phase 3: Selecting kernel for compilation...") + + kernel_headers = [] + + # Find GEMM kernel header (try each expanded declaration until one matches) + if gemm_declarations: + gemm_header = None + for decl in gemm_declarations: + header = find_kernel_header(decl, args.gpu_target) + if header: + gemm_header = header + break + + if gemm_header: + kernel_headers.append(gemm_header) + print(f" GEMM: {gemm_header.name}") + else: + print_error(" GEMM: No kernel found matching any declaration!") + print_error( + " The kernels declared in DECL_KERNEL_SET must exist or be generatable." + ) + return 1 + + # Find Conv kernel header + if conv_declarations: + first_conv = conv_declarations[0] + conv_header = find_conv_kernel_header(first_conv) + if conv_header: + kernel_headers.append(conv_header) + print(f" CONV: {conv_header.name}") + + if not kernel_headers: + print_error(" No kernel headers found!") + return 1 + + # Use first available header (can be extended to use multiple) + kernel_header = kernel_headers[0] + print() + + # Phase 4: Build dispatcher library + print_phase("Phase 4: Building dispatcher library...") + hipcc = find_hipcc() + + if not build_dispatcher_library(hipcc): + print_error(" Failed to build dispatcher library!") + return 1 + print(" Done") + print() + + # Phase 5: Compile application + print_phase("Phase 5: Compiling application...") + + if not compile_application( + source_file, output_bin, kernel_header, hipcc, args.gpu_target + ): + print_error(" Compilation failed!") + return 1 + + print(f" Output: {output_bin}") + print() + + # Done + print_success("=== Build Complete ===") + print() + print("Run with:") + print(f" {output_bin}") + print() + print("List declared kernels:") + print(f" {output_bin} --list-kernels") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/scripts/example_kernel_builder.py b/dispatcher/scripts/example_kernel_builder.py new file mode 100755 index 0000000000..d3bb619174 --- /dev/null +++ b/dispatcher/scripts/example_kernel_builder.py @@ -0,0 +1,1447 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Build example kernels - generates and compiles kernels for a single example. + +Detects if example is GEMM or Conv based on macro presence, extracts all +configuration parameters, and generates appropriate kernels. +""" + +import argparse +import os +import re +import shutil +import subprocess +import sys +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Dict, List, Tuple + + +def find_hipcc() -> str: + for path in [os.environ.get("HIPCC"), "/opt/rocm/bin/hipcc", shutil.which("hipcc")]: + if path and os.path.isfile(path): + return path + return "hipcc" + + +def find_ar() -> str: + for path in [ + "/opt/rocm/llvm/bin/llvm-ar", + shutil.which("llvm-ar"), + shutil.which("ar"), + ]: + if path and os.path.isfile(path): + return path + return "ar" + + +def extract_balanced_parens(text: str, start_pos: int) -> str: + """Extract content between balanced parentheses.""" + if start_pos >= len(text) or text[start_pos] != "(": + return "" + depth = 0 + for i, c in enumerate(text[start_pos:], start_pos): + if c == "(": + depth += 1 + elif c == ")": + depth -= 1 + if depth == 0: + return text[start_pos + 1 : i] + return "" + + +def parse_conv_declarations(content: str) -> List[Dict]: + """Parse DECL_CONV_KERNEL_SET declarations with all parameters.""" + kernels = [] + + for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content): + body = extract_balanced_parens(content, match.end() - 1) + if not body: + continue + + # Parse each .add() call + for add_match in re.finditer(r"\.add\s*\(", body): + add_body = extract_balanced_parens(body, add_match.end() - 1) + + kernel = {} + + # ConvSig parameters - handle both single dtype and multi-dtype + # Multi-dtype: .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16", "bf16", "fp16") + if m := re.search( + r'\.dtype\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"(?:\s*,\s*"([^"]+)")?\s*\)', + add_body, + ): + kernel["dtype_in"] = m.group(1) + kernel["dtype_wei"] = m.group(2) + kernel["dtype_out"] = m.group(3) + kernel["dtype_acc"] = m.group(4) if m.group(4) else "fp32" + kernel["dtype"] = m.group(1) # Default for codegen + # Single dtype: .dtype("fp16") + elif m := re.search(r'\.dtype\s*\(\s*"([^"]+)"\s*\)', add_body): + kernel["dtype"] = m.group(1) + kernel["dtype_in"] = m.group(1) + kernel["dtype_wei"] = m.group(1) + kernel["dtype_out"] = m.group(1) + kernel["dtype_acc"] = "fp32" + if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body): + kernel["layout"] = m.group(1) + if m := re.search(r'\.conv_type\s*\(\s*"([^"]+)"', add_body): + kernel["conv_type"] = m.group(1) + if m := re.search(r"\.dims\s*\(\s*(\d+)\s*\)", add_body): + kernel["ndim"] = int(m.group(1)) + + # ConvAlgo parameters - tile(G, M, N) where G=batch, M=output, N=reduction + if m := re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["tile_g"] = int(m.group(1)) # batch tile (usually 1) + kernel["tile_m"] = int(m.group(2)) # output channel tile + kernel["tile_n"] = int(m.group(3)) # input channel tile (reduction) + + # wave(M_Warp, N_Warp, K_Warp) - warp distribution + if m := re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["warp_m"] = int(m.group(1)) + kernel["warp_n"] = int(m.group(2)) + kernel["warp_k"] = int(m.group(3)) + + # warp(M_Warp_Tile, N_Warp_Tile, K_Warp_Tile) - warp tile sizes + if m := re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["warp_tile_m"] = int(m.group(1)) + kernel["warp_tile_n"] = int(m.group(2)) + kernel["warp_tile_k"] = int(m.group(3)) + + # vector_sizes(A, B, C) + if m := re.search( + r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["vector_a"] = int(m.group(1)) + kernel["vector_b"] = int(m.group(2)) + kernel["vector_c"] = int(m.group(3)) + + # Single-value parameters + if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body): + kernel["pipeline"] = m.group(1) + if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body): + kernel["scheduler"] = m.group(1) + if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body): + kernel["epilogue"] = m.group(1) + if m := re.search(r"\.block_per_cu\s*\(\s*(\d+)\s*\)", add_body): + kernel["block_per_cu"] = int(m.group(1)) + if m := re.search(r"\.num_wave_groups\s*\(\s*(\d+)\s*\)", add_body): + kernel["num_wave_groups"] = int(m.group(1)) + if m := re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)\s*\)", add_body): + kernel["num_groups_to_merge"] = int(m.group(1)) + if m := re.search( + r"\.double_smem_buffer\s*\(\s*(true|false)\s*\)", add_body, re.I + ): + kernel["double_smem_buffer"] = m.group(1).lower() == "true" + + # Architecture + if m := re.search(r'"(gfx\d+)"', add_body): + kernel["arch"] = m.group(1) + + if kernel.get("dtype"): + # Auto-fill missing parameters with defaults (autocorrect) + kernel = auto_fill_conv_defaults(kernel) + kernels.append(kernel) + + return kernels + + +def auto_fill_conv_defaults(kernel: Dict) -> Dict: + """Auto-fill missing conv parameters with sensible defaults (autofill + autocorrect). + + This implements: + 1. AUTOFILL: Missing parameters are filled with valid defaults (ConvConfigComputeV3) + 2. AUTOCORRECT: Invalid values are corrected to valid ones + """ + # Default tile configuration matching ConvConfigComputeV3 + defaults = { + "tile_g": 1, + "tile_m": 16, + "tile_n": 64, + "warp_m": 1, + "warp_n": 4, + "warp_k": 1, + "warp_tile_m": 16, + "warp_tile_n": 16, + "warp_tile_k": 32, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "vector_a": 4, + "vector_b": 8, + "vector_c": 8, + "block_per_cu": 1, + "num_wave_groups": 1, + "num_groups_to_merge": 1, + "ndim": 2, + "layout": "nhwgc", + "conv_type": "forward", + "arch": "gfx942", + } + + # AUTOFILL: Fill missing parameters with defaults + autofilled = [] + for key, value in defaults.items(): + if key not in kernel or kernel[key] is None or kernel[key] == -1: + kernel[key] = value + autofilled.append(f"{key}={value}") + + if autofilled: + print(f" [AUTOFILL] {', '.join(autofilled)}") + + # AUTOCORRECT: Fix invalid wave configurations for gfx942 + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + current_wave = ( + kernel.get("warp_m", 1), + kernel.get("warp_n", 4), + kernel.get("warp_k", 1), + ) + + if current_wave not in valid_wave_configs: + old = current_wave + kernel["warp_m"] = 1 + kernel["warp_n"] = 4 + kernel["warp_k"] = 1 + print(f" [AUTOCORRECT] wave{old} -> wave(1,4,1) (invalid for gfx942)") + + # AUTOCORRECT: Fix invalid pipeline for backward ops + conv_type = kernel.get("conv_type", "forward") + pipeline = kernel.get("pipeline", "compv3") + + if conv_type in ["bwd_data", "bwd_weight"] and pipeline in ["compv4", "compv5"]: + old_pipeline = pipeline + kernel["pipeline"] = "compv3" + print( + f" [AUTOCORRECT] pipeline {old_pipeline} -> compv3 (invalid for {conv_type})" + ) + + return kernel + + +def expand_conv_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]: + """Expand wildcard parameters to multiple valid configurations. + + When users specify wildcards (-1 or *), this expands them to all + valid configurations for the target architecture. + """ + expanded = [] + + # Valid wave configurations for gfx942 + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + # Valid warp tile configurations for gfx942 fp16 + valid_warp_configs = [(16, 16, 32), (32, 32, 16)] + + # Check if expansion is needed + needs_wave = kernel.get("warp_m") is None or kernel.get("warp_m") == -1 + needs_warp = kernel.get("warp_tile_m") is None or kernel.get("warp_tile_m") == -1 + + if not needs_wave and not needs_warp: + return [kernel] + + # Expand wave configurations + wave_configs = ( + valid_wave_configs + if needs_wave + else [ + (kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1)) + ] + ) + + # Expand warp tile configurations + warp_configs = ( + valid_warp_configs + if needs_warp + else [ + ( + kernel.get("warp_tile_m", 32), + kernel.get("warp_tile_n", 32), + kernel.get("warp_tile_k", 16), + ) + ] + ) + + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_configs: + new_kernel = kernel.copy() + new_kernel["warp_m"] = wm + new_kernel["warp_n"] = wn + new_kernel["warp_k"] = wk + new_kernel["warp_tile_m"] = wtm + new_kernel["warp_tile_n"] = wtn + new_kernel["warp_tile_k"] = wtk + expanded.append(new_kernel) + + return expanded + + +def parse_int_or_wildcard(val: str) -> int: + """Parse integer or return -1 for wildcards. + + Supported wildcard formats: + - ANY_INT: Macro defined as -1 + - -1: Direct numeric wildcard + - "*": String wildcard (also maps to -1 for integer params) + """ + val = val.strip() + if val == "ANY_INT" or val == "-1" or val == "*": + return -1 + return int(val) + + +def parse_gemm_declarations(content: str) -> List[Dict]: + """Parse DECL_KERNEL_SET declarations for GEMM. + + Supports wildcards: + - ANY_INT for numeric params (wave, warp) -> expands to all valid combos + - "*" for string params (pipeline, scheduler) -> expands to valid options + + Each kernel is tagged with its kernel_set name for separate registration. + """ + kernels = [] + + for match in re.finditer(r"DECL_KERNEL_SET\s*\(\s*(\w+)\s*,", content): + kernel_set_name = match.group(1) + body = extract_balanced_parens( + content, match.start() + content[match.start() :].find("(") + ) + if not body: + continue + + for add_match in re.finditer(r"\.add\s*\(", body): + add_body = extract_balanced_parens(body, add_match.end() - 1) + + kernel = {} + + # Signature parameters + if m := re.search(r'\.dtype\s*\(\s*"([^"]+)"', add_body): + kernel["dtype"] = m.group(1) + if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body): + kernel["layout"] = m.group(1) + if m := re.search(r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)', add_body): + kernel["elementwise_op"] = m.group(1) + kernel["num_d_tensors"] = int(m.group(2)) + + # Algorithm parameters - support ANY_INT wildcard + if m := re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["tile_m"] = int(m.group(1)) + kernel["tile_n"] = int(m.group(2)) + kernel["tile_k"] = int(m.group(3)) + + # Wave: support ANY_INT, -1, and "*" as wildcards + if m := re.search( + r"\.wave\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)", + add_body, + ): + kernel["warp_m"] = parse_int_or_wildcard(m.group(1)) + kernel["warp_n"] = parse_int_or_wildcard(m.group(2)) + kernel["warp_k"] = parse_int_or_wildcard(m.group(3)) + + # Warp: support ANY_INT, -1, and "*" as wildcards + if m := re.search( + r"\.warp\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)", + add_body, + ): + kernel["warp_tile_m"] = parse_int_or_wildcard(m.group(1)) + kernel["warp_tile_n"] = parse_int_or_wildcard(m.group(2)) + kernel["warp_tile_k"] = parse_int_or_wildcard(m.group(3)) + + # Pipeline/Scheduler: support "*" wildcard + if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body): + kernel["pipeline"] = m.group(1) + if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body): + kernel["scheduler"] = m.group(1) + if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body): + kernel["epilogue"] = m.group(1) + if m := re.search( + r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)", + add_body, + re.I, + ): + kernel["pad_m"] = m.group(1).lower() == "true" + kernel["pad_n"] = m.group(2).lower() == "true" + kernel["pad_k"] = m.group(3).lower() == "true" + + # Shorthand format: .add("dtype", "layout", M, N, K) + if not kernel.get("dtype"): + if m := re.match( + r'\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)', + add_body, + ): + kernel["dtype"] = m.group(1) + kernel["layout"] = m.group(2) + kernel["tile_m"] = int(m.group(3)) + kernel["tile_n"] = int(m.group(4)) + kernel["tile_k"] = int(m.group(5)) + + if kernel.get("dtype"): + kernel["kernel_set"] = kernel_set_name + kernels.append(kernel) + + # Expand wildcards to multiple kernels + expanded = [] + for kernel in kernels: + expanded.extend(expand_gemm_wildcards(kernel)) + + # Apply autocorrect to each expanded kernel + return [auto_fill_gemm_defaults(k) for k in expanded] + + +def expand_gemm_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]: + """Expand wildcard parameters to multiple valid configurations. + + When users specify ANY_INT (-1) or "*", this expands them to all + valid configurations for the target architecture. + + Note: Block size constraint filters invalid combos: + - (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024 + - For 128x128 tile: only (32,32,k) works (16 warps * 64 = 1024) + - For 64x64 tile: both (16,16,k) and (32,32,k) work + """ + # Valid wave configurations for gfx942 + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + # Valid warp tile configurations for gfx942 fp16 + valid_warp_configs = [(16, 16, 32), (32, 32, 16)] + + # Valid pipelines and schedulers + valid_pipelines = ["compv3"] # compv4 requires special handling + valid_schedulers = ["intrawave"] + + # Check what needs expansion + needs_wave = kernel.get("warp_m") == -1 + needs_warp = kernel.get("warp_tile_m") == -1 + needs_pipeline = kernel.get("pipeline") == "*" + needs_scheduler = kernel.get("scheduler") == "*" + + if not any([needs_wave, needs_warp, needs_pipeline, needs_scheduler]): + return [kernel] + + # Determine configs to iterate + wave_configs = ( + valid_wave_configs + if needs_wave + else [ + (kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1)) + ] + ) + warp_configs = ( + valid_warp_configs + if needs_warp + else [ + ( + kernel.get("warp_tile_m", 32), + kernel.get("warp_tile_n", 32), + kernel.get("warp_tile_k", 16), + ) + ] + ) + pipelines = ( + valid_pipelines if needs_pipeline else [kernel.get("pipeline", "compv3")] + ) + schedulers = ( + valid_schedulers if needs_scheduler else [kernel.get("scheduler", "intrawave")] + ) + + expanded = [] + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_configs: + # Check block size constraint: (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024 + tile_m = kernel.get("tile_m", 128) + tile_n = kernel.get("tile_n", 128) + num_warps = (tile_m // wtm) * (tile_n // wtn) + if num_warps * 64 > 1024: + continue # Skip invalid config + + for pipe in pipelines: + for sched in schedulers: + new_kernel = kernel.copy() + new_kernel["warp_m"] = wm + new_kernel["warp_n"] = wn + new_kernel["warp_k"] = wk + new_kernel["warp_tile_m"] = wtm + new_kernel["warp_tile_n"] = wtn + new_kernel["warp_tile_k"] = wtk + new_kernel["pipeline"] = pipe + new_kernel["scheduler"] = sched + expanded.append(new_kernel) + + if expanded: + print(f" [WILDCARD] Expanded 1 declaration -> {len(expanded)} kernel(s)") + + return expanded if expanded else [kernel] + + +def auto_fill_gemm_defaults(kernel: Dict) -> Dict: + """Auto-fill missing GEMM parameters with sensible defaults (autofill + autocorrect). + + This implements: + 1. AUTOFILL: Missing parameters are filled with valid defaults + 2. AUTOCORRECT: Invalid values are corrected to valid ones (e.g., wave(1,1,1) -> wave(2,2,1)) + """ + defaults = { + "tile_m": 128, + "tile_n": 128, + "tile_k": 64, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "pad_m": False, + "pad_n": False, + "pad_k": False, + "layout": "rcr", + } + + # AUTOFILL: Fill missing parameters with defaults + autofilled = [] + for key, value in defaults.items(): + if key not in kernel or kernel[key] is None or kernel[key] == -1: + kernel[key] = value + autofilled.append(f"{key}={value}") + + if autofilled: + print(f" [AUTOFILL] {', '.join(autofilled)}") + + # AUTOCORRECT: Fix invalid wave configurations for gfx942 + # Valid wave configs: (1,4,1), (2,2,1), (4,1,1) + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + current_wave = ( + kernel.get("warp_m", 2), + kernel.get("warp_n", 2), + kernel.get("warp_k", 1), + ) + + if current_wave not in valid_wave_configs: + # Correct to (2,2,1) which is a balanced default + old = current_wave + kernel["warp_m"] = 2 + kernel["warp_n"] = 2 + kernel["warp_k"] = 1 + print(f" [AUTOCORRECT] wave{old} -> wave(2,2,1) (invalid for gfx942)") + + # AUTOCORRECT: Fix invalid pipeline/scheduler combinations + invalid_combos = [ + ("compv3", "interwave"), + ("compv4", "interwave"), + ] + current_combo = ( + kernel.get("pipeline", "compv3"), + kernel.get("scheduler", "intrawave"), + ) + if current_combo in invalid_combos: + old = current_combo + kernel["scheduler"] = "intrawave" + print( + f" [AUTOCORRECT] {old[0]}/{old[1]} -> {old[0]}/intrawave (invalid combo)" + ) + + # AUTOCORRECT: Fix warp tile to avoid exceeding max block size (1024 threads) + # Block size = (tile_m / warp_tile_m) * (tile_n / warp_tile_n) * 64 + tile_m = kernel.get("tile_m", 128) + tile_n = kernel.get("tile_n", 128) + warp_tile_m = kernel.get("warp_tile_m", 32) + warp_tile_n = kernel.get("warp_tile_n", 32) + + num_warps = (tile_m // warp_tile_m) * (tile_n // warp_tile_n) + block_size = num_warps * 64 # 64 threads per warp + + if block_size > 1024: + # Find valid warp tile that fits + old_warp = (warp_tile_m, warp_tile_n, kernel.get("warp_tile_k", 16)) + + # For large tiles, use larger warp tiles + if tile_m >= 256: + kernel["warp_tile_m"] = 64 + if tile_n >= 256: + kernel["warp_tile_n"] = 64 + + # Recalculate + num_warps = (tile_m // kernel["warp_tile_m"]) * ( + tile_n // kernel["warp_tile_n"] + ) + block_size = num_warps * 64 + + if block_size <= 1024: + new_warp = ( + kernel["warp_tile_m"], + kernel["warp_tile_n"], + kernel["warp_tile_k"], + ) + print( + f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size={block_size})" + ) + else: + # Still too large, try even larger warp tiles + kernel["warp_tile_m"] = tile_m // 4 + kernel["warp_tile_n"] = tile_n // 4 + new_warp = ( + kernel["warp_tile_m"], + kernel["warp_tile_n"], + kernel["warp_tile_k"], + ) + print( + f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size adjusted)" + ) + + return kernel + + +def strip_cpp_strings_and_comments(content: str) -> str: + """Strip C++ string literals and comments that could cause false positives. + + Only strips: + - Comments (// and /* */) - always stripped + - Raw string literals (R"...") - always stripped (can contain anything) + - Regular strings ONLY if they contain problematic patterns like DECL_KERNEL_SET + + Preserves normal string literals like "fp16", "rcr" which are needed for parsing. + """ + result = [] + i = 0 + n = len(content) + + # Patterns that indicate a string is problematic and should be stripped + problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("] + + while i < n: + # Check for raw string literal: R"delimiter(...)delimiter" + # Always strip these as they can contain arbitrary content + if i < n - 1 and content[i] == "R" and content[i + 1] == '"': + # Find the delimiter (between R" and () + j = i + 2 + delimiter_start = j + while j < n and content[j] != "(": + j += 1 + delimiter = content[delimiter_start:j] + # Find the closing )delimiter" + end_marker = ")" + delimiter + '"' + end_pos = content.find(end_marker, j + 1) + if end_pos != -1: + # Replace with spaces to preserve line numbers + span = content[i : end_pos + len(end_marker)] + result.append("".join("\n" if c == "\n" else " " for c in span)) + i = end_pos + len(end_marker) + continue + + # Check for regular string literal - only strip if it contains problematic patterns + if content[i] == '"': + j = i + 1 + while j < n: + if content[j] == "\\" and j + 1 < n: + j += 2 # Skip escaped character + elif content[j] == '"': + j += 1 + break + else: + j += 1 + string_content = content[i:j] + + # Only strip if this string contains problematic patterns + should_strip = any(pat in string_content for pat in problematic_patterns) + if should_strip: + result.append(" " * len(string_content)) + else: + result.append(string_content) + i = j + continue + + # Check for single-line comment - always strip + if i < n - 1 and content[i : i + 2] == "//": + j = i + while j < n and content[j] != "\n": + j += 1 + result.append(" " * (j - i)) + i = j + continue + + # Check for multi-line comment - always strip + if i < n - 1 and content[i : i + 2] == "/*": + end_pos = content.find("*/", i + 2) + if end_pos != -1: + span = content[i : end_pos + 2] + # Preserve newlines in multi-line comments + result.append("".join("\n" if c == "\n" else " " for c in span)) + i = end_pos + 2 + continue + + result.append(content[i]) + i += 1 + + return "".join(result) + + +def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]: + """Detect example type and parse kernel declarations. + + Properly strips string literals and comments before parsing to avoid + picking up declarations inside strings or commented-out code. + """ + content = source_path.read_text() + content = strip_cpp_strings_and_comments(content) + + if "DECL_CONV_KERNEL_SET" in content: + return "conv", parse_conv_declarations(content) + elif "DECL_KERNEL_SET" in content: + return "gemm", parse_gemm_declarations(content) + return "unknown", [] + + +def generate_gemm_registration( + kernel_headers: List[Path], example_name: str, kernels: List[Dict] = None +) -> str: + """Generate GEMM kernel registration code for the dispatcher registry. + + Uses GeneratedKernelInstance to wrap the generated kernels + and provide the KernelInstance interface for the Dispatcher. + + If kernels list is provided with kernel_set info, generates separate + registration functions per kernel set. + """ + if not kernel_headers: + return " // No kernels to register" + + # Build mapping from kernel config pattern to kernel set + kernel_to_set = {} + kernel_sets = set() + if kernels: + for k in kernels: + tile_m = k.get("tile_m", 128) + tile_n = k.get("tile_n", 128) + tile_k = k.get("tile_k", 64) + warp_m = k.get("warp_m", 2) + warp_n = k.get("warp_n", 2) + warp_k = k.get("warp_k", 1) + warp_tile_m = k.get("warp_tile_m", 32) + warp_tile_n = k.get("warp_tile_n", 32) + warp_tile_k = k.get("warp_tile_k", 16) + + # Pattern that appears in kernel filename + key_pattern = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + kernel_set = k.get("kernel_set", "default") + kernel_to_set[key_pattern] = kernel_set + kernel_sets.add(kernel_set) + + def generate_registration_block(h: Path) -> str: + """Generate registration code for a single kernel.""" + kernel_name = h.stem + ns = f"ns_{kernel_name}" + + # Parse pipeline, scheduler, and layout from kernel name + # Format: gemm_fp16_rcr_compv3_cshuffle_intrawave_... + parts = kernel_name.split("_") + pipeline = "CompV3" + scheduler = "Intrawave" + epilogue = "CShuffle" + datatype = "FP16" + layout_a = "RowMajor" + layout_b = "ColMajor" + layout_c = "RowMajor" + + # Parse datatype (e.g., fp16, bf16, fp32) + dtype_map = { + "fp16": "FP16", + "bf16": "BF16", + "fp32": "FP32", + "fp64": "FP64", + "int8": "INT8", + } + + # Parse layout from 3-char codes (e.g., rcr, rrr, rrc, ccc) + # r = RowMajor, c = ColMajor + layout_map = {"r": "RowMajor", "c": "ColMajor"} + + # Find pipeline, epilogue, scheduler in the name parts + pipeline_map = { + "mem": "Mem", + "compv1": "CompV1", + "compv2": "CompV2", + "compv3": "CompV3", + "compv4": "CompV4", + "compv5": "CompV5", + "preshufflev1": "PreShuffleV1", + "preshufflev2": "PreShuffleV2", + } + scheduler_map = { + "intrawave": "Intrawave", + "interwave": "Interwave", + "auto": "Auto", + } + epilogue_map = {"default": "Default", "cshuffle": "CShuffle", "none": "None"} + + for part in parts: + if part in pipeline_map: + pipeline = pipeline_map[part] + if part in scheduler_map: + scheduler = scheduler_map[part] + if part in epilogue_map: + epilogue = epilogue_map[part] + if part in dtype_map: + datatype = dtype_map[part] + # Parse 3-char layout codes (e.g., rcr, rrr) + if len(part) == 3 and all(c in "rc" for c in part): + layout_a = layout_map[part[0]] + layout_b = layout_map[part[1]] + layout_c = layout_map[part[2]] + + block = [] + block.append(f" // Register kernel: {kernel_name}") + block.append(" {") + block.append(f" using SelectedKernel = {ns}::SelectedKernel;") + block.append(" ck_tile::dispatcher::KernelKey key;") + block.append( + f" key.signature.dtype_a = ck_tile::dispatcher::DataType::{datatype};" + ) + block.append( + f" key.signature.dtype_b = ck_tile::dispatcher::DataType::{datatype};" + ) + block.append( + f" key.signature.dtype_c = ck_tile::dispatcher::DataType::{datatype};" + ) + block.append( + " key.signature.dtype_acc = ck_tile::dispatcher::DataType::FP32;" + ) + block.append( + f" key.signature.layout_a = ck_tile::dispatcher::LayoutTag::{layout_a};" + ) + block.append( + f" key.signature.layout_b = ck_tile::dispatcher::LayoutTag::{layout_b};" + ) + block.append( + f" key.signature.layout_c = ck_tile::dispatcher::LayoutTag::{layout_c};" + ) + block.append(" key.algorithm.tile_shape.m = SelectedKernel::TileM;") + block.append(" key.algorithm.tile_shape.n = SelectedKernel::TileN;") + block.append(" key.algorithm.tile_shape.k = SelectedKernel::TileK;") + block.append( + " key.algorithm.wave_shape.m = SelectedKernel::WarpPerBlock_M;" + ) + block.append( + " key.algorithm.wave_shape.n = SelectedKernel::WarpPerBlock_N;" + ) + block.append( + " key.algorithm.wave_shape.k = SelectedKernel::WarpPerBlock_K;" + ) + block.append( + " key.algorithm.warp_tile_shape.m = SelectedKernel::WarpTileM;" + ) + block.append( + " key.algorithm.warp_tile_shape.n = SelectedKernel::WarpTileN;" + ) + block.append( + " key.algorithm.warp_tile_shape.k = SelectedKernel::WarpTileK;" + ) + block.append( + " key.algorithm.block_size = SelectedKernel::BlockSize;" + ) + block.append( + f" key.algorithm.pipeline = ck_tile::dispatcher::Pipeline::{pipeline};" + ) + block.append( + f" key.algorithm.scheduler = ck_tile::dispatcher::Scheduler::{scheduler};" + ) + block.append( + f" key.algorithm.epilogue = ck_tile::dispatcher::Epilogue::{epilogue};" + ) + block.append(" key.gfx_arch = arch;") + block.append( + f' auto instance = std::make_shared>(key, "{kernel_name}");' + ) + block.append(" registry.register_kernel(instance);") + block.append(" }") + return "\n".join(block) + + def find_kernel_set(header: Path) -> str: + """Find which kernel set a header belongs to.""" + name = header.stem + for pattern, kset in kernel_to_set.items(): + if pattern in name: + return kset + return "default" + + # Group kernels by set + kernels_by_set = {} + for h in kernel_headers: + kset = find_kernel_set(h) + if kset not in kernels_by_set: + kernels_by_set[kset] = [] + kernels_by_set[kset].append(h) + + # If only one set or no set info, use simple registration + if len(kernels_by_set) <= 1: + lines = [" (void)arch;", ""] + for h in kernel_headers: + lines.append(generate_registration_block(h)) + return "\n".join(lines) + + # Multiple sets - generate registration for all, plus store per-set info + lines = [" // Register ALL kernels from all sets", " (void)arch;", ""] + for h in kernel_headers: + lines.append(generate_registration_block(h)) + + # Store per-set mapping for separate function generation + global _kernels_by_set_cache + _kernels_by_set_cache = (kernels_by_set, generate_registration_block) + + return "\n".join(lines) + + +# Global cache for per-set kernel info +_kernels_by_set_cache = None + + +def generate_per_set_functions(source_stem: str) -> str: + """Generate separate registration functions for each kernel set. + + Generates: + 1. Per-set functions: register_(registry, arch) + 2. String-based dispatcher: register_kernel_set("set_name", registry, arch) + 3. get_kernel_set_names() to list available sets + """ + global _kernels_by_set_cache + if not _kernels_by_set_cache: + return "" + + kernels_by_set, gen_block = _kernels_by_set_cache + _kernels_by_set_cache = None # Clear cache + + lines = [] + set_names = [] + + # Generate per-set functions + for set_name, headers in kernels_by_set.items(): + safe_name = set_name.replace("-", "_") + set_names.append((set_name, safe_name)) + lines.append( + f"inline void register_{safe_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{" + ) + lines.append(" (void)arch;") + for h in headers: + lines.append(gen_block(h)) + lines.append("}") + lines.append("") + + # Generate string-based dispatcher (only if multiple sets) + if len(set_names) > 0: + lines.append("// Dynamic registration by kernel set name") + lines.append( + "inline bool register_kernel_set(const std::string& set_name, ck_tile::dispatcher::Registry& registry, const std::string& arch) {" + ) + for set_name, safe_name in set_names: + lines.append( + f' if (set_name == "{set_name}") {{ register_{safe_name}(registry, arch); return true; }}' + ) + lines.append(" return false; // Unknown set name") + lines.append("}") + lines.append("") + + # Generate helper to list available set names + lines.append("// Get list of available kernel set names") + lines.append("inline std::vector get_kernel_set_names() {") + names_str = ", ".join(f'"{name}"' for name, _ in set_names) + lines.append(f" return {{{names_str}}};") + lines.append("}") + lines.append("") + + return "\n".join(lines) + + +def generate_conv_registration( + kernel_headers: List[Path], example_name: str, kernels: List[Dict] +) -> str: + """Generate Conv kernel registration code for the dispatcher registry.""" + if not kernel_headers: + return " // No kernels to register" + + lines = [] + lines.append( + " (void)registry; (void)arch; // Conv uses direct launcher pattern for now" + ) + + # For conv, we provide direct access to kernel launchers + for i, h in enumerate(kernel_headers): + kernel_name = h.stem + lines.append(f" // Kernel {i + 1}: {kernel_name}") + + return "\n".join(lines) + + +def generate_conv_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path +) -> bool: + """Generate Conv kernels for ALL declarations using unified codegen.""" + if not kernels: + return False + + variant_map = { + "forward": "forward", + "bwd_data": "bwd_data", + "backward_data": "bwd_data", + "bwd_weight": "bwd_weight", + "backward_weight": "bwd_weight", + } + + success_count = 0 + + # Generate a kernel for EACH declaration + for idx, k in enumerate(kernels): + variant = variant_map.get(k.get("conv_type", "forward"), "forward") + + cmd = [ + sys.executable, + str(codegen_dir / "unified_conv_codegen.py"), + "--datatype", + k.get("dtype", "fp16"), + "--variant", + variant, + "--ndim", + str(k.get("ndim", 2)), + "--output", + str(output_dir), + ] + + # Add optional parameters if specified + if k.get("tile_m"): + cmd.extend(["--tile-m", str(k["tile_m"])]) + if k.get("tile_n"): + cmd.extend(["--tile-n", str(k["tile_n"])]) + if k.get("warp_m"): + cmd.extend(["--warp-m", str(k["warp_m"])]) + if k.get("warp_n"): + cmd.extend(["--warp-n", str(k["warp_n"])]) + if k.get("warp_k"): + cmd.extend(["--warp-k", str(k["warp_k"])]) + if k.get("warp_tile_m"): + cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])]) + if k.get("warp_tile_n"): + cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])]) + if k.get("warp_tile_k"): + cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])]) + if k.get("pipeline"): + cmd.extend(["--pipeline", k["pipeline"]]) + if k.get("scheduler"): + cmd.extend(["--scheduler", k["scheduler"]]) + if k.get("epilogue"): + cmd.extend(["--epilogue", k["epilogue"]]) + if k.get("vector_a"): + cmd.extend(["--vector-a", str(k["vector_a"])]) + if k.get("vector_b"): + cmd.extend(["--vector-b", str(k["vector_b"])]) + if k.get("vector_c"): + cmd.extend(["--vector-c", str(k["vector_c"])]) + if k.get("block_per_cu"): + cmd.extend(["--block-per-cu", str(k["block_per_cu"])]) + if k.get("num_wave_groups"): + cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])]) + if k.get("num_groups_to_merge"): + cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])]) + if k.get("double_smem_buffer") is not None: + cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()]) + if k.get("tile_k"): + cmd.extend(["--tile-k", str(k["tile_k"])]) + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=str(codegen_dir) + ) + if result.returncode != 0: + print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") + else: + success_count += 1 + + return success_count > 0 + + +def generate_gemm_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path +) -> bool: + """Generate GEMM kernels for ALL declarations using unified codegen.""" + import json + + if not kernels: + return False + + success_count = 0 + + # Generate a kernel for EACH declaration + for idx, k in enumerate(kernels): + variant = "multi_d" if k.get("elementwise_op") else "standard" + + # Build tile config JSON for this specific kernel + tile_config = { + "tile_m": [k.get("tile_m", 128)], + "tile_n": [k.get("tile_n", 128)], + "tile_k": [k.get("tile_k", 32)], + "warp_m": [k.get("warp_m", 2)], + "warp_n": [k.get("warp_n", 2)], + "warp_k": [k.get("warp_k", 1)], + "warp_tile_m": [k.get("warp_tile_m", 32)], + "warp_tile_n": [k.get("warp_tile_n", 32)], + "warp_tile_k": [k.get("warp_tile_k", 16)], + } + + trait_config = { + "pipeline": [k.get("pipeline", "compv3")], + "epilogue": [k.get("epilogue", "cshuffle")], + "scheduler": [k.get("scheduler", "intrawave")], + "pad_m": [k.get("pad_m", False)], + "pad_n": [k.get("pad_n", False)], + "pad_k": [k.get("pad_k", False)], + "persistent": [False], + } + + config_json = json.dumps( + {"tile_config": tile_config, "trait_config": trait_config} + ) + + cmd = [ + sys.executable, + str(codegen_dir / "unified_gemm_codegen.py"), + "--datatype", + k.get("dtype", "fp16"), + "--layout", + k.get("layout", "rcr"), + "--variants", + variant, + "--output", + str(output_dir), + "--tile-config-json", + config_json, + ] + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=str(codegen_dir) + ) + if result.returncode != 0: + print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") + else: + success_count += 1 + + return success_count > 0 + + +def compile_kernel(args: Tuple) -> Tuple[str, bool, str]: + """Compile a single kernel to object file.""" + kernel_hpp, output_dir, include_dirs, hipcc, gpu_target, idx, total = args + kernel_name = kernel_hpp.stem + + wrapper_cpp = output_dir / f"{kernel_name}.cpp" + wrapper_cpp.write_text( + f'#include "{kernel_hpp.name}"\nnamespace {{ volatile bool _k{idx} = true; }}\n' + ) + + obj_file = output_dir / f"{kernel_name}.o" + + cmd = [ + hipcc, + "-c", + "-fPIC", + "-std=c++17", + "-O3", + f"--offload-arch={gpu_target}", + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + "--offload-compress", + ] + + for inc_dir in include_dirs: + cmd.extend(["-I", str(inc_dir)]) + cmd.extend(["-I", str(kernel_hpp.parent)]) + cmd.extend(["-o", str(obj_file), str(wrapper_cpp)]) + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + return (kernel_name, False, result.stderr[:500]) + return (kernel_name, True, str(obj_file)) + + +def main(): + parser = argparse.ArgumentParser(description="Build example kernels") + parser.add_argument("source", type=Path, help="C++ source file") + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--include-dirs", type=str, required=True) + parser.add_argument("--gpu-target", type=str, default="gfx942") + parser.add_argument("--jobs", type=int, default=os.cpu_count()) + parser.add_argument( + "--target-name", type=str, help="CMake target name (for library naming)" + ) + args = parser.parse_args() + + script_dir = Path(__file__).parent + codegen_dir = script_dir.parent / "codegen" + source_stem = args.source.stem # e.g., "01_basic_gemm" + target_name = args.target_name or source_stem # e.g., "gemm_01_basic" from CMake + + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Detect and parse + example_type, kernels = detect_and_parse(args.source) + + if example_type == "conv": + k = kernels[0] if kernels else {} + variant = k.get("conv_type", "forward") + print( + f"[{target_name}] Conv {k.get('dtype', 'fp16')} {variant} {k.get('ndim', 2)}D ({len(kernels)} declarations)" + ) + elif example_type == "gemm": + k = kernels[0] if kernels else {} + print( + f"[{target_name}] GEMM {k.get('dtype', 'fp16')} {k.get('layout', 'rcr')} ({len(kernels)} declarations)" + ) + else: + print(f"[{target_name}] No kernel declarations - creating empty library") + lib_path = args.output_dir / f"lib{target_name}_kernels.a" + subprocess.run([find_ar(), "rcs", str(lib_path)], check=True) + header = args.output_dir / f"{source_stem}_kernels.hpp" + header.write_text(f"// No kernels for {target_name}\n#pragma once\n") + return 0 + + # Generate kernels + print(f"[{target_name}] Generating kernels...") + if example_type == "conv": + success = generate_conv_kernels(kernels, args.output_dir, codegen_dir) + else: + success = generate_gemm_kernels(kernels, args.output_dir, codegen_dir) + + if not success: + print(f"[{target_name}] Kernel generation failed!") + return 1 + + # Find generated headers + if example_type == "gemm": + kernel_headers = list(args.output_dir.glob("gemm_*.hpp")) + else: + k = kernels[0] if kernels else {} + variant = k.get("conv_type", "forward") + prefix_map = { + "forward": "conv_fwd", + "bwd_data": "conv_bwdd", + "bwd_weight": "conv_bwdw", + } + prefix = prefix_map.get(variant, "conv_fwd") + kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp")) + + if not kernel_headers: + print(f"[{target_name}] No kernel headers generated!") + return 1 + + print(f"[{target_name}] Compiling {len(kernel_headers)} kernels...") + + include_dirs = [Path(p.strip()) for p in args.include_dirs.split(",")] + hipcc = find_hipcc() + + work = [ + ( + h, + args.output_dir, + include_dirs, + hipcc, + args.gpu_target, + i + 1, + len(kernel_headers), + ) + for i, h in enumerate(kernel_headers) + ] + + obj_files = [] + failed = [] + + with ProcessPoolExecutor(max_workers=args.jobs) as executor: + futures = {executor.submit(compile_kernel, w): w[0].name for w in work} + for future in as_completed(futures): + name, ok, result = future.result() + if ok: + obj_files.append(result) + else: + failed.append((name, result)) + print(f"[{target_name}] FAILED: {name}") + + if failed: + print(f"[{target_name}] {len(failed)} kernels failed") + for name, err in failed[:3]: + print(f" {name}: {err[:200]}") + return 1 + + # Create static library (use target_name for CMake compatibility) + lib_path = args.output_dir / f"lib{target_name}_kernels.a" + subprocess.run([find_ar(), "rcs", str(lib_path)] + obj_files, check=True) + + # Generate registration header (use source_stem for header name to match CMake's EXAMPLE_STEM) + header_path = args.output_dir / f"{source_stem}_kernels.hpp" + + # Build includes + includes = "\n".join(f'#include "{h.name}"' for h in kernel_headers) + + # Build kernel registration entries + # Function name uses source_stem (e.g., register_01_basic_gemm_kernels) + func_name = f"register_{source_stem}_kernels" + + # Generate registration code based on example type + if example_type == "gemm": + register_body = generate_gemm_registration(kernel_headers, target_name, kernels) + else: + register_body = generate_conv_registration(kernel_headers, target_name, kernels) + + # Generate appropriate header based on example type + if example_type == "conv" and kernel_headers: + launcher_aliases = [] + + # Helper to find kernel by dtype and type + def find_kernel_by_dtype_type(headers, dtype, conv_type_marker): + """Find kernel matching dtype and conv type, prioritize fp16.""" + matching = [h for h in headers if conv_type_marker in h.stem] + # Prefer fp16 over bf16 for default launchers + fp16_kernels = [h for h in matching if f"_{dtype}_" in h.stem] + return ( + fp16_kernels[0] if fp16_kernels else (matching[0] if matching else None) + ) + + # Check what conv types are in the declarations + has_fwd = any("forward" in k.get("conv_type", "forward") for k in kernels) + has_bwd_data = any("bwd_data" in k.get("conv_type", "") for k in kernels) + has_bwd_weight = any("bwd_weight" in k.get("conv_type", "") for k in kernels) + + # Export dtype-specific launcher aliases for each available dtype + for dtype in ["fp16", "bf16", "fp32"]: + dtype_fwd_kernels = [ + h + for h in kernel_headers + if "_fwd_" in h.stem and f"_{dtype}_" in h.stem + ] + if dtype_fwd_kernels: + k = dtype_fwd_kernels[0] + ns = f"ns_{k.stem}" + dtype_upper = dtype.upper() + launcher_aliases.append( + f"using {dtype_upper}FwdKernelLauncher = {ns}::{k.stem}_Launcher;" + ) + + # Export generic launcher aliases (prioritize fp16) + if has_fwd: + fwd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_fwd_") + if fwd_kernel: + fwd_ns = f"ns_{fwd_kernel.stem}" + launcher_aliases.append( + f"using FwdKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;" + ) + launcher_aliases.append( + f"using FirstKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;" + ) + + if has_bwd_data: + bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_") + if bwdd_kernel: + bwdd_ns = f"ns_{bwdd_kernel.stem}" + launcher_aliases.append( + f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + ) + if not has_fwd: # If no fwd, use bwd_data as first + launcher_aliases.append( + f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + ) + + if has_bwd_weight: + bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_") + if bwdw_kernel: + bwdw_ns = f"ns_{bwdw_kernel.stem}" + launcher_aliases.append( + f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + ) + if ( + not has_fwd and not has_bwd_data + ): # If no fwd or bwdd, use bwdw as first + launcher_aliases.append( + f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + ) + + launcher_section = "\n".join(launcher_aliases) + + header_content = f"""// Auto-generated for {target_name} +#pragma once + +{includes} + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" + +namespace generated {{ + +// Kernel launchers for direct use +{launcher_section} + +// Registration function +inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{ +{register_body} +}} + +}} // namespace generated + +// Generic registration - avoids hardcoding the example name in user code +// Safe for single-example executables (typical use case) +#ifndef REGISTER_GENERATED_KERNELS +#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch) +#endif +""" + else: + # GEMM: Generate per-set functions if multiple kernel sets declared + per_set_funcs = generate_per_set_functions(source_stem) + + header_content = f"""// Auto-generated for {target_name} +#pragma once + +{includes} + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp" + +namespace generated {{ + +// Register ALL kernels from all declared sets +inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{ +{register_body} +}} + +{per_set_funcs} +}} // namespace generated + +// Generic registration - avoids hardcoding the example name in user code +// Safe for single-example executables (typical use case) +#ifndef REGISTER_GENERATED_KERNELS +#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch) +#endif + +// Register a specific kernel set by name (for multi-registry patterns) +// Usage: REGISTER_KERNEL_SET("compute_bound_set", registry, arch) +#ifndef REGISTER_KERNEL_SET +#define REGISTER_KERNEL_SET(set_name, registry, arch) generated::register_kernel_set(set_name, registry, arch) +#endif +""" + header_path.write_text(header_content) + + print(f"[{target_name}] ✓ {len(obj_files)} kernels compiled") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/scripts/parallel_kernel_builder.py b/dispatcher/scripts/parallel_kernel_builder.py new file mode 100755 index 0000000000..911ea61bd7 --- /dev/null +++ b/dispatcher/scripts/parallel_kernel_builder.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Build kernels in parallel - one translation unit per kernel. + +This script is called at make time (not cmake time) to avoid slow cmake configuration. +""" + +import argparse +import os +import subprocess +import sys +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor, as_completed + + +def find_hipcc(): + """Find hipcc compiler.""" + candidates = [ + os.environ.get("HIPCC"), + "/opt/rocm/bin/hipcc", + shutil.which("hipcc") if shutil else None, + ] + for path in candidates: + if path and os.path.isfile(path): + return path + return "hipcc" # Assume in PATH + + +def compile_kernel(args): + """Compile a single kernel.""" + kernel_hpp, output_dir, include_dirs, hipcc = args + kernel_name = kernel_hpp.stem + + # Create wrapper .cpp + wrapper_cpp = output_dir / f"{kernel_name}.cpp" + wrapper_cpp.write_text(f'''// Auto-generated wrapper +#include "{kernel_hpp.name}" +namespace {{ volatile bool _k = true; }} +''') + + # Compile to object + obj_file = output_dir / f"{kernel_name}.o" + + cmd = [ + hipcc, + "-c", + "-fPIC", + "-std=c++17", + "-O3", + "--offload-arch=gfx942", + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + "--offload-compress", + ] + + for inc_dir in include_dirs: + cmd.extend(["-I", str(inc_dir)]) + cmd.extend(["-I", str(kernel_hpp.parent)]) + + cmd.extend(["-o", str(obj_file), str(wrapper_cpp)]) + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + return (kernel_name, False, result.stderr) + return (kernel_name, True, str(obj_file)) + + +def main(): + parser = argparse.ArgumentParser(description="Build kernels in parallel") + parser.add_argument("--kernel-dir", type=Path, required=True) + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--include-dirs", type=str, required=True) + parser.add_argument("--jobs", type=int, default=os.cpu_count()) + args = parser.parse_args() + + # Find kernel headers + kernel_headers = list(args.kernel_dir.glob("gemm_*.hpp")) + list( + args.kernel_dir.glob("conv_*.hpp") + ) + + if not kernel_headers: + print("No kernels found to build") + return 0 + + print(f"Building {len(kernel_headers)} kernels with {args.jobs} parallel jobs...") + + include_dirs = [Path(p.strip()) for p in args.include_dirs.split(",")] + hipcc = find_hipcc() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Prepare work items + work = [(h, args.output_dir, include_dirs, hipcc) for h in kernel_headers] + + # Compile in parallel + obj_files = [] + failed = [] + + with ProcessPoolExecutor(max_workers=args.jobs) as executor: + futures = {executor.submit(compile_kernel, w): w[0].name for w in work} + + for i, future in enumerate(as_completed(futures), 1): + name, success, result = future.result() + if success: + obj_files.append(result) + print(f"[{i}/{len(kernel_headers)}] Built: {name}") + else: + failed.append((name, result)) + print(f"[{i}/{len(kernel_headers)}] FAILED: {name}") + + if failed: + print(f"\n{len(failed)} kernels failed to compile:") + for name, err in failed[:5]: + print(f" {name}: {err[:100]}") + return 1 + + # Link into shared library + print(f"\nLinking {len(obj_files)} objects into libdispatcher_kernels.so...") + lib_path = args.output_dir / "libdispatcher_kernels.so" + + link_cmd = [hipcc, "-shared", "-fPIC", "-o", str(lib_path)] + obj_files + result = subprocess.run(link_cmd, capture_output=True, text=True) + + if result.returncode != 0: + print(f"Linking failed: {result.stderr}") + return 1 + + print(f"✓ Built: {lib_path}") + return 0 + + +if __name__ == "__main__": + import shutil + + sys.exit(main()) diff --git a/dispatcher/scripts/stress_test_autocorrect.py b/dispatcher/scripts/stress_test_autocorrect.py new file mode 100644 index 0000000000..13e92abffa --- /dev/null +++ b/dispatcher/scripts/stress_test_autocorrect.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Stress Test for Auto-Correction and Codegen + +This script tests the robustness of: +1. GEMM auto-correction (Python) +2. Conv auto-correction (Python) +3. C++ kernel declaration validation and wildcard expansion +4. Architecture filtering + +Usage: + python3 scripts/stress_test_autocorrect.py [--arch gfx942] [--samples 50] [--verbose] +""" + +import argparse +import random +import sys +from pathlib import Path + +# Add paths for imports +dispatcher_root = Path(__file__).parent.parent +sys.path.insert(0, str(dispatcher_root / "python")) +sys.path.insert(0, str(dispatcher_root / "codegen")) +sys.path.insert(0, str(dispatcher_root / "scripts")) + +from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402 + +# Import validation/expansion functions from compile scripts +from compile_gemm_examples import ( # noqa: E402 + validate_kernel_config, + expand_declaration_with_arch_filter, +) +from compile_conv_examples import ( # noqa: E402 + validate_conv_kernel_config, + expand_conv_declaration_with_arch_filter, +) + + +# ============================================================================= +# TEST PARAMETERS +# ============================================================================= + +# Valid dtypes +DTYPES = ["fp16", "bf16", "fp32", "fp8", "bf8", "int8"] + +# Valid layouts +LAYOUTS = ["rcr", "rrr", "crr", "ccr"] + +# Tile sizes (some valid, some invalid) +TILE_SIZES = [ + (32, 32, 16), + (64, 64, 32), + (128, 128, 32), + (256, 256, 64), + (128, 256, 32), + (256, 128, 32), + # Invalid sizes to test auto-correction + (100, 100, 50), + (17, 17, 17), + (512, 512, 128), +] + +# Wave configs (some valid, some invalid) +WAVE_CONFIGS = [ + (1, 1, 1), + (1, 2, 1), + (2, 1, 1), + (2, 2, 1), + (1, 4, 1), + (4, 1, 1), + (2, 4, 1), + (4, 2, 1), + # Invalid configs to test auto-correction + (3, 3, 1), + (5, 5, 1), + (1, 1, 2), +] + +# Warp tile sizes (some valid, some invalid) +WARP_TILES = [ + (16, 16, 16), + (16, 16, 32), + (32, 32, 8), + (32, 32, 16), + # Invalid tiles to test auto-correction + (48, 48, 24), + (64, 64, 32), +] + +# Pipelines and schedulers +PIPELINES = ["compv3", "compv4", "flatmma", "invalid_pipeline"] +SCHEDULERS = ["intrawave", "interwave", "invalid_scheduler"] + +# Architectures +ARCHS = ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1200", "gfx1201"] + + +# ============================================================================= +# TEST FUNCTIONS +# ============================================================================= + + +def generate_random_gemm_config(): + """Generate a random GEMM configuration (may be invalid).""" + dtype = random.choice(DTYPES) + layout = random.choice(LAYOUTS) + tile = random.choice(TILE_SIZES) + wave = random.choice(WAVE_CONFIGS) + warp = random.choice(WARP_TILES) + pipeline = random.choice(PIPELINES) + scheduler = random.choice(SCHEDULERS) + arch = random.choice(ARCHS) + + return { + "name": f"test_{dtype}_{layout}_{tile[0]}x{tile[1]}x{tile[2]}", + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "dtype_acc": "fp32", + "layout": layout, + "tile_m": tile[0], + "tile_n": tile[1], + "tile_k": tile[2], + "wave_m": wave[0], + "wave_n": wave[1], + "wave_k": wave[2], + "warp_m": warp[0], + "warp_n": warp[1], + "warp_k": warp[2], + "pipeline": pipeline, + "scheduler": scheduler, + "arch": arch, + } + + +def generate_random_conv_config(): + """Generate a random Conv configuration (may be invalid).""" + dtype = random.choice(["fp16", "bf16"]) + tile_k = random.choice([64, 128, 256]) + tile_c = random.choice([64, 128, 256]) + wave = random.choice(WAVE_CONFIGS) + warp = random.choice(WARP_TILES) + pipeline = random.choice(["compv3", "compv4"]) + scheduler = random.choice(["intrawave"]) + arch = random.choice(ARCHS) + + return { + "name": f"test_conv_{dtype}_{tile_k}x{tile_c}", + "dtype": dtype, + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": tile_k, + "tile_c": tile_c, + "wave_m": wave[0], + "wave_n": wave[1], + "wave_k": wave[2], + "warp_m": warp[0], + "warp_n": warp[1], + "warp_k": warp[2], + "pipeline": pipeline, + "scheduler": scheduler, + "arch": arch, + } + + +def test_gemm_validation(config, verbose=False): + """Test GEMM validation and auto-correction.""" + arch = config.get("arch", "gfx942") + is_valid, error_msg = validate_kernel_config(config, arch) + + result = { + "config": config, + "is_valid": is_valid, + "error_msg": error_msg, + "expanded": [], + "auto_corrected": None, + } + + if not is_valid: + # Try wildcard expansion + wildcard_config = config.copy() + wildcard_config["wave_m"] = -1 + wildcard_config["wave_n"] = -1 + wildcard_config["warp_m"] = -1 + wildcard_config["warp_n"] = -1 + wildcard_config["pipeline"] = "*" + wildcard_config["scheduler"] = "*" + + expanded = expand_declaration_with_arch_filter(wildcard_config, arch) + result["expanded"] = expanded + + if verbose: + print(f"\n Config: {config['name']}") + print(f" Valid: {is_valid}") + if not is_valid: + print(f" Error: {error_msg[:80]}...") + print(f" Expanded to: {len(result['expanded'])} configurations") + + return result + + +def test_python_autocorrect(verbose=False): + """Test Python auto-correction for GEMM KernelConfig.""" + print("\n" + "=" * 70) + print(" PYTHON AUTO-CORRECTION TEST (GEMM KernelConfig)") + print("=" * 70) + + test_cases = [ + # Valid config + { + "name": "valid_fp16", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "gfx_arch": "gfx942", + }, + # Invalid wave config + { + "name": "invalid_wave", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 1, + "wave_n": 1, + "wave_k": 1, # Invalid for gfx942 + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "gfx_arch": "gfx942", + }, + # Invalid scheduler + { + "name": "invalid_scheduler", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "interwave", # May not be valid for all archs + "gfx_arch": "gfx942", + }, + ] + + results = {"passed": 0, "failed": 0, "details": []} + + for tc in test_cases: + try: + config = KernelConfig() + config.dtype_a = tc["dtype_a"] + config.dtype_b = tc["dtype_b"] + config.dtype_c = tc["dtype_c"] + config.dtype_acc = tc["dtype_acc"] + config.tile_m = tc["tile_m"] + config.tile_n = tc["tile_n"] + config.tile_k = tc["tile_k"] + config.wave_m = tc["wave_m"] + config.wave_n = tc["wave_n"] + config.wave_k = tc["wave_k"] + config.warp_m = tc["warp_m"] + config.warp_n = tc["warp_n"] + config.warp_k = tc["warp_k"] + config.pipeline = tc["pipeline"] + config.scheduler = tc["scheduler"] + config.gfx_arch = tc["gfx_arch"] + + corrected, was_modified, corrections = auto_correct_kernel_config( + config, verbose=verbose + ) + + results["passed"] += 1 + results["details"].append( + { + "name": tc["name"], + "status": "PASS", + "was_modified": was_modified, + "corrections": corrections, + } + ) + + if verbose: + print(f"\n {tc['name']}: PASS") + if was_modified: + print(f" Modified: {len(corrections)} correction(s)") + for c in corrections: + print(f" • {c}") + + except Exception as e: + results["failed"] += 1 + results["details"].append( + {"name": tc["name"], "status": "FAIL", "error": str(e)} + ) + if verbose: + print(f"\n {tc['name']}: FAIL - {e}") + + print(f"\n Summary: {results['passed']} passed, {results['failed']} failed") + return results + + +def run_stress_test(arch, num_samples, verbose): + """Run the full stress test.""" + print("\n" + "=" * 70) + print(" DISPATCHER AUTO-CORRECTION & CODEGEN STRESS TEST") + print("=" * 70) + print(f" Target Architecture: {arch}") + print(f" Number of Samples: {num_samples}") + print("=" * 70) + + # Test 1: GEMM Validation + print("\n" + "-" * 70) + print(" TEST 1: GEMM Validation & Wildcard Expansion") + print("-" * 70) + + gemm_results = {"valid": 0, "invalid": 0, "expanded": 0, "expansion_failed": 0} + + for i in range(num_samples): + config = generate_random_gemm_config() + config["arch"] = arch # Override with target arch + + result = test_gemm_validation(config, verbose) + + if result["is_valid"]: + gemm_results["valid"] += 1 + else: + gemm_results["invalid"] += 1 + if result["expanded"]: + gemm_results["expanded"] += 1 + else: + gemm_results["expansion_failed"] += 1 + + print("\n GEMM Results:") + print(f" Valid configs: {gemm_results['valid']}") + print(f" Invalid configs: {gemm_results['invalid']}") + print(f" Successfully expanded: {gemm_results['expanded']}") + print(f" Expansion failed: {gemm_results['expansion_failed']}") + + # Test 2: Conv Validation + print("\n" + "-" * 70) + print(" TEST 2: Conv Validation & Wildcard Expansion") + print("-" * 70) + + conv_results = {"valid": 0, "invalid": 0, "expanded": 0, "expansion_failed": 0} + + for i in range(num_samples): + config = generate_random_conv_config() + config["arch"] = arch # Override with target arch + + is_valid, error_msg = validate_conv_kernel_config(config, arch) + + if is_valid: + conv_results["valid"] += 1 + else: + conv_results["invalid"] += 1 + # Try wildcard expansion + wildcard_config = config.copy() + wildcard_config["wave_m"] = -1 + wildcard_config["wave_n"] = -1 + wildcard_config["warp_m"] = -1 + wildcard_config["warp_n"] = -1 + + expanded = expand_conv_declaration_with_arch_filter(wildcard_config, arch) + if expanded: + conv_results["expanded"] += 1 + else: + conv_results["expansion_failed"] += 1 + + print("\n Conv Results:") + print(f" Valid configs: {conv_results['valid']}") + print(f" Invalid configs: {conv_results['invalid']}") + print(f" Successfully expanded: {conv_results['expanded']}") + print(f" Expansion failed: {conv_results['expansion_failed']}") + + # Test 3: Python Auto-Correction + print("\n" + "-" * 70) + print(" TEST 3: Python Auto-Correction (KernelConfig)") + print("-" * 70) + + py_results = test_python_autocorrect(verbose) + + # Test 4: Architecture-specific tests + print("\n" + "-" * 70) + print(" TEST 4: Architecture-Specific Validation") + print("-" * 70) + + arch_test_configs = [ + # fp16 should work on all archs + {"dtype": "fp16", "expected_archs": ARCHS}, + # bf16 works on all archs that have bf16_bf16_fp32 in warp_tile_combos + { + "dtype": "bf16", + "expected_archs": [ + "gfx908", + "gfx90a", + "gfx942", + "gfx950", + "gfx1100", + "gfx1200", + "gfx1201", + ], + }, + # fp8 works on archs that have fp8_fp8_fp32 in warp_tile_combos + { + "dtype": "fp8", + "expected_archs": ["gfx90a", "gfx942", "gfx950", "gfx1200", "gfx1201"], + }, + ] + + for test in arch_test_configs: + dtype = test["dtype"] + print(f"\n Testing {dtype}:") + + for test_arch in ARCHS: + config = { + "name": f"arch_test_{dtype}_{test_arch}", + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": -1, # Wildcard + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": -1, + "pipeline": "*", + "scheduler": "*", + "arch": test_arch, + } + + expanded = expand_declaration_with_arch_filter(config, test_arch) + status = "✓" if expanded else "✗" + expected = test_arch in test["expected_archs"] + match = "OK" if (bool(expanded) == expected) else "MISMATCH" + + if verbose or match == "MISMATCH": + print(f" {test_arch}: {status} ({len(expanded)} configs) [{match}]") + + # Summary + print("\n" + "=" * 70) + print(" STRESS TEST SUMMARY") + print("=" * 70) + print( + f" GEMM: {gemm_results['valid'] + gemm_results['expanded']}/{num_samples} handled" + ) + print( + f" Conv: {conv_results['valid'] + conv_results['expanded']}/{num_samples} handled" + ) + print( + f" Python Auto-Correct: {py_results['passed']}/{py_results['passed'] + py_results['failed']} passed" + ) + + total_success = ( + gemm_results["valid"] + + gemm_results["expanded"] + + conv_results["valid"] + + conv_results["expanded"] + + py_results["passed"] + ) + total_tests = num_samples * 2 + py_results["passed"] + py_results["failed"] + + print(f"\n Overall: {total_success}/{total_tests} tests handled successfully") + print("=" * 70) + + return ( + gemm_results["expansion_failed"] == 0 and conv_results["expansion_failed"] == 0 + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Stress test auto-correction and codegen" + ) + parser.add_argument( + "--arch", + default="gfx942", + choices=ARCHS, + help="Target GPU architecture (default: gfx942)", + ) + parser.add_argument( + "--samples", + type=int, + default=50, + help="Number of random samples to test (default: 50)", + ) + parser.add_argument( + "--verbose", "-v", action="store_true", help="Show detailed output" + ) + parser.add_argument( + "--seed", type=int, default=None, help="Random seed for reproducibility" + ) + + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + + success = run_stress_test(args.arch, args.samples, args.verbose) + + return 0 if success else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/src/dispatcher.cpp b/dispatcher/src/dispatcher.cpp new file mode 100644 index 0000000000..fdb400921e --- /dev/null +++ b/dispatcher/src/dispatcher.cpp @@ -0,0 +1,152 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +Dispatcher::Dispatcher(Registry* registry) + : registry_(registry ? registry : &Registry::instance()), + heuristic_(nullptr), + strategy_(SelectionStrategy::FirstFit) +{ +} + +void Dispatcher::set_heuristic(HeuristicFunction heuristic) +{ + heuristic_ = heuristic; + if(heuristic_) + { + strategy_ = SelectionStrategy::Heuristic; + } +} + +void Dispatcher::set_strategy(SelectionStrategy strategy) { strategy_ = strategy; } + +KernelInstancePtr Dispatcher::select_kernel(const Problem& problem) const +{ + if(!problem.is_valid()) + { + return nullptr; + } + + switch(strategy_) + { + case SelectionStrategy::FirstFit: return select_first_fit(problem); + case SelectionStrategy::Heuristic: return select_heuristic(problem); + default: return nullptr; + } +} + +float Dispatcher::run( + const void* a_ptr, const void* b_ptr, void* c_ptr, const Problem& problem, void* stream) const +{ + return run_fused(a_ptr, b_ptr, c_ptr, nullptr, problem, stream); +} + +float Dispatcher::run_fused(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const +{ + auto kernel = select_kernel(problem); + if(!kernel) + { + std::ostringstream oss; + oss << "No suitable kernel found for problem: M=" << problem.M << " N=" << problem.N + << " K=" << problem.K; + throw std::runtime_error(oss.str()); + } + + return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); +} + +float Dispatcher::run_explicit(const std::string& kernel_id, + const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const +{ + auto kernel = registry_->lookup(kernel_id); + if(!kernel) + { + throw std::runtime_error("Kernel not found: " + kernel_id); + } + + if(!kernel->supports(problem)) + { + std::ostringstream oss; + oss << "Kernel " << kernel_id << " does not support problem: M=" << problem.M + << " N=" << problem.N << " K=" << problem.K; + throw std::runtime_error(oss.str()); + } + + return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); +} + +bool Dispatcher::validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const +{ + auto kernel = select_kernel(problem); + if(!kernel) + { + return false; + } + + return kernel->validate(a_ptr, b_ptr, c_ptr, d_ptrs, problem, tolerance); +} + +KernelInstancePtr Dispatcher::select_first_fit(const Problem& problem) const +{ + auto all_kernels = registry_->get_all(); + + for(const auto& kernel : all_kernels) + { + if(kernel->supports(problem)) + { + return kernel; + } + } + + return nullptr; +} + +KernelInstancePtr Dispatcher::select_heuristic(const Problem& problem) const +{ + if(!heuristic_) + { + // Fall back to first-fit if no heuristic available + return select_first_fit(problem); + } + + // Get ranked list of kernel identifiers from heuristic + auto candidates = heuristic_(problem); + + // Try each candidate in order + for(const auto& kernel_id : candidates) + { + auto kernel = registry_->lookup(kernel_id); + if(kernel && kernel->supports(problem)) + { + return kernel; + } + } + + // If no heuristic candidate works, fall back to first-fit + return select_first_fit(problem); +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/src/registry.cpp b/dispatcher/src/registry.cpp new file mode 100644 index 0000000000..0d83afd613 --- /dev/null +++ b/dispatcher/src/registry.cpp @@ -0,0 +1,288 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/json_export.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include + +namespace ck_tile { +namespace dispatcher { + +Registry::Registry() + : name_("default"), + auto_export_enabled_(false), + auto_export_include_statistics_(true), + auto_export_on_every_registration_(true) +{ +} + +Registry::~Registry() +{ + // Perform auto-export on destruction if enabled (regardless of export_on_every_registration + // setting) + if(auto_export_enabled_) + { + perform_auto_export(); + } +} + +Registry::Registry(Registry&& other) noexcept + : mutex_() // mutex is not movable, create new one + , + kernels_(std::move(other.kernels_)), + name_(std::move(other.name_)), + auto_export_enabled_(other.auto_export_enabled_), + auto_export_filename_(std::move(other.auto_export_filename_)), + auto_export_include_statistics_(other.auto_export_include_statistics_), + auto_export_on_every_registration_(other.auto_export_on_every_registration_) +{ + // Disable auto-export on the moved-from object to prevent double export + other.auto_export_enabled_ = false; +} + +Registry& Registry::operator=(Registry&& other) noexcept +{ + if(this != &other) + { + std::lock_guard lock(mutex_); + std::lock_guard other_lock(other.mutex_); + + kernels_ = std::move(other.kernels_); + name_ = std::move(other.name_); + auto_export_enabled_ = other.auto_export_enabled_; + auto_export_filename_ = std::move(other.auto_export_filename_); + auto_export_include_statistics_ = other.auto_export_include_statistics_; + auto_export_on_every_registration_ = other.auto_export_on_every_registration_; + + // Disable auto-export on the moved-from object + other.auto_export_enabled_ = false; + } + return *this; +} + +bool Registry::register_kernel(KernelInstancePtr instance, Priority priority) +{ + if(!instance) + { + return false; + } + + const std::string identifier = instance->get_key().encode_identifier(); + + bool registered = false; + { + std::lock_guard lock(mutex_); + + auto it = kernels_.find(identifier); + if(it != kernels_.end()) + { + // Kernel with this identifier already exists + // Only replace if new priority is higher + if(priority > it->second.priority) + { + it->second.instance = instance; + it->second.priority = priority; + registered = true; + } + } + else + { + // New kernel, insert it + kernels_[identifier] = RegistryEntry{instance, priority}; + registered = true; + } + } + + // Perform auto-export if enabled and configured to export on every registration + if(registered && auto_export_enabled_ && auto_export_on_every_registration_) + { + perform_auto_export(); + } + + return registered; +} + +KernelInstancePtr Registry::lookup(const std::string& identifier) const +{ + std::lock_guard lock(mutex_); + + auto it = kernels_.find(identifier); + if(it != kernels_.end()) + { + return it->second.instance; + } + + return nullptr; +} + +KernelInstancePtr Registry::lookup(const KernelKey& key) const +{ + return lookup(key.encode_identifier()); +} + +std::vector Registry::get_all() const +{ + std::lock_guard lock(mutex_); + + std::vector result; + result.reserve(kernels_.size()); + + for(const auto& pair : kernels_) + { + result.push_back(pair.second.instance); + } + + return result; +} + +std::vector +Registry::filter(std::function predicate) const +{ + std::lock_guard lock(mutex_); + + std::vector result; + + for(const auto& pair : kernels_) + { + if(predicate(*pair.second.instance)) + { + result.push_back(pair.second.instance); + } + } + + return result; +} + +std::size_t Registry::size() const +{ + std::lock_guard lock(mutex_); + return kernels_.size(); +} + +bool Registry::empty() const +{ + std::lock_guard lock(mutex_); + return kernels_.empty(); +} + +void Registry::clear() +{ + std::lock_guard lock(mutex_); + kernels_.clear(); +} + +const std::string& Registry::get_name() const +{ + std::lock_guard lock(mutex_); + return name_; +} + +void Registry::set_name(const std::string& name) +{ + std::lock_guard lock(mutex_); + name_ = name; +} + +Registry& Registry::instance() +{ + static Registry global_registry; + return global_registry; +} + +std::string Registry::export_json(bool include_statistics) const +{ + return export_registry_json(*this, include_statistics); +} + +bool Registry::export_json_to_file(const std::string& filename, bool include_statistics) const +{ + return export_registry_json_to_file(*this, filename, include_statistics); +} + +void Registry::enable_auto_export(const std::string& filename, + bool include_statistics, + bool export_on_every_registration) +{ + std::lock_guard lock(mutex_); + auto_export_enabled_ = true; + auto_export_filename_ = filename; + auto_export_include_statistics_ = include_statistics; + auto_export_on_every_registration_ = export_on_every_registration; +} + +void Registry::disable_auto_export() +{ + std::lock_guard lock(mutex_); + auto_export_enabled_ = false; +} + +bool Registry::is_auto_export_enabled() const +{ + std::lock_guard lock(mutex_); + return auto_export_enabled_; +} + +void Registry::perform_auto_export() +{ + // Don't hold the lock during file I/O + std::string filename; + bool include_stats; + + { + std::lock_guard lock(mutex_); + if(!auto_export_enabled_) + { + return; + } + filename = auto_export_filename_; + include_stats = auto_export_include_statistics_; + } + + // Export without holding the lock + export_json_to_file(filename, include_stats); +} + +std::size_t Registry::merge_from(const Registry& other, Priority priority) +{ + auto other_kernels = other.get_all(); + std::size_t merged_count = 0; + + for(const auto& kernel : other_kernels) + { + if(register_kernel(kernel, priority)) + { + merged_count++; + } + } + + return merged_count; +} + +std::size_t Registry::filter_by_arch(const std::string& gpu_arch) +{ + ArchFilter filter(gpu_arch); + std::vector to_remove; + + { + std::lock_guard lock(mutex_); + + for(const auto& pair : kernels_) + { + if(!filter.is_valid(pair.second.instance->get_key())) + { + to_remove.push_back(pair.first); + } + } + + for(const auto& key : to_remove) + { + kernels_.erase(key); + } + } + + return to_remove.size(); +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/tests/CMakeLists.txt b/dispatcher/tests/CMakeLists.txt new file mode 100644 index 0000000000..6c20c18c95 --- /dev/null +++ b/dispatcher/tests/CMakeLists.txt @@ -0,0 +1,343 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# ============================================================================= +# CK Tile Dispatcher Tests (C++ and Python) +# ============================================================================= + +cmake_minimum_required(VERSION 3.16) + +# Find Python +find_package(Python3 COMPONENTS Interpreter REQUIRED) + +# ============================================================================= +# Python Tests +# ============================================================================= + +# Auto-correction and validation stress test +add_test( + NAME dispatcher_test_autocorrect + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_autocorrect.py + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_test_autocorrect PROPERTIES + LABELS "dispatcher;python;validation" + TIMEOUT 120 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Verbose version of the test +add_test( + NAME dispatcher_test_autocorrect_verbose + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_autocorrect.py -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_test_autocorrect_verbose PROPERTIES + LABELS "dispatcher;python;validation;verbose" + TIMEOUT 180 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Individual Python Test Categories +add_test( + NAME dispatcher_test_gemm_validation + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestGemmValidation test_autocorrect.TestGemmExpansion -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_gemm_validation PROPERTIES + LABELS "dispatcher;python;gemm;validation" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_python_autocorrect + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestPythonAutoCorrect -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_python_autocorrect PROPERTIES + LABELS "dispatcher;python;autocorrect" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_stress + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestStressRandom -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_stress PROPERTIES + LABELS "dispatcher;python;stress" + TIMEOUT 120 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_arch_support + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestArchitectureSupport -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_arch_support PROPERTIES + LABELS "dispatcher;python;arch" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Stress Test Script +add_test( + NAME dispatcher_stress_test + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/stress_test_autocorrect.py + --arch gfx942 --samples 30 --seed 42 + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_stress_test PROPERTIES + LABELS "dispatcher;python;stress;integration" + TIMEOUT 180 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# ============================================================================= +# Integration Tests (mimic examples) +# ============================================================================= + +# Full integration test suite +add_test( + NAME dispatcher_integration_tests + COMMAND ${Python3_EXECUTABLE} -m pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_examples_integration.py -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_integration_tests PROPERTIES + LABELS "dispatcher;python;integration;examples" + TIMEOUT 600 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Quick integration test (utilities only) +add_test( + NAME dispatcher_integration_quick + COMMAND ${Python3_EXECUTABLE} -m pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_examples_integration.py::TestUtilityImports -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_integration_quick PROPERTIES + LABELS "dispatcher;python;integration;quick" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# GEMM examples integration +add_test( + NAME dispatcher_integration_gemm + COMMAND ${Python3_EXECUTABLE} -m pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_examples_integration.py::TestGemmPythonExamples -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_integration_gemm PROPERTIES + LABELS "dispatcher;python;integration;gemm" + TIMEOUT 300 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# ============================================================================= +# C++ Tests (Google Test) +# ============================================================================= + +# Include Google Test setup +if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../../cmake/gtest.cmake") + include(${CMAKE_CURRENT_SOURCE_DIR}/../../cmake/gtest.cmake) +else() + include(gtest) +endif() + +# Mock kernel instance for testing (shared across tests) +add_library(dispatcher_test_utils STATIC + test_mock_kernel.cpp +) + +target_include_directories(dispatcher_test_utils PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CMAKE_CURRENT_SOURCE_DIR}/../../include +) + +target_link_libraries(dispatcher_test_utils PRIVATE + ck_tile_dispatcher +) + +# Test executables using Google Test +set(TEST_SOURCES + # Core unit tests + test_kernel_key.cpp + test_problem.cpp + test_registry.cpp + test_dispatcher.cpp + test_tile_backend.cpp + + # Extended unit tests (more comprehensive coverage) + test_kernel_key_extended.cpp + test_problem_extended.cpp + test_registry_extended.cpp + test_dispatcher_extended.cpp + + # Regression tests (known issues and edge cases) + test_regression.cpp + + # JSON export tests + test_json_export.cpp +) + +foreach(test_source ${TEST_SOURCES}) + get_filename_component(test_name ${test_source} NAME_WE) + + add_executable(${test_name} ${test_source}) + + target_link_libraries(${test_name} PRIVATE + ck_tile_dispatcher + dispatcher_test_utils + GTest::gtest_main + ) + + target_compile_options(${test_name} PRIVATE + -Wno-global-constructors + -Wno-undef + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) + set_tests_properties(${test_name} PROPERTIES LABELS "dispatcher;cpp;unit") +endforeach() + +# Standalone integration tests (with their own main()) +set(STANDALONE_TESTS + test_minimal.cpp +) + +foreach(test_source ${STANDALONE_TESTS}) + get_filename_component(test_name ${test_source} NAME_WE) + + add_executable(${test_name} ${test_source}) + + target_link_libraries(${test_name} PRIVATE + ck_tile_dispatcher + dispatcher_test_utils + ) + + target_compile_options(${test_name} PRIVATE + -Wno-global-constructors + -Wno-undef + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) + set_tests_properties(${test_name} PROPERTIES LABELS "dispatcher;cpp;integration") +endforeach() + +# ============================================================================= +# Real Kernel Tests (requires generated kernels) +# ============================================================================= + +set(KERNEL_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/../generated_kernels") +set(KERNEL_REGISTRATION_HEADER "${KERNEL_OUTPUT_DIR}/dispatcher_wrappers/register_all_kernels.hpp") +set(CODEGEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py") + +option(BUILD_DISPATCHER_REAL_KERNEL_TESTS "Build tests with real GPU kernels" ON) + +if(BUILD_DISPATCHER_REAL_KERNEL_TESTS AND EXISTS "${CODEGEN_SCRIPT}") + message(STATUS "Setting up real kernel test generation") + + add_custom_command( + OUTPUT ${KERNEL_REGISTRATION_HEADER} + COMMAND ${CMAKE_COMMAND} -E make_directory ${KERNEL_OUTPUT_DIR} + COMMAND ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT} + --output-dir ${KERNEL_OUTPUT_DIR} + --datatype fp16 + --layout rcr + --gpu-target gfx942 + --preselected fp16_rcr_essential + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating CK Tile kernels for real kernel tests..." + VERBATIM + ) + + add_custom_target(generate_test_kernels DEPENDS ${KERNEL_REGISTRATION_HEADER}) + + set(SINGLE_KERNEL_HEADER "${KERNEL_OUTPUT_DIR}/gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16.hpp") + + set(REAL_KERNEL_TESTS + test_real_kernel_simple + test_real_kernel_multi_size + test_real_kernel_performance + test_real_kernel_correctness + test_sanity_ck_tile + ) + + if(EXISTS "${SINGLE_KERNEL_HEADER}") + foreach(test_name ${REAL_KERNEL_TESTS}) + add_executable(${test_name} ${test_name}.cpp) + + add_dependencies(${test_name} generate_test_kernels) + + target_link_libraries(${test_name} PRIVATE + ck_tile_dispatcher + ) + + target_include_directories(${test_name} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${KERNEL_OUTPUT_DIR} + ) + + target_compile_options(${test_name} PRIVATE + -include ${SINGLE_KERNEL_HEADER} + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + + if(hip_FOUND) + target_link_libraries(${test_name} PRIVATE hip::device hip::host) + endif() + + add_test(NAME ${test_name} COMMAND ${test_name}) + set_tests_properties(${test_name} PROPERTIES LABELS "dispatcher;cpp;gpu;kernel") + endforeach() + endif() +endif() + +# ============================================================================= +# Custom Targets +# ============================================================================= + +add_custom_target(run_dispatcher_tests + COMMAND ${CMAKE_CTEST_COMMAND} -L dispatcher --output-on-failure + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running all dispatcher tests" +) + +add_custom_target(test_dispatcher_python + COMMAND ${CMAKE_CTEST_COMMAND} -L "dispatcher;python" --output-on-failure + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running Python dispatcher tests" +) + +add_custom_target(test_dispatcher_cpp + COMMAND ${CMAKE_CTEST_COMMAND} -L "dispatcher;cpp" --output-on-failure + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running C++ dispatcher tests" +) + +# ============================================================================= +# Summary +# ============================================================================= + +message(STATUS "Dispatcher tests configured:") +message(STATUS " Run all: ctest -L dispatcher") +message(STATUS " Run Python: ctest -L 'dispatcher;python' or make test_dispatcher_python") +message(STATUS " Run C++: ctest -L 'dispatcher;cpp' or make test_dispatcher_cpp") +message(STATUS " Run verbose: ctest -R dispatcher_test_autocorrect_verbose") diff --git a/dispatcher/tests/test_autocorrect.py b/dispatcher/tests/test_autocorrect.py new file mode 100644 index 0000000000..0ec3ebda3c --- /dev/null +++ b/dispatcher/tests/test_autocorrect.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Comprehensive Test Suite for Auto-Correction and Validation + +Tests: +1. GEMM validation and wildcard expansion +2. Conv validation and wildcard expansion +3. Python KernelConfig auto-correction +4. Architecture-specific dtype support +5. Edge cases and error handling + +Can be run as: + python3 tests/test_autocorrect.py # Run all tests + python3 tests/test_autocorrect.py -v # Verbose output + python3 tests/test_autocorrect.py TestGemmValidation # Run specific test class + ctest -R test_autocorrect # Via ctest + +Exit codes: + 0 = All tests passed + 1 = Some tests failed +""" + +import sys +import unittest +import random +from pathlib import Path + +# Setup paths +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) +sys.path.insert(0, str(DISPATCHER_DIR / "scripts")) + +# Import modules under test +from compile_gemm_examples import ( # noqa: E402 + validate_kernel_config, + expand_declaration_with_arch_filter, + is_wildcard_declaration, +) +from compile_conv_examples import ( # noqa: E402 + validate_conv_kernel_config, + expand_conv_declaration_with_arch_filter, + is_conv_wildcard_declaration, +) +from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402 + + +# ============================================================================= +# TEST DATA +# ============================================================================= + +VALID_ARCHS = ["gfx90a", "gfx942", "gfx950"] +VALID_DTYPES = ["fp16", "bf16"] +VALID_LAYOUTS = ["rcr", "rrr"] +VALID_PIPELINES = ["compv3", "compv4"] +VALID_SCHEDULERS = ["intrawave"] + +# Known valid wave configs for gfx942 +VALID_WAVE_CONFIGS_GFX942 = [[1, 4, 1], [2, 2, 1], [4, 1, 1]] + +# Known valid warp tiles for fp16 on gfx942 +VALID_WARP_TILES_FP16_GFX942 = [[16, 16, 16], [16, 16, 32], [32, 32, 8], [32, 32, 16]] + + +# ============================================================================= +# GEMM VALIDATION TESTS +# ============================================================================= + + +class TestGemmValidation(unittest.TestCase): + """Test GEMM kernel validation.""" + + def test_valid_config(self): + """Valid configuration should pass validation.""" + config = { + "name": "test_valid", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_kernel_config(config, "gfx942") + self.assertTrue(is_valid, f"Expected valid, got error: {error}") + + def test_invalid_wave_config(self): + """Invalid wave config should fail validation.""" + config = { + "name": "test_invalid_wave", + "dtype_a": "fp16", + "wave_m": 3, # Invalid + "wave_n": 3, # Invalid + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_kernel_config(config, "gfx942") + self.assertFalse(is_valid) + self.assertIn("wave", error.lower()) + + def test_invalid_scheduler(self): + """Invalid scheduler should fail validation.""" + config = { + "name": "test_invalid_scheduler", + "dtype_a": "fp16", + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "epilogue": "cshuffle", + "scheduler": "interwave", # Invalid with compv4+cshuffle + } + is_valid, error = validate_kernel_config(config, "gfx942") + self.assertFalse(is_valid) + self.assertIn("trait", error.lower()) + + def test_wildcard_skips_validation(self): + """Wildcard declarations should skip validation.""" + config = { + "name": "test_wildcard", + "dtype_a": "fp16", + "wave_m": -1, # Wildcard + "wave_n": -1, # Wildcard + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + self.assertTrue(is_wildcard_declaration(config)) + is_valid, _ = validate_kernel_config(config, "gfx942") + self.assertTrue(is_valid) + + def test_unsupported_arch(self): + """Unsupported architecture should fail validation.""" + config = { + "name": "test_bad_arch", + "dtype_a": "fp16", + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_kernel_config(config, "gfx_invalid") + self.assertFalse(is_valid) + self.assertIn("unsupported", error.lower()) + + +class TestGemmExpansion(unittest.TestCase): + """Test GEMM wildcard expansion.""" + + def test_wave_expansion(self): + """Wave wildcard should expand to valid configs.""" + config = { + "name": "test_wave_expand", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": -1, # Wildcard + "wave_n": -1, # Wildcard + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "Should expand to at least one config") + + # All expanded configs should be valid + for exp in expanded: + is_valid, error = validate_kernel_config(exp, "gfx942") + self.assertTrue(is_valid, f"Expanded config invalid: {error}") + + def test_full_wildcard_expansion(self): + """Full wildcard should expand to multiple valid configs.""" + config = { + "name": "test_full_wildcard", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater( + len(expanded), 1, "Full wildcard should expand to multiple configs" + ) + + def test_explicit_config_not_expanded(self): + """Explicit (non-wildcard) config should not expand.""" + config = { + "name": "test_explicit", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertEqual(len(expanded), 1, "Explicit config should not expand") + + +# ============================================================================= +# CONV VALIDATION TESTS +# ============================================================================= + + +class TestConvValidation(unittest.TestCase): + """Test Conv kernel validation.""" + + def test_valid_conv_config(self): + """Valid conv configuration should pass validation.""" + config = { + "name": "test_valid_conv", + "dtype": "fp16", + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": 128, + "tile_c": 128, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_conv_kernel_config(config, "gfx942") + self.assertTrue(is_valid, f"Expected valid, got error: {error}") + + def test_invalid_conv_wave(self): + """Invalid wave config should fail conv validation.""" + config = { + "name": "test_invalid_conv_wave", + "dtype": "fp16", + "wave_m": 5, # Invalid + "wave_n": 5, # Invalid + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_conv_kernel_config(config, "gfx942") + self.assertFalse(is_valid) + self.assertIn("wave", error.lower()) + + def test_conv_wildcard_detection(self): + """Should correctly detect conv wildcards.""" + wildcard_config = { + "wave_m": -1, + "wave_n": 2, + "warp_m": 32, + "warp_n": 32, + "pipeline": "compv4", + "scheduler": "intrawave", + } + self.assertTrue(is_conv_wildcard_declaration(wildcard_config)) + + explicit_config = { + "wave_m": 2, + "wave_n": 2, + "warp_m": 32, + "warp_n": 32, + "pipeline": "compv4", + "scheduler": "intrawave", + } + self.assertFalse(is_conv_wildcard_declaration(explicit_config)) + + +class TestConvExpansion(unittest.TestCase): + """Test Conv wildcard expansion.""" + + def test_conv_wave_expansion(self): + """Conv wave wildcard should expand to valid configs.""" + config = { + "name": "test_conv_wave_expand", + "dtype": "fp16", + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": 128, + "tile_c": 128, + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + expanded = expand_conv_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "Should expand to at least one config") + + +# ============================================================================= +# PYTHON AUTO-CORRECTION TESTS +# ============================================================================= + + +class TestPythonAutoCorrect(unittest.TestCase): + """Test Python KernelConfig auto-correction.""" + + def test_autocorrect_invalid_wave(self): + """Auto-correction should fix invalid wave config.""" + config = KernelConfig() + config.dtype_a = "fp16" + config.dtype_b = "fp16" + config.dtype_c = "fp16" + config.dtype_acc = "fp32" + config.layout_a = "row" + config.layout_b = "col" + config.layout_c = "row" + config.tile_m = 128 + config.tile_n = 128 + config.tile_k = 32 + config.wave_m = 1 # May be invalid + config.wave_n = 1 # May be invalid + config.wave_k = 1 + config.warp_m = 32 + config.warp_n = 32 + config.warp_k = 16 + config.pipeline = "compv4" + config.scheduler = "intrawave" + config.gfx_arch = "gfx942" + + corrected, was_modified, corrections = auto_correct_kernel_config( + config, verbose=False + ) + + # Should either be valid or corrected + self.assertIsNotNone(corrected) + if was_modified: + self.assertGreater(len(corrections), 0) + + def test_autocorrect_returns_three_values(self): + """Auto-correction should return (config, was_modified, corrections).""" + config = KernelConfig() + config.dtype_a = "fp16" + config.dtype_b = "fp16" + config.dtype_c = "fp16" + config.dtype_acc = "fp32" + config.layout_a = "row" + config.layout_b = "col" + config.layout_c = "row" + config.tile_m = 128 + config.tile_n = 128 + config.tile_k = 32 + config.wave_m = 2 + config.wave_n = 2 + config.wave_k = 1 + config.warp_m = 32 + config.warp_n = 32 + config.warp_k = 16 + config.pipeline = "compv4" + config.scheduler = "intrawave" + config.gfx_arch = "gfx942" + + result = auto_correct_kernel_config(config, verbose=False) + + self.assertEqual(len(result), 3, "Should return 3 values") + corrected, was_modified, corrections = result + self.assertIsInstance(was_modified, bool) + self.assertIsInstance(corrections, list) + + +# ============================================================================= +# STRESS TESTS +# ============================================================================= + + +class TestStressRandom(unittest.TestCase): + """Stress test with random configurations.""" + + def test_random_gemm_configs(self): + """Random GEMM configs should either validate or expand successfully.""" + random.seed(42) # Reproducible + + dtypes = ["fp16", "bf16"] + layouts = ["rcr", "rrr"] + tiles = [(64, 64, 32), (128, 128, 32), (256, 256, 64)] + waves = [(1, 1, 1), (2, 2, 1), (1, 4, 1), (3, 3, 1)] # Some invalid + warps = [(16, 16, 16), (32, 32, 16), (48, 48, 24)] # Some invalid + pipelines = ["compv3", "compv4", "invalid"] + schedulers = ["intrawave", "interwave"] + + success_count = 0 + total_count = 30 + + for _ in range(total_count): + config = { + "name": "random_test", + "dtype_a": random.choice(dtypes), + "dtype_b": random.choice(dtypes), + "dtype_c": random.choice(dtypes), + "layout": random.choice(layouts), + "tile_m": random.choice(tiles)[0], + "tile_n": random.choice(tiles)[1], + "tile_k": random.choice(tiles)[2], + "wave_m": random.choice(waves)[0], + "wave_n": random.choice(waves)[1], + "wave_k": random.choice(waves)[2], + "warp_m": random.choice(warps)[0], + "warp_n": random.choice(warps)[1], + "warp_k": random.choice(warps)[2], + "pipeline": random.choice(pipelines), + "scheduler": random.choice(schedulers), + } + + is_valid, _ = validate_kernel_config(config, "gfx942") + + if is_valid: + success_count += 1 + else: + # Try wildcard expansion + wildcard = config.copy() + wildcard["wave_m"] = -1 + wildcard["wave_n"] = -1 + wildcard["warp_m"] = -1 + wildcard["warp_n"] = -1 + wildcard["pipeline"] = "*" + wildcard["scheduler"] = "*" + + expanded = expand_declaration_with_arch_filter(wildcard, "gfx942") + if expanded: + success_count += 1 + + # At least 50% should be handleable + self.assertGreater( + success_count / total_count, + 0.5, + f"Only {success_count}/{total_count} configs were handleable", + ) + + def test_random_conv_configs(self): + """Random Conv configs should either validate or expand successfully.""" + random.seed(42) + + dtypes = ["fp16", "bf16"] + tiles = [(64, 64), (128, 128), (256, 256)] + waves = [(2, 2, 1), (1, 4, 1), (3, 3, 1)] + warps = [(16, 16, 16), (32, 32, 16)] + + success_count = 0 + total_count = 20 + + for _ in range(total_count): + config = { + "name": "random_conv_test", + "dtype": random.choice(dtypes), + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": random.choice(tiles)[0], + "tile_c": random.choice(tiles)[1], + "wave_m": random.choice(waves)[0], + "wave_n": random.choice(waves)[1], + "wave_k": random.choice(waves)[2], + "warp_m": random.choice(warps)[0], + "warp_n": random.choice(warps)[1], + "warp_k": random.choice(warps)[2], + "pipeline": "compv4", + "scheduler": "intrawave", + } + + is_valid, _ = validate_conv_kernel_config(config, "gfx942") + + if is_valid: + success_count += 1 + else: + # Try wildcard expansion + wildcard = config.copy() + wildcard["wave_m"] = -1 + wildcard["wave_n"] = -1 + wildcard["warp_m"] = -1 + wildcard["warp_n"] = -1 + + expanded = expand_conv_declaration_with_arch_filter(wildcard, "gfx942") + if expanded: + success_count += 1 + + self.assertGreater( + success_count / total_count, + 0.5, + f"Only {success_count}/{total_count} conv configs were handleable", + ) + + +# ============================================================================= +# ARCHITECTURE TESTS +# ============================================================================= + + +class TestArchitectureSupport(unittest.TestCase): + """Test architecture-specific support.""" + + def test_gfx942_fp16_support(self): + """gfx942 should support fp16.""" + config = { + "dtype_a": "fp16", + "wave_m": -1, + "wave_n": -1, + "warp_m": -1, + "warp_n": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "gfx942 should support fp16") + + def test_gfx942_bf16_support(self): + """gfx942 should support bf16.""" + config = { + "dtype_a": "bf16", + "wave_m": -1, + "wave_n": -1, + "warp_m": -1, + "warp_n": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "gfx942 should support bf16") + + def test_gfx90a_support(self): + """gfx90a should support fp16.""" + config = { + "dtype_a": "fp16", + "wave_m": -1, + "wave_n": -1, + "warp_m": -1, + "warp_n": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx90a") + self.assertGreater(len(expanded), 0, "gfx90a should support fp16") + + +# ============================================================================= +# MAIN +# ============================================================================= + + +def main(): + """Run tests.""" + # Parse args for verbosity + verbosity = 2 if "-v" in sys.argv or "--verbose" in sys.argv else 1 + + # Create test suite + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add all test classes + suite.addTests(loader.loadTestsFromTestCase(TestGemmValidation)) + suite.addTests(loader.loadTestsFromTestCase(TestGemmExpansion)) + suite.addTests(loader.loadTestsFromTestCase(TestConvValidation)) + suite.addTests(loader.loadTestsFromTestCase(TestConvExpansion)) + suite.addTests(loader.loadTestsFromTestCase(TestPythonAutoCorrect)) + suite.addTests(loader.loadTestsFromTestCase(TestStressRandom)) + suite.addTests(loader.loadTestsFromTestCase(TestArchitectureSupport)) + + # Run tests + runner = unittest.TextTestRunner(verbosity=verbosity) + result = runner.run(suite) + + # Return exit code + return 0 if result.wasSuccessful() else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/tests/test_dispatcher.cpp b/dispatcher/tests/test_dispatcher.cpp new file mode 100644 index 0000000000..1e3893756c --- /dev/null +++ b/dispatcher/tests/test_dispatcher.cpp @@ -0,0 +1,296 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for Dispatcher using Google Test + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +class DispatcherTest : public ::testing::Test +{ + protected: + void SetUp() override + { + // Clear registry before each test + Registry::instance().clear(); + } + + void TearDown() override + { + // Clean up after each test + Registry::instance().clear(); + } +}; + +TEST_F(DispatcherTest, SelectKernelFirstFit) +{ + Dispatcher dispatcher; + + // Register kernels + auto key1 = make_test_key(256); + auto key2 = make_test_key(128); + auto kernel1 = std::make_shared(key1, "kernel1"); + auto kernel2 = std::make_shared(key2, "kernel2"); + + Registry::instance().register_kernel(kernel1); + Registry::instance().register_kernel(kernel2); + + // Select kernel for valid problem + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + // Should select a kernel that supports the problem + // (order is not guaranteed, so just verify one is selected) + EXPECT_TRUE(selected->get_name() == "kernel1" || selected->get_name() == "kernel2"); + EXPECT_TRUE(selected->supports(problem)); +} + +TEST_F(DispatcherTest, SelectKernelInvalidProblem) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + // Invalid problem + Problem invalid_problem(0, 0, 0); + auto selected = dispatcher.select_kernel(invalid_problem); + + EXPECT_EQ(selected, nullptr); +} + +TEST_F(DispatcherTest, SelectKernelNoMatch) +{ + Dispatcher dispatcher; + + // Register kernel that doesn't support the problem + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1", false); + Registry::instance().register_kernel(kernel); + + // Problem with dimensions not divisible by tile size + Problem problem(100, 100, 100); // Not divisible by 256 + auto selected = dispatcher.select_kernel(problem); + + EXPECT_EQ(selected, nullptr); +} + +TEST_F(DispatcherTest, SelectKernelHeuristic) +{ + Dispatcher dispatcher; + + // Register kernels + auto key1 = make_test_key(256); + auto key2 = make_test_key(128); + auto kernel1 = std::make_shared(key1, "kernel1"); + auto kernel2 = std::make_shared(key2, "kernel2"); + + Registry::instance().register_kernel(kernel1); + Registry::instance().register_kernel(kernel2); + + // Set heuristic that prefers kernel2 + dispatcher.set_heuristic([](const Problem&) { + std::vector candidates; + auto key2 = make_test_key(128); + candidates.push_back(key2.encode_identifier()); + auto key1 = make_test_key(256); + candidates.push_back(key1.encode_identifier()); + return candidates; + }); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel2"); +} + +TEST_F(DispatcherTest, SelectKernelHeuristicFallback) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + // Set heuristic that returns non-existent kernel + dispatcher.set_heuristic( + [](const Problem&) { return std::vector{"nonexistent_kernel"}; }); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + // Should fall back to first-fit + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel1"); +} + +TEST_F(DispatcherTest, RunBasic) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + + // Mock pointers (not actually used) + float a[1], b[1], c[1]; + + float time_ms = dispatcher.run(a, b, c, problem); + + EXPECT_GT(time_ms, 0.0f); + EXPECT_EQ(kernel->get_execution_count(), 1); +} + +TEST_F(DispatcherTest, RunNoKernel) +{ + Dispatcher dispatcher; + + // No kernels registered + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + EXPECT_THROW((void)dispatcher.run(a, b, c, problem), std::runtime_error); +} + +TEST_F(DispatcherTest, RunExplicit) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + std::string kernel_id = key.encode_identifier(); + + float a[1], b[1], c[1]; + + float time_ms = dispatcher.run_explicit(kernel_id, a, b, c, nullptr, problem); + + EXPECT_GT(time_ms, 0.0f); + EXPECT_EQ(kernel->get_execution_count(), 1); +} + +TEST_F(DispatcherTest, RunExplicitNotFound) +{ + Dispatcher dispatcher; + + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + EXPECT_THROW((void)dispatcher.run_explicit("nonexistent", a, b, c, nullptr, problem), + std::runtime_error); +} + +TEST_F(DispatcherTest, RunExplicitNotSupported) +{ + Dispatcher dispatcher; + + // Register kernel that doesn't support the problem + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1", false); + Registry::instance().register_kernel(kernel); + + Problem problem(100, 100, 100); // Not divisible by 256 + std::string kernel_id = key.encode_identifier(); + + float a[1], b[1], c[1]; + + EXPECT_THROW((void)dispatcher.run_explicit(kernel_id, a, b, c, nullptr, problem), + std::runtime_error); +} + +TEST_F(DispatcherTest, Validate) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + bool valid = dispatcher.validate(a, b, c, nullptr, problem); + + EXPECT_TRUE(valid); +} + +TEST_F(DispatcherTest, ValidateNoKernel) +{ + Dispatcher dispatcher; + + // No kernels registered + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + bool valid = dispatcher.validate(a, b, c, nullptr, problem); + + EXPECT_FALSE(valid); +} + +TEST_F(DispatcherTest, StrategySelection) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + + // Test FirstFit strategy + dispatcher.set_strategy(Dispatcher::SelectionStrategy::FirstFit); + auto selected1 = dispatcher.select_kernel(problem); + ASSERT_NE(selected1, nullptr); + + // Test Heuristic strategy (without heuristic function - should fallback) + dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic); + auto selected2 = dispatcher.select_kernel(problem); + ASSERT_NE(selected2, nullptr); +} + +TEST_F(DispatcherTest, CustomRegistry) +{ + // Create custom registry instance (not singleton) + // Note: This requires Registry to allow non-singleton instances + // For now, we'll test with a separate registry instance + // In practice, custom registry would be created differently + + // Since Registry is singleton-only, we'll test that dispatcher + // can work with the singleton registry + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + registry.register_kernel(kernel); + + // Dispatcher defaults to singleton registry + Dispatcher dispatcher; + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel1"); +} diff --git a/dispatcher/tests/test_dispatcher_extended.cpp b/dispatcher/tests/test_dispatcher_extended.cpp new file mode 100644 index 0000000000..e8d7e4b5d1 --- /dev/null +++ b/dispatcher/tests/test_dispatcher_extended.cpp @@ -0,0 +1,499 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for Dispatcher - covers selection strategies, heuristics, edge cases + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; +using SelectionStrategy = Dispatcher::SelectionStrategy; + +// ============================================================================= +// Basic Dispatcher Tests +// ============================================================================= + +class DispatcherBasicTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(DispatcherBasicTest, DefaultConstruction) +{ + Dispatcher dispatcher; + // Should not crash + SUCCEED(); +} + +TEST_F(DispatcherBasicTest, SelectKernelEmpty) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + auto kernel = dispatcher.select_kernel(problem); + EXPECT_EQ(kernel, nullptr); +} + +TEST_F(DispatcherBasicTest, SelectKernelSingle) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "test_kernel"); +} + +TEST_F(DispatcherBasicTest, SelectKernelMultiple) +{ + // Register multiple kernels + for(int tile : {128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + // Should select one of the registered kernels + EXPECT_TRUE(selected->get_name() == "kernel_128" || selected->get_name() == "kernel_256" || + selected->get_name() == "kernel_512"); +} + +// ============================================================================= +// Selection Strategy Tests +// ============================================================================= + +class SelectionStrategyTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + // Register kernels with different tile sizes + for(int tile : {128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(SelectionStrategyTest, FirstFitStrategy) +{ + Dispatcher dispatcher; + dispatcher.set_strategy(SelectionStrategy::FirstFit); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + // FirstFit returns first matching kernel +} + +TEST_F(SelectionStrategyTest, HeuristicStrategy) +{ + Dispatcher dispatcher; + + // Set heuristic that prefers larger tiles for large problems + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + if(p.M >= 1024 && p.N >= 1024) + { + // For large problems, prefer 512 tile + auto key = make_test_key(512); + return {key.encode_identifier()}; + } + // For small problems, prefer 128 tile + auto key = make_test_key(128); + return {key.encode_identifier()}; + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + // Large problem should get 512 tile + Problem large_problem(2048, 2048, 2048); + auto selected = dispatcher.select_kernel(large_problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_512"); + + // Small problem should get 128 tile + Problem small_problem(256, 256, 256); + selected = dispatcher.select_kernel(small_problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_128"); +} + +TEST_F(SelectionStrategyTest, HeuristicWithFallback) +{ + Dispatcher dispatcher; + + // Heuristic returns non-existent kernel first, then valid one + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + auto key = make_test_key(256); + return {"nonexistent_kernel", key.encode_identifier()}; + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_256"); +} + +TEST_F(SelectionStrategyTest, SwitchBetweenStrategies) +{ + Dispatcher dispatcher; + + // Start with FirstFit + dispatcher.set_strategy(SelectionStrategy::FirstFit); + + Problem problem(1024, 1024, 1024); + auto selected1 = dispatcher.select_kernel(problem); + ASSERT_NE(selected1, nullptr); + + // Switch to Heuristic + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + auto key = make_test_key(256); + return {key.encode_identifier()}; + }); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + auto selected2 = dispatcher.select_kernel(problem); + ASSERT_NE(selected2, nullptr); +} + +// ============================================================================= +// Heuristic Function Tests +// ============================================================================= + +class HeuristicTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + for(int tile : {64, 128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(HeuristicTest, SizeBasedHeuristic) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + std::vector candidates; + + // Problem-size based selection + int size = p.M * p.N * p.K; + + if(size >= 1024 * 1024 * 1024) + { + candidates.push_back(make_test_key(512).encode_identifier()); + candidates.push_back(make_test_key(256).encode_identifier()); + } + else if(size >= 256 * 256 * 256) + { + candidates.push_back(make_test_key(256).encode_identifier()); + candidates.push_back(make_test_key(128).encode_identifier()); + } + else + { + candidates.push_back(make_test_key(64).encode_identifier()); + candidates.push_back(make_test_key(128).encode_identifier()); + } + + return candidates; + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + // Large problem + auto selected = dispatcher.select_kernel(Problem(1024, 1024, 1024)); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_512"); + + // Medium problem + selected = dispatcher.select_kernel(Problem(256, 256, 256)); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_256"); + + // Small problem + selected = dispatcher.select_kernel(Problem(64, 64, 64)); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_64"); +} + +TEST_F(HeuristicTest, EmptyHeuristicFallsBackToFirstFit) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {}; // Empty list + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + // Should fall back to FirstFit + ASSERT_NE(selected, nullptr); +} + +TEST_F(HeuristicTest, InvalidHeuristicFallsBackToFirstFit) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {"invalid_kernel_1", "invalid_kernel_2"}; // All invalid + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + // Should fall back to FirstFit + ASSERT_NE(selected, nullptr); +} + +// ============================================================================= +// Dispatcher with Custom Registry Tests +// ============================================================================= + +class DispatcherCustomRegistryTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(DispatcherCustomRegistryTest, UseCustomRegistry) +{ + Registry custom_registry; + custom_registry.set_name("custom"); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "custom_kernel"); + custom_registry.register_kernel(kernel); + + Dispatcher dispatcher(&custom_registry); + Problem problem(1024, 1024, 1024); + + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "custom_kernel"); +} + +TEST_F(DispatcherCustomRegistryTest, CustomRegistryIsolation) +{ + Registry custom_registry; + + auto key_custom = make_test_key(256); + auto key_global = make_test_key(512); + + custom_registry.register_kernel( + std::make_shared(key_custom, "custom_kernel")); + Registry::instance().register_kernel( + std::make_shared(key_global, "global_kernel")); + + Dispatcher custom_dispatcher(&custom_registry); + Dispatcher global_dispatcher; + + Problem problem(1024, 1024, 1024); + + auto custom_selected = custom_dispatcher.select_kernel(problem); + auto global_selected = global_dispatcher.select_kernel(problem); + + ASSERT_NE(custom_selected, nullptr); + ASSERT_NE(global_selected, nullptr); + + EXPECT_EQ(custom_selected->get_name(), "custom_kernel"); + EXPECT_EQ(global_selected->get_name(), "global_kernel"); +} + +// ============================================================================= +// Edge Cases Tests +// ============================================================================= + +class DispatcherEdgeCasesTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(DispatcherEdgeCasesTest, InvalidProblem) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + + // Zero dimensions + Problem invalid(0, 1024, 1024); + EXPECT_FALSE(invalid.is_valid()); + + // The dispatcher should still attempt selection + // (validation is up to the kernel's supports() method) +} + +TEST_F(DispatcherEdgeCasesTest, KernelDoesNotSupportProblem) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "selective_kernel", false); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + + // Problem not divisible by tile size - kernel doesn't support it + Problem problem(1000, 1000, 1000); // Not divisible by 256 + + auto selected = dispatcher.select_kernel(problem); + // Should return nullptr since kernel doesn't support this problem + EXPECT_EQ(selected, nullptr); +} + +TEST_F(DispatcherEdgeCasesTest, MultipleSelectionsConsistent) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Multiple selections should return the same kernel + auto selected1 = dispatcher.select_kernel(problem); + auto selected2 = dispatcher.select_kernel(problem); + auto selected3 = dispatcher.select_kernel(problem); + + ASSERT_NE(selected1, nullptr); + EXPECT_EQ(selected1, selected2); + EXPECT_EQ(selected2, selected3); +} + +// ============================================================================= +// Validate Method Tests +// ============================================================================= + +class DispatcherValidateTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + kernel_ = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel_); + } + + void TearDown() override { Registry::instance().clear(); } + + std::shared_ptr kernel_; +}; + +TEST_F(DispatcherValidateTest, ValidateWithMockKernel) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // MockKernelInstance always validates successfully + bool valid = dispatcher.validate(nullptr, nullptr, nullptr, nullptr, problem); + + // This depends on implementation - mock returns true + // Real validation would need actual data +} + +// ============================================================================= +// Run Method Tests (with mock) +// ============================================================================= + +class DispatcherRunTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + kernel_ = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel_); + } + + void TearDown() override { Registry::instance().clear(); } + + std::shared_ptr kernel_; +}; + +TEST_F(DispatcherRunTest, RunWithMockKernel) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Mock run (with null pointers - mock doesn't use them) + float time = dispatcher.run(nullptr, nullptr, nullptr, problem); + + // Mock kernel returns 1.0f + EXPECT_FLOAT_EQ(time, 1.0f); + + // Verify execution count + EXPECT_EQ(kernel_->get_execution_count(), 1); +} + +TEST_F(DispatcherRunTest, MultipleRuns) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + for(int i = 0; i < 10; i++) + { + (void)dispatcher.run(nullptr, nullptr, nullptr, problem); + } + + EXPECT_EQ(kernel_->get_execution_count(), 10); +} + +TEST_F(DispatcherRunTest, RunWithNoKernelThrows) +{ + Registry::instance().clear(); + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Should throw when no kernel found + EXPECT_THROW((void)dispatcher.run(nullptr, nullptr, nullptr, problem), std::runtime_error); +} diff --git a/dispatcher/tests/test_examples_integration.py b/dispatcher/tests/test_examples_integration.py new file mode 100644 index 0000000000..cfd18a3305 --- /dev/null +++ b/dispatcher/tests/test_examples_integration.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Integration tests that verify examples work correctly. + +These tests mimic the examples to ensure they continue working. +Run with: pytest test_examples_integration.py -v +""" + +import unittest +import subprocess +import sys +import os +from pathlib import Path + +# Get paths +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_ROOT = SCRIPT_DIR.parent +EXAMPLES_DIR = DISPATCHER_ROOT / "examples" +BUILD_DIR = DISPATCHER_ROOT / "build" +PYTHON_DIR = DISPATCHER_ROOT / "python" + +# Add python utilities to path +sys.path.insert(0, str(PYTHON_DIR)) + + +def run_python_example( + example_path: Path, timeout: int = 120 +) -> subprocess.CompletedProcess: + """Run a Python example and capture output.""" + env = os.environ.copy() + env["PYTHONPATH"] = str(PYTHON_DIR) + + return subprocess.run( + [sys.executable, str(example_path)], + capture_output=True, + text=True, + timeout=timeout, + cwd=example_path.parent, + env=env, + ) + + +def run_cpp_example( + example_name: str, timeout: int = 60 +) -> subprocess.CompletedProcess: + """Run a C++ example and capture output.""" + example_path = BUILD_DIR / "examples" / example_name + + if not example_path.exists(): + return None + + return subprocess.run( + [str(example_path)], + capture_output=True, + text=True, + timeout=timeout, + ) + + +class TestGemmPythonExamples(unittest.TestCase): + """Test GEMM Python examples.""" + + @classmethod + def setUpClass(cls): + """Check if examples directory exists.""" + cls.gemm_examples_dir = EXAMPLES_DIR / "gemm" / "python" + if not cls.gemm_examples_dir.exists(): + raise unittest.SkipTest("GEMM Python examples not found") + + def test_01_basic_gemm(self): + """Test basic GEMM example.""" + example = self.gemm_examples_dir / "01_basic_gemm.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_02_batch_gemm(self): + """Test batch GEMM example.""" + example = self.gemm_examples_dir / "02_batch_gemm.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_03_benchmark(self): + """Test benchmark example.""" + example = self.gemm_examples_dir / "03_benchmark.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_04_validation(self): + """Test validation example.""" + example = self.gemm_examples_dir / "04_validation.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + # Should pass validation + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestConvPythonExamples(unittest.TestCase): + """Test Conv Python examples.""" + + @classmethod + def setUpClass(cls): + """Check if examples directory exists.""" + cls.conv_examples_dir = EXAMPLES_DIR / "conv" / "python" + if not cls.conv_examples_dir.exists(): + raise unittest.SkipTest("Conv Python examples not found") + + def test_01_basic_conv(self): + """Test basic conv example.""" + example = self.conv_examples_dir / "01_basic_conv.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_02_conv2d_fwd(self): + """Test 2D forward conv example.""" + example = self.conv_examples_dir / "02_conv2d_fwd.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_03_conv3d_fwd(self): + """Test 3D forward conv example.""" + example = self.conv_examples_dir / "03_conv3d_fwd.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_07_validation(self): + """Test validation example.""" + example = self.conv_examples_dir / "07_validation.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestGemmCppExamples(unittest.TestCase): + """Test GEMM C++ examples.""" + + @classmethod + def setUpClass(cls): + """Check if build directory exists.""" + cls.examples_dir = BUILD_DIR / "examples" + if not cls.examples_dir.exists(): + raise unittest.SkipTest("C++ examples not built") + + def test_gemm_01_basic(self): + """Test basic GEMM C++ example.""" + result = run_cpp_example("gemm_01_basic") + if result is None: + self.skipTest("gemm_01_basic not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_gemm_02_multi_size(self): + """Test multi-size GEMM C++ example.""" + result = run_cpp_example("gemm_02_multi_size") + if result is None: + self.skipTest("gemm_02_multi_size not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_gemm_04_validation(self): + """Test validation GEMM C++ example.""" + result = run_cpp_example("gemm_04_validation") + if result is None: + self.skipTest("gemm_04_validation not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestConvCppExamples(unittest.TestCase): + """Test Conv C++ examples.""" + + @classmethod + def setUpClass(cls): + """Check if build directory exists.""" + cls.examples_dir = BUILD_DIR / "examples" + if not cls.examples_dir.exists(): + raise unittest.SkipTest("C++ examples not built") + + def test_conv_01_forward(self): + """Test forward conv C++ example.""" + result = run_cpp_example("conv_01_forward") + if result is None: + self.skipTest("conv_01_forward not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_conv_02_validation(self): + """Test validation conv C++ example.""" + result = run_cpp_example("conv_02_validation") + if result is None: + self.skipTest("conv_02_validation not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestUtilityImports(unittest.TestCase): + """Test that utility modules can be imported.""" + + def test_import_ctypes_utils(self): + """Test importing ctypes_utils.""" + try: + from ctypes_utils import KernelConfig, setup_gemm_dispatcher # noqa: F401 + + self.assertTrue(True) + except ImportError as e: + self.fail(f"Failed to import ctypes_utils: {e}") + + def test_import_conv_utils(self): + """Test importing conv_utils.""" + try: + from conv_utils import ConvSignature, ConvAlgorithm, ConvProblem # noqa: F401 + + self.assertTrue(True) + except ImportError as e: + self.fail(f"Failed to import conv_utils: {e}") + + def test_kernel_config_creation(self): + """Test creating a KernelConfig.""" + from ctypes_utils import KernelConfig + + config = KernelConfig( + dtype_a="fp16", + dtype_b="fp16", + dtype_c="fp16", + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + ) + + self.assertEqual(config.dtype_a, "fp16") + self.assertEqual(config.layout_a, "row") + + def test_conv_signature_creation(self): + """Test creating a ConvSignature.""" + from conv_utils import ConvSignature + + sig = ConvSignature( + dtype_in="fp16", + dtype_wei="fp16", + dtype_out="fp16", + dtype_acc="fp32", + layout="nhwgc", + direction="forward", + num_dims=2, + ) + + self.assertEqual(sig.dtype_in, "fp16") + self.assertEqual(sig.direction, "forward") + + +class TestAutoCorrection(unittest.TestCase): + """Test auto-correction functionality.""" + + def test_gemm_auto_correct(self): + """Test GEMM auto-correction.""" + from ctypes_utils import KernelConfig, auto_correct_kernel_config + + # Create a config with invalid wave config + config = KernelConfig( + dtype_a="fp16", + dtype_b="fp16", + dtype_c="fp16", + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + wave_m=99, # Invalid + wave_n=99, # Invalid + wave_k=99, # Invalid + ) + + corrected, was_modified, corrections = auto_correct_kernel_config(config) + + self.assertTrue(was_modified, "Config should be modified") + self.assertGreater(len(corrections), 0, "Should have corrections") + + def test_conv_auto_correct(self): + """Test Conv auto-correction.""" + from conv_utils import auto_correct_conv_config + + # Call with invalid wave config parameters + corrected, was_modified, corrections = auto_correct_conv_config( + wave_m=99, # Invalid + wave_n=99, # Invalid + wave_k=99, # Invalid + dtype="fp16", + arch="gfx942", + ) + + self.assertTrue(was_modified, "Config should be modified") + self.assertGreater(len(corrections), 0, "Should have corrections") + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_json_export.cpp b/dispatcher/tests/test_json_export.cpp new file mode 100644 index 0000000000..4392729554 --- /dev/null +++ b/dispatcher/tests/test_json_export.cpp @@ -0,0 +1,448 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for JSON export functionality + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/json_export.hpp" +#include "test_mock_kernel.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +// ============================================================================= +// Basic Export Tests +// ============================================================================= + +class JSONExportBasicTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONExportBasicTest, ExportEmptyRegistry) +{ + std::string json = Registry::instance().export_json(false); + + EXPECT_FALSE(json.empty()); + EXPECT_NE(json.find("\"kernels\""), std::string::npos); + // Empty registry should still produce valid JSON with kernels section +} + +TEST_F(JSONExportBasicTest, ExportSingleKernel) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(false); + + EXPECT_FALSE(json.empty()); + EXPECT_NE(json.find("\"test_kernel\""), std::string::npos); +} + +TEST_F(JSONExportBasicTest, ExportMultipleKernels) +{ + for(int i = 0; i < 5; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + std::string json = Registry::instance().export_json(false); + + // Should contain all kernel names + for(int i = 0; i < 5; i++) + { + EXPECT_NE(json.find("\"kernel_" + std::to_string(i) + "\""), std::string::npos); + } +} + +// ============================================================================= +// Export with Statistics Tests +// ============================================================================= + +class JSONExportStatisticsTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONExportStatisticsTest, ExportWithStatistics) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); // Include statistics + + EXPECT_NE(json.find("\"statistics\""), std::string::npos); + EXPECT_NE(json.find("\"by_datatype\""), std::string::npos); + EXPECT_NE(json.find("\"by_pipeline\""), std::string::npos); +} + +TEST_F(JSONExportStatisticsTest, ExportWithoutStatistics) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(false); // No statistics + + // Statistics section might be minimal or absent + EXPECT_NE(json.find("\"kernels\""), std::string::npos); +} + +// ============================================================================= +// Metadata Tests +// ============================================================================= + +class JSONExportMetadataTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONExportMetadataTest, MetadataPresent) +{ + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"metadata\""), std::string::npos); + EXPECT_NE(json.find("\"timestamp\""), std::string::npos); + EXPECT_NE(json.find("\"total_kernels\""), std::string::npos); +} + +TEST_F(JSONExportMetadataTest, CorrectKernelCount) +{ + const int num_kernels = 7; + for(int i = 0; i < num_kernels; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"total_kernels\": " + std::to_string(num_kernels)), std::string::npos); +} + +TEST_F(JSONExportMetadataTest, RegistryNameIncluded) +{ + Registry::instance().set_name("test_registry"); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"registry_name\""), std::string::npos); + EXPECT_NE(json.find("\"test_registry\""), std::string::npos); +} + +// ============================================================================= +// Export to File Tests +// ============================================================================= + +class JSONExportToFileTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + test_file_ = "/tmp/test_export_" + std::to_string(time(nullptr)) + ".json"; + } + + void TearDown() override + { + Registry::instance().clear(); + std::remove(test_file_.c_str()); + } + + std::string test_file_; +}; + +TEST_F(JSONExportToFileTest, ExportToFile) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + bool success = Registry::instance().export_json_to_file(test_file_, true); + EXPECT_TRUE(success); + + // Verify file exists + std::ifstream file(test_file_); + EXPECT_TRUE(file.good()); + + // Verify content + std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + EXPECT_NE(content.find("\"kernel\""), std::string::npos); +} + +TEST_F(JSONExportToFileTest, ExportToInvalidPath) +{ + bool success = Registry::instance().export_json_to_file("/invalid/path/file.json", true); + EXPECT_FALSE(success); +} + +// ============================================================================= +// Auto-Export Tests +// ============================================================================= + +class JSONAutoExportTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + Registry::instance().disable_auto_export(); + test_file_ = "/tmp/test_auto_export_" + std::to_string(time(nullptr)) + ".json"; + } + + void TearDown() override + { + Registry::instance().disable_auto_export(); + Registry::instance().clear(); + std::remove(test_file_.c_str()); + } + + std::string test_file_; +}; + +TEST_F(JSONAutoExportTest, EnableAutoExport) +{ + EXPECT_FALSE(Registry::instance().is_auto_export_enabled()); + + Registry::instance().enable_auto_export(test_file_, true, false); + + EXPECT_TRUE(Registry::instance().is_auto_export_enabled()); +} + +TEST_F(JSONAutoExportTest, DisableAutoExport) +{ + Registry::instance().enable_auto_export(test_file_, true, false); + EXPECT_TRUE(Registry::instance().is_auto_export_enabled()); + + Registry::instance().disable_auto_export(); + EXPECT_FALSE(Registry::instance().is_auto_export_enabled()); +} + +TEST_F(JSONAutoExportTest, AutoExportOnRegistration) +{ + // Enable auto-export with export_on_every_registration=true + Registry::instance().enable_auto_export(test_file_, true, false); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "auto_kernel"); + Registry::instance().register_kernel(kernel); + + // File might be created on registration or on exit depending on implementation + // Just verify auto-export is enabled + EXPECT_TRUE(Registry::instance().is_auto_export_enabled()); +} + +// ============================================================================= +// JSON Validity Tests +// ============================================================================= + +class JSONValidityTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } + + // Simple JSON syntax checker + bool isValidJSON(const std::string& json) + { + int braces = 0; + int brackets = 0; + bool in_string = false; + char prev = '\0'; + + for(char c : json) + { + if(c == '"' && prev != '\\') + { + in_string = !in_string; + } + + if(!in_string) + { + if(c == '{') + braces++; + else if(c == '}') + braces--; + else if(c == '[') + brackets++; + else if(c == ']') + brackets--; + } + + if(braces < 0 || brackets < 0) + return false; + prev = c; + } + + return braces == 0 && brackets == 0 && !in_string; + } +}; + +TEST_F(JSONValidityTest, EmptyRegistryProducesValidJSON) +{ + std::string json = Registry::instance().export_json(true); + EXPECT_TRUE(isValidJSON(json)); +} + +TEST_F(JSONValidityTest, SingleKernelProducesValidJSON) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + EXPECT_TRUE(isValidJSON(json)); +} + +TEST_F(JSONValidityTest, ManyKernelsProduceValidJSON) +{ + for(int i = 0; i < 50; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + std::string json = Registry::instance().export_json(true); + EXPECT_TRUE(isValidJSON(json)); +} + +TEST_F(JSONValidityTest, NoNullBytesInJSON) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + // Check for null bytes + EXPECT_EQ(json.find('\0'), std::string::npos); +} + +TEST_F(JSONValidityTest, NoPrintableGarbageInJSON) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + // All characters should be printable or whitespace + for(char c : json) + { + EXPECT_TRUE(std::isprint(c) || std::isspace(c)) + << "Non-printable character: " << static_cast(c); + } +} + +// ============================================================================= +// Kernel Details Tests +// ============================================================================= + +class JSONKernelDetailsTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONKernelDetailsTest, SignatureIncluded) +{ + auto key = make_test_key(256); + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"signature\""), std::string::npos); + EXPECT_NE(json.find("\"dtype_a\""), std::string::npos); + EXPECT_NE(json.find("\"fp16\""), std::string::npos); +} + +TEST_F(JSONKernelDetailsTest, AlgorithmIncluded) +{ + auto key = make_test_key(256, 256, 32); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"algorithm\""), std::string::npos); + EXPECT_NE(json.find("\"tile_shape\""), std::string::npos); +} + +TEST_F(JSONKernelDetailsTest, IdentifierIncluded) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "my_kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"identifier\""), std::string::npos); + EXPECT_NE(json.find("\"name\""), std::string::npos); + EXPECT_NE(json.find("\"my_kernel\""), std::string::npos); +} + +// ============================================================================= +// Multiple Registries Export Tests +// ============================================================================= + +class JSONMultipleRegistriesTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONMultipleRegistriesTest, DifferentRegistriesDifferentJSON) +{ + Registry reg1; + reg1.set_name("registry1"); + + Registry reg2; + reg2.set_name("registry2"); + + auto key1 = make_test_key(128); + auto key2 = make_test_key(256); + + reg1.register_kernel(std::make_shared(key1, "k1")); + reg2.register_kernel(std::make_shared(key2, "k2")); + + std::string json1 = reg1.export_json(true); + std::string json2 = reg2.export_json(true); + + EXPECT_NE(json1, json2); + + EXPECT_NE(json1.find("\"registry1\""), std::string::npos); + EXPECT_NE(json2.find("\"registry2\""), std::string::npos); + + EXPECT_NE(json1.find("\"k1\""), std::string::npos); + EXPECT_NE(json2.find("\"k2\""), std::string::npos); +} diff --git a/dispatcher/tests/test_kernel_key.cpp b/dispatcher/tests/test_kernel_key.cpp new file mode 100644 index 0000000000..b35641952a --- /dev/null +++ b/dispatcher/tests/test_kernel_key.cpp @@ -0,0 +1,147 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for KernelKey using Google Test + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +TEST(KernelKeyTest, Construction) +{ + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + + key.gfx_arch = "gfx942"; + + EXPECT_EQ(key.signature.dtype_a, DataType::FP16); + EXPECT_EQ(key.algorithm.tile_shape.m, 256); + EXPECT_EQ(key.gfx_arch, "gfx942"); +} + +TEST(KernelKeyTest, Equality) +{ + // Use helper function to ensure all fields are initialized + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + KernelKey key2 = make_test_key(256, 256, 32, "gfx942"); + + EXPECT_EQ(key1, key2); + EXPECT_FALSE(key1 != key2); + + // Change one value + KernelKey key3 = make_test_key(128, 256, 32, "gfx942"); + EXPECT_NE(key1, key3); + EXPECT_FALSE(key1 == key3); +} + +TEST(KernelKeyTest, EncodeIdentifier) +{ + KernelKey key; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.persistent = true; + key.algorithm.preshuffle = false; + key.signature.structured_sparsity = false; + + std::string id = key.encode_identifier(); + + // Check that identifier contains expected components + EXPECT_NE(id.find("256x256x32"), std::string::npos); // tile shape + EXPECT_NE(id.find("2x2x1"), std::string::npos); // wave shape + EXPECT_NE(id.find("32x32x16"), std::string::npos); // warp tile shape + EXPECT_NE(id.find("persist"), std::string::npos); // persistent flag +} + +TEST(KernelKeyTest, EncodeIdentifierWithFusion) +{ + KernelKey key; + key.signature.split_k = 1; + key.signature.elementwise_op = "Relu"; + key.signature.num_d_tensors = 2; + key.algorithm.tile_shape.m = 128; + key.algorithm.tile_shape.n = 128; + key.algorithm.tile_shape.k = 64; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 16; + key.algorithm.warp_tile_shape.n = 16; + key.algorithm.warp_tile_shape.k = 32; + key.algorithm.persistent = false; + key.signature.structured_sparsity = false; + + std::string id = key.encode_identifier(); + + // Check fusion-specific components + EXPECT_NE(id.find("Relu"), std::string::npos); + EXPECT_NE(id.find("_d2"), std::string::npos); + EXPECT_NE(id.find("nopers"), std::string::npos); +} + +TEST(KernelKeyTest, EncodeIdentifierWithSplitK) +{ + KernelKey key; + key.signature.split_k = 4; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.persistent = false; + key.signature.structured_sparsity = false; + + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("_splitk4"), std::string::npos); +} + +TEST(KernelKeyTest, EncodeIdentifierWithSparsity) +{ + KernelKey key; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = true; + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.persistent = false; + + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("_sparse"), std::string::npos); +} diff --git a/dispatcher/tests/test_kernel_key_extended.cpp b/dispatcher/tests/test_kernel_key_extended.cpp new file mode 100644 index 0000000000..1c6b5bcba0 --- /dev/null +++ b/dispatcher/tests/test_kernel_key_extended.cpp @@ -0,0 +1,453 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for KernelKey - covers all data types, layouts, pipelines + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "test_mock_kernel.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +// ============================================================================= +// DataType Tests +// ============================================================================= + +class DataTypeTest : public ::testing::Test +{ + protected: + void SetUp() override {} +}; + +TEST_F(DataTypeTest, AllDataTypesExist) +{ + // Every DataType should be accessible + std::vector all_types = {DataType::FP16, + DataType::BF16, + DataType::FP32, + DataType::FP64, + DataType::INT8, + DataType::INT4, + DataType::INT32, + DataType::FP8, + DataType::BF8, + DataType::UNKNOWN}; + + EXPECT_EQ(all_types.size(), 10); +} + +TEST_F(DataTypeTest, DataTypesAreDifferent) +{ + EXPECT_NE(DataType::FP16, DataType::BF16); + EXPECT_NE(DataType::FP16, DataType::FP32); + EXPECT_NE(DataType::INT8, DataType::INT4); +} + +// ============================================================================= +// LayoutTag Tests +// ============================================================================= + +class LayoutTagTest : public ::testing::Test +{ +}; + +TEST_F(LayoutTagTest, AllLayoutsExist) +{ + std::vector all_layouts = { + LayoutTag::RowMajor, LayoutTag::ColMajor, LayoutTag::PackedExternal}; + + EXPECT_EQ(all_layouts.size(), 3); +} + +TEST_F(LayoutTagTest, LayoutsAreDifferent) { EXPECT_NE(LayoutTag::RowMajor, LayoutTag::ColMajor); } + +// ============================================================================= +// Pipeline Tests +// ============================================================================= + +class PipelineTest : public ::testing::Test +{ +}; + +TEST_F(PipelineTest, AllPipelinesExist) +{ + std::vector all_pipelines = {Pipeline::Mem, + Pipeline::CompV1, + Pipeline::CompV2, + Pipeline::CompV3, + Pipeline::CompV4, + Pipeline::CompV5, + Pipeline::PreShuffleV1, + Pipeline::PreShuffleV2}; + + EXPECT_EQ(all_pipelines.size(), 8); +} + +TEST_F(PipelineTest, PipelinesAreDifferent) +{ + EXPECT_NE(Pipeline::Mem, Pipeline::CompV4); + EXPECT_NE(Pipeline::CompV3, Pipeline::CompV4); +} + +// ============================================================================= +// Scheduler Tests +// ============================================================================= + +class SchedulerTest : public ::testing::Test +{ +}; + +TEST_F(SchedulerTest, AllSchedulersExist) +{ + std::vector all_schedulers = { + Scheduler::Auto, Scheduler::Intrawave, Scheduler::Interwave}; + + EXPECT_EQ(all_schedulers.size(), 3); +} + +// ============================================================================= +// Epilogue Tests +// ============================================================================= + +class EpilogueTest : public ::testing::Test +{ +}; + +TEST_F(EpilogueTest, AllEpiloguesExist) +{ + std::vector all_epilogues = {Epilogue::None, + Epilogue::Default, + Epilogue::CShuffle, + Epilogue::Bias, + Epilogue::Activation, + Epilogue::BiasActivation}; + + EXPECT_EQ(all_epilogues.size(), 6); +} + +// ============================================================================= +// KernelKey::Signature Tests +// ============================================================================= + +class SignatureTest : public ::testing::Test +{ + protected: + KernelKey::Signature CreateDefaultSignature() + { + KernelKey::Signature sig; + sig.dtype_a = DataType::FP16; + sig.dtype_b = DataType::FP16; + sig.dtype_c = DataType::FP16; + sig.dtype_acc = DataType::FP32; + sig.layout_a = LayoutTag::RowMajor; + sig.layout_b = LayoutTag::ColMajor; + sig.layout_c = LayoutTag::RowMajor; + sig.transpose_a = false; + sig.transpose_b = false; + sig.grouped = false; + sig.split_k = 1; + sig.elementwise_op = "PassThrough"; + sig.num_d_tensors = 0; + sig.structured_sparsity = false; + return sig; + } +}; + +TEST_F(SignatureTest, DefaultValuesAreReasonable) +{ + KernelKey::Signature sig = CreateDefaultSignature(); + EXPECT_EQ(sig.split_k, 1); + EXPECT_FALSE(sig.grouped); + EXPECT_FALSE(sig.structured_sparsity); +} + +TEST_F(SignatureTest, AllDataTypeCombinations) +{ + // Test various data type combinations that should be valid + std::vector> valid_combos = { + {DataType::FP16, DataType::FP16, DataType::FP16, DataType::FP32}, + {DataType::BF16, DataType::BF16, DataType::BF16, DataType::FP32}, + {DataType::FP32, DataType::FP32, DataType::FP32, DataType::FP32}, + {DataType::INT8, DataType::INT8, DataType::INT8, DataType::INT32}, + }; + + for(const auto& [a, b, c, acc] : valid_combos) + { + KernelKey::Signature sig; + sig.dtype_a = a; + sig.dtype_b = b; + sig.dtype_c = c; + sig.dtype_acc = acc; + + EXPECT_EQ(sig.dtype_a, a); + EXPECT_EQ(sig.dtype_b, b); + EXPECT_EQ(sig.dtype_c, c); + EXPECT_EQ(sig.dtype_acc, acc); + } +} + +TEST_F(SignatureTest, AllLayoutCombinations) +{ + std::vector layout_codes = { + "rrr", "rcr", "crr", "ccr", "rrc", "rcc", "crc", "ccc"}; + + for(const std::string& code : layout_codes) + { + KernelKey::Signature sig = CreateDefaultSignature(); + sig.layout_a = (code[0] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor; + sig.layout_b = (code[1] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor; + sig.layout_c = (code[2] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor; + + // Just verify assignment works + EXPECT_TRUE(sig.layout_a == LayoutTag::RowMajor || sig.layout_a == LayoutTag::ColMajor); + } +} + +TEST_F(SignatureTest, SplitKValues) +{ + KernelKey::Signature sig = CreateDefaultSignature(); + + std::vector valid_split_k = {1, 2, 4, 8, 16}; + for(auto sk : valid_split_k) + { + sig.split_k = sk; + EXPECT_EQ(sig.split_k, sk); + } +} + +// ============================================================================= +// KernelKey::Algorithm Tests +// ============================================================================= + +class AlgorithmTest : public ::testing::Test +{ + protected: + KernelKey::Algorithm CreateDefaultAlgorithm() + { + KernelKey::Algorithm algo; + algo.tile_shape = {256, 256, 32}; + algo.wave_shape = {2, 2, 1}; + algo.warp_tile_shape = {32, 32, 16}; + algo.pipeline = Pipeline::CompV4; + algo.scheduler = Scheduler::Intrawave; + algo.epilogue = Epilogue::CShuffle; + algo.block_size = 256; + algo.double_buffer = true; + algo.persistent = false; + algo.preshuffle = false; + algo.transpose_c = false; + algo.num_wave_groups = 1; + return algo; + } +}; + +TEST_F(AlgorithmTest, CommonTileShapes) +{ + std::vector> valid_tiles = { + {64, 64, 32}, + {128, 128, 32}, + {128, 128, 64}, + {256, 256, 32}, + {256, 256, 64}, + {256, 128, 32}, + {128, 256, 32}, + }; + + for(const auto& [m, n, k] : valid_tiles) + { + KernelKey::Algorithm algo = CreateDefaultAlgorithm(); + algo.tile_shape = {static_cast(m), + static_cast(n), + static_cast(k)}; + + EXPECT_EQ(algo.tile_shape.m, m); + EXPECT_EQ(algo.tile_shape.n, n); + EXPECT_EQ(algo.tile_shape.k, k); + } +} + +TEST_F(AlgorithmTest, CommonWarpConfigs) +{ + std::vector> valid_warps = { + {1, 4, 1}, + {2, 2, 1}, + {4, 1, 1}, + {1, 2, 1}, + {2, 1, 1}, + }; + + for(const auto& [m, n, k] : valid_warps) + { + KernelKey::Algorithm algo = CreateDefaultAlgorithm(); + algo.wave_shape = {static_cast(m), + static_cast(n), + static_cast(k)}; + + EXPECT_EQ(algo.wave_shape.m, m); + EXPECT_EQ(algo.wave_shape.n, n); + EXPECT_EQ(algo.wave_shape.k, k); + } +} + +TEST_F(AlgorithmTest, AllPipelines) +{ + KernelKey::Algorithm algo = CreateDefaultAlgorithm(); + + std::vector pipelines = {Pipeline::Mem, + Pipeline::CompV3, + Pipeline::CompV4, + Pipeline::PreShuffleV1, + Pipeline::PreShuffleV2}; + + for(Pipeline p : pipelines) + { + algo.pipeline = p; + EXPECT_EQ(algo.pipeline, p); + } +} + +// ============================================================================= +// KernelKey Identifier Encoding Tests +// ============================================================================= + +class IdentifierEncodingTest : public ::testing::Test +{ +}; + +TEST_F(IdentifierEncodingTest, UniqueIdentifiersForDifferentConfigs) +{ + std::set identifiers; + + // Generate multiple configurations + for(int tile_m : {128, 256}) + { + for(int wave_m : {1, 2, 4}) + { + for(bool persistent : {true, false}) + { + KernelKey key = make_test_key(tile_m); + key.algorithm.wave_shape.m = wave_m; + key.algorithm.persistent = persistent; + + std::string id = key.encode_identifier(); + EXPECT_TRUE(identifiers.find(id) == identifiers.end()) + << "Duplicate identifier: " << id; + identifiers.insert(id); + } + } + } + + // Should have generated 2 * 3 * 2 = 12 unique identifiers + EXPECT_EQ(identifiers.size(), 12); +} + +TEST_F(IdentifierEncodingTest, IdentifierContainsTileShape) +{ + KernelKey key = make_test_key(256, 128, 64); + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("256x128x64"), std::string::npos) + << "Identifier should contain tile shape: " << id; +} + +TEST_F(IdentifierEncodingTest, IdentifierContainsWarpConfig) +{ + KernelKey key = make_test_key(256); + key.algorithm.wave_shape = {4, 2, 1}; + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("4x2x1"), std::string::npos) + << "Identifier should contain warp config: " << id; +} + +TEST_F(IdentifierEncodingTest, IdentifierReflectsPersistence) +{ + KernelKey persistent_key = make_test_key(256); + persistent_key.algorithm.persistent = true; + + KernelKey non_persistent_key = make_test_key(256); + non_persistent_key.algorithm.persistent = false; + + std::string persistent_id = persistent_key.encode_identifier(); + std::string non_persistent_id = non_persistent_key.encode_identifier(); + + EXPECT_NE(persistent_id, non_persistent_id); + EXPECT_NE(persistent_id.find("persist"), std::string::npos); + EXPECT_NE(non_persistent_id.find("nopers"), std::string::npos); +} + +// ============================================================================= +// KernelKey Equality Tests +// ============================================================================= + +class KeyEqualityTest : public ::testing::Test +{ +}; + +TEST_F(KeyEqualityTest, IdenticalKeysAreEqual) +{ + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + KernelKey key2 = make_test_key(256, 256, 32, "gfx942"); + + EXPECT_EQ(key1, key2); + EXPECT_FALSE(key1 != key2); +} + +TEST_F(KeyEqualityTest, DifferentTileShapesNotEqual) +{ + KernelKey key1 = make_test_key(256, 256, 32); + KernelKey key2 = make_test_key(128, 128, 32); + + EXPECT_NE(key1, key2); +} + +TEST_F(KeyEqualityTest, DifferentDataTypesNotEqual) +{ + KernelKey key1 = make_test_key(256); + KernelKey key2 = make_test_key(256); + key2.signature.dtype_a = DataType::BF16; + + EXPECT_NE(key1, key2); +} + +TEST_F(KeyEqualityTest, DifferentLayoutsNotEqual) +{ + KernelKey key1 = make_test_key(256); + KernelKey key2 = make_test_key(256); + key2.signature.layout_a = LayoutTag::ColMajor; + + EXPECT_NE(key1, key2); +} + +TEST_F(KeyEqualityTest, DifferentGfxArchNotEqual) +{ + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + KernelKey key2 = make_test_key(256, 256, 32, "gfx90a"); + + EXPECT_NE(key1, key2); +} + +// ============================================================================= +// ElementwiseOps Tests +// ============================================================================= + +class ElementwiseOpsTest : public ::testing::Test +{ +}; + +TEST_F(ElementwiseOpsTest, CanUseInKernelKey) +{ + KernelKey key = make_test_key(256); + + key.signature.elementwise_op = "Relu"; + EXPECT_EQ(key.signature.elementwise_op, "Relu"); + + key.signature.elementwise_op = "Gelu"; + EXPECT_EQ(key.signature.elementwise_op, "Gelu"); + + key.signature.elementwise_op = "PassThrough"; + EXPECT_EQ(key.signature.elementwise_op, "PassThrough"); +} diff --git a/dispatcher/tests/test_minimal.cpp b/dispatcher/tests/test_minimal.cpp new file mode 100644 index 0000000000..22efc2524c --- /dev/null +++ b/dispatcher/tests/test_minimal.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Minimal test: Verify dispatcher can select and run a kernel +#include +#include +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +int main() +{ + std::cout << "Minimal Dispatcher Test\n"; + std::cout << "=======================\n\n"; + + // Create a mock kernel for testing + KernelKey key = make_test_key(128, 128, 64, "gfx942"); + auto kernel = std::make_shared(key, "test_kernel_128x128x64", true); + + // Register kernel + Registry::instance().clear(); + Registry::instance().register_kernel(kernel); + + std::cout << "OK Registered kernel: " << kernel->get_name() << "\n"; + + // Create dispatcher and problem + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + std::cout << "OK Created problem: M=" << problem.M << " N=" << problem.N << " K=" << problem.K + << "\n"; + + // Select kernel + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "[FAIL] Failed to select kernel\n"; + return 1; + } + + std::cout << "OK Selected kernel: " << selected->get_name() << "\n"; + + // Mock execution (no actual GPU computation in mock kernel) + void* a_ptr = nullptr; + void* b_ptr = nullptr; + void* c_ptr = nullptr; + + float time = dispatcher.run(a_ptr, b_ptr, c_ptr, problem); + + std::cout << "OK Executed kernel: " << time << " ms\n"; + std::cout << "\n[OK] Minimal test passed!\n"; + + return 0; +} diff --git a/dispatcher/tests/test_mock_kernel.cpp b/dispatcher/tests/test_mock_kernel.cpp new file mode 100644 index 0000000000..fd8f3f4baa --- /dev/null +++ b/dispatcher/tests/test_mock_kernel.cpp @@ -0,0 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mock_kernel.hpp" + +// Empty file - implementation is in header diff --git a/dispatcher/tests/test_mock_kernel.hpp b/dispatcher/tests/test_mock_kernel.hpp new file mode 100644 index 0000000000..7d511719a8 --- /dev/null +++ b/dispatcher/tests/test_mock_kernel.hpp @@ -0,0 +1,134 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include + +namespace ck_tile { +namespace dispatcher { +namespace test { + +/// Mock kernel instance for testing dispatcher functionality +/// Supports configurable behavior for testing different scenarios +class MockKernelInstance : public KernelInstance +{ + public: + /// Constructor + /// @param key Kernel configuration key + /// @param name Human-readable kernel name + /// @param supports_all Whether this kernel supports all problems (default: true) + explicit MockKernelInstance(const KernelKey& key, + const std::string& name, + bool supports_all = true) + : key_(key), name_(name), supports_all_(supports_all), execution_count_(0) + { + } + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + if(supports_all_) + { + return problem.is_valid(); + } + // For testing: only support problems where M/N/K are divisible by tile sizes + return problem.is_valid() && (problem.M % key_.algorithm.tile_shape.m == 0) && + (problem.N % key_.algorithm.tile_shape.n == 0) && + (problem.K % key_.algorithm.tile_shape.k == 0); + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + execution_count_++; + // Simulate execution time (1ms for testing) + return 1.0f; + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + // Mock validation always passes + return true; + } + + /// Get execution count (for testing) + int get_execution_count() const { return execution_count_; } + + /// Reset execution count + void reset_execution_count() { execution_count_ = 0; } + + /// Set whether this kernel supports all problems + void set_supports_all(bool supports_all) { supports_all_ = supports_all; } + + private: + KernelKey key_; + std::string name_; + bool supports_all_; + mutable int execution_count_; +}; + +/// Helper function to create a test kernel key +inline KernelKey make_test_key(std::uint16_t tile_m = 256, + std::uint16_t tile_n = 256, + std::uint16_t tile_k = 32, + const std::string& gfx_arch = "gfx942") +{ + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape.m = tile_m; + key.algorithm.tile_shape.n = tile_n; + key.algorithm.tile_shape.k = tile_k; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + + key.gfx_arch = gfx_arch; + + return key; +} + +} // namespace test +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/tests/test_problem.cpp b/dispatcher/tests/test_problem.cpp new file mode 100644 index 0000000000..7d5500e320 --- /dev/null +++ b/dispatcher/tests/test_problem.cpp @@ -0,0 +1,96 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for Problem using Google Test + +#include "ck_tile/dispatcher/problem.hpp" +#include + +using namespace ck_tile::dispatcher; + +TEST(ProblemTest, DefaultConstruction) +{ + Problem p; + EXPECT_EQ(p.M, 0); + EXPECT_EQ(p.N, 0); + EXPECT_EQ(p.K, 0); + EXPECT_EQ(p.k_batch, 1); + EXPECT_FALSE(p.is_valid()); +} + +TEST(ProblemTest, ConstructorWithDimensions) +{ + Problem p(1024, 1024, 1024); + EXPECT_EQ(p.M, 1024); + EXPECT_EQ(p.N, 1024); + EXPECT_EQ(p.K, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST(ProblemTest, Validation) +{ + Problem p; + + // Invalid: all zeros + p.M = 0; + p.N = 0; + p.K = 0; + EXPECT_FALSE(p.is_valid()); + + // Invalid: negative + p.M = -1; + p.N = 1024; + p.K = 1024; + EXPECT_FALSE(p.is_valid()); + + // Invalid: zero K + p.M = 1024; + p.N = 1024; + p.K = 0; + EXPECT_FALSE(p.is_valid()); + + // Valid + p.M = 1024; + p.N = 1024; + p.K = 1024; + EXPECT_TRUE(p.is_valid()); + + // Invalid k_batch + p.k_batch = 0; + EXPECT_FALSE(p.is_valid()); + + p.k_batch = 1; + EXPECT_TRUE(p.is_valid()); +} + +TEST(ProblemTest, NumOps) +{ + Problem p(100, 200, 300); + + // 2 * M * N * K (multiply-add = 2 ops) + std::int64_t expected = 2 * 100 * 200 * 300; + EXPECT_EQ(p.num_ops(), expected); +} + +TEST(ProblemTest, Configuration) +{ + Problem p(1024, 1024, 1024); + + // Set preferences + p.prefer_persistent = true; + p.enable_validation = true; + p.smem_budget = 65536; + p.k_batch = 2; + + EXPECT_TRUE(p.prefer_persistent); + EXPECT_TRUE(p.enable_validation); + EXPECT_EQ(p.smem_budget, 65536); + EXPECT_EQ(p.k_batch, 2); +} + +TEST(ProblemTest, LargeDimensions) +{ + Problem p(1024, 1024, 1024); // Use smaller but still large dimensions + EXPECT_TRUE(p.is_valid()); + EXPECT_GT(p.num_ops(), 0); +} diff --git a/dispatcher/tests/test_problem_extended.cpp b/dispatcher/tests/test_problem_extended.cpp new file mode 100644 index 0000000000..21ea545292 --- /dev/null +++ b/dispatcher/tests/test_problem_extended.cpp @@ -0,0 +1,457 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for Problem - covers dimension inference, validation, edge cases + +#include "ck_tile/dispatcher/problem.hpp" +#include +#include + +using namespace ck_tile::dispatcher; + +// ============================================================================= +// Dimension Inference Tests +// ============================================================================= + +class ProblemDimensionInferenceTest : public ::testing::Test +{ +}; + +TEST_F(ProblemDimensionInferenceTest, FromAB_Basic) +{ + // A: M×K (1024×512), B: K×N (512×2048) + auto problem = Problem::from_ab(1024, 512, 512, 2048); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemDimensionInferenceTest, FromDimensions_Valid) +{ + // A: 1024×512, B: 512×2048, C: 1024×2048 + auto problem = Problem::from_dimensions(1024, 512, 512, 2048, 1024, 2048); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemDimensionInferenceTest, FromShapes_WithC) +{ + TensorShape A{1024, 512, false}; + TensorShape B{512, 2048, false}; + TensorShape C{1024, 2048, false}; + + auto problem = Problem::from_shapes(A, B, C); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA) +{ + // A stored as K×M (transposed) + TensorShape A{512, 1024, true}; + TensorShape B{512, 2048, false}; + TensorShape C{1024, 2048, false}; + + auto problem = Problem::from_shapes(A, B, C); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); +} + +TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedB) +{ + TensorShape A{1024, 512, false}; + // B stored as N×K (transposed) + TensorShape B{2048, 512, true}; + TensorShape C{1024, 2048, false}; + + auto problem = Problem::from_shapes(A, B, C); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); +} + +// ============================================================================= +// Validation Tests +// ============================================================================= + +class ProblemValidationTest : public ::testing::Test +{ +}; + +TEST_F(ProblemValidationTest, ValidProblem) +{ + Problem p(1024, 1024, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroM) +{ + Problem p(0, 1024, 1024); + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroN) +{ + Problem p(1024, 0, 1024); + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroK) +{ + Problem p(1024, 1024, 0); + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, NegativeM) +{ + Problem p; + p.M = -1; + p.N = 1024; + p.K = 1024; + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroKBatch) +{ + Problem p(1024, 1024, 1024); + p.k_batch = 0; + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ValidKBatch) +{ + Problem p(1024, 1024, 1024); + p.k_batch = 4; + EXPECT_TRUE(p.is_valid()); +} + +// ============================================================================= +// num_ops Tests +// ============================================================================= + +class ProblemNumOpsTest : public ::testing::Test +{ +}; + +TEST_F(ProblemNumOpsTest, SmallProblem) +{ + Problem p(10, 20, 30); + // 2 * M * N * K = 2 * 10 * 20 * 30 = 12000 + EXPECT_EQ(p.num_ops(), 12000); +} + +TEST_F(ProblemNumOpsTest, SymmetricProblem) +{ + Problem p(1024, 1024, 1024); + // 2 * 1024^3 = 2,147,483,648 + EXPECT_EQ(p.num_ops(), 2LL * 1024 * 1024 * 1024); +} + +TEST_F(ProblemNumOpsTest, AsymmetricProblem) +{ + Problem p(512, 2048, 256); + EXPECT_EQ(p.num_ops(), 2LL * 512 * 2048 * 256); +} + +TEST_F(ProblemNumOpsTest, LargeProblem) +{ + Problem p(4096, 4096, 4096); + std::int64_t expected = 2LL * 4096 * 4096 * 4096; + EXPECT_EQ(p.num_ops(), expected); + EXPECT_GT(p.num_ops(), 0); // No overflow +} + +// ============================================================================= +// Edge Cases +// ============================================================================= + +class ProblemEdgeCasesTest : public ::testing::Test +{ +}; + +TEST_F(ProblemEdgeCasesTest, MinimumValidSize) +{ + Problem p(1, 1, 1); + EXPECT_TRUE(p.is_valid()); + EXPECT_EQ(p.num_ops(), 2); +} + +TEST_F(ProblemEdgeCasesTest, NonSquare_TallMatrix) +{ + Problem p(8192, 64, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, NonSquare_WideMatrix) +{ + Problem p(64, 8192, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, NonSquare_DeepK) +{ + Problem p(1024, 1024, 8192); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, SmallK) +{ + Problem p(1024, 1024, 16); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, NonPowerOf2Dimensions) +{ + Problem p(1000, 2000, 300); + EXPECT_TRUE(p.is_valid()); + EXPECT_EQ(p.num_ops(), 2LL * 1000 * 2000 * 300); +} + +TEST_F(ProblemEdgeCasesTest, PrimeDimensions) +{ + Problem p(997, 1009, 1013); // All prime numbers + EXPECT_TRUE(p.is_valid()); +} + +// ============================================================================= +// Configuration Tests +// ============================================================================= + +class ProblemConfigurationTest : public ::testing::Test +{ +}; + +TEST_F(ProblemConfigurationTest, DefaultConfiguration) +{ + Problem p(1024, 1024, 1024); + + EXPECT_FALSE(p.prefer_persistent); + EXPECT_FALSE(p.enable_validation); + EXPECT_EQ(p.smem_budget, 0); + EXPECT_EQ(p.k_batch, 1); +} + +TEST_F(ProblemConfigurationTest, SetPersistentPreference) +{ + Problem p(1024, 1024, 1024); + p.prefer_persistent = true; + + EXPECT_TRUE(p.prefer_persistent); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemConfigurationTest, SetSmemBudget) +{ + Problem p(1024, 1024, 1024); + p.smem_budget = 65536; // 64KB + + EXPECT_EQ(p.smem_budget, 65536); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemConfigurationTest, SetKBatch) +{ + Problem p(1024, 1024, 1024); + + for(int kb : {1, 2, 4, 8, 16}) + { + p.k_batch = kb; + EXPECT_EQ(p.k_batch, kb); + EXPECT_TRUE(p.is_valid()); + } +} + +// ============================================================================= +// Copy and Assignment Tests +// ============================================================================= + +class ProblemCopyTest : public ::testing::Test +{ +}; + +TEST_F(ProblemCopyTest, CopyConstruction) +{ + Problem p1(1024, 2048, 512); + p1.prefer_persistent = true; + p1.k_batch = 4; + + Problem p2(p1); + + EXPECT_EQ(p2.M, 1024); + EXPECT_EQ(p2.N, 2048); + EXPECT_EQ(p2.K, 512); + EXPECT_TRUE(p2.prefer_persistent); + EXPECT_EQ(p2.k_batch, 4); +} + +TEST_F(ProblemCopyTest, Assignment) +{ + Problem p1(1024, 2048, 512); + Problem p2(256, 256, 256); + + p2 = p1; + + EXPECT_EQ(p2.M, 1024); + EXPECT_EQ(p2.N, 2048); + EXPECT_EQ(p2.K, 512); +} + +// ============================================================================= +// Builder Tests +// ============================================================================= + +class ProblemBuilderTest : public ::testing::Test +{ +}; + +TEST_F(ProblemBuilderTest, BasicBuild) +{ + auto problem = ProblemBuilder().dimensions(1024, 2048, 512).build(); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemBuilderTest, WithSplitK) +{ + auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).split_k(4).build(); + + EXPECT_EQ(problem.k_batch, 4); +} + +TEST_F(ProblemBuilderTest, WithPersistent) +{ + auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).persistent(true).build(); + + EXPECT_TRUE(problem.prefer_persistent); +} + +TEST_F(ProblemBuilderTest, WithSmemBudget) +{ + auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).smem_budget(65536).build(); + + EXPECT_EQ(problem.smem_budget, 65536); +} + +TEST_F(ProblemBuilderTest, ChainedConfiguration) +{ + auto problem = ProblemBuilder() + .dimensions(2048, 2048, 1024) + .split_k(2) + .persistent(true) + .smem_budget(32768) + .validate(true) + .build(); + + EXPECT_EQ(problem.M, 2048); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 1024); + EXPECT_EQ(problem.k_batch, 2); + EXPECT_TRUE(problem.prefer_persistent); + EXPECT_EQ(problem.smem_budget, 32768); + EXPECT_TRUE(problem.enable_validation); +} + +TEST_F(ProblemBuilderTest, FromAB) +{ + auto problem = ProblemBuilder().from_ab(1024, 512, 512, 2048).build(); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); +} + +// ============================================================================= +// Dimension Mismatch Error Tests +// ============================================================================= + +class ProblemDimensionErrorTest : public ::testing::Test +{ +}; + +TEST_F(ProblemDimensionErrorTest, KMismatchThrows) +{ + EXPECT_THROW((void)Problem::from_ab(1024, 512, 256, 2048), // K mismatch: 512 vs 256 + std::invalid_argument); +} + +TEST_F(ProblemDimensionErrorTest, MDimensionMismatchThrows) +{ + TensorShape A{1024, 512, false}; + TensorShape B{512, 2048, false}; + TensorShape C{512, 2048, false}; // M mismatch: A says M=1024, C says M=512 + + EXPECT_THROW((void)Problem::from_shapes(A, B, C), std::invalid_argument); +} + +TEST_F(ProblemDimensionErrorTest, NDimensionMismatchThrows) +{ + TensorShape A{1024, 512, false}; + TensorShape B{512, 2048, false}; + TensorShape C{1024, 1024, false}; // N mismatch: B says N=2048, C says N=1024 + + EXPECT_THROW((void)Problem::from_shapes(A, B, C), std::invalid_argument); +} + +// ============================================================================= +// Validate Sizes Tests +// ============================================================================= + +class ProblemValidateSizesTest : public ::testing::Test +{ +}; + +TEST_F(ProblemValidateSizesTest, CorrectSizes) +{ + Problem p(1024, 2048, 512); + + // This should not throw + EXPECT_NO_THROW(p.validate_sizes(1024 * 512, // A size + 512 * 2048, // B size + 1024 * 2048 // C size + )); +} + +TEST_F(ProblemValidateSizesTest, WrongASizeThrows) +{ + Problem p(1024, 2048, 512); + + EXPECT_THROW(p.validate_sizes(1024 * 256, // Wrong A size + 512 * 2048, + 1024 * 2048), + std::invalid_argument); +} + +TEST_F(ProblemValidateSizesTest, WrongBSizeThrows) +{ + Problem p(1024, 2048, 512); + + EXPECT_THROW(p.validate_sizes(1024 * 512, + 256 * 2048, // Wrong B size + 1024 * 2048), + std::invalid_argument); +} + +TEST_F(ProblemValidateSizesTest, WrongCSizeThrows) +{ + Problem p(1024, 2048, 512); + + EXPECT_THROW(p.validate_sizes(1024 * 512, + 512 * 2048, + 512 * 1024 // Wrong C size + ), + std::invalid_argument); +} diff --git a/dispatcher/tests/test_real_kernel_correctness.cpp b/dispatcher/tests/test_real_kernel_correctness.cpp new file mode 100644 index 0000000000..e753f04e19 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_correctness.cpp @@ -0,0 +1,232 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Correctness test with real GPU kernel + * Validates GPU results against CPU reference implementation + */ + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +// CPU reference GEMM +// A: RowMajor (M x K) - A[m,k] = A[m*K + k] +// B: ColumnMajor (K x N) - B[k,n] = B[k + n*K] +// C: RowMajor (M x N) - C[m,n] = C[m*N + n] +template +void cpu_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + // A is row-major: A[m,k] = A[m*K + k] + // B is column-major: B[k,n] = B[k + n*K] + acc += float(A[m * K + k]) * float(B[k + n * K]); + } + C[m * N + n] = T(acc); + } + } +} + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Correctness Test - Real GPU Kernel\n"; + std::cout << "=======================================\n\n"; + + std::cout << "Kernel: " << KERNEL_NAME << "\n\n"; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + + // Test with random matrices + const int M = 256; + const int N = 256; + const int K = 256; + + std::cout << "Test configuration:\n"; + std::cout << " Problem: M=" << M << " N=" << N << " K=" << K << "\n"; + std::cout << " Method: Random matrices vs CPU reference\n\n"; + + // Random number generation + std::mt19937 rng(42); // Fixed seed for reproducibility + std::uniform_real_distribution dist(-1.0f, 1.0f); + + std::vector A_host(M * K); + std::vector B_host(K * N); + std::vector C_gpu(M * N); + std::vector C_cpu(M * N); + + // Initialize with random values + std::cout << "Initializing random matrices...\n"; + for(int i = 0; i < M * K; i++) + { + A_host[i] = ADataType(dist(rng)); + } + for(int i = 0; i < K * N; i++) + { + B_host[i] = BDataType(dist(rng)); + } + + // GPU execution + std::cout << "Executing on GPU...\n"; + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Problem problem(M, N, K); + float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + std::cout << "OK GPU execution complete: " << gpu_time << " ms\n"; + + double flops = 2.0 * M * N * K; + double tflops = (flops / (gpu_time * 1e-3)) / 1e12; + std::cout << "OK GPU performance: " << tflops << " TFLOPS\n\n"; + + // CPU reference + std::cout << "Computing CPU reference...\n"; + cpu_gemm(A_host, B_host, C_cpu, M, N, K); + std::cout << "OK CPU reference complete\n\n"; + + // Validation + std::cout << "Validating results...\n"; + + int num_correct = 0; + float max_rel_error = 0.0f; + float max_abs_error = 0.0f; + const float tolerance = 0.02f; // 2% for FP16 + + for(int i = 0; i < M * N; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + + float abs_error = std::abs(gpu_val - cpu_val); + float rel_error = abs_error / (std::abs(cpu_val) + 1e-5f); + + max_abs_error = std::max(max_abs_error, abs_error); + max_rel_error = std::max(max_rel_error, rel_error); + + if(rel_error < tolerance) + { + num_correct++; + } + } + + float accuracy = 100.0f * num_correct / (M * N); + + std::cout << "\nValidation Results:\n"; + std::cout << " Correct elements: " << num_correct << "/" << M * N << "\n"; + std::cout << " Accuracy: " << accuracy << "%\n"; + std::cout << " Max absolute error: " << max_abs_error << "\n"; + std::cout << " Max relative error: " << max_rel_error << "\n"; + std::cout << " Tolerance: " << tolerance << " (2%)\n\n"; + + // Show sample comparisons + std::cout << "Sample results (first 5 elements):\n"; + std::cout << " Index | GPU Result | CPU Result | Error\n"; + std::cout << " ------|------------|------------|-------\n"; + + for(int i = 0; i < 5; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + float error = std::abs(gpu_val - cpu_val); + printf(" %-5d | %10.4f | %10.4f | %.4f\n", i, gpu_val, cpu_val, error); + } + std::cout << "\n"; + + // Cleanup + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + if(accuracy > 99.0f) + { + std::cout << "[OK] CORRECTNESS TEST PASSED\n"; + std::cout << " GPU results match CPU reference within tolerance\n"; + return 0; + } + else + { + std::cout << "[FAIL] CORRECTNESS TEST FAILED\n"; + std::cout << " Accuracy too low: " << accuracy << "%\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_real_kernel_multi_size.cpp b/dispatcher/tests/test_real_kernel_multi_size.cpp new file mode 100644 index 0000000000..f23f684631 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_multi_size.cpp @@ -0,0 +1,213 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Multi-size real kernel test: Test multiple problem sizes with real GPU kernel + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +struct TestResult +{ + int M, N, K; + float time_ms; + double tflops; + int correct; + int total; + bool passed; +}; + +TestResult run_test(Dispatcher& dispatcher, int M, int N, int K) +{ + TestResult result = {M, N, K, 0.0f, 0.0, 0, M * N, false}; + + // Allocate and prepare data + std::vector A_host(M * K); + std::vector B_host(K * N); + std::vector C_gpu(M * N); + + // Initialize: A=1, B=1, expected C=K + for(int i = 0; i < M * K; i++) + A_host[i] = ADataType(1.0f); + for(int i = 0; i < K * N; i++) + B_host[i] = BDataType(1.0f); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + // Execute + Problem problem(M, N, K); + result.time_ms = dispatcher.run(A_dev, B_dev, C_dev, problem); + + // Calculate performance + double flops = 2.0 * M * N * K; + result.tflops = (flops / (result.time_ms * 1e-3)) / 1e12; + + // Copy result and validate + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C_gpu[i]) - float(K)) < 1.0f) + { + result.correct++; + } + } + + result.passed = (result.correct == result.total); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + return result; +} + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Multi-Size Real Kernel Test\n"; + std::cout << "=======================================\n\n"; + + std::cout << "Using kernel: " << KERNEL_NAME << "\n\n"; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + + std::cout << "Running tests on multiple problem sizes...\n"; + std::cout << "===========================================\n\n"; + + // Test various sizes (all multiples of tile size) + std::vector> test_sizes = { + {128, 128, 128}, // Small + {256, 256, 256}, // Medium + {512, 512, 512}, // Large + {1024, 1024, 1024}, // Very large + {128, 512, 256}, // Non-square + {512, 128, 384}, // Non-square + }; + + std::vector results; + int num_passed = 0; + + for(const auto& [M, N, K] : test_sizes) + { + std::cout << "Testing M=" << M << " N=" << N << " K=" << K << "...\n"; + + auto result = run_test(dispatcher, M, N, K); + results.push_back(result); + + std::cout << " Time: " << result.time_ms << " ms\n"; + std::cout << " Performance: " << result.tflops << " TFLOPS\n"; + std::cout << " Accuracy: " << (100.0f * result.correct / result.total) << "%\n"; + std::cout << " Status: " << (result.passed ? "[OK] PASS" : "[FAIL] FAIL") << "\n\n"; + + if(result.passed) + num_passed++; + } + + // Summary + std::cout << "===========================================\n"; + std::cout << "Summary\n"; + std::cout << "===========================================\n\n"; + + std::cout << "Results by size:\n"; + std::cout << " Size | Time (ms) | TFLOPS | Accuracy | Status\n"; + std::cout << " ---------------|-----------|--------|----------|--------\n"; + + for(const auto& r : results) + { + char size_str[32]; + snprintf(size_str, sizeof(size_str), "%4d×%4d×%4d", r.M, r.N, r.K); + + printf(" %-14s | %9.4f | %6.2f | %7.2f%% | %s\n", + size_str, + r.time_ms, + r.tflops, + 100.0f * r.correct / r.total, + r.passed ? "[OK]" : "[FAIL]"); + } + + std::cout << "\n"; + std::cout << "Tests passed: " << num_passed << "/" << results.size() << "\n"; + + if(num_passed == results.size()) + { + std::cout << "\n[OK] ALL TESTS PASSED\n"; + return 0; + } + else + { + std::cout << "\n[FAIL] SOME TESTS FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_real_kernel_performance.cpp b/dispatcher/tests/test_real_kernel_performance.cpp new file mode 100644 index 0000000000..ff3d635968 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_performance.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Performance test with real GPU kernel + * Measures and reports detailed performance metrics + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Performance Test - Real GPU Kernel\n"; + std::cout << "=======================================\n\n"; + + std::cout << "Kernel: " << KERNEL_NAME << "\n"; + std::cout << "Device: AMD Instinct MI325X (gfx942)\n\n"; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + + // Performance benchmark sizes + std::vector> benchmarks = { + {128, 128, 128, "Tiny"}, + {256, 256, 256, "Small"}, + {512, 512, 512, "Medium"}, + {1024, 1024, 1024, "Large"}, + {2048, 2048, 2048, "Very Large"}, + }; + + std::cout << "Performance Benchmark Results\n"; + std::cout << "=============================\n\n"; + + std::cout << " Size | Time (ms) | TFLOPS | BW (GB/s) | Status\n"; + std::cout << " ----------|-----------|--------|-----------|--------\n"; + + bool all_passed = true; + + for(const auto& [M, N, K, label] : benchmarks) + { + // Prepare data + std::vector A_host(M * K, ADataType(1.0f)); + std::vector B_host(K * N, BDataType(1.0f)); + std::vector C_gpu(M * N); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK( + hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + // Execute + Problem problem(M, N, K); + float time_ms = dispatcher.run(A_dev, B_dev, C_dev, problem); + + // Calculate metrics + double flops = 2.0 * M * N * K; + double tflops = (flops / (time_ms * 1e-3)) / 1e12; + + // Bandwidth (A + B read, C write) + double bytes = (M * K + K * N + M * N) * sizeof(CDataType); + double bandwidth_gbs = (bytes / (time_ms * 1e-3)) / 1e9; + + // Validate + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C_gpu[i]) - float(K)) < 1.0f) + correct++; + } + + bool passed = (correct == M * N); + all_passed = all_passed && passed; + + char size_label[32]; + snprintf(size_label, sizeof(size_label), "%s %d³", label, M); + + printf(" %-9s | %9.4f | %6.2f | %9.1f | %s\n", + size_label, + time_ms, + tflops, + bandwidth_gbs, + passed ? "[OK]" : "[FAIL]"); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + } + + std::cout << "\n"; + + if(all_passed) + { + std::cout << "[OK] ALL PERFORMANCE TESTS PASSED\n"; + return 0; + } + else + { + std::cout << "[FAIL] SOME TESTS FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_real_kernel_simple.cpp b/dispatcher/tests/test_real_kernel_simple.cpp new file mode 100644 index 0000000000..72e3a5fc87 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_simple.cpp @@ -0,0 +1,201 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Simple real kernel test using tile_engine style (single kernel with -include) + * This follows the proven pattern from the examples + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header will be included via -include compiler flag +// It defines: ADataType, BDataType, CDataType, AccDataType, SelectedKernel, KERNEL_NAME + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +// Reference CPU GEMM +template +void reference_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + acc += float(A[m * K + k]) * float(B[k * N + n]); + } + C[m * N + n] = T(acc); + } + } +} + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Simple Real Kernel Test\n"; + std::cout << "=======================================\n\n"; + + // Test size (must be multiple of tile size) + const int M = 256; + const int N = 256; + const int K = 256; + + std::cout << "Problem: M=" << M << " N=" << N << " K=" << K << "\n"; + std::cout << "Kernel: " << KERNEL_NAME << "\n\n"; + + // Create kernel key + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 64}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + // Create and register kernel + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + std::cout << "OK Registered kernel\n"; + + // Create dispatcher + Dispatcher dispatcher; + Problem problem(M, N, K); + + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "[FAIL] Failed to select kernel\n"; + return 1; + } + std::cout << "OK Selected kernel: " << selected->get_name() << "\n\n"; + + // Prepare data + std::cout << "Preparing test data...\n"; + std::vector A_host(M * K); + std::vector B_host(K * N); + std::vector C_gpu(M * N); + std::vector C_cpu(M * N); + + // Simple test: A=1, B=1, C should be K + for(int i = 0; i < M * K; i++) + A_host[i] = ADataType(1.0f); + for(int i = 0; i < K * N; i++) + B_host[i] = BDataType(1.0f); + + // Allocate GPU memory + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + std::cout << "OK Data ready on GPU\n\n"; + + // Execute + std::cout << "Executing GPU kernel...\n"; + float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + std::cout << "OK GPU time: " << gpu_time << " ms\n"; + + double flops = 2.0 * M * N * K; + double tflops = (flops / (gpu_time * 1e-3)) / 1e12; + std::cout << "OK Performance: " << tflops << " TFLOPS\n\n"; + + // Copy result + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Validate + std::cout << "Validating (expected: all elements = " << K << ")...\n"; + + int correct = 0; + for(int i = 0; i < M * N; i++) + { + float val = float(C_gpu[i]); + if(std::abs(val - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + std::cout << "Accuracy: " << accuracy << "% (" << correct << "/" << M * N << ")\n"; + + // Show samples + std::cout << "\nFirst 5 results:\n"; + for(int i = 0; i < 5; i++) + { + std::cout << " C[" << i << "] = " << float(C_gpu[i]) << " (expected " << K << ")\n"; + } + std::cout << "\n"; + + // Cleanup + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + if(accuracy > 99.0f) + { + std::cout << "[OK] TEST PASSED\n"; + return 0; + } + else + { + std::cout << "[FAIL] TEST FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_registry.cpp b/dispatcher/tests/test_registry.cpp new file mode 100644 index 0000000000..4e5bf718df --- /dev/null +++ b/dispatcher/tests/test_registry.cpp @@ -0,0 +1,166 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for Registry using Google Test + +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +TEST(RegistryTest, Registration) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + + bool registered = registry.register_kernel(kernel); + EXPECT_TRUE(registered); + EXPECT_EQ(registry.size(), 1); +} + +TEST(RegistryTest, Lookup) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + registry.register_kernel(kernel); + + // Lookup by key + auto found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "test_kernel"); + + // Lookup by identifier + std::string id = key.encode_identifier(); + auto found2 = registry.lookup(id); + ASSERT_NE(found2, nullptr); + EXPECT_EQ(found2->get_name(), "test_kernel"); + + // Lookup non-existent + auto key2 = make_test_key(128); + auto not_found = registry.lookup(key2); + EXPECT_EQ(not_found, nullptr); +} + +TEST(RegistryTest, Priority) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel1 = std::make_shared(key, "kernel_low"); + auto kernel2 = std::make_shared(key, "kernel_high"); + + // Register with low priority + registry.register_kernel(kernel1, Registry::Priority::Low); + + // Try to register with normal priority (should replace) + bool replaced = registry.register_kernel(kernel2, Registry::Priority::Normal); + EXPECT_TRUE(replaced); + + auto found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_high"); + + // Try to register with low priority again (should fail) + auto kernel3 = std::make_shared(key, "kernel_low2"); + bool not_replaced = registry.register_kernel(kernel3, Registry::Priority::Low); + EXPECT_FALSE(not_replaced); + + found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_high"); +} + +TEST(RegistryTest, GetAll) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key1 = make_test_key(256); + auto key2 = make_test_key(128); + auto kernel1 = std::make_shared(key1, "kernel1"); + auto kernel2 = std::make_shared(key2, "kernel2"); + + registry.register_kernel(kernel1); + registry.register_kernel(kernel2); + + auto all = registry.get_all(); + EXPECT_EQ(all.size(), 2); +} + +TEST(RegistryTest, Filter) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + // Create kernels with different tile sizes + for(int tile_m : {128, 256, 512}) + { + auto key = make_test_key(tile_m); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(tile_m)); + registry.register_kernel(kernel); + } + + // Filter for large tiles (>= 256) + auto large_tiles = registry.filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 256; }); + + EXPECT_EQ(large_tiles.size(), 2); +} + +TEST(RegistryTest, Clear) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + registry.register_kernel(kernel); + + EXPECT_EQ(registry.size(), 1); + + registry.clear(); + EXPECT_EQ(registry.size(), 0); +} + +TEST(RegistryTest, MultipleKernels) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + // Register multiple kernels + for(int i = 0; i < 10; ++i) + { + auto key = make_test_key(256 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + registry.register_kernel(kernel); + } + + EXPECT_EQ(registry.size(), 10); + + // Verify all can be looked up + for(int i = 0; i < 10; ++i) + { + auto key = make_test_key(256 + i); + auto found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_" + std::to_string(i)); + } +} + +TEST(RegistryTest, Singleton) +{ + Registry& reg1 = Registry::instance(); + Registry& reg2 = Registry::instance(); + + // Should be the same instance + EXPECT_EQ(®1, ®2); +} diff --git a/dispatcher/tests/test_registry_extended.cpp b/dispatcher/tests/test_registry_extended.cpp new file mode 100644 index 0000000000..d173e1a38d --- /dev/null +++ b/dispatcher/tests/test_registry_extended.cpp @@ -0,0 +1,503 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for Registry - covers multiple registries, merging, filtering + +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +// ============================================================================= +// Basic Registration Tests +// ============================================================================= + +class RegistryBasicTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryBasicTest, RegisterSingleKernel) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + + EXPECT_TRUE(Registry::instance().register_kernel(kernel)); + EXPECT_EQ(Registry::instance().size(), 1); +} + +TEST_F(RegistryBasicTest, RegisterNullKernel) +{ + EXPECT_FALSE(Registry::instance().register_kernel(nullptr)); + EXPECT_EQ(Registry::instance().size(), 0); +} + +TEST_F(RegistryBasicTest, RegisterMultipleKernels) +{ + for(int i = 0; i < 100; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + EXPECT_TRUE(Registry::instance().register_kernel(kernel)); + } + EXPECT_EQ(Registry::instance().size(), 100); +} + +TEST_F(RegistryBasicTest, RegisterDuplicateKey) +{ + auto key = make_test_key(256); + auto kernel1 = std::make_shared(key, "kernel1"); + auto kernel2 = std::make_shared(key, "kernel2"); + + EXPECT_TRUE(Registry::instance().register_kernel(kernel1, Registry::Priority::Normal)); + + // Same priority should not replace + EXPECT_FALSE(Registry::instance().register_kernel(kernel2, Registry::Priority::Normal)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "kernel1"); +} + +// ============================================================================= +// Priority Tests +// ============================================================================= + +class RegistryPriorityTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryPriorityTest, HigherPriorityReplaces) +{ + auto key = make_test_key(256); + + auto low = std::make_shared(key, "low"); + auto normal = std::make_shared(key, "normal"); + auto high = std::make_shared(key, "high"); + + EXPECT_TRUE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "low"); + + EXPECT_TRUE(Registry::instance().register_kernel(normal, Registry::Priority::Normal)); + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "normal"); + + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "high"); +} + +TEST_F(RegistryPriorityTest, LowerPriorityDoesNotReplace) +{ + auto key = make_test_key(256); + + auto high = std::make_shared(key, "high"); + auto low = std::make_shared(key, "low"); + + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + EXPECT_FALSE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "high"); +} + +TEST_F(RegistryPriorityTest, SamePriorityDoesNotReplace) +{ + auto key = make_test_key(256); + + auto first = std::make_shared(key, "first"); + auto second = std::make_shared(key, "second"); + + EXPECT_TRUE(Registry::instance().register_kernel(first, Registry::Priority::Normal)); + EXPECT_FALSE(Registry::instance().register_kernel(second, Registry::Priority::Normal)); + + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "first"); +} + +// ============================================================================= +// Lookup Tests +// ============================================================================= + +class RegistryLookupTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + // Register several kernels + for(int tile : {128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryLookupTest, LookupByKey) +{ + auto key = make_test_key(256); + auto found = Registry::instance().lookup(key); + + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_256"); +} + +TEST_F(RegistryLookupTest, LookupByIdentifier) +{ + auto key = make_test_key(256); + std::string id = key.encode_identifier(); + + auto found = Registry::instance().lookup(id); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_256"); +} + +TEST_F(RegistryLookupTest, LookupNonExistent) +{ + auto key = make_test_key(1024); // Not registered + EXPECT_EQ(Registry::instance().lookup(key), nullptr); + EXPECT_EQ(Registry::instance().lookup("nonexistent_id"), nullptr); +} + +TEST_F(RegistryLookupTest, LookupEmptyIdentifier) +{ + EXPECT_EQ(Registry::instance().lookup(""), nullptr); +} + +// ============================================================================= +// Filter Tests +// ============================================================================= + +class RegistryFilterTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + // Register kernels with various tile sizes + for(int tile : {64, 128, 256, 512, 1024}) + { + auto key = make_test_key(tile); + key.signature.dtype_a = (tile < 256) ? DataType::FP16 : DataType::BF16; + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryFilterTest, FilterByTileSize) +{ + auto large = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 256; }); + + EXPECT_EQ(large.size(), 3); // 256, 512, 1024 +} + +TEST_F(RegistryFilterTest, FilterByDataType) +{ + auto fp16 = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().signature.dtype_a == DataType::FP16; }); + + EXPECT_EQ(fp16.size(), 2); // 64, 128 +} + +TEST_F(RegistryFilterTest, FilterMatchesNone) +{ + auto none = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m > 2048; }); + + EXPECT_EQ(none.size(), 0); +} + +TEST_F(RegistryFilterTest, FilterMatchesAll) +{ + auto all = Registry::instance().filter([](const KernelInstance& k) { return true; }); + + EXPECT_EQ(all.size(), 5); +} + +// ============================================================================= +// Multiple Registries Tests +// ============================================================================= + +class MultipleRegistriesTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(MultipleRegistriesTest, CreateIndependentRegistries) +{ + Registry reg1; + Registry reg2; + + reg1.set_name("registry1"); + reg2.set_name("registry2"); + + auto key1 = make_test_key(256); + auto key2 = make_test_key(512); + + reg1.register_kernel(std::make_shared(key1, "kernel1")); + reg2.register_kernel(std::make_shared(key2, "kernel2")); + + EXPECT_EQ(reg1.size(), 1); + EXPECT_EQ(reg2.size(), 1); + + EXPECT_NE(reg1.lookup(key1), nullptr); + EXPECT_EQ(reg1.lookup(key2), nullptr); + + EXPECT_EQ(reg2.lookup(key1), nullptr); + EXPECT_NE(reg2.lookup(key2), nullptr); +} + +TEST_F(MultipleRegistriesTest, RegistryNaming) +{ + Registry reg; + reg.set_name("my_custom_registry"); + + EXPECT_EQ(reg.get_name(), "my_custom_registry"); +} + +TEST_F(MultipleRegistriesTest, MergeRegistries) +{ + Registry reg1; + Registry reg2; + + auto key1 = make_test_key(128); + auto key2 = make_test_key(256); + auto key3 = make_test_key(512); + + reg1.register_kernel(std::make_shared(key1, "k1")); + reg1.register_kernel(std::make_shared(key2, "k2")); + + reg2.register_kernel(std::make_shared(key3, "k3")); + + Registry combined; + combined.merge_from(reg1, Registry::Priority::Normal); + combined.merge_from(reg2, Registry::Priority::Normal); + + EXPECT_EQ(combined.size(), 3); + EXPECT_NE(combined.lookup(key1), nullptr); + EXPECT_NE(combined.lookup(key2), nullptr); + EXPECT_NE(combined.lookup(key3), nullptr); +} + +TEST_F(MultipleRegistriesTest, MergeWithPriorityConflict) +{ + Registry reg1; + Registry reg2; + + auto key = make_test_key(256); + + reg1.register_kernel(std::make_shared(key, "from_reg1")); + reg2.register_kernel(std::make_shared(key, "from_reg2")); + + Registry combined; + combined.merge_from(reg1, Registry::Priority::Low); + combined.merge_from(reg2, Registry::Priority::High); + + EXPECT_EQ(combined.size(), 1); + EXPECT_EQ(combined.lookup(key)->get_name(), "from_reg2"); +} + +TEST_F(MultipleRegistriesTest, SingletonIndependence) +{ + Registry local_reg; + local_reg.set_name("local"); + + auto key1 = make_test_key(256); + auto key2 = make_test_key(512); + + local_reg.register_kernel(std::make_shared(key1, "local_kernel")); + Registry::instance().register_kernel( + std::make_shared(key2, "global_kernel")); + + EXPECT_EQ(local_reg.size(), 1); + EXPECT_EQ(Registry::instance().size(), 1); + + EXPECT_NE(local_reg.lookup(key1), nullptr); + EXPECT_EQ(local_reg.lookup(key2), nullptr); + + EXPECT_EQ(Registry::instance().lookup(key1), nullptr); + EXPECT_NE(Registry::instance().lookup(key2), nullptr); +} + +// ============================================================================= +// Thread Safety Tests +// ============================================================================= + +class RegistryThreadSafetyTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryThreadSafetyTest, ConcurrentRegistrations) +{ + const int num_threads = 10; + const int kernels_per_thread = 100; + + std::vector threads; + std::atomic success_count{0}; + + for(int t = 0; t < num_threads; t++) + { + threads.emplace_back([t, kernels_per_thread, &success_count]() { + for(int k = 0; k < kernels_per_thread; k++) + { + int tile = t * 1000 + k; // Unique tile size + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + + if(Registry::instance().register_kernel(kernel)) + { + success_count++; + } + } + }); + } + + for(auto& t : threads) + { + t.join(); + } + + EXPECT_EQ(success_count.load(), num_threads * kernels_per_thread); + EXPECT_EQ(Registry::instance().size(), num_threads * kernels_per_thread); +} + +TEST_F(RegistryThreadSafetyTest, ConcurrentLookups) +{ + // Pre-register kernels + for(int i = 0; i < 100; i++) + { + auto key = make_test_key(i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + const int num_threads = 10; + const int lookups_per_thread = 1000; + std::atomic found_count{0}; + + std::vector threads; + for(int t = 0; t < num_threads; t++) + { + threads.emplace_back([lookups_per_thread, &found_count]() { + for(int k = 0; k < lookups_per_thread; k++) + { + auto key = make_test_key(k % 100); + if(Registry::instance().lookup(key) != nullptr) + { + found_count++; + } + } + }); + } + + for(auto& t : threads) + { + t.join(); + } + + EXPECT_EQ(found_count.load(), num_threads * lookups_per_thread); +} + +// ============================================================================= +// Clear and Size Tests +// ============================================================================= + +class RegistryClearTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryClearTest, ClearEmptyRegistry) +{ + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); + + Registry::instance().clear(); // Should not crash + EXPECT_EQ(Registry::instance().size(), 0); +} + +TEST_F(RegistryClearTest, ClearNonEmptyRegistry) +{ + for(int i = 0; i < 10; i++) + { + auto key = make_test_key(i); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + EXPECT_EQ(Registry::instance().size(), 10); + + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); +} + +TEST_F(RegistryClearTest, RegisterAfterClear) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); + + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); + + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); +} + +// ============================================================================= +// GetAll Tests +// ============================================================================= + +class RegistryGetAllTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryGetAllTest, GetAllEmpty) +{ + auto all = Registry::instance().get_all(); + EXPECT_EQ(all.size(), 0); +} + +TEST_F(RegistryGetAllTest, GetAllMultiple) +{ + for(int i = 0; i < 5; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + auto all = Registry::instance().get_all(); + EXPECT_EQ(all.size(), 5); +} diff --git a/dispatcher/tests/test_regression.cpp b/dispatcher/tests/test_regression.cpp new file mode 100644 index 0000000000..8b5a416ecf --- /dev/null +++ b/dispatcher/tests/test_regression.cpp @@ -0,0 +1,492 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Regression tests for known issues and edge cases. + * Add a new test here whenever a bug is fixed to prevent regression. + */ + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "test_mock_kernel.hpp" +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; +using SelectionStrategy = Dispatcher::SelectionStrategy; + +// ============================================================================= +// Issue: Uninitialized 'grouped' field in KernelKey caused JSON corruption +// Fix: Ensure all fields in make_test_key() are initialized +// ============================================================================= + +class RegressionGroupedFieldTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionGroupedFieldTest, GroupedFieldInitialized) +{ + KernelKey key = make_test_key(256); + + // grouped should be explicitly initialized + EXPECT_FALSE(key.signature.grouped); + + // Encoding should not crash or produce garbage + std::string id = key.encode_identifier(); + EXPECT_FALSE(id.empty()); + + // ID should not contain garbage characters + for(char c : id) + { + EXPECT_TRUE(std::isprint(c) || c == '_' || c == '-') + << "Invalid character in identifier: " << static_cast(c); + } +} + +TEST_F(RegressionGroupedFieldTest, GroupedFieldInJSON) +{ + KernelKey key = make_test_key(256); + key.signature.grouped = false; + + auto kernel = std::make_shared(key, "test_kernel"); + Registry::instance().register_kernel(kernel); + + // Export to JSON + std::string json = Registry::instance().export_json(true); + + // JSON should be valid (not contain null bytes or garbage) + EXPECT_FALSE(json.empty()); + + // Should contain the grouped field with proper value + EXPECT_NE(json.find("\"grouped\""), std::string::npos); + EXPECT_NE(json.find("false"), std::string::npos); +} + +// ============================================================================= +// Issue: Priority comparison was incorrect +// Fix: Higher priority should replace lower, same priority should not replace +// ============================================================================= + +class RegressionPriorityTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionPriorityTest, LowThenHighReplaces) +{ + auto key = make_test_key(256); + auto low = std::make_shared(key, "low"); + auto high = std::make_shared(key, "high"); + + EXPECT_TRUE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "high"); +} + +TEST_F(RegressionPriorityTest, HighThenLowDoesNotReplace) +{ + auto key = make_test_key(256); + auto high = std::make_shared(key, "high"); + auto low = std::make_shared(key, "low"); + + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + EXPECT_FALSE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "high"); +} + +TEST_F(RegressionPriorityTest, SamePriorityDoesNotReplace) +{ + auto key = make_test_key(256); + auto first = std::make_shared(key, "first"); + auto second = std::make_shared(key, "second"); + + EXPECT_TRUE(Registry::instance().register_kernel(first, Registry::Priority::Normal)); + EXPECT_FALSE(Registry::instance().register_kernel(second, Registry::Priority::Normal)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "first"); +} + +// ============================================================================= +// Issue: Empty heuristic caused crash +// Fix: Fall back to FirstFit when heuristic returns empty or invalid results +// ============================================================================= + +class RegressionHeuristicTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionHeuristicTest, EmptyHeuristicFallback) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {}; // Empty + }); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + + // Should not crash, should fall back to FirstFit + auto selected = dispatcher.select_kernel(problem); + EXPECT_NE(selected, nullptr); +} + +TEST_F(RegressionHeuristicTest, AllInvalidHeuristicFallback) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {"invalid1", "invalid2", "invalid3"}; + }); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + + // Should not crash, should fall back to FirstFit + auto selected = dispatcher.select_kernel(problem); + EXPECT_NE(selected, nullptr); +} + +TEST_F(RegressionHeuristicTest, NullHeuristicSafe) +{ + Dispatcher dispatcher; + + // Don't set any heuristic + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + + // Should not crash + auto selected = dispatcher.select_kernel(problem); + // Behavior depends on implementation - may return nullptr or fall back +} + +// ============================================================================= +// Issue: Lookup by empty string caused crash or undefined behavior +// ============================================================================= + +class RegressionLookupTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionLookupTest, EmptyStringLookup) +{ + EXPECT_EQ(Registry::instance().lookup(""), nullptr); +} + +TEST_F(RegressionLookupTest, VeryLongStringLookup) +{ + std::string very_long(10000, 'x'); + EXPECT_EQ(Registry::instance().lookup(very_long), nullptr); +} + +TEST_F(RegressionLookupTest, SpecialCharactersLookup) +{ + EXPECT_EQ(Registry::instance().lookup("kernel\0name"), nullptr); + EXPECT_EQ(Registry::instance().lookup("kernel\nname"), nullptr); + EXPECT_EQ(Registry::instance().lookup("kernel\tname"), nullptr); +} + +// ============================================================================= +// Issue: Problem with zero dimensions passed to dispatcher +// ============================================================================= + +class RegressionProblemTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionProblemTest, ZeroMDimension) +{ + Problem problem; + problem.M = 0; + problem.N = 1024; + problem.K = 1024; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionProblemTest, ZeroNDimension) +{ + Problem problem; + problem.M = 1024; + problem.N = 0; + problem.K = 1024; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionProblemTest, ZeroKDimension) +{ + Problem problem; + problem.M = 1024; + problem.N = 1024; + problem.K = 0; + + EXPECT_FALSE(problem.is_valid()); +} + +// ============================================================================= +// Issue: Dispatcher run with null pointers +// ============================================================================= + +class RegressionNullPointerTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionNullPointerTest, RunWithNullPointers) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Mock kernel doesn't use pointers, so this should work + float time = dispatcher.run(nullptr, nullptr, nullptr, problem); + + // Mock returns 1.0f + EXPECT_FLOAT_EQ(time, 1.0f); +} + +// ============================================================================= +// Issue: Thread safety - concurrent access to singleton +// ============================================================================= + +class RegressionThreadSafetyTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionThreadSafetyTest, SingletonAddressStable) +{ + Registry* addr1 = &Registry::instance(); + Registry* addr2 = &Registry::instance(); + Registry* addr3 = &Registry::instance(); + + EXPECT_EQ(addr1, addr2); + EXPECT_EQ(addr2, addr3); +} + +// ============================================================================= +// Issue: encode_identifier could produce duplicate IDs for different configs +// ============================================================================= + +class RegressionIdentifierTest : public ::testing::Test +{ +}; + +TEST_F(RegressionIdentifierTest, DifferentConfigsDifferentIDs) +{ + // Create two keys that differ only in one field + KernelKey key1 = make_test_key(256); + KernelKey key2 = make_test_key(256); + key2.algorithm.persistent = true; // Only difference + + std::string id1 = key1.encode_identifier(); + std::string id2 = key2.encode_identifier(); + + EXPECT_NE(id1, id2) << "Different persistent flag should produce different IDs"; +} + +TEST_F(RegressionIdentifierTest, DifferentTileShapesDifferentIDs) +{ + KernelKey key1 = make_test_key(128, 128, 32); + KernelKey key2 = make_test_key(256, 256, 32); + + EXPECT_NE(key1.encode_identifier(), key2.encode_identifier()); +} + +TEST_F(RegressionIdentifierTest, DifferentWarpConfigsDifferentIDs) +{ + KernelKey key1 = make_test_key(256); + key1.algorithm.wave_shape = {2, 2, 1}; + + KernelKey key2 = make_test_key(256); + key2.algorithm.wave_shape = {4, 1, 1}; + + EXPECT_NE(key1.encode_identifier(), key2.encode_identifier()); +} + +// ============================================================================= +// Issue: Negative k_batch could cause issues +// ============================================================================= + +class RegressionKBatchTest : public ::testing::Test +{ +}; + +TEST_F(RegressionKBatchTest, ZeroKBatchInvalid) +{ + Problem problem(1024, 1024, 1024); + problem.k_batch = 0; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionKBatchTest, NegativeKBatchInvalid) +{ + Problem problem(1024, 1024, 1024); + problem.k_batch = -1; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionKBatchTest, LargeKBatchValid) +{ + Problem problem(1024, 1024, 1024); + problem.k_batch = 1000; + + EXPECT_TRUE(problem.is_valid()); +} + +// ============================================================================= +// Issue: Filter returning shared_ptr leaks +// ============================================================================= + +class RegressionFilterTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + for(int i = 0; i < 10; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionFilterTest, FilterResultsAreValid) +{ + auto results = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 105; }); + + EXPECT_EQ(results.size(), 5); + + for(const auto& kernel : results) + { + EXPECT_NE(kernel, nullptr); + EXPECT_GE(kernel->get_key().algorithm.tile_shape.m, 105); + } +} + +// ============================================================================= +// Issue: Double clear() could cause issues +// ============================================================================= + +class RegressionDoubleClearTest : public ::testing::Test +{ +}; + +TEST_F(RegressionDoubleClearTest, DoubleClearSafe) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); + + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); + + Registry::instance().clear(); // Second clear + EXPECT_EQ(Registry::instance().size(), 0); + + // Should still work after double clear + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); +} + +// ============================================================================= +// Issue: Multiple dispatchers with same registry +// ============================================================================= + +class RegressionMultiDispatcherTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionMultiDispatcherTest, MultipleDispatchersShareRegistry) +{ + Dispatcher d1; + Dispatcher d2; + Dispatcher d3; + + Problem problem(1024, 1024, 1024); + + auto k1 = d1.select_kernel(problem); + auto k2 = d2.select_kernel(problem); + auto k3 = d3.select_kernel(problem); + + // All should select the same kernel + EXPECT_NE(k1, nullptr); + EXPECT_EQ(k1, k2); + EXPECT_EQ(k2, k3); +} diff --git a/dispatcher/tests/test_sanity_ck_tile.cpp b/dispatcher/tests/test_sanity_ck_tile.cpp new file mode 100644 index 0000000000..fd28b7e54c --- /dev/null +++ b/dispatcher/tests/test_sanity_ck_tile.cpp @@ -0,0 +1,607 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Sanity check tests to verify CK Tile kernels are actually running on GPU. + * + * These tests verify: + * 1. GPU memory allocation and transfer work correctly + * 2. The dispatcher calls CK Tile infrastructure + * 3. GPU computes correct results (not just zeros) + * 4. Performance is reasonable (not CPU fallback) + * 5. Different problem sizes work correctly + */ + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header will be included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error at " << __FILE__ << ":" << __LINE__ << ": " \ + << hipGetErrorString(err) << "\n"; \ + return 1; \ + } \ + } + +// Reference CPU GEMM for validation +template +void cpu_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + acc += float(A[m * K + k]) * float(B[k * N + n]); + } + C[m * N + n] = T(acc); + } + } +} + +// Test helper to setup dispatcher +void setup_dispatcher() +{ + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 64}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Registry::Priority::High); +} + +// ============================================================================= +// Test 1: Basic Sanity - All ones multiplication +// ============================================================================= +int test_all_ones() +{ + std::cout << "\n=== Test: All Ones Multiplication ===\n"; + + const int M = 256, N = 256, K = 256; + + std::vector A(M * K, ADataType(1.0f)); + std::vector B(K * N, BDataType(1.0f)); + std::vector C(M * N, CDataType(0.0f)); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // All ones * all ones with K=256 should give K=256 for each element + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C[i]) - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Time: " << time << " ms\n"; + std::cout << " Expected: " << K << "\n"; + std::cout << " Sample C[0]: " << float(C[0]) << "\n"; + std::cout << " Accuracy: " << accuracy << "%\n"; + + if(accuracy < 99.0f) + { + std::cerr << " FAILED: Accuracy too low\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 2: Non-Zero Results - Verify GPU actually computed something +// ============================================================================= +int test_non_zero_results() +{ + std::cout << "\n=== Test: Non-Zero Results ===\n"; + + const int M = 256, N = 256, K = 256; + + std::vector A(M * K, ADataType(2.0f)); // All 2s + std::vector B(K * N, BDataType(3.0f)); // All 3s + std::vector C(M * N, CDataType(0.0f)); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // 2 * 3 * K = 6 * 256 = 1536 + float expected = 6.0f * K; + int correct = 0; + int non_zero = 0; + + for(int i = 0; i < M * N; i++) + { + if(float(C[i]) != 0.0f) + non_zero++; + if(std::abs(float(C[i]) - expected) < 10.0f) + { + correct++; + } + } + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Time: " << time << " ms\n"; + std::cout << " Expected: " << expected << "\n"; + std::cout << " Sample C[0]: " << float(C[0]) << "\n"; + std::cout << " Non-zero elements: " << non_zero << "/" << M * N << "\n"; + + if(non_zero == 0) + { + std::cerr << " FAILED: All zeros - GPU may not have run\n"; + return 1; + } + + float accuracy = 100.0f * correct / (M * N); + std::cout << " Accuracy: " << accuracy << "%\n"; + + if(accuracy < 99.0f) + { + std::cerr << " FAILED: Accuracy too low\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 3: Performance Check - Ensure not CPU fallback +// ============================================================================= +int test_performance() +{ + std::cout << "\n=== Test: Performance Check ===\n"; + + const int M = 1024, N = 1024, K = 1024; + const int num_runs = 5; + + std::vector A(M * K, ADataType(1.0f)); + std::vector B(K * N, BDataType(1.0f)); + std::vector C(M * N); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + // Warmup + dispatcher.run(A_dev, B_dev, C_dev, problem); + HIP_CHECK(hipDeviceSynchronize()); + + // Timed runs + std::vector times; + for(int i = 0; i < num_runs; i++) + { + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + times.push_back(time); + } + + float avg_time = std::accumulate(times.begin(), times.end(), 0.0f) / times.size(); + float min_time = *std::min_element(times.begin(), times.end()); + + double flops = 2.0 * M * N * K; + double tflops = (flops / (min_time * 1e-3)) / 1e12; + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Problem: " << M << "x" << N << "x" << K << "\n"; + std::cout << " Avg time: " << avg_time << " ms\n"; + std::cout << " Min time: " << min_time << " ms\n"; + std::cout << " Performance: " << tflops << " TFLOPS\n"; + + // GPU should achieve at least 1 TFLOPS for this size + // CPU would be ~0.001 TFLOPS + if(tflops < 1.0) + { + std::cerr << " FAILED: Performance too low - may be CPU fallback\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 4: CPU vs GPU Correctness +// ============================================================================= +int test_vs_cpu_reference() +{ + std::cout << "\n=== Test: CPU vs GPU Correctness ===\n"; + + const int M = 128, N = 128, K = 128; // Small for CPU reference + + // Random-ish values + std::vector A(M * K); + std::vector B(K * N); + std::vector C_gpu(M * N); + std::vector C_cpu(M * N); + + for(int i = 0; i < M * K; i++) + { + A[i] = ADataType(float((i % 10) + 1) * 0.1f); + } + for(int i = 0; i < K * N; i++) + { + B[i] = BDataType(float((i % 7) + 1) * 0.1f); + } + + // CPU reference + cpu_gemm(A, B, C_cpu, M, N, K); + + // GPU + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Compare + float max_diff = 0.0f; + float sum_diff = 0.0f; + int correct = 0; + + for(int i = 0; i < M * N; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + float diff = std::abs(gpu_val - cpu_val); + + max_diff = std::max(max_diff, diff); + sum_diff += diff; + + // FP16 has limited precision (~3-4 decimal digits) + // For K=128, values can reach ~10-30, so allow 5% relative error + absolute tolerance + float tolerance = std::max(std::abs(cpu_val) * 0.05f, 1.0f); + if(diff < tolerance) + { + correct++; + } + } + + float avg_diff = sum_diff / (M * N); + float accuracy = 100.0f * correct / (M * N); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Max diff: " << max_diff << "\n"; + std::cout << " Avg diff: " << avg_diff << "\n"; + std::cout << " Sample CPU C[0]: " << float(C_cpu[0]) << "\n"; + std::cout << " Sample GPU C[0]: " << float(C_gpu[0]) << "\n"; + std::cout << " Accuracy: " << accuracy << "%\n"; + + // FP16 accumulation can have significant rounding differences from CPU FP32 + // 90% is reasonable for FP16 with K=128 accumulation + if(accuracy < 90.0f) + { + std::cerr << " FAILED: Too many mismatches vs CPU\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 5: Different Problem Sizes +// ============================================================================= +int test_multiple_sizes() +{ + std::cout << "\n=== Test: Multiple Problem Sizes ===\n"; + + std::vector> sizes = { + {128, 128, 128}, + {256, 256, 256}, + {512, 512, 512}, + {128, 256, 512}, + {512, 256, 128}, + {1024, 1024, 256}, + }; + + int passed = 0; + int total = sizes.size(); + + for(const auto& [M, N, K] : sizes) + { + std::cout << " Testing " << M << "x" << N << "x" << K << "... "; + + std::vector A(M * K, ADataType(1.0f)); + std::vector B(K * N, BDataType(1.0f)); + std::vector C(M * N); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + hipMalloc(&A_dev, M * K * sizeof(ADataType)); + hipMalloc(&B_dev, K * N * sizeof(BDataType)); + hipMalloc(&C_dev, M * N * sizeof(CDataType)); + + hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice); + hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice); + hipMemset(C_dev, 0, M * N * sizeof(CDataType)); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + hipMemcpy(C.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); + + hipFree(A_dev); + hipFree(B_dev); + hipFree(C_dev); + + // Check result + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C[i]) - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + + if(accuracy > 99.0f && time > 0) + { + std::cout << "PASS (" << time << " ms)\n"; + passed++; + } + else + { + std::cout << "FAIL (acc=" << accuracy << "%, time=" << time << ")\n"; + } + } + + std::cout << "\n Passed: " << passed << "/" << total << "\n"; + + if(passed < total) + { + std::cerr << " FAILED: Some sizes failed\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 6: Memory Bounds Check +// ============================================================================= +int test_memory_bounds() +{ + std::cout << "\n=== Test: Memory Bounds Check ===\n"; + + const int M = 256, N = 256, K = 256; + const float sentinel = -999.0f; + + // Allocate with extra padding and sentinel values + const int padding = 16; + std::vector A(M * K + padding, ADataType(1.0f)); + std::vector B(K * N + padding, BDataType(1.0f)); + std::vector C(M * N + padding, CDataType(sentinel)); + + // Set sentinels at the end + for(int i = 0; i < padding; i++) + { + A[M * K + i] = ADataType(sentinel); + B[K * N + i] = BDataType(sentinel); + } + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, (M * K + padding) * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, (K * N + padding) * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, (M * N + padding) * sizeof(CDataType))); + + HIP_CHECK( + hipMemcpy(A_dev, A.data(), (M * K + padding) * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(B_dev, B.data(), (K * N + padding) * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(C_dev, C.data(), (M * N + padding) * sizeof(CDataType), hipMemcpyHostToDevice)); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK( + hipMemcpy(C.data(), C_dev, (M * N + padding) * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Check sentinels weren't overwritten + bool sentinels_intact = true; + for(int i = 0; i < padding; i++) + { + if(float(C[M * N + i]) != sentinel) + { + sentinels_intact = false; + std::cerr << " Sentinel overwritten at position " << (M * N + i) << "\n"; + } + } + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + if(!sentinels_intact) + { + std::cerr << " FAILED: Memory bounds violated\n"; + return 1; + } + + // Also check actual results are correct + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C[i]) - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + std::cout << " Sentinels intact: Yes\n"; + std::cout << " Result accuracy: " << accuracy << "%\n"; + + if(accuracy < 99.0f) + { + std::cerr << " FAILED: Results incorrect\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Main +// ============================================================================= +int main() +{ + std::cout << "========================================\n"; + std::cout << "CK Tile Sanity Check Tests\n"; + std::cout << "========================================\n"; + std::cout << "Kernel: " << KERNEL_NAME << "\n"; + + // Setup + setup_dispatcher(); + + int failures = 0; + + // Run all tests + failures += test_all_ones(); + failures += test_non_zero_results(); + failures += test_performance(); + failures += test_vs_cpu_reference(); + failures += test_multiple_sizes(); + failures += test_memory_bounds(); + + std::cout << "\n========================================\n"; + if(failures == 0) + { + std::cout << "ALL TESTS PASSED\n"; + std::cout << "CK Tile is running correctly on GPU.\n"; + return 0; + } + else + { + std::cout << failures << " TEST(S) FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_tile_backend.cpp b/dispatcher/tests/test_tile_backend.cpp new file mode 100644 index 0000000000..4e7c693071 --- /dev/null +++ b/dispatcher/tests/test_tile_backend.cpp @@ -0,0 +1,155 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for CK Tile backend using Google Test +/// Note: This test validates the dispatcher wrapper infrastructure, not actual kernel execution + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +namespace { + +// Note: Actual CK Tile backend tests require real generated kernels and GPU hardware. +// These tests verify the dispatcher's tile backend interface and wrapper functionality +// using mock kernels instead of real tile kernels. +} // anonymous namespace + +// These tests verify the tile backend can be used with mock kernels +// Real tile kernel integration would require generated CK Tile kernels + +TEST(TileBackendTest, KernelKeyCreation) +{ + // Test creating a kernel key for tile backend + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + + EXPECT_EQ(key.algorithm.tile_shape.m, 256); + EXPECT_EQ(key.algorithm.tile_shape.n, 256); + EXPECT_EQ(key.algorithm.tile_shape.k, 32); + EXPECT_EQ(key.gfx_arch, "gfx942"); + EXPECT_EQ(key.signature.dtype_a, DataType::FP16); +} + +TEST(TileBackendTest, MockKernelRegistration) +{ + // Clear registry for clean test + Registry::instance().clear(); + + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + auto kernel = + std::make_shared(key, "mock_tile_kernel", false); // strict divisibility + + // Register kernel + bool registered = Registry::instance().register_kernel(kernel); + EXPECT_TRUE(registered); + + // Lookup kernel + std::string kernel_id = key.encode_identifier(); + auto found_kernel = Registry::instance().lookup(kernel_id); + EXPECT_NE(found_kernel, nullptr); + EXPECT_EQ(found_kernel->get_name(), "mock_tile_kernel"); + + Registry::instance().clear(); +} + +TEST(TileBackendTest, DispatcherWithMockTileKernel) +{ + // Clear registry + Registry::instance().clear(); + + // Create and register mock tile kernel + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + auto kernel = + std::make_shared(key, "mock_tile_kernel", false); // strict divisibility + Registry::instance().register_kernel(kernel); + + // Create dispatcher + Dispatcher dispatcher; + + // Test kernel selection - divisible dimensions + Problem problem1(512, 512, 512); // Divisible by 256, 256, 32 + auto selected1 = dispatcher.select_kernel(problem1); + EXPECT_NE(selected1, nullptr); + EXPECT_EQ(selected1->get_name(), "mock_tile_kernel"); + + // Test with non-divisible problem + Problem problem2(100, 200, 300); // Not divisible + auto not_selected = dispatcher.select_kernel(problem2); + EXPECT_EQ(not_selected, nullptr); + + Registry::instance().clear(); +} + +TEST(TileBackendTest, TileKernelIdentifierEncoding) +{ + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + + std::string id = key.encode_identifier(); + + // Should contain tile dimensions + EXPECT_NE(id.find("256x256x32"), std::string::npos); + EXPECT_NE(id.find("2x2x1"), std::string::npos); + EXPECT_NE(id.find("32x32x16"), std::string::npos); + + // Should contain persistent flag + EXPECT_NE(id.find("nopers"), std::string::npos); // persistent = false +} + +TEST(TileBackendTest, MultipleKernelRegistration) +{ + // Clear registry + Registry::instance().clear(); + + // Register multiple kernels with different tile sizes + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + auto kernel1 = std::make_shared(key1, "kernel_256x256x32", false); + + KernelKey key2 = make_test_key(128, 128, 64, "gfx942"); + auto kernel2 = std::make_shared(key2, "kernel_128x128x64", false); + + Registry::instance().register_kernel(kernel1); + Registry::instance().register_kernel(kernel2); + + EXPECT_EQ(Registry::instance().size(), 2); + + // Verify both are accessible + auto found1 = Registry::instance().lookup(key1.encode_identifier()); + auto found2 = Registry::instance().lookup(key2.encode_identifier()); + + EXPECT_NE(found1, nullptr); + EXPECT_NE(found2, nullptr); + EXPECT_EQ(found1->get_name(), "kernel_256x256x32"); + EXPECT_EQ(found2->get_name(), "kernel_128x128x64"); + + Registry::instance().clear(); +} + +TEST(TileBackendTest, TileSizeSupport) +{ + Registry::instance().clear(); + + // Create kernel with 256x256x32 tiles (no padding) + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + auto kernel = + std::make_shared(key, "test_kernel", false); // strict divisibility + + // Should support 512x512x512 (divisible) + EXPECT_TRUE(kernel->supports(Problem(512, 512, 512))); + + // Should support 256x256x32 (exact match) + EXPECT_TRUE(kernel->supports(Problem(256, 256, 32))); + + // Should NOT support 100x200x300 (not divisible) + EXPECT_FALSE(kernel->supports(Problem(100, 200, 300))); + + // Should support 1024x1024x1024 (divisible) + EXPECT_TRUE(kernel->supports(Problem(1024, 1024, 1024))); + + Registry::instance().clear(); +} diff --git a/docs/conceptual/ck_tile/CK-tile-index.rst b/docs/conceptual/ck_tile/CK-tile-index.rst index e18cb24f80..58d95bbe9d 100644 --- a/docs/conceptual/ck_tile/CK-tile-index.rst +++ b/docs/conceptual/ck_tile/CK-tile-index.rst @@ -1,14 +1,13 @@ .. _ck_tile_index: -************************ -CK Tile Index -************************ - -CK Tile documentation structure: +**************************************************** +CK Tile conceptual documentation table of contents +**************************************************** .. toctree:: :maxdepth: 2 + index introduction_motivation buffer_views tensor_views diff --git a/docs/conceptual/ck_tile/MERMAID_DIAGRAMS.md b/docs/conceptual/ck_tile/MERMAID_DIAGRAMS.md deleted file mode 100644 index 5e8679dbd2..0000000000 --- a/docs/conceptual/ck_tile/MERMAID_DIAGRAMS.md +++ /dev/null @@ -1,156 +0,0 @@ -# Mermaid Diagram Management - -This document explains how to manage mermaid diagrams in the CK Tile documentation. - -## Overview - -All mermaid diagrams in the CK Tile documentation have been converted to SVG files for better rendering compatibility. The original mermaid source code is preserved as commented blocks in the RST files, allowing easy updates when needed. - -## Directory Structure - -- `docs/conceptual/ck_tile/diagrams/` - Contains all SVG diagram files -- `docs/conceptual/ck_tile/convert_mermaid_to_svg.py` - Initial conversion script (one-time use) -- `docs/conceptual/ck_tile/update_diagrams.py` - Helper script to regenerate diagrams from comments - -## Diagram Format in RST Files - -Each diagram follows this format: - -```rst -.. - Original mermaid diagram (edit here, then run update_diagrams.py) - - .. mermaid:: - - graph TB - A --> B - B --> C - -.. image:: diagrams/diagram_name.svg - :alt: Diagram - :align: center -``` - -The commented mermaid block won't appear in the rendered documentation but serves as the source for regenerating the SVG. - -## Updating Diagrams - -### When to Update - -You need to regenerate SVG files when: -- Modifying the mermaid source in a commented block -- Adding new diagrams -- Updating diagram styling - -### How to Update - -1. **Edit the commented mermaid source** in the RST file -2. **Run the update script**: - ```bash - # Update all diagrams - python docs/conceptual/ck_tile/update_diagrams.py - - # Update diagrams in a specific file - python docs/conceptual/ck_tile/update_diagrams.py transforms.rst - - # Force regenerate all diagrams (even if SVGs exist) - python docs/conceptual/ck_tile/update_diagrams.py --force - ``` - -### Prerequisites - -The update script requires [mermaid-cli](https://github.com/mermaid-js/mermaid-cli): - -```bash -npm install -g @mermaid-js/mermaid-cli -``` - -## Adding New Diagrams - -To add a new mermaid diagram: - -1. **Create the commented block** in your RST file: - ```rst - .. - Original mermaid diagram (edit here, then run update_diagrams.py) - - .. mermaid:: - - graph TB - A --> B - ``` - -2. **Add the image reference** immediately after: - ```rst - .. image:: diagrams/my_new_diagram.svg - :alt: My New Diagram - :align: center - ``` - -3. **Generate the SVG**: - ```bash - python docs/conceptual/ck_tile/update_diagrams.py your_file.rst - ``` - -## Current Diagrams - -The following RST files contain mermaid diagrams (40 total): - -- `adaptors.rst` (2 diagrams) -- `convolution_example.rst` (1 diagram) -- `coordinate_movement.rst` (1 diagram) -- `descriptors.rst` (2 diagrams) -- `encoding_internals.rst` (2 diagrams) -- `lds_index_swapping.rst` (3 diagrams) -- `load_store_traits.rst` (2 diagrams) -- `space_filling_curve.rst` (1 diagram) -- `static_distributed_tensor.rst` (1 diagram) -- `sweep_tile.rst` (4 diagrams) -- `tensor_coordinates.rst` (2 diagrams) -- `thread_mapping.rst` (2 diagrams) -- `tile_window.rst` (5 diagrams) -- `transforms.rst` (12 diagrams) - -## Troubleshooting - -### SVG not generated - -- Check that mermaid-cli is installed: `mmdc --version` -- Verify the mermaid syntax is valid -- Look for error messages in the script output - -### Diagram not updating - -- Use `--force` flag to regenerate: `python docs/update_diagrams.py --force` -- Check that the image reference matches the generated filename - -### Pattern not matching - -If the update script can't find your commented diagram: -- Ensure proper indentation (3 spaces for comment block content) -- Verify the `.. mermaid::` directive is commented -- Check that the image reference immediately follows the comment block - -## Script Details - -### update_diagrams.py - -This script: -1. Scans RST files for commented mermaid blocks -2. Extracts the mermaid source code -3. Converts to SVG using `mmdc` -4. Saves to the diagrams directory - -**Usage:** -- `python docs/conceptual/ck_tile/update_diagrams.py` - Check all files, update missing SVGs -- `python docs/conceptual/ck_tile/update_diagrams.py --force` - Regenerate all SVGs -- `python docs/conceptual/ck_tile/update_diagrams.py ` - Update specific file - -### convert_mermaid_to_svg.py - -This was the initial conversion script. It: -1. Found all active `.. mermaid::` directives -2. Converted them to SVGs -3. Replaced directives with commented source + image references - -This script was used once for the initial conversion and typically doesn't need to be run again. diff --git a/docs/conceptual/ck_tile/adaptors.rst b/docs/conceptual/ck_tile/adaptors.rst index 9e8907ab10..8720199eab 100644 --- a/docs/conceptual/ck_tile/adaptors.rst +++ b/docs/conceptual/ck_tile/adaptors.rst @@ -59,8 +59,8 @@ A TensorAdaptor encapsulates a sequence of :ref:`coordinate transformations {}) // to single dim 0 ); - // The adaptor is embedded in the :ref:`descriptor ` + // The adaptor is embedded in the descriptor // To use it: multi_index<1> top_coord{5}; // 1D coordinate // This internally calculates: row = 5/3 = 1, col = 5%3 = 2 @@ -309,7 +309,6 @@ A practical example showing how adaptors create efficient :ref:`GPU memory acces // - Dimension 0,1: Thread indices // - Dimension 2,3: Vector indices within thread // Enables coalesced memory access on GPU - // See :ref:`ck_tile_thread_mapping` for thread mapping details Common Transform Chains ----------------------- diff --git a/docs/conceptual/ck_tile/buffer_views.rst b/docs/conceptual/ck_tile/buffer_views.rst index 14b8309504..600aaed96f 100644 --- a/docs/conceptual/ck_tile/buffer_views.rst +++ b/docs/conceptual/ck_tile/buffer_views.rst @@ -1,35 +1,32 @@ -.. meta:: - :description: Composable Kernel CK Tile buffer views - :keywords: composable kernel, CK, CK Tile, ROCm, API, buffer view, raw memory - .. _ck_tile_buffer_views: -CK Tile buffer view -======================= +********************************** +Buffer Views - Raw Memory Access +********************************** -Buffer view is an abstraction that provides structured access to memory. The ``buffer_view`` class is exposed in ``include/ck_tile/core/tensor/buffer_view.hpp``. +Overview +-------- -Buffer view serves as the foundation for :ref:`ck_tile_tensor_views`. BufferView handles memory addressing and type safety, while TensorView builds upon this to add multi-dimensional coordinates (shape and strides). +At the foundation of the CK Tile system lies BufferView, a compile-time abstraction that provides structured access to raw memory regions within GPU kernels. This serves as the bridge between the hardware's physical memory model and the higher-level abstractions that enable efficient GPU programming. BufferView encapsulates the complexity of GPU memory hierarchies while exposing a unified interface that works seamlessly across different memory address spaces including global memory shared across the entire device, local data share (LDS) memory shared within a workgroup, or the ultra-fast register files private to each thread. +BufferView serves as the foundation for :ref:`ck_tile_tensor_views`, which add multi-dimensional structure on top of raw memory access. Understanding BufferView is essential before moving on to more complex abstractions like :ref:`ck_tile_distribution` and :ref:`ck_tile_tile_window`. -Buffer view provides the following advantages: +By providing compile-time knowledge of buffer properties through template metaprogramming, BufferView enables the compiler to generate optimal machine code for each specific use case. This zero-overhead abstraction ensures that the convenience of a high-level interface comes with no runtime performance penalty. -* A unified interface across global, shared, and register memory -* Address spaces encoded in types, taking advantage of compile-time type checking -* Configurable handling of invalid values, out-of-bounds operations, and conditional access patterns -* Atomic operations for parallel algorithms -* AMD GPU-specific optimizations -* Automatic application of appropriate memory ordering constraints and cache control directives based on the target address space and operation type +One of BufferView's most important features is its advanced handling of out-of-bounds memory access. Unlike CPU programming where such accesses typically result in segmentation faults or undefined behavior, GPU programming must gracefully handle cases where threads attempt to access memory beyond allocated boundaries. BufferView provides configurable strategies for these scenarios, where developers can choose between returning either numerical zero values or custom sentinel values for invalid accesses. This flexibility is important for algorithms that naturally extend beyond data boundaries, such as convolutions with padding or matrix operations with non-aligned dimensions. +The abstraction extends beyond simple memory access to encompass both scalar and vector data types. GPUs achieve their highest efficiency when loading or storing multiple data elements in a single instruction. BufferView seamlessly supports these vectorized operations, automatically selecting the appropriate hardware instructions based on the data type and access pattern. This capability transforms what would be multiple memory transactions into single, efficient operations that fully utilize the available memory bandwidth. -[TO DO: do we want to say more about these items? There wasn't a lot of detail in the original text, so I put them in a list for now] - +BufferView also incorporates AMD GPU-specific optimizations that leverage unique hardware features. The AMD buffer addressing mode, for instance, provides hardware-accelerated bounds checking that ensures memory safety without the performance overhead of software-based checks. Similarly, BufferView exposes atomic operations that are crucial for parallel algorithms requiring thread-safe updates to shared data structures. These hardware-specific optimizations are abstracted behind a portable interface, ensuring that code remains maintainable while achieving optimal performance. +Memory coherence and caching policies represent another layer of complexity that BufferView manages transparently. Different GPU memory spaces have different coherence guarantees and caching behaviors. Global memory accesses can be cached in L1 and L2 caches with various coherence protocols, while LDS memory provides workgroup-level coherence with specialized banking structures (see :ref:`ck_tile_lds_bank_conflicts` for details on avoiding bank conflicts). BufferView encapsulates these details, automatically applying the appropriate memory ordering constraints and cache control directives based on the target address space and operation type. Address Space Usage Patterns ---------------------------- -[TO DO: explain in words what the diagram shows] +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + .. Original mermaid diagram (edit here, then run update_diagrams.py) @@ -66,18 +63,27 @@ Address Space Usage Patterns style Compute fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + + + + .. image:: diagrams/buffer_views_1.svg :alt: Diagram :align: center + +C++ Implementation +------------------ +**File**: ``include/ck_tile/core/tensor/buffer_view.hpp`` Basic Creation ~~~~~~~~~~~~~~ -[TO DO: remove "modern C++ template metaprogramming" and "zero-overhead abstraction"] +By encoding critical properties such as buffer size and address space as template parameters, BufferView transforms what would traditionally be runtime decisions into compile-time constants. This design philosophy enables the compiler to perform aggressive optimizations, including constant propagation, loop unrolling, and instruction selection, that would be impossible with runtime parameters. -[TO DO: might want to move the implementation details to a separate section under "reference"] +The use of compile-time constants extends beyond mere optimization. When the buffer size is encoded in the type system using constructs like ``number<8>{}``, the compiler can statically verify that array accesses are within bounds, eliminate unnecessary bounds checks, and even restructure algorithms to better match the known data dimensions. This compile-time knowledge propagates through the entire computation, enabling optimizations at every level of the abstraction hierarchy. +The address space template parameter represents another crucial design decision. By making the memory space part of the type system, BufferView ensures that operations appropriate for one memory space cannot be accidentally applied to another. This type safety prevents common errors such as attempting atomic operations on register memory or using global memory synchronization primitives on local memory. The compiler enforces these constraints at compile time, transforming potential runtime errors into compile-time diagnostics. .. code-block:: cpp @@ -98,7 +104,6 @@ Basic Creation buffer_size // number of elements ); - // Implementation detail: The actual C++ template is: // template (data, buffer_size, custom_invalid); - - // Invalid element access with is_valid_element=false - // Returns custom_invalid due to custom invalid value mode - auto invalid_value = buffer_view.template get(0, 0, false); - printf("Invalid element: %.1f\n", invalid_value.get(0)); - - // Out of bounds access - AMD buffer addressing handles bounds checking - // Will return custom_invalid when accessing beyond buffer_size - auto oob_value = buffer_view.template get(0, 100, true); - printf("Out of bounds: %.1f\n", oob_value.get(0)); - - - - - Get Operations -------------- -[TO DO: might want to put this implementation detail in the reference section] +Scalar Access +~~~~~~~~~~~~~ -The signature for the ``buffer_view`` ``get()`` takes four parameters: +The get operations in BufferView form the cornerstone of memory access patterns in CK Tile. These operations embody a advanced understanding of GPU memory systems and the patterns that lead to optimal performance. The scalar access interface incorporates multiple layers of optimization and safety mechanisms that work together to provide both performance and correctness. -``i``: the primary offset into the buffer expressed in terms of elements of type T rather than raw bytes. +The parameter structure of scalar access operations reflects careful design choices aimed at maximizing flexibility while maintaining efficiency. The base index parameter ``i`` represents the primary offset into the buffer, expressed in terms of elements of type T rather than raw bytes. This type-aware indexing prevents common errors related to pointer arithmetic and ensures that vector types are handled correctly. The additional ``linear_offset`` parameter provides fine-grained control over the final access location, enabling complex access patterns without requiring expensive index calculations in the kernel code. -``linear_offset``: [TO DO: what is this?] +The ``is_valid_element`` parameter provides a solution to conditional memory access. Rather than using traditional if-statements that would cause warp divergence, this boolean parameter enables predicated execution where the memory access occurs unconditionally but the result is conditionally used. This approach maintains uniform control flow across all threads in a warp, preserving the SIMD execution model that is fundamental to GPU performance. -``is_valid_element``: [TO DO: what is this?] +The invalid value modes provide a mechanism for handling the boundary conditions that arise in parallel algorithms. When ``InvalidElementUseNumericalZeroValue`` is set to true, the system returns zero for any invalid access, whether due to the ``is_valid_element`` flag or out-of-bounds indexing. This mode is important for algorithms where zero serves as a natural extension value, such as in image processing with zero-padding or sparse matrix operations where missing elements are implicitly zero. -[TO DO: the last param, that's the out of bounds handling, yes? -.. code:: cpp +The custom invalid value mode, activated when ``InvalidElementUseNumericalZeroValue`` is false, offers additional flexibility for algorithms with specific boundary requirements. This mode returns a user-specified value for invalid accesses, accommodating use cases such as sentinel values in sorting algorithms, infinity values in optimization problems, or special markers in data processing pipelines. The implementation ensures that this flexibility comes without performance penalty, using the same branchless execution strategies as the zero mode. - get(index_t i, - index_t linear_offset, - bool is_valid_element, - bool_constant = {}) +Out-of-bounds handling leverages AMD GPU hardware capabilities to provide safety with minimal impact to performance. When AMD buffer addressing is enabled, the hardware automatically clamps memory accesses to valid ranges, preventing the segmentation faults that would occur on CPU systems. This hardware-assisted bounds checking operates at wire speed, adding no overhead to the memory access path while ensuring that kernels cannot corrupt memory outside their allocated regions. +Vector Access +~~~~~~~~~~~~~ -[TO DO: need some context around the code] +Vector memory operations represent one of the most critical optimizations available in modern GPU programming, and BufferView's vector access interface exposes this capability. By using template parameters to specify vector types through constructs like ``ext_vector_t``, the interface enables compile-time selection of optimal load and store instructions that can transfer multiple data elements in a single memory transaction. This vectorization is crucial for :ref:`ck_tile_load_store_traits`, which automatically selects optimal access patterns. -[TO DO: code chunks need to have detail and explanation so that the reader can see what they're trying to demonstrate.] +The significance of vector operations extends beyond bandwidth improvements. GPUs are designed with wide memory buses that can transfer 128, 256, or even 512 bits per transaction. When scalar operations access only 32 bits at a time, they utilize only a fraction of this available bandwidth. Vector operations align with these wide buses, enabling full bandwidth utilization and reducing the total number of memory transactions required. +The implementation of vector access maintains the same parameter structure as scalar operations, providing consistency across the API while automatically handling the complexities of multi-element transfers. The system manages alignment requirements, ensures that vector loads and stores use the optimal hardware instructions, and handles cases where vector operations extend beyond buffer boundaries. This transparent handling of edge cases allows developers to use vector operations confidently without manual boundary checks or special-case code for partial vectors. -.. code-block:: cpp - - // Create buffer view - float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; - auto buffer_view = make_buffer_view(data, 8); - - // Simple get - compile-time bounds checking when possible - auto value_buf = buffer_view.template get(0,1,true); //get the buffer from the buffer view - float value = value_buf.get(0); //get the value from the buffer - - // Get with valid flag - branchless conditional access - bool valid_flag = false; - value_buf = buffer_view.template get(0,1,valid_flag); - value = value_buf.get(0); - // Returns 0 valid_flag is false - - // vectorized get - using float2 = ext_vector_t; - auto vector_buf = buffer_view.template get(0, 0, true); - // Loads 2 floats in a single instruction - float val1 = vector_buf.get(0); - float val2 = vector_buf.get(1); - } - -``ext_vector_t`` enables compile-time selection of optimal load and store instructions that can transfer multiple data elements in a single memory transaction. - -[TO DO: what is it actually doing? When does one use scalars vs vectors? Is it application specific or are there ] +Scalar vs Vectorized Memory Access +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + .. Original mermaid diagram (edit here, then run update_diagrams.py) @@ -287,8 +236,9 @@ The signature for the ``buffer_view`` ``get()`` takes four parameters: Understanding BufferView Indexing ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -[TO DO: an explanation of the diagram is needed] - +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + .. Original mermaid diagram (edit here, then run update_diagrams.py) @@ -335,14 +285,69 @@ Understanding BufferView Indexing .. image:: diagrams/buffer_views_3.svg :alt: Diagram :align: center - - + +C++ Get Operations +~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + __device__ void example_get_operations() + { + // Create buffer view + float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + auto buffer_view = make_buffer_view(data, 8); + + // Simple get - compile-time bounds checking when possible + auto value_buf = buffer_view.template get(0,1,true); //get the buffer from the buffer view + float value = value_buf.get(0); //get the value from the buffer + + // Get with valid flag - branchless conditional access + bool valid_flag = false; + value_buf = buffer_view.template get(0,1,valid_flag); + value = value_buf.get(0); + // Returns 0 valid_flag is false + + // vectorized get + using float2 = ext_vector_t; + auto vector_buf = buffer_view.template get(0, 0, true); + // Loads 2 floats in a single instruction + float val1 = vector_buf.get(0); + float val2 = vector_buf.get(1); + } + +Custom Value Return Mode for OOB & Invalid Access +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + void scalar_get_operations_example() { + + // Create data array + constexpr size_t buffer_size = 8; + float data[buffer_size] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + float custom_invalid = 13.0f; + + // Create global memory buffer view with zero invalid value mode (default) + auto buffer_view = make_buffer_view(data, buffer_size, custom_invalid); + + // Invalid element access with is_valid_element=false + // Returns custom_invalid due to custom invalid value mode + auto invalid_value = buffer_view.template get(0, 0, false); + printf("Invalid element: %.1f\n", invalid_value.get(0)); + + // Out of bounds access - AMD buffer addressing handles bounds checking + // Will return custom_invalid when accessing beyond buffer_size + auto oob_value = buffer_view.template get(0, 100, true); + printf("Out of bounds: %.1f\n", oob_value.get(0)); + } + +.. note:: + + Partial Out Of Bound (OOB) access during vector reads will return 'junk' values for the OOB access. Zero or custom invalid value is only returned for complete invalid/OOB access, in other words, it is only returned when the first address of the vector is invalid. Update Operations ----------------- -Update operations modify the buffer content. The ``set()`` method writes a value to a specific location. - .. code-block:: cpp void scalar_set_operations_example() { @@ -373,8 +378,6 @@ Update operations modify the buffer content. The ``set()`` method writes a value Atomic Operations ----------------- -[TO DO: this needs information] - Atomic vs Non-Atomic Operations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -441,3 +444,21 @@ C++ Atomic Operations __syncthreads(); } + +Summary +------- + +BufferView abstracts GPU memory hierarchies behind a concise interface. The approach is intended to keep overhead small while enabling optimizations that are otherwise awkward in low-level code. + +BufferView offers a unified interface across global, shared, and register memory. Using the same API for each space can lower cognitive overhead, reduce certain classes of mistakes, and support code reuse via template parameters. + +Address spaces are encoded in types so that common errors are reported at compile time. Consistent with CK Tile’s zero-overhead design aim, compile-time checks are favored over runtime guards. The C++ type system enforces memory-space constraints and can make valid cases more amenable to compiler optimization. + +BufferView supports configurable handling of invalid values, optional runtime bounds checks, and conditional access patterns. It also provides atomic operations for thread-safe updates. These features are intended to cover common edge cases without adding unnecessary overhead. + +By hiding the complexity of different memory spaces while exposing the operations needed for high-performance GPU computing, BufferView establishes a pattern that the rest of CK Tile follows: compile-time abstractions that enhance rather than compromise performance. The :ref:`ck_tile_tensor_views` and :ref:`ck_tile_distribution` add capability while maintaining the efficiency established at the base. For hardware-specific details about memory hierarchies, see :ref:`ck_tile_gpu_basics`. + +Next Steps +---------- + +Continue to :ref:`ck_tile_tensor_views` to learn how to build structured tensor views on top of buffer views. diff --git a/docs/conceptual/ck_tile/convolution_example.rst b/docs/conceptual/ck_tile/convolution_example.rst index a981ae04da..c2fe62bb22 100644 --- a/docs/conceptual/ck_tile/convolution_example.rst +++ b/docs/conceptual/ck_tile/convolution_example.rst @@ -59,10 +59,6 @@ The key insight is that convolution can be transformed from a complex nested loo -.. image:: diagrams/convolution_example.svg - :alt: Diagram - :align: center - .. image:: diagrams/convolution_example.svg :alt: Diagram :align: center @@ -88,7 +84,6 @@ Non-overlapping tiles: // Original matrix: shape=(6, 6), strides=(6, 1) // Tiled view: shape=(3, 3, 2, 2), strides=(12, 2, 6, 1) - // See :ref:`ck_tile_descriptors` for descriptor details using TileDescriptor = TensorDescriptor< Sequence, Sequence<12, 2, 6, 1> @@ -243,7 +238,6 @@ The im2col transformation converts the 4D windows tensor into a 2D matrix suitab >; // Step 2: Apply merge transforms to create 2D im2col layout - // See :ref:`ck_tile_transforms` for transform operations using Im2colDescriptor = decltype( transform_tensor_descriptor( WindowsDescriptor{}, @@ -312,7 +306,6 @@ Combining all components into an optimized convolution implementation: >; // Tile distribution for matrix multiplication - // See :ref:`ck_tile_tile_distribution` for details using ATileDist = TileDistribution< Sequence, Sequence @@ -327,7 +320,6 @@ Combining all components into an optimized convolution implementation: >; // Thread-local accumulator - // See :ref:`ck_tile_static_distributed_tensor` StaticDistributedTensor c_accumulator; // Initialize accumulator @@ -339,7 +331,6 @@ Combining all components into an optimized convolution implementation: // Main GEMM loop over K dimension for (index_t k_tile = 0; k_tile < PatchSize; k_tile += TileK) { // Create tile windows for im2col matrix and kernel - // See :ref:`ck_tile_tile_window` for window operations auto a_window = make_tile_window( input, Im2colDesc{H, W, K}, {blockIdx.y * TileM, k_tile} @@ -350,7 +341,7 @@ Combining all components into an optimized convolution implementation: {k_tile, 0} ); - // Load tiles - see :ref:`ck_tile_load_store_traits` for optimization + // Load tiles auto a_tile = a_window.load(); auto b_tile = b_window.load(); @@ -476,7 +467,6 @@ CK Tile enables several optimizations for convolution: __shared__ float smem_b[TileK][TileN]; // Collaborative loading with proper bank conflict avoidance - // See :ref:`ck_tile_lds_bank_conflicts` for optimization auto load_tile_to_smem = [&](auto& window, float smem[][TileK]) { #pragma unroll for (index_t i = threadIdx.y; i < TileM; i += blockDim.y) { @@ -560,7 +550,7 @@ This example demonstrates how CK Tile transforms convolution from a memory-bound - **Sliding windows** can be efficiently represented using tensor descriptors with appropriate strides - **Im2col transformation** converts convolution to matrix multiplication without data copies -- **Tile distribution** enables optimal work distribution across GPU threads (see :ref:`ck_tile_tile_distribution`) +- **Tile distribution** enables optimal work distribution across GPU threads (see :ref:`ck_tile_distribution`) - **Multi-channel support** extends naturally through higher-dimensional descriptors - **Performance optimizations** like vectorization and shared memory are seamlessly integrated (see :ref:`ck_tile_gemm_optimization` for similar techniques) diff --git a/docs/conceptual/ck_tile/coordinate_movement.rst b/docs/conceptual/ck_tile/coordinate_movement.rst index 73633afa88..78d864bf75 100644 --- a/docs/conceptual/ck_tile/coordinate_movement.rst +++ b/docs/conceptual/ck_tile/coordinate_movement.rst @@ -317,7 +317,7 @@ Movement Through Adaptors Advanced Movement Patterns ========================== -Real-world applications use advanced movement patterns for optimal memory access. These patterns often relate to :ref:`ck_tile_tile_window` operations and :ref:`ck_tile_tile_distribution` concepts: +Real-world applications use advanced movement patterns for optimal memory access. These patterns often relate to :ref:`ck_tile_tile_window` operations and :ref:`ck_tile_distribution` concepts: Tiled Access Pattern -------------------- diff --git a/docs/conceptual/ck_tile/descriptors.rst b/docs/conceptual/ck_tile/descriptors.rst index 3a52097d06..449e7bc4b1 100644 --- a/docs/conceptual/ck_tile/descriptors.rst +++ b/docs/conceptual/ck_tile/descriptors.rst @@ -315,18 +315,18 @@ Padding for Convolution .. code-block:: cpp -// Add padding to spatial dimensions - auto padded = transform_tensor_descriptor( - input_tensor, - make_tuple( - make_pass_through_transform(N), // Batch - make_pass_through_transform(C), // Channel - make_pad_transform(H, pad_h, pad_h), // Height - make_pad_transform(W, pad_w, pad_w) // Width - ), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}) - ); + // Add padding to spatial dimensions + auto padded = transform_tensor_descriptor( + input_tensor, + make_tuple( + make_pass_through_transform(N), // Batch + make_pass_through_transform(C), // Channel + make_pad_transform(H, pad_h, pad_h), // Height + make_pad_transform(W, pad_w, pad_w) // Width + ), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}) + ); For a complete convolution example, see :ref:`ck_tile_convolution_example`. diff --git a/docs/conceptual/ck_tile/hardware/gemm_optimization.rst b/docs/conceptual/ck_tile/hardware/gemm_optimization.rst index a31b6b7803..7a99577290 100644 --- a/docs/conceptual/ck_tile/hardware/gemm_optimization.rst +++ b/docs/conceptual/ck_tile/hardware/gemm_optimization.rst @@ -260,7 +260,6 @@ Here's how CK Tile implements an optimized GEMM kernel: index_t K) { // Define tile distribution encoding - // See :ref:`ck_tile_encoding_internals` and :ref:`ck_tile_tile_distribution` using Encoding = tile_distribution_encoding< sequence<>, // No replication tuple, // M dimension hierarchy @@ -274,7 +273,6 @@ Here's how CK Tile implements an optimized GEMM kernel: constexpr auto tile_dist = make_static_tile_distribution(Encoding{}); // Create tensor views for global memory - // See :ref:`ck_tile_tensor_views` and :ref:`ck_tile_buffer_views` auto a_global_view = make_naive_tensor_view( a_global, make_tuple(M, K), make_tuple(K, 1)); auto b_global_view = make_naive_tensor_view( @@ -287,7 +285,6 @@ Here's how CK Tile implements an optimized GEMM kernel: const index_t block_n_id = blockIdx.x; // Create tile windows for loading - // See :ref:`ck_tile_tile_window` for tile window details auto a_window = make_tile_window( a_global_view, make_tuple(number{}, number{}), @@ -301,7 +298,6 @@ Here's how CK Tile implements an optimized GEMM kernel: tile_dist); // Allocate LDS storage - // See :ref:`ck_tile_static_distributed_tensor` for distributed tensors auto a_lds = make_static_distributed_tensor(); auto b_lds = make_static_distributed_tensor(); - // See :ref:`ck_tile_sweep_tile` for sweep operations sweep_tile(c_reg, [](auto idx, auto& val) { val = 0; }); // Main GEMM loop with pipelining @@ -324,7 +319,6 @@ Here's how CK Tile implements an optimized GEMM kernel: // Pipeline loop for(index_t k_tile = 0; k_tile < num_k_tiles - 1; ++k_tile) { // Move windows for next iteration - // See :ref:`ck_tile_coordinate_movement` for window movement a_window.move_slice_window(make_tuple(0, KPerBlock)); b_window.move_slice_window(make_tuple(0, KPerBlock)); diff --git a/docs/conceptual/ck_tile/hardware/lds_bank_conflicts.rst b/docs/conceptual/ck_tile/hardware/lds_bank_conflicts.rst index 8802fba9e8..cca18035fe 100644 --- a/docs/conceptual/ck_tile/hardware/lds_bank_conflicts.rst +++ b/docs/conceptual/ck_tile/hardware/lds_bank_conflicts.rst @@ -172,7 +172,6 @@ Example usage in CK Tile: a_window.load(a_lds_tensor); // Subsequent reads from LDS are conflict-free - // See :ref:`ck_tile_sweep_tile` for sweep operations sweep_tile(a_lds_tensor, [](auto idx, auto& val) { // Process data... }); diff --git a/docs/conceptual/ck_tile/introduction_motivation.rst b/docs/conceptual/ck_tile/introduction_motivation.rst index 9884901556..e6f2112311 100644 --- a/docs/conceptual/ck_tile/introduction_motivation.rst +++ b/docs/conceptual/ck_tile/introduction_motivation.rst @@ -276,7 +276,7 @@ The foundation of the exploration begins with raw memory access through :ref:`ck With these foundational concepts established, the documentation delves into the :ref:`ck_tile_coordinate_systems` that powers tile distribution. This engine implements the mathematical framework that have been introduced, providing compile-time transformations between P-space, Y-space, X-space, and D-space. Understanding these transformations at a deep level enables developers to reason about performance implications and design custom distribution strategies for novel algorithms. The :ref:`ck_tile_transforms` and :ref:`ck_tile_adaptors` provide the building blocks for these transformations. -The high-level :ref:`ck_tile_distribution` APIs represent the culmination of these lower-level abstractions. These APIs provide an accessible interface for common patterns while exposing enough flexibility for advanced optimizations. Through concrete examples and detailed explanations, the documentation will demonstrate how to leverage these APIs to achieve near-optimal performance across a variety of computational patterns. The :ref:`ck_tile_window` abstraction provides the gateway for efficient data access. +The high-level :ref:`ck_tile_distribution` APIs represent the culmination of these lower-level abstractions. These APIs provide an accessible interface for common patterns while exposing enough flexibility for advanced optimizations. Through concrete examples and detailed explanations, the documentation will demonstrate how to leverage these APIs to achieve near-optimal performance across a variety of computational patterns. The :ref:`ck_tile_tile_window` abstraction provides the gateway for efficient data access. The exploration of coordinate systems goes beyond the basic P, Y, X, D framework to encompass advanced topics such as multi-level tiling, replication strategies, and specialized coordinate systems for specific algorithm classes. The :ref:`ck_tile_encoding_internals` reveals the mathematical foundations, while :ref:`ck_tile_thread_mapping` shows how these abstractions map to hardware. This comprehensive treatment ensures that developers can handle not just common cases but also novel algorithms that require custom distribution strategies. diff --git a/docs/conceptual/ck_tile/lds_index_swapping.rst b/docs/conceptual/ck_tile/lds_index_swapping.rst index 891b32f9ed..b0a2b32010 100644 --- a/docs/conceptual/ck_tile/lds_index_swapping.rst +++ b/docs/conceptual/ck_tile/lds_index_swapping.rst @@ -5,7 +5,7 @@ .. _ck_tile_lds_index_swapping: ******************************** -Load Datat Share Index Swapping +Load Data Share Index Swapping ******************************** Overview @@ -70,9 +70,9 @@ The original K coordinate is split into K0 and K1, where K1 represents the threa The XOR transformation updates the K0 coordinate using the formula: -.. code-block:: cpp +.. math:: - K0' = K0 ^ (M % (KPerBlock / KPack * MLdsLayer)) + K0' = K0^{(M \% (KPerBlock / KPack * MLdsLayer))} This XOR operation redistributes accesses across memory banks by mixing bits from the M and K dimensions. @@ -132,10 +132,10 @@ The transformed K0' is split into L and K0'' components, creating an intermediat The unmerge operation: -.. code-block:: cpp +.. math:: L = K0' / (KPerBlock/KPack) - K0'' = K0' % (KPerBlock/KPack) + K0'' = K0' \% (KPerBlock/KPack) When MLdsLayer == 1, this simplifies to L=0 and K0''=K0'. diff --git a/docs/conceptual/ck_tile/load_store_traits.rst b/docs/conceptual/ck_tile/load_store_traits.rst index f9555a8bfe..bf2decc37e 100644 --- a/docs/conceptual/ck_tile/load_store_traits.rst +++ b/docs/conceptual/ck_tile/load_store_traits.rst @@ -71,7 +71,6 @@ The LoadStoreTraits class analyzes distribution patterns at compile time: static constexpr index_t scalars_per_access = scalar_per_vector; // Space-filling curve for optimal traversal - // See :ref:`ck_tile_space_filling_curve` for details using sfc_type = space_filling_curve; static constexpr sfc_type sfc_ys = make_space_filling_curve(); @@ -274,7 +273,7 @@ LoadStoreTraits optimizes for several performance metrics: return Traits::num_access; } - // Check coalescing efficiency (see :ref:`ck_tile_gpu_basics`) + // Check coalescing efficiency static constexpr bool is_perfectly_coalesced() { // Perfect coalescing when adjacent threads access adjacent memory @@ -316,7 +315,6 @@ Comparing Different Configurations static_assert(OptimizedAnalyzer::bandwidth_utilization() == 50.0f); // 8*4/64 // Better bandwidth utilization leads to improved performance - // See :ref:`ck_tile_gemm_optimization` for real-world examples Integration with Space-Filling Curves ------------------------------------- diff --git a/docs/conceptual/ck_tile/space_filling_curve.rst b/docs/conceptual/ck_tile/space_filling_curve.rst index 4b95f71a69..869285b462 100644 --- a/docs/conceptual/ck_tile/space_filling_curve.rst +++ b/docs/conceptual/ck_tile/space_filling_curve.rst @@ -254,7 +254,6 @@ For :ref:`matrix multiplication `, optimal access pat // GEMM tile: 16x32 with vector-8 loads // Column-major for coalesced access in GEMM - // See :ref:`ck_tile_gemm_optimization` for complete example using GemmTileCurve = space_filling_curve< 2, sequence<16, 32>, // Tile size @@ -336,7 +335,7 @@ Optimizing for Hardware .. code-block:: cpp - // Optimize for GPU memory coalescing (see :ref:`ck_tile_gpu_basics`) + // Optimize for GPU memory coalescing template struct coalesced_access_pattern { @@ -411,7 +410,6 @@ LoadStoreTraits Integration struct load_store_traits { // Create optimized space-filling curve - // See :ref:`ck_tile_tile_distribution` for Distribution details using sfc_type = space_filling_curve< Distribution::ndim_y, typename Distribution::y_lengths, @@ -461,7 +459,6 @@ Best Practices .. code-block:: cpp // Match vector size to cache line for optimal bandwidth - // See :ref:`ck_tile_lds_bank_conflicts` for cache optimization constexpr index_t optimal_vector = min( tensor_length_fast_dim, cache_line_size / sizeof(DataType) diff --git a/docs/conceptual/ck_tile/static_distributed_tensor.rst b/docs/conceptual/ck_tile/static_distributed_tensor.rst index bfd50c0899..1f7a93657f 100644 --- a/docs/conceptual/ck_tile/static_distributed_tensor.rst +++ b/docs/conceptual/ck_tile/static_distributed_tensor.rst @@ -17,9 +17,9 @@ Each thread in a workgroup owns a portion of the overall tensor data, stored in This design enables three critical optimizations: - * It maximizes register utilization by keeping frequently accessed data in the fastest memory hierarchy. - * It eliminates redundant memory accesses since each thread maintains its own working set. - * It provides a clean abstraction for complex algorithms like matrix multiplication where each thread accumulates partial results that eventually combine into the final output. +* It maximizes register utilization by keeping frequently accessed data in the fastest memory hierarchy. +* It eliminates redundant memory accesses since each thread maintains its own working set. +* It provides a clean abstraction for complex algorithms like matrix multiplication where each thread accumulates partial results that eventually combine into the final output. Thread-Local Storage Model ========================== @@ -384,8 +384,7 @@ Static distributed tensors integrate seamlessly with other CK Tile components: // Main GEMM loop for(index_t k_tile = 0; k_tile < K; k_tile += kTileK) { // Create tile windows for this iteration - // See :ref:`ck_tile_tile_window` for details - auto a_window = make_tile_window( + auto a_window = make_tile_window( a_ptr, ALayout{M, K}, ATileDist{}, {blockIdx.y * kTileM, k_tile} @@ -398,7 +397,6 @@ Static distributed tensors integrate seamlessly with other CK Tile components: ); // Load tiles to distributed tensors - // See :ref:`ck_tile_load_store_traits` for optimized loading auto a_tile = a_window.load(); auto b_tile = b_window.load(); diff --git a/docs/conceptual/ck_tile/thread_mapping.rst b/docs/conceptual/ck_tile/thread_mapping.rst index cff4f727ff..361912ba9f 100644 --- a/docs/conceptual/ck_tile/thread_mapping.rst +++ b/docs/conceptual/ck_tile/thread_mapping.rst @@ -356,7 +356,6 @@ CK uses several techniques to optimize memory access: float>>>; // 2. Swizzling to avoid bank conflicts - // See :ref:`ck_tile_lds_index_swapping` and :ref:`ck_tile_swizzling_example` template __device__ index_t swizzle_offset(index_t tid, index_t offset) { @@ -434,7 +433,6 @@ The following example shows how thread mapping works in a CK kernel: __shared__ ComputeType shared_sum[BlockSize]; // 5. Create tensor view and tile window - // See :ref:`ck_tile_tensor_views` and :ref:`ck_tile_tile_window` auto x_view = make_naive_tensor_view( x + bid * hidden_size, make_tuple(hidden_size), diff --git a/docs/conceptual/ck_tile/tile_distribution.rst b/docs/conceptual/ck_tile/tile_distribution.rst index c57a87e5ce..3c016318bf 100644 --- a/docs/conceptual/ck_tile/tile_distribution.rst +++ b/docs/conceptual/ck_tile/tile_distribution.rst @@ -1,4 +1,4 @@ -.. _ck_tile_distribution: +.. _ck_tile_tile_distribution: Tile Distribution - The Core API ================================ diff --git a/docs/conceptual/ck_tile/tile_window.rst b/docs/conceptual/ck_tile/tile_window.rst index 87d2f39b01..23c006d972 100644 --- a/docs/conceptual/ck_tile/tile_window.rst +++ b/docs/conceptual/ck_tile/tile_window.rst @@ -283,7 +283,7 @@ Creating and Using TileWindow using namespace ck_tile; - // Create a tensor view for input data (see :ref:`ck_tile_tensor_views`) + // Create a tensor view for input data auto tensor_view = make_naive_tensor_view( data_ptr, make_tuple(256, 256), // Shape @@ -314,7 +314,7 @@ Creating and Using TileWindow distribution ); - // Load data into distributed tensor (see :ref:`ck_tile_static_distributed_tensor`) + // Load data into distributed tensor auto distributed_data = make_static_distributed_tensor(distribution); window.load(distributed_data); @@ -558,7 +558,6 @@ Complete Load-Compute-Store Pipeline c_dist); // Create distributed tensors for register storage - // See :ref:`ck_tile_static_distributed_tensor` for details auto a_reg = make_static_distributed_tensor(a_dist); auto b_reg = make_static_distributed_tensor(b_dist); auto c_reg = make_static_distributed_tensor(c_dist); @@ -620,6 +619,8 @@ Performance Characteristics .. image:: diagrams/tile_window_5.svg :alt: Diagram :align: center + + Best Practices -------------- diff --git a/docs/conceptual/ck_tile/transforms.rst b/docs/conceptual/ck_tile/transforms.rst index 63b830563e..3dfea276cb 100644 --- a/docs/conceptual/ck_tile/transforms.rst +++ b/docs/conceptual/ck_tile/transforms.rst @@ -302,7 +302,7 @@ EmbedTransform expands linear indices from the lower coordinate space into multi using namespace ck_tile; // Create embed transform for 2x3 tensor with strides [12, 1] - // This is commonly used in :ref:`descriptors ` + // This is commonly used in descriptors auto transform = make_embed_transform(make_tuple(2, 3), make_tuple(12, 1)); // Forward: Linear → 2D (Manual calculation) diff --git a/docs/conf.py b/docs/conf.py index 58e78f3d1d..bb7847e1d6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,8 +30,6 @@ release = version_number external_toc_path = "./sphinx/_toc.yml" docs_core = ROCmDocs(left_nav_title) -docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") -docs_core.enable_api_reference() docs_core.setup() external_projects_current_project = "composable_kernel" @@ -50,4 +48,4 @@ for sphinx_var in ROCmDocs.SPHINX_VARS: extensions += ['sphinxcontrib.bibtex'] bibtex_bibfiles = ['refs.bib'] -cpp_id_attributes = ["__global__", "__device__", "__host__"] +cpp_id_attributes = ["__global__", "__device__", "__host__"] \ No newline at end of file diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile deleted file mode 100644 index 4c8019f8d3..0000000000 --- a/docs/doxygen/Doxyfile +++ /dev/null @@ -1,2778 +0,0 @@ -# Doxyfile 1.9.7 - -# This file describes the settings to be used by the documentation system -# doxygen (www.doxygen.org) for a project. -# -# All text after a double hash (##) is considered a comment and is placed in -# front of the TAG it is preceding. -# -# All text after a single hash (#) is considered a comment and will be ignored. -# The format is: -# TAG = value [value, ...] -# For lists, items can also be appended using: -# TAG += value [value, ...] -# Values that contain spaces should be placed between quotes (\" \"). -# -# Note: -# -# Use doxygen to compare the used configuration file with the template -# configuration file: -# doxygen -x [configFile] -# Use doxygen to compare the used configuration file with the template -# configuration file without replacing the environment variables or CMake type -# replacement variables: -# doxygen -x_noenv [configFile] - -#--------------------------------------------------------------------------- -# Project related configuration options -#--------------------------------------------------------------------------- - -# This tag specifies the encoding used for all characters in the configuration -# file that follow. The default is UTF-8 which is also the encoding used for all -# text before the first occurrence of this tag. Doxygen uses libiconv (or the -# iconv built into libc) for the transcoding. See -# https://www.gnu.org/software/libiconv/ for the list of possible encodings. -# The default value is: UTF-8. - -DOXYFILE_ENCODING = UTF-8 - -# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by -# double-quotes, unless you are using Doxywizard) that should identify the -# project for which the documentation is generated. This name is used in the -# title of most generated pages and in a few other places. -# The default value is: My Project. - -PROJECT_NAME = "Composable Kernel" - -# The PROJECT_NUMBER tag can be used to enter a project or revision number. This -# could be handy for archiving the generated documentation or if some version -# control system is used. - -PROJECT_NUMBER = - -# Using the PROJECT_BRIEF tag one can provide an optional one line description -# for a project that appears at the top of each page and should give viewer a -# quick idea about the purpose of the project. Keep the description short. - -PROJECT_BRIEF = "Prototype interfaces compatible with ROCm platform and HiP" - -# With the PROJECT_LOGO tag one can specify a logo or an icon that is included -# in the documentation. The maximum height of the logo should not exceed 55 -# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy -# the logo to the output directory. - -PROJECT_LOGO = - -# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path -# into which the generated documentation will be written. If a relative path is -# entered, it will be relative to the location where doxygen was started. If -# left blank the current directory will be used. - -OUTPUT_DIRECTORY = . - -# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 -# sub-directories (in 2 levels) under the output directory of each output format -# and will distribute the generated files over these directories. Enabling this -# option can be useful when feeding doxygen a huge amount of source files, where -# putting all generated files in the same directory would otherwise causes -# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to -# control the number of sub-directories. -# The default value is: NO. - -CREATE_SUBDIRS = NO - -# Controls the number of sub-directories that will be created when -# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every -# level increment doubles the number of directories, resulting in 4096 -# directories at level 8 which is the default and also the maximum value. The -# sub-directories are organized in 2 levels, the first level always has a fixed -# number of 16 directories. -# Minimum value: 0, maximum value: 8, default value: 8. -# This tag requires that the tag CREATE_SUBDIRS is set to YES. - -CREATE_SUBDIRS_LEVEL = 8 - -# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII -# characters to appear in the names of generated files. If set to NO, non-ASCII -# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode -# U+3044. -# The default value is: NO. - -ALLOW_UNICODE_NAMES = NO - -# The OUTPUT_LANGUAGE tag is used to specify the language in which all -# documentation generated by doxygen is written. Doxygen will use this -# information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, -# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English -# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, -# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with -# English messages), Korean, Korean-en (Korean with English messages), Latvian, -# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, -# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, -# Swedish, Turkish, Ukrainian and Vietnamese. -# The default value is: English. - -OUTPUT_LANGUAGE = English - -# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member -# descriptions after the members that are listed in the file and class -# documentation (similar to Javadoc). Set to NO to disable this. -# The default value is: YES. - -BRIEF_MEMBER_DESC = YES - -# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief -# description of a member or function before the detailed description -# -# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the -# brief descriptions will be completely suppressed. -# The default value is: YES. - -REPEAT_BRIEF = YES - -# This tag implements a quasi-intelligent brief description abbreviator that is -# used to form the text in various listings. Each string in this list, if found -# as the leading text of the brief description, will be stripped from the text -# and the result, after processing the whole list, is used as the annotated -# text. Otherwise, the brief description is used as-is. If left blank, the -# following values are used ($name is automatically replaced with the name of -# the entity):The $name class, The $name widget, The $name file, is, provides, -# specifies, contains, represents, a, an and the. - -ABBREVIATE_BRIEF = "The $name class" \ - "The $name widget" \ - "The $name file" \ - is \ - provides \ - specifies \ - contains \ - represents \ - a \ - an \ - the - -# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then -# doxygen will generate a detailed section even if there is only a brief -# description. -# The default value is: NO. - -ALWAYS_DETAILED_SEC = NO - -# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all -# inherited members of a class in the documentation of that class as if those -# members were ordinary class members. Constructors, destructors and assignment -# operators of the base classes will not be shown. -# The default value is: NO. - -INLINE_INHERITED_MEMB = NO - -# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path -# before files name in the file list and in the header files. If set to NO the -# shortest path that makes the file name unique will be used -# The default value is: YES. - -FULL_PATH_NAMES = YES - -# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. -# Stripping is only done if one of the specified strings matches the left-hand -# part of the path. The tag can be used to show relative paths in the file list. -# If left blank the directory from which doxygen is run is used as the path to -# strip. -# -# Note that you can specify absolute paths here, but also relative paths, which -# will be relative from the directory where doxygen is started. -# This tag requires that the tag FULL_PATH_NAMES is set to YES. - -#STRIP_FROM_PATH = -STRIP_FROM_PATH = /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/latest/ - -# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the -# path mentioned in the documentation of a class, which tells the reader which -# header file to include in order to use a class. If left blank only the name of -# the header file containing the class definition is used. Otherwise one should -# specify the list of include paths that are normally passed to the compiler -# using the -I flag. - -STRIP_FROM_INC_PATH = - - -# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but -# less readable) file names. This can be useful is your file systems doesn't -# support long names like on DOS, Mac, or CD-ROM. -# The default value is: NO. - -SHORT_NAMES = NO - -# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the -# first line (until the first dot) of a Javadoc-style comment as the brief -# description. If set to NO, the Javadoc-style will behave just like regular Qt- -# style comments (thus requiring an explicit @brief command for a brief -# description.) -# The default value is: NO. - -JAVADOC_AUTOBRIEF = NO - -# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line -# such as -# /*************** -# as being the beginning of a Javadoc-style comment "banner". If set to NO, the -# Javadoc-style will behave just like regular comments and it will not be -# interpreted by doxygen. -# The default value is: NO. - -JAVADOC_BANNER = NO - -# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first -# line (until the first dot) of a Qt-style comment as the brief description. If -# set to NO, the Qt-style will behave just like regular Qt-style comments (thus -# requiring an explicit \brief command for a brief description.) -# The default value is: NO. - -QT_AUTOBRIEF = NO - -# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a -# multi-line C++ special comment block (i.e. a block of //! or /// comments) as -# a brief description. This used to be the default behavior. The new default is -# to treat a multi-line C++ comment block as a detailed description. Set this -# tag to YES if you prefer the old behavior instead. -# -# Note that setting this tag to YES also means that rational rose comments are -# not recognized any more. -# The default value is: NO. - -MULTILINE_CPP_IS_BRIEF = NO - -# By default Python docstrings are displayed as preformatted text and doxygen's -# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the -# doxygen's special commands can be used and the contents of the docstring -# documentation blocks is shown as doxygen documentation. -# The default value is: YES. - -PYTHON_DOCSTRING = YES - -# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the -# documentation from any documented member that it re-implements. -# The default value is: YES. - -INHERIT_DOCS = YES - -# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new -# page for each member. If set to NO, the documentation of a member will be part -# of the file/class/namespace that contains it. -# The default value is: NO. - -SEPARATE_MEMBER_PAGES = NO - -# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen -# uses this value to replace tabs by spaces in code fragments. -# Minimum value: 1, maximum value: 16, default value: 4. - -TAB_SIZE = 4 - -# This tag can be used to specify a number of aliases that act as commands in -# the documentation. An alias has the form: -# name=value -# For example adding -# "sideeffect=@par Side Effects:^^" -# will allow you to put the command \sideeffect (or @sideeffect) in the -# documentation, which will result in a user-defined paragraph with heading -# "Side Effects:". Note that you cannot put \n's in the value part of an alias -# to insert newlines (in the resulting output). You can put ^^ in the value part -# of an alias to insert a newline as if a physical newline was in the original -# file. When you need a literal { or } or , in the value part of an alias you -# have to escape them by means of a backslash (\), this can lead to conflicts -# with the commands \{ and \} for these it is advised to use the version @{ and -# @} or use a double escape (\\{ and \\}) - -ALIASES = - -# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources -# only. Doxygen will then generate output that is more tailored for C. For -# instance, some of the names that are used will be different. The list of all -# members will be omitted, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_FOR_C = NO - -# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or -# Python sources only. Doxygen will then generate output that is more tailored -# for that language. For instance, namespaces will be presented as packages, -# qualified scopes will look different, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_JAVA = NO - -# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran -# sources. Doxygen will then generate output that is tailored for Fortran. -# The default value is: NO. - -OPTIMIZE_FOR_FORTRAN = NO - -# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL -# sources. Doxygen will then generate output that is tailored for VHDL. -# The default value is: NO. - -OPTIMIZE_OUTPUT_VHDL = NO - -# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice -# sources only. Doxygen will then generate output that is more tailored for that -# language. For instance, namespaces will be presented as modules, types will be -# separated into more groups, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_SLICE = NO - -# Doxygen selects the parser to use depending on the extension of the files it -# parses. With this tag you can assign which parser to use for a given -# extension. Doxygen has a built-in mapping, but you can override or extend it -# using this tag. The format is ext=language, where ext is a file extension, and -# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, -# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, -# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: -# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser -# tries to guess whether the code is fixed or free formatted code, this is the -# default for Fortran type files). For instance to make doxygen treat .inc files -# as Fortran files (default is PHP), and .f files as C (default is Fortran), -# use: inc=Fortran f=C. -# -# Note: For files without extension you can use no_extension as a placeholder. -# -# Note that for custom extensions you also need to set FILE_PATTERNS otherwise -# the files are not read by doxygen. When specifying no_extension you should add -# * to the FILE_PATTERNS. -# -# Note see also the list of default file extension mappings. - -EXTENSION_MAPPING = - -# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments -# according to the Markdown format, which allows for more readable -# documentation. See https://daringfireball.net/projects/markdown/ for details. -# The output of markdown processing is further processed by doxygen, so you can -# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in -# case of backward compatibilities issues. -# The default value is: YES. - -MARKDOWN_SUPPORT = YES - -# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up -# to that level are automatically included in the table of contents, even if -# they do not have an id attribute. -# Note: This feature currently applies only to Markdown headings. -# Minimum value: 0, maximum value: 99, default value: 5. -# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. - -TOC_INCLUDE_HEADINGS = 5 - -# The MARKDOWN_ID_STYLE tag can be used to specify the algorithm used to -# generate identifiers for the Markdown headings. Note: Every identifier is -# unique. -# Possible values are: DOXYGEN Use a fixed 'autotoc_md' string followed by a -# sequence number starting at 0. and GITHUB Use the lower case version of title -# with any whitespace replaced by '-' and punctations characters removed.. -# The default value is: DOXYGEN. -# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. - -MARKDOWN_ID_STYLE = DOXYGEN - -# When enabled doxygen tries to link words that correspond to documented -# classes, or namespaces to their corresponding documentation. Such a link can -# be prevented in individual cases by putting a % sign in front of the word or -# globally by setting AUTOLINK_SUPPORT to NO. -# The default value is: YES. - -AUTOLINK_SUPPORT = YES - -# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want -# to include (a tag file for) the STL sources as input, then you should set this -# tag to YES in order to let doxygen match functions declarations and -# definitions whose arguments contain STL classes (e.g. func(std::string); -# versus func(std::string) {}). This also make the inheritance and collaboration -# diagrams that involve STL classes more complete and accurate. -# The default value is: NO. - -BUILTIN_STL_SUPPORT = YES - -# If you use Microsoft's C++/CLI language, you should set this option to YES to -# enable parsing support. -# The default value is: NO. - -CPP_CLI_SUPPORT = NO - -# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: -# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen -# will parse them like normal C++ but will assume all classes use public instead -# of private inheritance when no explicit protection keyword is present. -# The default value is: NO. - -SIP_SUPPORT = NO - -# For Microsoft's IDL there are propget and propput attributes to indicate -# getter and setter methods for a property. Setting this option to YES will make -# doxygen to replace the get and set methods by a property in the documentation. -# This will only work if the methods are indeed getting or setting a simple -# type. If this is not the case, or you want to show the methods anyway, you -# should set this option to NO. -# The default value is: YES. - -IDL_PROPERTY_SUPPORT = YES - -# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC -# tag is set to YES then doxygen will reuse the documentation of the first -# member in the group (if any) for the other members of the group. By default -# all members of a group must be documented explicitly. -# The default value is: NO. - -DISTRIBUTE_GROUP_DOC = YES - -# If one adds a struct or class to a group and this option is enabled, then also -# any nested class or struct is added to the same group. By default this option -# is disabled and one has to add nested compounds explicitly via \ingroup. -# The default value is: NO. - -GROUP_NESTED_COMPOUNDS = NO - -# Set the SUBGROUPING tag to YES to allow class member groups of the same type -# (for instance a group of public functions) to be put as a subgroup of that -# type (e.g. under the Public Functions section). Set it to NO to prevent -# subgrouping. Alternatively, this can be done per class using the -# \nosubgrouping command. -# The default value is: YES. - -SUBGROUPING = YES - -# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions -# are shown inside the group in which they are included (e.g. using \ingroup) -# instead of on a separate page (for HTML and Man pages) or section (for LaTeX -# and RTF). -# -# Note that this feature does not work in combination with -# SEPARATE_MEMBER_PAGES. -# The default value is: NO. - -INLINE_GROUPED_CLASSES = NO - -# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions -# with only public data fields or simple typedef fields will be shown inline in -# the documentation of the scope in which they are defined (i.e. file, -# namespace, or group documentation), provided this scope is documented. If set -# to NO, structs, classes, and unions are shown on a separate page (for HTML and -# Man pages) or section (for LaTeX and RTF). -# The default value is: NO. - -INLINE_SIMPLE_STRUCTS = NO - -# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or -# enum is documented as struct, union, or enum with the name of the typedef. So -# typedef struct TypeS {} TypeT, will appear in the documentation as a struct -# with name TypeT. When disabled the typedef will appear as a member of a file, -# namespace, or class. And the struct will be named TypeS. This can typically be -# useful for C code in case the coding convention dictates that all compound -# types are typedef'ed and only the typedef is referenced, never the tag name. -# The default value is: NO. - -TYPEDEF_HIDES_STRUCT = YES - -# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This -# cache is used to resolve symbols given their name and scope. Since this can be -# an expensive process and often the same symbol appears multiple times in the -# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small -# doxygen will become slower. If the cache is too large, memory is wasted. The -# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range -# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 -# symbols. At the end of a run doxygen will report the cache usage and suggest -# the optimal cache size from a speed point of view. -# Minimum value: 0, maximum value: 9, default value: 0. - -LOOKUP_CACHE_SIZE = 0 - -# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use -# during processing. When set to 0 doxygen will based this on the number of -# cores available in the system. You can set it explicitly to a value larger -# than 0 to get more control over the balance between CPU load and processing -# speed. At this moment only the input processing can be done using multiple -# threads. Since this is still an experimental feature the default is set to 1, -# which effectively disables parallel processing. Please report any issues you -# encounter. Generating dot graphs in parallel is controlled by the -# DOT_NUM_THREADS setting. -# Minimum value: 0, maximum value: 32, default value: 1. - -NUM_PROC_THREADS = 1 - -# If the TIMESTAMP tag is set different from NO then each generated page will -# contain the date or date and time when the page was generated. Setting this to -# NO can help when comparing the output of multiple runs. -# Possible values are: YES, NO, DATETIME and DATE. -# The default value is: NO. - -TIMESTAMP = YES - -#--------------------------------------------------------------------------- -# Build related configuration options -#--------------------------------------------------------------------------- - -# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in -# documentation are documented, even if no documentation was available. Private -# class members and static file members will be hidden unless the -# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. -# Note: This will also disable the warnings about undocumented members that are -# normally produced when WARNINGS is set to YES. -# The default value is: NO. - -EXTRACT_ALL = YES - -# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will -# be included in the documentation. -# The default value is: NO. - -EXTRACT_PRIVATE = NO - -# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual -# methods of a class will be included in the documentation. -# The default value is: NO. - -EXTRACT_PRIV_VIRTUAL = NO - -# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal -# scope will be included in the documentation. -# The default value is: NO. - -EXTRACT_PACKAGE = NO - -# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be -# included in the documentation. -# The default value is: NO. - -EXTRACT_STATIC = NO - -# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined -# locally in source files will be included in the documentation. If set to NO, -# only classes defined in header files are included. Does not have any effect -# for Java sources. -# The default value is: YES. - -EXTRACT_LOCAL_CLASSES = YES - -# This flag is only useful for Objective-C code. If set to YES, local methods, -# which are defined in the implementation section but not in the interface are -# included in the documentation. If set to NO, only methods in the interface are -# included. -# The default value is: NO. - -EXTRACT_LOCAL_METHODS = NO - -# If this flag is set to YES, the members of anonymous namespaces will be -# extracted and appear in the documentation as a namespace called -# 'anonymous_namespace{file}', where file will be replaced with the base name of -# the file that contains the anonymous namespace. By default anonymous namespace -# are hidden. -# The default value is: NO. - -EXTRACT_ANON_NSPACES = NO - -# If this flag is set to YES, the name of an unnamed parameter in a declaration -# will be determined by the corresponding definition. By default unnamed -# parameters remain unnamed in the output. -# The default value is: YES. - -RESOLVE_UNNAMED_PARAMS = YES - -# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all -# undocumented members inside documented classes or files. If set to NO these -# members will be included in the various overviews, but no documentation -# section is generated. This option has no effect if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_MEMBERS = NO - -# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all -# undocumented classes that are normally visible in the class hierarchy. If set -# to NO, these classes will be included in the various overviews. This option -# will also hide undocumented C++ concepts if enabled. This option has no effect -# if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_CLASSES = NO - -# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend -# declarations. If set to NO, these declarations will be included in the -# documentation. -# The default value is: NO. - -HIDE_FRIEND_COMPOUNDS = NO - -# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any -# documentation blocks found inside the body of a function. If set to NO, these -# blocks will be appended to the function's detailed documentation block. -# The default value is: NO. - -HIDE_IN_BODY_DOCS = NO - -# The INTERNAL_DOCS tag determines if documentation that is typed after a -# \internal command is included. If the tag is set to NO then the documentation -# will be excluded. Set it to YES to include the internal documentation. -# The default value is: NO. - -INTERNAL_DOCS = NO - -# With the correct setting of option CASE_SENSE_NAMES doxygen will better be -# able to match the capabilities of the underlying filesystem. In case the -# filesystem is case sensitive (i.e. it supports files in the same directory -# whose names only differ in casing), the option must be set to YES to properly -# deal with such files in case they appear in the input. For filesystems that -# are not case sensitive the option should be set to NO to properly deal with -# output files written for symbols that only differ in casing, such as for two -# classes, one named CLASS and the other named Class, and to also support -# references to files without having to specify the exact matching casing. On -# Windows (including Cygwin) and MacOS, users should typically set this option -# to NO, whereas on Linux or other Unix flavors it should typically be set to -# YES. -# Possible values are: SYSTEM, NO and YES. -# The default value is: SYSTEM. - -CASE_SENSE_NAMES = NO - -# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with -# their full class and namespace scopes in the documentation. If set to YES, the -# scope will be hidden. -# The default value is: NO. - -HIDE_SCOPE_NAMES = NO - -# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will -# append additional text to a page's title, such as Class Reference. If set to -# YES the compound reference will be hidden. -# The default value is: NO. - -HIDE_COMPOUND_REFERENCE= NO - -# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class -# will show which file needs to be included to use the class. -# The default value is: YES. - -SHOW_HEADERFILE = YES - -# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of -# the files that are included by a file in the documentation of that file. -# The default value is: YES. - -SHOW_INCLUDE_FILES = YES - -# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each -# grouped member an include statement to the documentation, telling the reader -# which file to include in order to use the member. -# The default value is: NO. - -SHOW_GROUPED_MEMB_INC = NO - -# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include -# files with double quotes in the documentation rather than with sharp brackets. -# The default value is: NO. - -FORCE_LOCAL_INCLUDES = NO - -# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the -# documentation for inline members. -# The default value is: YES. - -INLINE_INFO = YES - -# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the -# (detailed) documentation of file and class members alphabetically by member -# name. If set to NO, the members will appear in declaration order. -# The default value is: YES. - -SORT_MEMBER_DOCS = YES - -# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief -# descriptions of file, namespace and class members alphabetically by member -# name. If set to NO, the members will appear in declaration order. Note that -# this will also influence the order of the classes in the class list. -# The default value is: NO. - -SORT_BRIEF_DOCS = NO - -# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the -# (brief and detailed) documentation of class members so that constructors and -# destructors are listed first. If set to NO the constructors will appear in the -# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. -# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief -# member documentation. -# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting -# detailed member documentation. -# The default value is: NO. - -SORT_MEMBERS_CTORS_1ST = NO - -# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy -# of group names into alphabetical order. If set to NO the group names will -# appear in their defined order. -# The default value is: NO. - -SORT_GROUP_NAMES = NO - -# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by -# fully-qualified names, including namespaces. If set to NO, the class list will -# be sorted only by class name, not including the namespace part. -# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. -# Note: This option applies only to the class list, not to the alphabetical -# list. -# The default value is: NO. - -SORT_BY_SCOPE_NAME = NO - -# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper -# type resolution of all parameters of a function it will reject a match between -# the prototype and the implementation of a member function even if there is -# only one candidate or it is obvious which candidate to choose by doing a -# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still -# accept a match between prototype and implementation in such cases. -# The default value is: NO. - -STRICT_PROTO_MATCHING = NO - -# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo -# list. This list is created by putting \todo commands in the documentation. -# The default value is: YES. - -GENERATE_TODOLIST = YES - -# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test -# list. This list is created by putting \test commands in the documentation. -# The default value is: YES. - -GENERATE_TESTLIST = YES - -# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug -# list. This list is created by putting \bug commands in the documentation. -# The default value is: YES. - -GENERATE_BUGLIST = YES - -# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) -# the deprecated list. This list is created by putting \deprecated commands in -# the documentation. -# The default value is: YES. - -GENERATE_DEPRECATEDLIST= YES - -# The ENABLED_SECTIONS tag can be used to enable conditional documentation -# sections, marked by \if ... \endif and \cond -# ... \endcond blocks. - -ENABLED_SECTIONS = - -# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the -# initial value of a variable or macro / define can have for it to appear in the -# documentation. If the initializer consists of more lines than specified here -# it will be hidden. Use a value of 0 to hide initializers completely. The -# appearance of the value of individual variables and macros / defines can be -# controlled using \showinitializer or \hideinitializer command in the -# documentation regardless of this setting. -# Minimum value: 0, maximum value: 10000, default value: 30. - -MAX_INITIALIZER_LINES = 30 - -# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at -# the bottom of the documentation of classes and structs. If set to YES, the -# list will mention the files that were used to generate the documentation. -# The default value is: YES. - -SHOW_USED_FILES = YES - -# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This -# will remove the Files entry from the Quick Index and from the Folder Tree View -# (if specified). -# The default value is: YES. - -SHOW_FILES = YES - -# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces -# page. This will remove the Namespaces entry from the Quick Index and from the -# Folder Tree View (if specified). -# The default value is: YES. - -SHOW_NAMESPACES = YES - -# The FILE_VERSION_FILTER tag can be used to specify a program or script that -# doxygen should invoke to get the current version for each file (typically from -# the version control system). Doxygen will invoke the program by executing (via -# popen()) the command command input-file, where command is the value of the -# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided -# by doxygen. Whatever the program writes to standard output is used as the file -# version. For an example see the documentation. - -FILE_VERSION_FILTER = - -# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed -# by doxygen. The layout file controls the global structure of the generated -# output files in an output format independent way. To create the layout file -# that represents doxygen's defaults, run doxygen with the -l option. You can -# optionally specify a file name after the option, if omitted DoxygenLayout.xml -# will be used as the name of the layout file. See also section "Changing the -# layout of pages" for information. -# -# Note that if you run doxygen from a directory containing a file called -# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE -# tag is left empty. - -LAYOUT_FILE = - -# The CITE_BIB_FILES tag can be used to specify one or more bib files containing -# the reference definitions. This must be a list of .bib files. The .bib -# extension is automatically appended if omitted. This requires the bibtex tool -# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. -# For LaTeX the style of the bibliography can be controlled using -# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the -# search path. See also \cite for info how to create references. - -CITE_BIB_FILES = - -#--------------------------------------------------------------------------- -# Configuration options related to warning and progress messages -#--------------------------------------------------------------------------- - -# The QUIET tag can be used to turn on/off the messages that are generated to -# standard output by doxygen. If QUIET is set to YES this implies that the -# messages are off. -# The default value is: NO. - -QUIET = NO - -# The WARNINGS tag can be used to turn on/off the warning messages that are -# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES -# this implies that the warnings are on. -# -# Tip: Turn warnings on while writing the documentation. -# The default value is: YES. - -WARNINGS = YES - -# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate -# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag -# will automatically be disabled. -# The default value is: YES. - -WARN_IF_UNDOCUMENTED = YES - -# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for -# potential errors in the documentation, such as documenting some parameters in -# a documented function twice, or documenting parameters that don't exist or -# using markup commands wrongly. -# The default value is: YES. - -WARN_IF_DOC_ERROR = YES - -# If WARN_IF_INCOMPLETE_DOC is set to YES, doxygen will warn about incomplete -# function parameter documentation. If set to NO, doxygen will accept that some -# parameters have no documentation without warning. -# The default value is: YES. - -WARN_IF_INCOMPLETE_DOC = YES - -# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that -# are documented, but have no documentation for their parameters or return -# value. If set to NO, doxygen will only warn about wrong parameter -# documentation, but not about the absence of documentation. If EXTRACT_ALL is -# set to YES then this flag will automatically be disabled. See also -# WARN_IF_INCOMPLETE_DOC -# The default value is: NO. - -WARN_NO_PARAMDOC = NO - -# If WARN_IF_UNDOC_ENUM_VAL option is set to YES, doxygen will warn about -# undocumented enumeration values. If set to NO, doxygen will accept -# undocumented enumeration values. If EXTRACT_ALL is set to YES then this flag -# will automatically be disabled. -# The default value is: NO. - -WARN_IF_UNDOC_ENUM_VAL = NO - -# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when -# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS -# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but -# at the end of the doxygen process doxygen will return with a non-zero status. -# If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS_PRINT then doxygen behaves -# like FAIL_ON_WARNINGS but in case no WARN_LOGFILE is defined doxygen will not -# write the warning messages in between other messages but write them at the end -# of a run, in case a WARN_LOGFILE is defined the warning messages will be -# besides being in the defined file also be shown at the end of a run, unless -# the WARN_LOGFILE is defined as - i.e. standard output (stdout) in that case -# the behavior will remain as with the setting FAIL_ON_WARNINGS. -# Possible values are: NO, YES, FAIL_ON_WARNINGS and FAIL_ON_WARNINGS_PRINT. -# The default value is: NO. - -WARN_AS_ERROR = NO - -# The WARN_FORMAT tag determines the format of the warning messages that doxygen -# can produce. The string should contain the $file, $line, and $text tags, which -# will be replaced by the file and line number from which the warning originated -# and the warning text. Optionally the format may contain $version, which will -# be replaced by the version of the file (if it could be obtained via -# FILE_VERSION_FILTER) -# See also: WARN_LINE_FORMAT -# The default value is: $file:$line: $text. - -WARN_FORMAT = "$file:$line: $text" - -# In the $text part of the WARN_FORMAT command it is possible that a reference -# to a more specific place is given. To make it easier to jump to this place -# (outside of doxygen) the user can define a custom "cut" / "paste" string. -# Example: -# WARN_LINE_FORMAT = "'vi $file +$line'" -# See also: WARN_FORMAT -# The default value is: at line $line of file $file. - -WARN_LINE_FORMAT = "at line $line of file $file" - -# The WARN_LOGFILE tag can be used to specify a file to which warning and error -# messages should be written. If left blank the output is written to standard -# error (stderr). In case the file specified cannot be opened for writing the -# warning and error messages are written to standard error. When as file - is -# specified the warning and error messages are written to standard output -# (stdout). - -WARN_LOGFILE = - -#--------------------------------------------------------------------------- -# Configuration options related to the input files -#--------------------------------------------------------------------------- - -# The INPUT tag is used to specify the files and/or directories that contain -# documented source files. You may enter file names like myfile.cpp or -# directories like /usr/src/myproject. Separate the files or directories with -# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING -# Note: If this tag is empty the current directory is searched. - -INPUT = ../../include \ - ../../include/ck/ \ - ../../library/include/ck/library/utility \ - ../../include/ck_tile - -# This tag can be used to specify the character encoding of the source files -# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses -# libiconv (or the iconv built into libc) for the transcoding. See the libiconv -# documentation (see: -# https://www.gnu.org/software/libiconv/) for the list of possible encodings. -# See also: INPUT_FILE_ENCODING -# The default value is: UTF-8. - -INPUT_ENCODING = UTF-8 - -# This tag can be used to specify the character encoding of the source files -# that doxygen parses The INPUT_FILE_ENCODING tag can be used to specify -# character encoding on a per file pattern basis. Doxygen will compare the file -# name with each pattern and apply the encoding instead of the default -# INPUT_ENCODING) if there is a match. The character encodings are a list of the -# form: pattern=encoding (like *.php=ISO-8859-1). See cfg_input_encoding -# "INPUT_ENCODING" for further information on supported encodings. - -INPUT_FILE_ENCODING = - -# If the value of the INPUT tag contains directories, you can use the -# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and -# *.h) to filter out the source-files in the directories. -# -# Note that for custom extensions or not directly supported extensions you also -# need to set EXTENSION_MAPPING for the extension otherwise the files are not -# read by doxygen. -# -# Note the list of default checked file patterns might differ from the list of -# default file extension mappings. -# -# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, -# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, -# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, -# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C -# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, -# *.vhdl, *.ucf, *.qsf and *.ice. - -FILE_PATTERNS = *.c \ - *.cc \ - *.cxx \ - *.cpp \ - *.c++ \ - *.java \ - *.ii \ - *.ixx \ - *.ipp \ - *.i++ \ - *.inl \ - *.idl \ - *.ddl \ - *.odl \ - *.h \ - *.hh \ - *.hxx \ - *.hpp \ - *.h++ \ - *.l \ - *.cs \ - *.d \ - *.php \ - *.php4 \ - *.php5 \ - *.phtml \ - *.inc \ - *.m \ - *.markdown \ - *.md \ - *.mm \ - *.dox \ - *.py \ - *.pyw \ - *.f90 \ - *.f95 \ - *.f03 \ - *.f08 \ - *.f18 \ - *.f \ - *.for \ - *.vhd \ - *.vhdl \ - *.ucf \ - *.qsf \ - *.ice - -# The RECURSIVE tag can be used to specify whether or not subdirectories should -# be searched for input files as well. -# The default value is: NO. - -RECURSIVE = YES - -# The EXCLUDE tag can be used to specify files and/or directories that should be -# excluded from the INPUT source files. This way you can easily exclude a -# subdirectory from a directory tree whose root is specified with the INPUT tag. -# -# Note that relative paths are relative to the directory from which doxygen is -# run. - -EXCLUDE = - -# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or -# directories that are symbolic links (a Unix file system feature) are excluded -# from the input. -# The default value is: NO. - -EXCLUDE_SYMLINKS = NO - -# If the value of the INPUT tag contains directories, you can use the -# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude -# certain files from those directories. -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories for example use the pattern */test/* - -EXCLUDE_PATTERNS = - -# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names -# (namespaces, classes, functions, etc.) that should be excluded from the -# output. The symbol name can be a fully qualified name, a word, or if the -# wildcard * is used, a substring. Examples: ANamespace, AClass, -# ANamespace::AClass, ANamespace::*Test - -EXCLUDE_SYMBOLS = - -# The EXAMPLE_PATH tag can be used to specify one or more files or directories -# that contain example code fragments that are included (see the \include -# command). - -EXAMPLE_PATH = - -# If the value of the EXAMPLE_PATH tag contains directories, you can use the -# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and -# *.h) to filter out the source-files in the directories. If left blank all -# files are included. - -EXAMPLE_PATTERNS = * - -# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be -# searched for input files to be used with the \include or \dontinclude commands -# irrespective of the value of the RECURSIVE tag. -# The default value is: NO. - -EXAMPLE_RECURSIVE = NO - -# The IMAGE_PATH tag can be used to specify one or more files or directories -# that contain images that are to be included in the documentation (see the -# \image command). - -IMAGE_PATH = - -# The INPUT_FILTER tag can be used to specify a program that doxygen should -# invoke to filter for each input file. Doxygen will invoke the filter program -# by executing (via popen()) the command: -# -# -# -# where is the value of the INPUT_FILTER tag, and is the -# name of an input file. Doxygen will then use the output that the filter -# program writes to standard output. If FILTER_PATTERNS is specified, this tag -# will be ignored. -# -# Note that the filter must not add or remove lines; it is applied before the -# code is scanned, but not when the output code is generated. If lines are added -# or removed, the anchors will not be placed correctly. -# -# Note that doxygen will use the data processed and written to standard output -# for further processing, therefore nothing else, like debug statements or used -# commands (so in case of a Windows batch file always use @echo OFF), should be -# written to standard output. -# -# Note that for custom extensions or not directly supported extensions you also -# need to set EXTENSION_MAPPING for the extension otherwise the files are not -# properly processed by doxygen. - -INPUT_FILTER = - -# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern -# basis. Doxygen will compare the file name with each pattern and apply the -# filter if there is a match. The filters are a list of the form: pattern=filter -# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how -# filters are used. If the FILTER_PATTERNS tag is empty or if none of the -# patterns match the file name, INPUT_FILTER is applied. -# -# Note that for custom extensions or not directly supported extensions you also -# need to set EXTENSION_MAPPING for the extension otherwise the files are not -# properly processed by doxygen. - -FILTER_PATTERNS = - -# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using -# INPUT_FILTER) will also be used to filter the input files that are used for -# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). -# The default value is: NO. - -FILTER_SOURCE_FILES = NO - -# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file -# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and -# it is also possible to disable source filtering for a specific pattern using -# *.ext= (so without naming a filter). -# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. - -FILTER_SOURCE_PATTERNS = - -# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that -# is part of the input, its contents will be placed on the main page -# (index.html). This can be useful if you have a project on for instance GitHub -# and want to reuse the introduction page also for the doxygen output. - - -USE_MDFILE_AS_MAINPAGE = - -# The Fortran standard specifies that for fixed formatted Fortran code all -# characters from position 72 are to be considered as comment. A common -# extension is to allow longer lines before the automatic comment starts. The -# setting FORTRAN_COMMENT_AFTER will also make it possible that longer lines can -# be processed before the automatic comment starts. -# Minimum value: 7, maximum value: 10000, default value: 72. - -FORTRAN_COMMENT_AFTER = 72 - -#--------------------------------------------------------------------------- -# Configuration options related to source browsing -#--------------------------------------------------------------------------- - -# If the SOURCE_BROWSER tag is set to YES then a list of source files will be -# generated. Documented entities will be cross-referenced with these sources. -# -# Note: To get rid of all source code in the generated output, make sure that -# also VERBATIM_HEADERS is set to NO. -# The default value is: NO. - -SOURCE_BROWSER = NO - -# Setting the INLINE_SOURCES tag to YES will include the body of functions, -# classes and enums directly into the documentation. -# The default value is: NO. - -INLINE_SOURCES = NO - -# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any -# special comment blocks from generated source code fragments. Normal C, C++ and -# Fortran comments will always remain visible. -# The default value is: YES. - -STRIP_CODE_COMMENTS = YES - -# If the REFERENCED_BY_RELATION tag is set to YES then for each documented -# entity all documented functions referencing it will be listed. -# The default value is: NO. - -REFERENCED_BY_RELATION = NO - -# If the REFERENCES_RELATION tag is set to YES then for each documented function -# all documented entities called/used by that function will be listed. -# The default value is: NO. - -REFERENCES_RELATION = NO - -# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set -# to YES then the hyperlinks from functions in REFERENCES_RELATION and -# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will -# link to the documentation. -# The default value is: YES. - -REFERENCES_LINK_SOURCE = YES - -# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the -# source code will show a tooltip with additional information such as prototype, -# brief description and links to the definition and documentation. Since this -# will make the HTML file larger and loading of large files a bit slower, you -# can opt to disable this feature. -# The default value is: YES. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -SOURCE_TOOLTIPS = YES - -# If the USE_HTAGS tag is set to YES then the references to source code will -# point to the HTML generated by the htags(1) tool instead of doxygen built-in -# source browser. The htags tool is part of GNU's global source tagging system -# (see https://www.gnu.org/software/global/global.html). You will need version -# 4.8.6 or higher. -# -# To use it do the following: -# - Install the latest version of global -# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file -# - Make sure the INPUT points to the root of the source tree -# - Run doxygen as normal -# -# Doxygen will invoke htags (and that will in turn invoke gtags), so these -# tools must be available from the command line (i.e. in the search path). -# -# The result: instead of the source browser generated by doxygen, the links to -# source code will now point to the output of htags. -# The default value is: NO. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -USE_HTAGS = NO - -# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a -# verbatim copy of the header file for each class for which an include is -# specified. Set to NO to disable this. -# See also: Section \class. -# The default value is: YES. - -VERBATIM_HEADERS = YES - -#--------------------------------------------------------------------------- -# Configuration options related to the alphabetical class index -#--------------------------------------------------------------------------- - -# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all -# compounds will be generated. Enable this if the project contains a lot of -# classes, structs, unions or interfaces. -# The default value is: YES. - -ALPHABETICAL_INDEX = YES - -# The IGNORE_PREFIX tag can be used to specify a prefix (or a list of prefixes) -# that should be ignored while generating the index headers. The IGNORE_PREFIX -# tag works for classes, function and member names. The entity will be placed in -# the alphabetical list under the first letter of the entity name that remains -# after removing the prefix. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -IGNORE_PREFIX = - -#--------------------------------------------------------------------------- -# Configuration options related to the HTML output -#--------------------------------------------------------------------------- - -# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output -# The default value is: YES. - -GENERATE_HTML = YES - -# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a -# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of -# it. -# The default directory is: html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_OUTPUT = html - -# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each -# generated HTML page (for example: .htm, .php, .asp). -# The default value is: .html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FILE_EXTENSION = .html - -# The HTML_HEADER tag can be used to specify a user-defined HTML header file for -# each generated HTML page. If the tag is left blank doxygen will generate a -# standard header. -# -# To get valid HTML the header file that includes any scripts and style sheets -# that doxygen needs, which is dependent on the configuration options used (e.g. -# the setting GENERATE_TREEVIEW). It is highly recommended to start with a -# default header using -# doxygen -w html new_header.html new_footer.html new_stylesheet.css -# YourConfigFile -# and then modify the file new_header.html. See also section "Doxygen usage" -# for information on how to generate the default header that doxygen normally -# uses. -# Note: The header is subject to change so you typically have to regenerate the -# default header when upgrading to a newer version of doxygen. For a description -# of the possible markers and block names see the documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_HEADER = ../_doxygen/header.html - -# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each -# generated HTML page. If the tag is left blank doxygen will generate a standard -# footer. See HTML_HEADER for more information on how to generate a default -# footer and what special commands can be used inside the footer. See also -# section "Doxygen usage" for information on how to generate the default footer -# that doxygen normally uses. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FOOTER = ../_doxygen/footer.html - -# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style -# sheet that is used by each HTML page. It can be used to fine-tune the look of -# the HTML output. If left blank doxygen will generate a default style sheet. -# See also section "Doxygen usage" for information on how to generate the style -# sheet that doxygen normally uses. -# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as -# it is more robust and this tag (HTML_STYLESHEET) will in the future become -# obsolete. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_STYLESHEET = ../_doxygen/stylesheet.css - -# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined -# cascading style sheets that are included after the standard style sheets -# created by doxygen. Using this option one can overrule certain style aspects. -# This is preferred over using HTML_STYLESHEET since it does not replace the -# standard style sheet and is therefore more robust against future updates. -# Doxygen will copy the style sheet files to the output directory. -# Note: The order of the extra style sheet files is of importance (e.g. the last -# style sheet in the list overrules the setting of the previous ones in the -# list). -# Note: Since the styling of scrollbars can currently not be overruled in -# Webkit/Chromium, the styling will be left out of the default doxygen.css if -# one or more extra stylesheets have been specified. So if scrollbar -# customization is desired it has to be added explicitly. For an example see the -# documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_STYLESHEET = ../_doxygen/extra_stylesheet.css - -# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or -# other source files which should be copied to the HTML output directory. Note -# that these files will be copied to the base HTML output directory. Use the -# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these -# files. In the HTML_STYLESHEET file, use the file name only. Also note that the -# files will be copied as-is; there are no commands or markers available. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_FILES = ../_doxygen/extra_stylesheet.css - -# The HTML_COLORSTYLE tag can be used to specify if the generated HTML output -# should be rendered with a dark or light theme. -# Possible values are: LIGHT always generate light mode output, DARK always -# generate dark mode output, AUTO_LIGHT automatically set the mode according to -# the user preference, use light mode if no preference is set (the default), -# AUTO_DARK automatically set the mode according to the user preference, use -# dark mode if no preference is set and TOGGLE allow to user to switch between -# light and dark mode via a button. -# The default value is: AUTO_LIGHT. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE = LIGHT - -# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen -# will adjust the colors in the style sheet and background images according to -# this color. Hue is specified as an angle on a color-wheel, see -# https://en.wikipedia.org/wiki/Hue for more information. For instance the value -# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 -# purple, and 360 is red again. -# Minimum value: 0, maximum value: 359, default value: 220. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_HUE = 240 - -# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors -# in the HTML output. For a value of 0 the output will use gray-scales only. A -# value of 255 will produce the most vivid colors. -# Minimum value: 0, maximum value: 255, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_SAT = 100 - -# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the -# luminance component of the colors in the HTML output. Values below 100 -# gradually make the output lighter, whereas values above 100 make the output -# darker. The value divided by 100 is the actual gamma applied, so 80 represents -# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not -# change the gamma. -# Minimum value: 40, maximum value: 240, default value: 80. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_GAMMA = 80 - -# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML -# documentation will contain a main index with vertical navigation menus that -# are dynamically created via JavaScript. If disabled, the navigation index will -# consists of multiple levels of tabs that are statically embedded in every HTML -# page. Disable this option to support browsers that do not have JavaScript, -# like the Qt help browser. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_DYNAMIC_MENUS = YES - -# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML -# documentation will contain sections that can be hidden and shown after the -# page has loaded. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_DYNAMIC_SECTIONS = NO - -# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries -# shown in the various tree structured indices initially; the user can expand -# and collapse entries dynamically later on. Doxygen will expand the tree to -# such a level that at most the specified number of entries are visible (unless -# a fully collapsed tree already exceeds this amount). So setting the number of -# entries 1 will produce a full collapsed tree by default. 0 is a special value -# representing an infinite number of entries and will result in a full expanded -# tree by default. -# Minimum value: 0, maximum value: 9999, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_INDEX_NUM_ENTRIES = 100 - -# If the GENERATE_DOCSET tag is set to YES, additional index files will be -# generated that can be used as input for Apple's Xcode 3 integrated development -# environment (see: -# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To -# create a documentation set, doxygen will generate a Makefile in the HTML -# output directory. Running make will produce the docset in that directory and -# running make install will install the docset in -# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at -# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy -# genXcode/_index.html for more information. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_DOCSET = NO - -# This tag determines the name of the docset feed. A documentation feed provides -# an umbrella under which multiple documentation sets from a single provider -# (such as a company or product suite) can be grouped. -# The default value is: Doxygen generated docs. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_FEEDNAME = "Doxygen generated docs" - -# This tag determines the URL of the docset feed. A documentation feed provides -# an umbrella under which multiple documentation sets from a single provider -# (such as a company or product suite) can be grouped. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_FEEDURL = - -# This tag specifies a string that should uniquely identify the documentation -# set bundle. This should be a reverse domain-name style string, e.g. -# com.mycompany.MyDocSet. Doxygen will append .docset to the name. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_BUNDLE_ID = org.doxygen.Project - -# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify -# the documentation publisher. This should be a reverse domain-name style -# string, e.g. com.mycompany.MyDocSet.documentation. -# The default value is: org.doxygen.Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_ID = org.doxygen.Publisher - -# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. -# The default value is: Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_NAME = Publisher - -# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three -# additional HTML index files: index.hhp, index.hhc, and index.hhk. The -# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop -# on Windows. In the beginning of 2021 Microsoft took the original page, with -# a.o. the download links, offline the HTML help workshop was already many years -# in maintenance mode). You can download the HTML help workshop from the web -# archives at Installation executable (see: -# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo -# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). -# -# The HTML Help Workshop contains a compiler that can convert all HTML output -# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML -# files are now used as the Windows 98 help format, and will replace the old -# Windows help format (.hlp) on all Windows platforms in the future. Compressed -# HTML files also contain an index, a table of contents, and you can search for -# words in the documentation. The HTML workshop also contains a viewer for -# compressed HTML files. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_HTMLHELP = NO - -# The CHM_FILE tag can be used to specify the file name of the resulting .chm -# file. You can add a path in front of the file if the result should not be -# written to the html output directory. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_FILE = - -# The HHC_LOCATION tag can be used to specify the location (absolute path -# including file name) of the HTML help compiler (hhc.exe). If non-empty, -# doxygen will try to run the HTML help compiler on the generated index.hhp. -# The file has to be specified with full path. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -HHC_LOCATION = - -# The GENERATE_CHI flag controls if a separate .chi index file is generated -# (YES) or that it should be included in the main .chm file (NO). -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -GENERATE_CHI = NO - -# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) -# and project file content. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_INDEX_ENCODING = - -# The BINARY_TOC flag controls whether a binary table of contents is generated -# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it -# enables the Previous and Next buttons. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -BINARY_TOC = NO - -# The TOC_EXPAND flag can be set to YES to add extra items for group members to -# the table of contents of the HTML help documentation and to the tree view. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -TOC_EXPAND = NO - -# The SITEMAP_URL tag is used to specify the full URL of the place where the -# generated documentation will be placed on the server by the user during the -# deployment of the documentation. The generated sitemap is called sitemap.xml -# and placed on the directory specified by HTML_OUTPUT. In case no SITEMAP_URL -# is specified no sitemap is generated. For information about the sitemap -# protocol see https://www.sitemaps.org -# This tag requires that the tag GENERATE_HTML is set to YES. - -SITEMAP_URL = - -# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and -# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that -# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help -# (.qch) of the generated HTML documentation. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_QHP = NO - -# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify -# the file name of the resulting .qch file. The path specified is relative to -# the HTML output folder. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QCH_FILE = - -# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help -# Project output. For more information please see Qt Help Project / Namespace -# (see: -# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_NAMESPACE = org.doxygen.Project - -# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt -# Help Project output. For more information please see Qt Help Project / Virtual -# Folders (see: -# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). -# The default value is: doc. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_VIRTUAL_FOLDER = doc - -# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom -# filter to add. For more information please see Qt Help Project / Custom -# Filters (see: -# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_NAME = - -# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the -# custom filter to add. For more information please see Qt Help Project / Custom -# Filters (see: -# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_ATTRS = - -# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this -# project's filter section matches. Qt Help Project / Filter Attributes (see: -# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_SECT_FILTER_ATTRS = - -# The QHG_LOCATION tag can be used to specify the location (absolute path -# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to -# run qhelpgenerator on the generated .qhp file. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHG_LOCATION = - -# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be -# generated, together with the HTML files, they form an Eclipse help plugin. To -# install this plugin and make it available under the help contents menu in -# Eclipse, the contents of the directory containing the HTML and XML files needs -# to be copied into the plugins directory of eclipse. The name of the directory -# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. -# After copying Eclipse needs to be restarted before the help appears. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_ECLIPSEHELP = NO - -# A unique identifier for the Eclipse help plugin. When installing the plugin -# the directory name containing the HTML and XML files should also have this -# name. Each documentation set should have its own identifier. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. - -ECLIPSE_DOC_ID = org.doxygen.Project - -# If you want full control over the layout of the generated HTML pages it might -# be necessary to disable the index and replace it with your own. The -# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top -# of each HTML page. A value of NO enables the index and the value YES disables -# it. Since the tabs in the index contain the same information as the navigation -# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -DISABLE_INDEX = NO - -# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index -# structure should be generated to display hierarchical information. If the tag -# value is set to YES, a side panel will be generated containing a tree-like -# index structure (just like the one that is generated for HTML Help). For this -# to work a browser that supports JavaScript, DHTML, CSS and frames is required -# (i.e. any modern browser). Windows users are probably better off using the -# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can -# further fine tune the look of the index (see "Fine-tuning the output"). As an -# example, the default style sheet generated by doxygen has an example that -# shows how to put an image at the root of the tree instead of the PROJECT_NAME. -# Since the tree basically has the same information as the tab index, you could -# consider setting DISABLE_INDEX to YES when enabling this option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_TREEVIEW = NO - -# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the -# FULL_SIDEBAR option determines if the side bar is limited to only the treeview -# area (value NO) or if it should extend to the full height of the window (value -# YES). Setting this to YES gives a layout similar to -# https://docs.readthedocs.io with more room for contents, but less room for the -# project logo, title, and description. If either GENERATE_TREEVIEW or -# DISABLE_INDEX is set to NO, this option has no effect. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FULL_SIDEBAR = NO - -# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that -# doxygen will group on one line in the generated HTML documentation. -# -# Note that a value of 0 will completely suppress the enum values from appearing -# in the overview section. -# Minimum value: 0, maximum value: 20, default value: 4. -# This tag requires that the tag GENERATE_HTML is set to YES. - -ENUM_VALUES_PER_LINE = 1 - -# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used -# to set the initial width (in pixels) of the frame in which the tree is shown. -# Minimum value: 0, maximum value: 1500, default value: 250. -# This tag requires that the tag GENERATE_HTML is set to YES. - -TREEVIEW_WIDTH = 250 - -# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to -# external symbols imported via tag files in a separate window. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -EXT_LINKS_IN_WINDOW = NO - -# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email -# addresses. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -OBFUSCATE_EMAILS = YES - -# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg -# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see -# https://inkscape.org) to generate formulas as SVG images instead of PNGs for -# the HTML output. These images will generally look nicer at scaled resolutions. -# Possible values are: png (the default) and svg (looks nicer but requires the -# pdf2svg or inkscape tool). -# The default value is: png. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FORMULA_FORMAT = png - -# Use this tag to change the font size of LaTeX formulas included as images in -# the HTML documentation. When you change the font size after a successful -# doxygen run you need to manually remove any form_*.png images from the HTML -# output directory to force them to be regenerated. -# Minimum value: 8, maximum value: 50, default value: 10. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_FONTSIZE = 10 - -# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands -# to create new LaTeX commands to be used in formulas as building blocks. See -# the section "Including formulas" for details. - -FORMULA_MACROFILE = - -# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see -# https://www.mathjax.org) which uses client side JavaScript for the rendering -# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX -# installed or if you want to formulas look prettier in the HTML output. When -# enabled you may also need to install MathJax separately and configure the path -# to it using the MATHJAX_RELPATH option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -USE_MATHJAX = YES - -# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. -# Note that the different versions of MathJax have different requirements with -# regards to the different settings, so it is possible that also other MathJax -# settings have to be changed when switching between the different MathJax -# versions. -# Possible values are: MathJax_2 and MathJax_3. -# The default value is: MathJax_2. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_VERSION = MathJax_2 - -# When MathJax is enabled you can set the default output format to be used for -# the MathJax output. For more details about the output format see MathJax -# version 2 (see: -# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 -# (see: -# http://docs.mathjax.org/en/latest/web/components/output.html). -# Possible values are: HTML-CSS (which is slower, but has the best -# compatibility. This is the name for Mathjax version 2, for MathJax version 3 -# this will be translated into chtml), NativeMML (i.e. MathML. Only supported -# for NathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This -# is the name for Mathjax version 3, for MathJax version 2 this will be -# translated into HTML-CSS) and SVG. -# The default value is: HTML-CSS. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_FORMAT = HTML-CSS - -# When MathJax is enabled you need to specify the location relative to the HTML -# output directory using the MATHJAX_RELPATH option. The destination directory -# should contain the MathJax.js script. For instance, if the mathjax directory -# is located at the same level as the HTML output directory, then -# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax -# Content Delivery Network so you can quickly see the result without installing -# MathJax. However, it is strongly recommended to install a local copy of -# MathJax from https://www.mathjax.org before deployment. The default value is: -# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 -# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_RELPATH = - -# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax -# extension names that should be enabled during MathJax rendering. For example -# for MathJax version 2 (see -# https://docs.mathjax.org/en/v2.7-latest/tex.html#tex-and-latex-extensions): -# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols -# For example for MathJax version 3 (see -# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): -# MATHJAX_EXTENSIONS = ams -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_EXTENSIONS = - -# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces -# of code that will be used on startup of the MathJax code. See the MathJax site -# (see: -# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an -# example see the documentation. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_CODEFILE = - -# When the SEARCHENGINE tag is enabled doxygen will generate a search box for -# the HTML output. The underlying search engine uses javascript and DHTML and -# should work on any modern browser. Note that when using HTML help -# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) -# there is already a search function so this one should typically be disabled. -# For large projects the javascript based search engine can be slow, then -# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to -# search using the keyboard; to jump to the search box use + S -# (what the is depends on the OS and browser, but it is typically -# , /