diff --git a/.gitignore b/.gitignore index 17f93500bd..04ac34466f 100644 --- a/.gitignore +++ b/.gitignore @@ -112,4 +112,30 @@ test_data/* experimental/grouped_convolution_tile_instances/instances/* !experimental/grouped_convolution_tile_instances/instances/*.in !experimental/grouped_convolution_tile_instances/instances/*.inc +!experimental/grouped_convolution_tile_instances/instances/*.hpp experimental/grouped_convolution_tile_instances/*.inc +# Heuristics: benchmark data (never in git) +dispatcher/heuristics/data/ + +# Heuristics: experimental/training artifacts (exclude from git) +dispatcher/heuristics/models/**/oof_predictions.parquet +dispatcher/heuristics/models/**/cv_metrics_*.json +dispatcher/heuristics/models/**/eval_report.json +dispatcher/heuristics/models/**/feature_importances_*.json +dispatcher/heuristics/models/**/model_tflops_ihem.lgbm +dispatcher/heuristics/models/**/model_tflops_log.lgbm +dispatcher/heuristics/models/**/model_tflops_log_big.lgbm + +# Heuristics: keep in git (production model files): +# models/{op}_{dtype}_{arch}/model_tflops.lgbm +# models/{op}_{dtype}_{arch}/model_latency.lgbm +# models/{op}_{dtype}_{arch}/model_bandwidth.lgbm +# models/{op}_{dtype}_{arch}/feature_spec.json +# models/{op}_{dtype}_{arch}/train_manifest.json + +# Heuristics: logs and caches +dispatcher/heuristics/*.log +dispatcher/heuristics/__pycache__/ +dispatcher/heuristics/tests/__pycache__/ +dispatcher/heuristics/.pytest_cache/ + diff --git a/.readthedocs.yaml b/.readthedocs.yaml index b3299fa4e8..50fa167b41 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -4,13 +4,13 @@ version: 2 sphinx: - configuration: docs/conf.py + configuration: projects/composablekernel/docs/conf.py formats: [htmlzip, pdf, epub] python: install: - - requirements: docs/sphinx/requirements.txt + - requirements: projects/composablekernel/docs/sphinx/requirements.txt build: os: ubuntu-22.04 diff --git a/CHANGELOG.md b/CHANGELOG.md index 370e9e4243..f6812a8520 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FP8 block scale quantization for FMHA forward kernel. * Added gfx11 support for FMHA. * Added microscaling (MX) FP8/FP4 support on gfx950 for FMHA forward kernel ("qr" pipeline only). +* Added FP8 per-tensor quantization support for FMHA forward V3 pipeline on gfx950. ### Changed diff --git a/Dockerfile b/Dockerfile index f19bc69362..de129d0703 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,22 +2,33 @@ FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive ARG ROCMVERSION=7.1.1 +ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +ARG TARBALL_URL=https://rocm.nightlies.amd.com/tarball/therock-dist-linux-gfx90X-dcgpu-7.12.0a20260218.tar.gz ARG compiler_version="" ARG compiler_commit="" -ARG CK_SCCACHE="" -ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn ENV DEBIAN_FRONTEND=noninteractive +ENV PATH=$PATH:/opt/rocm/bin +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib +ENV HIP_PLATFORM=amd # Add rocm repository RUN set -xe && \ apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl -RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \ - apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \ - apt update && \ - apt install python3-setuptools python3-wheel -y && \ - apt install rocm-dev -y +RUN if [ "$compiler_version" = "therock" ]; then \ + rm -rf /opt/rocm && mkdir /opt/rocm && \ + echo "Downloading ROCm tarball from $TARBALL_URL..." && \ + wget -q -O /tmp/rocm.tar.gz "$TARBALL_URL" && \ + echo "Extracting tarball to /opt/rocm..." && \ + tar -xzf /tmp/rocm.tar.gz -C /opt/rocm --strip-components=1 ; \ + else echo "using the release compiler" && \ + wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \ + apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \ + apt update && \ + apt install python3-setuptools python3-wheel -y && \ + apt install rocm-dev -y; \ + fi # Install SCCACHE ENV SCCACHE_VERSION="0.14.0" @@ -34,7 +45,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- build-essential \ cmake \ git \ - hip-rocclr \ iputils-ping \ jq \ libelf-dev \ @@ -44,8 +54,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- net-tools \ pkg-config \ python3-full \ + python3-pip \ redis \ - rocm-llvm-dev \ sshpass \ stunnel \ software-properties-common \ @@ -88,26 +98,3 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- git clone -b master https://github.com/ROCm/rocm-cmake.git && \ cd rocm-cmake && mkdir build && cd build && \ cmake .. && cmake --build . && cmake --build . --target install - -WORKDIR / -# Add alternative compilers, if necessary -ENV compiler_version=$compiler_version -ENV compiler_commit=$compiler_commit -RUN sh -c "echo compiler version = '$compiler_version'" && \ - sh -c "echo compiler commit = '$compiler_commit'" - -RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ - git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ - cd llvm-project && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ - make -j 8 ; \ - else echo "using the release compiler"; \ - fi - -RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ - git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ - cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ - make -j 8 ; \ - else echo "using the release compiler"; \ - fi diff --git a/Dockerfile.aiter b/Dockerfile.aiter index a5a3f81fca..ebfef41643 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -4,7 +4,7 @@ ARG AITER_BRANCH="main" ARG CK_AITER_BRANCH="develop" # CK_FROM_ROCM_LIBRARIES - 1: CK from rocm-libraries sparse-checkout; 0: direct clone from ROCm/composable_kernel ARG CK_FROM_ROCM_LIBRARIES=1 -RUN pip install pandas zmq einops ninja tabulate && \ +RUN pip install pandas zmq einops ninja tabulate vcs_versioning && \ pip install numpy==1.26.2 && \ sudo mkdir /home/jenkins && \ sudo mkdir /home/jenkins/workspace && \ diff --git a/Dockerfile.compiler b/Dockerfile.compiler index 8f5503d79e..9d1e54106e 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -9,18 +9,70 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" && \ sh -c "echo compiler commit = '$compiler_commit'" -RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ +RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ - cd llvm-project && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ - make -j 16 ; \ + cd llvm-project && git log -1 && mkdir build && cd build && \ + cmake -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_PROJECTS="clang;lld;clang-tools-extra;flang;mlir" \ + -DLLVM_LIT_ARGS="-vv --show-unsupported --show-xfail -j 32" \ + -DPACKAGE_VENDOR="AMD" \ + -DCMAKE_INSTALL_PREFIX=/home/$USER/rocm/pure_llvm_1.0 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_BUILD_DOCS=ON \ + -DLLVM_TARGETS_TO_BUILD=all \ + -DLIBOMPTARGET_ENABLE_DEBUG=ON \ + -DOFFLOAD_ENABLE_EMISSARY_APIS=OFF \ + -DCLANG_DEFAULT_LINKER=lld \ + -DCLANG_DEFAULT_PIE_ON_LINUX=0 \ + -DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi;openmp;compiler-rt;libunwind;flang-rt" \ + -DLIBCXX_ENABLE_SHARED=OFF \ + -DLIBCXX_ENABLE_STATIC=ON \ + -DLIBCXX_INSTALL_LIBRARY=OFF \ + -DLIBCXX_INSTALL_HEADERS=OFF \ + -DLIBCXXABI_ENABLE_SHARED=OFF \ + -DLIBCXXABI_ENABLE_STATIC=ON \ + -DLIBCXXABI_INSTALL_STATIC_LIBRARY=OFF \ + -DLLVM_ENABLE_ASSERTIONS=1 \ + -DLLVM_ENABLE_Z3_SOLVER=OFF \ + -DLLVM_ENABLE_ZLIB=ON \ + -DLLVM_LINK_LLVM_DYLIB=OFF \ + -DCLANG_LINK_CLANG_DYLIB=OFF \ + ../llvm && \ + ninja -j16 ; \ else echo "using the release compiler"; \ fi -RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ +RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ - make -j 16 ; \ + cmake -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_PROJECTS="clang;lld;clang-tools-extra;flang;mlir" \ + -DLLVM_LIT_ARGS="-vv --show-unsupported --show-xfail -j 32" \ + -DPACKAGE_VENDOR="AMD" \ + -DCMAKE_INSTALL_PREFIX=/home/$USER/rocm/pure_llvm_1.0 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_BUILD_DOCS=ON \ + -DLLVM_TARGETS_TO_BUILD=all \ + -DLIBOMPTARGET_ENABLE_DEBUG=ON \ + -DOFFLOAD_ENABLE_EMISSARY_APIS=OFF \ + -DCLANG_DEFAULT_LINKER=lld \ + -DCLANG_DEFAULT_PIE_ON_LINUX=0 \ + -DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi;openmp;compiler-rt;libunwind;flang-rt" \ + -DLIBCXX_ENABLE_SHARED=OFF \ + -DLIBCXX_ENABLE_STATIC=ON \ + -DLIBCXX_INSTALL_LIBRARY=OFF \ + -DLIBCXX_INSTALL_HEADERS=OFF \ + -DLIBCXXABI_ENABLE_SHARED=OFF \ + -DLIBCXXABI_ENABLE_STATIC=ON \ + -DLIBCXXABI_INSTALL_STATIC_LIBRARY=OFF \ + -DLLVM_ENABLE_ASSERTIONS=1 \ + -DLLVM_ENABLE_Z3_SOLVER=OFF \ + -DLLVM_ENABLE_ZLIB=ON \ + -DLLVM_LINK_LLVM_DYLIB=OFF \ + -DCLANG_LINK_CLANG_DYLIB=OFF \ + ../llvm && \ + ninja -j16 ; \ else echo "using the release compiler"; \ fi diff --git a/Dockerfile.fa b/Dockerfile.fa new file mode 100644 index 0000000000..c5cbacfc16 --- /dev/null +++ b/Dockerfile.fa @@ -0,0 +1,43 @@ +ARG BASE_DOCKER="rocm/pytorch:latest" +FROM $BASE_DOCKER +ARG FA_ORIGIN="ROCm" +ARG FA_BRANCH="tridao" +ARG CK_FA_ORIGIN="ROCm" +ARG CK_FA_BRANCH="develop" +# CK_FROM_ROCM_LIBRARIES - 1: CK from rocm-libraries sparse-checkout; 0: direct clone from ROCm/composable_kernel +ARG CK_FROM_ROCM_LIBRARIES=1 +ARG GPU_ARCHS="gfx90a;gfx942;gfx950" +RUN set -x ; \ + sudo mkdir /home/jenkins && \ + sudo mkdir /home/jenkins/workspace && \ + cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \ + if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \ + git clone --depth 1 -b "$CK_FA_BRANCH" --no-checkout --filter=blob:none https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \ + cd rocm-libraries && \ + git sparse-checkout init --cone && \ + git sparse-checkout set projects/composablekernel && \ + git checkout "$CK_FA_BRANCH" && \ + ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \ + mv projects/composablekernel ../ck && \ + cd ../ck && rm -rf ../rocm-libraries && \ + git init && \ + git config user.name "assistant-librarian[bot]" && \ + git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \ + git branch -m "$CK_FA_BRANCH" && git add -A && \ + git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" > /dev/null ; \ + else \ + git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck ; \ + fi && \ + cd /home/jenkins/workspace && rm -rf flash-attention && \ + git clone --depth 1 -b "$FA_BRANCH" --recursive "https://github.com/$FA_ORIGIN/flash-attention.git" && \ + cd flash-attention && \ + rm -rf csrc/composable_kernel/ && \ + git clone -b "$CK_FA_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \ + MAX_JOBS=$(nproc) GPU_ARCHS="$GPU_ARCHS" /opt/venv/bin/python3 -u -m pip install --no-build-isolation -v . && \ + groupadd -g 1001 jenkins && \ + useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \ + chown -R jenkins:jenkins /home/jenkins && \ + chmod -R a+rwx /home/jenkins && \ + chown -R jenkins:jenkins /tmp && \ + chmod -R a+rwx /tmp && \ + sudo usermod -aG irc jenkins diff --git a/Dockerfile.manylinux b/Dockerfile.manylinux index 2c0bec2840..bfbe847b1d 100644 --- a/Dockerfile.manylinux +++ b/Dockerfile.manylinux @@ -3,7 +3,6 @@ ARG DEBIAN_FRONTEND=noninteractive ARG ROCMVERSION=7.2 ARG compiler_version="" ARG compiler_commit="" -ARG CK_SCCACHE="" ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn ENV DEBIAN_FRONTEND=noninteractive @@ -19,16 +18,15 @@ RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2 dnf install python3-setuptools python3-wheel -y && \ dnf install rocm-dev -y -## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined -ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache +# Install SCCACHE +ENV SCCACHE_VERSION="0.14.0" ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} -ENV CK_SCCACHE=$CK_SCCACHE -RUN if [ "$CK_SCCACHE" != "" ]; then \ - mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ - curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ - chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \ - fi +RUN set -x && \ + mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ + wget -qO sccache.tar.gz https://github.com/mozilla/sccache/releases/latest/download/sccache-v$SCCACHE_VERSION-x86_64-unknown-linux-musl.tar.gz && \ + tar -xzf sccache.tar.gz --strip-components=1 -C ${SCCACHE_INSTALL_LOCATION} && \ + chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache # Install dependencies RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \ @@ -83,19 +81,71 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" && \ sh -c "echo compiler commit = '$compiler_commit'" -RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ +RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ - make -j 8 ; \ + cmake -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_PROJECTS="clang;lld;clang-tools-extra;flang;mlir" \ + -DLLVM_LIT_ARGS="-vv --show-unsupported --show-xfail -j 32" \ + -DPACKAGE_VENDOR="AMD" \ + -DCMAKE_INSTALL_PREFIX=/home/$USER/rocm/pure_llvm_1.0 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_BUILD_DOCS=ON \ + -DLLVM_TARGETS_TO_BUILD=all \ + -DLIBOMPTARGET_ENABLE_DEBUG=ON \ + -DOFFLOAD_ENABLE_EMISSARY_APIS=OFF \ + -DCLANG_DEFAULT_LINKER=lld \ + -DCLANG_DEFAULT_PIE_ON_LINUX=0 \ + -DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi;openmp;compiler-rt;libunwind;flang-rt" \ + -DLIBCXX_ENABLE_SHARED=OFF \ + -DLIBCXX_ENABLE_STATIC=ON \ + -DLIBCXX_INSTALL_LIBRARY=OFF \ + -DLIBCXX_INSTALL_HEADERS=OFF \ + -DLIBCXXABI_ENABLE_SHARED=OFF \ + -DLIBCXXABI_ENABLE_STATIC=ON \ + -DLIBCXXABI_INSTALL_STATIC_LIBRARY=OFF \ + -DLLVM_ENABLE_ASSERTIONS=1 \ + -DLLVM_ENABLE_Z3_SOLVER=OFF \ + -DLLVM_ENABLE_ZLIB=ON \ + -DLLVM_LINK_LLVM_DYLIB=OFF \ + -DCLANG_LINK_CLANG_DYLIB=OFF \ + ../llvm && \ + ninja -j16 ; \ else echo "using the release compiler"; \ fi -RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ +RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ - make -j 8 ; \ + cmake -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_PROJECTS="clang;lld;clang-tools-extra;flang;mlir" \ + -DLLVM_LIT_ARGS="-vv --show-unsupported --show-xfail -j 32" \ + -DPACKAGE_VENDOR="AMD" \ + -DCMAKE_INSTALL_PREFIX=/home/$USER/rocm/pure_llvm_1.0 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_BUILD_DOCS=ON \ + -DLLVM_TARGETS_TO_BUILD=all \ + -DLIBOMPTARGET_ENABLE_DEBUG=ON \ + -DOFFLOAD_ENABLE_EMISSARY_APIS=OFF \ + -DCLANG_DEFAULT_LINKER=lld \ + -DCLANG_DEFAULT_PIE_ON_LINUX=0 \ + -DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi;openmp;compiler-rt;libunwind;flang-rt" \ + -DLIBCXX_ENABLE_SHARED=OFF \ + -DLIBCXX_ENABLE_STATIC=ON \ + -DLIBCXX_INSTALL_LIBRARY=OFF \ + -DLIBCXX_INSTALL_HEADERS=OFF \ + -DLIBCXXABI_ENABLE_SHARED=OFF \ + -DLIBCXXABI_ENABLE_STATIC=ON \ + -DLIBCXXABI_INSTALL_STATIC_LIBRARY=OFF \ + -DLLVM_ENABLE_ASSERTIONS=1 \ + -DLLVM_ENABLE_Z3_SOLVER=OFF \ + -DLLVM_ENABLE_ZLIB=ON \ + -DLLVM_LINK_LLVM_DYLIB=OFF \ + -DCLANG_LINK_CLANG_DYLIB=OFF \ + ../llvm && \ + ninja -j16 ; \ else echo "using the release compiler"; \ fi diff --git a/Jenkinsfile b/Jenkinsfile index 22709f414a..42ca1756c0 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -81,71 +81,6 @@ def checkoutComposableKernel() checkout scm } -// Given a pattern, check if the log contains the pattern and return the context. -def checkForPattern(pattern, log) { - def lines = log.split('\n') - for (int i = 0; i < lines.size(); i++) { - if (lines[i] =~ pattern) { - echo "Found pattern match in log for ${pattern}" - - // Get the two lines before and after failure. - def contextStart = Math.max(0, i - 2) - def contextEnd = Math.min(lines.size() - 1, i + 2) - def contextLines = [] - for (int j = contextStart; j <= contextEnd; j++) { - contextLines.add(lines[j]) - } - - return [found: true, matchedLine: lines[i], context: contextLines.join('\n')] - } - } - echo "No pattern match found in log for ${pattern}" - return [found: false, matchedLine: "", context: ""] -} - -// Scan the build logs for failures and send notifications. -def sendFailureNotifications() { - // Error patterns to scan build logs for specific failure types and send detailed notifications. - def failurePatterns = [ - [pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"], - [pattern: /.*docker login failed.*/, description: "Docker login failed"], - [pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"], - [pattern: /cat: .* No such file or directory/, description: "GPU not found"], - [pattern: /.*GPU not found.*/, description: "GPU not found"], - [pattern: /Could not connect to Redis at .* Connection timed out/, description: "Redis connection timed out"], - [pattern: /.*unauthorized: your account must log in with a Personal Access Token.*/, description: "Docker login failed"], - [pattern: /.*sccache: error: Server startup failed: Address in use.*/, description: "Sccache Error"] - ] - - // Get the build log. - def buildLog = sh(script: 'wget -q --no-check-certificate -O - ' + BUILD_URL + 'consoleText', returnStdout: true) - echo "Checking for failure patterns..." - // Check for patterns in the log. - // def foundPatterns = [] - // for (patternMap in failurePatterns) { - // def result = checkForPattern(patternMap.pattern, buildLog) - // if (result.found) { - // foundPatterns.add([ - // description: patternMap.description, - // matchedLine: result.matchedLine, - // context: result.context - // ]) - // } - // } - echo "Done checking for failure patterns..." - // Send a notification for each matched failure pattern. - for (patternMap in foundPatterns) { - withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) { - sh ''' - curl -X POST "${WEBHOOK_URL}" \ - -H 'Content-Type: application/json' \ - -d '{"text": "\\n\\n**Build Failed**\\n\\n**Issues detected:** ''' + patternMap.description + '''\\n\\n**Log context:**\\n```\\n''' + patternMap.context.replace("'", "\\'") + '''\\n```\\n\\n**Job:** ''' + env.JOB_NAME + '''\\n\\n**Build:** #''' + env.BUILD_NUMBER + '''\\n\\n**URL:** ''' + env.RUN_DISPLAY_URL + '''"}' - ''' - } - } - echo "Done failure pattern checking and notifications" -} - def generateAndArchiveBuildTraceVisualization(String buildTraceFileName) { try { checkoutComposableKernel() @@ -479,51 +414,86 @@ def getDockerImage(Map conf=[:]){ return [retimage, image] } -def buildDocker(install_prefix){ +// Build and push a docker image, capturing its digest into the specified env var. +// If forceBuild is false, will skip building if the image already exists in the registry. +def buildAndPushDockerImage(String install_prefix, String image_name, String dockerExtraArgs, boolean forceBuild){ show_node_info() env.DOCKER_BUILDKIT=1 checkoutComposableKernel() - def image_name = getDockerImageName() - def base_image_name = getBaseDockerImageName() - echo "Building Docker for ${image_name}" - def dockerArgs = "--build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if(params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ - dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f projects/composablekernel/Dockerfile.compiler . " - } - else if(params.RUN_AITER_TESTS){ - image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter" - dockerArgs = dockerArgs + " --no-cache -f projects/composablekernel/Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . " - } - else if(params.RUN_PYTORCH_TESTS){ - image_name = "${env.CK_DOCKERHUB}:ck_pytorch" - dockerArgs = dockerArgs + " --no-cache -f projects/composablekernel/Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . " - } - else{ - dockerArgs = dockerArgs + " -f projects/composablekernel/Dockerfile . " - } - echo "Build Args: ${dockerArgs}" - try{ - if(params.BUILD_DOCKER || params.RUN_AITER_TESTS || params.RUN_PYTORCH_TESTS){ - //force building the new docker if that parameter is true - echo "Building image: ${image_name}" - retimage = docker.build("${image_name}", dockerArgs) - withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { - retimage.push() - } - sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' - } - else{ + def dockerArgs = "--build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + dockerArgs += " " + dockerExtraArgs + + if(!forceBuild){ + try{ echo "Checking for image: ${image_name}" sh "docker manifest inspect --insecure ${image_name}" echo "Image: ${image_name} found! Skipping building image" + return image_name + } + catch(Exception ex){ + echo "Unable to locate image: ${image_name}. Will attempt to build image now." } } - catch(Exception ex){ - echo "Unable to locate image: ${image_name}. Building image now" - retimage = docker.build("${image_name}", dockerArgs) - withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { - retimage.push() - } + + echo "Building image: ${image_name} with args: ${dockerArgs}" + def retimage = docker.build("${image_name}", dockerArgs) + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { + retimage.push() + } + def digest = sh(returnStdout: true, script: "docker inspect --format='{{index .RepoDigests 0}}' ${image_name}").trim() + echo "Built image digest: ${digest}" + echo "Pruning dangling Docker images to free disk space on CI agent" + sh "docker image prune -f --filter 'dangling=true' || true" + return digest +} + +def buildDockerBase(install_prefix){ + def image_name = getDockerImageName() + def base_image_name = getBaseDockerImageName() + echo "Building Docker for ${image_name}" + def dockerExtraArgs = " -f projects/composablekernel/Dockerfile . " + if(params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_COMMIT != ""){ + dockerExtraArgs = " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f projects/composablekernel/Dockerfile.compiler . " + } + else if(params.COMPILER_VERSION == "therock"){ + dockerExtraArgs = " --no-cache -f projects/composablekernel/Dockerfile . " + } + env.CK_BASE_IMAGE = buildAndPushDockerImage(install_prefix, image_name, dockerExtraArgs, params.BUILD_DOCKER.toBoolean()) +} + +def buildDockerPytorch(install_prefix){ + def image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_pytorch" + def dockerExtraArgs = " --no-cache -f projects/composablekernel/Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . " + env.CK_PYTORCH_IMAGE = buildAndPushDockerImage(install_prefix, image_name, dockerExtraArgs, true) +} + +def buildDockerAiter(install_prefix){ + def image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter" + def dockerExtraArgs = " --no-cache -f projects/composablekernel/Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . " + env.CK_AITER_IMAGE = buildAndPushDockerImage(install_prefix, image_name, dockerExtraArgs, true) +} + +def buildDockerFa(install_prefix){ + def image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_fa" + def dockerExtraArgs = " --no-cache -f projects/composablekernel/Dockerfile.fa" + dockerExtraArgs += " --build-arg BASE_DOCKER='${params.fa_base_docker}'" + dockerExtraArgs += " --build-arg FA_BRANCH='${params.fa_branch}'" + dockerExtraArgs += " --build-arg CK_FA_BRANCH='${params.ck_fa_branch}'" + dockerExtraArgs += " --build-arg GPU_ARCHS='gfx942;gfx950'" + dockerExtraArgs += " . " + env.CK_FA_IMAGE = buildAndPushDockerImage(install_prefix, image_name, dockerExtraArgs, true) +} + +def buildDocker(install_prefix){ + buildDockerBase(install_prefix) + if (params.RUN_PYTORCH_TESTS.toBoolean()) { + buildDockerPytorch(install_prefix) + } + if (params.RUN_AITER_TESTS.toBoolean()) { + buildDockerAiter(install_prefix) + } + if (params.RUN_FA_TESTS.toBoolean()) { + buildDockerFa(install_prefix) } } @@ -535,10 +505,10 @@ def get_docker_options(){ else{ //only add kfd and dri paths if you actually going to run somthing on GPUs dockerOpts = "--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" } - if (params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "therock" || params.COMPILER_COMMIT != ""){ // the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with // newer clang22 compilers and running with older hip runtima libraries - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 " + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 --env HIP_PLATFORM=amd " } // on some machines the group ids for video and render groups may not be the same as in the docker image! def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') @@ -1148,99 +1118,73 @@ def process_results(Map conf=[:]){ } } -def run_aiter_tests(Map conf=[:]){ +def run_downstream_tests(Map conf=[:]){ show_node_info() checkoutComposableKernel() - //use the latest pytorch image - def image = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter" - def dockerOpts=get_docker_options() + ' --group-add irc ' + def dockerOpts = get_docker_options() + ' --group-add irc ' gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') { try { - echo "Pulling image: ${image}" - retimage = docker.image("${image}") + echo "Pulling image: ${conf.image}" + retimage = docker.image("${conf.image}") withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.pull() } } catch(Exception ex) { - error "Unable to locate image: ${image}" + error "Unable to locate image: ${conf.image}" } } - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'HOURS'){ + withDockerContainer(image: conf.image, args: dockerOpts) { + timeout(time: conf.get("timeoutHours", 2), unit: 'HOURS'){ try{ sh "rocminfo" sh "python3 --version" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_batch_prefill.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_ep.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_sorting.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_sorting_mxfp4.py" - sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_tkw1.py" + for (cmd in conf.execute_cmds) { + sh "${cmd}" + } } catch(e){ - echo "Throwing error exception while running AITER tests" + echo "Throwing error exception while running ${env.STAGE_NAME}" echo 'Exception occurred: ' + e.toString() throw e } finally{ - echo "Finished running AITER tests" + echo "Finished running ${env.STAGE_NAME}" } } } } - -def run_pytorch_tests(Map conf=[:]){ - show_node_info() - checkoutComposableKernel() - //use the latest pytorch-nightly image - def image = "${env.CK_DOCKERHUB}:ck_pytorch" - def dockerOpts=get_docker_options() + ' --group-add irc ' - - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') { - try - { - echo "Pulling image: ${image}" - retimage = docker.image("${image}") - withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { - retimage.pull() - } - } - catch(Exception ex) - { - error "Unable to locate image: ${image}" - } - } - - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 2, unit: 'HOURS'){ - try{ - sh "rocminfo" - sh "python3 --version" - sh "python3 /tmp/pytorch/tools/amd_build/build_amd.py" - sh "USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python /tmp/pytorch/setup.py develop" - } - catch(e){ - echo "Throwing error exception while building Pytorch" - echo 'Exception occurred: ' + e.toString() - throw e - } - finally{ - echo "Finished building Pytorch" - } - } - } +def getPytorchTestsCmds() { + return [ + "python3 /tmp/pytorch/tools/amd_build/build_amd.py", + "USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python /tmp/pytorch/setup.py develop" + ] +} +def getAiterTestsCmds() { + return [ + "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_batch_prefill.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_ep.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_sorting.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_sorting_mxfp4.py", + "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_tkw1.py" + ] +} +def getFaTestsCmds() { + return [ + "python3 -u -m pytest /home/jenkins/workspace/flash-attention/tests/test_flash_attn_ck.py" + ] } //launch develop branch daily jobs @@ -1248,15 +1192,20 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_ 0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_BASIC_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=false;BUILD_GFX908=false;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=develop;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true - 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true - 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true - 0 13 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true - 0 11 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : "" + 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=therock;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 15 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 13 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true + 0 11 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;RUN_FA_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true + 0 9 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : "" +CURRENT_BRANCH_NAME = env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME + +POLL_SPEC = BRANCH_NAME == "develop" ? 'H H/6 * * *' : '' pipeline { agent none triggers { parameterizedCron(CRON_SETTINGS) + pollSCM(POLL_SPEC) } options { skipDefaultCheckout() @@ -1278,7 +1227,7 @@ pipeline { string( name: 'COMPILER_VERSION', defaultValue: '', - description: 'Specify which version of compiler to use: release, develop, amd-mainline, or leave blank (default).') + description: 'Specify which version of compiler to use: develop, amd-staging, therock, or leave blank (default).') string( name: 'COMPILER_COMMIT', defaultValue: '', @@ -1381,8 +1330,8 @@ pipeline { description: "Build CK and run tests on gfx12 (default: ON)") booleanParam( name: "NINJA_BUILD_TRACE", - defaultValue: false, - description: "Generate a ninja build trace (default: OFF)") + defaultValue: true, + description: "Generate a ninja build trace (default: ON)") booleanParam( name: "NINJA_FTIME_TRACE", defaultValue: false, @@ -1409,8 +1358,8 @@ pipeline { description: "Try building PYTORCH with latest CK develop branch (default: OFF)") string( name: 'ck_pytorch_branch', - defaultValue: 'develop', - description: 'Specify which branch of CK to test with Pytorch (default: develop)') + defaultValue: CURRENT_BRANCH_NAME, + description: 'Specify which branch of CK to test with Pytorch (default: current branch)') booleanParam( name: "RUN_AITER_TESTS", defaultValue: false, @@ -1425,8 +1374,24 @@ pipeline { description: 'Specify which branch of AITER to use (default: main)') string( name: 'ck_aiter_branch', - defaultValue: 'develop', - description: 'Specify which branch of CK to test with AITER (default: develop)') + defaultValue: CURRENT_BRANCH_NAME, + description: 'Specify which branch of CK to test with AITER (default: current branch)') + booleanParam( + name: "RUN_FA_TESTS", + defaultValue: false, + description: "Run Flash Attention tests with latest CK develop branch (default: OFF)") + string( + name: 'fa_base_docker', + defaultValue: 'rocm/pytorch:rocm7.1.1_ubuntu24.04_py3.12_pytorch_release_2.9.1', + description: 'Specify which base docker image to use for flash-attention tests') + string( + name: 'fa_branch', + defaultValue: 'ck_improve_main', + description: 'Specify which branch of flash-attention to use (default: ck_improve_main)') + string( + name: 'ck_fa_branch', + defaultValue: CURRENT_BRANCH_NAME, + description: 'Specify which branch of CK to test with flash-attention (default: current branch)') booleanParam( name: "FORCE_CI", defaultValue: false, @@ -1519,7 +1484,7 @@ pipeline { } } } - stage("Run Pytorch Tests") + stage("Run Downstream Tests") { when { beforeAgent true @@ -1535,20 +1500,10 @@ pipeline { } agent{ label rocmnode("gfx942")} steps{ - run_pytorch_tests() + run_downstream_tests(image: "${env.CK_PYTORCH_IMAGE}", timeoutHours: 2, execute_cmds: getPytorchTestsCmds()) cleanWs() } } - } - } - stage("Run AITER Tests") - { - when { - beforeAgent true - expression { env.SHOULD_RUN_CI.toBoolean() } - } - parallel - { stage("Run AITER Tests on gfx942") { when { @@ -1557,7 +1512,7 @@ pipeline { } agent{ label rocmnode("gfx942")} steps{ - run_aiter_tests() + run_downstream_tests(image: "${env.CK_AITER_IMAGE}", timeoutHours: 5, execute_cmds: getAiterTestsCmds()) cleanWs() } } @@ -1569,7 +1524,31 @@ pipeline { } agent{ label rocmnode("gfx950")} steps{ - run_aiter_tests() + run_downstream_tests(image: "${env.CK_AITER_IMAGE}", timeoutHours: 5, execute_cmds: getAiterTestsCmds()) + cleanWs() + } + } + stage("Run FA Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_FA_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942")} + steps{ + run_downstream_tests(image: "${env.CK_FA_IMAGE}", timeoutHours: 5, execute_cmds: getFaTestsCmds()) + cleanWs() + } + } + stage("Run FA Tests on gfx950") + { + when { + beforeAgent true + expression { params.RUN_FA_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx950")} + steps{ + run_downstream_tests(image: "${env.CK_FA_IMAGE}", timeoutHours: 5, execute_cmds: getFaTestsCmds()) cleanWs() } } @@ -2109,7 +2088,10 @@ pipeline { description: 'Some checks have failed' node(rocmnode("nogpu")) { script { - sendFailureNotifications() + checkoutComposableKernel() + } + withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) { + sh 'bash projects/composablekernel/script/infra_helper/send_failure_notifications.sh' } } } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp index 863501cd0a..9895ed7e54 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -198,10 +198,6 @@ struct Epilogue input_left_pads, input_right_pads); - // auto res = rtc::from_gpu(out_dev); - // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); - // assert(pass); - // Simple check: this checks that the output from each instance matches the output from the // first instance CHECK(report(solution, check(rtc::from_gpu(out_dev)))); diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp index e748a29743..617c2318d5 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -198,10 +198,6 @@ struct Epilogue input_left_pads, input_right_pads); - // auto res = rtc::from_gpu(out_dev); - // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); - // assert(pass); - // Simple check: this checks that the output from each instance matches the output from the // first instance CHECK(report(solution, check(rtc::from_gpu(out_dev)))); diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp index a68fb53cba..84516b2577 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -198,10 +198,6 @@ struct Epilogue input_left_pads, input_right_pads); - // auto res = rtc::from_gpu(out_dev); - // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); - // assert(pass); - // Simple check: this checks that the output from each instance matches the output from the // first instance CHECK(report(solution, check(rtc::from_gpu(out_dev)))); diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp index 0262319c39..3490c38f6a 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -198,10 +198,6 @@ struct Epilogue input_left_pads, input_right_pads); - // auto res = rtc::from_gpu(out_dev); - // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); - // assert(pass); - // Simple check: this checks that the output from each instance matches the output from the // first instance CHECK(report(solution, check(rtc::from_gpu(out_dev)))); diff --git a/dispatcher/README.md b/dispatcher/README.md index d1ca299d78..dc864f7c62 100644 --- a/dispatcher/README.md +++ b/dispatcher/README.md @@ -1,6 +1,6 @@ # CK Tile Dispatcher -A unified kernel dispatch system for AMD GPUs with C++ and Python frontends. +A unified kernel dispatch system for AMD GPUs with C++ and Python frontends, supporting GEMM and Grouped Convolution operations. **Validated Platform:** AMD Instinct MI300 series (gfx942) @@ -154,6 +154,8 @@ rocminfo | grep -i "gfx" ### Install Python Dependencies +#### Core Dependencies (Required) + NumPy is required for Python examples and kernel generation. We recommend using a virtual environment: **Option 1: Using standard venv** @@ -165,8 +167,8 @@ python3 -m venv .venv source .venv/bin/activate # Linux/macOS # .venv\Scripts\activate # Windows -# Install NumPy -pip install numpy +# Install core dependencies +pip install -r python/requirements.txt ``` **Option 2: Using uv (faster alternative)** @@ -179,17 +181,38 @@ uv venv .venv source .venv/bin/activate # Linux/macOS # .venv\Scripts\activate # Windows -# Install NumPy -uv pip install numpy +# Install core dependencies +uv pip install -r python/requirements.txt ``` **Option 3: System-wide install (not recommended)** ```bash -pip install numpy +pip install -r python/requirements.txt ``` > **Note:** Always activate your virtual environment before running CMake or Python examples. +#### ML Heuristics Dependencies (Optional) + +For ML-based kernel selection (examples 09-11), install additional dependencies: + +```bash +# Activate your virtual environment first +source .venv/bin/activate + +# Install ML dependencies (LightGBM, pandas, pyarrow, scikit-learn) +pip install -r requirements-ml.txt +``` + +**Why separate?** ML dependencies are large (especially pyarrow) and not needed for basic dispatcher usage. Install only if you need: +- ML-based kernel selection (`examples/gemm/python/09_ml_heuristic.py`) +- Model training (`heuristics/train.py`) +- Model evaluation (`heuristics/evaluate.py`) +- Automated benchmark analysis + +**Core dependencies:** ~50 MB (NumPy only) +**With ML dependencies:** ~500 MB (includes LightGBM, pandas, pyarrow, scikit-learn) + ### Supported Data Types CK Tile supports a wide range of data types for GEMM operations: @@ -319,8 +342,8 @@ ls examples/libdispatcher_gemm_lib.so | `CMAKE_PREFIX_PATH` | - | ROCm installation path | | `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler | -⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower. -⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories). +WARNING: **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower. +WARNING: **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories). --- @@ -340,6 +363,15 @@ cd build/examples ./gemm_04_heuristics # Heuristic kernel selection ./gemm_05_json_export # Registry JSON export ./gemm_06_multi_registry # Multiple registries + +# Grouped Convolution Examples +./grouped_conv_01_basic # Declaration patterns + GPU execution +./grouped_conv_02_all_dirs # Forward/BwdData/BwdWeight with GPU +./grouped_conv_03_bench_val # Benchmark + CPU reference validation +./grouped_conv_04_registry_json # Heuristic selection + JSON export +./grouped_conv_05_bwd_data # Backward data + CPU validation +./grouped_conv_06_bwd_weight # Backward weight + CPU validation +./grouped_conv_07_benchmark # Multi-tile ResNet benchmark ``` ### Python Examples @@ -352,8 +384,16 @@ cd /path/to/composable_kernel/dispatcher # GEMM Examples python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM python3 examples/gemm/python/04_validation.py # CPU reference validation -python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels) +python3 examples/gemm/python/07_stress_test.py # Stress test python3 examples/gemm/python/08_heuristics.py # Heuristic selection + +# Grouped Convolution Examples +python3 examples/grouped_conv/python/01_basic_grouped_conv.py # Config patterns + registry + GPU +python3 examples/grouped_conv/python/02_forward.py # Forward 2D/3D + CPU ref +python3 examples/grouped_conv/python/03_bwd_data.py # Backward data + CPU ref +python3 examples/grouped_conv/python/04_bwd_weight.py # Backward weight + CPU ref +python3 examples/grouped_conv/python/05_benchmark.py # Multi-problem benchmark +python3 examples/grouped_conv/python/06_registry_json.py # Heuristic selection + JSON ``` ### Example Output @@ -470,6 +510,42 @@ python3 examples/gemm/python/10_advanced_benchmark.py \ --- +## ML-Based Kernel Selection (Optional) + +The dispatcher includes ML heuristics for automated kernel selection using trained LightGBM models. + +**Prerequisites:** Install ML dependencies first: + +```bash +pip install -r requirements-ml.txt # ~500 MB (LightGBM, pandas, pyarrow, scikit-learn) +``` + +**Documentation:** See [heuristics/README.md](heuristics/README.md) for: +- Training and evaluating models +- Feature engineering (72 features) +- Using pre-trained models +- Python API reference + +**Examples:** +```bash +python3 examples/gemm/python/09_ml_heuristic.py # ML-based kernel selection +python3 examples/gemm/python/10_rank_kernels.py # Kernel ranking +``` + +**Model Compression:** Trained models are stored in compressed `.lgbm.gz` format to save space (~67% size reduction). Python tools automatically decompress models on first use. For C++ examples, decompress manually: + +```bash +# If you have compressed models +cd heuristics/models/gemm_universal_fp16_gfx950 +gunzip model_tflops.lgbm.gz + +# Then use in C++ example +cd ../../../build +./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm +``` + +--- + ## External Integration ### Using Dispatcher in Your Own Project @@ -588,7 +664,7 @@ lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so") ### Data Flow ``` -KernelConfig → Registry → Dispatcher → GPU Execution +KernelConfig -> Registry -> Dispatcher -> GPU Execution ``` 1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts) @@ -784,31 +860,49 @@ make -j$(nproc) ``` dispatcher/ -├── README.md # This file -├── CMakeLists.txt # Build configuration -│ -├── include/ck_tile/dispatcher/ # C++ headers -│ ├── dispatcher.hpp # GEMM dispatcher -│ ├── registry.hpp # Kernel registry -│ └── kernel_key.hpp # Kernel configuration -│ -├── src/ # C++ implementation -│ -├── codegen/ # Kernel generation -│ ├── unified_gemm_codegen.py # GEMM kernel generator -│ └── arch_specs.json # GPU specifications -│ -├── bindings/ctypes/ # Python ctypes interface -│ └── gemm_ctypes_lib.cpp # GEMM Python library -│ -├── examples/ # Examples -│ └── gemm/ -│ ├── cpp/ # C++ GEMM examples (01-06) -│ └── python/ # Python GEMM examples (01-11) -│ -├── scripts/ # Build scripts -│ -└── tests/ # Unit tests +|---- README.md # This file +|---- CMakeLists.txt # Build configuration +| +|---- include/ck_tile/dispatcher/ # C++ headers +| |---- dispatcher.hpp # Main dispatcher include +| |---- registry.hpp # GEMM kernel registry +| |---- kernel_key.hpp # Kernel configuration +| |---- grouped_conv_config.hpp # Grouped conv configuration +| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder) +| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations +| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe) +| +---- grouped_conv_utils.hpp # Grouped conv utilities +| +|---- src/ # C++ implementation +| +|---- codegen/ # Kernel generation +| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings +| |---- unified_gemm_codegen.py # GEMM kernel generator +| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator +| +---- arch_specs.json # GPU specifications +| +|---- python/ # Python utilities +| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output +| |---- ctypes_utils.py # GEMM ctypes utilities +| +---- grouped_conv_utils.py # Grouped conv utilities +| +|---- scripts/ # Build scripts +| |---- compile_gemm_examples.py # GEMM build script +| +---- compile_grouped_conv_examples.py # Grouped conv build script +| +|---- bindings/ctypes/ # Python ctypes interface +| |---- gemm_ctypes_lib.cpp # GEMM Python library +| +---- conv_ctypes_lib.cpp # Grouped conv Python library +| +|---- examples/ # Examples +| |---- gemm/ +| | |---- cpp/ # C++ GEMM examples (01-07) +| | +---- python/ # Python GEMM examples (01-11) +| +---- grouped_conv/ +| |---- cpp/ # C++ Grouped Conv examples (01-07) +| +---- python/ # Python Grouped Conv examples (01-06) +| ++---- tests/ # Unit tests (C++ and Python) ``` --- @@ -820,17 +914,49 @@ dispatcher/ | GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) | | GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) | | Codegen | [codegen/README.md](codegen/README.md) | +| Python Utils | [python/README.md](python/README.md) | +| C++ Headers | [include/ck_tile/dispatcher/README.md](include/ck_tile/dispatcher/README.md) | --- -## Archived Content +## Grouped Convolution Support -Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`: -- `examples/conv/cpp/` - 11 C++ convolution examples -- `examples/conv/python/` - 14 Python convolution examples -- `codegen/unified_conv_codegen.py` - Conv kernel generator -- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers -- `python/conv_utils.py` - Conv Python utilities +Grouped convolution is fully supported alongside GEMM, with shared infrastructure to eliminate duplication. + +### Python + +```bash +# Generate grouped conv kernels +python3 codegen/unified_grouped_conv_codegen.py \ + --output-dir build/generated_kernels \ + --datatype fp16 --variant forward --ndim-spatial 2 + +# Build grouped conv examples +python3 scripts/compile_grouped_conv_examples.py examples/grouped_conv/cpp/01_basic_grouped_conv.cpp +``` + +### Key Files + +| Component | File | +|-----------|------| +| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` | +| Python Codegen | `codegen/unified_grouped_conv_codegen.py` | +| Python Utils | `python/grouped_conv_utils.py` | +| Build Script | `scripts/compile_grouped_conv_examples.py` | +| Shared Codegen | `codegen/codegen_common.py` | +| Shared Utils | `python/dispatcher_common.py` | + +### Variants + +- **Forward** (`grouped_conv_fwd`) - Standard grouped convolution +- **Backward Data** (`grouped_conv_bwd_data`) - Gradient w.r.t. input +- **Backward Weight** (`grouped_conv_bwd_weight`) - Gradient w.r.t. weights + +### Shared Infrastructure + +GEMM and grouped convolution share common code to avoid duplication: +- `codegen/codegen_common.py` - TileConfig, TraitConfigBase, type mappings, parallel generation, arch-aware expansion +- `python/dispatcher_common.py` - Path helpers, validation, auto-correction, Colors, phased output --- diff --git a/dispatcher/bindings/README.md b/dispatcher/bindings/README.md index 7cda21f6ec..04029d32a9 100644 --- a/dispatcher/bindings/README.md +++ b/dispatcher/bindings/README.md @@ -6,13 +6,13 @@ This directory contains language bindings for the CK Tile Dispatcher. ``` bindings/ -├── ctypes/ # Python ctypes bindings (C API) -│ ├── gemm_ctypes_lib.cpp # GEMM dispatcher C API -│ ├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data) -│ ├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API -│ ├── gpu_helper.cpp # CLI helper for Python -│ └── CMakeLists.txt -└── README.md +|---- ctypes/ # Python ctypes bindings (C API) +| |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API +| |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data) +| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library) +| |---- gpu_helper.cpp # CLI helper for Python +| +---- CMakeLists.txt ++---- README.md ``` ## ctypes Bindings @@ -65,7 +65,7 @@ lib.dispatcher_cleanup() | `dispatcher_export_registry_json()` | Export registry as JSON | | `dispatcher_cleanup()` | Release resources | -### Convolution API +### Grouped Convolution API | Function | Description | |----------|-------------| @@ -105,5 +105,11 @@ Output is JSON for easy parsing: See the examples that use these bindings: - **GEMM**: `dispatcher/examples/gemm/python/` -- **Conv**: `dispatcher/examples/conv/python/` + +### Grouped Convolution + +Grouped convolution C++ headers and Python utilities are in: +- **C++ Headers**: `dispatcher/include/ck_tile/dispatcher/grouped_conv_*.hpp` +- **Python Utils**: `dispatcher/python/grouped_conv_utils.py` +- **Build Script**: `dispatcher/scripts/compile_grouped_conv_examples.py` diff --git a/dispatcher/bindings/ctypes/CMakeLists.txt b/dispatcher/bindings/ctypes/CMakeLists.txt index 804e5e9bd7..18314017f2 100644 --- a/dispatcher/bindings/ctypes/CMakeLists.txt +++ b/dispatcher/bindings/ctypes/CMakeLists.txt @@ -78,7 +78,7 @@ endif() # Look for forward kernels file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp") # Look for backward data kernels -file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp") +file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwd_data_*.hpp") # Fallback: any conv kernel (for backwards compatibility) file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp") @@ -112,7 +112,7 @@ endif() # Add backward data kernel if available if(CONV_BWDD_KERNEL_HEADERS) list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER) - message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}") + message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWD_DATA_KERNEL_HEADER}") target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER}) target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE) endif() diff --git a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp index 09e058f80f..96b4aa3462 100644 --- a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp +++ b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp @@ -53,6 +53,7 @@ struct ConvBwdwProblemC int stride_d, stride_h, stride_w; int pad_d, pad_h, pad_w; int dilation_d, dilation_h, dilation_w; + int split_k; }; // ============================================================================= @@ -108,8 +109,7 @@ static float run_bwd_weight_impl(const void* input_ptr, grad_weight_ptr, // wei_ptr = grad_weight (output) {}, // ds_ptr grad_output_ptr, // out_ptr = grad_output - 1 // k_batch - ); + (prob->split_k > 1) ? prob->split_k : 1); ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; diff --git a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp index d3c64621a7..002219c82e 100644 --- a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp +++ b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp @@ -1,128 +1,46 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT - -/** - * Convolution Dispatcher ctypes Library - * - * Provides C API for Python ctypes integration. - * Supports forward convolution. Backward operations require additional headers. - * - * REQUIRED: Forward kernel header must be force-included via -include flag. - * OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE - * - * Usage from Python: - * lib = ctypes.CDLL("libdispatcher_conv.so") - * lib.conv_dispatcher_init() - * lib.conv_dispatcher_run(...) - */ +// +// Multi-kernel grouped convolution dispatcher for Python ctypes. +// +// Supports: forward / backward-data / backward-weight x 2D / 3D +// +// The dispatch header (conv_python_dispatch.hpp) is force-included via +// -include and brings in ALL compiled kernels with these aliases: +// +// 2D launchers (from include_all headers): +// SelectedConvKernelLauncher (forward 2D) +// SelectedConvBwdDataLauncher (backward-data 2D) +// SelectedConvBwdWeightLauncher (backward-weight 2D) +// +// 3D launchers (from dispatch header): +// ConvFwd3dLauncher (forward 3D) +// ConvBwdData3dLauncher (backward-data 3D) +// ConvBwdWeight3dLauncher (backward-weight 3D) +// +// Usage from Python: +// lib = ctypes.CDLL("libdispatcher_conv_lib.so") +// lib.conv_dispatcher_init() +// lib.conv_dispatcher_run(input, weight, output, &problem, stream) #include -#include -#include +#include #include -#include "ck_tile/dispatcher/conv_utils.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" -using namespace ck_tile::dispatcher; - -// Global state (using shared_ptr for safe memory management) -static std::shared_ptr g_registry = nullptr; -static std::shared_ptr g_dispatcher = nullptr; -static std::vector g_kernels; - extern "C" { -// ============================================================================= -// Initialization -// ============================================================================= - -int conv_dispatcher_init() +// ========================================================================= +// Problem definition (matches Python ctypes struct exactly) +// ========================================================================= +enum ConvDirection { - if(g_registry) - return 0; // Already initialized - - g_registry = std::make_shared(); - g_dispatcher = std::make_shared(g_registry.get()); - - // Register kernel configurations using simple ConvKernelSet - // (actual kernel launch uses the force-included SelectedConvKernelLauncher) - using namespace ck_tile::dispatcher::conv_decl; - - // Forward kernels (required - must be force-included) - // Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb - ConvKernelSet fwd_set; - fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), - ConvAlgorithm() - .tile(128, 128, 64) // tile_m x tile_n x tile_k - .wave(2, 2, 1) - .warp(32, 32, 16) - .pipeline("compv4") - .scheduler("intrawave"), - "gfx942"); - g_registry->register_set(fwd_set, ConvRegistry::Priority::High); - -#ifdef CONV_BWD_DATA_AVAILABLE - // Backward data kernels - // Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16 - ConvKernelSet bwd_data_set; - bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), - ConvAlgorithm() - .tile(128, 128, 64) // tile_m x tile_n x tile_k - .wave(2, 2, 1) - .warp(32, 32, 16) - .pipeline("compv3") - .scheduler("intrawave"), - "gfx942"); - g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High); -#endif - - return 0; -} - -int conv_dispatcher_cleanup() -{ - // shared_ptr automatically handles cleanup when reset - g_dispatcher.reset(); - g_registry.reset(); - g_kernels.clear(); - return 0; -} - -// ============================================================================= -// Registry Management -// ============================================================================= - -int conv_dispatcher_get_kernel_count() -{ - if(!g_registry) - return 0; - return static_cast(g_registry->size()); -} - -int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size) -{ - if(index < 0 || !buffer || buffer_size <= 0) - return -1; - - if(!g_registry) - return -1; - - // Use registry to get kernel names (they are registered with full names) - const auto& kernels = g_registry->all_kernels(); - if(static_cast(index) >= kernels.size()) - return -1; - - const auto* kernel = kernels[index]; - std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1); - buffer[buffer_size - 1] = '\0'; - return 0; -} - -// ============================================================================= -// Problem Definition -// ============================================================================= + CONV_FORWARD = 0, + CONV_BWD_DATA = 1, + CONV_BWD_WEIGHT = 2 +}; struct ConvProblemC { @@ -132,267 +50,33 @@ struct ConvProblemC int stride_d, stride_h, stride_w; int pad_d, pad_h, pad_w; int dilation_d, dilation_h, dilation_w; - int direction; // 0=forward, 1=bwd_data, 2=bwd_weight + int direction; + int split_k; }; -// ============================================================================= -// Kernel Selection -// ============================================================================= +// ========================================================================= +// Initialization / lifecycle +// ========================================================================= +int conv_dispatcher_init() { return 0; } +int conv_dispatcher_cleanup() { return 0; } -int conv_dispatcher_is_supported(const ConvProblemC* prob) -{ - if(!g_registry || !prob) - return 0; - - ConvProblem problem; - problem.N = prob->N; - problem.G = prob->G; - problem.C = prob->C; - problem.K = prob->K; - problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; - problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; - problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; - problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; - problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; - problem.op = static_cast(prob->direction); - problem.compute_output_size(); - - const auto* kernel = g_dispatcher->select(problem); - return kernel ? 1 : 0; -} - -int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size) -{ - if(!g_registry || !prob || !kernel_name || buffer_size <= 0) - return -1; - - ConvProblem problem; - problem.N = prob->N; - problem.G = prob->G; - problem.C = prob->C; - problem.K = prob->K; - problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; - problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; - problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; - problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; - problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; - problem.op = static_cast(prob->direction); - problem.compute_output_size(); - - const auto* kernel = g_dispatcher->select(problem); - if(!kernel) - return -1; - - std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1); - kernel_name[buffer_size - 1] = '\0'; - - return 0; -} - -// ============================================================================= -// Convolution Execution -// ============================================================================= - -// Helper to build ConvParam -static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob) -{ - // Determine if this is 2D or 3D convolution - const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); - - if(is_3d) - { - // 3D convolution: use all spatial dimensions - return ck_tile::conv::ConvParam{3, - prob->G, - prob->N, - prob->K, - prob->C, - {prob->filter_z, prob->filter_y, prob->filter_x}, - {prob->input_d, prob->input_h, prob->input_w}, - {prob->stride_d, prob->stride_h, prob->stride_w}, - {prob->dilation_d, prob->dilation_h, prob->dilation_w}, - {prob->pad_d, prob->pad_h, prob->pad_w}, - {prob->pad_d, prob->pad_h, prob->pad_w}}; - } - else - { - // 2D convolution: only use H, W dimensions - return ck_tile::conv::ConvParam{2, - prob->G, - prob->N, - prob->K, - prob->C, - {prob->filter_y, prob->filter_x}, - {prob->input_h, prob->input_w}, - {prob->stride_h, prob->stride_w}, - {prob->dilation_h, prob->dilation_w}, - {prob->pad_h, prob->pad_w}, - {prob->pad_h, prob->pad_w}}; - } -} - -// Forward convolution (required - kernel header must be force-included) -static float run_forward(const void* input_ptr, - const void* weight_ptr, - void* output_ptr, - const ConvProblemC* prob, - void* stream) -{ - auto conv_param = build_conv_param(prob); - - ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1); - - ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; - - // SelectedConvKernelLauncher is defined in the force-included forward kernel header - return SelectedConvKernelLauncher::launch(args, stream_cfg); -} - -#ifdef CONV_BWD_DATA_AVAILABLE -// Backward data convolution (optional) -// Computes: grad_input = conv_bwd_data(weight, grad_output) -// -// Parameters: -// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT) -// weight_ptr: W - frozen weights (const, read-only INPUT) -// grad_input_ptr: dX - gradient for input (writable, OUTPUT) -static float run_bwd_data(const void* grad_output_ptr, - const void* weight_ptr, - void* grad_input_ptr, - const ConvProblemC* prob, - void* stream) -{ - auto conv_param = build_conv_param(prob); - - // CK Tile API uses tensor POSITION names (from forward pass), not data flow: - // in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data) - // wei_ptr = weight tensor = weight_ptr (W, const) - // out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data) - ck_tile::GroupedConvBwdDataHostArgs args( - conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1); - - ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; - - return SelectedConvBwdDataLauncher::launch(args, stream_cfg); -} -#endif - -#ifdef CONV_BWD_WEIGHT_AVAILABLE -// Backward weight convolution (optional) -// Parameters: -// input_ptr: original forward input X (const, read-only) -// grad_output_ptr: gradient from next layer dY (const, read-only) -// grad_weight_ptr: gradient of weights dW (writable, OUTPUT) -static float run_bwd_weight(const void* input_ptr, - const void* grad_output_ptr, - void* grad_weight_ptr, - const ConvProblemC* prob, - void* stream) -{ - auto conv_param = build_conv_param(prob); - - // GroupedConvBwdWeightHostArgs constructor order: - // (param, in=X, wei=dW (output), ds, out=dY (input), k_batch) - // Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output) - ck_tile::GroupedConvBwdWeightHostArgs args( - conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1); - - ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; - - return SelectedConvBwdWeightLauncher::launch(args, stream_cfg); -} -#endif - -/** - * @brief Execute convolution based on direction specified in prob - * - * Parameter mapping varies by direction: - * Forward (direction=0): - * input_ptr = X (input tensor) - * weight_ptr = W (weight tensor) - * output_ptr = Y (output buffer) - * - * Backward Data (direction=1): - * input_ptr = dY (grad_output - gradient from next layer) - * weight_ptr = W (weight tensor, frozen) - * output_ptr = dX (grad_input buffer) - * - * Backward Weight (direction=2): - * input_ptr = X (forward input tensor) - * weight_ptr = dY (grad_output - gradient from next layer) - * output_ptr = dW (grad_weight buffer) - */ -float conv_dispatcher_run(const void* input_ptr, - const void* weight_ptr, - void* output_ptr, - const ConvProblemC* prob, - void* stream) -{ - // Validate all required pointers before kernel launch - if(!g_dispatcher || !prob) - return -1.0f; - if(!input_ptr || !weight_ptr || !output_ptr) - return -1.0f; // Null data pointer would cause kernel crash - - // Build problem for kernel selection - ConvProblem problem; - problem.N = prob->N; - problem.G = prob->G; - problem.C = prob->C; - problem.K = prob->K; - problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; - problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; - problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; - problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; - problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; - problem.op = static_cast(prob->direction); - problem.compute_output_size(); - - // Select kernel - const auto* kernel = g_dispatcher->select(problem); - if(!kernel) - return -1.0f; - - // Dispatch based on direction - switch(prob->direction) - { - case 0: // Forward (always available) - return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream); - -#ifdef CONV_BWD_DATA_AVAILABLE - case 1: // Backward data - // Convention: caller passes (grad_output, weight, grad_input_buffer) - // in the (input_ptr, weight_ptr, output_ptr) slots respectively. - // run_bwd_data expects: (grad_output, weight, grad_input) - return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream); -#endif - -#ifdef CONV_BWD_WEIGHT_AVAILABLE - case 2: // Backward weight - // Convention: caller passes (input, grad_output, grad_weight_buffer) - // in the (input_ptr, weight_ptr, output_ptr) slots respectively. - // run_bwd_weight expects: (input, grad_output, grad_weight) - return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream); -#endif - - default: return -1.0f; - } -} - -// ============================================================================= -// Info -// ============================================================================= - -const char* conv_dispatcher_version() { return "1.0.0"; } +// ========================================================================= +// Library info +// ========================================================================= +const char* conv_dispatcher_version() { return "2.0.0"; } int conv_dispatcher_has_kernels() { - return 1; // Forward kernel is required +#if defined(CONV_FWD_2D_AVAILABLE) || defined(CONV_FWD_3D_AVAILABLE) + return 1; +#else + return 0; +#endif } int conv_dispatcher_has_bwd_data() { -#ifdef CONV_BWD_DATA_AVAILABLE +#if defined(CONV_BWD_DATA_2D_AVAILABLE) || defined(CONV_BWD_DATA_3D_AVAILABLE) return 1; #else return 0; @@ -401,11 +85,240 @@ int conv_dispatcher_has_bwd_data() int conv_dispatcher_has_bwd_weight() { -#ifdef CONV_BWD_WEIGHT_AVAILABLE +#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE) || defined(CONV_BWD_WEIGHT_3D_AVAILABLE) return 1; #else return 0; #endif } +int conv_dispatcher_get_kernel_count() +{ + return CONV_KERNEL_COUNT; // defined in conv_python_dispatch.hpp +} + +int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size) +{ + if(!buffer || buffer_size <= 0 || index < 0 || index >= CONV_KERNEL_COUNT) + return -1; + std::strncpy(buffer, CONV_KERNEL_NAMES[index], buffer_size - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +} + +// ========================================================================= +// Support query +// ========================================================================= +bool conv_dispatcher_is_supported(const ConvProblemC* prob) +{ + if(!prob) + return false; + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + switch(prob->direction) + { + case CONV_FORWARD: +#if defined(CONV_FWD_3D_AVAILABLE) + if(is_3d) + return true; +#endif +#if defined(CONV_FWD_2D_AVAILABLE) + if(!is_3d) + return true; +#endif + return false; + case CONV_BWD_DATA: +#if defined(CONV_BWD_DATA_3D_AVAILABLE) + if(is_3d) + return true; +#endif +#if defined(CONV_BWD_DATA_2D_AVAILABLE) + if(!is_3d) + return true; +#endif + return false; + case CONV_BWD_WEIGHT: +#if defined(CONV_BWD_WEIGHT_3D_AVAILABLE) + if(is_3d) + return true; +#endif +#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE) + if(!is_3d) + return true; +#endif + return false; + default: return false; + } +} + +// ========================================================================= +// ConvParam builders +// ========================================================================= +static ck_tile::conv::ConvParam make_param_2d(const ConvProblemC* p) +{ + return ck_tile::conv::ConvParam{2, + p->G, + p->N, + p->K, + p->C, + {p->filter_y, p->filter_x}, + {p->input_h, p->input_w}, + {p->stride_h, p->stride_w}, + {p->dilation_h, p->dilation_w}, + {p->pad_h, p->pad_w}, + {p->pad_h, p->pad_w}}; +} + +static ck_tile::conv::ConvParam make_param_3d(const ConvProblemC* p) +{ + return ck_tile::conv::ConvParam{3, + p->G, + p->N, + p->K, + p->C, + {p->filter_z, p->filter_y, p->filter_x}, + {p->input_d, p->input_h, p->input_w}, + {p->stride_d, p->stride_h, p->stride_w}, + {p->dilation_d, p->dilation_h, p->dilation_w}, + {p->pad_d, p->pad_h, p->pad_w}, + {p->pad_d, p->pad_h, p->pad_w}}; +} + +// ========================================================================= +// Kernel launch helpers +// ========================================================================= + +#ifdef CONV_FWD_2D_AVAILABLE +static float +launch_fwd_2d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_2d(p); + ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return SelectedConvKernelLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_FWD_3D_AVAILABLE +static float +launch_fwd_3d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_3d(p); + ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return ConvFwd3dLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_DATA_2D_AVAILABLE +static float launch_bwd_data_2d( + const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_2d(p); + ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return SelectedConvBwdDataLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_DATA_3D_AVAILABLE +static float launch_bwd_data_3d( + const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_3d(p); + ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return ConvBwdData3dLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE +static float launch_bwd_weight_2d( + const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_2d(p); + const int k_batch = (p->split_k > 1) ? p->split_k : 1; + ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return SelectedConvBwdWeightLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE +static float launch_bwd_weight_3d( + const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_3d(p); + const int k_batch = (p->split_k > 1) ? p->split_k : 1; + ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return ConvBwdWeight3dLauncher::launch(args, sc); +} +#endif + +// ========================================================================= +// Main dispatch +// +// direction=0 (forward): a=X(input), b=W(weight), c=Y(output) +// direction=1 (bwd_data): a=dY(grad_out), b=W(weight), c=dX(grad_in) +// direction=2 (bwd_weight): a=X(input), b=dY(grad_out), c=dW(grad_wei) +// ========================================================================= +float conv_dispatcher_run( + const void* a_ptr, const void* b_ptr, void* c_ptr, const ConvProblemC* prob, void* stream) +{ + if(!prob || !a_ptr || !b_ptr || !c_ptr) + return -1.0f; + + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + hipStream_t hip_stream = static_cast(stream); + + try + { + switch(prob->direction) + { + case CONV_FORWARD: +#ifdef CONV_FWD_3D_AVAILABLE + if(is_3d) + return launch_fwd_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif +#ifdef CONV_FWD_2D_AVAILABLE + if(!is_3d) + return launch_fwd_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif + return -2.0f; + + case CONV_BWD_DATA: +#ifdef CONV_BWD_DATA_3D_AVAILABLE + if(is_3d) + return launch_bwd_data_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif +#ifdef CONV_BWD_DATA_2D_AVAILABLE + if(!is_3d) + return launch_bwd_data_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif + return -2.0f; + + case CONV_BWD_WEIGHT: +#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE + if(is_3d) + return launch_bwd_weight_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif +#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE + if(!is_3d) + return launch_bwd_weight_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif + return -2.0f; + + default: return -1.0f; + } + } + catch(const std::exception&) + { + return -3.0f; // Kernel rejected args (e.g. unsupported tile/channel combo) + } + catch(...) + { + return -3.0f; + } +} + } // extern "C" diff --git a/dispatcher/codegen/ADDING_NEW_GPU.md b/dispatcher/codegen/ADDING_NEW_GPU.md index 0bd2966a85..664b59b6b1 100644 --- a/dispatcher/codegen/ADDING_NEW_GPU.md +++ b/dispatcher/codegen/ADDING_NEW_GPU.md @@ -9,8 +9,8 @@ Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatche The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications: ``` -arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python) - → arch_specs_generated.hpp (C++) +arch_specs.json -> generate_arch_specs.py -> arch_specs_generated.py (Python) + -> arch_specs_generated.hpp (C++) ``` ## Quick Start @@ -175,14 +175,14 @@ for error in result.errors: ``` codegen/ -├── arch_specs.json # Single source of truth (EDIT THIS) -├── generate_arch_specs.py # Generator script -├── arch_specs_generated.py # Generated Python module -└── ADDING_NEW_GPU.md # This file +|---- arch_specs.json # Single source of truth (EDIT THIS) +|---- generate_arch_specs.py # Generator script +|---- arch_specs_generated.py # Generated Python module ++---- ADDING_NEW_GPU.md # This file include/ck_tile/dispatcher/ -├── arch_specs_generated.hpp # Generated C++ header -└── arch_filter.hpp # C++ filter +|---- arch_specs_generated.hpp # Generated C++ header ++---- arch_filter.hpp # C++ filter ``` ## Best Practices diff --git a/dispatcher/codegen/README.md b/dispatcher/codegen/README.md index 2d753924f5..40a9b7b8c1 100644 --- a/dispatcher/codegen/README.md +++ b/dispatcher/codegen/README.md @@ -1,11 +1,22 @@ -# CK Tile GEMM Unified Code Generator +# CK Tile Unified Code Generators -Single source of truth for all GEMM kernel generation. +Single source of truth for GEMM and Grouped Convolution kernel generation. > **See also:** [Main Dispatcher README](../README.md) for installation and core concepts. +## Shared Infrastructure + +Both GEMM and Grouped Conv generators share common code via `codegen_common.py`: +- `TileConfig` - Dataclass for tile dimensions +- `TraitConfigBase` - Base for kernel trait configurations with arch-aware validation +- `CommonTypeMappings` - Dtype-to-C++ type mappings +- `parallel_generate()` - Parallel kernel generation with per-kernel progress logging +- Arch-aware expansion helpers (`valid_wave_configs`, `valid_warp_configs`, etc.) + ## Quick Start +### GEMM + ```bash cd dispatcher/codegen @@ -22,6 +33,25 @@ python3 unified_gemm_codegen.py \ --variants standard preshuffle multi_d ``` +### Grouped Convolution + +```bash +cd dispatcher/codegen + +# Generate forward FP16 grouped conv kernels +python3 unified_grouped_conv_codegen.py \ + --output-dir ../build/generated_kernels \ + --datatype fp16 \ + --variant forward \ + --ndim-spatial 2 + +# Generate backward data kernels +python3 unified_grouped_conv_codegen.py \ + --output-dir ../build/generated_kernels \ + --variant backward_data \ + --ndim-spatial 2 +``` + ## Using from Python ```python @@ -58,13 +88,13 @@ results = codegen.generate_all() ## Variants ### Standard -Basic GEMM: `C = A × B` +Basic GEMM: `C = A x B` ### PreShuffle Optimized weight access with LDS pre-shuffling. Best for large matrices. ### Multi-D -Element-wise fusion: `C = op(A × B + D0 + D1 + ...)` +Element-wise fusion: `C = op(A x B + D0 + D1 + ...)` Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh` @@ -72,10 +102,11 @@ Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh` ``` generated_kernels/ -├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp -├── gemm_fp16_rcr_compv4_..._preshuffle.hpp -├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp -└── ... +|---- gemm_fp16_rcr_compv4_..._128x128x32_....hpp # GEMM kernels +|---- gemm_fp16_rcr_compv4_..._preshuffle.hpp +|---- gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp +|---- grouped_conv_fwd_fp16_nhwgc_..._128x128x32_....hpp # Grouped conv kernels ++---- ... ``` ## Configuration Files diff --git a/dispatcher/codegen/codegen_common.py b/dispatcher/codegen/codegen_common.py new file mode 100644 index 0000000000..4e9e8de1b3 --- /dev/null +++ b/dispatcher/codegen/codegen_common.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Shared codegen infrastructure for GEMM and grouped convolution code generators. + +Extracted from unified_gemm_codegen.py + arch-aware expansion helpers from conv. +Both unified_gemm_codegen.py and unified_grouped_conv_codegen.py import from here +to eliminate duplication. +""" + +import logging +import concurrent.futures +from dataclasses import dataclass +from typing import ( + Callable, + ClassVar, + Dict, + FrozenSet, + List, + Optional, + Sequence, + Tuple, + TypeVar, +) + +log = logging.getLogger(__name__) + +T = TypeVar("T") +R = TypeVar("R") + +ANY_INT = -1 + + +# ============================================================================ +# Tile and Trait Configuration (shared between GEMM and Conv) +# ============================================================================ + + +@dataclass +class TileConfig: + """Tile configuration parameters shared by GEMM and grouped conv.""" + + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + def is_valid(self) -> bool: + if self.tile_m <= 0 or self.tile_n <= 0 or self.tile_k <= 0: + return False + return ( + self.tile_m % (self.warp_m * self.warp_tile_m) == 0 + and self.tile_n % (self.warp_n * self.warp_tile_n) == 0 + and self.tile_k % (self.warp_k * self.warp_tile_k) == 0 + ) + + +@dataclass +class TraitConfigBase: + """ + Base kernel trait configuration shared by GEMM and grouped conv. + + GEMM extends this with ``persistent``; grouped conv extends with + ``double_smem_buffer`` and ``num_groups_to_merge``. + """ + + pipeline: str # mem, compv3, compv4, compv5, ... + epilogue: str # cshuffle, default + scheduler: str # intrawave, interwave + pad_m: bool + pad_n: bool + pad_k: bool + + # Unsupported (pipeline, epilogue, scheduler) combinations. + # Only 'mem' and 'basic_v1' pipelines support interwave; all compute + # pipelines (compv3/v4/v5/v6/async) only support intrawave. + _UNSUPPORTED: ClassVar[FrozenSet] = frozenset( + { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + ("compv5", "cshuffle", "interwave"), + ("compv5", "default", "interwave"), + ("compv6", "cshuffle", "interwave"), + ("compv6", "default", "interwave"), + ("comp_async", "cshuffle", "interwave"), + ("comp_async", "default", "interwave"), + ("basic_async_v1", "cshuffle", "interwave"), + ("basic_async_v1", "default", "interwave"), + } + ) + + def is_valid(self) -> bool: + return (self.pipeline, self.epilogue, self.scheduler) not in self._UNSUPPORTED + + +# ============================================================================ +# Type Mappings (centralized for both GEMM and conv codegen) +# ============================================================================ + + +class CommonTypeMappings: + """Centralized type mappings shared by GEMM and grouped conv codegen.""" + + DTYPE_TO_CK = { + "fp16": "fp16_t", + "bf16": "bf16_t", + "fp32": "float", + "fp8": "fp8_t", + "bf8": "bf8_t", + "int8": "int8_t", + } + + DTYPE_TO_CK_QUALIFIED = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int8": "int8_t", + } + + DTYPE_TO_DISPATCHER = { + "fp16": "DataType::FP16", + "bf16": "DataType::BF16", + "fp32": "DataType::FP32", + "fp8": "DataType::FP8", + "bf8": "DataType::BF8", + "int8": "DataType::INT8", + } + + # GEMM-specific layout mappings ("r"/"c" for row/column major). + # Convolution layouts (NHWGC, GKYXC, etc.) are handled by + # unified_grouped_conv_codegen.py via GroupedConvLayout / GroupedConvTypeMappings. + GEMM_LAYOUT_TO_CK = { + "r": "tensor_layout::gemm::RowMajor", + "c": "tensor_layout::gemm::ColumnMajor", + } + LAYOUT_TO_CK = GEMM_LAYOUT_TO_CK # backward compat alias + + GEMM_LAYOUT_TO_DISPATCHER = { + "r": "LayoutTag::RowMajor", + "c": "LayoutTag::ColMajor", + } + LAYOUT_TO_DISPATCHER = GEMM_LAYOUT_TO_DISPATCHER # backward compat alias + + # GEMM-only pipeline mappings (used by unified_gemm_codegen.py). + # Convolution pipelines are in GroupedConvTypeMappings + # (unified_grouped_conv_codegen.py). CK Tile conv supports: + # BASIC_V1, Mem, CompV3, CompV4, CompV5, CompV6, ASYNC_V1, ASYNC_V4. + # The dispatcher currently generates: mem, compv3, compv4. + # preshufflev2 is GEMM-only (weight pre-shuffle for GEMM, not conv). + PIPELINE_TO_CK = { + "mem": "GemmPipelineAgBgCrMem", + "compv3": "GemmPipelineAgBgCrCompV3", + "compv4": "GemmPipelineAgBgCrCompV4", + "compv5": "GemmPipelineAgBgCrCompV5", + "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_BASE = { + "mem": "BaseGemmPipelineAgBgCrMem", + "compv3": "BaseGemmPipelineAgBgCrCompV3", + "compv4": "BaseGemmPipelineAgBgCrCompV4", + "compv5": "BaseGemmPipelineAgBgCrCompV5", + "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_DISPATCHER = { + "mem": "Pipeline::Mem", + "compv3": "Pipeline::CompV3", + "compv4": "Pipeline::CompV4", + "compv5": "Pipeline::CompV5", + "preshufflev2": "Pipeline::PreShuffleV2", + } + + SCHEDULER_TO_CK = { + "intrawave": "GemmPipelineScheduler::Intrawave", + "interwave": "GemmPipelineScheduler::Interwave", + "default": "GemmPipelineScheduler::Default", + } + + SCHEDULER_TO_DISPATCHER = { + "intrawave": "Scheduler::Intrawave", + "interwave": "Scheduler::Interwave", + "default": "Scheduler::Auto", + } + + EPILOGUE_TO_DISPATCHER = { + "cshuffle": "Epilogue::CShuffle", + "default": "Epilogue::Default", + } + + @staticmethod + def get_output_dtype(dtype: str) -> str: + """Get output datatype (fp8/bf8 -> fp16).""" + return "fp16" if dtype in ("fp8", "bf8") else dtype + + +# ============================================================================ +# Code Generation Helpers +# ============================================================================ + + +def generate_cpp_compilation_unit(kernel_name: str) -> str: + """Generate a .cpp compilation unit that includes a kernel header. + + This is the standard pattern: one .cpp per kernel that just includes + the generated .hpp header, causing template instantiation. + """ + return ( + f"// Auto-generated compilation unit for {kernel_name}\n" + f'#include "{kernel_name}.hpp"\n' + ) + + +def parallel_generate( + generate_fn: Callable[[T], R], + items: Sequence[T], + parallel: bool = True, +) -> List[R]: + """Run ``generate_fn`` over ``items``, optionally in parallel. + + Logs per-item progress (best-of-conv pattern). + Returns a flat list of results in completion order. + """ + results: List[R] = [] + if not items: + return results + + if parallel and len(items) > 1: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = {executor.submit(generate_fn, item): item for item in items} + for future in concurrent.futures.as_completed(futures): + result = future.result() + results.append(result) + log.info("Generated: %s", futures[future]) + else: + for item in items: + result = generate_fn(item) + results.append(result) + log.info("Generated: %s", item) + + return results + + +# ============================================================================ +# Arch-Aware Expansion Helpers (adopted from conv kernel_decl.hpp) +# ============================================================================ + +# These load from arch_specs_generated when available, falling back to +# hardcoded defaults that match the most common arch (gfx942). + +_arch_data_cache: Optional[Dict] = None + + +def _get_arch_data() -> Dict: + """Load arch filter data, with caching.""" + global _arch_data_cache + if _arch_data_cache is not None: + return _arch_data_cache + + try: + from arch_specs_generated import ( + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + TRAIT_UNSUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + _arch_data_cache = { + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + _arch_data_cache = { + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + "gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + }, + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv4", "cshuffle", "interwave"), + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + return _arch_data_cache + + +def valid_wave_configs(arch: str) -> List[List[int]]: + """Return valid [wave_m, wave_n, wave_k] combos for *arch*.""" + data = _get_arch_data() + return data["warp_combos"].get(arch, [[2, 2, 1]]) + + +def valid_warp_configs(arch: str, dtype: str) -> List[List[int]]: + """Return valid [warp_tile_m, warp_tile_n, warp_tile_k] combos for *arch*/*dtype*. + + The dtype key is constructed as ``{dtype}_{dtype}_{acc}`` where acc is + fp32 for float types and int32 for int8. + """ + data = _get_arch_data() + acc = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc}" + arch_tiles = data["warp_tile_combos"].get(arch, {}) + return arch_tiles.get(dtype_key, [[32, 32, 16]]) + + +def valid_trait_configs() -> List[Tuple[str, str]]: + """Return valid (pipeline, scheduler) pairs. + + Compute pipelines only support intrawave; mem supports both. + """ + return [ + ("compv3", "intrawave"), + ("compv4", "intrawave"), + ("compv5", "intrawave"), + ("mem", "intrawave"), + ("mem", "interwave"), + ] + + +def needs_wave_expansion(config: dict) -> bool: + """True if wave_m or wave_n is a wildcard (ANY_INT = -1).""" + return config.get("wave_m", 2) == ANY_INT or config.get("wave_n", 2) == ANY_INT + + +def needs_warp_expansion(config: dict) -> bool: + """True if warp_m or warp_n is a wildcard (ANY_INT = -1).""" + return config.get("warp_m", 32) == ANY_INT or config.get("warp_n", 32) == ANY_INT + + +def needs_pipeline_expansion(config: dict) -> bool: + """True if pipeline is a wildcard (\"*\").""" + return config.get("pipeline", "compv4") == "*" diff --git a/dispatcher/codegen/generate_dispatcher_registration.py b/dispatcher/codegen/generate_dispatcher_registration.py index 024ec4a7c8..8e8b67376c 100644 --- a/dispatcher/codegen/generate_dispatcher_registration.py +++ b/dispatcher/codegen/generate_dispatcher_registration.py @@ -109,7 +109,7 @@ inline void register_all_kernels() """ output_file.write_text(content) - print(f"✓ Generated registration header: {output_file}") + print(f"OK Generated registration header: {output_file}") def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path): @@ -143,7 +143,7 @@ namespace generated { """ output_file.write_text(content) - print(f"✓ Generated registration implementation: {output_file}") + print(f"OK Generated registration implementation: {output_file}") def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path): @@ -414,8 +414,8 @@ def main(): with open(manifest_output, "w") as f: json.dump(manifest_data, f, indent=2) - print(f"✓ Generated manifest: {manifest_output}") - print("\n✓ Registration code generation complete!") + print(f"OK Generated manifest: {manifest_output}") + print("\nOK Registration code generation complete!") print(f" Total kernels: {len(kernels)}") print(" Output files:") print(f" - {registration_header}") diff --git a/dispatcher/codegen/generate_kernel_wrappers.py b/dispatcher/codegen/generate_kernel_wrappers.py index 53a9bff3ed..e11bd7a0a5 100644 --- a/dispatcher/codegen/generate_kernel_wrappers.py +++ b/dispatcher/codegen/generate_kernel_wrappers.py @@ -17,10 +17,10 @@ Usage: Output structure: build/kernel_wrappers/ - ├── gemm_fp16_rcr_128x128x32.cpp - ├── gemm_fp16_rcr_256x256x64.cpp - ├── conv_fwd_fp16_2d_128x128.cpp - └── ... + |---- gemm_fp16_rcr_128x128x32.cpp + |---- gemm_fp16_rcr_256x256x64.cpp + |---- conv_fwd_fp16_2d_128x128.cpp + +---- ... Each .cpp simply includes its corresponding .hpp and forces symbol emission. """ diff --git a/dispatcher/codegen/kernel_config_loader.py b/dispatcher/codegen/kernel_config_loader.py index 537fc40581..980b4e5fd0 100644 --- a/dispatcher/codegen/kernel_config_loader.py +++ b/dispatcher/codegen/kernel_config_loader.py @@ -359,8 +359,8 @@ class ConvTraitConfig: @dataclass -class ConvKernelConfig: - """Complete convolution kernel configuration""" +class GroupedConvKernelConfig: + """Complete grouped convolution kernel configuration""" tile: ConvTileConfig = field(default_factory=ConvTileConfig) trait: ConvTraitConfig = field(default_factory=ConvTraitConfig) @@ -419,7 +419,11 @@ class ConvKernelConfig: def kernel_name(self) -> str: """Generate kernel name from config""" - variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"} + variant_map = { + "forward": "fwd", + "bwd_data": "bwd_data", + "bwd_weight": "bwd_weight", + } var_str = variant_map.get(self.variant, self.variant) name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d" @@ -433,11 +437,11 @@ class ConvKernelConfig: @dataclass -class ConvKernelConfigSet: +class GroupedConvKernelConfigSet: """A set of convolution kernel configurations loaded from JSON""" name: str = "default" - configs: List[ConvKernelConfig] = field(default_factory=list) + configs: List[GroupedConvKernelConfig] = field(default_factory=list) # Tile parameter ranges tile_m_values: List[int] = field(default_factory=lambda: [128]) @@ -481,7 +485,7 @@ class ConvKernelConfigSet: layout: str = "nhwgc" gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"]) - def generate_configs(self) -> Iterator[ConvKernelConfig]: + def generate_configs(self) -> Iterator[GroupedConvKernelConfig]: """Generate all kernel configurations (cartesian product)""" # Tile parameters tile_params = itertools.product( @@ -548,7 +552,7 @@ class ConvKernelConfigSet: double_smem_buffer=trait[6], num_groups_to_merge=trait[7], ) - yield ConvKernelConfig( + yield GroupedConvKernelConfig( tile=tile_cfg, trait=trait_cfg, dtype_input=self.dtype_input, @@ -599,7 +603,9 @@ class ConvKernelConfigSet: return tile_count * trait_count * extra_count * len(self.gpu_targets) -def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: +def load_grouped_conv_kernel_configs( + json_path: str | Path, +) -> GroupedConvKernelConfigSet: """ Load convolution kernel configurations from a JSON file. @@ -607,14 +613,14 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: json_path: Path to JSON configuration file Returns: - ConvKernelConfigSet with all parameter values loaded + GroupedConvKernelConfigSet with all parameter values loaded """ json_path = Path(json_path) with open(json_path) as f: data = json.load(f) - config_set = ConvKernelConfigSet() + config_set = GroupedConvKernelConfigSet() # Name config_set.name = data.get("kernel_set_name", json_path.stem) @@ -680,15 +686,15 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: def generate_cpp_conv_kernel_set_declaration( - config_set: ConvKernelConfigSet, + config_set: GroupedConvKernelConfigSet, set_name: Optional[str] = None, ) -> str: """ - Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet. + Generate C++ DECL_GROUPED_CONV_KERNEL_SET code from a GroupedConvKernelConfigSet. """ name = set_name or config_set.name - lines = [f"DECL_CONV_KERNEL_SET({name},"] + lines = [f"DECL_GROUPED_CONV_KERNEL_SET({name},"] for config in config_set.generate_configs(): line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, ' diff --git a/dispatcher/codegen/unified_gemm_codegen.py b/dispatcher/codegen/unified_gemm_codegen.py index b0dd961be7..a818cec83e 100755 --- a/dispatcher/codegen/unified_gemm_codegen.py +++ b/dispatcher/codegen/unified_gemm_codegen.py @@ -7,7 +7,7 @@ Unified GEMM Code Generator - Single Source of Truth This is THE unified code generator for all GEMM kernel variants: -- Standard GEMM (C = A × B) +- Standard GEMM (C = A x B) - Preshuffle GEMM (optimized weight access) - Multi-D GEMM (element-wise fusion) @@ -25,6 +25,12 @@ from dataclasses import dataclass, asdict from enum import Enum import concurrent.futures +from codegen_common import ( + TileConfig, + TraitConfigBase, + CommonTypeMappings as TypeMappings, +) + # Import architecture filter for GPU-specific validation try: from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType @@ -194,62 +200,14 @@ class GemmVariant(Enum): MULTI_D = "multi_d" -@dataclass -class TileConfig: - """Tile configuration parameters""" - - tile_m: int - tile_n: int - tile_k: int - warp_m: int - warp_n: int - warp_k: int - warp_tile_m: int - warp_tile_n: int - warp_tile_k: int - - def is_valid(self) -> bool: - """Validate tile configuration""" - return ( - self.tile_m % (self.warp_m * self.warp_tile_m) == 0 - and self.tile_n % (self.warp_n * self.warp_tile_n) == 0 - and self.tile_k % (self.warp_k * self.warp_tile_k) == 0 - and self.tile_m > 0 - and self.tile_n > 0 - and self.tile_k > 0 - ) +# TileConfig imported from codegen_common @dataclass -class TraitConfig: - """Kernel trait configuration""" +class TraitConfig(TraitConfigBase): + """GEMM-specific trait configuration extending TraitConfigBase with persistent mode.""" - pipeline: str # mem, compv3, compv4 - epilogue: str # default, cshuffle - scheduler: str # intrawave, interwave - pad_m: bool - pad_n: bool - pad_k: bool - persistent: bool - - def is_valid(self) -> bool: - """Check if trait combination is valid""" - # Unsupported combinations - # Only 'mem' pipeline supports interwave scheduler. - # All compute pipelines (compv3/v4/v5/v6/async) only support intrawave. - unsupported = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), - ("compv5", "cshuffle", "interwave"), - ("compv5", "default", "interwave"), - ("compv6", "cshuffle", "interwave"), - ("compv6", "default", "interwave"), - ("comp_async", "cshuffle", "interwave"), - ("comp_async", "default", "interwave"), - } - return (self.pipeline, self.epilogue, self.scheduler) not in unsupported + persistent: bool = False @dataclass @@ -345,89 +303,7 @@ class KernelConfig: # ============================================================================ -class TypeMappings: - """Centralized type mappings for code generation""" - - DTYPE_TO_CK = { - "fp16": "fp16_t", - "bf16": "bf16_t", - "fp32": "float", - "fp8": "fp8_t", - "bf8": "bf8_t", - "int8": "int8_t", - } - - # Fully-qualified types for use outside of 'using namespace ck_tile' scope - DTYPE_TO_CK_QUALIFIED = { - "fp16": "ck_tile::fp16_t", - "bf16": "ck_tile::bf16_t", - "fp32": "float", # Built-in type, no namespace - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "int8": "int8_t", # Built-in type - } - - DTYPE_TO_DISPATCHER = { - "fp16": "DataType::FP16", - "bf16": "DataType::BF16", - "fp32": "DataType::FP32", - "fp8": "DataType::FP8", - "bf8": "DataType::BF8", - "int8": "DataType::INT8", - } - - LAYOUT_TO_CK = { - "r": "tensor_layout::gemm::RowMajor", - "c": "tensor_layout::gemm::ColumnMajor", - } - - LAYOUT_TO_DISPATCHER = { - "r": "LayoutTag::RowMajor", - "c": "LayoutTag::ColMajor", - } - - PIPELINE_TO_CK = { - "mem": "GemmPipelineAgBgCrMem", - "compv3": "GemmPipelineAgBgCrCompV3", - "compv4": "GemmPipelineAgBgCrCompV4", - "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2", - } - - PIPELINE_TO_BASE = { - "mem": "BaseGemmPipelineAgBgCrMem", - "compv3": "BaseGemmPipelineAgBgCrCompV3", - "compv4": "BaseGemmPipelineAgBgCrCompV4", - "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2", - } - - PIPELINE_TO_DISPATCHER = { - "mem": "Pipeline::Mem", - "compv3": "Pipeline::CompV3", - "compv4": "Pipeline::CompV4", - "preshufflev2": "Pipeline::PreShuffleV2", - } - - SCHEDULER_TO_CK = { - "intrawave": "GemmPipelineScheduler::Intrawave", - "interwave": "GemmPipelineScheduler::Interwave", - "default": "GemmPipelineScheduler::Default", - } - - SCHEDULER_TO_DISPATCHER = { - "intrawave": "Scheduler::Intrawave", - "interwave": "Scheduler::Interwave", - "default": "Scheduler::Auto", - } - - EPILOGUE_TO_DISPATCHER = { - "cshuffle": "Epilogue::CShuffle", - "default": "Epilogue::Default", - } - - @staticmethod - def get_output_dtype(dtype: str) -> str: - """Get output datatype (fp8/bf8 -> fp16)""" - return "fp16" if dtype in ["fp8", "bf8"] else dtype +# TypeMappings imported from codegen_common as CommonTypeMappings -> TypeMappings alias # ============================================================================ @@ -1068,7 +944,11 @@ class UnifiedGemmCodegen: } def generate_all(self, parallel: bool = True) -> Dict: - """Generate all kernels""" + """Generate all kernels. + + When parallel=True, all configs across all variants are collected first, + then generated concurrently in a single thread pool for maximum throughput. + """ log.info("Generating GEMM kernels:") log.info(f" Datatype: {self.datatype}") log.info(f" Layout: {self.layout}") @@ -1078,49 +958,24 @@ class UnifiedGemmCodegen: results = {"kernels": [], "wrappers": [], "failed": []} - # Get configurations + # Collect ALL configs across all variants/preselected sets upfront + all_configs = [] if self.use_preselected: - configs = self._get_preselected_configs() - log.info(f" Total configurations: {len(configs)}") + all_configs = self._get_preselected_configs() + log.info(f" Total configurations: {len(all_configs)}") else: for variant in self.variants: - log.info(f"\nGenerating {variant.value} kernels...") configs = self._get_configs_for_variant(variant) - log.info(f" Configurations: {len(configs)}") + log.info(f" {variant.value}: {len(configs)} configurations") + all_configs.extend(configs) + log.info(f" Total across all variants: {len(all_configs)}") - if parallel: - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(self._generate_one, cfg) for cfg in configs - ] - for future in concurrent.futures.as_completed(futures): - try: - k, w = future.result() - results["kernels"].append(k) - results["wrappers"].append(w) - except Exception as e: - results["failed"].append(str(e)) - log.error(f"Failed: {e}") - else: - for cfg in configs: - try: - k, w = self._generate_one(cfg) - results["kernels"].append(k) - results["wrappers"].append(w) - except Exception as e: - results["failed"].append(str(e)) - log.error(f"Failed: {e}") - - # Generate registration header - if results["wrappers"]: - self._generate_registration_header(results["wrappers"]) - - return results - - # Generate from preselected set - if parallel: + # Generate all configs in a single parallel pass + if parallel and all_configs: with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(self._generate_one, cfg) for cfg in configs] + futures = [ + executor.submit(self._generate_one, cfg) for cfg in all_configs + ] for future in concurrent.futures.as_completed(futures): try: k, w = future.result() @@ -1130,7 +985,7 @@ class UnifiedGemmCodegen: results["failed"].append(str(e)) log.error(f"Failed: {e}") else: - for cfg in configs: + for cfg in all_configs: try: k, w = self._generate_one(cfg) results["kernels"].append(k) @@ -1139,7 +994,6 @@ class UnifiedGemmCodegen: results["failed"].append(str(e)) log.error(f"Failed: {e}") - # Generate registration header if results["wrappers"]: self._generate_registration_header(results["wrappers"]) @@ -1638,12 +1492,19 @@ def main(): # Write to temp file and use as config import tempfile + import os as _os - with tempfile.NamedTemporaryFile( + _tmp_config = tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False - ) as f: - json.dump(full_config, f) - args.config = Path(f.name) + ) + try: + json.dump(full_config, _tmp_config) + _tmp_config.close() + args.config = Path(_tmp_config.name) + except Exception: + _tmp_config.close() + _os.unlink(_tmp_config.name) + raise except json.JSONDecodeError as e: logging.error(f"Invalid tile-config-json: {e}") return 1 @@ -1672,7 +1533,7 @@ def main(): results = codegen.generate_all(parallel=not args.no_parallel) - logging.info("\n✅ Generation complete!") + logging.info("\nGeneration complete.") logging.info(f" Kernels: {len(results['kernels'])}") logging.info(f" Wrappers: {len(results['wrappers'])}") logging.info(f" Failed: {len(results['failed'])}") @@ -1684,7 +1545,7 @@ def main(): # Generate dispatcher registration if requested if args.register: - logging.info("\n📝 Generating dispatcher registration code...") + logging.info("\nGenerating dispatcher registration code...") try: from generate_dispatcher_registration import ( scan_generated_headers, @@ -1701,11 +1562,20 @@ def main(): ) generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp") - logging.info(f"✓ Generated registration code for {len(kernels)} kernels") + logging.info(f"Generated registration code for {len(kernels)} kernels") except Exception as e: logging.error(f"Failed to generate registration code: {e}") return 1 + # Clean up temp config file if we created one + if args.tile_config_json and args.config and args.config.exists(): + try: + import os as _os + + _os.unlink(args.config) + except OSError: + pass + return 0 if not results["failed"] else 1 diff --git a/dispatcher/codegen/unified_grouped_conv_codegen.py b/dispatcher/codegen/unified_grouped_conv_codegen.py new file mode 100644 index 0000000000..ff40cb4ed4 --- /dev/null +++ b/dispatcher/codegen/unified_grouped_conv_codegen.py @@ -0,0 +1,1757 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Unified Grouped Convolution Code Generator + +This is the unified code generator for all grouped convolution kernel variants: +- Forward grouped convolution +- Backward data grouped convolution +- Backward weight grouped convolution + +Generates both CK Tile kernels AND dispatcher wrappers. +Based on the GEMM codegen pattern. +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from enum import Enum + +from codegen_common import ( + TileConfig, + TraitConfigBase, + parallel_generate, +) + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +log = logging.getLogger(__name__) + +# Import architecture filter for GPU-specific validation +try: + from arch_filter import ArchFilter, OperatorType + + HAS_ARCH_FILTER = True +except ImportError: + HAS_ARCH_FILTER = False + ArchFilter = None + OperatorType = None + + +# ============================================================================ +# Configuration and Data Structures +# ============================================================================ + + +class GroupedConvVariant(Enum): + """Grouped convolution kernel variants""" + + FORWARD = "forward" + BACKWARD_DATA = "bwd_data" + BACKWARD_WEIGHT = "bwd_weight" + + +class GroupedConvLayout(Enum): + """Grouped convolution data layouts""" + + # 1D + NWGC = "NWGC" # Input/Output: N W G C + GKXC = "GKXC" # Weight: G K X C + NWGK = "NWGK" # Output: N W G K + + # 2D + NHWGC = "NHWGC" # Input: N H W G C + GKYXC = "GKYXC" # Weight: G K Y X C + NHWGK = "NHWGK" # Output: N H W G K + + # 3D + NDHWGC = "NDHWGC" # Input: N D H W G C + GKZYXC = "GKZYXC" # Weight: G K Z Y X C + NDHWGK = "NDHWGK" # Output: N D H W G K + + +@dataclass +class GroupedConvTraitConfig(TraitConfigBase): + """Kernel trait configuration for grouped convolution (extends TraitConfigBase). + + Conv-specific extensions beyond TraitConfigBase. These map to + GroupedConvTraits template parameters in grouped_convolution_utils.hpp: + - double_smem_buffer: ping-pong LDS for compute V4+ pipelines + - num_groups_to_merge: fuse multiple groups into one tile (NumGroupsToMerge) + - split_image: split spatial dims for large tensors (EnableSplitImage) + - explicit_gemm: use explicit GEMM path (ExplicitGemm) + - two_stage: two-stage bwd_weight with fp32 workspace + elementwise convert + + Note: CK Tile already uses long_index_t (64-bit) for group strides and + batch offsets, so there is no separate "large_tensor" flag. For large + spatial dimensions, use split_image=True instead. + """ + + double_smem_buffer: bool = False + num_groups_to_merge: int = 1 + split_image: bool = False + explicit_gemm: bool = False + two_stage: bool = False + + +# Backward compatibility alias +TraitConfig = GroupedConvTraitConfig + + +@dataclass +class GroupedConvKernelConfig: + """Complete grouped convolution kernel configuration""" + + tile: TileConfig + trait: GroupedConvTraitConfig + variant: GroupedConvVariant = GroupedConvVariant.FORWARD + ndim_spatial: int = 2 # 1D, 2D, or 3D + arch: str = "gfx942" # Target architecture + layout: Union[str, GroupedConvLayout] = ( + "nhwgc" # Data layout (e.g., "nhwgc", "ndhwgc") + ) + + # Vector sizes: a=4 for fp16 input (8-byte aligned global loads), + # b=8 for weight tensor, c=8 for output stores. These match the + # CK Tile default vectorization widths for fp16 on CDNA3 (gfx942). + vector_size_a: int = 4 + vector_size_b: int = 8 + vector_size_c: int = 8 + vector_sizes: Optional[Tuple[int, int, int]] = None + + # Occupancy parameters + block_per_cu: int = 1 + num_wave_groups: int = 1 + num_groups_to_merge: int = 1 + + # Double buffering + double_smem_buffer: bool = False + + def __post_init__(self): + if self.vector_sizes is not None: + self.vector_size_a, self.vector_size_b, self.vector_size_c = ( + self.vector_sizes[:3] + ) + # Sync trait fields with top-level fields (trait is source of truth + # when both are specified, but top-level overrides default trait values). + if self.double_smem_buffer and not self.trait.double_smem_buffer: + self.trait.double_smem_buffer = self.double_smem_buffer + elif self.trait.double_smem_buffer: + self.double_smem_buffer = self.trait.double_smem_buffer + if self.num_groups_to_merge != 1 and self.trait.num_groups_to_merge == 1: + self.trait.num_groups_to_merge = self.num_groups_to_merge + elif self.trait.num_groups_to_merge != 1: + self.num_groups_to_merge = self.trait.num_groups_to_merge + + def _layout_str(self) -> str: + """Get layout as lowercase string for naming.""" + if hasattr(self.layout, "value"): + return self.layout.value.lower() + return str(self.layout).lower() + + def name(self, datatype: str) -> str: + """ + Generate kernel name that uniquely identifies the kernel configuration. + + Format: grouped_conv_{variant}_{dtype}_{layout}_{ndim}d_{pipeline}_{epilogue}_{scheduler} + _{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k} + _{warp_tile_m}x{warp_tile_n}x{warp_tile_k} + [_vec{a}_{b}_{c}][_bpc{n}][_wg{n}][_gm{n}][_dsb][_pad{mnk}] + + All parameters that affect kernel behavior MUST be included to ensure + unique names for unique configurations: + - Variant (fwd/bwd_data/bwd_weight) + - Data type + - Layout (nhwgc, nchw, ndhwgc, etc.) + - Spatial dimensions (2d/3d) + - Pipeline, epilogue, scheduler + - Tile, warp, warp_tile dimensions + - Vector sizes, occupancy hints (if non-default) + - Double SMEM buffer, padding flags + """ + t = self.tile + tr = self.trait + layout_str = self._layout_str() + + variant_str = { + GroupedConvVariant.FORWARD: "fwd", + GroupedConvVariant.BACKWARD_DATA: "bwd_data", + GroupedConvVariant.BACKWARD_WEIGHT: "bwd_weight", + }[self.variant] + + # Core identity: variant, dtype, layout, dims + name = ( + f"grouped_conv_{variant_str}_{datatype}_{layout_str}_{self.ndim_spatial}d" + ) + + # Pipeline configuration + name += f"_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}" + + # Block tile dimensions (M_Tile x N_Tile x K_Tile) + name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}" + + # Wave distribution (M_Warp x N_Warp x K_Warp) + name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}" + + # Warp tile dimensions (M_Warp_Tile x N_Warp_Tile x K_Warp_Tile) + name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}" + + # Vector sizes (only if non-default) + if (self.vector_size_a, self.vector_size_b, self.vector_size_c) != (4, 8, 8): + name += ( + f"_vec{self.vector_size_a}_{self.vector_size_b}_{self.vector_size_c}" + ) + + # Occupancy hints (only if non-default) + if self.block_per_cu != 1: + name += f"_bpc{self.block_per_cu}" + + if self.num_wave_groups != 1: + name += f"_wg{self.num_wave_groups}" + + if self.num_groups_to_merge != 1: + name += f"_gm{self.num_groups_to_merge}" + + # Double SMEM buffer (for compute V4+) + if self.double_smem_buffer or tr.double_smem_buffer: + name += "_dsb" + + # Two-stage bwd_weight (fp32 workspace + elementwise convert) + if tr.two_stage: + name += "_2stage" + + # Padding suffix (only if not all enabled) + if not (tr.pad_m and tr.pad_n and tr.pad_k): + name += f"_pad{int(tr.pad_m)}{int(tr.pad_n)}{int(tr.pad_k)}" + + return name + + def is_valid_for_arch(self, arch: Optional[str] = None) -> bool: + """Check if configuration is valid for target architecture""" + target_arch = arch if arch is not None else self.arch + + # Check trait validity + if not self.trait.is_valid(): + return False + + # Backward operations have stricter pipeline requirements: + # - Backward weight: compv4/compv5 have transpose_tile2d issues + # - Backward data: compv4 has get_length issues in bwd_data kernel + # Both backward operations ONLY support compv3 and mem pipelines + if self.variant in ( + GroupedConvVariant.BACKWARD_WEIGHT, + GroupedConvVariant.BACKWARD_DATA, + ): + if self.trait.pipeline not in ("compv3", "mem"): + return False + + # Check warp configuration (from arch_specs) + try: + from arch_specs_generated import WARP_SUPPORTED_COMBINATIONS + + supported = WARP_SUPPORTED_COMBINATIONS.get(target_arch) + if supported is None: + return False # Unknown architecture + warp_cfg = [self.tile.warp_m, self.tile.warp_n, self.tile.warp_k] + if warp_cfg not in supported: + return False + except ImportError: + pass # Allow if arch_specs not available + + return True + + +# ============================================================================ +# Type Mappings +# ============================================================================ + + +class GroupedConvTypeMappings: + """Centralized type mappings for grouped convolution code generation""" + + DTYPE_TO_CK = { + "fp16": "half_t", + "bf16": "bf16_t", + "fp32": "float", + } + + # CK Tile conv pipelines (from conv_configs.hpp PipelineTypeTraits). + # basic_v1/mem/compv3 use GroupedConvUniversalPipelineAgBgCrPolicy; + # compv4/compv5/compv6/comp_async/basic_async_v1 use their own default policy. + PIPELINE_TO_CK = { + "basic_v1": "GemmPipeline::BASIC_V1", + "mem": "GemmPipeline::MEMORY", + "compv3": "GemmPipeline::COMPUTE_V3", + "compv4": "GemmPipeline::COMPUTE_V4", + "compv5": "GemmPipeline::COMPUTE_V5", + "compv6": "GemmPipeline::COMPUTE_V6", + "comp_async": "GemmPipeline::COMPUTE_ASYNC", + "basic_async_v1": "GemmPipeline::BASIC_ASYNC_V1", + } + + SCHEDULER_TO_CK = { + "intrawave": "GemmPipelineScheduler::Intrawave", + "interwave": "GemmPipelineScheduler::Interwave", + } + + LAYOUT_1D = { + "in": "tensor_layout::convolution::NWGC", + "wei": "tensor_layout::convolution::GKXC", + "out": "tensor_layout::convolution::NWGK", + } + + LAYOUT_2D = { + "in": "tensor_layout::convolution::NHWGC", + "wei": "tensor_layout::convolution::GKYXC", + "out": "tensor_layout::convolution::NHWGK", + } + + LAYOUT_3D = { + "in": "tensor_layout::convolution::NDHWGC", + "wei": "tensor_layout::convolution::GKZYXC", + "out": "tensor_layout::convolution::NDHWGK", + } + + @classmethod + def get_layouts(cls, ndim: int) -> dict: + if ndim == 1: + return cls.LAYOUT_1D + elif ndim == 2: + return cls.LAYOUT_2D + else: + return cls.LAYOUT_3D + + +# ============================================================================ +# CK Tile Grouped Conv Kernel Generator +# ============================================================================ + + +class CKTileGroupedConvKernelGenerator: + """Generates CK Tile grouped convolution kernel instance code""" + + def __init__( + self, + datatype: str, + variant: GroupedConvVariant = GroupedConvVariant.FORWARD, + ): + self.datatype = datatype + self.variant = variant + self.tm = GroupedConvTypeMappings() + + def generate(self, config: GroupedConvKernelConfig) -> str: + """Generate complete CK Tile grouped convolution kernel""" + kernel_name = config.name(self.datatype) + return f"""{self._header(kernel_name, config)} +{self._config_struct(config, kernel_name)} +{self._kernel_instance(config, kernel_name)} +""" + + def _header(self, kernel_name: str, config: GroupedConvKernelConfig) -> str: + """Generate header includes based on variant""" + if self.variant == GroupedConvVariant.BACKWARD_DATA: + kernel_header = "grouped_convolution_backward_data_kernel.hpp" + elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT: + kernel_header = "grouped_convolution_backward_weight_kernel.hpp" + else: + kernel_header = "grouped_convolution_forward_kernel.hpp" + + elementwise_include = "" + if config.trait.two_stage: + elementwise_include = '\n#include "ck_tile/ops/elementwise.hpp"' + + return f"""// SPDX-License-Identifier: MIT +// Auto-generated CK Tile Grouped Convolution kernel: {kernel_name} +// Variant: {self.variant.value} +#pragma once + +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/grouped_convolution/kernel/{kernel_header}" +#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp"{elementwise_include} + +using namespace ck_tile; +""" + + def _config_struct(self, config: GroupedConvKernelConfig, kernel_name: str) -> str: + """Generate config struct""" + t = config.tile + tr = config.trait + layouts = self.tm.get_layouts(config.ndim_spatial) + + return f""" +// Kernel configuration +struct {kernel_name}_Config {{ + // Data types + using InDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using WeiDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using AccDataType = float; + using OutDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + + // Layouts + using InLayout = {layouts["in"]}; + using WeiLayout = {layouts["wei"]}; + using OutLayout = {layouts["out"]}; + + // Tile shape + static constexpr index_t M_Tile = {t.tile_m}; + static constexpr index_t N_Tile = {t.tile_n}; + static constexpr index_t K_Tile = {t.tile_k}; + + static constexpr index_t M_Warp = {t.warp_m}; + static constexpr index_t N_Warp = {t.warp_n}; + static constexpr index_t K_Warp = {t.warp_k}; + + static constexpr index_t M_Warp_Tile = {t.warp_tile_m}; + static constexpr index_t N_Warp_Tile = {t.warp_tile_n}; + static constexpr index_t K_Warp_Tile = {t.warp_tile_k}; + + // Vector sizes + static constexpr index_t VectorSizeA = {config.vector_size_a}; + static constexpr index_t VectorSizeB = {config.vector_size_b}; + static constexpr index_t VectorSizeC = {config.vector_size_c}; + + // Padding + static constexpr bool kPadM = {str(tr.pad_m).lower()}; + static constexpr bool kPadN = {str(tr.pad_n).lower()}; + static constexpr bool kPadK = {str(tr.pad_k).lower()}; + + // Pipeline & Epilogue + static constexpr auto Pipeline = {self.tm.PIPELINE_TO_CK[tr.pipeline]}; + static constexpr auto Scheduler = {self.tm.SCHEDULER_TO_CK[tr.scheduler]}; + static constexpr bool DoubleSmemBuffer = {str(tr.double_smem_buffer).lower()}; + static constexpr bool UseCShuffleEpilogue = {str(tr.epilogue == "cshuffle").lower()}; + + // Other params + static constexpr int kBlockPerCu = {config.block_per_cu}; + static constexpr index_t NumWaveGroups = {config.num_wave_groups}; + static constexpr index_t NumGroupsToMerge = {tr.num_groups_to_merge}; + static constexpr bool EnableSplitImage = {str(tr.split_image).lower()}; + static constexpr bool ExplicitGemm = {str(tr.explicit_gemm).lower()}; + static constexpr index_t NDimSpatial = {config.ndim_spatial}; + + // Target architecture + static constexpr const char* TargetArch = "{config.arch}"; +}}; +""" + + def _kernel_instance( + self, config: GroupedConvKernelConfig, kernel_name: str + ) -> str: + """Generate kernel instantiation code with launch function""" + tr = config.trait + + if self.variant == GroupedConvVariant.BACKWARD_WEIGHT and tr.two_stage: + return self._kernel_instance_two_stage(config, kernel_name) + + # Variant-specific configuration + if self.variant == GroupedConvVariant.BACKWARD_DATA: + host_args_type = "GroupedConvBwdDataHostArgs" + kernel_type = "GroupedConvolutionBackwardDataKernel" + gemm_traits = "GroupedConvImplicitGemmTraitsBwdData" + layout_suffix = "BwdData" + # For bwd_data: A=dOutput, B=Weight, C=dInput + a_dtype = "OutDataType" + b_dtype = "WeiDataType" + c_dtype = "InDataType" + gemm_k_calc = "args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()" + direction_prefix = "BWD_DATA" + launcher_alias = "SelectedConvBwdDataLauncher" + elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT: + host_args_type = "GroupedConvBwdWeightHostArgs" + kernel_type = "GroupedConvolutionBackwardWeightKernel" + gemm_traits = "GroupedConvImplicitGemmTraitsBwdWeight" + layout_suffix = "BwdWeight" + # For bwd_weight: A=dOutput, B=Input, C=dWeight (per CK Tile invoker) + a_dtype = "OutDataType" + b_dtype = "InDataType" + c_dtype = "WeiDataType" + gemm_k_calc = "args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end()" + direction_prefix = "BWD_WEIGHT" + launcher_alias = "SelectedConvBwdWeightLauncher" + else: # Forward + host_args_type = "GroupedConvFwdHostArgs<>" + kernel_type = "GroupedConvolutionForwardKernel" + gemm_traits = "GroupedConvImplicitGemmTraitsFwd" + layout_suffix = "Fwd" + a_dtype = "InDataType" + b_dtype = "WeiDataType" + c_dtype = "OutDataType" + gemm_k_calc = "args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()" + direction_prefix = "FWD" + launcher_alias = "SelectedConvKernelLauncher" + + # Create valid C++ namespace name + ns_name = "ns_" + kernel_name.replace("-", "_") + + return f""" +// Unique namespace for this kernel to avoid conflicts when including multiple kernels +namespace {ns_name} {{ + +// Bring Config into namespace +using Config = {kernel_name}_Config; + +// Kernel name for identification +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}"; + +// Selected kernel alias +using SelectedConv{direction_prefix.title()}Kernel = Config; + +// ============================================================================= +// Kernel Launch Implementation ({self.variant.value}) +// ============================================================================= + +struct {kernel_name}_Launcher {{ + using KernelConfig = Config; // Use the Config alias from namespace + using InDataType = typename Config::InDataType; + using WeiDataType = typename Config::WeiDataType; + using OutDataType = typename Config::OutDataType; + using AccDataType = typename Config::AccDataType; + using InLayout = typename Config::InLayout; + using WeiLayout = typename Config::WeiLayout; + using OutLayout = typename Config::OutLayout; + + static constexpr index_t NDimSpatial = Config::NDimSpatial; + + // Implicit GEMM shape + using GemmShape = TileGemmShape< + sequence, + sequence, + sequence>; + + // Convolution traits + static constexpr auto ConvSpec = ConvolutionSpecialization::Default; + using GroupedConvTraitsType = GroupedConvTraits< + NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout, + Config::VectorSizeA, Config::VectorSizeB, Config::VectorSizeC, + Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>; + + // Tile partitioner + using TilePartitioner = GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + // Universal traits - layout suffix changes per variant + using GemmUniversalTraits = TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + Config::DoubleSmemBuffer, + typename GroupedConvTraitsType::AsLayout{layout_suffix}, + typename GroupedConvTraitsType::BsLayout{layout_suffix}, + typename GroupedConvTraitsType::CLayout{layout_suffix}, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + Config::NumWaveGroups>; + + // Pipeline problem - data types change per variant + using GemmPipelineProblem = GemmPipelineProblem< + {a_dtype}, {b_dtype}, AccDataType, GemmShape, + typename GroupedConvTraitsType::template {gemm_traits}, + element_wise::PassThrough, element_wise::PassThrough, {c_dtype}, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + // Base pipeline for tail handling + using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}; + + static float launch(const {host_args_type}& args, const stream_config& s) {{ + const index_t gemm_k = {gemm_k_calc}, 1, std::multiplies()); + + const index_t k_grain = args.k_batch * Config::K_Tile; + const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = Config::Scheduler; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + {a_dtype}, {b_dtype}, AccDataType, GemmShape, GemmUniversalTraits, + scheduler, + element_wise::PassThrough, element_wise::PassThrough, {c_dtype}, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")}; + + using ConvEpilogue = CShuffleEpilogue, AccDataType, {c_dtype}, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile, + Config::N_Warp_Tile, Config::K_Warp_Tile, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + Config::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + Config::VectorSizeC, false, 1, Config::DoubleSmemBuffer>>; + + using Kernel = {kernel_type}< + GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = Kernel::MakeKernelArgs(args); + + if (!Kernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported for grouped conv kernel"); + }} + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + ave_time = launch_kernel(s, make_kernel( + Kernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }} +}}; + +// Launcher alias for tile_engine compatibility +using {launcher_alias} = {kernel_name}_Launcher; + +}} // namespace {ns_name} + +// Export specific launcher to global namespace +using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher; + +// When used with -include compiler flag, export aliases to global namespace +#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE +using {launcher_alias} = {ns_name}::{launcher_alias}; +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME; +#endif +""" + + # Pipelines that accept GroupedConvUniversalPipelineAgBgCrPolicy + # as a second template parameter for conv-specific LDS layout. + # (from conv_configs.hpp PipelineTypeTraits -- basic_v1/mem/compv3) + # CompV4/V5/V6/comp_async/basic_async_v1 use their own default policies. + _CONV_POLICY_PIPELINES = {"basic_v1", "mem", "compv3"} + + def _get_pipeline(self, pipeline: str) -> str: + """Get pipeline class name.""" + pipelines = { + "basic_v1": "GemmPipelineAGmemBGmemCRegV1", + "mem": "GemmPipelineAgBgCrMem", + "compv3": "GemmPipelineAgBgCrCompV3", + "compv4": "GemmPipelineAgBgCrCompV4", + "compv5": "GemmPipelineAgBgCrCompV5", + "compv6": "GemmPipelineAgBgCrCompV6", + "comp_async": "GemmPipelineAgBgCrCompAsync", + "basic_async_v1": "GemmPipelineAGmemBGmemCRegAsyncV1", + } + return pipelines.get(pipeline, "GemmPipelineAgBgCrCompV3") + + def _get_pipeline_template_args(self, pipeline: str, problem_type: str) -> str: + """Get full template argument list for pipeline instantiation. + + For basic_v1/mem/compv3, passes GroupedConvUniversalPipelineAgBgCrPolicy + as a second template argument for conv-specific LDS banking. + """ + base = self._get_pipeline(pipeline) + if pipeline in self._CONV_POLICY_PIPELINES: + return f"{base}<{problem_type}, GroupedConvUniversalPipelineAgBgCrPolicy>" + return f"{base}<{problem_type}>" + + def _get_base_pipeline(self, pipeline: str) -> str: + """Get base pipeline class name (used for tail handling only). + + Note: basic_async_v1 inherits from BaseGemmPipelineAGmemBGmemCRegV1 + (there is no separate BaseGemmPipelineAGmemBGmemCRegAsyncV1). + """ + pipelines = { + "basic_v1": "BaseGemmPipelineAGmemBGmemCRegV1", + "mem": "BaseGemmPipelineAgBgCrMem", + "compv3": "BaseGemmPipelineAgBgCrCompV3", + "compv4": "BaseGemmPipelineAgBgCrCompV4", + "compv5": "BaseGemmPipelineAgBgCrCompV5", + "compv6": "BaseGemmPipelineAgBgCrCompV6", + "comp_async": "BaseGemmPipelineAgBgCrCompAsync", + "basic_async_v1": "BaseGemmPipelineAGmemBGmemCRegV1", + } + return pipelines.get(pipeline, "BaseGemmPipelineAgBgCrCompV3") + + def _kernel_instance_two_stage( + self, config: GroupedConvKernelConfig, kernel_name: str + ) -> str: + """Generate two-stage bwd_weight kernel: GEMM into fp32 workspace + ElementWise convert. + + Mirrors grouped_convolution_backward_weight_two_stage_invoker.hpp from + example/ck_tile/20_grouped_convolution/. + """ + tr = config.trait + ns_name = "ns_" + kernel_name.replace("-", "_") + direction_prefix = "BWD_WEIGHT" + launcher_alias = "SelectedConvBwdWeightLauncher" + + return f""" +namespace {ns_name} {{ + +using Config = {kernel_name}_Config; +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}"; +using SelectedConv{direction_prefix.title()}Kernel = Config; + +struct {kernel_name}_Launcher {{ + using KernelConfig = Config; + using InDataType = typename Config::InDataType; + using WeiDataType = typename Config::WeiDataType; + using OutDataType = typename Config::OutDataType; + using AccDataType = typename Config::AccDataType; + using InLayout = typename Config::InLayout; + using WeiLayout = typename Config::WeiLayout; + using OutLayout = typename Config::OutLayout; + using WorkspaceDataType = float; + + static constexpr index_t NDimSpatial = Config::NDimSpatial; + // Two-stage forces VectorSizeC = 1 for workspace writes + static constexpr index_t VectorSizeC_TwoStage = 1; + + using GemmShape = TileGemmShape< + sequence, + sequence, + sequence>; + + static constexpr auto ConvSpec = ConvolutionSpecialization::Default; + using GroupedConvTraitsType = GroupedConvTraits< + NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout, + Config::VectorSizeA, Config::VectorSizeB, VectorSizeC_TwoStage, + Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>; + + using TilePartitioner = GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + Config::DoubleSmemBuffer, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + Config::NumWaveGroups>; + + using GemmPipelineProblem = GemmPipelineProblem< + OutDataType, InDataType, AccDataType, GemmShape, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight, + element_wise::PassThrough, element_wise::PassThrough, WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}; + + static float launch(const GroupedConvBwdWeightHostArgs& args, const stream_config& s) {{ + const index_t gemm_k = args.N_ * std::accumulate( + args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end(), + 1, std::multiplies()); + + const index_t k_grain = args.k_batch * Config::K_Tile; + const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = Config::Scheduler; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + OutDataType, InDataType, AccDataType, GemmShape, GemmUniversalTraits, + scheduler, + element_wise::PassThrough, element_wise::PassThrough, WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")}; + + // Epilogue writes to fp32 workspace (not fp16 output) + using ConvEpilogue = CShuffleEpilogue, AccDataType, WorkspaceDataType, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile, + Config::N_Warp_Tile, Config::K_Warp_Tile, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + Config::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeC>>; + + using Kernel = GroupedConvolutionBackwardWeightKernel< + GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>; + + // ElementWise kernel: fp32 workspace -> fp16/bf16 output + using XElementwiseOp = element_wise::UnaryConvert; + using EwBlockTile = sequence<2048>; + using EwBlockWarps = sequence<8>; + using EwWarpTile = sequence<64>; + using EwShape = ElementWiseShape; + using EwProblem = ElementWisePipelineProblem< + WorkspaceDataType, WorkspaceDataType, WeiDataType, EwShape, XElementwiseOp>; + using EwKernel = ElementWiseKernel; + + // Workspace: G * K * C * product(filter_spatial) elements in fp32 + const index_t spatial_accum = std::accumulate( + args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), + 1, std::multiplies()); + DeviceMem ws_buf(args.G_ * args.K_ * args.C_ * spatial_accum * sizeof(WorkspaceDataType)); + + GroupedConvBwdWeightHostArgs ws_args(args); + auto* c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_buf.GetDeviceBuffer(); + + auto kargs = Kernel::MakeKernelArgs(ws_args); + + if(!Kernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported for two-stage bwd_weight kernel"); + }} + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + // ElementWise kernel setup + const index_t ew_block_size = EwKernel::BlockSize(); + const index_t total_elems = args.G_ * args.K_ * args.C_ * spatial_accum; + constexpr index_t elems_per_block = EwBlockTile::at(number<0>{{}}); + const index_t ew_grid_size = (total_elems + elems_per_block - 1) / elems_per_block; + + auto ew_shape = make_tuple(args.G_ * args.K_, + args.C_ * spatial_accum); + auto ew_inputs = make_tuple(static_cast(ws_args.wei_ptr)); + + if(!EwKernel::IsSupportedArgument(ew_shape)) {{ + throw std::runtime_error("ElementWise arguments not supported for two-stage convert"); + }} + + auto preprocess = [&]() {{ + if(kargs.k_batch > 1) + hip_check_error(hipMemsetAsync( + ws_args.wei_ptr, 0, + total_elems * sizeof(WorkspaceDataType), + s.stream_id_)); + }}; + + ave_time = launch_kernel_time_mask( + s, preprocess, + make_kernel(Kernel{{}}, grids, blocks, 0, kargs), + make_kernel( + EwKernel{{}}, ew_grid_size, ew_block_size, 0, + ew_shape, + make_tuple(args.C_ * spatial_accum, 1), + make_tuple(args.C_ * spatial_accum, 1), + ew_inputs, + static_cast(c_ptr))); + + return ave_time; + }} +}}; + +using {launcher_alias} = {kernel_name}_Launcher; + +}} // namespace {ns_name} + +using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher; + +#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE +using {launcher_alias} = {ns_name}::{launcher_alias}; +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME; +#endif +""" + + +# ============================================================================ +# Dispatcher Wrapper Generator +# ============================================================================ + + +class GroupedConvDispatcherWrapperGenerator: + """Generates dispatcher integration wrapper following GEMM pattern""" + + # Static mappings for pipeline and scheduler enum names (matches kernel_key.hpp) + PIPELINE_TO_DISPATCHER = { + "mem": "Pipeline::Mem", + "compv3": "Pipeline::CompV3", + "compv4": "Pipeline::CompV4", + "compv5": "Pipeline::CompV5", + "preshufflev1": "Pipeline::PreShuffleV1", + "preshufflev2": "Pipeline::PreShuffleV2", + } + + SCHEDULER_TO_DISPATCHER = { + "default": "Scheduler::Default", + "intrawave": "Scheduler::Intrawave", + "interwave": "Scheduler::Interwave", + } + + def __init__( + self, + datatype: str, + variant: GroupedConvVariant = GroupedConvVariant.FORWARD, + ): + self.datatype = datatype + self.variant = variant + + def _pipeline_to_dispatcher(self, pipeline: str) -> str: + """Convert pipeline string to dispatcher enum value""" + return self.PIPELINE_TO_DISPATCHER.get( + pipeline.lower(), f"Pipeline::{pipeline.capitalize()}" + ) + + def _scheduler_to_dispatcher(self, scheduler: str) -> str: + """Convert scheduler string to dispatcher enum value""" + return self.SCHEDULER_TO_DISPATCHER.get( + scheduler.lower(), f"Scheduler::{scheduler.capitalize()}" + ) + + def generate( + self, + config: GroupedConvKernelConfig, + kernel_path: Path, + output_dir: Path, + ) -> str: + """Generate dispatcher wrapper with factory function for registry""" + kernel_name = config.name(self.datatype) + rel_path = kernel_path.relative_to(output_dir) + + # Determine launcher type based on variant + if self.variant == GroupedConvVariant.FORWARD: + launcher_alias = "SelectedConvKernelLauncher" + host_args_type = "GroupedConvFwdHostArgs<>" + conv_type_str = "forward" + elif self.variant == GroupedConvVariant.BACKWARD_DATA: + launcher_alias = "SelectedConvBwdDataLauncher" + host_args_type = "GroupedConvBwdDataHostArgs" + conv_type_str = "bwd_data" + else: # BACKWARD_WEIGHT + launcher_alias = "SelectedConvBwdWeightLauncher" + host_args_type = "GroupedConvBwdWeightHostArgs" + conv_type_str = "bwd_weight" + + return f"""// SPDX-License-Identifier: MIT +// Auto-generated dispatcher wrapper for: {kernel_name} +#pragma once + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "../{rel_path}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +using ::ck_tile::dispatcher::GroupedConvKernelInstancePtr; +using ::ck_tile::dispatcher::GroupedConvKernelKey; +using ::ck_tile::dispatcher::DataType; +using ::ck_tile::dispatcher::LayoutTag; +using ::ck_tile::dispatcher::Pipeline; +using ::ck_tile::dispatcher::Scheduler; +using ::ck_tile::dispatcher::Epilogue; +using Priority = ::ck_tile::dispatcher::GroupedConvRegistry::Priority; + +// Factory function to create kernel instance for registry +inline GroupedConvKernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{ + GroupedConvKernelKey key; + key.signature.dtype_in = DataType::FP16; + key.signature.dtype_wei = DataType::FP16; + key.signature.dtype_out = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout = "nhwgc"; + key.signature.conv_type = "{conv_type_str}"; + key.signature.num_dims = {config.ndim_spatial}; + key.signature.groups = 1; + + key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}}; + key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, 1}}; + key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}}; + key.algorithm.pipeline = {self._pipeline_to_dispatcher(config.trait.pipeline)}; + key.algorithm.scheduler = {self._scheduler_to_dispatcher(config.trait.scheduler)}; + key.algorithm.epilogue = Epilogue::CShuffle; + key.gfx_arch = gfx_arch; + + // Create kernel instance that wraps the launcher + return std::make_shared( + key, + "{kernel_name}", + []({host_args_type}& args, const stream_config& cfg) -> float {{ + return {kernel_name}_Launcher::launch(args, cfg); + }} + ); +}} + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile + +// Export launcher alias to global namespace for direct use +using {launcher_alias} = {kernel_name}_Launcher; +""" + + +# ============================================================================ +# Configuration Parser +# ============================================================================ + + +def get_default_configs( + arch: str = "gfx942", + variants: Optional[List[GroupedConvVariant]] = None, + ndims: Optional[List[int]] = None, +) -> List[GroupedConvKernelConfig]: + """Get default grouped convolution configurations for target architecture""" + configs = [] + + if variants is None: + variants = [GroupedConvVariant.FORWARD] + if ndims is None: + ndims = [2] + + # Valid configurations per variant (based on CK Tile example configs) + # Forward and Backward Data: standard GEMM-like tiles + fwd_bwd_data_tiles = [ + # (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k) + (128, 128, 32, 2, 2, 32, 32, 16), # Standard 128x128 + (256, 256, 32, 2, 2, 32, 32, 16), # Large 256x256 + (64, 64, 32, 1, 4, 16, 16, 16), # Small 64x64 + (128, 64, 32, 2, 2, 32, 32, 16), # Rectangular + (16, 64, 64, 1, 4, 16, 16, 32), # Tall and narrow + ] + + # Backward Weight: VERY specific tile configs that work with CK Tile's bwd_weight kernel + # Based on ConvConfigComputeV3 from CK Tile examples (example/ck_tile/20_grouped_convolution/) + # Note: Backward weight has strict constraints on warp configurations due to transpose_tile2d + # Only specific warp configs work: (1, 4, 1) and (4, 1, 1) are known to work + bwd_weight_tiles = [ + # (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k) + # ConvConfigComputeV3: The primary working config for backward weight + (16, 64, 64, 1, 4, 16, 16, 32), + ] + + for variant in variants: + # Select tile configs based on variant + if variant == GroupedConvVariant.BACKWARD_WEIGHT: + tile_configs = bwd_weight_tiles + # Backward weight ONLY supports compv3 (compv4/compv5 have transpose_tile2d issues) + pipelines = [("compv3", "cshuffle")] + # Also generate two-stage variants (fp32 workspace + elementwise convert) + two_stage_flags = [False, True] + elif variant == GroupedConvVariant.BACKWARD_DATA: + tile_configs = fwd_bwd_data_tiles + # Backward data ONLY supports compv3 (compv4 has get_length issues in bwd_data kernel) + pipelines = [("compv3", "cshuffle")] + two_stage_flags = [False] + else: + tile_configs = fwd_bwd_data_tiles + # Only forward grouped convolution supports both compv3 and compv4 + pipelines = [("compv3", "cshuffle"), ("compv4", "cshuffle")] + two_stage_flags = [False] + for ndim in ndims: + for pipeline, epilogue in pipelines: + for ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) in tile_configs: + for two_stage in two_stage_flags: + adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k + + trait = GroupedConvTraitConfig( + pipeline=pipeline, + scheduler="intrawave", + epilogue=epilogue, + double_smem_buffer=(pipeline == "compv4"), + pad_m=True, + pad_n=True, + pad_k=True, + two_stage=two_stage, + ) + + if not trait.is_valid(): + continue + + config = GroupedConvKernelConfig( + tile=TileConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=adj_tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=1, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + ), + trait=trait, + variant=variant, + ndim_spatial=ndim, + arch=arch, + ) + + if config.is_valid_for_arch(): + configs.append(config) + + return configs + + +def get_arch_filter(): + """Get arch filter if available""" + try: + from arch_filter import ArchFilter + + return ArchFilter + except ImportError: + return None + + +# ============================================================================ +# Main Generator +# ============================================================================ + + +class _GenItem: + """Item for parallel generation with progress logging.""" + + def __init__( + self, + idx: int, + total: int, + config: GroupedConvKernelConfig, + datatype: str, + variant: GroupedConvVariant, + ): + self.idx = idx + self.total = total + self.config = config + self.datatype = datatype + self.variant = variant + + def __str__(self) -> str: + return f"kernel {self.idx}/{self.total}: {self.config.name(self.datatype)}" + + +class UnifiedGroupedConvCodegen: + """Main grouped convolution code generator""" + + def __init__( + self, + output_dir: Path, + gpu_target: str = "gfx942", + datatype: str = "fp16", + ndim_spatial: int = 2, + enable_arch_filter: bool = True, + ): + self.output_dir = output_dir + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Create wrapper directory for dispatcher integration + self.wrapper_dir = self.output_dir / "dispatcher_wrappers" + self.wrapper_dir.mkdir(parents=True, exist_ok=True) + + self.generated_files: List[Path] = [] + self.generated_wrappers: List[Path] = [] + self.gpu_target = gpu_target + self.datatype = datatype + self.ndim_spatial = ndim_spatial + + # Initialize architecture filter for GPU-specific validation + self.arch_filter = None + if enable_arch_filter and HAS_ARCH_FILTER: + try: + self.arch_filter = ArchFilter(gpu_target, strict_mode=False) + log.info(f"Architecture filter enabled for {gpu_target}") + except ValueError as e: + log.warning(f"Could not create arch filter: {e}") + + def _get_configs(self) -> List[GroupedConvKernelConfig]: + """Get configurations for this codegen's datatype and ndim_spatial.""" + return get_default_configs( + arch=self.gpu_target, + variants=[ + GroupedConvVariant.FORWARD, + GroupedConvVariant.BACKWARD_DATA, + GroupedConvVariant.BACKWARD_WEIGHT, + ], + ndims=[self.ndim_spatial], + ) + + def _get_operator_type( + self, variant: GroupedConvVariant + ) -> Optional["OperatorType"]: + """Map GroupedConvVariant to OperatorType for arch validation""" + if OperatorType is None: + return None + + variant_to_operator = { + GroupedConvVariant.FORWARD: OperatorType.CONV_FWD, + GroupedConvVariant.BACKWARD_DATA: OperatorType.CONV_BWD_DATA, + GroupedConvVariant.BACKWARD_WEIGHT: OperatorType.CONV_BWD_WEIGHT, + } + return variant_to_operator.get(variant, OperatorType.CONV_FWD) + + def is_config_valid( + self, config: GroupedConvKernelConfig, datatype: str = "fp16" + ) -> bool: + """Validate configuration against architecture constraints""" + if not self.arch_filter or not HAS_ARCH_FILTER: + return True + + operator = self._get_operator_type(config.variant) + + return self.arch_filter.is_kernel_valid( + datatype_a=datatype, + datatype_b=datatype, + datatype_c=datatype, + tile_m=config.tile.tile_m, + tile_n=config.tile.tile_n, + tile_k=config.tile.tile_k, + warp_m=config.tile.warp_m, + warp_n=config.tile.warp_n, + warp_k=1, # Grouped conv typically uses warp_k=1 + warp_tile_m=config.tile.warp_tile_m, + warp_tile_n=config.tile.warp_tile_n, + warp_tile_k=config.tile.warp_tile_k, + pipeline=config.trait.pipeline, + epilogue=config.trait.epilogue, + scheduler=config.trait.scheduler, + operator=operator, + ) + + def generate_kernel( + self, + config: GroupedConvKernelConfig, + datatype: str, + variant: GroupedConvVariant = GroupedConvVariant.FORWARD, + ) -> Tuple[Path, Path]: + """Generate a single kernel file and dispatcher wrapper. Returns (kernel_path, wrapper_path).""" + kernel_gen = CKTileGroupedConvKernelGenerator(datatype, variant) + wrapper_gen = GroupedConvDispatcherWrapperGenerator(datatype, variant) + + kernel_name = config.name(datatype) + filename = f"{kernel_name}.hpp" + filepath = self.output_dir / filename + + # Generate kernel header + content = kernel_gen.generate(config) + filepath.write_text(content) + self.generated_files.append(filepath) + + # Generate dispatcher wrapper + wrapper_content = wrapper_gen.generate(config, filepath, self.output_dir) + wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp" + wrapper_path.write_text(wrapper_content) + self.generated_wrappers.append(wrapper_path) + + # Generate .cpp compilation unit for per-kernel parallel builds + cpp_filename = f"{kernel_name}.cpp" + cpp_filepath = self.output_dir / cpp_filename + cpp_content = f"""// SPDX-License-Identifier: MIT +// Auto-generated compilation unit for: {kernel_name} +// Enables per-kernel parallel compilation with make -j + +#include "{filename}" + +namespace ck_tile {{ namespace generated {{ + volatile bool _{kernel_name.replace("-", "_")}_loaded = true; +}} }} +""" + cpp_filepath.write_text(cpp_content) + + return filepath, wrapper_path + + def _generate_single_kernel(self, item: _GenItem): + """Generate one kernel (used by parallel_generate). Returns (kernel_path, wrapper_path) or raises.""" + kernel_path, wrapper_path = self.generate_kernel( + item.config, item.datatype, item.variant + ) + log.info( + "Generated kernel %d/%d: %s", + item.idx, + item.total, + item.config.name(item.datatype), + ) + return (kernel_path, wrapper_path) + + def generate_all( + self, + configs: Optional[List[GroupedConvKernelConfig]] = None, + datatypes: Optional[List[str]] = None, + parallel: bool = True, + ) -> dict: + """Generate all kernel files (optionally in parallel). + + Configs are filtered using architecture validation before generation. + Returns dict with keys: kernels, wrappers, failed. + """ + if configs is None: + configs = self._get_configs() + if datatypes is None: + datatypes = [self.datatype] + + results = {"kernels": [], "wrappers": [], "failed": []} + + # Filter configs using arch validation + valid_tasks = [] + rejected_count = 0 + + for datatype in datatypes: + for config in configs: + if self.is_config_valid(config, datatype): + valid_tasks.append((config, datatype, config.variant)) + else: + rejected_count += 1 + log.debug( + f"Rejected config for {self.gpu_target}: " + f"{config.tile.tile_m}x{config.tile.tile_n}x{config.tile.tile_k} " + f"variant={config.variant.value}" + ) + + if rejected_count > 0: + log.info( + f"Filtered {rejected_count} configs for {self.gpu_target}, " + f"{len(valid_tasks)} remaining" + ) + + total = len(valid_tasks) + items = [ + _GenItem(i, total, config, datatype, variant) + for i, (config, datatype, variant) in enumerate(valid_tasks) + ] + + def _safe_generate(item: _GenItem): + """Wrapper that catches exceptions for failure tracking.""" + try: + k, w = self._generate_single_kernel(item) + return ("ok", k, w, None) + except Exception as e: + return ("fail", None, None, str(e)) + + raw = parallel_generate( + _safe_generate, items, parallel=parallel and len(items) > 1 + ) + for r in raw: + if r[0] == "ok": + results["kernels"].append(r[1]) + results["wrappers"].append(r[2]) + else: + results["failed"].append(r[3]) + log.error("Failed: %s", r[3]) + + # Generate include_all_*.hpp headers for Python ctypes libraries + if results["wrappers"]: + self._generate_include_all_headers() + + return results + + def _generate_include_all_headers(self): + """Generate include_all_grouped_conv_*.hpp headers and registration header""" + # Scan output directory for ALL kernel files (not just this run's generated_files) + # This handles the case where fwd and bwd kernels are generated in separate make targets + fwd_headers = [] + bwd_data_headers = [] + bwd_weight_headers = [] + fwd_kernels = [] + bwd_data_kernels = [] + bwd_weight_kernels = [] + + for filepath in self.output_dir.glob("grouped_conv_*.hpp"): + name = filepath.name + kernel_name = name[:-4] + if name.startswith("grouped_conv_fwd_"): + fwd_headers.append(name) + fwd_kernels.append(kernel_name) + elif name.startswith(("grouped_conv_bwd_data_", "grouped_conv_bwdd_")): + bwd_data_headers.append(name) + bwd_data_kernels.append(kernel_name) + elif name.startswith(("grouped_conv_bwd_weight_", "grouped_conv_bwdw_")): + bwd_weight_headers.append(name) + bwd_weight_kernels.append(kernel_name) + + headers_to_generate = [ + ("include_all_grouped_conv_fwd_kernels.hpp", fwd_headers, "forward"), + ( + "include_all_grouped_conv_bwd_data_kernels.hpp", + bwd_data_headers, + "backward data", + ), + ( + "include_all_grouped_conv_bwd_weight_kernels.hpp", + bwd_weight_headers, + "backward weight", + ), + ] + + for header_name, kernel_headers, variant_desc in headers_to_generate: + header_path = self.output_dir / header_name + includes = "\n".join(f'#include "{h}"' for h in sorted(kernel_headers)) + + # Pick the first kernel as the default Selected*Launcher + if kernel_headers: + first_kernel = sorted(kernel_headers)[0][:-4] # Remove .hpp + if variant_desc == "forward": + launcher_alias = ( + f"using SelectedConvKernelLauncher = {first_kernel}_Launcher;" + ) + elif variant_desc == "backward data": + launcher_alias = ( + f"using SelectedConvBwdDataLauncher = {first_kernel}_Launcher;" + ) + else: # backward weight + launcher_alias = f"using SelectedConvBwdWeightLauncher = {first_kernel}_Launcher;" + else: + launcher_alias = "// No kernels generated for this variant" + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Auto-generated header for grouped conv {variant_desc} kernels +#pragma once + +{includes} + +// Default launcher alias (uses first kernel) +{launcher_alias} +""" + header_path.write_text(content) + if kernel_headers: + log.info(f"Generated: {header_name} ({len(kernel_headers)} kernels)") + + # Generate registration header (following GEMM pattern) + self._generate_registration_header( + fwd_kernels, bwd_data_kernels, bwd_weight_kernels + ) + + def _generate_registration_header( + self, + fwd_kernels: List[str], + bwd_data_kernels: List[str], + bwd_weight_kernels: List[str], + ): + """Generate master registration header for all grouped conv kernels""" + # Scan wrapper directory for ALL wrapper files + all_wrappers = [] + for wrapper_path in self.wrapper_dir.glob( + "dispatcher_wrapper_grouped_conv_*.hpp" + ): + all_wrappers.append(wrapper_path.name) + + wrapper_includes = "\n".join(f'#include "{w}"' for w in sorted(all_wrappers)) + + # Generate registration calls + fwd_registrations = "\n ".join( + f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);" + for k in sorted(fwd_kernels) + ) + bwd_data_registrations = "\n ".join( + f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);" + for k in sorted(bwd_data_kernels) + ) + bwd_weight_registrations = "\n ".join( + f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);" + for k in sorted(bwd_weight_kernels) + ) + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Auto-generated master registration header for grouped conv kernels +#pragma once + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" + +{wrapper_includes} + +namespace ck_tile {{ +namespace dispatcher {{ + +using Priority = GroupedConvRegistry::Priority; + +inline void register_all_grouped_conv_fwd_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = GroupedConvRegistry::instance(); + {fwd_registrations if fwd_registrations else "// No forward kernels"} +}} + +inline void register_all_grouped_conv_bwd_data_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = GroupedConvRegistry::instance(); + {bwd_data_registrations if bwd_data_registrations else "// No backward data kernels"} +}} + +inline void register_all_grouped_conv_bwd_weight_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = GroupedConvRegistry::instance(); + {bwd_weight_registrations if bwd_weight_registrations else "// No backward weight kernels"} +}} + +inline void register_all_grouped_conv_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + register_all_grouped_conv_fwd_kernels(gfx_arch, priority); + register_all_grouped_conv_bwd_data_kernels(gfx_arch, priority); + register_all_grouped_conv_bwd_weight_kernels(gfx_arch, priority); +}} + +inline std::size_t get_grouped_conv_fwd_kernel_count() {{ return {len(fwd_kernels)}; }} +inline std::size_t get_grouped_conv_bwd_data_kernel_count() {{ return {len(bwd_data_kernels)}; }} +inline std::size_t get_grouped_conv_bwd_weight_kernel_count() {{ return {len(bwd_weight_kernels)}; }} +inline std::size_t get_grouped_conv_kernel_count() {{ return {len(fwd_kernels) + len(bwd_data_kernels) + len(bwd_weight_kernels)}; }} + +}} // namespace dispatcher +}} // namespace ck_tile +""" + reg_path = self.wrapper_dir / "register_all_grouped_conv_kernels.hpp" + reg_path.write_text(content) + log.info(f"Generated registration header: {reg_path}") + + +# ============================================================================ +# CLI +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Unified Grouped Convolution Code Generator" + ) + parser.add_argument( + "--output", + "-o", + type=Path, + default=Path("build/generated_kernels"), + help="Output directory", + ) + parser.add_argument( + "--datatype", + "-d", + type=str, + nargs="+", + default=["fp16"], + choices=["fp16", "bf16", "fp32"], + help="Data types to generate", + ) + parser.add_argument( + "--variant", + "-v", + type=str, + nargs="+", + default=["forward"], + choices=["forward", "bwd_data", "bwd_weight"], + help="Grouped convolution variants", + ) + parser.add_argument( + "--ndim", + "-n", + type=int, + nargs="+", + default=[2], + choices=[1, 2, 3], + help="Spatial dimensions", + ) + parser.add_argument( + "--arch", + "-a", + type=str, + default="gfx942", + choices=["gfx90a", "gfx942", "gfx950", "gfx1201"], + help="Target GPU architecture", + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--list-configs", + action="store_true", + help="List configurations without generating", + ) + + # Individual kernel configuration (when not using predefined configs) + parser.add_argument("--tile-m", type=int, help="Block tile M dimension") + parser.add_argument("--tile-n", type=int, help="Block tile N dimension") + parser.add_argument("--tile-k", type=int, help="Block tile K dimension") + parser.add_argument("--warp-m", type=int, help="Wave distribution M") + parser.add_argument("--warp-n", type=int, help="Wave distribution N") + parser.add_argument("--warp-k", type=int, default=1, help="Wave distribution K") + parser.add_argument("--warp-tile-m", type=int, help="Warp tile M") + parser.add_argument("--warp-tile-n", type=int, help="Warp tile N") + parser.add_argument("--warp-tile-k", type=int, default=16, help="Warp tile K") + parser.add_argument( + "--pipeline", + type=str, + choices=["mem", "compv3", "compv4", "compv5"], + help="Pipeline type", + ) + parser.add_argument( + "--scheduler", + type=str, + choices=["intrawave", "interwave"], + help="Scheduler type", + ) + parser.add_argument( + "--epilogue", + type=str, + default="cshuffle", + choices=["cshuffle", "default"], + help="Epilogue type", + ) + parser.add_argument("--pad-m", type=bool, default=True, help="Pad M dimension") + parser.add_argument("--pad-n", type=bool, default=True, help="Pad N dimension") + parser.add_argument("--pad-k", type=bool, default=True, help="Pad K dimension") + parser.add_argument("--vector-a", type=int, default=4, help="Vector size A") + parser.add_argument("--vector-b", type=int, default=8, help="Vector size B") + parser.add_argument("--vector-c", type=int, default=8, help="Vector size C") + parser.add_argument("--block-per-cu", type=int, default=1, help="Blocks per CU") + parser.add_argument("--num-wave-groups", type=int, default=1, help="Wave groups") + parser.add_argument( + "--num-groups-to-merge", type=int, default=1, help="Groups to merge" + ) + parser.add_argument( + "--double-smem-buffer", + type=str, + default=None, + help="Double SMEM buffer (true/false)", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Map variant strings to enums + variant_map = { + "forward": GroupedConvVariant.FORWARD, + "bwd_data": GroupedConvVariant.BACKWARD_DATA, + "bwd_weight": GroupedConvVariant.BACKWARD_WEIGHT, + } + requested_variants = [variant_map[v] for v in args.variant] + + # Check if user specified custom configuration + custom_config = ( + args.tile_m is not None or args.tile_n is not None or args.pipeline is not None + ) + + if custom_config: + # Build custom config from CLI arguments + tile = TileConfig( + tile_m=args.tile_m or 128, + tile_n=args.tile_n or 128, + tile_k=args.tile_k or 64, + warp_m=args.warp_m or 2, + warp_n=args.warp_n or 2, + warp_k=args.warp_k or 1, + warp_tile_m=args.warp_tile_m or 32, + warp_tile_n=args.warp_tile_n or 32, + warp_tile_k=args.warp_tile_k or 16, + ) + pipeline = args.pipeline or "compv4" + # Determine double_smem_buffer: use CLI arg if given, else default based on pipeline + if args.double_smem_buffer is not None: + dsb = args.double_smem_buffer.lower() == "true" + else: + dsb = pipeline == "compv4" # compv4 requires double buffer + + trait = GroupedConvTraitConfig( + pipeline=pipeline, + scheduler=args.scheduler or "intrawave", + epilogue=args.epilogue or "cshuffle", + pad_m=args.pad_m, + pad_n=args.pad_n, + pad_k=args.pad_k, + double_smem_buffer=dsb, + num_groups_to_merge=args.num_groups_to_merge, + ) + config = GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=requested_variants[0] + if requested_variants + else GroupedConvVariant.FORWARD, + ndim_spatial=args.ndim[0] if args.ndim else 2, + arch=args.arch, + vector_size_a=args.vector_a, + vector_size_b=args.vector_b, + vector_size_c=args.vector_c, + block_per_cu=args.block_per_cu, + num_wave_groups=args.num_wave_groups, + ) + filtered_configs = [config] + else: + # Get predefined configurations for target arch with requested variants and ndims + filtered_configs = get_default_configs( + arch=args.arch, variants=requested_variants, ndims=args.ndim + ) + + if args.list_configs: + print(f"Grouped convolution configurations for {args.arch}:") + print(f" Datatypes: {args.datatype}") + print(f" Variants: {args.variant}") + print(f" Spatial dims: {args.ndim}") + print(f"\nConfigurations ({len(filtered_configs)}):") + for cfg in filtered_configs: + print(f" - {cfg.name('fp16')}") + print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}") + print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}") + print( + f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}" + ) + print( + f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}" + ) + print( + f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}" + ) + return + + # Generate + codegen = UnifiedGroupedConvCodegen( + output_dir=args.output, + gpu_target=args.arch, + enable_arch_filter=True, + ) + results = codegen.generate_all( + configs=filtered_configs, datatypes=args.datatype, parallel=True + ) + + print( + f"\nGenerated {len(results['kernels'])} grouped convolution kernel files " + f"for {args.arch} in {args.output}" + ) + if results["failed"]: + print(f" Failed: {len(results['failed'])}") + for err in results["failed"][:5]: + print(f" - {err}") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/examples/CMakeLists.txt b/dispatcher/examples/CMakeLists.txt index 0359eb0d8d..ab094e90cf 100644 --- a/dispatcher/examples/CMakeLists.txt +++ b/dispatcher/examples/CMakeLists.txt @@ -187,7 +187,6 @@ function(add_gpu_example NAME SOURCE KERNEL_HEADER) if(HEADER_NAME STREQUAL "register_all_kernels.hpp") # Registration header - examples include it directly target_compile_options(${NAME} PRIVATE - -DGEMM_KERNEL_AVAILABLE=1 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal @@ -315,6 +314,7 @@ function(add_declarative_gpu_example NAME SOURCE) target_include_directories(${NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${EXAMPLE_KERNEL_DIR} ${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers ) @@ -322,7 +322,6 @@ function(add_declarative_gpu_example NAME SOURCE) # Force-include the generated registration header target_compile_options(${NAME} PRIVATE -include ${EXAMPLE_HEADER} - -DGEMM_KERNEL_AVAILABLE=1 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal @@ -345,6 +344,56 @@ add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_v add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp) add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp) add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp) +add_declarative_gpu_example(gemm_07_gfx950_minimal gemm/cpp/07_gfx950_minimal.cpp) + +# ML Heuristic example -- requires LightGBM shared library +# Derive site-packages from active Python interpreter (respects virtualenvs) +find_package(Python3 COMPONENTS Interpreter) + +set(LIGHTGBM_SEARCH_PATHS) +if(Python3_FOUND AND Python3_EXECUTABLE) + execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_path('purelib'))" + OUTPUT_VARIABLE PYTHON_SITE_PACKAGES + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if(PYTHON_SITE_PACKAGES) + list(APPEND LIGHTGBM_SEARCH_PATHS "${PYTHON_SITE_PACKAGES}/lightgbm/lib") + endif() +endif() + +# Fallback to common Python 3.x site-packages if auto-detection failed +if(NOT PYTHON_SITE_PACKAGES) + list(APPEND LIGHTGBM_SEARCH_PATHS + "$ENV{HOME}/.local/lib/python3.12/site-packages/lightgbm/lib" + ) +endif() + +find_library(LIGHTGBM_LIB NAMES LightGBM lib_lightgbm _lightgbm + HINTS ${CMAKE_PREFIX_PATH} + PATHS ${LIGHTGBM_SEARCH_PATHS} + NO_DEFAULT_PATH + DOC "LightGBM shared library for ML heuristics" +) + +# Fallback: search default paths (respects LightGBM_DIR if set by user) +if(NOT LIGHTGBM_LIB) + find_library(LIGHTGBM_LIB NAMES LightGBM lib_lightgbm) +endif() + +if(LIGHTGBM_LIB) + add_declarative_gpu_example(gemm_09_ml_heuristic gemm/cpp/09_ml_heuristic.cpp) + target_link_libraries(gemm_09_ml_heuristic PRIVATE ${LIGHTGBM_LIB}) + message(STATUS "LightGBM found: ${LIGHTGBM_LIB} -- building gemm_09_ml_heuristic") +else() + message(STATUS "LightGBM not found -- skipping gemm_09_ml_heuristic") + message(STATUS " To enable ML heuristic example:") + message(STATUS " 1. Activate virtualenv: source .venv/bin/activate") + message(STATUS " 2. Install: pip install -r ../requirements-ml.txt") + message(STATUS " 3. Reconfigure: cmake ..") + message(STATUS " Or set CMAKE_PREFIX_PATH or LightGBM_DIR to LightGBM location") +endif() # ============================================================================= # GEMM Python Library - Single Fallback Kernel @@ -394,19 +443,79 @@ if(hip_FOUND) endif() add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel) +# ============================================================================= +# Grouped Convolution C++ Examples +# ============================================================================= + +add_declarative_gpu_example(grouped_conv_01_basic grouped_conv/cpp/01_basic_grouped_conv.cpp) +add_declarative_gpu_example(grouped_conv_02_all_dirs grouped_conv/cpp/02_all_directions.cpp) +add_declarative_gpu_example(grouped_conv_03_bench_val grouped_conv/cpp/03_benchmark_validation.cpp) +add_declarative_gpu_example(grouped_conv_04_registry_json grouped_conv/cpp/04_registry_json.cpp) +add_declarative_gpu_example(grouped_conv_05_bwd_data grouped_conv/cpp/05_bwd_data.cpp) +add_declarative_gpu_example(grouped_conv_06_bwd_weight grouped_conv/cpp/06_bwd_weight.cpp) +add_declarative_gpu_example(grouped_conv_07_benchmark grouped_conv/cpp/07_multi_tile_benchmark.cpp) + +# ============================================================================= +# Grouped Convolution Python Library - Multi-Kernel (fwd/bwd_data/bwd_weight x 2D/3D) +# ============================================================================= + +# Kernel output directory for the Python conv library +set(CONV_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/conv_python_fallback") +set(CONV_DISPATCH_HEADER "${CONV_FALLBACK_KERNEL_DIR}/conv_python_dispatch.hpp") + +# Generate ALL conv kernels (fwd/bwd_data/bwd_weight x 2D/3D x multiple tile configs) +# then create the dispatch header with 2D/3D aliases +add_custom_command( + OUTPUT ${CONV_DISPATCH_HEADER} + COMMAND ${CMAKE_COMMAND} -E make_directory ${CONV_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_grouped_conv_codegen.py + --variant forward bwd_data bwd_weight --ndim 2 3 + --datatype fp16 --arch ${GPU_TARGET} + --output ${CONV_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/generate_conv_dispatch_header.py + --kernel-dir ${CONV_FALLBACK_KERNEL_DIR} + --output ${CONV_DISPATCH_HEADER} + COMMENT "Generating conv kernels (fwd/bwd_data/bwd_weight x 2D/3D) for Python library..." + VERBATIM +) + +add_custom_target(generate_conv_fallback_kernels DEPENDS ${CONV_DISPATCH_HEADER}) + +# Conv dynamic library for Python (all 6 kernel variants) +add_library(dispatcher_conv_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/conv_ctypes_lib.cpp) +target_link_libraries(dispatcher_conv_lib PRIVATE ck_tile_dispatcher) +target_include_directories(dispatcher_conv_lib PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CONV_FALLBACK_KERNEL_DIR} +) +target_compile_options(dispatcher_conv_lib PRIVATE + -include ${CONV_DISPATCH_HEADER} + -DGFX_ARCH="${GPU_TARGET}" + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +if(hip_FOUND) + target_link_libraries(dispatcher_conv_lib PRIVATE hip::device hip::host) +endif() +add_dependencies(dispatcher_conv_lib generate_conv_fallback_kernels) + message(STATUS "GEMM examples configured - kernels will be generated during 'make'") +message(STATUS "Grouped Conv examples configured - kernels will be generated during 'make'") # Convenience target to build all Python ctypes libraries add_custom_target(python_libs - DEPENDS dispatcher_gemm_lib - COMMENT "Building Python ctypes libraries (GEMM)" + DEPENDS dispatcher_gemm_lib dispatcher_conv_lib + COMMENT "Building Python ctypes libraries (GEMM + Conv)" ) # ============================================================================= # Per-Architecture Kernel Generation Targets # ============================================================================= -set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030) +set(SUPPORTED_GPU_ARCHS gfx942 gfx950 gfx90a gfx1100 gfx1030) foreach(ARCH ${SUPPORTED_GPU_ARCHS}) # GEMM kernels for this arch diff --git a/dispatcher/examples/README.md b/dispatcher/examples/README.md index fdee9c3583..24bea821ba 100644 --- a/dispatcher/examples/README.md +++ b/dispatcher/examples/README.md @@ -1,8 +1,6 @@ # CK Tile Dispatcher Examples -Comprehensive examples for GEMM operations with GPU execution. - -> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference. +Comprehensive examples for GEMM and Grouped Convolution operations with GPU execution. --- @@ -60,11 +58,11 @@ python3 examples/gemm/python/08_heuristics.py ``` examples/ -├── gemm/ -│ ├── cpp/ # 6 C++ GEMM examples -│ └── python/ # 11 Python GEMM examples -│ -└── README.md +|---- gemm/ +| |---- cpp/ # 6 C++ GEMM examples +| +---- python/ # 11 Python GEMM examples +| ++---- README.md ``` --- @@ -201,10 +199,31 @@ rocminfo | grep "Name:" --- -## Archived Examples +## Grouped Convolution -Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`: -- `examples/conv/cpp/` - 11 C++ convolution examples -- `examples/conv/python/` - 14 Python convolution examples +Grouped convolution support has been re-introduced with a unified infrastructure shared with GEMM. -See the archive for convolution functionality reference. +### Infrastructure + +The grouped convolution code generation, utilities, and build scripts are available: + +| Component | Location | +|-----------|----------| +| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` | +| Python Codegen | `codegen/unified_grouped_conv_codegen.py` | +| Python Utils | `python/grouped_conv_utils.py` | +| Build Script | `scripts/compile_grouped_conv_examples.py` | + +### Building Grouped Conv Kernels + +```bash +# Generate grouped conv kernels +python3 codegen/unified_grouped_conv_codegen.py \ + --output-dir build/generated_kernels \ + --datatype fp16 --variant forward --ndim-spatial 2 + +# Compile a grouped conv example +python3 scripts/compile_grouped_conv_examples.py my_grouped_conv_example.cpp +``` + +See the [main README](../README.md#grouped-convolution-support) for more details. diff --git a/dispatcher/examples/gemm/cpp/02_multi_size.cpp b/dispatcher/examples/gemm/cpp/02_multi_size.cpp index 5e620209f4..ffd2858be4 100644 --- a/dispatcher/examples/gemm/cpp/02_multi_size.cpp +++ b/dispatcher/examples/gemm/cpp/02_multi_size.cpp @@ -21,9 +21,9 @@ * - pipeline: "compv3" -> 1 option (compv4 requires special handling) * - scheduler: "intrawave" -> 1 option * - * Raw expansion: 3 × 2 = 6 configs, but arch filter validates each: - * - tile_m must be divisible by (warp_m × warp_tile_m) - * - tile_n must be divisible by (warp_n × warp_tile_n) + * Raw expansion: 3 x 2 = 6 configs, but arch filter validates each: + * - tile_m must be divisible by (warp_m x warp_tile_m) + * - tile_n must be divisible by (warp_n x warp_tile_n) * - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16) * Result: 4 valid wildcard kernels + 1 explicit = 5 total * @@ -70,13 +70,13 @@ DECL_KERNEL_SET(multi_size_kernels, .add(Signature().dtype("fp16").layout("rcr"), Algorithm() .tile(64, 64, 64) - .wave(ANY_INT, ANY_INT, 1) // ANY_INT → (1,4,1), (2,2,1), (4,1,1) - .warp(-1, -1, -1) // -1 same as ANY_INT → (16,16,32), (32,32,16) - .pipeline("*") // "*" → valid pipelines - .scheduler("*") // "*" → valid schedulers + .wave(ANY_INT, ANY_INT, 1) // ANY_INT -> (1,4,1), (2,2,1), (4,1,1) + .warp(-1, -1, -1) // -1 same as ANY_INT -> (16,16,32), (32,32,16) + .pipeline("*") // "*" -> valid pipelines + .scheduler("*") // "*" -> valid schedulers .epilogue("cshuffle"), "gfx942")); -// Raw: 3×2=6, arch filter removes 2 invalid → 4 valid kernels +// Raw: 3x2=6, arch filter removes 2 invalid -> 4 valid kernels // ============================================================================= // MAIN @@ -116,8 +116,8 @@ int main(int argc, char* argv[]) .pipeline("*") -> expands to valid pipelines = 1 .scheduler("*") -> expands to valid schedulers = 1 - Expanded: 3 × 2 = 6 configs, but arch filter validates each: - - wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64 + Expanded: 3 x 2 = 6 configs, but arch filter validates each: + - wave x warp must divide tile: (4,1,1)x(32,32,16) invalid for 64x64 - Result: 4 valid kernels from wildcard + 1 explicit = 5 total )"; diff --git a/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp b/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp new file mode 100644 index 0000000000..7e62ad2e4f --- /dev/null +++ b/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp @@ -0,0 +1,191 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 07: Minimal gfx950 (CDNA4 / MI350) GEMM + * + * Demonstrates the dispatcher working with gfx950-specific kernels: + * + * - fp16 GEMM with standard tile configs + * - fp8 GEMM with gfx950-extended warp tiles (16x16x128) + * - 160KB LDS: gfx950 doubles the LDS from 64KB to 160KB + * + * Build: cd dispatcher/build && cmake .. -DGPU_TARGETS=gfx950 && make gemm_07_gfx950_minimal + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// gfx950-targeted kernel declarations +// ============================================================================= + +DECL_KERNEL_SET(gfx950_gemm_kernels, + + // fp16 128x128x32 -- bread-and-butter config, works on all CDNA + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx950") + + // fp16 128x128x64 -- deeper K tile using more LDS + // LDS usage: 128*64*2 + 128*64*2 = 32768 bytes (fits 64KB, gfx950 has 160KB) + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx950") + + // fp16 64x64x32 -- small-tile variant for small problems + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx950")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 07: gfx950 Minimal GEMM", + "Demonstrates gfx950 (CDNA4 / MI350) dispatcher"); + args.add_flag("--list", "List registered kernels"); + args.add_flag("--list-verbose", "List registered kernels with full details"); + args.add_option("--M", "1024", "Problem M dimension"); + args.add_option("--N", "1024", "Problem N dimension"); + args.add_option("--K", "1024", "Problem K dimension"); + args.add_option("--arch", "gfx950", "GPU architecture (default: gfx950)"); + + if(!args.parse(argc, argv)) + return 0; + + std::string gfx_arch = args.get("--arch", "gfx950"); + + print_header("Example 07: gfx950 (CDNA4) Minimal GEMM"); + + // ========================================================================= + // Architecture info + // ========================================================================= + std::cout << "\ngfx950 (CDNA4 / MI350) highlights:\n"; + std::cout << " - 160KB LDS (up from 64KB on gfx942)\n"; + std::cout << " - Extended FP8 warp tiles: 16x16x128, 32x32x64\n"; + std::cout << " - Packed FP4 support (pk_fp4)\n"; + std::cout << " - Same warp configs as gfx942: [1,4,1], [2,2,1], [4,1,1]\n\n"; + + // ========================================================================= + // Register kernels + // ========================================================================= + std::cout << "Registering kernels for " << gfx_arch << "...\n"; + + Registry registry; + registry.set_name("gfx950_gemm"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + if(args.has("--list") || args.has("--list-verbose")) + { + std::cout << "\n"; + print_registered_kernels(registry, std::cout, args.has("--list-verbose")); + return 0; + } + + if(registry.size() == 0) + { + std::cerr << "ERROR: No kernels registered for " << gfx_arch << "!\n"; + std::cerr << " Did you build with -DGPU_TARGETS=gfx950?\n"; + return 1; + } + + // ========================================================================= + // Create Dispatcher + // ========================================================================= + Dispatcher dispatcher(®istry); + + // ========================================================================= + // Setup Problem + // ========================================================================= + const int M = args.get_int("--M", 1024); + const int N = args.get_int("--N", 1024); + const int K = args.get_int("--K", 1024); + + std::cout << "\nProblem: " << M << " x " << N << " x " << K << "\n"; + + Problem problem(M, N, K); + + using DataType = ck_tile::fp16_t; + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + // ========================================================================= + // Select and Run + // ========================================================================= + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "ERROR: No suitable kernel found for " << M << "x" << N << "x" << K << "\n"; + return 1; + } + std::cout << " Selected: " << selected->get_name() << "\n"; + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n"; + + // ========================================================================= + // Verify + // ========================================================================= + std::cout << "\nVerification:\n"; + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + + const float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < std::min(M * N, 1024); ++i) + { + if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f) + ++errors; + } + + bool passed = (errors == 0); + std::cout << " Expected value: " << expected << "\n"; + std::cout << " Errors (first 1024 elements): " << errors << "\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + + print_separator(); + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp b/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp new file mode 100644 index 0000000000..cec6d1cd02 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp @@ -0,0 +1,211 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 09: ML-Based Kernel Selection (Native C++) + * + * Uses a trained LightGBM model loaded via the C API to predict TFLOPS + * for each kernel in the registry and select the best one. The kernels + * are JIT-compiled at build time via DECL_KERNEL_SET (same as other examples). + * + * Build: cd dispatcher/build && cmake .. && make gemm_09_ml_heuristic + * Run: ./gemm_09_ml_heuristic --model + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" +#include "ck_tile/dispatcher/ml_heuristic.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// Multiple kernel configs for ML to choose from +DECL_KERNEL_SET(ml_kernels, + // Small tiles + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(16, 16, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 64) + .wave(2, 2, 1) + .warp(16, 16, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + // Medium tiles + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + // Large tiles + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(256, 256, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(256, 128, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 256, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 09: ML-Based Kernel Selection", + "Uses trained LightGBM model for kernel selection"); + args.add_option("--arch", "gfx942", "GPU architecture"); + args.add_option("--model", "", "Path to LightGBM model file (.lgbm)"); + args.add_option("--log_transform", "false", "Model uses log1p transform"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 09: ML-Based Kernel Selection"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + std::string model_path = args.get("--model", ""); + bool log_transform = (args.get("--log_transform", "false") == "true"); + + if(model_path.empty()) + { + std::cerr << "Error: --model is required" << std::endl; + std::cerr << "Usage: ./gemm_09_ml_heuristic --model path/to/model_tflops.lgbm" << std::endl; + return 1; + } + + // Setup Registry (kernels are JIT compiled from DECL_KERNEL_SET above) + Registry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << "Registry: " << registry.size() << " kernel(s)" << std::endl; + + // Load ML model and create heuristic + HardwareProfile hw; + MLHeuristic ml_heuristic(model_path, ®istry, hw, log_transform); + if(!ml_heuristic.is_loaded()) + { + std::cerr << "Failed to load model. Exiting." << std::endl; + return 1; + } + + // Wire ML heuristic into dispatcher + Dispatcher dispatcher(®istry); + dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic); + dispatcher.set_heuristic([&ml_heuristic](const Problem& p) { return ml_heuristic(p); }); + + std::cout << "Strategy: ML Heuristic (LightGBM)" << std::endl; + + // Test with different problem sizes + using DataType = ck_tile::fp16_t; + std::vector> sizes = { + {128, 128, 64}, + {512, 512, 256}, + {1024, 1024, 512}, + {2048, 2048, 1024}, + }; + + std::cout << std::endl + << std::setw(20) << "Shape" << std::setw(30) << "Selected Kernel" << std::setw(15) + << "Pred TFLOPS" << std::setw(12) << "Select ms" << std::setw(10) << "Status" + << std::endl; + std::cout << std::string(87, '-') << std::endl; + + bool all_passed = true; + + for(const auto& [M, N, K] : sizes) + { + Problem problem; + problem.M = M; + problem.N = N; + problem.K = K; + problem.k_batch = 1; + + auto t0 = std::chrono::high_resolution_clock::now(); + auto kernel = dispatcher.select_kernel(problem); + auto t1 = std::chrono::high_resolution_clock::now(); + double select_ms = std::chrono::duration(t1 - t0).count(); + + std::string size_str = + std::to_string(M) + "x" + std::to_string(N) + "x" + std::to_string(K); + + if(!kernel) + { + std::cout << std::setw(20) << size_str << std::setw(30) << "NONE" << std::setw(15) + << "N/A" << std::setw(12) << std::fixed << std::setprecision(2) << select_ms + << std::setw(10) << "FAIL" << std::endl; + all_passed = false; + continue; + } + + double pred = ml_heuristic.predict_tflops(problem, kernel->get_key()); + std::string name = kernel->get_key().encode_identifier(); + if(name.length() > 27) + name = name.substr(0, 27) + ".."; + + std::cout << std::setw(20) << size_str << std::setw(30) << name << std::setw(15) + << std::fixed << std::setprecision(2) << pred << std::setw(12) + << std::setprecision(2) << select_ms << std::setw(10) << "OK" << std::endl; + } + + std::cout << std::endl + << (all_passed ? "*** ALL TESTS PASSED ***" : "*** SOME TESTS FAILED ***") + << std::endl; + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/README.md b/dispatcher/examples/gemm/cpp/README.md index 1d81a90a0e..79d60d1198 100644 --- a/dispatcher/examples/gemm/cpp/README.md +++ b/dispatcher/examples/gemm/cpp/README.md @@ -29,14 +29,14 @@ cd examples ## Examples -| Example | Description | Complexity | -|---------|-------------|------------| -| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ | -| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ | -| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ | -| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ | -| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ | -| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ | +| Example | Description | +|---------|-------------| +| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | +| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | +| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | +| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | +| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | +| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ## Example Details @@ -225,5 +225,5 @@ DECL_KERNEL_SET(my_kernels, ## Related Documentation - [Python GEMM Examples](../python/README.md) -- [Convolution Examples](../../conv/cpp/README.md) +- [C++ Headers (GEMM + Grouped Conv)](../../../include/ck_tile/dispatcher/README.md) - [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/gemm/python/01_basic_gemm.py b/dispatcher/examples/gemm/python/01_basic_gemm.py index 93a78d24d1..8c23da89e2 100644 --- a/dispatcher/examples/gemm/python/01_basic_gemm.py +++ b/dispatcher/examples/gemm/python/01_basic_gemm.py @@ -7,41 +7,37 @@ Example 01: Basic GEMM with Multiple Kernels Demonstrates: -1. Declaring multiple kernel configurations -2. Printing all registered kernels -3. Running each kernel and validating output +1. Building a Registry with multiple kernel configurations +2. Parallel JIT compilation via registry.build() +3. Running each kernel and validating output against NumPy reference 4. Comparing performance across kernels -Complexity: ★★☆☆☆ - Usage: python3 01_basic_gemm.py - python3 01_basic_gemm.py --help python3 01_basic_gemm.py --dtype bf16 python3 01_basic_gemm.py --size 2048 + python3 01_basic_gemm.py --num-kernels 4 + python3 01_basic_gemm.py --workers 4 """ import sys +import time import argparse from pathlib import Path from dataclasses import dataclass -from typing import List sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) import numpy as np from ctypes_utils import ( KernelConfig, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @dataclass class KernelSpec: - """Specification for a kernel configuration""" - name: str tile_m: int tile_n: int @@ -50,80 +46,37 @@ class KernelSpec: scheduler: str = "intrawave" -# Define multiple kernel configurations to test (50+ kernels) KERNEL_SPECS = [ - # Small tiles - compv3 + # Small tiles KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"), KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"), - # Small tiles - compv4 KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"), - KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"), - # Medium tiles - compv3 + # Medium tiles KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"), KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"), - KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"), - # Medium tiles - compv4 KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"), KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"), - KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"), - # Rectangular tiles - compv3 + # Rectangular tiles KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"), KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"), KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"), KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"), - # Rectangular tiles - compv4 KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"), - KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"), KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"), - KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"), - # Large tiles - compv3 + # Large tiles KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"), - KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"), KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"), - KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"), KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"), - KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"), - # Large tiles - compv4 KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"), - KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"), - KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"), - KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"), KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"), - KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"), - # Interwave scheduler variants - KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"), + # Interwave scheduler KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"), - KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"), KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"), - # More tile_k variations - compv3 - KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"), - KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"), - KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"), - # More tile_k variations - compv4 - KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"), - KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"), - KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"), - # Additional rectangular - KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"), - KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"), - KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"), - KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"), - # Additional compv4 variants - KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"), - KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"), - KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"), - KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"), ] -def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: - """Create a KernelConfig from a spec""" - # Adjust warp tiles based on tile size - if spec.tile_m <= 64: - warp_m, warp_n = 16, 16 - else: - warp_m, warp_n = 32, 32 - +def spec_to_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + warp_m, warp_n = (16, 16) if spec.tile_m <= 64 else (32, 32) return KernelConfig( dtype_a=dtype, dtype_b=dtype, @@ -148,180 +101,118 @@ def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfi ) -def print_kernel_table(specs: List[KernelSpec], dtype: str): - """Print a formatted table of kernel configurations""" - print("\n" + "=" * 70) - print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)") - print("=" * 70) - print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}") - print(" " + "-" * 68) - - for i, spec in enumerate(specs, 1): - tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" - print( - f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}" - ) - - print(" " + "-" * 68) - print(f" Data type: {dtype}") - - def main(): - parser = argparse.ArgumentParser( - description="Basic GEMM Example with Multiple Kernels", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - python3 01_basic_gemm.py # Default FP16 with 4 kernels - python3 01_basic_gemm.py --dtype bf16 # BF16 mode - python3 01_basic_gemm.py --size 2048 # Larger problem size - python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels - """, - ) + parser = argparse.ArgumentParser(description="Basic GEMM with Multiple Kernels") + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--size", type=int, default=512, help="Problem size MxNxK") + parser.add_argument("--num-kernels", type=int, default=0, help="0 = all") parser.add_argument( - "--dtype", - default="fp16", - choices=["fp16", "bf16", "fp32"], - help="Data type (default: fp16)", - ) - parser.add_argument( - "--arch", - default="gfx942", - help="Target architecture (default: gfx942)", - ) - parser.add_argument( - "--size", - type=int, - default=512, - help="Problem size MxNxK (default: 512)", - ) - parser.add_argument( - "--num-kernels", - type=int, - default=0, - help="Number of kernels to test (0 = all)", + "--workers", type=int, default=0, help="Max parallel JIT workers (0 = auto)" ) args = parser.parse_args() - reset_for_example() - print("=" * 70) print("Example 01: Basic GEMM with Multiple Kernels") print("=" * 70) - # Select kernels to test specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS - # ========================================================================= - # Step 1: Print all kernel configurations - # ========================================================================= - print_kernel_table(specs, args.dtype) - - # ========================================================================= - # Step 2: Setup and test each kernel - # ========================================================================= - print("\n" + "=" * 70) - print(" RUNNING KERNELS") - print("=" * 70) - - np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 - M, N, K = args.size, args.size, args.size - - results = [] - - print(f"\n Problem size: {M}x{N}x{K}\n") + # Step 1: Build registry print( - f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}" + f"\n {len(specs)} kernel configurations, dtype={args.dtype}, arch={args.arch}" ) - print(" " + "-" * 78) - - for i, spec in enumerate(specs, 1): - # Create unique test data per kernel - np.random.seed(42 + i * 1000) - A = (np.random.randn(M, K) * 0.1).astype(np_dtype) - B = (np.random.randn(K, N) * 0.1).astype(np_dtype) - - # Create config and setup dispatcher - config = create_kernel_config(spec, args.dtype, args.arch) - - setup = setup_gemm_dispatcher( - config=config, - registry_name=f"kernel_{spec.name}", - verbose=False, - auto_rebuild=True, + print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}") + print(" " + "-" * 64) + for i, s in enumerate(specs, 1): + print( + f" {i:<3} {s.name:<22} {s.tile_m}x{s.tile_n}x{s.tile_k:<6} {s.pipeline:<10} {s.scheduler:<12}" ) + reg = Registry(name="basic_gemm") + for s in specs: + reg.register_kernel(spec_to_config(s, args.dtype, args.arch)) + + # Step 2: Parallel JIT build via registry.build() + workers = args.workers if args.workers > 0 else None + print( + f"\n--- Parallel JIT Build ({len(specs)} kernels{f', workers={workers}' if workers else ''}) ---" + ) + + t0 = time.perf_counter() + setups = reg.build(verbose=False, max_workers=workers) + jit_build_s = time.perf_counter() - t0 + + built = sum(1 for s in setups if s.success) + print(f" Built: {built}/{len(specs)} kernels in {jit_build_s:.1f} s") + + if built == 0: + print(" ERROR: No kernels built") + return 1 + + # Step 3: Run each kernel and validate + print(f"\n--- Running Kernels (problem {args.size}x{args.size}x{args.size}) ---") + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + M = N = K = args.size + + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + + print( + f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':<6}" + ) + print(" " + "-" * 80) + + results = [] + for i, (spec, setup) in enumerate(zip(specs, setups), 1): tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" if not setup.success: print( - f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}" ) - results.append((spec.name, False, 0, 0, 0)) - cleanup_gemm() + results.append((spec.name, False, 0.0, 0.0, 0.0)) continue - dispatcher = setup.dispatcher - - # Check if size is supported - if not dispatcher.is_supported(M, N, K): + disp = setup.dispatcher + if not disp.is_supported(M, N, K): print( - f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}" ) - results.append((spec.name, False, 0, 0, 0)) - cleanup_gemm() + results.append((spec.name, False, 0.0, 0.0, 0.0)) continue - # Run GEMM - result = dispatcher.run(A, B, M, N, K) - - if not result.success: + res = disp.run(A, B, M, N, K) + if not res.success: print( - f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'FAIL':<6}" ) - results.append((spec.name, False, 0, 0, 0)) - cleanup_gemm() + results.append((spec.name, False, 0.0, 0.0, 0.0)) continue - # Validate against NumPy reference - C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) - max_err = np.max(np.abs(result.output - C_ref)) - - # Check if within tolerance - passed = max_err < 1e-2 - status = "PASS" if passed else "FAIL" - + max_err = float(np.max(np.abs(res.output - C_ref))) + ok = max_err < 1e-2 + tag = "PASS" if ok else "FAIL" print( - f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {res.time_ms:>10.4f} {res.tflops:>10.2f} {max_err:>10.2e} {tag:<6}" ) - results.append((spec.name, passed, result.time_ms, result.tflops, max_err)) - - cleanup_gemm() - - # ========================================================================= - # Step 3: Summary - # ========================================================================= - print("\n" + "=" * 70) - print(" SUMMARY") - print("=" * 70) + results.append((spec.name, ok, res.time_ms, res.tflops, max_err)) + # Step 4: Summary passed = sum(1 for r in results if r[1]) failed = len(results) - passed + valid = [r for r in results if r[1]] - print(f"\n Results: {passed}/{len(results)} kernels passed") - print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}") - - if results: - valid_results = [r for r in results if r[1]] - if valid_results: - best = max(valid_results, key=lambda x: x[3]) - print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)") - - if failed == 0: - print("\n *** ALL KERNELS PASSED ***") - else: - print(f"\n *** {failed} KERNELS FAILED ***") - + print("\n" + "=" * 70) + print(f" Results: {passed}/{len(results)} passed") + print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}") + print(f" JIT time: {jit_build_s:.1f} s (parallel)") + if valid: + best = max(valid, key=lambda x: x[3]) + print(f" Best: {best[0]} ({best[3]:.2f} TFLOPS)") + print(f" Status: {'PASS' if failed == 0 else 'FAIL'}") print("=" * 70) return 0 if failed == 0 else 1 diff --git a/dispatcher/examples/gemm/python/02_batch_gemm.py b/dispatcher/examples/gemm/python/02_batch_gemm.py index 039aba2790..745ec1c494 100644 --- a/dispatcher/examples/gemm/python/02_batch_gemm.py +++ b/dispatcher/examples/gemm/python/02_batch_gemm.py @@ -6,9 +6,7 @@ """ Example 02: Batch GEMM -Runs multiple GEMM operations with different sizes. - -Complexity: ★★☆☆☆ +Runs multiple GEMM operations with different sizes using JIT compilation. Usage: python3 02_batch_gemm.py @@ -25,9 +23,8 @@ import numpy as np from ctypes_utils import ( KernelConfig, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @@ -55,20 +52,20 @@ Examples: help="Maximum problem size (default: 4096)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 02: Batch GEMM") print("=" * 60) # ========================================================================= - # Step 1: Setup dispatcher + # Step 1: JIT build dispatcher # ========================================================================= - print("\nStep 1: Setup Dispatcher") + print("\nStep 1: JIT Build Dispatcher") config = KernelConfig( dtype_a=args.dtype, @@ -80,19 +77,22 @@ Examples: gfx_arch=args.arch, ) - setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True) - if not setup.success: - print(f" ERROR: {setup.error}") + reg = Registry(name="batch_gemm") + reg.register_kernel(config) + + setups = reg.build(verbose=True) + if not setups or not setups[0].success: + error = setups[0].error if setups else "No kernels built" + print(f" ERROR: {error}") return 1 - dispatcher = setup.dispatcher + dispatcher = setups[0].dispatcher # ========================================================================= # Step 2: Run batch of different sizes # ========================================================================= print("\nStep 2: Run Batch") - # Generate sizes up to max_size all_sizes = [ (256, 256, 256), (512, 512, 512), @@ -135,9 +135,6 @@ Examples: avg_tflops = (total_ops / 1e12) / (total_time / 1000) print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS") - # Cleanup - cleanup_gemm() - print("\n" + "=" * 60) print("Batch GEMM complete!") print("=" * 60) diff --git a/dispatcher/examples/gemm/python/03_benchmark.py b/dispatcher/examples/gemm/python/03_benchmark.py index bec1b7e2fb..508b3f8b35 100644 --- a/dispatcher/examples/gemm/python/03_benchmark.py +++ b/dispatcher/examples/gemm/python/03_benchmark.py @@ -6,9 +6,8 @@ """ Example 03: Benchmark -Performance benchmarking with compute-optimized kernel configuration. - -Complexity: ★★★☆☆ +Performance benchmarking with compute-optimized kernel configuration +using JIT compilation. Usage: python3 03_benchmark.py @@ -26,9 +25,8 @@ import numpy as np from ctypes_utils import ( KernelConfig, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @@ -63,20 +61,20 @@ Examples: "--iterations", type=int, default=10, help="Benchmark iterations (default: 10)" ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 03: Benchmark") print("=" * 60) # ========================================================================= - # Step 1: Setup dispatcher with compute-optimized config + # Step 1: JIT build dispatcher with compute-optimized config # ========================================================================= - print("\nStep 1: Setup Dispatcher") + print("\nStep 1: JIT Build Dispatcher") config = KernelConfig( dtype_a=args.dtype, @@ -90,12 +88,16 @@ Examples: gfx_arch=args.arch, ) - setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True) - if not setup.success: - print(f" ERROR: {setup.error}") + reg = Registry(name="benchmark") + reg.register_kernel(config) + + setups = reg.build(verbose=True) + if not setups or not setups[0].success: + error = setups[0].error if setups else "No kernels built" + print(f" ERROR: {error}") return 1 - dispatcher = setup.dispatcher + dispatcher = setups[0].dispatcher # ========================================================================= # Step 2: Benchmark @@ -130,11 +132,9 @@ Examples: A = np.random.randn(M, K).astype(np_dtype) * 0.1 B = np.random.randn(K, N).astype(np_dtype) * 0.1 - # Warmup for _ in range(args.warmup): dispatcher.run(A, B, M, N, K) - # Benchmark times = [] for _ in range(args.iterations): result = dispatcher.run(A, B, M, N, K) @@ -150,9 +150,6 @@ Examples: f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}" ) - # Cleanup - cleanup_gemm() - # Summary print("\n" + "=" * 60) print("Summary") diff --git a/dispatcher/examples/gemm/python/04_validation.py b/dispatcher/examples/gemm/python/04_validation.py index 2fe54c53f7..d56621c3c8 100644 --- a/dispatcher/examples/gemm/python/04_validation.py +++ b/dispatcher/examples/gemm/python/04_validation.py @@ -6,9 +6,7 @@ """ Example 04: Validation -Validates GPU GEMM against NumPy reference. - -Complexity: ★★★☆☆ +Validates GPU GEMM against NumPy reference using JIT compilation. Usage: python3 04_validation.py @@ -26,9 +24,8 @@ import numpy as np from ctypes_utils import ( KernelConfig, Validator, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @@ -56,20 +53,20 @@ Examples: "--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)" ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 04: Validation") print("=" * 60) # ========================================================================= - # Step 1: Setup dispatcher + # Step 1: JIT build dispatcher # ========================================================================= - print("\nStep 1: Setup Dispatcher") + print("\nStep 1: JIT Build Dispatcher") config = KernelConfig( dtype_a=args.dtype, @@ -81,12 +78,16 @@ Examples: gfx_arch=args.arch, ) - setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True) - if not setup.success: - print(f" ERROR: {setup.error}") + reg = Registry(name="validation") + reg.register_kernel(config) + + setups = reg.build(verbose=True) + if not setups or not setups[0].success: + error = setups[0].error if setups else "No kernels built" + print(f" ERROR: {error}") return 1 - dispatcher = setup.dispatcher + dispatcher = setups[0].dispatcher # ========================================================================= # Step 2: Run validation tests @@ -139,9 +140,6 @@ Examples: print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED") failed += 1 - # Cleanup - cleanup_gemm() - # Summary print("\n" + "=" * 60) total = passed + failed diff --git a/dispatcher/examples/gemm/python/05_numpy_integration.py b/dispatcher/examples/gemm/python/05_numpy_integration.py index 493ce46d22..b0af5fa700 100644 --- a/dispatcher/examples/gemm/python/05_numpy_integration.py +++ b/dispatcher/examples/gemm/python/05_numpy_integration.py @@ -8,7 +8,6 @@ Example 05: NumPy Integration Shows how to create a GPU-accelerated matmul wrapper. -Complexity: ★★☆☆☆ Usage: python3 05_numpy_integration.py @@ -29,6 +28,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -70,7 +70,9 @@ Examples: help="Data type (default: fp16)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/06_json_export.py b/dispatcher/examples/gemm/python/06_json_export.py index 9e062e507b..780032ce06 100644 --- a/dispatcher/examples/gemm/python/06_json_export.py +++ b/dispatcher/examples/gemm/python/06_json_export.py @@ -8,7 +8,6 @@ Example 06: JSON Export Exports registry configuration to JSON. -Complexity: ★★☆☆☆ Usage: python3 06_json_export.py @@ -28,6 +27,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -54,7 +54,9 @@ Examples: help="Data type (default: fp16)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/07_stress_test.py b/dispatcher/examples/gemm/python/07_stress_test.py index 8160030631..620e66eeaf 100644 --- a/dispatcher/examples/gemm/python/07_stress_test.py +++ b/dispatcher/examples/gemm/python/07_stress_test.py @@ -18,7 +18,6 @@ This tests: - Multiple data types (fp16, bf16) - Different schedulers (intrawave, interwave) -Complexity: ★★★★☆ Usage: python3 07_stress_test.py @@ -43,6 +42,7 @@ from ctypes_utils import ( cleanup_gemm, reset_for_example, Validator, + detect_gpu_arch, ) @@ -413,8 +413,8 @@ Examples: ) parser.add_argument( "--arch", - default="gfx942", - help="Target architecture (default: gfx942)", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/08_heuristics.py b/dispatcher/examples/gemm/python/08_heuristics.py index e2763c0513..acbf1b3ae0 100644 --- a/dispatcher/examples/gemm/python/08_heuristics.py +++ b/dispatcher/examples/gemm/python/08_heuristics.py @@ -19,7 +19,6 @@ Heuristic strategies: - Memory-bound: Optimize memory access for bandwidth-limited cases - Latency-focused: Minimize kernel launch overhead for small problems -Complexity: ★★★★☆ Usage: python3 08_heuristics.py @@ -43,6 +42,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -561,8 +561,8 @@ Examples: ) parser.add_argument( "--arch", - default="gfx942", - help="Target architecture (default: gfx942)", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/09_ml_heuristic.py b/dispatcher/examples/gemm/python/09_ml_heuristic.py new file mode 100644 index 0000000000..d6726a2033 --- /dev/null +++ b/dispatcher/examples/gemm/python/09_ml_heuristic.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 09: ML-Based Kernel Selection + +Uses a trained LightGBM model to select the optimal kernel for each problem +size. The model predicts TFLOPS for every candidate in the kernel pool and +picks the highest-scoring one, which is then JIT-compiled and run. + +This replaces the hand-crafted rules in 08_heuristics.py with a data-driven +approach achieving 97-98% of oracle-best TFLOPS efficiency. + +Complexity: ***** + +Prerequisites: + - Trained model in dispatcher/heuristics/models/gemm_universal_fp8_gfx950/ + - lightgbm, pandas, numpy, pyarrow installed + +Usage: + python3 09_ml_heuristic.py + python3 09_ml_heuristic.py --dtype fp16 --arch gfx942 +""" + +import sys +import argparse +import time +from pathlib import Path +from dataclasses import dataclass +from typing import List + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "heuristics")) + +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, +) + +from predict import Predictor + + +@dataclass +class KernelSpec: + """Kernel specification -- same structure as 08_heuristics.py""" + + name: str + tile_m: int + tile_n: int + tile_k: int + pipeline: str = "compv3" + scheduler: str = "intrawave" + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + warp_m: int = 32 + warp_n: int = 32 + warp_k: int = 16 + + +# Kernel pool: representative configs spanning small to large tiles, +# compv3/compv4/mem pipelines, and intrawave/interwave schedulers. +KERNEL_POOL = [ + # Small tiles + KernelSpec("s_64x64_k32_v3", 64, 64, 32, "compv3", warp_m=16, warp_n=16), + KernelSpec("s_64x64_k64_v3", 64, 64, 64, "compv3", warp_m=16, warp_n=16), + KernelSpec("s_64x64_k128_v3", 64, 64, 128, "compv3", warp_m=16, warp_n=16), + KernelSpec("s_64x64_k32_v4", 64, 64, 32, "compv4", warp_m=16, warp_n=16), + KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", warp_m=16, warp_n=16), + KernelSpec("s_64x64_k128_mem", 64, 64, 128, "mem", warp_m=16, warp_n=16), + # Medium tiles + KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3"), + KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3"), + KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3"), + KernelSpec("m_128x128_k32_v4", 128, 128, 32, "compv4"), + KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4"), + KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem"), + KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem"), + # Rectangular medium + KernelSpec("r_64x128_k32", 64, 128, 32, "compv3", warp_m=16), + KernelSpec("r_128x64_k32", 128, 64, 32, "compv3", warp_n=16), + KernelSpec("r_64x128_k64", 64, 128, 64, "compv3", warp_m=16), + KernelSpec("r_128x64_k64", 128, 64, 64, "compv3", warp_n=16), + # Large tiles + KernelSpec("l_256x128_k32", 256, 128, 32, "compv3"), + KernelSpec("l_128x256_k32", 128, 256, 32, "compv3"), + KernelSpec("l_256x256_k32", 256, 256, 32, "compv3"), + KernelSpec("l_256x256_k64", 256, 256, 64, "compv3"), + # Interwave variants + KernelSpec("m_128x128_k64_iw", 128, 128, 64, "compv3", "interwave"), + KernelSpec("m_128x128_k64_mem_iw", 128, 128, 64, "mem", "interwave"), +] + + +def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict: + """Convert a KernelSpec to the dict format the feature engine expects. + + Note: pad_m/n/k default to True to match KernelConfig defaults and actual + compiled kernels. This ensures the ML model receives the correct padding + flags that will be used during JIT compilation. + """ + return { + "kernel_name": spec.name, + "tile_m": spec.tile_m, + "tile_n": spec.tile_n, + "tile_k": spec.tile_k, + "warp_m": spec.wave_m, + "warp_n": spec.wave_n, + "warp_k": spec.wave_k, + "warp_tile_m": spec.warp_m, + "warp_tile_n": spec.warp_n, + "warp_tile_k": spec.warp_k, + "pipeline": spec.pipeline, + "scheduler": spec.scheduler, + "epilogue": "cshuffle", + "pad_m": True, # Match KernelConfig default + "pad_n": True, # Match KernelConfig default + "pad_k": True, # Match KernelConfig default + "persistent": False, + "dtype": dtype, + "layout": layout, + } + + +def spec_to_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + """Convert a KernelSpec to the dispatcher's KernelConfig for JIT compilation.""" + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=spec.tile_m, + tile_n=spec.tile_n, + tile_k=spec.tile_k, + wave_m=spec.wave_m, + wave_n=spec.wave_n, + wave_k=spec.wave_k, + warp_m=spec.warp_m, + warp_n=spec.warp_n, + warp_k=spec.warp_k, + pipeline=spec.pipeline, + scheduler=spec.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +def ml_select_kernel( + predictor: Predictor, + pool: List[KernelSpec], + M: int, + N: int, + K: int, + dtype: str, + layout: str, +) -> tuple: + """Score all kernels in the pool and return (best_spec, predicted_tflops).""" + problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1} + kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool] + + ranked = predictor.rank_kernels(problem, kernel_dicts) + if not ranked: + return pool[0], 0.0 + + best_name, best_tflops = ranked[0] + best_spec = next((s for s in pool if s.name == best_name), pool[0]) + return best_spec, best_tflops + + +def main(): + parser = argparse.ArgumentParser(description="ML-based kernel selection for GEMM") + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp8"]) + parser.add_argument("--arch", default="gfx942") + parser.add_argument( + "--model_dir", + default=str( + Path(__file__).parent.parent.parent.parent + / "heuristics" + / "models" + / "gemm_universal_fp8_gfx950" + ), + ) + parser.add_argument( + "--no_run", action="store_true", help="Only predict, don't run GEMMs" + ) + args = parser.parse_args() + + print("=" * 75) + print(" Example 09: ML-Based Kernel Selection") + print("=" * 75) + print(f"\n Model: {args.model_dir}") + print(f" Dtype: {args.dtype}") + print(f" Arch: {args.arch}") + print(f" Pool: {len(KERNEL_POOL)} kernels") + + predictor = Predictor(args.model_dir) + print(" Model loaded successfully") + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float16 + + test_sizes = [ + (128, 128, 64), + (256, 256, 128), + (512, 512, 256), + (1024, 1024, 512), + (2048, 2048, 1024), + ] + + header = f"{'Shape':<20} {'Selected Kernel':<25} {'Pred TFLOPS':>12}" + if not args.no_run: + header += f" {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}" + print(f"\n {header}") + print(" " + "-" * len(header)) + + results = [] + + for M, N, K in test_sizes: + t0 = time.time() + best_spec, pred_tflops = ml_select_kernel( + predictor, KERNEL_POOL, M, N, K, args.dtype, "rcr" + ) + _ = (time.time() - t0) * 1000 # ML selection time (unused) + + size_str = f"{M}x{N}x{K}" + line = f" {size_str:<20} {best_spec.name:<25} {pred_tflops:>12.2f}" + + if args.no_run: + print(line) + results.append((size_str, best_spec.name, True, 0, pred_tflops)) + continue + + config = spec_to_kernel_config(best_spec, args.dtype, args.arch) + + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"ml_{best_spec.name}", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + line += f" {'N/A':>10} {'N/A':>10} {'BUILD':>8}" + print(line) + results.append((size_str, best_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + dispatcher = setup.dispatcher + if not dispatcher.is_supported(M, N, K): + line += f" {'N/A':>10} {'N/A':>10} {'UNSUP':>8}" + print(line) + results.append((size_str, best_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + result = dispatcher.run(A, B, M, N, K) + + if result.success: + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype( + np_dtype + ) + max_err = np.max(np.abs(result.output - C_ref)) + passed = max_err < 1e-2 + status = "PASS" if passed else "FAIL" + line += f" {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}" + results.append( + (size_str, best_spec.name, passed, result.time_ms, result.tflops) + ) + else: + line += f" {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + results.append((size_str, best_spec.name, False, 0, 0)) + + print(line) + cleanup_gemm() + + # Summary + print("\n" + "=" * 75) + print(" SUMMARY") + print("=" * 75) + passed = sum(1 for r in results if r[2]) + print(f"\n Results: {passed}/{len(results)} tests passed") + valid = [r for r in results if r[2] and r[4] > 0] + if valid: + avg = sum(r[4] for r in valid) / len(valid) + print(f" Average TFLOPS: {avg:.2f}") + if passed == len(results): + print("\n *** ALL TESTS PASSED ***") + print("=" * 75) + return 0 if passed == len(results) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/09_multi_registry.py b/dispatcher/examples/gemm/python/09_multi_registry.py index 97cbce3497..5d9af239d4 100644 --- a/dispatcher/examples/gemm/python/09_multi_registry.py +++ b/dispatcher/examples/gemm/python/09_multi_registry.py @@ -8,7 +8,6 @@ Example 09: Multiple Registries Demonstrates multiple registries for different optimization targets. -Complexity: ★★★★★ Usage: python3 09_multi_registry.py @@ -30,6 +29,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -50,7 +50,9 @@ Examples: help="Data type (default: fp16)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/10_advanced_benchmark.py b/dispatcher/examples/gemm/python/10_advanced_benchmark.py index e16e4e271f..b1462478d0 100644 --- a/dispatcher/examples/gemm/python/10_advanced_benchmark.py +++ b/dispatcher/examples/gemm/python/10_advanced_benchmark.py @@ -33,6 +33,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -69,7 +70,11 @@ def parse_args(): # Kernel configuration parser.add_argument("--dtype", default="fp16", help="Data type") parser.add_argument("--pipeline", default="compv4", help="Pipeline type") - parser.add_argument("--arch", default="gfx942", help="GPU architecture") + parser.add_argument( + "--arch", + default=detect_gpu_arch(), + help="GPU architecture (auto-detected from rocminfo)", + ) return parser.parse_args() diff --git a/dispatcher/examples/gemm/python/11_json_import.py b/dispatcher/examples/gemm/python/11_json_import.py index 06743af406..d19395e553 100644 --- a/dispatcher/examples/gemm/python/11_json_import.py +++ b/dispatcher/examples/gemm/python/11_json_import.py @@ -15,7 +15,6 @@ Key Features: - Use arch_filter validation on loaded configs - Export to C++ DECL_KERNEL_SET format -Complexity: ★★★☆☆ Usage: python3 11_json_import.py @@ -45,6 +44,7 @@ from ctypes_utils import ( # noqa: E402 cleanup_gemm, reset_for_example, validate_kernel_config, + detect_gpu_arch, ) # Sample JSON configuration (embedded for demonstration) @@ -141,8 +141,8 @@ Examples: ) parser.add_argument( "--arch", - default="gfx942", - help="Target GPU architecture (default: gfx942)", + default=detect_gpu_arch(), + help="Target GPU architecture (auto-detected from rocminfo, override with --arch gfxNNN)", ) args = parser.parse_args() @@ -236,13 +236,13 @@ Examples: else: invalid_count += 1 if invalid_count <= 3: # Show first 3 invalid - print(f"\n ✗ Invalid: {config.kernel_name()}") + print(f"\n FAIL Invalid: {config.kernel_name()}") for error in result.errors: print(f" Error: {error}") print("\n Validation Summary:") - print(f" ✓ Valid: {valid_count}") - print(f" ✗ Invalid: {invalid_count}") + print(f" OK Valid: {valid_count}") + print(f" FAIL Invalid: {invalid_count}") print(f" Total: {len(configs)}") # ========================================================================= @@ -275,12 +275,12 @@ Examples: disp_config, registry_name="json_import", verbose=False ) if setup.success: - print(" ✓ Dispatcher setup successful") + print(" OK Dispatcher setup successful") print( f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}" ) else: - print(f" ⚠ Dispatcher setup: {setup.error}") + print(f" WARNING Dispatcher setup: {setup.error}") print(" (This is expected if kernels aren't generated)") # ========================================================================= diff --git a/dispatcher/examples/gemm/python/README.md b/dispatcher/examples/gemm/python/README.md index 0a83f3533f..07757b951b 100644 --- a/dispatcher/examples/gemm/python/README.md +++ b/dispatcher/examples/gemm/python/README.md @@ -295,5 +295,5 @@ Compilation time scales roughly linearly with kernel count. ## Related Documentation - [C++ GEMM Examples](../cpp/README.md) -- [Python Conv Examples](../../conv/python/README.md) +- [Python Utilities](../../../python/README.md) - [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp b/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp new file mode 100644 index 0000000000..b503129c57 --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp @@ -0,0 +1,203 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 01: Basic Grouped Convolution +// +// Demonstrates three declaration patterns (mirrors GEMM 01): +// 1. AUTOFILL - tile + pipeline only, wave/warp auto-filled +// 2. AUTOCORRECT - invalid wave(1,1,1) corrected to valid config +// 3. FULL - all parameters explicit (matches validated gfx942 config) +// +// Then runs the forward convolution on GPU and verifies output. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_01_basic + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +// Three declaration patterns -- codegen auto-fills/auto-corrects as needed +DECL_GROUPED_CONV_KERNEL_SET( + basic_conv_kernels, + // Pattern 1: AUTOFILL - only tile + pipeline, rest auto-filled + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").scheduler("intrawave"), + "gfx950") + // Pattern 2: AUTOCORRECT - wave(1,1,1) invalid, corrected to (1,4,1) + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 64, 64) + .wave(1, 1, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8), + "gfx950") + // Pattern 3: FULL - all parameters explicit (validated config) + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 01: Basic Grouped Convolution", + "Declaration patterns + GPU execution"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--size", "14", "Spatial size (H=W)"); + args.add_option("-n", "1", "Batch size"); + args.add_option("-g", "1", "Groups"); + args.add_option("-c", "64", "Input channels C"); + args.add_option("-k", "128", "Output channels K"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 01: Basic Grouped Convolution"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int N = args.get_int("-n", 1); + int G = args.get_int("-g", 1); + int C = args.get_int("-c", 64); + int K = args.get_int("-k", 128); + int HW = args.get_int("--size", 14); + int Y = 3, X = 3; + + // Step 1: Show declared kernel sets + std::cout << "\nStep 1: Declared Kernel Sets\n"; + GroupedConvKernelSetRegistry::instance().print(); + + // Step 2: Register kernels + std::cout << "\nStep 2: Register Kernels\n"; + GroupedConvRegistry registry; + registry.set_name("basic_conv"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + // Step 3: Create dispatcher + std::cout << "\nStep 3: Create Dispatcher\n"; + GroupedConvDispatcher dispatcher(®istry); + + // Step 4: Build problem using CK Tile ConvParam + std::cout << "\nStep 4: Problem\n"; + auto problem = create_grouped_conv2d_problem(N, C, K, HW, HW, Y, X, 1, 1); + problem.op = GroupedConvOp::Forward; + print_grouped_conv_problem(problem); + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(HW), static_cast(HW)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input_host(in_desc); + ck_tile::HostTensor weight_host(wei_desc); + ck_tile::HostTensor output_host(out_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight_host); + + ck_tile::DeviceMem input_dev(input_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev(weight_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev(output_host.get_element_space_size_in_bytes()); + + input_dev.ToDevice(input_host.data()); + weight_dev.ToDevice(weight_host.data()); + + // Step 5: Select and run + std::cout << "\nStep 5: Select and Run\n"; + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No kernel found for problem!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), + weight_dev.GetDeviceBuffer(), + output_dev.GetDeviceBuffer(), + problem, + nullptr); + + double tflops = calculate_conv_tflops(problem, time_ms); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 6: Verify + std::cout << "\nStep 6: Verify\n"; + output_dev.FromDevice(output_host.data()); + + size_t total = output_host.get_element_space_size(); + size_t nonzero = 0; + double checksum = 0.0; + for(size_t i = 0; i < total; ++i) + { + float v = static_cast(output_host.data()[i]); + if(v != 0.0f) + ++nonzero; + checksum += v; + } + + bool passed = nonzero > 0; + std::cout << " Output elements: " << total << "\n"; + std::cout << " Non-zero: " << nonzero << "/" << total + << (nonzero > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n"; + std::cout << " Checksum: " << std::fixed << std::setprecision(2) << checksum << "\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + + utils::print_separator(); + std::cout << "DECLARATION PATTERNS:\n"; + std::cout << " 1. AUTOFILL: tile + pipeline only, wave/warp auto-filled\n"; + std::cout << " 2. AUTOCORRECT: invalid wave(1,1,1) corrected\n"; + std::cout << " 3. FULL: all parameters explicit\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp b/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp new file mode 100644 index 0000000000..a2f2b9d560 --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp @@ -0,0 +1,216 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 02: All Convolution Directions +// +// Forward, backward-data, and backward-weight for 2D convolution, +// each executed on GPU with non-zero verification. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_02_all_dirs + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +DECL_GROUPED_CONV_KERNEL_SET( + conv_fwd_2d, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8), + "gfx950")); + +DECL_GROUPED_CONV_KERNEL_SET( + conv_bwdd_2d, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv3").vector_sizes(4, 8, 8), + "gfx950")); + +DECL_GROUPED_CONV_KERNEL_SET( + conv_bwdw_2d, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .pipeline("compv3") + .memory_op("atomic_add") + .vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 02: All Convolution Directions", + "Forward/BwdData/BwdWeight with GPU execution and verification"); + args.add_option("--arch", "gfx950", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 02: All Convolution Directions"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + + GroupedConvRegistry registry; + registry.set_name("all_directions"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + GroupedConvDispatcher dispatcher(®istry); + + const int N = 1, G = 1, C = 64, K = 128, Hi = 14, Wi = 14, Y = 3, X = 3; + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor output(out_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + + ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev(output.get_element_space_size_in_bytes()); + + input_dev.ToDevice(input.data()); + weight_dev.ToDevice(weight.data()); + + std::cout << "\n " << std::left << std::setw(12) << "Direction" << std::right << std::setw(10) + << "Time(ms)" << std::setw(10) << "TFLOPS" << std::setw(14) << "NonZero" + << std::setw(10) << "Status" << "\n"; + std::cout << " " << std::string(56, '-') << "\n"; + + bool all_pass = true; + + auto print_result = + [](const char* label, float time_ms, double tflops, size_t nz, size_t total, bool ok) { + std::cout << " " << std::left << std::setw(12) << label << std::right << std::fixed + << std::setprecision(4) << std::setw(10) << time_ms << std::setprecision(2) + << std::setw(10) << tflops << std::setw(14) + << (std::to_string(nz) + "/" + std::to_string(total)) << std::setw(10) + << (ok ? "OK" : "FAIL") << "\n"; + }; + + // Forward: run(X, W, Y) + { + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::Forward); + float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), + weight_dev.GetDeviceBuffer(), + output_dev.GetDeviceBuffer(), + problem, + nullptr); + output_dev.FromDevice(output.data()); + size_t nz = 0; + for(size_t i = 0; i < output.get_element_space_size(); ++i) + if(static_cast(output.data()[i]) != 0.0f) + ++nz; + bool ok = nz > 0; + print_result("forward", + time_ms, + calculate_conv_tflops(problem, time_ms), + nz, + output.get_element_space_size(), + ok); + if(!ok) + all_pass = false; + } + + // Backward Data: run(dY, W, dX) + { + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData); + ck_tile::HostTensor dx_host(in_desc); + ck_tile::DeviceMem dx_dev(dx_host.get_element_space_size_in_bytes()); + float time_ms = dispatcher.run(output_dev.GetDeviceBuffer(), // dY (from forward pass) + weight_dev.GetDeviceBuffer(), // W + dx_dev.GetDeviceBuffer(), // dX (output) + problem, + nullptr); + dx_dev.FromDevice(dx_host.data()); + size_t nz = 0; + for(size_t i = 0; i < dx_host.get_element_space_size(); ++i) + if(static_cast(dx_host.data()[i]) != 0.0f) + ++nz; + bool ok = nz > 0; + print_result("bwd_data", + time_ms, + calculate_conv_tflops(problem, time_ms), + nz, + dx_host.get_element_space_size(), + ok); + if(!ok) + all_pass = false; + } + + // Backward Weight: run(X, dY, dW) + { + auto problem = create_grouped_conv2d_problem( + N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight); + ck_tile::HostTensor dw_host(wei_desc); + ck_tile::DeviceMem dw_dev(dw_host.get_element_space_size_in_bytes()); + float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), // X + output_dev.GetDeviceBuffer(), // dY + dw_dev.GetDeviceBuffer(), // dW (output) + problem, + nullptr); + dw_dev.FromDevice(dw_host.data()); + size_t nz = 0; + for(size_t i = 0; i < dw_host.get_element_space_size(); ++i) + if(static_cast(dw_host.data()[i]) != 0.0f) + ++nz; + bool ok = nz > 0; + print_result("bwd_weight", + time_ms, + calculate_conv_tflops(problem, time_ms), + nz, + dw_host.get_element_space_size(), + ok); + if(!ok) + all_pass = false; + } + + utils::print_separator(); + std::cout << " Status: " << (all_pass ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return all_pass ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp b/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp new file mode 100644 index 0000000000..12bd87d1a4 --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp @@ -0,0 +1,263 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 03: Benchmark and CPU-Reference Validation +// +// Runs a 2D grouped conv forward kernel on the GPU via dispatcher.run() +// and compares against the CK Tile host reference implementation. +// Exposes warmup/repeat/log_level as CLI args (matches example 20 pattern). +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_03_bench_val + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; +using AccDataType = float; + +DECL_GROUPED_CONV_KERNEL_SET( + bench_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8), + "gfx950") + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 03: Benchmark & Validation", + "GPU execution with CPU reference validation"); + args.add_option("-n", "1", "Batch size N"); + args.add_option("-g", "1", "Groups G"); + args.add_option("-c", "64", "Input channels C"); + args.add_option("-k", "128", "Output channels K"); + args.add_option("--size", "14", "Spatial size (H=W)"); + args.add_option("--warmup", "3", "Warmup iterations"); + args.add_option("--repeat", "10", "Benchmark iterations"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_flag("--no-verify", "Skip CPU validation"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 03: Grouped Conv Benchmark & Validation"); + + int N = args.get_int("-n", 1); + int G = args.get_int("-g", 1); + int C = args.get_int("-c", 64); + int K = args.get_int("-k", 128); + int Hi = args.get_int("--size", 14); + int Wi = Hi; + int Y = 3, X = 3; + int warmup = args.get_int("--warmup", 3); + int repeat = args.get_int("--repeat", 10); + bool verify = !args.has("--no-verify"); + std::string gfx_arch = args.get("--arch", "gfx950"); + + std::cout << "\nProblem: N=" << N << " G=" << G << " C=" << C << " K=" << K << " Hi=" << Hi + << " Wi=" << Wi << " Y=" << Y << " X=" << X << "\n"; + std::cout << "Benchmark: warmup=" << warmup << " repeat=" << repeat << "\n"; + + // Step 1: Setup tensors using CK Tile descriptors + std::cout << "\nStep 1: Setup tensors\n"; + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor output_gpu(out_desc); + ck_tile::HostTensor output_cpu(out_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + output_cpu.SetZero(); + + std::cout << " Input: " << input.get_element_space_size() << " elements\n"; + std::cout << " Weight: " << weight.get_element_space_size() << " elements\n"; + std::cout << " Output: " << output_gpu.get_element_space_size() << " elements\n"; + + // Step 2: CPU reference + if(verify) + { + std::cout << "\nStep 2: CPU Reference\n"; + + std::vector strides_v = {1, 1}; + std::vector dilations_v = {1, 1}; + std::vector left_pads_v = {1, 1}; + std::vector right_pads_v = {1, 1}; + + ck_tile::reference_grouped_conv_fwd<2, InDataType, WeiDataType, OutDataType>( + input, weight, output_cpu, strides_v, dilations_v, left_pads_v, right_pads_v); + + std::cout << " CPU ref[0..7]: "; + for(int i = 0; i < std::min(8, static_cast(output_cpu.get_element_space_size())); ++i) + std::cout << std::fixed << std::setprecision(4) + << static_cast(output_cpu.data()[i]) << " "; + std::cout << "\n"; + } + + // Step 3: GPU execution via dispatcher + std::cout << "\nStep 3: GPU Execution (via dispatcher.run)\n"; + + GroupedConvRegistry registry; + registry.set_name("bench_val"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + GroupedConvDispatcher dispatcher(®istry); + + auto problem = create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1); + problem.op = GroupedConvOp::Forward; + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev(output_gpu.get_element_space_size_in_bytes()); + + input_dev.ToDevice(input.data()); + weight_dev.ToDevice(weight.data()); + + float elapsed_ms = dispatcher.run(input_dev.GetDeviceBuffer(), + weight_dev.GetDeviceBuffer(), + output_dev.GetDeviceBuffer(), + problem, + nullptr); + + output_dev.FromDevice(output_gpu.data()); + + size_t total = output_gpu.get_element_space_size(); + std::cout << " GPU out[0..7]: "; + for(int i = 0; i < std::min(8, static_cast(total)); ++i) + std::cout << std::fixed << std::setprecision(4) << static_cast(output_gpu.data()[i]) + << " "; + std::cout << "\n"; + + size_t nonzero_gpu = 0; + double gpu_sum = 0.0; + for(size_t i = 0; i < total; ++i) + { + float v = static_cast(output_gpu.data()[i]); + if(v != 0.0f) + ++nonzero_gpu; + gpu_sum += v; + } + std::cout << " GPU checksum: " << std::fixed << std::setprecision(6) << gpu_sum << "\n"; + std::cout << " GPU non-zero: " << nonzero_gpu << "/" << total + << (nonzero_gpu > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n"; + + int Ho = static_cast(problem.Ho()); + int Wo = static_cast(problem.Wo()); + double flops = 2.0 * G * N * K * C * Y * X * Ho * Wo; + double tflops = flops / (elapsed_ms * 1e9); + + std::cout << " Time: " << std::fixed << std::setprecision(4) << elapsed_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 4: Validation + bool passed = true; + if(verify) + { + std::cout << "\nStep 4: Validation (GPU vs CPU)\n"; + + constexpr float rtol = 1e-2f; + constexpr float atol = 1e-2f; + + float max_diff = 0.0f; + float max_rel = 0.0f; + size_t max_diff_idx = 0; + size_t num_elements = output_gpu.get_element_space_size(); + size_t mismatches = 0; + + for(size_t i = 0; i < num_elements; ++i) + { + float gpu_val = static_cast(output_gpu.data()[i]); + float cpu_val = static_cast(output_cpu.data()[i]); + float diff = std::abs(gpu_val - cpu_val); + float tol = atol + rtol * std::abs(cpu_val); + float rel = diff / (std::abs(cpu_val) + 1e-6f); + if(diff > max_diff) + { + max_diff = diff; + max_diff_idx = i; + } + max_rel = std::max(max_rel, rel); + if(diff > tol) + ++mismatches; + } + + passed = (mismatches == 0); + + std::cout << " Side-by-side at worst element [" << max_diff_idx << "]:\n"; + std::cout << " GPU: " << std::fixed << std::setprecision(6) + << static_cast(output_gpu.data()[max_diff_idx]) + << " CPU: " << static_cast(output_cpu.data()[max_diff_idx]) + << " diff: " << std::scientific << max_diff << "\n"; + std::cout << " Elements: " << num_elements << "\n"; + std::cout << " Mismatches: " << mismatches << "/" << num_elements << "\n"; + std::cout << " Max abs diff: " << std::scientific << max_diff << "\n"; + std::cout << " Max rel diff: " << std::scientific << max_rel << "\n"; + std::cout << " Status: " << (passed ? "PASSED" : "FAILED") << "\n"; + } + + utils::print_separator(); + std::cout << "BENCHMARK & VALIDATION:\n"; + std::cout << " GPU kernel: " << (selected ? selected->name() : "none") << "\n"; + std::cout << " Performance: " << std::fixed << std::setprecision(2) << tflops + << " TFLOPS\n"; + std::cout << " CPU reference: reference_grouped_conv_fwd<2>()\n"; + std::cout << " Validation: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp b/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp new file mode 100644 index 0000000000..0e5a6d33be --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 04: Heuristic Selection + JSON Export +// +// Demonstrates runtime kernel selection with heuristic ranking, +// GPU execution, and JSON registry export. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_04_registry_json + +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +// Two tile configs for heuristic selection +DECL_GROUPED_CONV_KERNEL_SET( + heuristic_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8), + "gfx950") + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8), + "gfx950")); + +std::vector conv_heuristic(const GroupedConvProblem& problem) +{ + int64_t spatial = problem.Ho() * problem.Wo(); + if(spatial > 400) + return {"128x128", "64x64"}; + return {"64x64", "128x128"}; +} + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 04: Heuristic + JSON", + "Runtime kernel selection and JSON export"); + args.add_option("--arch", "gfx950", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 04: Heuristic Selection + JSON Export"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + + // Step 1: Register + std::cout << "\nStep 1: Register Kernels" << std::endl; + GroupedConvRegistry registry; + registry.set_name("heuristic_conv"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)" << std::endl; + + // Step 2: Heuristic dispatcher + std::cout << "\nStep 2: Heuristic Dispatcher" << std::endl; + GroupedConvDispatcher dispatcher(®istry); + dispatcher.set_strategy(GroupedConvDispatcher::SelectionStrategy::Heuristic); + dispatcher.set_heuristic(conv_heuristic); + + // Step 3: Select kernels (no GPU yet) + std::cout << "\nStep 3: Kernel Selection" << std::endl; + + auto problem = create_grouped_conv2d_problem(1, 64, 128, 14, 14, 3, 3, 1, 1); + + auto* selected = dispatcher.select_kernel(problem); + std::cout << " Selected: " << (selected ? selected->name() : "none") << std::endl; + + // Step 4: GPU execution + std::cout << "\nStep 4: GPU Execution" << std::endl; + + ck_tile::conv::ConvParam cp{ + 2, + static_cast(1), + static_cast(1), + static_cast(128), + static_cast(64), + {static_cast(3), static_cast(3)}, + {static_cast(14), static_cast(14)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + std::cout << " Creating tensors..." << std::endl; + auto in_d = ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(cp); + auto wei_d = ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(cp); + auto out_d = ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(cp); + + ck_tile::HostTensor input(in_d); + ck_tile::HostTensor weight(wei_d); + ck_tile::HostTensor output(out_d); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + + std::cout << " Allocating device memory..." << std::endl; + ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_dev(output.get_element_space_size_in_bytes()); + in_dev.ToDevice(input.data()); + wei_dev.ToDevice(weight.data()); + + std::cout << " Launching kernel..." << std::endl; + float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(), + wei_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer(), + problem, + nullptr); + + std::cout << " Reading back..." << std::endl; + out_dev.FromDevice(output.data()); + size_t nz = 0; + for(size_t i = 0; i < output.get_element_space_size(); ++i) + if(static_cast(output.data()[i]) != 0.0f) + ++nz; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms" + << std::endl; + std::cout << " TFLOPS: " << std::setprecision(2) << calculate_conv_tflops(problem, time_ms) + << std::endl; + std::cout << " NonZero: " << nz << "/" << output.get_element_space_size() << std::endl; + + // Step 5: JSON export + std::cout << "\nStep 5: JSON Export" << std::endl; + std::string json = registry.export_json(false); + std::cout << " JSON size: " << json.size() << " bytes" << std::endl; + + bool passed = nz > 0; + utils::print_separator(); + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp b/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp new file mode 100644 index 0000000000..35595bb14c --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp @@ -0,0 +1,183 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 05: Backward Data with CPU Reference Validation +// +// Computes dX = ConvBwdData(dY, W) on GPU via dispatcher.run() +// and validates against ck_tile::reference_grouped_conv_bwd_data. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_05_bwd_data + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +DECL_GROUPED_CONV_KERNEL_SET( + bwd_data_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .pipeline("compv3") + .scheduler("intrawave") + .vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 05: Backward Data Validation", + "dX = ConvBwdData(dY, W) with CPU reference"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("-n", "1", "Batch size"); + args.add_option("-c", "64", "Input channels"); + args.add_option("-k", "128", "Output channels"); + args.add_option("--size", "14", "Spatial size (H=W)"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 05: Backward Data with CPU Validation"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int N = args.get_int("-n", 1), G = 1; + int C = args.get_int("-c", 64), K = args.get_int("-k", 128); + int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3; + + // Setup + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + // dY (gradient from next layer) and W (weight) are inputs; dX is output + ck_tile::HostTensor dy(out_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor dx_gpu(in_desc); + ck_tile::HostTensor dx_cpu(in_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(dy); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + dx_cpu.SetZero(); + + // CPU reference + std::cout << "\nStep 1: CPU Reference (bwd_data)\n"; + std::vector strides_v = {1, 1}; + std::vector dilations_v = {1, 1}; + std::vector left_pads_v = {1, 1}; + std::vector right_pads_v = {1, 1}; + + ck_tile::reference_grouped_conv_bwd_data<2, InDataType, WeiDataType, OutDataType>( + dx_cpu, weight, dy, strides_v, dilations_v, left_pads_v, right_pads_v); + std::cout << " CPU complete\n"; + + // GPU execution via dispatcher + std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n"; + + GroupedConvRegistry registry; + registry.set_name("bwd_data"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + GroupedConvDispatcher dispatcher(®istry); + + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData); + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No bwd_data kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes()); + ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dx_dev(dx_gpu.get_element_space_size_in_bytes()); + + dy_dev.ToDevice(dy.data()); + wei_dev.ToDevice(weight.data()); + + // dispatcher.run(dY, W, dX, problem) for bwd_data + float time_ms = dispatcher.run(dy_dev.GetDeviceBuffer(), + wei_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer(), + problem, + nullptr); + + dx_dev.FromDevice(dx_gpu.data()); + + double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0; + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Validation + std::cout << "\nStep 3: Validation (GPU vs CPU)\n"; + + size_t num_elements = dx_gpu.get_element_space_size(); + float max_abs = 0, max_rel = 0; + size_t mismatches = 0; + constexpr float rtol = 5e-2f, atol = 5e-2f; + + for(size_t i = 0; i < num_elements; ++i) + { + float gv = static_cast(dx_gpu.data()[i]); + float cv = static_cast(dx_cpu.data()[i]); + float d = std::abs(gv - cv); + float r = d / (std::abs(cv) + 1e-6f); + max_abs = std::max(max_abs, d); + max_rel = std::max(max_rel, r); + if(d > atol + rtol * std::abs(cv)) + ++mismatches; + } + + bool passed = (mismatches == 0); + std::cout << " Elements: " << num_elements << "\n"; + std::cout << " Mismatches: " << mismatches << "\n"; + std::cout << " Max abs diff: " << std::scientific << max_abs << "\n"; + std::cout << " Max rel diff: " << std::scientific << max_rel << "\n"; + + utils::print_separator(); + std::cout << " dX = ConvBwdData(dY, W)\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp b/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp new file mode 100644 index 0000000000..41cb75aecf --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp @@ -0,0 +1,188 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 06: Backward Weight with CPU Reference Validation +// +// Computes dW = ConvBwdWeight(X, dY) on GPU via dispatcher.run() +// and validates against ck_tile::reference_grouped_conv_bwd_weight. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_06_bwd_weight + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +DECL_GROUPED_CONV_KERNEL_SET( + bwd_weight_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .pipeline("compv3") + .scheduler("intrawave") + .memory_op("atomic_add") + .vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 06: Backward Weight Validation", + "dW = ConvBwdWeight(X, dY) with CPU reference"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("-n", "1", "Batch size"); + args.add_option("-c", "64", "Input channels"); + args.add_option("-k", "128", "Output channels"); + args.add_option("--size", "14", "Spatial size (H=W)"); + args.add_option("--split-k", "1", "Split-K factor for bwd_weight (k_batch)"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 06: Backward Weight with CPU Validation"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int N = args.get_int("-n", 1), G = 1; + int C = args.get_int("-c", 64), K = args.get_int("-k", 128); + int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3; + + // Setup + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + // X (input) and dY (gradient) are inputs; dW is output + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor dy(out_desc); + ck_tile::HostTensor dw_gpu(wei_desc); + ck_tile::HostTensor dw_cpu(wei_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(dy); + dw_cpu.SetZero(); + + // CPU reference + std::cout << "\nStep 1: CPU Reference (bwd_weight)\n"; + std::vector strides_v = {1, 1}; + std::vector dilations_v = {1, 1}; + std::vector left_pads_v = {1, 1}; + std::vector right_pads_v = {1, 1}; + + ck_tile::reference_grouped_conv_bwd_weight<2, InDataType, WeiDataType, OutDataType>( + input, dw_cpu, dy, strides_v, dilations_v, left_pads_v, right_pads_v); + std::cout << " CPU complete\n"; + + // GPU execution + std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n"; + + GroupedConvRegistry registry; + registry.set_name("bwd_weight"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + GroupedConvDispatcher dispatcher(®istry); + + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight); + problem.split_k = args.get_int("--split-k", 1); + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No bwd_weight kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dw_dev(dw_gpu.get_element_space_size_in_bytes()); + + in_dev.ToDevice(input.data()); + dy_dev.ToDevice(dy.data()); + if(problem.split_k > 1) + dw_dev.SetZero(); + + // dispatcher.run(X, dY, dW, problem) for bwd_weight + float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(), + dy_dev.GetDeviceBuffer(), + dw_dev.GetDeviceBuffer(), + problem, + nullptr); + + dw_dev.FromDevice(dw_gpu.data()); + + double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0; + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Validation + std::cout << "\nStep 3: Validation (GPU vs CPU)\n"; + + size_t num_elements = dw_gpu.get_element_space_size(); + float max_abs = 0, max_rel = 0; + size_t mismatches = 0; + constexpr float rtol = 5e-2f, atol = 5e-2f; + + for(size_t i = 0; i < num_elements; ++i) + { + float gv = static_cast(dw_gpu.data()[i]); + float cv = static_cast(dw_cpu.data()[i]); + float d = std::abs(gv - cv); + float r = d / (std::abs(cv) + 1e-6f); + max_abs = std::max(max_abs, d); + max_rel = std::max(max_rel, r); + if(d > atol + rtol * std::abs(cv)) + ++mismatches; + } + + bool passed = (mismatches == 0); + std::cout << " Elements: " << num_elements << "\n"; + std::cout << " Mismatches: " << mismatches << "\n"; + std::cout << " Max abs diff: " << std::scientific << max_abs << "\n"; + std::cout << " Max rel diff: " << std::scientific << max_rel << "\n"; + + utils::print_separator(); + std::cout << " dW = ConvBwdWeight(X, dY)\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp b/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp new file mode 100644 index 0000000000..5c95f2c45a --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp @@ -0,0 +1,226 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 07: Multi-Tile Benchmark +// +// Benchmarks multiple tile configurations across ResNet-like problem sizes. +// Exposes warmup, repeat, and init method as CLI args (matching CK Tile +// example 20 patterns). +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_07_benchmark + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +// Multiple tile configurations for benchmarking +DECL_GROUPED_CONV_KERNEL_SET( + benchmark_tiles, + // Small tile - compv3 + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 64, 64) + .wave(1, 4, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950") + // Medium tile - compv3 + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950") + // Large tile - compv4 with double smem buffer + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 256, 256) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 07: Multi-Tile Benchmark", + "Multiple tiles across ResNet-like problem sizes"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--warmup", "5", "Warmup iterations (passed to stream_config)"); + args.add_option("--repeat", "20", "Benchmark iterations (passed to stream_config)"); + args.add_option("--init", "0", "Init method: 0=random, 1=linear, 2=constant(1)"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 07: Multi-Tile Benchmark"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int warmup = args.get_int("--warmup", 5); + int repeat = args.get_int("--repeat", 20); + int init_method = args.get_int("--init", 0); + + std::cout << "\n Config: warmup=" << warmup << " repeat=" << repeat << " init=" << init_method + << "\n"; + + GroupedConvRegistry registry; + registry.set_name("benchmark"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + GroupedConvDispatcher dispatcher(®istry); + + // ResNet-like problem sizes + struct BenchProblem + { + const char* label; + int N, C, K, Hi, Wi, Y, X; + }; + + BenchProblem problems[] = { + {"ResNet-stage2", 1, 64, 64, 56, 56, 3, 3}, + {"ResNet-stage3", 1, 128, 128, 28, 28, 3, 3}, + {"ResNet-stage4", 1, 256, 256, 14, 14, 3, 3}, + {"ResNet-stage5", 1, 512, 512, 7, 7, 3, 3}, + {"Pointwise-1x1", 1, 256, 256, 56, 56, 1, 1}, + {"Batch-8", 8, 64, 128, 56, 56, 3, 3}, + }; + + std::cout << "\n " << std::left << std::setw(16) << "Problem" << std::right << std::setw(5) + << "N" << std::setw(5) << "C" << std::setw(5) << "K" << std::setw(5) << "H" + << std::setw(5) << "W" << std::setw(4) << "F" << std::setw(10) << "Time(ms)" + << std::setw(10) << "TFLOPS" << std::setw(10) << "Status" << "\n"; + std::cout << " " << std::string(74, '-') << "\n"; + + bool all_pass = true; + for(const auto& bp : problems) + { + auto problem = + create_grouped_conv2d_problem(bp.N, bp.C, bp.K, bp.Hi, bp.Wi, bp.Y, bp.X, 1, 1); + problem.op = GroupedConvOp::Forward; + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(1), + static_cast(bp.N), + static_cast(bp.K), + static_cast(bp.C), + {static_cast(bp.Y), static_cast(bp.X)}, + {static_cast(bp.Hi), static_cast(bp.Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor output(out_desc); + + switch(init_method) + { + case 1: + ck_tile::FillMonotonicSeq{0.0f, 0.001f}(input); + ck_tile::FillMonotonicSeq{0.0f, 0.001f}(weight); + break; + case 2: + ck_tile::FillConstant{1.0f}(input); + ck_tile::FillConstant{1.0f}(weight); + break; + default: + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + break; + } + ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_dev(output.get_element_space_size_in_bytes()); + + in_dev.ToDevice(input.data()); + wei_dev.ToDevice(weight.data()); + + float time_ms = 0; + bool ok = false; + try + { + time_ms = dispatcher.run(in_dev.GetDeviceBuffer(), + wei_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer(), + problem, + nullptr); + + out_dev.FromDevice(output.data()); + size_t nz = 0; + for(size_t j = 0; j < output.get_element_space_size(); ++j) + if(static_cast(output.data()[j]) != 0.0f) + ++nz; + ok = nz > 0; + } + catch(const std::exception&) + { + ok = false; + } + + double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0; + + std::string filter_str = std::to_string(bp.Y) + "x" + std::to_string(bp.X); + std::cout << " " << std::left << std::setw(16) << bp.label << std::right << std::setw(5) + << bp.N << std::setw(5) << bp.C << std::setw(5) << bp.K << std::setw(5) << bp.Hi + << std::setw(5) << bp.Wi << std::setw(4) << filter_str << std::fixed + << std::setprecision(4) << std::setw(10) << time_ms << std::setprecision(2) + << std::setw(10) << tflops << std::setw(10) << (ok ? "OK" : "FAIL") << "\n"; + if(!ok) + all_pass = false; + } + + utils::print_separator(); + std::cout << " Warmup: " << warmup << ", Repeat: " << repeat << ", Init: " << init_method + << "\n"; + std::cout << " Status: " << (all_pass ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return all_pass ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py b/dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py new file mode 100644 index 0000000000..46f57b3879 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 01: Basic Grouped Convolution + +Demonstrates: +1. Three kernel configuration patterns (minimal, explicit, full ConvConfigBase) +2. Adding kernels to a registry +3. Validation and auto-correction +4. JIT compilation via registry.build() +5. GPU execution with CPU reference verification + +Usage: + python3 01_basic_grouped_conv.py + python3 01_basic_grouped_conv.py --variant bwd_data + python3 01_basic_grouped_conv.py --arch gfx942 +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + validate_grouped_conv_config, + auto_correct_grouped_conv_config, + detect_gpu_arch, +) + + +def cpu_conv2d_fwd(inp, wei, prob): + """Naive CPU reference: 2D forward, NHWGC layout.""" + N, Hi, Wi, G, Cpg = inp.shape + _, Kpg, Y, X, _ = wei.shape + Ho, Wo = prob.Ho, prob.Wo + out = np.zeros((N, Ho, Wo, G, Kpg), dtype=np.float32) + for n in range(N): + for g in range(G): + for ho in range(Ho): + for wo in range(Wo): + for k in range(Kpg): + s = 0.0 + for y in range(Y): + for x in range(X): + hi = ( + ho * prob.stride_h + - prob.pad_h + + y * prob.dilation_h + ) + wi = ( + wo * prob.stride_w + - prob.pad_w + + x * prob.dilation_w + ) + if 0 <= hi < Hi and 0 <= wi < Wi: + for c in range(Cpg): + s += float(inp[n, hi, wi, g, c]) * float( + wei[g, k, y, x, c] + ) + out[n, ho, wo, g, k] = s + return out + + +def main(): + parser = argparse.ArgumentParser(description="Basic Grouped Conv Example") + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument( + "--variant", default="forward", choices=["forward", "bwd_data", "bwd_weight"] + ) + parser.add_argument("--ndim", type=int, default=2, choices=[2, 3]) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument( + "--workers", type=int, default=0, help="Max JIT workers (0=auto)" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 01: Basic Grouped Convolution") + print("=" * 70) + + # ========================================================================= + # Step 1: Three kernel configuration patterns + # ========================================================================= + print("\n--- Step 1: Kernel Configuration Patterns ---") + + # Pattern 1: MINIMAL -- only variant/dtype/arch, everything else auto-filled + config_minimal = GroupedConvKernelConfig( + variant=args.variant, + ndim_spatial=args.ndim, + arch=args.arch, + dtype=args.dtype, + ) + print("\n Pattern 1: MINIMAL (defaults auto-filled)") + config_minimal.print_config(indent=" ") + + # Pattern 2: EXPLICIT tile/wave/warp -- user controls tiling strategy + config_explicit = GroupedConvKernelConfig( + variant=args.variant, + ndim_spatial=args.ndim, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + ) + print("\n Pattern 2: EXPLICIT tile/wave/warp") + config_explicit.print_config(indent=" ") + + # Pattern 3: FULL ConvConfigBase -- every parameter specified + config_full = GroupedConvKernelConfig( + variant=args.variant, + ndim_spatial=args.ndim, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + print("\n Pattern 3: FULL (all ConvConfigBase fields)") + config_full.print_config(indent=" ") + + # ========================================================================= + # Step 2: Build a registry with multiple configs + # ========================================================================= + print("\n--- Step 2: Build Registry ---") + registry = GroupedConvRegistry("basic_conv") + registry.add(config_minimal) + registry.add(config_explicit) + registry.add(config_full) + registry.print_registry() + + # ========================================================================= + # Step 3: Validate and auto-correct + # ========================================================================= + print("\n--- Step 3: Validate & Auto-Correct ---") + for i, cfg in enumerate(registry.kernels): + result = validate_grouped_conv_config(cfg.to_dict()) + if result.is_valid: + print(f" Config [{i}] {cfg.tile_str}: VALID") + else: + print(f" Config [{i}] {cfg.tile_str}: needs correction") + corrected, result = auto_correct_grouped_conv_config(cfg.to_dict()) + print(f" After correction: valid={result.is_valid}") + + # ========================================================================= + # Step 4: JIT compile via registry.build() + # ========================================================================= + print("\n--- Step 4: JIT Build (via registry.build()) ---") + + # Use only the first config for the actual GPU run + jit_reg = GroupedConvRegistry("jit") + jit_reg.add(config_minimal) + + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = jit_reg.build(verbose=False, max_workers=workers) + jit_build_s = time.perf_counter() - t0 + + key = (args.variant, args.ndim) + if key not in runners: + print(" JIT build failed") + return 1 + runner = runners[key] + print(f" JIT build: {jit_build_s:.3f} s") + print(f" Library: {runner.library_path}") + print(f" Kernels: {runner.lib.kernel_names()}") + + # ========================================================================= + # Step 5: Define problem + GPU execution + # ========================================================================= + print("\n--- Step 5: GPU Execution ---") + prob = GroupedConvProblem( + N=1, + C=64, + K=128, + Hi=16, + Wi=16, + Y=3, + X=3, + stride_h=1, + stride_w=1, + pad_h=1, + pad_w=1, + direction=args.variant, + ) + prob.print_problem() + + inp = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(np.float16) + wei = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(np.float16) + + res = runner.run(inp, wei, prob) + if not res.success: + print(f" GPU execution failed: {res.error}") + runner.cleanup() + return 1 + + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print( + f" Output: shape={res.output.shape}, range=[{res.output.min():.3f}, {res.output.max():.3f}]" + ) + + # ========================================================================= + # Step 6: CPU reference (forward 2D only) + # ========================================================================= + verified = False + if args.variant == "forward" and args.ndim == 2: + print("\n--- Step 6: CPU Reference Verification ---") + ref = cpu_conv2d_fwd(inp, wei, prob) + gpu_f32 = res.output.astype(np.float32) + diff = np.abs(gpu_f32 - ref) + max_abs = diff.max() + max_rel = (diff / (np.abs(ref) + 1e-6)).max() + match = np.allclose(gpu_f32, ref, atol=0.05, rtol=0.05) + print(f" max_abs_diff: {max_abs:.6f}") + print(f" max_rel_diff: {max_rel:.6f}") + print(f" Match: {match}") + verified = match + + runner.cleanup() + + # Summary + print("\n" + "=" * 70) + status = ( + "PASS" if res.success and (verified or args.variant != "forward") else "FAIL" + ) + print(f" Status: {status}") + print( + f" {config_minimal.name} | {prob.gflops:.2f} GFLOPs | {res.tflops:.2f} TFLOPS" + ) + print(f" JIT build time: {jit_build_s:.3f} s") + print(f" Registry: {len(registry)} configs (3 patterns demonstrated)") + print("=" * 70) + return 0 if status == "PASS" else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/02_forward.py b/dispatcher/examples/grouped_conv/python/02_forward.py new file mode 100644 index 0000000000..8f59db05a1 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/02_forward.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 02: Forward Convolution (2D + 3D) + +Declares forward kernels with explicit tile/wave/warp/pipeline parameters, +builds a registry, JIT compiles, runs on GPU, and validates against CPU reference. + +Usage: + python3 02_forward.py + python3 02_forward.py --arch gfx942 +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def cpu_conv2d_fwd(inp, wei, prob): + """Naive CPU reference: 2D forward, NHWGC layout.""" + N, Hi, Wi, G, C = inp.shape + _, Kpg, Y, X, _ = wei.shape + Ho, Wo = prob.Ho, prob.Wo + out = np.zeros((N, Ho, Wo, G, Kpg), dtype=np.float32) + for n in range(N): + for g in range(G): + for ho in range(Ho): + for wo in range(Wo): + for k in range(Kpg): + s = 0.0 + for y in range(Y): + for x in range(X): + hi = ho * prob.stride_h - prob.pad_h + y + wi = wo * prob.stride_w - prob.pad_w + x + if 0 <= hi < Hi and 0 <= wi < Wi: + for c in range(C): + s += float(inp[n, hi, wi, g, c]) * float( + wei[g, k, y, x, c] + ) + out[n, ho, wo, g, k] = s + return out + + +def main(): + parser = argparse.ArgumentParser(description="Forward Convolution (2D + 3D)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument( + "--workers", type=int, default=0, help="Max JIT workers (0=auto)" + ) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 02: Forward Convolution (2D + 3D)") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + + # ========================================================================= + # Step 1: Declare forward kernels with explicit parameters + # ========================================================================= + print("\n--- Step 1: Declare Forward Kernels ---") + reg = GroupedConvRegistry("forward_conv") + + # Forward 2D: compv4, 128x128 tile, wave 2x2x1, warp 32x32x16 + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # Forward 3D: compv3, 64x64 tile, wave 1x4x1, warp 16x16x32 + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=3, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build via registry + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + print(f" Built {len(runners)} runners in {jit_s:.1f}s") + + for key in [("forward", 2), ("forward", 3)]: + tag = "OK" if key in runners else "FAILED" + print(f" {key[0]} {key[1]}D: {tag}") + + if ("forward", 2) not in runners: + print(" ERROR: forward 2D JIT failed") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 3: Forward 2D -- GPU + CPU reference + # ========================================================================= + print("\n--- Step 3: Forward 2D ---") + prob_2d = GroupedConvProblem( + N=1, C=64, K=64, Hi=8, Wi=8, Y=3, X=3, pad_h=1, pad_w=1, direction="forward" + ) + prob_2d.print_problem() + + x = np.random.uniform(-0.5, 0.5, prob_2d.input_shape()).astype(np_dtype) + w = np.random.uniform(-0.5, 0.5, prob_2d.weight_shape()).astype(np_dtype) + + res = runners[("forward", 2)].run(x, w, prob_2d) + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print( + f" Output: shape={res.output.shape}, nonzero={np.count_nonzero(res.output)}/{res.output.size}" + ) + + ref = cpu_conv2d_fwd(x, w, prob_2d) + diff = np.abs(res.output.astype(np.float32) - ref) + match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.05) + print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}") + + # ========================================================================= + # Step 4: Forward 3D -- GPU + non-zero check + # ========================================================================= + ok_3d = True + if ("forward", 3) in runners: + print("\n--- Step 4: Forward 3D ---") + prob_3d = GroupedConvProblem( + N=1, + C=64, + K=64, + Di=8, + Hi=8, + Wi=8, + Z=3, + Y=3, + X=3, + pad_d=1, + pad_h=1, + pad_w=1, + direction="forward", + ) + prob_3d.print_problem() + + x3 = np.random.uniform(-0.5, 0.5, prob_3d.input_shape()).astype(np_dtype) + w3 = np.random.uniform(-0.5, 0.5, prob_3d.weight_shape()).astype(np_dtype) + + res3 = runners[("forward", 3)].run(x3, w3, prob_3d) + nz = np.count_nonzero(res3.output) + ok_3d = res3.success and nz > 0 + print(f" Time: {res3.time_ms:.4f} ms") + print(f" TFLOPS: {res3.tflops:.2f}") + print(f" NonZero: {nz}/{res3.output.size}") + + for r in runners.values(): + r.cleanup() + + passed = res.success and match_2d and ok_3d + print("\n" + "=" * 70) + print(f" Forward 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)") + print(f" Forward 3D: {'PASS' if ok_3d else 'FAIL'} (non-zero check)") + print(f" JIT build: {jit_s:.1f}s") + print(f" Status: {'PASS' if passed else 'FAIL'}") + print("=" * 70) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/03_bwd_data.py b/dispatcher/examples/grouped_conv/python/03_bwd_data.py new file mode 100644 index 0000000000..a000ba7c96 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/03_bwd_data.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 03: Backward Data Convolution (2D + 3D) + +dX = ConvBwdData(dY, W) + +Declares backward-data kernels with explicit parameters, +builds a registry, JIT compiles, runs on GPU, and validates +against a CPU reference. + +Usage: + python3 03_bwd_data.py +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def cpu_conv2d_bwd_data(dy, wei, prob): + """CPU ref: compute dX from dY and W.""" + N, Ho, Wo, G, Kpg = dy.shape + _, _, Y, X, C = wei.shape + Hi, Wi = prob.Hi, prob.Wi + dx = np.zeros((N, Hi, Wi, G, C), dtype=np.float32) + for n in range(N): + for g in range(G): + for hi in range(Hi): + for wi in range(Wi): + for c in range(C): + s = 0.0 + for y in range(Y): + for x in range(X): + ho = hi + prob.pad_h - y + wo = wi + prob.pad_w - x + if ho % prob.stride_h == 0 and wo % prob.stride_w == 0: + ho //= prob.stride_h + wo //= prob.stride_w + if 0 <= ho < Ho and 0 <= wo < Wo: + for k in range(Kpg): + s += float(dy[n, ho, wo, g, k]) * float( + wei[g, k, y, x, c] + ) + dx[n, hi, wi, g, c] = s + return dx + + +def main(): + parser = argparse.ArgumentParser(description="Backward Data (2D + 3D)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--workers", type=int, default=0) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 03: Backward Data Convolution (2D + 3D)") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + print(" dX = ConvBwdData(dY, W)") + + # ========================================================================= + # Step 1: Declare bwd_data kernels + # ========================================================================= + print("\n--- Step 1: Declare BwdData Kernels ---") + reg = GroupedConvRegistry("bwd_data_conv") + + # BwdData 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_data", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdData 3D: compv3, 64x64 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_data", + ndim_spatial=3, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + print(f" Built {len(runners)} runners in {jit_s:.1f}s") + + if ("bwd_data", 2) not in runners: + print(" ERROR: bwd_data 2D JIT failed") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 3: BwdData 2D -- GPU + CPU reference + # ========================================================================= + print("\n--- Step 3: Backward Data 2D ---") + prob = GroupedConvProblem( + N=1, C=32, K=32, Hi=8, Wi=8, Y=3, X=3, pad_h=1, pad_w=1, direction="bwd_data" + ) + prob.print_problem() + + dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(np_dtype) + w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(np_dtype) + + res = runners[("bwd_data", 2)].run(dy, w, prob) + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}") + + ref = cpu_conv2d_bwd_data(dy, w, prob) + diff = np.abs(res.output.astype(np.float32) - ref) + match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.1) + print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}") + + # ========================================================================= + # Step 4: BwdData 3D -- GPU + non-zero check + # ========================================================================= + ok_3d = True + if ("bwd_data", 3) in runners: + print("\n--- Step 4: Backward Data 3D ---") + prob3 = GroupedConvProblem( + N=1, + C=32, + K=32, + Di=6, + Hi=6, + Wi=6, + Z=3, + Y=3, + X=3, + pad_d=1, + pad_h=1, + pad_w=1, + direction="bwd_data", + ) + dy3 = np.random.uniform(-0.5, 0.5, prob3.output_shape()).astype(np_dtype) + w3 = np.random.uniform(-0.5, 0.5, prob3.weight_shape()).astype(np_dtype) + res3 = runners[("bwd_data", 3)].run(dy3, w3, prob3) + nz = np.count_nonzero(res3.output) + ok_3d = res3.success and nz > 0 + print(f" Time: {res3.time_ms:.4f} ms, NonZero: {nz}/{res3.output.size}") + + for r in runners.values(): + r.cleanup() + + passed = res.success and match_2d and ok_3d + print("\n" + "=" * 70) + print(f" BwdData 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)") + print(f" BwdData 3D: {'PASS' if ok_3d else 'FAIL'}") + print(f" Status: {'PASS' if passed else 'FAIL'}") + print("=" * 70) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/04_bwd_weight.py b/dispatcher/examples/grouped_conv/python/04_bwd_weight.py new file mode 100644 index 0000000000..48e50cd4a9 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/04_bwd_weight.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 04: Backward Weight Convolution (2D + 3D) + +dW = ConvBwdWeight(X, dY) + +Declares backward-weight kernels with explicit parameters, +builds a registry, JIT compiles, runs on GPU, and validates +against a CPU reference. + +Usage: + python3 04_bwd_weight.py +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def cpu_conv2d_bwd_weight(x, dy, prob): + """CPU ref: compute dW from X and dY.""" + N, Hi, Wi, G, C = x.shape + _, Ho, Wo, _, Kpg = dy.shape + Y, X_ = prob.Y, prob.X + dw = np.zeros((G, Kpg, Y, X_, C), dtype=np.float32) + for g in range(G): + for k in range(Kpg): + for y in range(Y): + for xf in range(X_): + for c in range(C): + s = 0.0 + for n in range(N): + for ho in range(Ho): + for wo in range(Wo): + hi = ho * prob.stride_h - prob.pad_h + y + wi = wo * prob.stride_w - prob.pad_w + xf + if 0 <= hi < Hi and 0 <= wi < Wi: + s += float(x[n, hi, wi, g, c]) * float( + dy[n, ho, wo, g, k] + ) + dw[g, k, y, xf, c] = s + return dw + + +def main(): + parser = argparse.ArgumentParser(description="Backward Weight (2D + 3D)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--workers", type=int, default=0) + parser.add_argument( + "--split-k", type=int, default=1, help="Split-K factor for bwd_weight (k_batch)" + ) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 04: Backward Weight Convolution (2D + 3D)") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + print(" dW = ConvBwdWeight(X, dY)") + + # ========================================================================= + # Step 1: Declare bwd_weight kernels + # ========================================================================= + print("\n--- Step 1: Declare BwdWeight Kernels ---") + reg = GroupedConvRegistry("bwd_weight_conv") + + # BwdWeight 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_weight", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdWeight 3D: compv3, 64x64 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_weight", + ndim_spatial=3, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + print(f" Built {len(runners)} runners in {jit_s:.1f}s") + + if ("bwd_weight", 2) not in runners: + print(" ERROR: bwd_weight 2D JIT failed") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 3: BwdWeight 2D -- GPU + CPU reference + # ========================================================================= + print("\n--- Step 3: Backward Weight 2D ---") + prob = GroupedConvProblem( + N=1, + C=32, + K=32, + Hi=8, + Wi=8, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="bwd_weight", + split_k=args.split_k, + ) + prob.print_problem() + + x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(np_dtype) + dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(np_dtype) + + res = runners[("bwd_weight", 2)].run(x, dy, prob) + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}") + + ref = cpu_conv2d_bwd_weight(x, dy, prob) + diff = np.abs(res.output.astype(np.float32) - ref) + match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.5) + print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}") + + # ========================================================================= + # Step 4: BwdWeight 3D -- GPU + non-zero check + # ========================================================================= + ok_3d = True + if ("bwd_weight", 3) in runners: + print("\n--- Step 4: Backward Weight 3D ---") + prob3 = GroupedConvProblem( + N=1, + C=32, + K=32, + Di=6, + Hi=6, + Wi=6, + Z=3, + Y=3, + X=3, + pad_d=1, + pad_h=1, + pad_w=1, + direction="bwd_weight", + ) + x3 = np.random.uniform(-0.5, 0.5, prob3.input_shape()).astype(np_dtype) + dy3 = np.random.uniform(-0.5, 0.5, prob3.output_shape()).astype(np_dtype) + res3 = runners[("bwd_weight", 3)].run(x3, dy3, prob3) + nz = np.count_nonzero(res3.output) + ok_3d = res3.success and nz > 0 + print(f" Time: {res3.time_ms:.4f} ms, NonZero: {nz}/{res3.output.size}") + + for r in runners.values(): + r.cleanup() + + passed = res.success and match_2d and ok_3d + print("\n" + "=" * 70) + print(f" BwdWeight 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)") + print(f" BwdWeight 3D: {'PASS' if ok_3d else 'FAIL'}") + print(f" Status: {'PASS' if passed else 'FAIL'}") + print("=" * 70) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/05_benchmark.py b/dispatcher/examples/grouped_conv/python/05_benchmark.py new file mode 100644 index 0000000000..9166ab988e --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/05_benchmark.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 05: Multi-Problem GPU Benchmark + +Declares kernels with explicit tile/wave/warp/pipeline parameters for +all directions, builds registries, JIT compiles, and benchmarks across +ResNet-like problem sizes with configurable warmup/repeat. + +Usage: + python3 05_benchmark.py + python3 05_benchmark.py --warmup 3 --repeat 10 + python3 05_benchmark.py --workers 4 +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def compute_bytes(prob, dtype_bytes=2): + in_elems = 1 + for d in prob.input_shape(): + in_elems *= d + wei_elems = 1 + for d in prob.weight_shape(): + wei_elems *= d + out_elems = 1 + for d in prob.output_shape(): + out_elems *= d + return (in_elems + wei_elems + out_elems) * dtype_bytes + + +def main(): + parser = argparse.ArgumentParser(description="Multi-Problem GPU Benchmark") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations") + parser.add_argument("--repeat", type=int, default=5, help="Benchmark iterations") + parser.add_argument( + "--workers", type=int, default=0, help="Max JIT workers (0=auto)" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 05: Multi-Problem GPU Benchmark") + print("=" * 70) + print(f"\n Arch: {args.arch}, Dtype: {args.dtype}") + print(f" Warmup: {args.warmup}, Repeat: {args.repeat}") + + # ========================================================================= + # Step 1: Declare all kernels with explicit parameters + # ========================================================================= + print("\n--- Step 1: Declare Kernels ---") + reg = GroupedConvRegistry("benchmark") + + # Forward 2D: compv4, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # Forward 3D: compv3, 64x64 tile + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=3, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdData 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_data", + ndim_spatial=2, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdWeight 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_weight", + ndim_spatial=2, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runner_by_key = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + + for key in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)]: + tag = "OK" if key in runner_by_key else "FAILED" + print(f" {key[0]:12s} {key[1]}D: {tag}") + print(f" JIT build time: {jit_s:.3f} s") + + missing = [ + k + for k in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)] + if k not in runner_by_key + ] + if missing: + print(f"\n ERROR: missing {missing}") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + def bench_run(runner, inp, wei, prob): + for _ in range(args.warmup): + runner.run(inp, wei, prob) + times = [] + for _ in range(args.repeat): + r = runner.run(inp, wei, prob) + if r.success: + times.append(r.time_ms) + if not times: + return 0.0, 0.0 + return min(times), sum(times) / len(times) + + # ========================================================================= + # Step 3: 2D Forward benchmark + # ========================================================================= + print("\n--- Step 3: Forward 2D Benchmark ---") + print( + f"{'Problem':<18} {'N':>3} {'C':>4} {'K':>4} {'H':>3} {'W':>3} " + f"{'F':>3} {'Min(ms)':>9} {'Avg(ms)':>9} {'TFLOPS':>8} {'GB/s':>8}" + ) + print("-" * 85) + + all_ok = True + for label, n, c, k, h, w, y, x, s, p in [ + ("ResNet-stage2", 1, 64, 64, 56, 56, 3, 3, 1, 1), + ("ResNet-stage3", 1, 128, 128, 28, 28, 3, 3, 1, 1), + ("ResNet-stage4", 1, 256, 256, 14, 14, 3, 3, 1, 1), + ("ResNet-stage5", 1, 512, 512, 7, 7, 3, 3, 1, 1), + ("Pointwise-1x1", 1, 256, 256, 56, 56, 1, 1, 1, 0), + ("Batch-8", 8, 64, 128, 56, 56, 3, 3, 1, 1), + ("Batch-32", 32, 64, 128, 56, 56, 3, 3, 1, 1), + ]: + prob = GroupedConvProblem( + N=n, + C=c, + K=k, + Hi=h, + Wi=w, + Y=y, + X=x, + stride_h=s, + stride_w=s, + pad_h=p, + pad_w=p, + direction="forward", + ) + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + min_ms, avg_ms = bench_run(runner_by_key[("forward", 2)], inp, wei, prob) + if avg_ms > 0: + tflops = prob.flops / (avg_ms * 1e9) + bw = compute_bytes(prob) / (avg_ms * 1e6) + print( + f"{label:<18} {n:>3} {c:>4} {k:>4} {h:>3} {w:>3} " + f"{y}x{x} {min_ms:>9.4f} {avg_ms:>9.4f} {tflops:>8.2f} {bw:>8.1f}" + ) + else: + all_ok = False + + # ========================================================================= + # Step 4: 3D Forward + # ========================================================================= + print("\n--- Step 4: Forward 3D ---") + for label, n, c, k, d, h, w, z, y, x in [ + ("3D-small", 1, 64, 64, 8, 16, 16, 3, 3, 3), + ("3D-medium", 1, 64, 128, 16, 32, 32, 3, 3, 3), + ]: + prob = GroupedConvProblem( + N=n, C=c, K=k, Di=d, Hi=h, Wi=w, Z=z, Y=y, X=x, direction="forward" + ) + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + min_ms, avg_ms = bench_run(runner_by_key[("forward", 3)], inp, wei, prob) + if avg_ms > 0: + tflops = prob.flops / (avg_ms * 1e9) + print(f" {label:<14} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS") + + # ========================================================================= + # Step 5: Backward directions + # ========================================================================= + print("\n--- Step 5: Backward Directions ---") + for label, direction in [ + ("bwd_data ResNet-s3", "bwd_data"), + ("bwd_weight ResNet-s3", "bwd_weight"), + ]: + prob = GroupedConvProblem( + N=1, + C=128, + K=128, + Hi=28, + Wi=28, + Y=3, + X=3, + stride_h=1, + stride_w=1, + pad_h=1, + pad_w=1, + direction=direction, + ) + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + min_ms, avg_ms = bench_run(runner_by_key[(direction, 2)], inp, wei, prob) + if avg_ms > 0: + tflops = prob.flops / (avg_ms * 1e9) + print( + f" {label:<14} {direction:>12} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS" + ) + + for runner in runner_by_key.values(): + runner.cleanup() + + print("\n" + "=" * 70) + print(f" JIT build: {jit_s:.3f} s") + print(f" Warmup: {args.warmup}, Repeat: {args.repeat}") + print(f" Status: {'PASS' if all_ok else 'FAIL'}") + print("=" * 70) + return 0 if all_ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/06_registry_json.py b/dispatcher/examples/grouped_conv/python/06_registry_json.py new file mode 100644 index 0000000000..1a3dc854e7 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/06_registry_json.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 06: Registry, Heuristic Selection & JSON Export + +Declares multiple kernel configurations with different tile sizes, +builds a registry, demonstrates heuristic runtime kernel selection, +JSON round-trip, and GPU execution. + +Usage: + python3 06_registry_json.py + python3 06_registry_json.py --workers 4 +""" + +import sys +import time +import argparse +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def conv_heuristic(problem): + spatial = problem.Ho * problem.Wo + if spatial > 400: + return ["256", "128", "64"] + return ["64", "128", "256"] + + +def main(): + parser = argparse.ArgumentParser(description="Registry, Heuristic & JSON") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--workers", type=int, default=0) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 06: Registry, Heuristic Selection & JSON Export") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + + # Step 1: Declare kernels with full explicit parameters + print("\n--- Step 1: Declare Kernels + Build Registry ---") + reg = GroupedConvRegistry("conv_tiles") + + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=256, + tile_k=256, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + ) + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + ) + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + ) + reg.print_registry() + + # Step 2: Heuristic kernel selection + print("\n--- Step 2: Heuristic Kernel Selection ---") + problems = [ + ( + "small_7x7", + GroupedConvProblem( + N=1, + C=512, + K=512, + Hi=7, + Wi=7, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="forward", + ), + ), + ( + "medium_14x14", + GroupedConvProblem( + N=1, + C=256, + K=256, + Hi=14, + Wi=14, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="forward", + ), + ), + ( + "large_56x56", + GroupedConvProblem( + N=1, + C=64, + K=128, + Hi=56, + Wi=56, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="forward", + ), + ), + ] + print(f" {'Problem':<16} {'Spatial':>8} {'Selected Kernel':<50}") + print(f" {'-' * 74}") + for label, prob in problems: + selected = reg.select(prob, heuristic=conv_heuristic) + spatial = prob.Ho * prob.Wo + sel_name = selected.name if selected else "none" + print(f" {label:<16} {spatial:>8} {sel_name:<50}") + + # Step 3: JSON round-trip + print("\n--- Step 3: JSON Round-Trip ---") + json_str = reg.to_json() + print(f" Exported: {len(json_str)} bytes, {len(reg)} kernels") + imported = GroupedConvRegistry.from_json(json_str) + print(f" Imported: {len(imported)} kernels") + orig = reg.kernels[0] + imp = imported.kernels[0] + rt_ok = ( + orig.vector_size_a == imp.vector_size_a + and orig.block_per_cu == imp.block_per_cu + and orig.tile_n == imp.tile_n + ) + print(f" Full fields round-trip: {'OK' if rt_ok else 'FAIL'}") + + # Step 4: JIT build + GPU execution + print("\n--- Step 4: JIT Build + GPU Execution ---") + workers = args.workers if args.workers > 0 else None + jit_reg = GroupedConvRegistry("jit_conv") + jit_reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + ) + ) + t0 = time.perf_counter() + runners = jit_reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + + if ("forward", 2) not in runners: + print(" JIT build failed") + return 1 + runner = runners[("forward", 2)] + print(f" JIT build: {jit_s:.3f} s") + print(f" Library: {runner.library_path}") + + prob = GroupedConvProblem( + N=1, C=128, K=128, Hi=16, Wi=16, Y=3, X=3, pad_h=1, pad_w=1, direction="forward" + ) + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + res = runner.run(inp, wei, prob) + runner.cleanup() + + if res.success: + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}") + + gpu_ok = res.success + print("\n" + "=" * 70) + print(f" Registry: {len(reg)} kernels (3 tile configs)") + print(" Heuristic: spatial-based selection demonstrated") + print(f" JSON: round-trip {'OK' if rt_ok else 'FAIL'}") + print(f" GPU: {'OK' if gpu_ok else 'FAIL'}") + print(f" Status: {'PASS' if gpu_ok and rt_ok else 'FAIL'}") + print("=" * 70) + return 0 if gpu_ok and rt_ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/heuristics/.gitignore b/dispatcher/heuristics/.gitignore new file mode 100644 index 0000000000..d9523255bf --- /dev/null +++ b/dispatcher/heuristics/.gitignore @@ -0,0 +1,60 @@ +# Python bytecode and caches +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python + +# Jupyter notebooks +*.ipynb +.ipynb_checkpoints/ + +# Virtual environments +.venv/ +venv/ +ENV/ + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Test output and logs +*.log +test_output.log +custom_shapes_gpu_test.log + +# Benchmark and analysis output files +*.csv +*.json +!models/*/feature_spec.json +!models/*/train_manifest.json + +# Data files (parquet, arrow) +*.parquet +*.arrow + +# Temporary and NFS files +.nfs* +*.tmp +*.bak + +# Decompressed model files (compressed .lgbm.gz versions are tracked) +models/**/*.lgbm + +# User-specific test and analysis scripts +test_*.py +!tests/test_*.py +find_*.py +oracle_*.json +validation_results_*.csv +custom_shapes_*.csv +fp16_bf16_*.csv + +# Ignore all markdown files except tracked documentation +*.md +!DATA_GENERATION.md +!LEARNINGS.md +!README.md diff --git a/dispatcher/heuristics/DATA_GENERATION.md b/dispatcher/heuristics/DATA_GENERATION.md new file mode 100644 index 0000000000..819e77fe48 --- /dev/null +++ b/dispatcher/heuristics/DATA_GENERATION.md @@ -0,0 +1,412 @@ +# Data Generation Guide + +This document explains how to build benchmark binaries from the CK Tile engine, +generate benchmark datasets, and manage them for the ML kernel performance +prediction system. + +## Overview + +The ML heuristic needs benchmark data: measured TFLOPS, latency, and bandwidth +for every (problem shape, kernel config) pair. The tile engine builds one +executable per kernel configuration. Each executable benchmarks a single kernel +on a given problem size and outputs JSON with performance metrics. + +``` +CK source --> CMake configure --> ninja build --> benchmark binaries + (4608 per op/dtype/layout) + +benchmark binaries --> run on GPU --> streaming log --> parquet dataset + (per shape) (JSON blocks) (canonical schema) +``` + +## Prerequisites + +- **ROCm**: HIP >= 6.0.3 (for gfx950: HIP >= 6.0.4) +- **Build tools**: CMake >= 3.21, Ninja, HIP-aware clang compiler +- **Python**: 3.10+ with `pandas`, `pyarrow` +- **GPU**: ROCm-capable AMD GPU (MI250X, MI300X, MI355X, etc.) + +--- + +## Part 1: Building Benchmark Binaries from the Tile Engine + +If you already have pre-built binaries (e.g., in `/workspace/ck_tile/bin/`), +skip to Part 2. This section explains how to build them from source. + +### Step 1: CMake Configure + +From the CK repository root: + +```bash +cmake -S /workspace/rocm-libraries/projects/composablekernel \ + -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx950" \ + -DGEMM_UNIVERSAL_DATATYPE="fp8" \ + -DGEMM_UNIVERSAL_LAYOUT="rcr" \ + -G Ninja +``` + +**Key CMake variables:** + +| Variable | Default | Description | +|---|---|---| +| `GPU_TARGETS` | (required) | Target GPU architectures. Supported: `gfx90a`, `gfx942`, `gfx950`, `gfx1201`. Semicolon-separated for multiple. | +| `GEMM_UNIVERSAL_DATATYPE` | `"fp8;fp16"` | Data types to build. Options: `fp8`, `fp16`, `bf16`, `bf8`. Semicolon-separated. | +| `GEMM_UNIVERSAL_LAYOUT` | `"rcr;rrr;crr;ccr"` | Layouts to build. Semicolon-separated. | +| `GEMM_UNIVERSAL_CONFIG_FILE` | `"default_config.json"` | Kernel config file (in the `configs/` directory). Controls which tile sizes, warp configs, pipelines, etc. are enumerated. | +| `ENABLE_CCACHE_GEMM_UNIVERSAL` | `OFF` | Enable ccache for faster rebuilds. | + +**Example: build only fp8 RCR for gfx950 (fastest, ~4608 kernels):** +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx950" \ + -DGEMM_UNIVERSAL_DATATYPE="fp8" \ + -DGEMM_UNIVERSAL_LAYOUT="rcr" \ + -G Ninja +``` + +**Example: build all dtypes and layouts (slow, ~4608 * 4 * 4 = ~73K kernels):** +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx950" \ + -DGEMM_UNIVERSAL_DATATYPE="fp8;fp16;bf16;bf8" \ + -DGEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \ + -G Ninja +``` + +### What happens during configure + +1. CMake calls `gemm_universal_instance_builder.py --list_kernels` to enumerate + all valid kernel configurations from the config JSON. +2. It writes `gemm_universal_kernel_list.txt` (one kernel per line) and + `gemm_universal_kernel_count.txt` to the build directory. +3. For each kernel, it creates a ninja build target. + +### Step 2: Build + +```bash +# Build all benchmarks for the configured dtypes/layouts +ninja -C build benchmark_gemm_universal_all + +# Or build a specific dtype/layout combo +ninja -C build benchmark_gemm_universal_fp8_rcr + +# Or build by pipeline type +ninja -C build benchmark_gemm_universal_compv4_pipeline +ninja -C build benchmark_gemm_universal_mem_pipeline + +# Or build a single specific kernel +ninja -C build benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 +``` + +**Build time estimates:** +- ~4608 kernels (one dtype, one layout): 1-4 hours depending on CPU cores +- Use `-j ` to control parallelism: `ninja -C build -j 32 benchmark_gemm_universal_fp8_rcr` + +### Step 3: Verify binaries + +Binaries are placed in `build/bin/`: + +```bash +ls build/bin/benchmark_gemm_universal_fp8_rcr_* | wc -l +# Expected: 4608 (for default config) + +# Test one binary +./build/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \ + -m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0 +``` + +### Kernel config files + +The config files live in: +``` +tile_engine/ops/gemm/gemm_universal/configs/ + default_config.json # Default: full enumeration + default_ci_config.json # CI: reduced set for fast testing + user_provided_config.json # Custom: your own subset +``` + +To use a custom config: +```bash +cmake ... -DGEMM_UNIVERSAL_CONFIG_FILE="user_provided_config.json" +``` + +The config controls which tile sizes (e.g., 128x128x64, 256x256x32), warp +configurations (e.g., 2x2x1, 1x4x1), pipelines (compv3, compv4, mem), +schedulers, and other parameters are included in the kernel enumeration. + +### Building StreamK / other ops + +The same pattern applies to other tile engine ops: + +```bash +# StreamK +ninja -C build benchmark_gemm_streamk_fp8_rcr + +# Grouped convolution +ninja -C build benchmark_grouped_conv_fwd_fp16_nhwgc +``` + +Each op has its own instance builder and config directory. + +--- + +## Part 2: Running Benchmarks and Generating Data + +## Quick Start + +### 1. Run benchmarks for a set of shapes + +Each binary accepts `-m=`, `-n=`, `-k=`, `-warmup=`, `-repeat=`, `-verify=` flags +and outputs JSON to stdout: + +```bash +/workspace/ck_tile/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \ + -m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0 +``` + +Output: +```json +{ + "name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_...", + "problem": { + "split_k": 1, "m": 1024, "n": 1024, "k": 1024, + "dtype_a": "fp8", "dtype_b": "fp8", ... + }, + "perf_result": { + "latency(ms)": 0.04, + "tflops(TFlops)": 204.60, + "bandwidth(GB/s)": 624.39 + } +} +``` + +### 2. Batch generation using provided scripts + +**Wide coverage (diverse shapes across all regimes):** +```bash +python3 generate_wide_coverage.py \ + --bin_dir /workspace/ck_tile/bin \ + --out_dir data/wide_coverage \ + --batch_size 25 \ + --warmup 3 --repeat 10 +``` + +**Edge-case dimensions (N=1, K=1, small N/K):** +```bash +python3 generate_edge_dims.py +``` + +Both scripts write streaming log files that `data_pipeline.py` can parse. + +### 3. Parse logs into parquet + +```bash +python3 data_pipeline.py \ + -o data/my_dataset.parquet \ + --arch gfx950 \ + --capture_hw +``` + +The `--capture_hw` flag runs `rocminfo` once and injects the GPU hardware +profile (CU count, clock speed, cache sizes, etc.) into every row. + +## Canonical Data Schema + +Every parquet file follows this schema: + +| Column | Type | Description | +|---|---|---| +| `op_type` | str | `gemm_universal`, `gemm_streamk`, etc. | +| `dtype` | str | `fp8`, `fp16`, `bf16`, `bf8` | +| `layout` | str | `rcr`, `rrr`, `crr`, `ccr` | +| `arch` | str | `gfx942`, `gfx950`, etc. | +| `kernel_name` | str | Full kernel identifier | +| `m`, `n`, `k` | int | Problem dimensions | +| `split_k` | int | Split-K factor (1 = standard) | +| `measured_tflops` | float | Ground-truth TFLOPS | +| `latency_ms` | float | Measured latency | +| `bandwidth_gb_s` | float | Measured bandwidth | +| `is_valid` | bool | True if tflops > 0 and latency > 0 | +| `tile_m`, `tile_n`, `tile_k` | int | Tile dimensions | +| `warp_m`, `warp_n`, `warp_k` | int | Warp config | +| `warp_tile_m/n/k` | int | Warp tile dimensions | +| `pipeline` | str | `compv3`, `compv4`, `mem`, etc. | +| `scheduler` | str | `intrawave`, `interwave` | +| `epilogue` | str | `cshuffle`, `default` | +| `pad_m`, `pad_n`, `pad_k` | bool | Padding flags | +| `persistent` | bool | Persistent kernel flag | +| `run_id` | str | Unique collection run identifier | + +## Shape Selection Guidelines + +Good training data requires diverse shapes. Cover all of these regimes: + +### By M dimension (batch size / output rows) +- **M=1**: single-token inference (hardest case for tiling) +- **Tiny M (2-16)**: small batch inference +- **Small M (32-128)**: medium batch +- **Medium M (256-2048)**: large batch / training +- **Large M (4096-20480)**: very large batch + +### By N and K dimension +- **N=1**: vector-matrix multiply (degenerate) +- **K=1**: rank-1 update / outer product (degenerate) +- **Small N or K (2-16)**: stress tile efficiency +- **Deep K (K > 4096)**: compute-bound regime +- **Shallow K (K < 256)**: memory-bound regime + +### By shape family +- **Square**: M ~ N ~ K (powers of 2) +- **Tall**: M >> N (tall output matrix) +- **Wide**: N >> M (wide output matrix) +- **Deep-K**: K >> M and K >> N + +### Special cases +- **Prime dimensions**: 17, 31, 127, 251, 509, 1021, 2039, 4093 + (worst-case for tile alignment, tests padding logic) +- **Non-power-of-2**: 48, 96, 192, 384, 576, 768, 1536, 3072, 4608 + (common in LLM architectures) +- **LLM inference shapes**: DeepSeek, LLaMA-7B, LLaMA-70B MLP/attention dims + +### Minimum recommended coverage + +For a production-quality model, aim for: +- At least 200 unique (M, N, K) shapes +- At least 10 shapes per shape family +- All kernel configs (4608 for fp8 RCR) run against every shape +- Multiple layouts if training a cross-layout model + +## Benchmark Quality Guidelines + +### Warmup and repeat +- Minimum `warmup=3`, `repeat=10` for fast iteration +- Production quality: `warmup=5`, `repeat=20` for stable measurements +- The `perf_result` values are averaged over `repeat` iterations + +### Noise handling +- Use **median** latency when aggregating multiple runs of the same benchmark +- Flag measurements where coefficient of variation exceeds 10% +- Avoid benchmarking under thermal throttling (check GPU temperature) +- Lock GPU clocks if possible for reproducibility + +### Environment metadata +Store with every dataset: +- GPU model and architecture (from `rocminfo`) +- ROCm driver version +- Clock mode (default / locked) +- Git hash of the CK tile engine build (if available) +- Timestamp + +## Adding Data for a New Op + +To generate benchmark data for a new operation (e.g., `gemm_streamk`): + +1. **Build the binaries** using the tile engine: + ```bash + ninja -C build benchmark_gemm_streamk_fp8_rcr + ``` + +2. **Write a generation script** (or modify `generate_wide_coverage.py`): + - Change the executable glob pattern to match the new op + - Add any op-specific CLI flags the binaries need + +3. **Run and parse**: + ```bash + python3 data_pipeline.py my_streamk_run.log \ + -o data/gemm_streamk_fp8_gfx950.parquet --arch gfx950 + ``` + +4. **Train**: + ```bash + python3 train.py --op gemm_streamk --dtype fp8 --arch gfx950 \ + --data_dir data/ --out_dir models/gemm_streamk_fp8_gfx950 + ``` + +## Adding Data for a New Layout + +Same binaries, same shapes -- just change the layout filter: + +```bash +# Build rrr binaries +ninja -C build benchmark_gemm_universal_fp8_rrr + +# Generate and parse +# ... (same flow, different bin_dir or executable glob) + +# Train a cross-layout model by putting all layouts in the same data_dir +python3 train.py --data_dir data/ --out_dir models/gemm_universal_fp8_gfx950_all_layouts +``` + +The feature engine includes `layout` as a categorical feature, so one model +can handle all layouts. + +## Incremental Data Collection + +When you have a trained model and want to add more data: + +1. Generate new data (new shapes, new layouts, etc.) +2. Parse into parquet alongside existing data +3. Warm-start from the previous model: + ```bash + python3 train.py --data_dir data/ --out_dir models/v2 \ + --warm_start models/v1 \ + --warm_start_n_estimators 200 + ``` + +This adds 200 new trees on top of the existing model. The feature schema +must match exactly (enforced automatically). + +## File Organization + +Recommended directory structure: + +``` +heuristics/ + data/ + gemm_universal_fp8_rcr_gfx950.parquet # original 108 shapes + wide_coverage/ # batch log files + wide_coverage_batch_001.log + wide_coverage_batch_002.log + ... + edge_dims/ # N=1, K=1 edge cases + edge_dims_batch_001.log + ... + models/ + gemm_universal_fp8_gfx950/ # trained model artifacts + model_tflops.lgbm + model_latency.lgbm + model_bandwidth.lgbm + feature_spec.json + train_manifest.json + cv_metrics_tflops.json + eval_report.json + ... +``` + +## Troubleshooting + +### Benchmark binary exits with non-zero code +Some kernel configs are invalid for certain problem sizes (e.g., tile_m=256 +with M=16). The data pipeline marks these as `is_valid=False` and they are +filtered out during training. This is expected. + +### Edge dims produce very few results +N=1 and K=1 shapes are degenerate -- most kernel configurations have minimum +dimension requirements and will fail or produce zero TFLOPS. The small number +of valid results is still useful (it tells the model which configs work for +these shapes). + +### Benchmarks are slow +Each shape requires running all 4608 kernel executables sequentially. At +~0.01s per kernel, that is ~46 seconds per shape. For 700 shapes, expect +~9 hours. Tips: +- Run on a dedicated GPU (no other workloads) +- Use `--batch_size 25` to get incremental output +- Parse and train on partial data while generation continues + +### Data from different GPUs / driver versions +Store `run_id` and hardware metadata with each dataset. Training on mixed +data is allowed but not recommended for production models. Filter to a +single `run_id` or `arch` for clean experiments. diff --git a/dispatcher/heuristics/LEARNINGS.md b/dispatcher/heuristics/LEARNINGS.md new file mode 100644 index 0000000000..dba3514601 --- /dev/null +++ b/dispatcher/heuristics/LEARNINGS.md @@ -0,0 +1,151 @@ +# Learnings and Design Decisions + +Empirical findings from building the CK Tile kernel performance prediction system. +These inform the current defaults and explain why certain approaches were chosen. + +## 1. Log-Transform is Essential for Cross-Scale Accuracy + +**Problem**: GEMM TFLOPS spans 5 orders of magnitude across different problem +sizes. When training on raw TFLOPS, the regression loss (RMSE) is dominated by +large shapes where absolute errors are biggest. The model learns to predict +large shapes accurately but ignores tiny shapes where the TFLOPS values are +much lower. + +**Evidence** (168 shapes, 626K rows, 5-fold GroupKFold CV): + + +| Model | Mean Eff | P10 Eff | tiny_m Eff | Min Eff | +| ----------------------------- | ---------- | ---------- | ---------- | ---------- | +| Raw TFLOPS (500 trees) | 92.73% | 80.24% | 84.55% | 4.26% | +| **log1p(TFLOPS)** (500 trees) | **96.92%** | **94.34%** | **94.89%** | **60.27%** | +| log1p(TFLOPS) (2000 trees) | 97.51% | 93.89% | 96.04% | 63.56% | + + +**Solution**: Train on `log1p(measured_tflops)` and apply `expm1()` to +predictions. This is now the default in `train.py`. Pass `--no_log_transform` +to revert to raw regression (not recommended). + +**Why log1p, not log**: `log1p(x) = log(1 + x)` handles zero and near-zero +TFLOPS gracefully, whereas `log(x)` produces -inf for x=0. + +## 2. Tiny-M Shapes are the Hardest Case + +M=1 (single-token inference) shapes are fundamentally different from batch shapes: + +- Most kernel configurations produce very low TFLOPS +- The "best" kernel is often only marginally better than the rest +- The oracle performance itself is very low, so any prediction error tanks efficiency +- Many kernels fail outright (tile_m=128 with M=1 wastes 127/128 of the tile) + +The bottom shapes in our evaluation are all M=1, with efficiencies in the +63-70% range. These shapes have such low absolute performance that the model's +noise floor exceeds the performance difference between kernels. + +**Mitigation**: Log-transform helps significantly (tiny_m improved from 84% to +96%). For production use with M=1, consider a dedicated fallback (e.g., +hardcoded kernel selection for M < 4 based on known-good configs). + +## 3. IHEM (Hard Example Mining) Hurts When Scale is the Issue + +We tried Iterative Hard Example Mining with sample reweighting (2x-5x weight +on hard shapes). Result: it made things **worse**, degrading mean efficiency +from 94.31% to 92.90% over 3 iterations. + +**Why**: The hard shapes are hard because of scale mismatch, not because the +model lacks capacity. Reweighting amplifies the small-TFLOPS rows, which +distorts the learned relationship between features and performance for the +majority of shapes. The log-transform was the correct fix -- it addresses the +root cause (scale) rather than the symptom (bad predictions on tiny shapes). + +**Lesson**: IHEM is useful when the model has capacity gaps (e.g., certain +pipeline types are underrepresented). It is counterproductive when the issue +is target-variable scale. Always try target transforms before reweighting. + +## 4. GroupKFold Key = (M, N, K) Forces Generalization + +The validation uses `GroupKFold` where the group key is `(M, N, K)` -- all +kernels for the same shape go to the same fold. This means: + +- The model is always evaluated on shapes it has **never seen** during training +- Layout is excluded from the key, forcing the model to generalize across layouts +- Since models are per-arch, `arch` is implicit (constant within one training run) + +This is much stricter than random row splitting, where the model would see some +kernels for each shape during training. Our efficiency numbers are conservative +estimates of real-world performance on unseen shapes. + +## 5. Model Size vs Accuracy Tradeoff + + +| Config | Trees | Leaves | LR | Mean Eff | P10 Eff | Train Time | +| ------------------ | -------- | ------- | -------- | ---------- | ---------- | ------------- | +| Small (default v1) | 500 | 127 | 0.05 | 96.92% | 94.34% | ~20s | +| **Big (current)** | **2000** | **255** | **0.02** | **97.51%** | **93.89%** | **~25s/fold** | + + +The bigger model improved mean efficiency by 0.6% but P10 didn't improve +(actually slightly worse). The extra capacity helps on medium shapes but +doesn't crack the tiny-M floor. This suggests the feature set, not model +capacity, is the limiting factor for the hardest shapes. + +For C++ deployment, the bigger model (2000 trees, 255 leaves) is still fast +enough -- LightGBM inference is O(trees * log(leaves)) per sample, which is +~microseconds even at 2000 trees. + +## 6. N=1 and K=1 Shapes are Degenerate + +We generated benchmark data for 546 edge-case shapes (N=1, K=1, small N/K). +Result: **zero valid kernel results** across 94 shapes. All 4608 kernels either +fail or produce 0 TFLOPS for these degenerate dimensions. + +This means: + +- The tile engine kernels have hard minimum dimension requirements +- N=1 / K=1 shapes cannot be handled by the current kernel set +- These shapes need dedicated kernels (e.g., BLAS-1/BLAS-2 fallbacks) +- The ML model should not be expected to handle them -- they should be filtered +out before reaching the heuristic + +## 7. Feature Engineering Insights + +From LightGBM feature importances on the log-target model: + +**Top features** (by split count): + +- `M, N, K` -- raw dimensions are always the most important +- `tile_m, tile_n, tile_k` -- the tile shape is the primary kernel differentiator +- `overall_tile_efficiency` -- how well the shape fits the tile (the interaction) +- `num_tiles_m, total_output_tiles` -- work decomposition +- `arithmetic_intensity` -- compute vs memory bound regime +- `pipeline` -- pipeline type (compv3 vs compv4 vs mem) significantly affects perf + +**Low-importance features**: + +- Hardware constants (CUs, clock, caches) -- they're constant within one arch +model, so they provide no discriminative signal. They'll become important when +training cross-arch models. +- `split_k` -- always 1 in current data +- `persistent` -- rarely True in current kernel set + +## 8. Warm-Start Works for Incremental Updates + +LightGBM's `init_model` parameter successfully continues training from an +existing model. New trees are added on top of existing ones. Key considerations: + +- Feature schema must match exactly (enforced by `check_feature_compatibility`) +- Use fewer new trees (200-500) since we're refining, not starting fresh +- The `train_manifest.json` tracks the full lineage (total trees, data sizes) +- Quality should be at least as good as the base model (tested) + +## 9. Data Volume Matters More Than Model Complexity + + +| Dataset | Shapes | Rows | Mean Eff (log, 500 trees) | +| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------ | ---- | ----------------------------- | +| Original (DeepSeek only) | 108 | 418K | 98.28% (on seen distribution) | +| + Wide coverage M=1 distribution. Adding 60 diverse shapes (many M=1) exposed the model's weakness on tiny shapes. More diverse training data is always better than a bigger model on narrow data.Summary of DefaultsBased on these findings, the current defaults in `train.py` are:- **Target transform**: `log1p` for TFLOPS and bandwidth (scale normalization)- **Model**: 2000 trees, 255 leaves, max depth 15, LR 0.02- **Validation**: 5-fold GroupKFold, key = (M, N, K)- **Early stopping**: patience 100 (let trees fully converge)- **Warm start**: 500 new trees (was 200, increased for bigger base model) | 168 | 626K | 96.92% (harder distribution) | + + +The original 108-shape model looked great (98.28%) but was overfitting to the +DeepSeek LLM inference + diff --git a/dispatcher/heuristics/README.md b/dispatcher/heuristics/README.md new file mode 100644 index 0000000000..91b07466b6 --- /dev/null +++ b/dispatcher/heuristics/README.md @@ -0,0 +1,271 @@ +# CK Tile Heuristics: ML-Based Kernel Selection + +Fast, accurate kernel selection for CK Tile operations using LightGBM regression +with Origami-augmented feature engineering. + +## What This Does + +Instead of running all 4608+ kernel configurations on the GPU to find the best +one (exhaustive search taking ~46 seconds per shape), this system trains an ML +model that predicts TFLOPS for any (problem, kernel) pair in microseconds. It +scores all candidates instantly and picks the best kernel -- achieving 98.28% +of oracle-best TFLOPS efficiency across 108 tested shapes. + +## Quick Start + +### 1. Generate and convert benchmark data + +**Step 1: Generate benchmark data** + +```bash +python3 generate_benchmark_data.py \ + --build_dir /path/to/build \ + --output_dir data/fp16_original \ + --dtype fp16 \ + --layout rcr \ + --num_build_jobs 4 \ + --warmup 10 \ + --repeat 50 +``` + +This outputs JSON with all benchmark results. + +**Step 2: Convert JSON to parquet training format** + +```bash +python3 convert_json_to_parquet.py \ + --input data/fp16_original/benchmark_results_fp16_rcr.json \ + --output data/fp16_original/fp16_training_data.parquet \ + --arch gfx950 +``` + +The converter automatically fixes pad flags for `_mem` kernels and validates data. + +**Alternative: Parse existing logs** + +If you have raw benchmark logs from CK Tile: + +```bash +python3 data_pipeline.py ck_tile_testrun_2.log \ + -o data/gemm_universal_fp8_rcr_gfx950.parquet \ + --arch gfx950 --capture_hw +``` + +### 2. Train a model + +```bash +python3 train.py \ + --data_dir data/ \ + --out_dir models/gemm_universal_fp8_gfx950 \ + --op gemm_universal --dtype fp8 --arch gfx950 +``` + +**Note**: Trained models are automatically compressed to `.lgbm.gz` format to save space (~67% reduction). The Python tools automatically decompress them on first use and cache the decompressed version. For warm-start training, decompression happens automatically. + +### 3. Evaluate + +```bash +python3 evaluate.py \ + --model_dir models/gemm_universal_fp8_gfx950 \ + --data_dir data/ --op gemm_universal --dtype fp8 +``` + +### 4. Predict the best kernel for a problem + +```bash +python3 predict.py \ + --model_dir models/gemm_universal_fp8_gfx950 \ + --m 128 --n 1536 --k 7168 --layout rcr +``` + +### 5. Search for optimal configs (optional) + +```bash +python3 search.py \ + --model_dir models/gemm_universal_fp8_gfx950 \ + --m 128 --n 1536 --k 7168 \ + --strategy random --budget 500 --top_k 10 +``` + +### 6. Using models in C++ (requires decompression) + +C++ code uses the LightGBM C API which requires uncompressed `.lgbm` files. If you have compressed models (`.lgbm.gz`), decompress them first: + +```bash +cd models/gemm_universal_fp16_gfx950 +gunzip model_tflops.lgbm.gz +``` + +Then use in C++ examples: + +```bash +cd dispatcher/build +./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm +``` + +**Note**: Python tools automatically decompress `.lgbm.gz` files on first use, so you can run Python scripts first to trigger decompression, then use the same models in C++. + +## Architecture + +``` +Problem (M, N, K, dtype, layout) + | + v +FeatureEngine.extract_batch() <-- 55 features: problem, kernel, interaction, hardware + | + v +LGBMRegressor.predict() <-- predicts TFLOPS for each candidate kernel + | + v +Sort by predicted TFLOPS <-- rank all candidates + | + v +Select Top-1 kernel <-- 98.28% mean efficiency, <1ms inference +``` + +Three models are trained per (op, dtype, arch): +- **TFLOPS model** (primary): used for kernel ranking +- **Latency model** (auxiliary): for latency-sensitive workloads +- **Bandwidth model** (auxiliary): for memory-bound analysis + +## File Inventory + +| File | Purpose | +|---|---| +| `generate_benchmark_data.py` | Build and run benchmarks across ~25 diverse problem sizes, output JSON | +| `convert_json_to_parquet.py` | Convert benchmark JSON to parquet training format, fix `_mem` pad flags | +| `data_pipeline.py` | Parse raw benchmark logs into canonical parquet datasets | +| `feature_engine.py` | 55-feature extraction: problem, kernel, interaction, hardware profile | +| `train.py` | Multi-target LGBMRegressor training with GroupKFold CV, IHEM, warm-start | +| `predict.py` | Predictor class: predict TFLOPS/latency/bandwidth, rank kernels | +| `evaluate.py` | Full evaluation: global metrics, per-shape/layout/pipeline slices | +| `search.py` | Surrogate search: discrete DE, random top-K | +| `generate_wide_coverage.py` | Generate benchmark data across 706 diverse shapes | +| `generate_edge_dims.py` | Generate N=1, K=1, and other edge-case shapes | +| `DATA_GENERATION.md` | Detailed guide for building binaries and generating data | +| `plan.md` | Full design plan with architecture, milestones, and rationale | + +## Features Used (55 total) + +### Problem features (13) +`M, N, K, split_k, log2(M), log2(N), log2(K), log2(MNK), +arithmetic_intensity, aspect_ratio_mn, aspect_ratio_mk, aspect_ratio_nk, layout` + +### Kernel features (17) +`tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, +warp_tile_k, pipeline, scheduler, epilogue, pad_m, pad_n, pad_k, persistent, +num_warps, tile_volume, tile_mn, lds_usage_estimate, lds_usage_ratio` + +### Interaction features (9) +`num_tiles_m, num_tiles_n, num_tiles_k, total_output_tiles, +tile_eff_m, tile_eff_n, tile_eff_k, overall_tile_efficiency, cu_utilization` + +### Hardware profile features (12) +`hw_num_cus, hw_simds_per_cu, hw_total_simds, hw_shader_engines, +hw_max_clock_mhz, hw_max_waves_per_cu, hw_wavefront_size, hw_lds_capacity, +hw_l1_cache_kb, hw_l2_cache_kb, hw_l3_cache_kb, hw_num_xcd` + +## Model Performance + +### fp8 RCR, gfx950 + +| Metric | 108 shapes (original) | 168 shapes (wide coverage) | +|---|---|---| +| Mean TFLOPS Efficiency | 98.28% | 97.51% | +| P10 TFLOPS Efficiency | 94.64% | 93.89% | +| tiny_m (M=1) Efficiency | 95.57% | 96.04% | +| R2 (TFLOPS) | 0.997 | 0.993 | + +### fp16 RCR, gfx950 + +Trained on 25 shapes, 1,024 kernels, 21,920 valid benchmarks. + +| Metric | Value | +|---|---| +| Mean TFLOPS Efficiency | 99.36% | +| P10 TFLOPS Efficiency | 98.05% | +| P50 TFLOPS Efficiency | 100.00% | +| Min Efficiency | 95.45% | +| NDCG@1 | 64.00% | +| Top-5 Hit Rate | 88.00% | + +**Shape Family Breakdown:** + +| Shape Family | Mean Eff | P10 Eff | Shapes | +|---|---|---|---| +| Large M (M≥1024) | 99.54% | 99.07% | 4 | +| Medium M (128≤M<1024) | 99.62% | 98.74% | 7 | +| Small M (8≤M<128) | 98.82% | 96.22% | 8 | +| Tiny M (M<8) | 99.65% | 98.96% | 6 | + +**Pipeline Breakdown:** + +| Pipeline | Mean Eff | P10 Eff | +|---|---|---| +| compv3 | 99.75% | 99.09% | +| compv4 | 99.40% | 98.54% | +| mem | 99.08% | 96.59% | + +Training uses `log1p(TFLOPS)` as the target by default, which normalizes the +scale across shapes spanning 0.02 to 2230 TFLOPS. This was the key finding +that improved tiny-M shapes from 84% to 96% efficiency. See +[LEARNINGS.md](LEARNINGS.md) for details. + +## Validation + +Training uses `GroupKFold(n_splits=5)` with group key `(M, N, K)` to ensure +the model is evaluated on shapes it has never seen during training. Layout is +excluded from the group key to force cross-layout generalization. + +## Incremental Training (Warm Start) + +When new benchmark data arrives, update the model without retraining from scratch: + +```bash +python3 train.py \ + --data_dir data/ \ + --out_dir models/v2 \ + --warm_start models/gemm_universal_fp8_gfx950 \ + --warm_start_n_estimators 200 +``` + +This adds 200 new trees on top of the existing model. Feature schemas must +match exactly (automatically enforced). + +## Extending to New Ops + +Adding support for a new operation (e.g., `gemm_streamk`, `grouped_conv`): + +1. **Build binaries**: `ninja -C build benchmark_gemm_streamk_fp8_rcr` +2. **Subclass `FeatureEngine`**: add op-specific features (e.g., StreamK split factor) +3. **Generate data**: run benchmarks across diverse shapes +4. **Train**: `python3 train.py --op gemm_streamk --dtype fp8 --data_dir data/ --out_dir models/` + +The training, evaluation, prediction, and search infrastructure is fully +op-agnostic -- only the feature engine needs a new subclass. + +## Tests + +102 tests covering all modules: + +```bash +python3 -m pytest tests/ -v +``` + +Test coverage includes: +- Log parsing with malformed JSON, empty logs, single-kernel shapes +- Feature formula correctness (tile efficiency, LDS usage, arithmetic intensity) +- Corner-case shapes: M=1, N=1, K=1, prime dimensions, 20480x7168x256 +- Batch vs single extraction parity +- Parameter space validation and projection +- Predictor: single/batch prediction, ranking, missing models, empty inputs +- Training: group keys, efficiency computation, warm-start, feature compatibility +- Search: random, DE, config validity, determinism + +## Documentation + +- **[README.md](README.md)**: This file -- quick start, architecture, performance +- **[DATA_GENERATION.md](DATA_GENERATION.md)**: Complete guide for building tile engine + binaries, running benchmarks, managing datasets, and troubleshooting +- **[LEARNINGS.md](LEARNINGS.md)**: Empirical findings and design decisions (log-transform, + IHEM results, tiny-M analysis, feature importance, N=1/K=1 edge cases) diff --git a/dispatcher/heuristics/__init__.py b/dispatcher/heuristics/__init__.py new file mode 100644 index 0000000000..e208c91163 --- /dev/null +++ b/dispatcher/heuristics/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Tile Heuristics: ML-based kernel selection diff --git a/dispatcher/heuristics/collect_additional.sh b/dispatcher/heuristics/collect_additional.sh new file mode 100755 index 0000000000..d963b1483a --- /dev/null +++ b/dispatcher/heuristics/collect_additional.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# Generate additional benchmark data for shapes NOT in the original log. +# Runs in background; outputs streaming JSON that can be parsed by data_pipeline.py. + +BIN_DIR="/workspace/ck_tile/bin" +OUT_LOG="data/additional_shapes.log" +WARMUP=3 +REPEAT=10 + +mkdir -p data + +# Additional shapes: square powers-of-2 and common ML sizes not in original DeepSeek set +SHAPES=( + "64,64,64" + "128,128,128" + "256,256,256" + "512,512,512" + "1024,1024,1024" + "2048,2048,2048" + "4096,4096,4096" + "1,4096,4096" + "8,4096,4096" + "32,4096,4096" + "128,4096,4096" + "1,4096,11008" + "32,4096,11008" + "1,8192,8192" + "32,8192,8192" + "1,8192,28672" + "32,8192,28672" + "256,256,8192" + "8192,8192,256" + "1024,4096,1024" + "4096,1024,4096" + "2048,8192,2048" +) + +echo "CK Tile Additional Shapes Benchmark" > "$OUT_LOG" +echo "GPU ID: 0" >> "$OUT_LOG" +echo "Implementation: gemm_universal" >> "$OUT_LOG" +echo "" >> "$OUT_LOG" + +SHAPE_IDX=0 +for SHAPE in "${SHAPES[@]}"; do + IFS=',' read -r M N K <<< "$SHAPE" + SHAPE_IDX=$((SHAPE_IDX + 1)) + + echo "========================================" >> "$OUT_LOG" + echo "Shape $SHAPE_IDX: M=$M N=$N K=$K dtype=fp8 layout=rcr" >> "$OUT_LOG" + echo "========================================" >> "$OUT_LOG" + + KERNEL_COUNT=0 + for EXE in "$BIN_DIR"/benchmark_gemm_universal_fp8_rcr_*; do + KERNEL_COUNT=$((KERNEL_COUNT + 1)) + OUTPUT=$("$EXE" -m="$M" -n="$N" -k="$K" -warmup=$WARMUP -repeat=$REPEAT -verify=0 2>/dev/null) + # Extract just the JSON block + echo "$OUTPUT" | sed -n '/{/,/^}/p' >> "$OUT_LOG" + done + + echo "Found $KERNEL_COUNT kernels" >> "$OUT_LOG" + echo "Completed shape $SHAPE_IDX: M=$M N=$N K=$K ($KERNEL_COUNT kernels)" >&2 +done + +echo "Done generating additional data" >&2 diff --git a/dispatcher/heuristics/convert_json_to_parquet.py b/dispatcher/heuristics/convert_json_to_parquet.py new file mode 100644 index 0000000000..4cfd667c76 --- /dev/null +++ b/dispatcher/heuristics/convert_json_to_parquet.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Convert benchmark JSON results to parquet format for training. + +Usage: + python convert_json_to_parquet.py \ + --input benchmark_results_fp16_rcr.json \ + --output fp16_training_data.parquet + +Features: + - Converts JSON benchmark results to flat row format + - Automatically fixes pad flags for _mem kernels + - Captures both successes and failures + - Compatible with existing training data format +""" + +import argparse +import json +import pandas as pd +from pathlib import Path + + +def convert_json_to_parquet(json_file: Path, output_file: Path, arch: str = "gfx950"): + """Convert benchmark JSON to parquet training data format.""" + + print(f"Loading {json_file}...") + with open(json_file) as f: + data = json.load(f) + + metadata = data.get("metadata", {}) + dtype = metadata.get("dtype", "fp16") + layout = metadata.get("layout", "rcr") + + print(f" Data type: {dtype}") + print(f" Layout: {layout}") + print(f" Kernels: {metadata.get('num_kernels', 0)}") + print(f" Problem sizes: {metadata.get('num_problems', 0)}") + print() + + rows = [] + for kernel_result in data["results"]: + kernel_config = kernel_result["kernel_config"] + + for benchmark in kernel_result["benchmarks"]: + # Common fields for both valid and invalid runs + row = { + "op_type": "gemm_universal", + "dtype": dtype, + "layout": layout, + "arch": arch, + "kernel_name": kernel_config["name"], + "m": benchmark["m"], + "n": benchmark["n"], + "k": benchmark["k"], + "split_k": 1, + "is_valid": benchmark["is_valid"], + "run_id": 0, + "pipeline": kernel_config["pipeline"], + "epilogue": kernel_config["epilogue"], + "scheduler": kernel_config["scheduler"], + "pad_m": kernel_config["pad_m"], + "pad_n": kernel_config["pad_n"], + "pad_k": kernel_config["pad_k"], + "persistent": kernel_config["persistent"], + "tile_m": kernel_config["tile_m"], + "tile_n": kernel_config["tile_n"], + "tile_k": kernel_config["tile_k"], + "warp_m": kernel_config["warp_m"], + "warp_n": kernel_config["warp_n"], + "warp_k": kernel_config["warp_k"], + "warp_tile_m": kernel_config["warp_tile_m"], + "warp_tile_n": kernel_config["warp_tile_n"], + "warp_tile_k": kernel_config["warp_tile_k"], + } + + if benchmark["is_valid"]: + # Valid run - include performance metrics + row["measured_tflops"] = benchmark["tflops"] + row["latency_ms"] = benchmark["avg_time_ms"] + # Calculate bandwidth if needed + m, n, k = benchmark["m"], benchmark["n"], benchmark["k"] + bytes_transferred = (m * k + k * n + m * n) * 2 # FP16 = 2 bytes + if benchmark["avg_time_ms"] > 0: + row["bandwidth_gb_s"] = (bytes_transferred / 1e9) / ( + benchmark["avg_time_ms"] / 1000 + ) + else: + row["bandwidth_gb_s"] = 0.0 + else: + # Failed run - zero metrics + row["measured_tflops"] = 0.0 + row["latency_ms"] = 0.0 + row["bandwidth_gb_s"] = 0.0 + + rows.append(row) + + df = pd.DataFrame(rows) + + print(f"Converted {len(df):,} benchmark results") + print(f" Valid: {df['is_valid'].sum():,}") + print(f" Failed: {(~df['is_valid']).sum():,}") + print() + + # Fix pad flags for _mem kernels (critical for P1 features!) + print("Fixing pad flags for _mem kernels...") + mem_mask = df["pipeline"] == "mem" + mem_count = mem_mask.sum() + + if mem_count > 0: + df.loc[mem_mask, "pad_m"] = True + df.loc[mem_mask, "pad_n"] = True + df.loc[mem_mask, "pad_k"] = True + print(f" ✓ Fixed {mem_count:,} _mem kernel rows") + print() + + # Save to parquet + df.to_parquet(output_file, index=False) + print(f"✓ Saved to {output_file}") + print() + + # Show statistics + print("=" * 80) + print("STATISTICS") + print("=" * 80) + print() + + print("Dimension ranges:") + print(f" M: {df['m'].min():,} - {df['m'].max():,}") + print(f" N: {df['n'].min():,} - {df['n'].max():,}") + print(f" K: {df['k'].min():,} - {df['k'].max():,}") + print() + + print("Pipeline distribution:") + print(df["pipeline"].value_counts()) + print() + + print("Pad flag distribution:") + pad_combos = df[["pad_m", "pad_n", "pad_k"]].value_counts() + print(pad_combos) + print() + + if (~df["is_valid"]).sum() > 0: + print("Failure analysis:") + failed = df[~df["is_valid"]] + print(f" Total failures: {len(failed):,}") + + # Group by pipeline + print("\n By pipeline:") + for pipeline, count in failed["pipeline"].value_counts().items(): + print(f" {pipeline}: {count:,}") + + # Show sample failures + print("\n Sample failures:") + for _, row in failed.head(5).iterrows(): + print( + f" {row['kernel_name'][:60]:60s} M={row['m']:4d} N={row['n']:4d} K={row['k']:4d}" + ) + + return df + + +def merge_datasets(parquet_files: list[Path], output_file: Path): + """Merge multiple parquet files into one.""" + + print("=" * 80) + print("MERGING DATASETS") + print("=" * 80) + print() + + dfs = [] + for pq_file in parquet_files: + if pq_file.exists(): + df = pd.read_parquet(pq_file) + print(f" {pq_file.name}: {len(df):,} rows") + dfs.append(df) + else: + print(f" ✗ {pq_file} not found, skipping") + + if not dfs: + print("No files to merge!") + return + + combined = pd.concat(dfs, ignore_index=True) + combined.to_parquet(output_file, index=False) + + print() + print(f"✓ Merged {len(combined):,} total rows to {output_file}") + print() + + # Show dtype distribution + print("Data type distribution:") + print(combined["dtype"].value_counts()) + print() + + print("Layout distribution:") + print(combined["layout"].value_counts()) + + +def main(): + parser = argparse.ArgumentParser( + description="Convert benchmark JSON to parquet training data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--input", type=str, required=True, help="Input JSON file from benchmark" + ) + parser.add_argument("--output", type=str, required=True, help="Output parquet file") + parser.add_argument("--arch", type=str, default="gfx950", help="GPU architecture") + parser.add_argument( + "--merge_with", type=str, nargs="*", help="Additional parquet files to merge" + ) + + args = parser.parse_args() + + input_file = Path(args.input) + output_file = Path(args.output) + + # Convert JSON to parquet + df = convert_json_to_parquet(input_file, output_file, args.arch) + + # Merge if requested + if args.merge_with: + merge_files = [output_file] + [Path(f) for f in args.merge_with] + merged_output = output_file.parent / f"{output_file.stem}_merged.parquet" + merge_datasets(merge_files, merged_output) + + +if __name__ == "__main__": + main() diff --git a/dispatcher/heuristics/data_pipeline.py b/dispatcher/heuristics/data_pipeline.py new file mode 100644 index 0000000000..c3f5f9ced7 --- /dev/null +++ b/dispatcher/heuristics/data_pipeline.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Data pipeline for CK Tile heuristics. + +Parses benchmark logs and structured JSON into a canonical parquet dataset. +Supports: + - Streaming log format (Shape N: headers + inline JSON) from ck_tile profiling runs + - Structured JSON from generate_benchmark_data.py + - Direct parquet passthrough +""" + +import json +import re +import subprocess +import hashlib +from pathlib import Path +from typing import Optional + +import pandas as pd + + +CANONICAL_COLUMNS = [ + "op_type", + "dtype", + "layout", + "arch", + "kernel_name", + "m", + "n", + "k", + "split_k", + "measured_tflops", + "latency_ms", + "bandwidth_gb_s", + "is_valid", + "tile_m", + "tile_n", + "tile_k", + "warp_m", + "warp_n", + "warp_k", + "warp_tile_m", + "warp_tile_n", + "warp_tile_k", + "pipeline", + "scheduler", + "epilogue", + "pad_m", + "pad_n", + "pad_k", + "persistent", + "run_id", +] + + +def parse_kernel_name(name: str) -> dict: + """Extract kernel config fields from a gemm_universal kernel name. + + Name format: + gemm_universal_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler} + _{padM}_{padN}_{padK}_{persistent}_{tileM}x{tileN}x{tileK} + _{warpM}x{warpN}x{warpK}_{warpTileM}x{warpTileN}x{warpTileK} + """ + result = {} + try: + prefix_match = re.match( + r"gemm_universal_(\w+?)_((?:rcr|rrr|crr|ccr))_(.*)", name + ) + if not prefix_match: + return result + result["dtype"] = prefix_match.group(1) + result["layout"] = prefix_match.group(2) + remainder = prefix_match.group(3) + + parts = remainder.split("_") + if len(parts) < 10: + return result + + result["pipeline"] = parts[0] + result["epilogue"] = parts[1] + result["scheduler"] = parts[2] + result["pad_m"] = parts[3] == "True" + result["pad_n"] = parts[4] == "True" + result["pad_k"] = parts[5] == "True" + result["persistent"] = parts[6] == "True" + + tile_dims = parts[7].split("x") + warp_dims = parts[8].split("x") + warp_tile_dims = parts[9].split("x") + + result["tile_m"] = int(tile_dims[0]) + result["tile_n"] = int(tile_dims[1]) + result["tile_k"] = int(tile_dims[2]) + result["warp_m"] = int(warp_dims[0]) + result["warp_n"] = int(warp_dims[1]) + result["warp_k"] = int(warp_dims[2]) + result["warp_tile_m"] = int(warp_tile_dims[0]) + result["warp_tile_n"] = int(warp_tile_dims[1]) + result["warp_tile_k"] = int(warp_tile_dims[2]) + except (IndexError, ValueError): + pass + return result + + +def _layout_from_problem(problem: dict) -> str: + """Derive layout shorthand (rcr/rrr/etc.) from problem JSON fields.""" + la = problem.get("layout_a", "") + lb = problem.get("layout_b", "") + lc = problem.get("layout_c", "") + + def _tag(s): + s = s.lower() + if "row" in s: + return "r" + if "col" in s: + return "c" + return "?" + + return _tag(la) + _tag(lb) + _tag(lc) + + +def parse_streaming_log( + path: str | Path, + arch: str = "unknown", + run_id: Optional[str] = None, + op_type: str = "gemm_universal", +) -> pd.DataFrame: + """Parse a CK Tile streaming benchmark log into a canonical DataFrame. + + The log alternates between shape headers and JSON result blocks: + Shape N: M=16 N=1536 K=7168 dtype=fp8 layout=rcr + { + "name": "gemm_universal_...", + "problem": { ... }, + "perf_result": { "latency(ms)": ..., "tflops(TFlops)": ..., "bandwidth(GB/s)": ... } + } + """ + path = Path(path) + if run_id is None: + run_id = hashlib.md5(path.name.encode()).hexdigest()[:12] + + shape_re = re.compile( + r"Shape\s+\d+:\s+M=(\d+)\s+N=(\d+)\s+K=(\d+)\s+dtype=(\w+)\s+layout=(\w+)" + ) + + rows = [] + current_m, current_n, current_k = 0, 0, 0 + current_dtype, current_layout = "", "" + json_buf = [] + brace_depth = 0 + + with open(path, "r") as f: + for line in f: + stripped = line.strip() + + shape_match = shape_re.search(stripped) + if shape_match: + current_m = int(shape_match.group(1)) + current_n = int(shape_match.group(2)) + current_k = int(shape_match.group(3)) + current_dtype = shape_match.group(4) + current_layout = shape_match.group(5) + continue + + if brace_depth == 0 and stripped.startswith("{"): + json_buf = [stripped] + brace_depth = stripped.count("{") - stripped.count("}") + if brace_depth == 0: + raw = "\n".join(json_buf) + try: + obj = json.loads(raw) + except json.JSONDecodeError: + continue + else: + continue + elif brace_depth > 0: + json_buf.append(stripped) + brace_depth += stripped.count("{") - stripped.count("}") + if brace_depth <= 0: + brace_depth = 0 + raw = "\n".join(json_buf) + try: + obj = json.loads(raw) + except json.JSONDecodeError: + continue + else: + continue + else: + continue + + # If we get here, obj was successfully parsed + kernel_name = obj.get("name", "") + problem = obj.get("problem", {}) + perf = obj.get("perf_result", {}) + + m = problem.get("m", current_m) + n = problem.get("n", current_n) + k = problem.get("k", current_k) + split_k = problem.get("split_k", 1) + dtype = problem.get("dtype_a", current_dtype) + layout = ( + _layout_from_problem(problem) + if problem.get("layout_a") + else current_layout + ) + + tflops = perf.get("tflops(TFlops)", 0.0) + latency = perf.get("latency(ms)", 0.0) + bandwidth = perf.get("bandwidth(GB/s)", 0.0) + + kp = parse_kernel_name(kernel_name) + + row = { + "op_type": op_type, + "dtype": dtype, + "layout": layout, + "arch": arch, + "kernel_name": kernel_name, + "m": m, + "n": n, + "k": k, + "split_k": split_k, + "measured_tflops": tflops, + "latency_ms": latency, + "bandwidth_gb_s": bandwidth, + "is_valid": tflops > 0 and latency > 0, + "run_id": run_id, + } + row.update(kp) + rows.append(row) + + df = pd.DataFrame(rows) + for col in CANONICAL_COLUMNS: + if col not in df.columns: + df[col] = None + return df + + +def get_hardware_profile() -> dict: + """Capture GPU hardware profile from rocminfo.""" + profile = {} + try: + result = subprocess.run( + ["rocminfo"], capture_output=True, text=True, timeout=30 + ) + output = result.stdout + + gpu_section = False + for line in output.split("\n"): + line = line.strip() + if "Device Type:" in line and "GPU" in line: + gpu_section = True + continue + if gpu_section and "Device Type:" in line and "GPU" not in line: + break + if not gpu_section: + continue + + if ":" not in line: + continue + key, _, val = line.partition(":") + key = key.strip() + val = val.strip() + + mapping = { + "Name": "gfx_name", + "Marketing Name": "marketing_name", + "Compute Unit": "num_cus", + "SIMDs per CU": "simds_per_cu", + "Shader Engines": "shader_engines", + "Shader Arrs. per Eng.": "shader_arrays_per_engine", + "Max Clock Freq. (MHz)": "max_clock_mhz", + "Wavefront Size": "wavefront_size", + "Max Waves Per CU": "max_waves_per_cu", + "Chip ID": "chip_id", + } + + if key in mapping: + raw = val.split("(")[0].strip() + try: + profile[mapping[key]] = int(raw) + except ValueError: + profile[mapping[key]] = raw + + for line in output.split("\n"): + line = line.strip() + if line.startswith("L1:") and "num_cus" in profile: + raw = line.split(":")[1].strip().split("(")[0].strip() + try: + profile["l1_cache_kb"] = int(raw) + except ValueError: + pass + elif line.startswith("L2:"): + raw = line.split(":")[1].strip().split("(")[0].strip() + try: + profile["l2_cache_kb"] = int(raw) + except ValueError: + pass + elif line.startswith("L3:"): + raw = line.split(":")[1].strip().split("(")[0].strip() + try: + profile["l3_cache_kb"] = int(raw) + except ValueError: + pass + + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + + return profile + + +def load_parquet(path: str | Path) -> pd.DataFrame: + """Load a canonical parquet dataset.""" + return pd.read_parquet(path) + + +def save_parquet(df: pd.DataFrame, path: str | Path): + """Save a DataFrame in canonical parquet format.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(path, index=False, engine="pyarrow") + + +def build_training_dataset( + data_dir: str | Path, + op_type: str = "gemm_universal", + dtype: str = "fp8", +) -> pd.DataFrame: + """Load and merge all parquet files matching the given op/dtype from a directory.""" + data_dir = Path(data_dir) + frames = [] + for f in sorted(data_dir.glob("*.parquet")): + df = pd.read_parquet(f) + if "op_type" in df.columns: + df = df[df["op_type"] == op_type] + if "dtype" in df.columns: + df = df[df["dtype"] == dtype] + if len(df) > 0: + frames.append(df) + if not frames: + raise FileNotFoundError( + f"No parquet files with op_type={op_type}, dtype={dtype} in {data_dir}" + ) + return pd.concat(frames, ignore_index=True) + + +if __name__ == "__main__": + import argparse + import time + + parser = argparse.ArgumentParser(description="Parse CK Tile benchmark data") + parser.add_argument("input", help="Input file (log or parquet)") + parser.add_argument("--output", "-o", required=True, help="Output parquet path") + parser.add_argument("--arch", default="gfx950", help="GPU architecture") + parser.add_argument("--op_type", default="gemm_universal", help="Operation type") + parser.add_argument( + "--capture_hw", + action="store_true", + help="Capture hardware profile from rocminfo", + ) + args = parser.parse_args() + + input_path = Path(args.input) + + print(f"Parsing {input_path}...") + t0 = time.time() + + if input_path.suffix == ".parquet": + df = load_parquet(input_path) + else: + df = parse_streaming_log(input_path, arch=args.arch, op_type=args.op_type) + + elapsed = time.time() - t0 + print(f"Parsed {len(df)} rows in {elapsed:.1f}s") + print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}") + print(f" Unique kernels: {df['kernel_name'].nunique()}") + print(f" Valid rows: {df['is_valid'].sum()} / {len(df)}") + + if df["measured_tflops"].max() > 0: + print( + f" TFLOPS range: {df['measured_tflops'].min():.2f} - {df['measured_tflops'].max():.2f}" + ) + + if args.capture_hw: + hw = get_hardware_profile() + print(f" Hardware profile: {hw}") + for k, v in hw.items(): + df[f"hw_{k}"] = v + + save_parquet(df, args.output) + print(f"Saved to {args.output}") diff --git a/dispatcher/heuristics/dispatcher_integration.py b/dispatcher/heuristics/dispatcher_integration.py new file mode 100644 index 0000000000..c449c1e816 --- /dev/null +++ b/dispatcher/heuristics/dispatcher_integration.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Dispatcher integration for ML-based kernel selection. + +Bridges the trained LightGBM Predictor with the CK Tile dispatcher's +kernel selection flow. Provides heuristic functions compatible with +both the Python pre-selection pattern (08_heuristics.py style) and +the C++ HeuristicFunction signature. + +Name mapping between feature engine and dispatcher KernelConfig: + Feature engine Dispatcher KernelConfig + --------------------- ---------------------- + warp_m (warps/block) wave_m + warp_n wave_n + warp_k wave_k + warp_tile_m warp_m + warp_tile_n warp_n + warp_tile_k warp_k + +Usage: + from dispatcher_integration import create_ml_heuristic + + heuristic = create_ml_heuristic("models/gemm_universal_fp8_gfx950") + best_spec = heuristic(M=1024, N=1024, K=1024, kernel_pool=KERNEL_POOL) +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + + +from data_pipeline import parse_kernel_name +from predict import Predictor + + +LAYOUT_TO_DISPATCHER = { + "rcr": ("row", "col", "row"), + "rrr": ("row", "row", "row"), + "crr": ("col", "row", "row"), + "ccr": ("col", "col", "row"), +} + +DTYPE_TO_C_DTYPE = { + "fp8": "fp16", + "fp16": "fp16", + "bf16": "bf16", + "fp32": "fp32", +} + + +@dataclass +class MLKernelSpec: + """Kernel spec returned by the ML heuristic, compatible with the dispatcher + example pattern. Carries both the feature-engine-space config and the + dispatcher-space KernelConfig fields.""" + + kernel_name: str + predicted_tflops: float + + tile_m: int + tile_n: int + tile_k: int + wave_m: int + wave_n: int + wave_k: int + warp_m: int + warp_n: int + warp_k: int + pipeline: str + scheduler: str + epilogue: str + pad_m: bool + pad_n: bool + pad_k: bool + persistent: bool + + +def kernel_config_to_feature_dict(kernel_name: str) -> dict: + """Parse a tile-engine kernel name into a feature-engine-compatible dict. + + Returns a dict with fields matching what GemmUniversalFeatureEngine.extract() + expects for the kernel parameter: tile_m/n/k, warp_m/n/k (warps per block), + warp_tile_m/n/k, pipeline, scheduler, epilogue, pad_m/n/k, persistent. + """ + parsed = parse_kernel_name(kernel_name) + if not parsed: + return {} + parsed["kernel_name"] = kernel_name + return parsed + + +def feature_dict_to_dispatcher_config( + feat: dict, dtype: str = "fp8", arch: str = "gfx950" +) -> dict: + """Convert a feature-engine kernel dict to dispatcher KernelConfig fields. + + Handles the naming inversion: + feature engine warp_m -> KernelConfig wave_m (warps per block) + feature engine warp_tile_m -> KernelConfig warp_m (elements per warp) + """ + layout = feat.get("layout", "rcr") + la, lb, lc = LAYOUT_TO_DISPATCHER.get(layout, ("row", "col", "row")) + c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype) + + return { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": c_dtype, + "dtype_acc": "fp32", + "layout_a": la, + "layout_b": lb, + "layout_c": lc, + "tile_m": feat.get("tile_m", 128), + "tile_n": feat.get("tile_n", 128), + "tile_k": feat.get("tile_k", 64), + "wave_m": feat.get("warp_m", 2), + "wave_n": feat.get("warp_n", 2), + "wave_k": feat.get("warp_k", 1), + "warp_m": feat.get("warp_tile_m", 32), + "warp_n": feat.get("warp_tile_n", 32), + "warp_k": feat.get("warp_tile_k", 16), + "pipeline": feat.get("pipeline", "compv3"), + "scheduler": feat.get("scheduler", "intrawave"), + "epilogue": feat.get("epilogue", "cshuffle"), + "pad_m": feat.get("pad_m", True), + "pad_n": feat.get("pad_n", True), + "pad_k": feat.get("pad_k", True), + "gfx_arch": arch, + } + + +def feature_dict_to_ml_spec(feat: dict, predicted_tflops: float = 0.0) -> MLKernelSpec: + """Convert a feature-engine kernel dict + prediction to an MLKernelSpec.""" + return MLKernelSpec( + kernel_name=feat.get("kernel_name", "unknown"), + predicted_tflops=predicted_tflops, + tile_m=feat.get("tile_m", 128), + tile_n=feat.get("tile_n", 128), + tile_k=feat.get("tile_k", 64), + wave_m=feat.get("warp_m", 2), + wave_n=feat.get("warp_n", 2), + wave_k=feat.get("warp_k", 1), + warp_m=feat.get("warp_tile_m", 32), + warp_n=feat.get("warp_tile_n", 32), + warp_k=feat.get("warp_tile_k", 16), + pipeline=feat.get("pipeline", "compv3"), + scheduler=feat.get("scheduler", "intrawave"), + epilogue=feat.get("epilogue", "cshuffle"), + pad_m=feat.get("pad_m", False), + pad_n=feat.get("pad_n", False), + pad_k=feat.get("pad_k", False), + persistent=feat.get("persistent", False), + ) + + +def load_kernel_pool_from_binaries(bin_dir: str | Path) -> list[dict]: + """Discover benchmark executables and parse their names into feature dicts. + + Each executable name encodes the full kernel config. This creates the + candidate pool for the ML heuristic without needing a registry JSON export. + """ + bin_dir = Path(bin_dir) + configs = [] + for exe in sorted(bin_dir.glob("benchmark_gemm_universal_*")): + name = exe.stem.replace("benchmark_", "") + feat = kernel_config_to_feature_dict(name) + if feat and "tile_m" in feat: + configs.append(feat) + return configs + + +def create_ml_heuristic( + model_dir: str | Path, + dtype: str = "fp8", + arch: str = "gfx950", + layout: str = "rcr", + kernel_pool: Optional[list[dict]] = None, + bin_dir: Optional[str | Path] = None, +): + """Create an ML heuristic function for kernel selection. + + Returns a callable with signature: + (M: int, N: int, K: int) -> MLKernelSpec + + The returned function scores all candidate kernels using the trained + LightGBM regressor and returns the best one as an MLKernelSpec. + + Parameters + ---------- + model_dir : str or Path + Path to trained model directory (must contain model_tflops.lgbm or + model_tflops_log_big.lgbm and feature_spec.json). + dtype : str + Data type for the problem (fp8, fp16, bf16). + arch : str + GPU architecture (gfx942, gfx950). + layout : str + Matrix layout (rcr, rrr, crr, ccr). + kernel_pool : list of dict, optional + Pre-parsed kernel configs. If None, loads from bin_dir. + bin_dir : str or Path, optional + Directory with benchmark executables. Used to build kernel_pool if + kernel_pool is not provided. Defaults to /workspace/ck_tile/bin. + """ + model_dir = Path(model_dir) + predictor = Predictor(model_dir) + + if kernel_pool is None: + if bin_dir is None: + bin_dir = Path("/workspace/ck_tile/bin") + kernel_pool = load_kernel_pool_from_binaries(bin_dir) + + if not kernel_pool: + raise ValueError( + "No kernel configs found. Check bin_dir or provide kernel_pool." + ) + + def heuristic(M: int, N: int, K: int) -> MLKernelSpec: + problem = { + "m": M, + "n": N, + "k": K, + "dtype": dtype, + "layout": layout, + "split_k": 1, + } + + ranked = predictor.rank_kernels(problem, kernel_pool) + + if not ranked: + feat = kernel_pool[0] + return feature_dict_to_ml_spec(feat, 0.0) + + best_name, best_tflops = ranked[0] + best_feat = next( + (kp for kp in kernel_pool if kp.get("kernel_name") == best_name), + kernel_pool[0], + ) + return feature_dict_to_ml_spec(best_feat, best_tflops) + + return heuristic + + +def create_ranked_heuristic( + model_dir: str | Path, + dtype: str = "fp8", + arch: str = "gfx950", + layout: str = "rcr", + kernel_pool: Optional[list[dict]] = None, + bin_dir: Optional[str | Path] = None, + top_k: int = 5, +): + """Create an ML heuristic that returns the top-K ranked kernel specs. + + Returns a callable with signature: + (M: int, N: int, K: int) -> list[MLKernelSpec] + + Useful when you want fallback options if the top-1 kernel fails to build. + """ + model_dir = Path(model_dir) + predictor = Predictor(model_dir) + + if kernel_pool is None: + if bin_dir is None: + bin_dir = Path("/workspace/ck_tile/bin") + kernel_pool = load_kernel_pool_from_binaries(bin_dir) + + name_to_feat = {kp.get("kernel_name", ""): kp for kp in kernel_pool} + + def heuristic(M: int, N: int, K: int) -> list[MLKernelSpec]: + problem = { + "m": M, + "n": N, + "k": K, + "dtype": dtype, + "layout": layout, + "split_k": 1, + } + + ranked = predictor.rank_kernels(problem, kernel_pool) + results = [] + for name, tflops in ranked[:top_k]: + feat = name_to_feat.get(name, kernel_pool[0]) + results.append(feature_dict_to_ml_spec(feat, tflops)) + return results + + return heuristic + + +def ml_spec_to_dispatcher_config( + spec: MLKernelSpec, dtype: str = "fp8", arch: str = "gfx950" +) -> dict: + """Convert an MLKernelSpec to a dict compatible with ctypes_utils.KernelConfig.""" + layout_a, layout_b, layout_c = "row", "col", "row" + c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype) + + return { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": c_dtype, + "dtype_acc": "fp32", + "layout_a": layout_a, + "layout_b": layout_b, + "layout_c": layout_c, + "tile_m": spec.tile_m, + "tile_n": spec.tile_n, + "tile_k": spec.tile_k, + "wave_m": spec.wave_m, + "wave_n": spec.wave_n, + "wave_k": spec.wave_k, + "warp_m": spec.warp_m, + "warp_n": spec.warp_n, + "warp_k": spec.warp_k, + "pipeline": spec.pipeline, + "scheduler": spec.scheduler, + "epilogue": spec.epilogue, + "pad_m": spec.pad_m, + "pad_n": spec.pad_n, + "pad_k": spec.pad_k, + "gfx_arch": arch, + } diff --git a/dispatcher/heuristics/evaluate.py b/dispatcher/heuristics/evaluate.py new file mode 100644 index 0000000000..95c850aaf5 --- /dev/null +++ b/dispatcher/heuristics/evaluate.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Evaluation and reporting for CK Tile kernel performance models. + +Computes: + - Global metrics: TFLOPS efficiency (mean, p10, p50, min), R2, NDCG@1, Top-K hit rate + - Per-slice breakdowns: by layout, shape family, K-depth regime, pipeline + - Cross-target consistency checks + - Feature importance analysis + +Usage: + python evaluate.py --model_dir models/gemm_universal_fp8_gfx950 --data_dir data/ +""" + +import argparse +import json + +import numpy as np +import pandas as pd + +from data_pipeline import build_training_dataset +from feature_engine import GemmUniversalFeatureEngine +from predict import Predictor +from train import compute_tflops_efficiency + + +def classify_shape_family(m: int, n: int, k: int) -> str: + """Classify a GEMM shape into a family for sliced evaluation. + + Families: + - tiny_m: M < 32 (single-token / very small batch inference) + - small_m: 32 <= M < 256 + - medium_m: 256 <= M < 4096 + - large_m: M >= 4096 + - square: 0.5 <= M/N <= 2.0 and 0.5 <= M/K <= 2.0 + - tall: M/N > 2.0 + - wide: M/N < 0.5 + """ + if m < 32: + return "tiny_m" + elif m < 256: + return "small_m" + elif m < 4096: + return "medium_m" + elif m >= 4096: + return "large_m" + return "other" + + +def classify_k_regime(k: int) -> str: + """Classify K dimension into depth regime.""" + if k < 512: + return "shallow_k" + elif k < 4096: + return "medium_k" + else: + return "deep_k" + + +def evaluate_model( + predictor: Predictor, + df: pd.DataFrame, + feature_engine: GemmUniversalFeatureEngine, +) -> dict: + """Run full evaluation on a dataset. Returns a metrics dictionary. + + Parameters + ---------- + predictor : Predictor + Trained predictor with at least a TFLOPS model loaded. + df : pd.DataFrame + Benchmark data in canonical schema. + feature_engine : GemmUniversalFeatureEngine + Feature engine matching the trained model. + + Returns + ------- + dict with keys: global_metrics, shape_family_metrics, k_regime_metrics, + pipeline_metrics, per_shape_efficiency. + """ + valid = df[df["is_valid"].fillna(False) & (df["measured_tflops"] > 0)].copy() + valid = valid.reset_index(drop=True) + + X = feature_engine.extract_batch(valid) + model = predictor._load_model("tflops") + if model is None: + raise FileNotFoundError("No TFLOPS model found") + + # Predict and apply inverse log transform if model was trained in log-space + raw_pred = model.predict(X) + if "tflops" in predictor._log_targets: + valid["pred_tflops"] = np.expm1(raw_pred) + else: + # Clamp to non-negative even for non-log models + valid["pred_tflops"] = np.maximum(0.0, raw_pred) + + y_true = valid["measured_tflops"].values + y_pred = valid["pred_tflops"].values + + ss_res = np.sum((y_true - y_pred) ** 2) + ss_tot = np.sum((y_true - y_true.mean()) ** 2) + r2 = 1 - ss_res / max(ss_tot, 1e-10) + rmse = np.sqrt(np.mean((y_true - y_pred) ** 2)) + mae = np.mean(np.abs(y_true - y_pred)) + + eff_df = compute_tflops_efficiency(valid, "pred_tflops") + + ndcg1_count = 0 + total_shapes = 0 + topk_hits = {3: 0, 5: 0, 10: 0} + + for (m, n, k), group in valid.groupby(["m", "n", "k"]): + if group["measured_tflops"].max() <= 0: + continue + total_shapes += 1 + oracle_idx = group["measured_tflops"].idxmax() + pred_ranking = group.sort_values("pred_tflops", ascending=False).index.tolist() + + if pred_ranking[0] == oracle_idx: + ndcg1_count += 1 + + oracle_rank = pred_ranking.index(oracle_idx) + for topk in topk_hits: + if oracle_rank < topk: + topk_hits[topk] += 1 + + global_metrics = { + "r2": r2, + "rmse": rmse, + "mae": mae, + "num_valid_rows": len(valid), + "num_shapes": total_shapes, + "efficiency_mean": float(eff_df["efficiency"].mean()) if len(eff_df) > 0 else 0, + "efficiency_p10": float(eff_df["efficiency"].quantile(0.1)) + if len(eff_df) > 0 + else 0, + "efficiency_p50": float(eff_df["efficiency"].quantile(0.5)) + if len(eff_df) > 0 + else 0, + "efficiency_min": float(eff_df["efficiency"].min()) if len(eff_df) > 0 else 0, + "ndcg_at_1": ndcg1_count / max(total_shapes, 1), + "top3_hit_rate": topk_hits[3] / max(total_shapes, 1), + "top5_hit_rate": topk_hits[5] / max(total_shapes, 1), + "top10_hit_rate": topk_hits[10] / max(total_shapes, 1), + } + + def _slice_efficiency(slice_df): + if len(slice_df) == 0: + return {"count": 0} + eff = compute_tflops_efficiency(slice_df, "pred_tflops") + if len(eff) == 0: + return {"count": 0} + return { + "count": len(eff), + "mean": float(eff["efficiency"].mean()), + "p10": float(eff["efficiency"].quantile(0.1)), + "min": float(eff["efficiency"].min()), + } + + valid["shape_family"] = valid.apply( + lambda r: classify_shape_family(r["m"], r["n"], r["k"]), axis=1 + ) + valid["k_regime"] = valid["k"].apply(classify_k_regime) + + shape_family_metrics = {} + for family, group in valid.groupby("shape_family"): + shape_family_metrics[family] = _slice_efficiency(group) + + k_regime_metrics = {} + for regime, group in valid.groupby("k_regime"): + k_regime_metrics[regime] = _slice_efficiency(group) + + pipeline_metrics = {} + if "pipeline" in valid.columns: + for pipeline, group in valid.groupby("pipeline"): + pipeline_metrics[str(pipeline)] = _slice_efficiency(group) + + return { + "global_metrics": global_metrics, + "shape_family_metrics": shape_family_metrics, + "k_regime_metrics": k_regime_metrics, + "pipeline_metrics": pipeline_metrics, + "per_shape_efficiency": eff_df.to_dict(orient="records") + if len(eff_df) > 0 + else [], + } + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate CK Tile performance model") + parser.add_argument( + "--model_dir", required=True, help="Directory with trained models" + ) + parser.add_argument("--data_dir", required=True, help="Directory with parquet data") + parser.add_argument("--op", default="gemm_universal") + parser.add_argument("--dtype", default="fp8") + parser.add_argument("--output", "-o", help="Output JSON path for metrics") + args = parser.parse_args() + + print(f"Loading data from {args.data_dir}...") + df = build_training_dataset(args.data_dir, op_type=args.op, dtype=args.dtype) + print(f" {len(df)} rows, {df.groupby(['m', 'n', 'k']).ngroups} shapes") + + fe = GemmUniversalFeatureEngine() + predictor = Predictor(args.model_dir, feature_engine=fe) + + print("Evaluating...") + results = evaluate_model(predictor, df, fe) + + gm = results["global_metrics"] + print("\nGlobal Metrics:") + print(f" R2: {gm['r2']:.4f}") + print(f" RMSE: {gm['rmse']:.2f}") + print(f" Efficiency Mean: {gm['efficiency_mean']:.4f}") + print(f" Efficiency P10: {gm['efficiency_p10']:.4f}") + print(f" Efficiency P50: {gm['efficiency_p50']:.4f}") + print(f" Efficiency Min: {gm['efficiency_min']:.4f}") + print(f" NDCG@1: {gm['ndcg_at_1']:.4f}") + print(f" Top-3 Hit Rate: {gm['top3_hit_rate']:.4f}") + print(f" Top-5 Hit Rate: {gm['top5_hit_rate']:.4f}") + print(f" Top-10 Hit Rate: {gm['top10_hit_rate']:.4f}") + + print("\nShape Family Breakdown:") + for family, metrics in sorted(results["shape_family_metrics"].items()): + if metrics.get("count", 0) > 0: + print( + f" {family:12s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} min={metrics['min']:.4f} (n={metrics['count']})" + ) + + print("\nK-Depth Regime Breakdown:") + for regime, metrics in sorted(results["k_regime_metrics"].items()): + if metrics.get("count", 0) > 0: + print( + f" {regime:12s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} min={metrics['min']:.4f} (n={metrics['count']})" + ) + + print("\nPipeline Breakdown:") + for pipeline, metrics in sorted(results["pipeline_metrics"].items()): + if metrics.get("count", 0) > 0: + print( + f" {pipeline:15s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} (n={metrics['count']})" + ) + + if args.output: + with open(args.output, "w") as f: + json.dump(results, f, indent=2, default=str) + print(f"\nFull results saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/heuristics/feature_engine.py b/dispatcher/heuristics/feature_engine.py new file mode 100644 index 0000000000..557d9d8992 --- /dev/null +++ b/dispatcher/heuristics/feature_engine.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Feature engineering for CK Tile kernel performance prediction. + +Provides a strict FeatureEngine interface with per-op subclasses. +All feature engines produce a consistent numpy array for LightGBM. +""" + +import math +from abc import ABC, abstractmethod + +import numpy as np +import pandas as pd + + +DTYPE_BYTES = { + "fp32": 4.0, + "fp16": 2.0, + "bf16": 2.0, + "fp8": 1.0, + "bf8": 1.0, + "int8": 1.0, + "int4": 0.5, +} + +LAYOUT_MAP = {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3} +PIPELINE_MAP = {"compv3": 0, "compv4": 1, "compv5": 2, "mem": 3, "preshufflev2": 4} +SCHEDULER_MAP = {"intrawave": 0, "interwave": 1} +EPILOGUE_MAP = {"default": 0, "cshuffle": 1} + + +class FeatureEngine(ABC): + """Abstract base for per-op feature extraction.""" + + @abstractmethod + def get_feature_names(self) -> list[str]: + """Ordered list of feature names matching the output array columns.""" + ... + + @abstractmethod + def get_categorical_features(self) -> list[str]: + """Feature names that should be treated as categorical by LightGBM.""" + ... + + @abstractmethod + def extract(self, problem: dict, kernel: dict) -> np.ndarray: + """Extract a single feature vector from a (problem, kernel) pair.""" + ... + + def extract_batch(self, df: pd.DataFrame) -> np.ndarray: + """Vectorized batch extraction from a DataFrame. Override for speed.""" + names = self.get_feature_names() + result = np.zeros((len(df), len(names)), dtype=np.float64) + for i in range(len(df)): + row = df.iloc[i] + prob = row.to_dict() + kern = row.to_dict() + result[i] = self.extract(prob, kern) + return result + + def get_parameter_space(self) -> dict[str, list]: + """Valid discrete values for each kernel parameter (for surrogate search).""" + return {} + + def get_constraints(self) -> list: + """Multi-param constraint functions returning True if config is valid.""" + return [] + + def validate_config(self, config: dict) -> bool: + """Check all constraints. Returns True if the config is valid.""" + ps = self.get_parameter_space() + for k, valid_vals in ps.items(): + if k in config and config[k] not in valid_vals: + return False + for constraint in self.get_constraints(): + if not constraint(config): + return False + return True + + def project_to_valid(self, config: dict) -> dict: + """Snap a config to the nearest valid discrete point.""" + ps = self.get_parameter_space() + result = dict(config) + for k, valid_vals in ps.items(): + if k not in result: + continue + v = result[k] + if isinstance(valid_vals[0], (int, float)): + result[k] = min(valid_vals, key=lambda x: abs(x - v)) + elif v not in valid_vals: + result[k] = valid_vals[0] + return result + + +class GemmUniversalFeatureEngine(FeatureEngine): + """Feature engine for gemm_universal kernels.""" + + def __init__( + self, + num_cus: int = 256, + lds_capacity: int = 65536, + max_clock_mhz: int = 2400, + simds_per_cu: int = 4, + shader_engines: int = 32, + max_waves_per_cu: int = 32, + wavefront_size: int = 64, + l1_cache_kb: int = 32, + l2_cache_kb: int = 4096, + l3_cache_kb: int = 262144, + num_xcd: int = 8, + ): + self._hw = { + "num_cus": num_cus, + "lds_capacity": lds_capacity, + "max_clock_mhz": max_clock_mhz, + "simds_per_cu": simds_per_cu, + "shader_engines": shader_engines, + "max_waves_per_cu": max_waves_per_cu, + "wavefront_size": wavefront_size, + "l1_cache_kb": l1_cache_kb, + "l2_cache_kb": l2_cache_kb, + "l3_cache_kb": l3_cache_kb, + "num_xcd": num_xcd, + "total_simds": num_cus * simds_per_cu, + } + + def get_feature_names(self) -> list[str]: + return [ + # Problem features + "M", + "N", + "K", + "split_k", + "log2_M", + "log2_N", + "log2_K", + "log2_MNK", + "arithmetic_intensity", + "aspect_ratio_mn", + "aspect_ratio_mk", + "aspect_ratio_nk", + "layout", + # Kernel features + "tile_m", + "tile_n", + "tile_k", + "warp_m", + "warp_n", + "warp_k", + "warp_tile_m", + "warp_tile_n", + "warp_tile_k", + "pipeline", + "scheduler", + "epilogue", + "pad_m", + "pad_n", + "pad_k", + "persistent", + "num_warps", + "tile_volume", + "tile_mn", + "lds_usage_estimate", + "lds_usage_ratio", + # Interaction features + "num_tiles_m", + "num_tiles_n", + "num_tiles_k", + "total_output_tiles", + "tile_eff_m", + "tile_eff_n", + "tile_eff_k", + "overall_tile_efficiency", + "cu_utilization", + # P0 FIX: Problem-to-tile ratio features + "ratio_M_to_tile_m", + "ratio_N_to_tile_n", + "ratio_K_to_tile_k", + "problem_smaller_than_tile_m", + "problem_smaller_than_tile_n", + "problem_smaller_than_tile_k", + "any_dim_too_small", + # P1 FIX: Padding requirement interaction features + "needs_padding_m", + "needs_padding_n", + "needs_padding_k", + "has_padding_when_needed_m", + "has_padding_when_needed_n", + "has_padding_when_needed_k", + "missing_required_padding_m", + "missing_required_padding_n", + "missing_required_padding_k", + "missing_any_required_padding", + # Hardware features + "hw_num_cus", + "hw_simds_per_cu", + "hw_total_simds", + "hw_shader_engines", + "hw_max_clock_mhz", + "hw_max_waves_per_cu", + "hw_wavefront_size", + "hw_lds_capacity", + "hw_l1_cache_kb", + "hw_l2_cache_kb", + "hw_l3_cache_kb", + "hw_num_xcd", + ] + + def get_categorical_features(self) -> list[str]: + return ["layout", "pipeline", "scheduler", "epilogue"] + + def extract(self, problem: dict, kernel: dict) -> np.ndarray: + M = int(problem.get("m", problem.get("M", 0))) + N = int(problem.get("n", problem.get("N", 0))) + K = int(problem.get("k", problem.get("K", 0))) + split_k = int(problem.get("split_k", 1)) + dtype = str(problem.get("dtype", "fp8")) + bpe = DTYPE_BYTES.get(dtype, 1.0) + + log2_M = math.log2(max(M, 1)) + log2_N = math.log2(max(N, 1)) + log2_K = math.log2(max(K, 1)) + log2_MNK = math.log2(max(M * N * K, 1)) + + mem_bytes = (M * K + K * N + M * N) * bpe + ai = (2.0 * M * N * K) / max(mem_bytes, 1) + + ar_mn = M / max(N, 1) + ar_mk = M / max(K, 1) + ar_nk = N / max(K, 1) + + layout_code = LAYOUT_MAP.get(str(problem.get("layout", "rcr")), 0) + + tile_m = int(kernel.get("tile_m", 128)) + tile_n = int(kernel.get("tile_n", 128)) + tile_k = int(kernel.get("tile_k", 64)) + warp_m = int(kernel.get("warp_m", 2)) + warp_n = int(kernel.get("warp_n", 2)) + warp_k = int(kernel.get("warp_k", 1)) + warp_tile_m = int(kernel.get("warp_tile_m", 32)) + warp_tile_n = int(kernel.get("warp_tile_n", 32)) + warp_tile_k = int(kernel.get("warp_tile_k", 16)) + + pipeline_code = PIPELINE_MAP.get(str(kernel.get("pipeline", "compv4")), 0) + scheduler_code = SCHEDULER_MAP.get(str(kernel.get("scheduler", "intrawave")), 0) + epilogue_code = EPILOGUE_MAP.get(str(kernel.get("epilogue", "cshuffle")), 0) + + pad_m = float(kernel.get("pad_m", False)) + pad_n = float(kernel.get("pad_n", False)) + pad_k = float(kernel.get("pad_k", False)) + persistent = float(kernel.get("persistent", False)) + + num_warps = warp_m * warp_n * warp_k + tile_volume = tile_m * tile_n * tile_k + tile_mn = tile_m * tile_n + + lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe + lds_cap = self._hw["lds_capacity"] + if str(kernel.get("pipeline", "")).startswith("compv4"): + lds_cap = 32768 + lds_ratio = lds_est / max(lds_cap, 1) + + num_tiles_m = math.ceil(M / max(tile_m, 1)) + num_tiles_n = math.ceil(N / max(tile_n, 1)) + num_tiles_k = math.ceil(K / max(tile_k, 1)) + total_output_tiles = num_tiles_m * num_tiles_n + + rem_m = M % tile_m if tile_m > 0 else 0 + tile_eff_m = rem_m / tile_m if rem_m > 0 else 1.0 + rem_n = N % tile_n if tile_n > 0 else 0 + tile_eff_n = rem_n / tile_n if rem_n > 0 else 1.0 + rem_k = K % tile_k if tile_k > 0 else 0 + tile_eff_k = rem_k / tile_k if rem_k > 0 else 1.0 + overall_eff = tile_eff_m * tile_eff_n * tile_eff_k + + cu_util = total_output_tiles / max(self._hw["num_cus"], 1) + + # P0 FIX: Problem-to-tile ratio features (avoid oversized tiles for tiny problems) + ratio_M_to_tile_m = M / max(tile_m, 1) + ratio_N_to_tile_n = N / max(tile_n, 1) + ratio_K_to_tile_k = K / max(tile_k, 1) + + # Binary features: is problem dimension smaller than tile? + problem_smaller_than_tile_m = float(M < tile_m) + problem_smaller_than_tile_n = float(N < tile_n) + problem_smaller_than_tile_k = float(K < tile_k) + any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k)) + + # P1 FIX: Padding requirement features (does this kernel have padding when needed?) + needs_padding_m = float(M % tile_m != 0) if tile_m > 0 else 0.0 + needs_padding_n = float(N % tile_n != 0) if tile_n > 0 else 0.0 + needs_padding_k = float(K % tile_k != 0) if tile_k > 0 else 0.0 + + # Interaction features: kernel has padding capability when problem needs it + has_padding_when_needed_m = float(needs_padding_m and pad_m) + has_padding_when_needed_n = float(needs_padding_n and pad_n) + has_padding_when_needed_k = float(needs_padding_k and pad_k) + + # Critical feature: missing required padding (kernel will likely fail) + missing_required_padding_m = float(needs_padding_m and not pad_m) + missing_required_padding_n = float(needs_padding_n and not pad_n) + missing_required_padding_k = float(needs_padding_k and not pad_k) + missing_any_required_padding = float( + missing_required_padding_m + or missing_required_padding_n + or missing_required_padding_k + ) + + hw = self._hw + return np.array( + [ + M, + N, + K, + split_k, + log2_M, + log2_N, + log2_K, + log2_MNK, + ai, + ar_mn, + ar_mk, + ar_nk, + layout_code, + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline_code, + scheduler_code, + epilogue_code, + pad_m, + pad_n, + pad_k, + persistent, + num_warps, + tile_volume, + tile_mn, + lds_est, + lds_ratio, + num_tiles_m, + num_tiles_n, + num_tiles_k, + total_output_tiles, + tile_eff_m, + tile_eff_n, + tile_eff_k, + overall_eff, + cu_util, + # P0 FIX: New ratio and binary features + ratio_M_to_tile_m, + ratio_N_to_tile_n, + ratio_K_to_tile_k, + problem_smaller_than_tile_m, + problem_smaller_than_tile_n, + problem_smaller_than_tile_k, + any_dim_too_small, + # P1 FIX: Padding requirement interaction features + needs_padding_m, + needs_padding_n, + needs_padding_k, + has_padding_when_needed_m, + has_padding_when_needed_n, + has_padding_when_needed_k, + missing_required_padding_m, + missing_required_padding_n, + missing_required_padding_k, + missing_any_required_padding, + hw["num_cus"], + hw["simds_per_cu"], + hw["total_simds"], + hw["shader_engines"], + hw["max_clock_mhz"], + hw["max_waves_per_cu"], + hw["wavefront_size"], + hw["lds_capacity"], + hw["l1_cache_kb"], + hw["l2_cache_kb"], + hw["l3_cache_kb"], + hw["num_xcd"], + ], + dtype=np.float64, + ) + + def extract_batch(self, df: pd.DataFrame) -> np.ndarray: + """Vectorized batch extraction -- much faster than row-by-row.""" + n = len(df) + names = self.get_feature_names() + result = np.zeros((n, len(names)), dtype=np.float64) + + M = df["m"].values.astype(np.float64) + N = df["n"].values.astype(np.float64) + K = df["k"].values.astype(np.float64) + split_k = df["split_k"].fillna(1).values.astype(np.float64) + + dtype_col = df["dtype"].fillna("fp8") + bpe = dtype_col.map(DTYPE_BYTES).fillna(1.0).values + + result[:, 0] = M + result[:, 1] = N + result[:, 2] = K + result[:, 3] = split_k + result[:, 4] = np.log2(np.maximum(M, 1)) + result[:, 5] = np.log2(np.maximum(N, 1)) + result[:, 6] = np.log2(np.maximum(K, 1)) + result[:, 7] = np.log2(np.maximum(M * N * K, 1)) + + mem = (M * K + K * N + M * N) * bpe + result[:, 8] = (2.0 * M * N * K) / np.maximum(mem, 1) + result[:, 9] = M / np.maximum(N, 1) + result[:, 10] = M / np.maximum(K, 1) + result[:, 11] = N / np.maximum(K, 1) + + result[:, 12] = df["layout"].map(LAYOUT_MAP).fillna(0).values + + tile_m = df["tile_m"].fillna(128).values.astype(np.float64) + tile_n = df["tile_n"].fillna(128).values.astype(np.float64) + tile_k = df["tile_k"].fillna(64).values.astype(np.float64) + warp_m = df["warp_m"].fillna(2).values.astype(np.float64) + warp_n = df["warp_n"].fillna(2).values.astype(np.float64) + warp_k = df["warp_k"].fillna(1).values.astype(np.float64) + warp_tile_m = df["warp_tile_m"].fillna(32).values.astype(np.float64) + warp_tile_n = df["warp_tile_n"].fillna(32).values.astype(np.float64) + warp_tile_k = df["warp_tile_k"].fillna(16).values.astype(np.float64) + + result[:, 13] = tile_m + result[:, 14] = tile_n + result[:, 15] = tile_k + result[:, 16] = warp_m + result[:, 17] = warp_n + result[:, 18] = warp_k + result[:, 19] = warp_tile_m + result[:, 20] = warp_tile_n + result[:, 21] = warp_tile_k + + result[:, 22] = df["pipeline"].map(PIPELINE_MAP).fillna(0).values + result[:, 23] = df["scheduler"].map(SCHEDULER_MAP).fillna(0).values + result[:, 24] = df["epilogue"].map(EPILOGUE_MAP).fillna(0).values + + result[:, 25] = df["pad_m"].fillna(False).astype(float).values + result[:, 26] = df["pad_n"].fillna(False).astype(float).values + result[:, 27] = df["pad_k"].fillna(False).astype(float).values + result[:, 28] = df["persistent"].fillna(False).astype(float).values + + num_warps = warp_m * warp_n * warp_k + result[:, 29] = num_warps + result[:, 30] = tile_m * tile_n * tile_k + result[:, 31] = tile_m * tile_n + + lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe + result[:, 32] = lds_est + lds_cap = np.full(n, self._hw["lds_capacity"], dtype=np.float64) + is_compv4 = df["pipeline"].fillna("").str.startswith("compv4") + lds_cap[is_compv4] = 32768 + result[:, 33] = lds_est / np.maximum(lds_cap, 1) + + ntm = np.ceil(M / np.maximum(tile_m, 1)) + ntn = np.ceil(N / np.maximum(tile_n, 1)) + ntk = np.ceil(K / np.maximum(tile_k, 1)) + result[:, 34] = ntm + result[:, 35] = ntn + result[:, 36] = ntk + result[:, 37] = ntm * ntn + + rem_m = np.mod(M, np.maximum(tile_m, 1)) + result[:, 38] = np.where(rem_m > 0, rem_m / tile_m, 1.0) + rem_n = np.mod(N, np.maximum(tile_n, 1)) + result[:, 39] = np.where(rem_n > 0, rem_n / tile_n, 1.0) + rem_k = np.mod(K, np.maximum(tile_k, 1)) + result[:, 40] = np.where(rem_k > 0, rem_k / tile_k, 1.0) + result[:, 41] = result[:, 38] * result[:, 39] * result[:, 40] + + result[:, 42] = (ntm * ntn) / max(self._hw["num_cus"], 1) + + # P0 FIX: Problem-to-tile ratio features + result[:, 43] = M / np.maximum(tile_m, 1) # ratio_M_to_tile_m + result[:, 44] = N / np.maximum(tile_n, 1) # ratio_N_to_tile_n + result[:, 45] = K / np.maximum(tile_k, 1) # ratio_K_to_tile_k + + # Binary features: is problem smaller than tile? + result[:, 46] = (M < tile_m).astype(float) # problem_smaller_than_tile_m + result[:, 47] = (N < tile_n).astype(float) # problem_smaller_than_tile_n + result[:, 48] = (K < tile_k).astype(float) # problem_smaller_than_tile_k + result[:, 49] = ((M < tile_m) | (N < tile_n) | (K < tile_k)).astype( + float + ) # any_dim_too_small + + # P1 FIX: Padding requirement features + pad_m_bool = df["pad_m"].fillna(False).astype(bool).values + pad_n_bool = df["pad_n"].fillna(False).astype(bool).values + pad_k_bool = df["pad_k"].fillna(False).astype(bool).values + + needs_padding_m = (np.mod(M, np.maximum(tile_m, 1)) != 0) + needs_padding_n = (np.mod(N, np.maximum(tile_n, 1)) != 0) + needs_padding_k = (np.mod(K, np.maximum(tile_k, 1)) != 0) + + result[:, 50] = needs_padding_m.astype(float) + result[:, 51] = needs_padding_n.astype(float) + result[:, 52] = needs_padding_k.astype(float) + + # Interaction features: kernel has padding when problem needs it + result[:, 53] = (needs_padding_m & pad_m_bool).astype(float) # has_padding_when_needed_m + result[:, 54] = (needs_padding_n & pad_n_bool).astype(float) # has_padding_when_needed_n + result[:, 55] = (needs_padding_k & pad_k_bool).astype(float) # has_padding_when_needed_k + + # Critical feature: missing required padding + result[:, 56] = (needs_padding_m & ~pad_m_bool).astype(float) # missing_required_padding_m + result[:, 57] = (needs_padding_n & ~pad_n_bool).astype(float) # missing_required_padding_n + result[:, 58] = (needs_padding_k & ~pad_k_bool).astype(float) # missing_required_padding_k + result[:, 59] = ((needs_padding_m & ~pad_m_bool) | (needs_padding_n & ~pad_n_bool) | (needs_padding_k & ~pad_k_bool)).astype(float) # missing_any_required_padding + + # Hardware profile features + hw = self._hw + result[:, 60] = hw["num_cus"] + result[:, 61] = hw["simds_per_cu"] + result[:, 62] = hw["total_simds"] + result[:, 63] = hw["shader_engines"] + result[:, 64] = hw["max_clock_mhz"] + result[:, 65] = hw["max_waves_per_cu"] + result[:, 66] = hw["wavefront_size"] + result[:, 67] = hw["lds_capacity"] + result[:, 68] = hw["l1_cache_kb"] + result[:, 69] = hw["l2_cache_kb"] + result[:, 70] = hw["l3_cache_kb"] + result[:, 71] = hw["num_xcd"] + + return result + + def get_parameter_space(self) -> dict[str, list]: + return { + "tile_m": [32, 64, 128, 192, 256], + "tile_n": [32, 64, 128, 192, 256], + "tile_k": [32, 64, 128, 256], + "warp_m": [1, 2, 4], + "warp_n": [1, 2, 4], + "warp_k": [1], + "warp_tile_m": [4, 16, 32, 64], + "warp_tile_n": [4, 16, 32, 64], + "warp_tile_k": [8, 16, 32, 64, 128], + "pipeline": list(PIPELINE_MAP.keys()), + "scheduler": list(SCHEDULER_MAP.keys()), + "epilogue": list(EPILOGUE_MAP.keys()), + "pad_m": [True, False], + "pad_n": [True, False], + "pad_k": [True, False], + "persistent": [True, False], + } + + def get_constraints(self) -> list: + lds_cap = self._hw["lds_capacity"] + + def _lds_constraint(cfg): + tm = cfg.get("tile_m", 128) + tn = cfg.get("tile_n", 128) + tk = cfg.get("tile_k", 64) + bpe = 1.0 # fp8 default + est = (tm * tk + tn * tk) * bpe + cap = ( + 32768 if str(cfg.get("pipeline", "")).startswith("compv4") else lds_cap + ) + return est <= cap + + def _warp_constraint(cfg): + wm = cfg.get("warp_m", 2) + wn = cfg.get("warp_n", 2) + wk = cfg.get("warp_k", 1) + return (wm * wn * wk) in [2, 4, 8] + + return [_lds_constraint, _warp_constraint] diff --git a/dispatcher/heuristics/generate_benchmark_data.py b/dispatcher/heuristics/generate_benchmark_data.py new file mode 100644 index 0000000000..17c76e5967 --- /dev/null +++ b/dispatcher/heuristics/generate_benchmark_data.py @@ -0,0 +1,553 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +GEMM Universal Benchmark Data Generation Script + +This script generates training data for ML-based kernel selection heuristics by: +1. Reading kernel configurations from the tile engine +2. Building benchmark executables (in parallel) +3. Running benchmarks across multiple problem sizes +4. Outputting performance data in JSON format + +Usage: + python generate_benchmark_data.py \ + --build_dir /tmp/build \ + --output_dir /tmp/benchmark_data \ + --dtype fp16 \ + --layout rcr \ + --num_build_jobs 4 \ + --num_benchmark_jobs 1 + +Requirements: + - ROCm-capable GPU + - CK tile engine built with CMake +""" + +import argparse +import json +import subprocess +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import List, Optional, Tuple +import re + + +@dataclass +class KernelConfig: + """Represents a single kernel configuration.""" + + name: str + dtype: str + layout: str + pipeline: str + epilogue: str + scheduler: str + pad_m: bool + pad_n: bool + pad_k: bool + persistent: bool + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + @classmethod + def from_kernel_name(cls, name: str, dtype: str, layout: str) -> "KernelConfig": + """Parse kernel name to extract configuration.""" + # Format: gemm_universal_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{padM}_{padN}_{padK}_{persistent}_{tile_config} + # tile_config: {tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k} + + parts = name.split("_") + prefix = f"gemm_universal_{dtype}_{layout}_" + trait_and_tile = name[len(prefix) :] + trait_parts = trait_and_tile.split("_") + + pipeline = trait_parts[0] + epilogue = trait_parts[1] + scheduler = trait_parts[2] + pad_m = trait_parts[3] == "True" + pad_n = trait_parts[4] == "True" + pad_k = trait_parts[5] == "True" + persistent = trait_parts[6] == "True" + + # Parse tile config + tile_dims = trait_parts[7].split("x") + warp_dims = trait_parts[8].split("x") + warp_tile_dims = trait_parts[9].split("x") + + return cls( + name=name, + dtype=dtype, + layout=layout, + pipeline=pipeline, + epilogue=epilogue, + scheduler=scheduler, + pad_m=pad_m, + pad_n=pad_n, + pad_k=pad_k, + persistent=persistent, + tile_m=int(tile_dims[0]), + tile_n=int(tile_dims[1]), + tile_k=int(tile_dims[2]), + warp_m=int(warp_dims[0]), + warp_n=int(warp_dims[1]), + warp_k=int(warp_dims[2]), + warp_tile_m=int(warp_tile_dims[0]), + warp_tile_n=int(warp_tile_dims[1]), + warp_tile_k=int(warp_tile_dims[2]), + ) + + +@dataclass +class BenchmarkResult: + """Result of a single benchmark run.""" + + kernel_name: str + m: int + n: int + k: int + avg_time_ms: float + tflops: float + is_valid: bool + error: Optional[str] = None + + +@dataclass +class ProblemSize: + """GEMM problem dimensions.""" + + m: int + n: int + k: int + + +def get_problem_sizes() -> List[ProblemSize]: + """ + Generate diverse problem sizes for benchmarking. + + Includes: + - Square matrices (powers of 2) + - Rectangular matrices (common in ML) + - LLM-specific sizes (attention, MLP) + - Edge cases (small, very large) + """ + sizes = [] + + # Powers of 2 (square) + for p in [6, 7, 8, 9, 10, 11, 12, 13]: # 64 to 8192 + dim = 2**p + sizes.append(ProblemSize(dim, dim, dim)) + + # Common ML sizes (batch x hidden) + ml_sizes = [ + (1, 4096, 4096), # Single token inference + (8, 4096, 4096), # Small batch + (32, 4096, 4096), # Medium batch + (128, 4096, 4096), # Large batch + (1, 4096, 11008), # LLaMA MLP up-projection + (1, 11008, 4096), # LLaMA MLP down-projection + (32, 4096, 11008), + (32, 11008, 4096), + (1, 8192, 8192), # Large model + (32, 8192, 8192), + (1, 8192, 28672), # LLaMA-70B MLP + (32, 8192, 28672), + ] + for m, n, k in ml_sizes: + sizes.append(ProblemSize(m, n, k)) + + # Rectangular matrices + rect_sizes = [ + (1024, 4096, 1024), + (4096, 1024, 4096), + (2048, 8192, 2048), + (256, 256, 8192), # Tall K + (8192, 8192, 256), # Short K + ] + for m, n, k in rect_sizes: + sizes.append(ProblemSize(m, n, k)) + + # Remove duplicates while preserving order + seen = set() + unique_sizes = [] + for s in sizes: + key = (s.m, s.n, s.k) + if key not in seen: + seen.add(key) + unique_sizes.append(s) + + return unique_sizes + + +def load_kernel_list(build_dir: Path, dtype: str, layout: str) -> List[KernelConfig]: + """Load kernel configurations from the tile engine build.""" + kernel_list_path = ( + build_dir + / "tile_engine" + / "ops" + / "gemm" + / "gemm_universal" + / dtype + / layout + / "gemm_universal_kernel_list.txt" + ) + + if not kernel_list_path.exists(): + raise FileNotFoundError(f"Kernel list not found: {kernel_list_path}") + + kernels = [] + with open(kernel_list_path, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + # Format: kernel_name|tile_config|trait_combo + parts = line.split("|") + kernel_name = parts[0] + kernels.append(KernelConfig.from_kernel_name(kernel_name, dtype, layout)) + + return kernels + + +def build_kernel(build_dir: Path, kernel: KernelConfig) -> Tuple[str, bool, str]: + """ + Build a single kernel benchmark executable. + + Returns: (kernel_name, success, error_message) + """ + target_name = f"benchmark_{kernel.name}" + + try: + result = subprocess.run( + ["ninja", "-j1", target_name], + cwd=build_dir, + capture_output=True, + text=True, + timeout=300, # 5 minute timeout + ) + + if result.returncode != 0: + return (kernel.name, False, result.stderr[:500]) + + return (kernel.name, True, "") + except subprocess.TimeoutExpired: + return (kernel.name, False, "Build timeout") + except Exception as e: + return (kernel.name, False, str(e)) + + +def run_benchmark( + build_dir: Path, + kernel: KernelConfig, + problem: ProblemSize, + warmup: int = 10, + repeat: int = 50, +) -> BenchmarkResult: + """ + Run benchmark for a single kernel and problem size. + """ + exe_path = build_dir / "bin" / f"benchmark_{kernel.name}" + + if not exe_path.exists(): + return BenchmarkResult( + kernel_name=kernel.name, + m=problem.m, + n=problem.n, + k=problem.k, + avg_time_ms=0, + tflops=0, + is_valid=False, + error="Executable not found", + ) + + try: + result = subprocess.run( + [ + str(exe_path), + f"-m={problem.m}", + f"-n={problem.n}", + f"-k={problem.k}", + f"-warmup={warmup}", + f"-repeat={repeat}", + "-verify=0", + "-json_output=true", + ], + capture_output=True, + text=True, + timeout=120, + ) + + if result.returncode != 0: + # Try to parse error + error = result.stderr[:200] if result.stderr else result.stdout[:200] + return BenchmarkResult( + kernel_name=kernel.name, + m=problem.m, + n=problem.n, + k=problem.k, + avg_time_ms=0, + tflops=0, + is_valid=False, + error=error, + ) + + # Parse JSON output + output = result.stdout.strip() + + # Try to find JSON in output + json_match = re.search(r"\{.*\}", output, re.DOTALL) + if json_match: + data = json.loads(json_match.group()) + # Extract from nested perf_result object + perf = data.get("perf_result", {}) + avg_time_ms = perf.get("latency(ms)", 0) + tflops = perf.get("tflops(TFlops)", 0) + + return BenchmarkResult( + kernel_name=kernel.name, + m=problem.m, + n=problem.n, + k=problem.k, + avg_time_ms=avg_time_ms, + tflops=tflops, + is_valid=True, + ) + else: + # Parse from text output + # Look for patterns like "avg_time: X ms" or "tflops: Y" + avg_time = 0.0 + tflops = 0.0 + + time_match = re.search( + r"(?:avg[_\s]?time|latency)[:\s]+(\d+\.?\d*)\s*(?:ms)?", output, re.I + ) + if time_match: + avg_time = float(time_match.group(1)) + + tflops_match = re.search(r"tflops[:\s]+(\d+\.?\d*)", output, re.I) + if tflops_match: + tflops = float(tflops_match.group(1)) + + # Calculate TFLOPs if not provided + if tflops == 0 and avg_time > 0: + flops = 2.0 * problem.m * problem.n * problem.k + tflops = flops / (avg_time * 1e-3) / 1e12 + + return BenchmarkResult( + kernel_name=kernel.name, + m=problem.m, + n=problem.n, + k=problem.k, + avg_time_ms=avg_time, + tflops=tflops, + is_valid=avg_time > 0, + error=None if avg_time > 0 else "Could not parse output", + ) + + except subprocess.TimeoutExpired: + return BenchmarkResult( + kernel_name=kernel.name, + m=problem.m, + n=problem.n, + k=problem.k, + avg_time_ms=0, + tflops=0, + is_valid=False, + error="Benchmark timeout", + ) + except Exception as e: + return BenchmarkResult( + kernel_name=kernel.name, + m=problem.m, + n=problem.n, + k=problem.k, + avg_time_ms=0, + tflops=0, + is_valid=False, + error=str(e), + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Generate GEMM benchmark data for ML training" + ) + parser.add_argument( + "--build_dir", type=str, default="/tmp/build", help="CK build directory" + ) + parser.add_argument( + "--output_dir", + type=str, + default="/tmp/benchmark_data", + help="Output directory for benchmark results", + ) + parser.add_argument( + "--dtype", + type=str, + default="fp16", + choices=["fp16", "fp8", "bf16", "bf8"], + help="Data type to benchmark", + ) + parser.add_argument( + "--layout", + type=str, + default="rcr", + choices=["rcr", "rrr", "crr", "ccr"], + help="Matrix layout to benchmark", + ) + parser.add_argument( + "--num_build_jobs", type=int, default=4, help="Number of parallel build jobs" + ) + parser.add_argument( + "--num_benchmark_jobs", + type=int, + default=1, + help="Number of parallel benchmark jobs (use 1 for accurate timing)", + ) + parser.add_argument( + "--max_kernels", + type=int, + default=None, + help="Maximum number of kernels to benchmark (for testing)", + ) + parser.add_argument( + "--skip_build", + action="store_true", + help="Skip building and only run benchmarks", + ) + parser.add_argument( + "--warmup", type=int, default=10, help="Number of warmup iterations" + ) + parser.add_argument( + "--repeat", type=int, default=50, help="Number of benchmark iterations" + ) + + args = parser.parse_args() + + build_dir = Path(args.build_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load kernel configurations + print(f"Loading kernel list for {args.dtype}/{args.layout}...") + kernels = load_kernel_list(build_dir, args.dtype, args.layout) + print(f"Found {len(kernels)} kernel configurations") + + if args.max_kernels: + kernels = kernels[: args.max_kernels] + print(f"Limiting to {len(kernels)} kernels") + + # Build kernels + if not args.skip_build: + print( + f"\nBuilding {len(kernels)} kernels with {args.num_build_jobs} parallel jobs..." + ) + build_results = {"success": 0, "failed": 0, "failed_kernels": []} + + with ProcessPoolExecutor(max_workers=args.num_build_jobs) as executor: + futures = {executor.submit(build_kernel, build_dir, k): k for k in kernels} + + for i, future in enumerate(as_completed(futures)): + kernel_name, success, error = future.result() + if success: + build_results["success"] += 1 + else: + build_results["failed"] += 1 + build_results["failed_kernels"].append( + {"name": kernel_name, "error": error} + ) + + if (i + 1) % 10 == 0: + print( + f" Built {i + 1}/{len(kernels)} ({build_results['success']} success, {build_results['failed']} failed)" + ) + + print( + f"\nBuild complete: {build_results['success']} success, {build_results['failed']} failed" + ) + + # Save build results + with open(output_dir / "build_results.json", "w") as f: + json.dump(build_results, f, indent=2) + + # Get problem sizes + problem_sizes = get_problem_sizes() + print(f"\nBenchmarking {len(problem_sizes)} problem sizes...") + + # Run benchmarks + all_results = [] + total_benchmarks = len(kernels) * len(problem_sizes) + completed = 0 + + print(f"Total benchmarks to run: {total_benchmarks}") + + for kernel in kernels: + kernel_results = { + "kernel_config": asdict(kernel), + "benchmarks": [], + } + + for problem in problem_sizes: + result = run_benchmark( + build_dir, + kernel, + problem, + warmup=args.warmup, + repeat=args.repeat, + ) + kernel_results["benchmarks"].append(asdict(result)) + completed += 1 + + if completed % 100 == 0: + print(f" Progress: {completed}/{total_benchmarks} benchmarks complete") + + all_results.append(kernel_results) + + # Save intermediate results + intermediate_file = ( + output_dir / f"benchmark_results_{args.dtype}_{args.layout}_partial.json" + ) + with open(intermediate_file, "w") as f: + json.dump(all_results, f, indent=2) + + # Save final results + final_file = output_dir / f"benchmark_results_{args.dtype}_{args.layout}.json" + with open(final_file, "w") as f: + json.dump( + { + "metadata": { + "dtype": args.dtype, + "layout": args.layout, + "num_kernels": len(kernels), + "num_problems": len(problem_sizes), + "warmup": args.warmup, + "repeat": args.repeat, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + }, + "problem_sizes": [asdict(p) for p in problem_sizes], + "results": all_results, + }, + f, + indent=2, + ) + + print(f"\nResults saved to {final_file}") + + # Print summary + valid_count = sum( + 1 for kr in all_results for br in kr["benchmarks"] if br["is_valid"] + ) + print(f"Valid benchmarks: {valid_count}/{total_benchmarks}") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/heuristics/generate_edge_dims.py b/dispatcher/heuristics/generate_edge_dims.py new file mode 100644 index 0000000000..f5d243a5a9 --- /dev/null +++ b/dispatcher/heuristics/generate_edge_dims.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Supplementary edge-case benchmark generator for N=1 and K=1 dimensions. + +These shapes represent vector-matrix multiply (N=1), rank-1 updates (K=1), +and other degenerate GEMM cases that stress tile efficiency and padding logic. +""" + +import json +import subprocess +import sys +from pathlib import Path + + +def generate_edge_shapes(): + """Generate shapes with N=1, K=1, and other single-dimension edge cases.""" + shapes = set() + + # --- N=1: vector-matrix multiply / single output column --- + for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]: + for k in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]: + shapes.add((m, 1, k)) + + # --- K=1: rank-1 update / outer product --- + for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]: + for n in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]: + shapes.add((m, n, 1)) + + # --- M=1, N=1: dot product --- + for k in [1, 16, 64, 256, 1024, 4096, 8192]: + shapes.add((1, 1, k)) + + # --- M=1, K=1: scalar-vector --- + for n in [1, 16, 64, 256, 1024, 4096, 8192]: + shapes.add((1, n, 1)) + + # --- N=1, K=1: scalar-vector --- + for m in [1, 16, 64, 256, 1024, 4096, 8192]: + shapes.add((m, 1, 1)) + + # --- All ones: 1x1x1 --- + shapes.add((1, 1, 1)) + + # --- Small N (2-16) --- + for m in [64, 256, 1024, 4096]: + for n in [2, 3, 4, 7, 8, 15, 16]: + for k in [64, 256, 1024, 4096]: + shapes.add((m, n, k)) + + # --- Small K (2-16) --- + for m in [64, 256, 1024, 4096]: + for n in [64, 256, 1024, 4096]: + for k in [2, 3, 4, 7, 8, 15, 16]: + shapes.add((m, n, k)) + + return sorted(shapes) + + +def run_shapes(bin_dir, shapes, out_file, warmup=3, repeat=10): + """Run all kernels against shapes, writing streaming log.""" + executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*")) + if not executables: + print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr) + return 0 + + total = 0 + for idx, (m, n, k) in enumerate(shapes): + out_file.write("\n========================================\n") + out_file.write(f"Shape {idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n") + out_file.write("========================================\n") + out_file.write(f"Found {len(executables)} kernels\n") + out_file.flush() + + for exe in executables: + try: + result = subprocess.run( + [ + str(exe), + f"-m={m}", + f"-n={n}", + f"-k={k}", + f"-warmup={warmup}", + f"-repeat={repeat}", + "-verify=0", + ], + capture_output=True, + text=True, + timeout=60, + ) + output = result.stdout + json_start = output.find("{") + json_end = output.rfind("}") + 1 + if json_start >= 0 and json_end > json_start: + json_block = output[json_start:json_end] + try: + json.loads(json_block) + out_file.write(json_block + "\n") + total += 1 + except json.JSONDecodeError: + pass + except (subprocess.TimeoutExpired, Exception): + pass + + out_file.flush() + print( + f" Shape {idx + 1}/{len(shapes)}: M={m} N={n} K={k}", + file=sys.stderr, + flush=True, + ) + + return total + + +if __name__ == "__main__": + bin_dir = "/workspace/ck_tile/bin" + out_dir = Path("data/edge_dims") + out_dir.mkdir(parents=True, exist_ok=True) + + shapes = generate_edge_shapes() + print(f"Generated {len(shapes)} edge-case shapes", file=sys.stderr, flush=True) + + n1_count = sum(1 for m, n, k in shapes if n == 1) + k1_count = sum(1 for m, n, k in shapes if k == 1) + both1 = sum(1 for m, n, k in shapes if n == 1 and k == 1) + small_n = sum(1 for m, n, k in shapes if 2 <= n <= 16) + small_k = sum(1 for m, n, k in shapes if 2 <= k <= 16) + print( + f" N=1: {n1_count}, K=1: {k1_count}, both=1: {both1}", + file=sys.stderr, + flush=True, + ) + print( + f" Small N(2-16): {small_n}, Small K(2-16): {small_k}", + file=sys.stderr, + flush=True, + ) + + batch_size = 25 + total = 0 + batch_idx = 0 + for i in range(0, len(shapes), batch_size): + batch = shapes[i : i + batch_size] + batch_idx += 1 + out_path = out_dir / f"edge_dims_batch_{batch_idx:03d}.log" + print( + f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}", + file=sys.stderr, + flush=True, + ) + + with open(out_path, "w") as f: + f.write(f"CK Tile Edge Dims Benchmark Batch {batch_idx}\n") + f.write("GPU ID: 0\nImplementation: gemm_universal\n\n") + count = run_shapes(bin_dir, batch, f, warmup=3, repeat=10) + total += count + + print(f" Batch {batch_idx} done: {count} results", file=sys.stderr, flush=True) + + print( + f"\nTotal: {total} benchmarks across {len(shapes)} shapes", + file=sys.stderr, + flush=True, + ) diff --git a/dispatcher/heuristics/generate_wide_coverage.py b/dispatcher/heuristics/generate_wide_coverage.py new file mode 100644 index 0000000000..e8e8116946 --- /dev/null +++ b/dispatcher/heuristics/generate_wide_coverage.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Wide-coverage benchmark data generator. + +Generates benchmark results for hundreds of diverse GEMM shapes across all +corner cases: skinny M, tall N, deep K, M=1, prime dimensions, power-of-2, +LLM inference shapes, training shapes, and edge cases. + +Runs all 4608 kernels in /workspace/ck_tile/bin/ against each shape and +writes streaming log output parseable by data_pipeline.py. + +Usage: + python3 generate_wide_coverage.py --bin_dir /workspace/ck_tile/bin \ + --out_dir data/ --batch_size 20 --warmup 3 --repeat 10 +""" + +import argparse +import json +import subprocess +import sys +from pathlib import Path + + +def generate_shape_list(): + """Generate a comprehensive list of (M, N, K) shapes covering all corner cases. + + Categories: + 1. M=1 (single token inference) -- the hardest case + 2. Tiny M (2-16) -- small batch inference + 3. Small M (32-128) -- medium batch + 4. Medium M (256-2048) -- large batch / training + 5. Large M (4096-20480) -- very large batch + 6. Square shapes (powers of 2) + 7. Skinny M, tall N (M << N) + 8. Tall M, skinny N (M >> N) + 9. Deep K (K >> M, N) -- compute-bound + 10. Shallow K (K << M, N) -- memory-bound + 11. Prime dimensions -- worst-case for tiling + 12. LLM-specific shapes (DeepSeek, LLaMA, etc.) + 13. Non-power-of-2 common sizes + """ + shapes = set() + + # --- 1. M=1 (single token) across various N, K --- + for n in [512, 1024, 1536, 2048, 3072, 4096, 4608, 7168, 8192, 11008, 14336, 28672]: + for k in [256, 512, 1024, 1536, 2048, 2304, 4096, 7168, 8192]: + shapes.add((1, n, k)) + + # --- 2. Tiny M (2-16) --- + for m in [2, 4, 8, 16]: + for n in [512, 1536, 4096, 7168]: + for k in [256, 1024, 4096, 7168]: + shapes.add((m, n, k)) + + # --- 3. Small M (32-128) --- + for m in [32, 48, 64, 96, 128]: + for n in [512, 1536, 4096, 7168, 8192]: + for k in [256, 512, 2048, 4096, 7168]: + shapes.add((m, n, k)) + + # --- 4. Medium M (256-2048) --- + for m in [256, 384, 512, 768, 1024, 1536, 2048]: + for n in [512, 1536, 4096, 7168]: + for k in [256, 1024, 2048, 4096, 7168]: + shapes.add((m, n, k)) + + # --- 5. Large M (4096-20480) --- + for m in [4096, 8192, 12288, 16384, 20480]: + for n in [512, 1536, 4096, 7168]: + for k in [256, 1024, 2048, 7168]: + shapes.add((m, n, k)) + + # --- 6. Square shapes (powers of 2) --- + for p in range(5, 14): # 32 to 8192 + d = 2**p + shapes.add((d, d, d)) + + # --- 7. Skinny M, tall N --- + for m in [1, 4, 16, 64]: + for n in [8192, 16384, 28672]: + for k in [1024, 4096, 8192]: + shapes.add((m, n, k)) + + # --- 8. Tall M, skinny N --- + for m in [4096, 8192, 16384]: + for n in [32, 64, 128, 256]: + for k in [1024, 4096]: + shapes.add((m, n, k)) + + # --- 9. Deep K (K >> M, N) --- + for m in [16, 64, 256]: + for n in [16, 64, 256]: + for k in [4096, 8192, 16384, 32768]: + shapes.add((m, n, k)) + + # --- 10. Shallow K (K << M, N) --- + for m in [1024, 4096, 8192]: + for n in [1024, 4096, 8192]: + for k in [16, 32, 64, 128]: + shapes.add((m, n, k)) + + # --- 11. Prime dimensions --- + primes = [17, 31, 37, 127, 251, 509, 1021, 2039, 4093] + for p in primes: + shapes.add((p, p, p)) + for p in primes[:5]: + shapes.add((p, 4096, 4096)) + shapes.add((4096, p, 4096)) + shapes.add((4096, 4096, p)) + + # --- 12. LLM-specific shapes --- + llm_shapes = [ + # DeepSeek MoE + (1, 1536, 7168), + (1, 4608, 7168), + (1, 7168, 2048), + (1, 7168, 2304), + (1, 7168, 256), + (1, 576, 7168), + (1, 512, 7168), + (1, 3072, 1536), + # LLaMA-7B + (1, 4096, 4096), + (32, 4096, 4096), + (128, 4096, 4096), + (1, 4096, 11008), + (32, 4096, 11008), + (1, 11008, 4096), + (32, 11008, 4096), + # LLaMA-70B + (1, 8192, 8192), + (32, 8192, 8192), + (128, 8192, 8192), + (1, 8192, 28672), + (32, 8192, 28672), + (1, 28672, 8192), + # GPT-style attention + (128, 128, 64), + (128, 128, 128), + (256, 256, 64), + (512, 512, 64), + (1024, 1024, 64), + (2048, 2048, 64), + ] + for s in llm_shapes: + shapes.add(s) + + # --- 13. Non-power-of-2 common sizes --- + for m in [48, 96, 192, 384, 576, 768, 1152, 1536, 2304, 3072, 4608, 6144]: + shapes.add((m, m, m)) + shapes.add((m, 4096, 4096)) + + sorted_shapes = sorted(shapes) + return sorted_shapes + + +def run_shape_batch(bin_dir, shapes, out_file, warmup=3, repeat=10): + """Run all kernels against a batch of shapes, writing streaming log output.""" + executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*")) + if not executables: + print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr) + return 0 + + total_benchmarks = 0 + + for shape_idx, (m, n, k) in enumerate(shapes): + out_file.write("\n========================================\n") + out_file.write( + f"Shape {shape_idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n" + ) + out_file.write("========================================\n") + out_file.write(f"Found {len(executables)} kernels\n") + out_file.flush() + + for exe in executables: + try: + result = subprocess.run( + [ + str(exe), + f"-m={m}", + f"-n={n}", + f"-k={k}", + f"-warmup={warmup}", + f"-repeat={repeat}", + "-verify=0", + ], + capture_output=True, + text=True, + timeout=60, + ) + output = result.stdout + # Extract JSON block from output + json_start = output.find("{") + json_end = output.rfind("}") + 1 + if json_start >= 0 and json_end > json_start: + json_block = output[json_start:json_end] + try: + json.loads(json_block) + out_file.write(json_block + "\n") + total_benchmarks += 1 + except json.JSONDecodeError: + pass + except (subprocess.TimeoutExpired, Exception): + pass + + out_file.flush() + elapsed_kernels = len(executables) + print( + f" Shape {shape_idx + 1}/{len(shapes)}: M={m} N={n} K={k} " + f"({elapsed_kernels} kernels)", + file=sys.stderr, + flush=True, + ) + + return total_benchmarks + + +def main(): + parser = argparse.ArgumentParser( + description="Generate wide-coverage benchmark data" + ) + parser.add_argument( + "--bin_dir", + default="/workspace/ck_tile/bin", + help="Directory with benchmark executables", + ) + parser.add_argument("--out_dir", default="data", help="Output directory") + parser.add_argument( + "--batch_size", type=int, default=25, help="Shapes per output file" + ) + parser.add_argument("--warmup", type=int, default=3) + parser.add_argument("--repeat", type=int, default=10) + parser.add_argument( + "--max_shapes", type=int, default=None, help="Limit total shapes (for testing)" + ) + args = parser.parse_args() + + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + shapes = generate_shape_list() + if args.max_shapes: + shapes = shapes[: args.max_shapes] + + print(f"Generated {len(shapes)} unique shapes", file=sys.stderr, flush=True) + print(f"Bin dir: {args.bin_dir}", file=sys.stderr, flush=True) + print(f"Output dir: {args.out_dir}", file=sys.stderr, flush=True) + print(f"Batch size: {args.batch_size}", file=sys.stderr, flush=True) + + total = 0 + batch_idx = 0 + for i in range(0, len(shapes), args.batch_size): + batch = shapes[i : i + args.batch_size] + batch_idx += 1 + out_path = out_dir / f"wide_coverage_batch_{batch_idx:03d}.log" + + print( + f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}", + file=sys.stderr, + flush=True, + ) + + with open(out_path, "w") as f: + f.write(f"CK Tile Wide Coverage Benchmark Batch {batch_idx}\n") + f.write("GPU ID: 0\n") + f.write("Implementation: gemm_universal\n\n") + count = run_shape_batch( + args.bin_dir, batch, f, warmup=args.warmup, repeat=args.repeat + ) + total += count + + print( + f" Batch {batch_idx} complete: {count} benchmarks", + file=sys.stderr, + flush=True, + ) + + print( + f"\nTotal: {total} benchmarks across {len(shapes)} shapes", + file=sys.stderr, + flush=True, + ) + + +if __name__ == "__main__": + main() diff --git a/dispatcher/heuristics/ml_heuristic_sweep.py b/dispatcher/heuristics/ml_heuristic_sweep.py new file mode 100644 index 0000000000..7190a19678 --- /dev/null +++ b/dispatcher/heuristics/ml_heuristic_sweep.py @@ -0,0 +1,867 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +ML Heuristic Sweep: Comprehensive GEMM Performance Evaluation + +Sweeps across diverse problem shapes with ML-based kernel selection to measure +TFLOPS performance. Supports multiple dtypes (fp16, bf16, fp8) and validates +ML model predictions by executing kernels on GPU. + +Shape Constraints (fp16/bf16 on gfx950): +- M >= 1 (any M is valid) +- N % 8 == 0 AND N >= 64 +- K % 2 == 0 AND K >= 32 + +Usage: + python ml_heuristic_sweep.py --dtype fp16 --num_shapes 256 + python ml_heuristic_sweep.py --dtypes fp16 bf16 --output sweep_results.csv + python ml_heuristic_sweep.py --dtype fp16 --dry_run # Prediction only, no GPU execution +""" + +import sys +import argparse +import time +import csv +from pathlib import Path +from dataclasses import dataclass +from typing import List, Tuple + +# Add parent directories to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent / "python")) + +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, +) + +try: + from predict import Predictor + # from feature_engine import GemmUniversalFeatureEngine + + HAS_ML = True +except ImportError: + HAS_ML = False + print("WARNING: ML heuristic modules not available. Will use first-fit selection.") + + +@dataclass +class KernelSpec: + """Kernel specification for ML heuristic""" + + name: str + tile_m: int + tile_n: int + tile_k: int + pipeline: str = "compv3" + scheduler: str = "intrawave" + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + warp_m: int = 32 + warp_n: int = 32 + warp_k: int = 16 + + +# Comprehensive kernel pool covering diverse tile sizes and configurations +KERNEL_POOL = [ + # Small tiles (64x64) + KernelSpec( + "s_64x64_k32_v3", 64, 64, 32, "compv3", "intrawave", 2, 2, 1, 16, 16, 16 + ), + KernelSpec( + "s_64x64_k64_v3", 64, 64, 64, "compv3", "intrawave", 2, 2, 1, 16, 16, 16 + ), + KernelSpec( + "s_64x64_k128_v3", 64, 64, 128, "compv3", "intrawave", 2, 2, 1, 16, 16, 16 + ), + KernelSpec( + "s_64x64_k64_v4", 64, 64, 64, "compv4", "intrawave", 2, 2, 1, 16, 16, 16 + ), + KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", "intrawave", 2, 2, 1, 16, 16, 16), + KernelSpec( + "s_64x64_k128_mem", 64, 64, 128, "mem", "intrawave", 2, 2, 1, 16, 16, 16 + ), + # Medium tiles (128x128) + KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3", "intrawave"), + KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3", "intrawave"), + KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3", "intrawave"), + KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4", "intrawave"), + KernelSpec("m_128x128_k128_v4", 128, 128, 128, "compv4", "intrawave"), + KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem", "intrawave"), + KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem", "intrawave"), + # Rectangular medium (M != N) + KernelSpec( + "r_64x128_k32_v3", 64, 128, 32, "compv3", "intrawave", 2, 2, 1, 16, 32, 16 + ), + KernelSpec( + "r_128x64_k32_v3", 128, 64, 32, "compv3", "intrawave", 2, 2, 1, 32, 16, 16 + ), + KernelSpec( + "r_64x128_k64_v3", 64, 128, 64, "compv3", "intrawave", 2, 2, 1, 16, 32, 16 + ), + KernelSpec( + "r_128x64_k64_v3", 128, 64, 64, "compv3", "intrawave", 2, 2, 1, 32, 16, 16 + ), + KernelSpec( + "r_64x256_k32_v3", 64, 256, 32, "compv3", "intrawave", 2, 2, 1, 16, 32, 16 + ), + KernelSpec( + "r_256x64_k32_v3", 256, 64, 32, "compv3", "intrawave", 2, 2, 1, 32, 16, 16 + ), + # Large tiles (256x256) + KernelSpec("l_256x128_k32_v3", 256, 128, 32, "compv3", "intrawave"), + KernelSpec("l_128x256_k32_v3", 128, 256, 32, "compv3", "intrawave"), + KernelSpec("l_256x256_k32_v3", 256, 256, 32, "compv3", "intrawave"), + KernelSpec("l_256x256_k64_v3", 256, 256, 64, "compv3", "intrawave"), + KernelSpec("l_256x256_k64_v4", 256, 256, 64, "compv4", "intrawave"), + # Interwave variants + KernelSpec("m_128x128_k64_iw_v3", 128, 128, 64, "compv3", "interwave"), + KernelSpec("m_128x128_k128_iw_v3", 128, 128, 128, "compv3", "interwave"), + KernelSpec("l_256x256_k32_iw_v3", 256, 256, 32, "compv3", "interwave"), +] + + +def generate_problem_shapes(num_shapes: int = 1024) -> List[Tuple[int, int, int]]: + """ + Generate diverse problem shapes with hardware constraints: + - M >= 1 (any M is valid, including tiny M for inference) + - N % 8 == 0 AND N >= 64 (hardware alignment requirement) + - K % 2 == 0 AND K >= 32 (fp16 requirement) + + Covers: + - Powers of 2 (square and rectangular) + - ML workloads (LLM attention, MLP, batch inference) + - Non-power-of-2 dimensions (aligned to constraints) + - Edge cases (tiny M, very large matrices, extreme aspect ratios) + """ + shapes = [] + + # 1. Powers of 2 - Square (64 to 8192) with K variations + for p in range(6, 14): # 2^6=64 to 2^13=8192 + dim = 2**p + shapes.append((dim, dim, dim)) + if dim >= 128: + # K variations (must be even and >= 32) + shapes.append((dim, dim, dim // 2)) + shapes.append((dim, dim, dim * 2)) + shapes.append((dim, dim, max(32, dim // 4))) + + # 2. Small batch inference (1-256 batch, common hidden dims) + # N must be multiple of 8 and >= 64 + hidden_dims = [768, 1024, 2048, 3072, 4096, 5120, 8192, 11008, 12288, 16384] + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + + for hidden in hidden_dims: + for batch in batch_sizes[:8]: + shapes.append((batch, hidden, hidden)) + if hidden >= 4096: + # LLM MLP projections (ensure K is even) + k_mlp = hidden * 3 // 4 + if k_mlp % 2 == 1: + k_mlp += 1 # Make even + if k_mlp >= 32: + shapes.append((batch, hidden, k_mlp)) + shapes.append((batch, k_mlp, hidden)) + + # 3. Attention patterns (seq_len x head_dim) + # seq_len can be any value >= 1, total_dim must be multiple of 8 + seq_lens = [128, 256, 512, 1024, 2048, 4096, 8192] + head_dims = [64, 80, 96, 128, 256] + num_heads = [8, 12, 16, 32, 40, 64] + + for seq in seq_lens: + for head_dim in head_dims: + for nh in num_heads[:4]: + total_dim = nh * head_dim + # total_dim should be multiple of 8 (naturally satisfied for most cases) + if total_dim % 8 == 0 and total_dim >= 64: + # head_dim must be even for K + if head_dim % 2 == 0 and head_dim >= 32: + shapes.append((seq, total_dim, head_dim)) + shapes.append((seq, head_dim, total_dim)) + + # 4. Rectangular matrices (extreme aspect ratios) + # All dims must satisfy constraints + dims_m = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] + dims_n = [64, 128, 256, 512, 1024, 2048, 4096, 8192] # N >= 64, N % 8 == 0 + dims_k = [ + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + ] # K >= 32, K % 2 == 0 + + # Sample to avoid explosion + for i, m in enumerate(dims_m): + for j, n in enumerate(dims_n): + for _l, k in enumerate(dims_k): + if (i + j + _l) % 3 == 0: # Stratified sampling + shapes.append((m, n, k)) + + # 5. Non-power-of-2 dimensions (aligned to constraints) + # N values: multiples of 8, >= 64 + non_pow2_n = [ + 72, + 80, + 88, + 96, + 104, + 112, + 120, + 136, + 144, + 152, + 160, + 176, + 184, + 192, + 200, + 224, + 240, + 272, + 288, + 304, + 320, + 336, + 352, + 368, + 384, + 400, + 416, + 448, + 480, + 544, + 576, + 640, + 672, + 704, + 736, + 768, + 800, + 832, + 896, + 960, + 1088, + 1152, + 1216, + 1280, + 1344, + 1408, + 1472, + 1536, + 1600, + 1664, + 1728, + 1792, + 1856, + 1920, + 2176, + 2304, + 2432, + 2560, + 2688, + 2816, + 2944, + 3072, + 3200, + 3328, + 3456, + 3584, + 3712, + 3840, + 3968, + 4224, + 4352, + 4480, + 4608, + 4736, + 4864, + 4992, + ] + + # K values: even numbers >= 32 + non_pow2_k = [ + 34, + 36, + 38, + 40, + 42, + 44, + 48, + 50, + 52, + 56, + 60, + 66, + 68, + 72, + 76, + 80, + 88, + 96, + 100, + 112, + 120, + 136, + 144, + 160, + 176, + 192, + 224, + 240, + 272, + 288, + 320, + 352, + 384, + 416, + 448, + 480, + 544, + 576, + 640, + 672, + 704, + 768, + 800, + 832, + 896, + 960, + 1088, + 1152, + 1280, + 1344, + 1408, + 1536, + 1600, + 1664, + 1792, + 1920, + ] + + # M values: any value >= 1 + non_pow2_m = [ + 1, + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 23, + 27, + 31, + 33, + 37, + 41, + 47, + 51, + 57, + 63, + 65, + 71, + 79, + 87, + 95, + 97, + 111, + 119, + 127, + 129, + 143, + 159, + 175, + 191, + 193, + 223, + 239, + 255, + 257, + 287, + 319, + 351, + 383, + 385, + 447, + 479, + 511, + 513, + 575, + 639, + 703, + 767, + 769, + 895, + 959, + 1023, + 1025, + ] + + # Sample non-power-of-2 shapes + for i, m in enumerate(non_pow2_m[:30]): + for j, n in enumerate(non_pow2_n[:20]): + for _l, k in enumerate(non_pow2_k[:15]): + if (i + j + _l) % 4 == 0: # Stratified sampling + shapes.append((m, n, k)) + + # 6. Very tall K (memory-bound) - ensure N % 8 == 0, K % 2 == 0 + for mn in [64, 128, 256, 512, 1024]: + for k in [4096, 8192, 16384]: + shapes.append((mn, mn, k)) + + # 7. Very short K (compute-bound) - ensure K >= 32, K % 2 == 0 + for mn in [512, 1024, 2048, 4096]: + for k in [32, 64, 128]: + shapes.append((mn, mn, k)) + + # 8. Tiny M (edge cases for batch-1 inference) + for m in [1, 2, 4, 8, 16, 32]: + for n in [64, 128, 256, 512, 1024, 2048]: # N >= 64, N % 8 == 0 + for k in [32, 64, 128, 256, 512]: # K >= 32, K % 2 == 0 + shapes.append((m, n, k)) + + # 9. Stress test sizes (aligned to constraints) + stress_sizes = [ + (10000, 10000, 10000), + (1000, 10000, 1000), + (1000, 1000, 10000), + (5000, 5000, 5000), + (7168, 7168, 7168), # Common LLM hidden dim + (8192, 11008, 8192), # LLaMA MLP dimensions + ] + shapes.extend(stress_sizes) + + # Remove duplicates while preserving order + seen = set() + unique_shapes = [] + for s in shapes: + if s not in seen: + seen.add(s) + unique_shapes.append(s) + + # Filter to ensure all shapes meet constraints + valid_shapes = [] + for m, n, k in unique_shapes: + if m >= 1 and n >= 64 and n % 8 == 0 and k >= 32 and k % 2 == 0: + valid_shapes.append((m, n, k)) + + # Sample down to target number if we have too many + if len(valid_shapes) > num_shapes: + # Stratified sampling to preserve diversity + step = len(valid_shapes) / num_shapes + valid_shapes = [valid_shapes[int(i * step)] for i in range(num_shapes)] + + return valid_shapes + + +def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict: + """Convert KernelSpec to feature dict for ML predictor""" + return { + "kernel_name": spec.name, + "tile_m": spec.tile_m, + "tile_n": spec.tile_n, + "tile_k": spec.tile_k, + "warp_m": spec.wave_m, + "warp_n": spec.wave_n, + "warp_k": spec.wave_k, + "warp_tile_m": spec.warp_m, + "warp_tile_n": spec.warp_n, + "warp_tile_k": spec.warp_k, + "pipeline": spec.pipeline, + "scheduler": spec.scheduler, + "epilogue": "cshuffle", + "pad_m": True, # Enable padding to support arbitrary M dimensions + "pad_n": True, # Enable padding to support arbitrary N dimensions + "pad_k": True, # Enable padding to support arbitrary K dimensions + "persistent": False, + "dtype": dtype, + "layout": layout, + } + + +def spec_to_kernel_config( + spec: KernelSpec, dtype: str, arch: str, dtype_acc: str = "fp32" +) -> KernelConfig: + """Convert KernelSpec to KernelConfig for dispatcher""" + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc=dtype_acc, + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=spec.tile_m, + tile_n=spec.tile_n, + tile_k=spec.tile_k, + wave_m=spec.wave_m, + wave_n=spec.wave_n, + wave_k=spec.wave_k, + warp_m=spec.warp_m, + warp_n=spec.warp_n, + warp_k=spec.warp_k, + pipeline=spec.pipeline, + scheduler=spec.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +def ml_select_kernel( + predictor, pool: List[KernelSpec], M: int, N: int, K: int, dtype: str, layout: str +) -> Tuple[KernelSpec, float]: + """Use ML model to select best kernel""" + if not HAS_ML or predictor is None: + # Fallback: select first kernel + return pool[0], 0.0 + + problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1} + kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool] + + ranked = predictor.rank_kernels(problem, kernel_dicts) + if not ranked: + return pool[0], 0.0 + + best_name, best_tflops = ranked[0] + best_spec = next((s for s in pool if s.name == best_name), pool[0]) + return best_spec, best_tflops + + +def run_single_gemm( + M: int, + N: int, + K: int, + dtype: str, + arch: str, + predictor, + dry_run: bool = False, + dtype_acc: str = "fp32", +) -> dict: + """Run a single GEMM with ML heuristic selection""" + + # Select kernel via ML heuristic + t0 = time.time() + best_spec, pred_tflops = ml_select_kernel( + predictor, KERNEL_POOL, M, N, K, dtype, "rcr" + ) + select_time_ms = (time.time() - t0) * 1000 + + result = { + "M": M, + "N": N, + "K": K, + "dtype": dtype, + "selected_kernel": best_spec.name, + "predicted_tflops": pred_tflops, + "selection_time_ms": select_time_ms, + "actual_time_ms": 0, + "actual_tflops": 0, + "status": "SKIP" if dry_run else "PENDING", + "error": None, + } + + if dry_run: + return result + + # Build and run kernel + config = spec_to_kernel_config(best_spec, dtype, arch, dtype_acc) + + try: + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"sweep_{dtype}_{best_spec.name}", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + result["status"] = "BUILD_FAIL" + result["error"] = "Failed to build kernel" + cleanup_gemm() + return result + + dispatcher = setup.dispatcher + if not dispatcher.is_supported(M, N, K): + result["status"] = "UNSUPPORTED" + result["error"] = "Problem size not supported by kernel" + cleanup_gemm() + return result + + # Create input data + np_dtype = {"fp16": np.float16, "bf16": np.float16, "fp8": np.float16}[dtype] + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Run GEMM + exec_result = dispatcher.run(A, B, M, N, K) + + if exec_result.success: + result["actual_time_ms"] = exec_result.time_ms + result["actual_tflops"] = exec_result.tflops + result["status"] = "SUCCESS" + else: + # Decode status code for better error message + status_messages = { + 0: "Success", + -1: "GPU/HIP error (check permissions, memory, or kernel validity)", + -2: "No suitable kernel found for this problem size", + } + error_msg = status_messages.get(exec_result.status, f"Unknown error (status={exec_result.status})") + result["status"] = "RUN_FAIL" + result["error"] = f"{error_msg} (status_code={exec_result.status})" + + # Print detailed error for debugging + print(f" ERROR: {error_msg}") + print(f" Status code: {exec_result.status}") + print(f" Time returned: {exec_result.time_ms}") + print(f" Kernel: {exec_result.kernel_name}") + + cleanup_gemm() + + except Exception as e: + result["status"] = "ERROR" + result["error"] = str(e)[:200] + cleanup_gemm() + + return result + + +def main(): + parser = argparse.ArgumentParser( + description="ML Heuristic Sweep: Test GEMM across many shapes and dtypes" + ) + parser.add_argument( + "--dtypes", + nargs="+", + default=["fp16"], + choices=["fp16", "bf16", "fp8"], + help="Data types to test (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx950", help="GPU architecture (default: gfx950)" + ) + parser.add_argument( + "--dtype_acc", + default="fp32", + choices=["fp16", "fp32"], + help="Accumulator data type (default: fp32)", + ) + parser.add_argument( + "--model_dir", + default=None, + help="Path to model directory (auto-detect if not specified)", + ) + parser.add_argument( + "--num_shapes", + type=int, + default=256, + help="Number of problem shapes to test (default: 256)", + ) + parser.add_argument( + "--output", + default="ml_heuristic_sweep_results.csv", + help="Output CSV file path", + ) + parser.add_argument( + "--dry_run", + action="store_true", + help="Only predict, do not run kernels (fast validation)", + ) + + args = parser.parse_args() + + # Setup ML predictor + predictor = None + if HAS_ML: + if args.model_dir is None: + # Auto-detect model directory based on first dtype + first_dtype = args.dtypes[0] + heuristics_dir = Path(__file__).parent + model_candidates = [ + heuristics_dir / "models" / f"gemm_universal_{first_dtype}_{args.arch}", + ] + for model_dir in model_candidates: + if model_dir.exists(): + args.model_dir = str(model_dir) + break + + if args.model_dir and Path(args.model_dir).exists(): + try: + predictor = Predictor(args.model_dir) + print(f"✓ Loaded ML model from: {args.model_dir}") + except Exception as e: + print(f"⚠ Failed to load ML model: {e}") + print(" Will use first-fit selection instead") + else: + print(f"⚠ Model directory not found: {args.model_dir}") + print(" Will use first-fit selection instead") + + # Generate problem shapes + print(f"\nGenerating {args.num_shapes} problem shapes...") + shapes = generate_problem_shapes(args.num_shapes) + print( + f"✓ Generated {len(shapes)} valid shapes (M>=1, N%8==0, N>=64, K%2==0, K>=32)" + ) + + # Validate all shapes meet constraints + invalid = [ + (m, n, k) + for m, n, k in shapes + if not (m >= 1 and n >= 64 and n % 8 == 0 and k >= 32 and k % 2 == 0) + ] + if invalid: + print(f"⚠ WARNING: {len(invalid)} shapes violate constraints!") + print(f" First few: {invalid[:5]}") + + # Print configuration + print("\n" + "=" * 80) + print(" ML Heuristic Sweep Configuration") + print("=" * 80) + print( + f" Model: {args.model_dir if args.model_dir else 'first-fit (no ML)'}" + ) + print(f" Data types: {', '.join(args.dtypes)}") + print(f" Accumulator: {args.dtype_acc}") + print(f" Architecture: {args.arch}") + print(f" Kernel pool: {len(KERNEL_POOL)} kernels") + print(f" Problem shapes: {len(shapes)}") + print(f" Total tests: {len(shapes) * len(args.dtypes)}") + print( + f" Mode: {'DRY RUN (prediction only)' if args.dry_run else 'FULL RUN (execute kernels)'}" + ) + print(f" Output: {args.output}") + print("=" * 80) + + # Open output CSV + csv_file = open(args.output, "w", newline="") + csv_writer = csv.DictWriter( + csv_file, + fieldnames=[ + "dtype", + "M", + "N", + "K", + "selected_kernel", + "predicted_tflops", + "selection_time_ms", + "actual_time_ms", + "actual_tflops", + "status", + "error", + ], + ) + csv_writer.writeheader() + + # Run sweep + total_tests = len(shapes) * len(args.dtypes) + completed = 0 + start_time = time.time() + + print("\nStarting sweep... (Ctrl+C to stop and save partial results)\n") + + try: + for dtype in args.dtypes: + print(f"\n{'=' * 80}") + print(f" Testing dtype: {dtype.upper()}") + print(f"{'=' * 80}\n") + + for i, (M, N, K) in enumerate(shapes): + result = run_single_gemm( + M, N, K, dtype, args.arch, predictor, args.dry_run, args.dtype_acc + ) + + # Write to CSV + csv_writer.writerow(result) + csv_file.flush() + + completed += 1 + + # Progress update + if completed % 10 == 0 or result["status"] != "SUCCESS": + elapsed = time.time() - start_time + rate = completed / elapsed if elapsed > 0 else 0 + eta = (total_tests - completed) / rate if rate > 0 else 0 + + status_emoji = { + "SUCCESS": "✓", + "SKIP": "→", + "BUILD_FAIL": "✗", + "UNSUPPORTED": "○", + "RUN_FAIL": "✗", + "ERROR": "✗", + }.get(result["status"], "?") + + print( + f" [{completed:4d}/{total_tests}] {status_emoji} " + f"{dtype:4s} {M:5d}x{N:5d}x{K:5d} → " + f"{result['selected_kernel']:20s} " + f"pred={result['predicted_tflops']:6.1f} " + f"actual={result['actual_tflops']:6.1f} TFLOPS " + f"[{rate:.1f} tests/s, ETA {eta / 60:.1f}m]" + ) + + except KeyboardInterrupt: + print(f"\n\n⚠ Interrupted! Saving partial results to {args.output}...") + + finally: + csv_file.close() + + # Summary + print("\n" + "=" * 80) + print(" SWEEP COMPLETE") + print("=" * 80) + + # Read back results and compute statistics + results = [] + with open(args.output, "r") as f: + reader = csv.DictReader(f) + results = list(reader) + + print(f"\n Total tests: {len(results)}") + print(f" Output file: {args.output}") + + if not args.dry_run: + success = [r for r in results if r["status"] == "SUCCESS"] + print( + f" Successful: {len(success)} ({100 * len(success) / len(results):.1f}%)" + ) + + if success: + avg_tflops = np.mean([float(r["actual_tflops"]) for r in success]) + max_tflops = max([float(r["actual_tflops"]) for r in success]) + print(f" Avg TFLOPS: {avg_tflops:.2f}") + print(f" Max TFLOPS: {max_tflops:.2f}") + + # Per-dtype breakdown + for dtype in args.dtypes: + dtype_results = [r for r in success if r["dtype"] == dtype] + if dtype_results: + avg = np.mean([float(r["actual_tflops"]) for r in dtype_results]) + print( + f" {dtype:4s}: {avg:.2f} TFLOPS (n={len(dtype_results)})" + ) + + print("=" * 80) + print() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/heuristics/models/gemm_universal_fp16_gfx950/feature_spec.json b/dispatcher/heuristics/models/gemm_universal_fp16_gfx950/feature_spec.json new file mode 100644 index 0000000000..dc4ed02e5e --- /dev/null +++ b/dispatcher/heuristics/models/gemm_universal_fp16_gfx950/feature_spec.json @@ -0,0 +1,113 @@ +{ + "op_type": "gemm_universal", + "dtype": "fp16", + "arch": "gfx950", + "feature_names": [ + "M", + "N", + "K", + "split_k", + "log2_M", + "log2_N", + "log2_K", + "log2_MNK", + "arithmetic_intensity", + "aspect_ratio_mn", + "aspect_ratio_mk", + "aspect_ratio_nk", + "layout", + "tile_m", + "tile_n", + "tile_k", + "warp_m", + "warp_n", + "warp_k", + "warp_tile_m", + "warp_tile_n", + "warp_tile_k", + "pipeline", + "scheduler", + "epilogue", + "pad_m", + "pad_n", + "pad_k", + "persistent", + "num_warps", + "tile_volume", + "tile_mn", + "lds_usage_estimate", + "lds_usage_ratio", + "num_tiles_m", + "num_tiles_n", + "num_tiles_k", + "total_output_tiles", + "tile_eff_m", + "tile_eff_n", + "tile_eff_k", + "overall_tile_efficiency", + "cu_utilization", + "ratio_M_to_tile_m", + "ratio_N_to_tile_n", + "ratio_K_to_tile_k", + "problem_smaller_than_tile_m", + "problem_smaller_than_tile_n", + "problem_smaller_than_tile_k", + "any_dim_too_small", + "needs_padding_m", + "needs_padding_n", + "needs_padding_k", + "has_padding_when_needed_m", + "has_padding_when_needed_n", + "has_padding_when_needed_k", + "missing_required_padding_m", + "missing_required_padding_n", + "missing_required_padding_k", + "missing_any_required_padding", + "hw_num_cus", + "hw_simds_per_cu", + "hw_total_simds", + "hw_shader_engines", + "hw_max_clock_mhz", + "hw_max_waves_per_cu", + "hw_wavefront_size", + "hw_lds_capacity", + "hw_l1_cache_kb", + "hw_l2_cache_kb", + "hw_l3_cache_kb", + "hw_num_xcd" + ], + "categorical_features": [ + "layout", + "pipeline", + "scheduler", + "epilogue" + ], + "targets": [ + "tflops", + "latency", + "bandwidth" + ], + "log_targets": [ + "bandwidth", + "tflops" + ], + "params": { + "objective": "regression", + "metric": [ + "rmse", + "mae" + ], + "num_leaves": 255, + "max_depth": 15, + "n_estimators": 2000, + "learning_rate": 0.02, + "min_child_samples": 10, + "subsample": 0.85, + "colsample_bytree": 0.85, + "reg_alpha": 0.05, + "reg_lambda": 0.5, + "verbose": -1, + "n_jobs": 8, + "seed": 42 + } +} \ No newline at end of file diff --git a/dispatcher/heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm.gz b/dispatcher/heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm.gz new file mode 100644 index 0000000000..a59cc73c4f Binary files /dev/null and b/dispatcher/heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm.gz differ diff --git a/dispatcher/heuristics/models/gemm_universal_fp16_gfx950/train_manifest.json b/dispatcher/heuristics/models/gemm_universal_fp16_gfx950/train_manifest.json new file mode 100644 index 0000000000..7028dc32fa --- /dev/null +++ b/dispatcher/heuristics/models/gemm_universal_fp16_gfx950/train_manifest.json @@ -0,0 +1,10 @@ +{ + "warm_start_from": null, + "prev_n_estimators": 0, + "new_n_estimators": 2000, + "total_n_estimators": 2000, + "data_rows": 25600, + "valid_rows": 21920, + "unique_shapes": 25, + "timestamp": "2026-03-20T05:00:55" +} \ No newline at end of file diff --git a/dispatcher/heuristics/models/gemm_universal_fp8_gfx950/feature_spec.json b/dispatcher/heuristics/models/gemm_universal_fp8_gfx950/feature_spec.json new file mode 100644 index 0000000000..ffc4052d9b --- /dev/null +++ b/dispatcher/heuristics/models/gemm_universal_fp8_gfx950/feature_spec.json @@ -0,0 +1,113 @@ +{ + "op_type": "gemm_universal", + "dtype": "fp8", + "arch": "gfx950", + "feature_names": [ + "M", + "N", + "K", + "split_k", + "log2_M", + "log2_N", + "log2_K", + "log2_MNK", + "arithmetic_intensity", + "aspect_ratio_mn", + "aspect_ratio_mk", + "aspect_ratio_nk", + "layout", + "tile_m", + "tile_n", + "tile_k", + "warp_m", + "warp_n", + "warp_k", + "warp_tile_m", + "warp_tile_n", + "warp_tile_k", + "pipeline", + "scheduler", + "epilogue", + "pad_m", + "pad_n", + "pad_k", + "persistent", + "num_warps", + "tile_volume", + "tile_mn", + "lds_usage_estimate", + "lds_usage_ratio", + "num_tiles_m", + "num_tiles_n", + "num_tiles_k", + "total_output_tiles", + "tile_eff_m", + "tile_eff_n", + "tile_eff_k", + "overall_tile_efficiency", + "cu_utilization", + "ratio_M_to_tile_m", + "ratio_N_to_tile_n", + "ratio_K_to_tile_k", + "problem_smaller_than_tile_m", + "problem_smaller_than_tile_n", + "problem_smaller_than_tile_k", + "any_dim_too_small", + "needs_padding_m", + "needs_padding_n", + "needs_padding_k", + "has_padding_when_needed_m", + "has_padding_when_needed_n", + "has_padding_when_needed_k", + "missing_required_padding_m", + "missing_required_padding_n", + "missing_required_padding_k", + "missing_any_required_padding", + "hw_num_cus", + "hw_simds_per_cu", + "hw_total_simds", + "hw_shader_engines", + "hw_max_clock_mhz", + "hw_max_waves_per_cu", + "hw_wavefront_size", + "hw_lds_capacity", + "hw_l1_cache_kb", + "hw_l2_cache_kb", + "hw_l3_cache_kb", + "hw_num_xcd" + ], + "categorical_features": [ + "layout", + "pipeline", + "scheduler", + "epilogue" + ], + "targets": [ + "tflops", + "latency", + "bandwidth" + ], + "log_targets": [ + "bandwidth", + "tflops" + ], + "params": { + "objective": "regression", + "metric": [ + "rmse", + "mae" + ], + "num_leaves": 255, + "max_depth": 15, + "n_estimators": 2000, + "learning_rate": 0.02, + "min_child_samples": 10, + "subsample": 0.85, + "colsample_bytree": 0.85, + "reg_alpha": 0.05, + "reg_lambda": 0.5, + "verbose": -1, + "n_jobs": 8, + "seed": 42 + } +} \ No newline at end of file diff --git a/dispatcher/heuristics/models/gemm_universal_fp8_gfx950/model_tflops.lgbm.gz b/dispatcher/heuristics/models/gemm_universal_fp8_gfx950/model_tflops.lgbm.gz new file mode 100644 index 0000000000..a2a08ee01a Binary files /dev/null and b/dispatcher/heuristics/models/gemm_universal_fp8_gfx950/model_tflops.lgbm.gz differ diff --git a/dispatcher/heuristics/models/gemm_universal_fp8_gfx950/train_manifest.json b/dispatcher/heuristics/models/gemm_universal_fp8_gfx950/train_manifest.json new file mode 100644 index 0000000000..d7ce61d2ff --- /dev/null +++ b/dispatcher/heuristics/models/gemm_universal_fp8_gfx950/train_manifest.json @@ -0,0 +1,10 @@ +{ + "warm_start_from": null, + "prev_n_estimators": 0, + "new_n_estimators": 2000, + "total_n_estimators": 2000, + "data_rows": 1296528, + "valid_rows": 1253076, + "unique_shapes": 168, + "timestamp": "2026-03-19T06:10:29" +} \ No newline at end of file diff --git a/dispatcher/heuristics/predict.py b/dispatcher/heuristics/predict.py new file mode 100644 index 0000000000..8738c76f23 --- /dev/null +++ b/dispatcher/heuristics/predict.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Predictor for CK Tile kernel performance. + +Loads trained LightGBM models and provides: + - predict_tflops(): predicted TFLOPS for a single (problem, kernel) pair + - predict_latency(): predicted latency in ms + - predict_bandwidth(): predicted bandwidth in GB/s + - predict_all(): all three predictions at once + - rank_kernels(): rank all candidate kernels by predicted TFLOPS + - select_best(): return the best kernel ID + +Usage: + predictor = Predictor("models/gemm_universal_fp8_gfx950") + best_kernel = predictor.select_best( + problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"}, + kernel_configs=[...], + ) +""" + +import gzip +import json +from pathlib import Path +from typing import Optional + +import lightgbm as lgb +import numpy as np +import pandas as pd + +from feature_engine import GemmUniversalFeatureEngine + + +class Predictor: + """Loads trained models and feature spec for kernel performance prediction. + + Parameters + ---------- + model_dir : str or Path + Directory containing model artifacts: + - model_tflops.lgbm (required) + - model_latency.lgbm (optional) + - model_bandwidth.lgbm (optional) + - feature_spec.json (required) + + feature_engine : FeatureEngine, optional + Override the feature engine. If None, constructs one from feature_spec.json. + """ + + def __init__(self, model_dir: str | Path, feature_engine=None): + self._model_dir = Path(model_dir) + self._models: dict[str, lgb.Booster] = {} + + spec_path = self._model_dir / "feature_spec.json" + if spec_path.exists(): + with open(spec_path) as f: + self._spec = json.load(f) + else: + self._spec = {} + + self._log_targets = set(self._spec.get("log_targets", [])) + + if feature_engine is not None: + self._feature_engine = feature_engine + else: + self._feature_engine = GemmUniversalFeatureEngine() + + def _load_model(self, target: str) -> Optional[lgb.Booster]: + """Lazy-load a model for the given target. + + Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist. + The decompressed file is cached to disk for subsequent loads. + """ + if target in self._models: + return self._models[target] + + path = self._model_dir / f"model_{target}.lgbm" + gz_path = self._model_dir / f"model_{target}.lgbm.gz" + + # Auto-decompress if needed + if not path.exists() and gz_path.exists(): + with gzip.open(gz_path, 'rb') as f_in: + with open(path, 'wb') as f_out: + f_out.write(f_in.read()) + + if not path.exists(): + return None + + model = lgb.Booster(model_file=str(path)) + self._models[target] = model + return model + + def _predict_single(self, target: str, problem: dict, kernel_config: dict) -> float: + """Predict a single target value, applying inverse log transform if needed.""" + model = self._load_model(target) + if model is None: + raise FileNotFoundError(f"No model_{target}.lgbm in {self._model_dir}") + features = self._feature_engine.extract(problem, kernel_config) + raw = float(model.predict(features.reshape(1, -1))[0]) + if target in self._log_targets: + return float(np.expm1(raw)) + # Clamp to non-negative even for non-log models + return float(max(0.0, raw)) + + def predict_tflops(self, problem: dict, kernel_config: dict) -> float: + """Predict TFLOPS for a single (problem, kernel) pair. + + Returns a real TFLOPS estimate (interpretable, usable as DE surrogate). + If the model was trained in log-space, the inverse transform is applied + automatically. + """ + return self._predict_single("tflops", problem, kernel_config) + + def predict_latency(self, problem: dict, kernel_config: dict) -> float: + """Predict latency in milliseconds for a single (problem, kernel) pair.""" + return self._predict_single("latency", problem, kernel_config) + + def predict_bandwidth(self, problem: dict, kernel_config: dict) -> float: + """Predict bandwidth in GB/s for a single (problem, kernel) pair.""" + return self._predict_single("bandwidth", problem, kernel_config) + + def predict_all(self, problem: dict, kernel_config: dict) -> dict[str, float]: + """Predict all available targets for a single (problem, kernel) pair. + + Returns dict with keys 'tflops', 'latency_ms', 'bandwidth_gb_s' (if models exist). + + Note: Applies inverse log transform for targets in log_targets and clamps + negatives to 0.0, consistent with _predict_single(). + """ + features = self._feature_engine.extract(problem, kernel_config).reshape(1, -1) + result = {} + for target, key in [ + ("tflops", "tflops"), + ("latency", "latency_ms"), + ("bandwidth", "bandwidth_gb_s"), + ]: + model = self._load_model(target) + if model is not None: + raw = float(model.predict(features)[0]) + # Apply inverse log transform if model was trained in log-space + if target in self._log_targets: + result[key] = float(np.expm1(raw)) + else: + # Clamp to non-negative even for non-log models + result[key] = float(max(0.0, raw)) + return result + + def rank_kernels( + self, problem: dict, kernel_configs: list[dict] + ) -> list[tuple[str, float]]: + """Rank candidate kernels by predicted TFLOPS (descending). + + Parameters + ---------- + problem : dict + Problem specification with keys: m, n, k, dtype, layout, split_k. + kernel_configs : list of dict + Each dict must have a 'kernel_name' key plus kernel parameters. + + Returns + ------- + list of (kernel_name, predicted_tflops) tuples, sorted descending. + """ + if not kernel_configs: + return [] + + model = self._load_model("tflops") + if model is None: + raise FileNotFoundError(f"No model_tflops.lgbm in {self._model_dir}") + + rows = [] + for kc in kernel_configs: + merged = {**problem, **kc} + rows.append(merged) + + df = pd.DataFrame(rows) + X = self._feature_engine.extract_batch(df) + preds = model.predict(X) + if "tflops" in self._log_targets: + preds = np.expm1(preds) + + results = [] + for i, kc in enumerate(kernel_configs): + name = kc.get("kernel_name", f"kernel_{i}") + results.append((name, float(preds[i]))) + + results.sort(key=lambda x: -x[1]) + return results + + def select_best(self, problem: dict, kernel_configs: list[dict]) -> str: + """Return the kernel_name of the best predicted kernel.""" + ranked = self.rank_kernels(problem, kernel_configs) + if not ranked: + raise ValueError("No kernel configs provided") + return ranked[0][0] + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Predict kernel performance") + parser.add_argument( + "--model_dir", required=True, help="Directory with trained models" + ) + parser.add_argument("--m", type=int, required=True) + parser.add_argument("--n", type=int, required=True) + parser.add_argument("--k", type=int, required=True) + parser.add_argument("--layout", default="rcr") + parser.add_argument("--dtype", default="fp8") + args = parser.parse_args() + + predictor = Predictor(args.model_dir) + problem = { + "m": args.m, + "n": args.n, + "k": args.k, + "dtype": args.dtype, + "layout": args.layout, + "split_k": 1, + } + + print(f"Loading models from {args.model_dir}...") + print( + f"Problem: M={args.m} N={args.n} K={args.k} dtype={args.dtype} layout={args.layout}" + ) + + data_dir = Path(args.model_dir).parent.parent / "data" + if data_dir.exists(): + for pq in data_dir.glob("*.parquet"): + df = pd.read_parquet(pq) + kernel_names = df["kernel_name"].unique() + configs = [] + for kn in kernel_names[:10]: + row = df[df["kernel_name"] == kn].iloc[0] + configs.append(row.to_dict()) + if configs: + ranked = predictor.rank_kernels(problem, configs) + print(f"\nTop 5 kernels (from {len(configs)} candidates):") + for name, tflops in ranked[:5]: + print(f" {tflops:8.2f} TFLOPS {name}") + break diff --git a/dispatcher/heuristics/search.py b/dispatcher/heuristics/search.py new file mode 100644 index 0000000000..f9b7e13b09 --- /dev/null +++ b/dispatcher/heuristics/search.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Surrogate search for CK Tile kernel configuration optimization. + +Uses a trained LGBMRegressor as a cheap surrogate function to search the +discrete kernel parameter space (tile sizes, warp config, pipeline, etc.) +without running actual GPU benchmarks. + +Strategies: + - 'random': Sample N random valid configs, score all, return top-K. + - 'de': Discrete Differential Evolution with mutation over valid parameter choices. + +Usage: + from search import SurrogateSearch + from predict import Predictor + + predictor = Predictor("models/gemm_universal_fp8_gfx950") + searcher = SurrogateSearch(predictor, strategy='random') + results = searcher.search( + problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"}, + budget=500, + ) + # results: [(config_dict, predicted_tflops), ...] sorted descending +""" + +import random +from typing import Optional + +import numpy as np + +from feature_engine import GemmUniversalFeatureEngine +from predict import Predictor + + +class SurrogateSearch: + """Search kernel parameter space using ML regressor as surrogate objective. + + Parameters + ---------- + predictor : Predictor + Trained predictor with a TFLOPS model. + feature_engine : GemmUniversalFeatureEngine, optional + Feature engine for parameter space and validation. If None, uses default. + strategy : str + Search strategy: 'random' or 'de' (Discrete Differential Evolution). + seed : int + Random seed for reproducibility. + """ + + def __init__( + self, + predictor: Predictor, + feature_engine: Optional[GemmUniversalFeatureEngine] = None, + strategy: str = "random", + seed: int = 42, + ): + self._predictor = predictor + self._fe = feature_engine or GemmUniversalFeatureEngine() + self._strategy = strategy + self._rng = random.Random(seed) + self._np_rng = np.random.RandomState(seed) + self._param_space = self._fe.get_parameter_space() + + def _sample_random_config(self) -> dict: + """Sample a single random config from the parameter space.""" + config = {} + for param, values in self._param_space.items(): + config[param] = self._rng.choice(values) + return config + + def _sample_valid_config(self, max_attempts: int = 50) -> Optional[dict]: + """Sample a random config that passes all validation constraints.""" + for _ in range(max_attempts): + config = self._sample_random_config() + if self._fe.validate_config(config): + return config + return None + + def _score_config(self, problem: dict, config: dict) -> float: + """Score a config using the predictor.""" + return self._predictor.predict_tflops(problem, config) + + def _search_random( + self, problem: dict, budget: int, top_k: int + ) -> list[tuple[dict, float]]: + """Random search: sample valid configs, score all, return top-K.""" + configs = [] + for _ in range(budget): + cfg = self._sample_valid_config() + if cfg is not None: + configs.append(cfg) + + if not configs: + return [] + + scored = [] + for cfg in configs: + try: + score = self._score_config(problem, cfg) + scored.append((cfg, score)) + except Exception: + continue + + scored.sort(key=lambda x: -x[1]) + return scored[:top_k] + + def _search_de( + self, + problem: dict, + budget: int, + top_k: int, + pop_size: int = 20, + mutation_rate: float = 0.3, + crossover_rate: float = 0.7, + ) -> list[tuple[dict, float]]: + """Discrete Differential Evolution. + + Uses discrete mutation: randomly swap parameters to other valid values + from the parameter space (no continuous relaxation + snap). + + Each generation: + 1. For each member of the population, create a trial vector by: + - Selecting 3 random other members (a, b, c) + - For each parameter, with probability mutation_rate, take the value + from a, b, or c (uniform choice among the three donors) + - With probability crossover_rate, take the trial value; otherwise keep original + 2. Validate the trial; if invalid, resample that parameter from the space + 3. Score the trial; if better than parent, replace + """ + param_names = list(self._param_space.keys()) + + population = [] + for _ in range(pop_size): + cfg = self._sample_valid_config() + if cfg is not None: + score = self._score_config(problem, cfg) + population.append((cfg, score)) + + if len(population) < 4: + return self._search_random(problem, budget, top_k) + + evals_used = len(population) + max_gens = (budget - evals_used) // pop_size + + for gen in range(max_gens): + new_pop = [] + for i, (parent, parent_score) in enumerate(population): + candidates = [j for j in range(len(population)) if j != i] + if len(candidates) < 3: + new_pop.append((parent, parent_score)) + continue + + a_idx, b_idx, c_idx = self._rng.sample(candidates, 3) + a, b, c = ( + population[a_idx][0], + population[b_idx][0], + population[c_idx][0], + ) + + trial = dict(parent) + for param in param_names: + if self._rng.random() < mutation_rate: + donor = self._rng.choice([a, b, c]) + trial[param] = donor.get(param, parent.get(param)) + + if self._rng.random() > crossover_rate: + trial[param] = parent.get(param) + + if not self._fe.validate_config(trial): + for param in param_names: + if param in trial and trial[param] not in self._param_space.get( + param, [trial[param]] + ): + trial[param] = self._rng.choice(self._param_space[param]) + if not self._fe.validate_config(trial): + new_pop.append((parent, parent_score)) + continue + + try: + trial_score = self._score_config(problem, trial) + evals_used += 1 + except Exception: + new_pop.append((parent, parent_score)) + continue + + if trial_score > parent_score: + new_pop.append((trial, trial_score)) + else: + new_pop.append((parent, parent_score)) + + population = new_pop + + population.sort(key=lambda x: -x[1]) + return population[:top_k] + + def search( + self, + problem: dict, + budget: int = 500, + top_k: int = 10, + **kwargs, + ) -> list[tuple[dict, float]]: + """Search the kernel parameter space for the best configuration. + + Parameters + ---------- + problem : dict + Problem specification: m, n, k, dtype, layout, split_k. + budget : int + Maximum number of surrogate evaluations. + top_k : int + Number of top configurations to return. + **kwargs + Strategy-specific parameters (pop_size, mutation_rate, etc.). + + Returns + ------- + list of (config_dict, predicted_tflops), sorted descending by TFLOPS. + """ + if self._strategy == "random": + return self._search_random(problem, budget, top_k) + elif self._strategy == "de": + return self._search_de(problem, budget, top_k, **kwargs) + else: + raise ValueError(f"Unknown strategy: {self._strategy}") + + +if __name__ == "__main__": + import argparse + import time + + parser = argparse.ArgumentParser( + description="Surrogate search for optimal kernel config" + ) + parser.add_argument("--model_dir", required=True) + parser.add_argument("--m", type=int, required=True) + parser.add_argument("--n", type=int, required=True) + parser.add_argument("--k", type=int, required=True) + parser.add_argument("--dtype", default="fp8") + parser.add_argument("--layout", default="rcr") + parser.add_argument("--strategy", default="random", choices=["random", "de"]) + parser.add_argument("--budget", type=int, default=500) + parser.add_argument("--top_k", type=int, default=10) + args = parser.parse_args() + + predictor = Predictor(args.model_dir) + searcher = SurrogateSearch(predictor, strategy=args.strategy) + problem = { + "m": args.m, + "n": args.n, + "k": args.k, + "dtype": args.dtype, + "layout": args.layout, + "split_k": 1, + } + + print(f"Searching with strategy={args.strategy}, budget={args.budget}...") + t0 = time.time() + results = searcher.search(problem, budget=args.budget, top_k=args.top_k) + elapsed = time.time() - t0 + + print(f"\nTop {len(results)} configs found in {elapsed * 1000:.1f}ms:") + for i, (cfg, tflops) in enumerate(results): + tile_str = f"{cfg.get('tile_m', '?')}x{cfg.get('tile_n', '?')}x{cfg.get('tile_k', '?')}" + warp_str = f"{cfg.get('warp_m', '?')}x{cfg.get('warp_n', '?')}x{cfg.get('warp_k', '?')}" + print( + f" #{i + 1}: {tflops:8.2f} TFLOPS tile={tile_str} warp={warp_str} " + f"pipeline={cfg.get('pipeline', '?')} scheduler={cfg.get('scheduler', '?')}" + ) diff --git a/dispatcher/heuristics/tests/__init__.py b/dispatcher/heuristics/tests/__init__.py new file mode 100644 index 0000000000..1df4857184 --- /dev/null +++ b/dispatcher/heuristics/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT diff --git a/dispatcher/heuristics/tests/test_data_pipeline.py b/dispatcher/heuristics/tests/test_data_pipeline.py new file mode 100644 index 0000000000..d643138693 --- /dev/null +++ b/dispatcher/heuristics/tests/test_data_pipeline.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for data_pipeline.py. + +Covers: kernel name parsing, layout derivation, streaming log parsing, +schema validation, and corner cases (empty logs, malformed JSON, single-shape). +""" + +import tempfile +from pathlib import Path + +import pandas as pd +import pytest + +import sys + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from data_pipeline import ( + parse_kernel_name, + _layout_from_problem, + parse_streaming_log, + save_parquet, + load_parquet, + CANONICAL_COLUMNS, +) + + +# --------------------------------------------------------------------------- +# parse_kernel_name +# --------------------------------------------------------------------------- + + +class TestParseKernelName: + def test_standard_name(self): + name = "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128" + result = parse_kernel_name(name) + assert result["dtype"] == "fp8" + assert result["layout"] == "rcr" + assert result["pipeline"] == "compv3" + assert result["epilogue"] == "cshuffle" + assert result["scheduler"] == "intrawave" + assert result["pad_m"] is False + assert result["pad_n"] is False + assert result["pad_k"] is False + assert result["persistent"] is False + assert result["tile_m"] == 128 + assert result["tile_n"] == 128 + assert result["tile_k"] == 128 + assert result["warp_m"] == 1 + assert result["warp_n"] == 4 + assert result["warp_k"] == 1 + assert result["warp_tile_m"] == 16 + assert result["warp_tile_n"] == 16 + assert result["warp_tile_k"] == 128 + + def test_with_padding_and_persistent(self): + name = "gemm_universal_fp16_rrr_compv4_default_interwave_True_True_True_True_256x256x64_2x2x1_32x32x16" + result = parse_kernel_name(name) + assert result["dtype"] == "fp16" + assert result["layout"] == "rrr" + assert result["pad_m"] is True + assert result["pad_n"] is True + assert result["pad_k"] is True + assert result["persistent"] is True + assert result["tile_m"] == 256 + + def test_empty_name(self): + assert parse_kernel_name("") == {} + + def test_malformed_name(self): + assert parse_kernel_name("not_a_kernel_name") == {} + + def test_partial_name(self): + result = parse_kernel_name("gemm_universal_fp8_rcr_compv3") + assert result.get("dtype") == "fp8" + assert result.get("layout") == "rcr" + assert "tile_m" not in result # not enough parts + + def test_all_layouts(self): + for layout in ["rcr", "rrr", "crr", "ccr"]: + name = f"gemm_universal_fp8_{layout}_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128" + result = parse_kernel_name(name) + assert result["layout"] == layout + + +# --------------------------------------------------------------------------- +# _layout_from_problem +# --------------------------------------------------------------------------- + + +class TestLayoutFromProblem: + def test_rcr(self): + assert ( + _layout_from_problem( + { + "layout_a": "RowMajor", + "layout_b": "ColumnMajor", + "layout_c": "RowMajor", + } + ) + == "rcr" + ) + + def test_rrr(self): + assert ( + _layout_from_problem( + {"layout_a": "RowMajor", "layout_b": "RowMajor", "layout_c": "RowMajor"} + ) + == "rrr" + ) + + def test_empty(self): + assert _layout_from_problem({}) == "???" + + def test_case_insensitive(self): + assert ( + _layout_from_problem( + { + "layout_a": "rowmajor", + "layout_b": "COLUMNMAJOR", + "layout_c": "RowMajor", + } + ) + == "rcr" + ) + + +# --------------------------------------------------------------------------- +# parse_streaming_log +# --------------------------------------------------------------------------- + +SAMPLE_LOG = """\ +================================================================================ +LOG FILE: test.log +================================================================================ +CK Tile Profiling Run +GPU ID: 0 + +--- Running CK Tile benchmarks on GPU 0 --- + +======================================== +Shape 1: M=16 N=1536 K=7168 dtype=fp8 layout=rcr +======================================== +Found 2 kernels +{ + "name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128", + "problem": { + "split_k":1, + "m":16, + "n":1536, + "k":7168, + "stride_a":7168, + "stride_b":7168, + "stride_c":1536, + "dtype_a":"fp8", + "dtype_b":"fp8", + "dtype_acc":"fp32", + "dtype_c":"fp16", + "layout_a":"RowMajor", + "layout_b":"ColumnMajor", + "layout_c":"RowMajor", + "structured_sparsity":false +}, + "perf_result": { + "latency(ms)": 0.04, + "tflops(TFlops)": 8.81, + "bandwidth(GB/s)": 279.51 +} +} +{ + "name": "gemm_universal_fp8_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_2x2x1_32x32x16", + "problem": { + "split_k":1, + "m":16, + "n":1536, + "k":7168, + "stride_a":7168, + "stride_b":7168, + "stride_c":1536, + "dtype_a":"fp8", + "dtype_b":"fp8", + "dtype_acc":"fp32", + "dtype_c":"fp16", + "layout_a":"RowMajor", + "layout_b":"ColumnMajor", + "layout_c":"RowMajor", + "structured_sparsity":false +}, + "perf_result": { + "latency(ms)": 0.05, + "tflops(TFlops)": 7.22, + "bandwidth(GB/s)": 228.85 +} +} + +======================================== +Shape 2: M=20480 N=7168 K=256 dtype=fp8 layout=rcr +======================================== +Found 1 kernels +{ + "name": "gemm_universal_fp8_rcr_mem_cshuffle_intrawave_False_False_False_True_64x64x128_1x4x1_16x16x32", + "problem": { + "split_k":1, + "m":20480, + "n":7168, + "k":256, + "stride_a":256, + "stride_b":256, + "stride_c":7168, + "dtype_a":"fp8", + "dtype_b":"fp8", + "dtype_acc":"fp32", + "dtype_c":"fp16", + "layout_a":"RowMajor", + "layout_b":"ColumnMajor", + "layout_c":"RowMajor", + "structured_sparsity":false +}, + "perf_result": { + "latency(ms)": 0.15, + "tflops(TFlops)": 505.00, + "bandwidth(GB/s)": 1200.50 +} +} +""" + + +class TestParseStreamingLog: + def _write_log(self, content: str) -> Path: + f = tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) + f.write(content) + f.close() + return Path(f.name) + + def test_basic_parse(self): + path = self._write_log(SAMPLE_LOG) + df = parse_streaming_log(path, arch="gfx950") + assert len(df) == 3 + assert df["arch"].iloc[0] == "gfx950" + assert df["m"].tolist() == [16, 16, 20480] + assert df["n"].tolist() == [1536, 1536, 7168] + assert df["k"].tolist() == [7168, 7168, 256] + + def test_tflops_values(self): + path = self._write_log(SAMPLE_LOG) + df = parse_streaming_log(path) + assert df["measured_tflops"].tolist() == pytest.approx([8.81, 7.22, 505.0]) + + def test_kernel_config_parsed(self): + path = self._write_log(SAMPLE_LOG) + df = parse_streaming_log(path) + assert df["tile_m"].iloc[0] == 128 + assert df["pipeline"].iloc[0] == "compv3" + assert df["pipeline"].iloc[1] == "compv4" + + def test_layout_derived_from_json(self): + path = self._write_log(SAMPLE_LOG) + df = parse_streaming_log(path) + assert all(df["layout"] == "rcr") + + def test_empty_log(self): + path = self._write_log("No shapes here\nJust noise\n") + df = parse_streaming_log(path) + assert len(df) == 0 + for col in CANONICAL_COLUMNS: + assert col in df.columns + + def test_single_kernel(self): + log = """\ +Shape 1: M=1 N=1 K=1 dtype=fp8 layout=rcr +{ + "name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128", + "problem": {"split_k":1, "m":1, "n":1, "k":1, "dtype_a":"fp8", "dtype_b":"fp8", "layout_a":"RowMajor", "layout_b":"ColumnMajor", "layout_c":"RowMajor"}, + "perf_result": {"latency(ms)": 0.001, "tflops(TFlops)": 0.002, "bandwidth(GB/s)": 0.01} +} +""" + path = self._write_log(log) + df = parse_streaming_log(path) + assert len(df) == 1 + assert df["m"].iloc[0] == 1 + assert bool(df["is_valid"].iloc[0]) is True + + def test_zero_tflops_marked_invalid(self): + log = """\ +Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr +{ + "name": "test_kernel", + "problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"}, + "perf_result": {"latency(ms)": 0.0, "tflops(TFlops)": 0.0, "bandwidth(GB/s)": 0.0} +} +""" + path = self._write_log(log) + df = parse_streaming_log(path) + assert len(df) == 1 + assert bool(df["is_valid"].iloc[0]) is False + + def test_malformed_json_skipped(self): + log = """\ +Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr +{ + "name": "good_kernel", + "problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"}, + "perf_result": {"latency(ms)": 0.01, "tflops(TFlops)": 1.0, "bandwidth(GB/s)": 10.0} +} +{ this is not valid json } +{ + "name": "another_good", + "problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"}, + "perf_result": {"latency(ms)": 0.02, "tflops(TFlops)": 2.0, "bandwidth(GB/s)": 20.0} +} +""" + path = self._write_log(log) + df = parse_streaming_log(path) + assert len(df) == 2 + + def test_extreme_shapes(self): + """Tiny M=1 (single token) and very large M=20480.""" + path = self._write_log(SAMPLE_LOG) + df = parse_streaming_log(path) + assert 1 not in df["m"].values # sample has M=16, M=20480 + assert 16 in df["m"].values + assert 20480 in df["m"].values + + def test_run_id_assigned(self): + path = self._write_log(SAMPLE_LOG) + df = parse_streaming_log(path, run_id="test_run_123") + assert all(df["run_id"] == "test_run_123") + + def test_op_type_assigned(self): + path = self._write_log(SAMPLE_LOG) + df = parse_streaming_log(path, op_type="gemm_streamk") + assert all(df["op_type"] == "gemm_streamk") + + +# --------------------------------------------------------------------------- +# Parquet round-trip +# --------------------------------------------------------------------------- + + +class TestParquetIO: + def test_round_trip(self, tmp_path): + df = pd.DataFrame( + { + "m": [16, 32], + "n": [1536, 1536], + "k": [7168, 7168], + "measured_tflops": [8.81, 15.0], + } + ) + path = tmp_path / "test.parquet" + save_parquet(df, path) + loaded = load_parquet(path) + assert len(loaded) == 2 + assert loaded["m"].tolist() == [16, 32] + + def test_creates_parent_dirs(self, tmp_path): + path = tmp_path / "sub" / "dir" / "test.parquet" + df = pd.DataFrame({"x": [1]}) + save_parquet(df, path) + assert path.exists() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dispatcher/heuristics/tests/test_dispatcher_integration.py b/dispatcher/heuristics/tests/test_dispatcher_integration.py new file mode 100644 index 0000000000..a80438629d --- /dev/null +++ b/dispatcher/heuristics/tests/test_dispatcher_integration.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for dispatcher_integration.py. + +Covers: kernel name parsing to feature dict, feature dict to dispatcher config +(name mapping inversion), MLKernelSpec creation, binary pool loading, and +the ML heuristic function. +""" + +import json +import sys +from pathlib import Path + +import lightgbm as lgb +import numpy as np +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from dispatcher_integration import ( + kernel_config_to_feature_dict, + feature_dict_to_dispatcher_config, + feature_dict_to_ml_spec, + ml_spec_to_dispatcher_config, + create_ml_heuristic, + load_kernel_pool_from_binaries, + MLKernelSpec, + LAYOUT_TO_DISPATCHER, +) +from feature_engine import GemmUniversalFeatureEngine + + +SAMPLE_KERNEL_NAME = ( + "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave" + "_False_False_False_False_128x128x128_1x4x1_16x16x128" +) + + +# --------------------------------------------------------------------------- +# kernel_config_to_feature_dict +# --------------------------------------------------------------------------- + + +class TestKernelConfigToFeatureDict: + def test_parses_standard_name(self): + feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME) + assert feat["tile_m"] == 128 + assert feat["tile_n"] == 128 + assert feat["tile_k"] == 128 + assert feat["warp_m"] == 1 # warps per block + assert feat["warp_n"] == 4 + assert feat["warp_k"] == 1 + assert feat["warp_tile_m"] == 16 + assert feat["warp_tile_n"] == 16 + assert feat["warp_tile_k"] == 128 + assert feat["pipeline"] == "compv3" + assert feat["scheduler"] == "intrawave" + assert feat["epilogue"] == "cshuffle" + assert feat["kernel_name"] == SAMPLE_KERNEL_NAME + + def test_empty_name_returns_empty(self): + assert kernel_config_to_feature_dict("") == {} + + def test_invalid_name_returns_empty(self): + assert kernel_config_to_feature_dict("not_a_kernel") == {} + + +# --------------------------------------------------------------------------- +# Name mapping: feature dict <-> dispatcher config +# --------------------------------------------------------------------------- + + +class TestNameMapping: + """The critical inversion: feature engine warp_m/n/k (warps per block) + maps to dispatcher wave_m/n/k, and feature engine warp_tile_m/n/k + maps to dispatcher warp_m/n/k.""" + + def test_warp_to_wave_mapping(self): + feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME) + disp = feature_dict_to_dispatcher_config(feat) + assert disp["wave_m"] == feat["warp_m"] # 1 + assert disp["wave_n"] == feat["warp_n"] # 4 + assert disp["wave_k"] == feat["warp_k"] # 1 + + def test_warp_tile_to_warp_mapping(self): + feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME) + disp = feature_dict_to_dispatcher_config(feat) + assert disp["warp_m"] == feat["warp_tile_m"] # 16 + assert disp["warp_n"] == feat["warp_tile_n"] # 16 + assert disp["warp_k"] == feat["warp_tile_k"] # 128 + + def test_tile_dims_pass_through(self): + feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME) + disp = feature_dict_to_dispatcher_config(feat) + assert disp["tile_m"] == 128 + assert disp["tile_n"] == 128 + assert disp["tile_k"] == 128 + + def test_pipeline_passes_through(self): + feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME) + disp = feature_dict_to_dispatcher_config(feat) + assert disp["pipeline"] == "compv3" + assert disp["scheduler"] == "intrawave" + assert disp["epilogue"] == "cshuffle" + + def test_rcr_layout_mapping(self): + feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME) + disp = feature_dict_to_dispatcher_config(feat, dtype="fp8") + assert disp["layout_a"] == "row" + assert disp["layout_b"] == "col" + assert disp["layout_c"] == "row" + + def test_all_layouts(self): + for layout, (la, lb, lc) in LAYOUT_TO_DISPATCHER.items(): + feat = {"layout": layout, "tile_m": 128} + disp = feature_dict_to_dispatcher_config(feat) + assert disp["layout_a"] == la + assert disp["layout_b"] == lb + assert disp["layout_c"] == lc + + +# --------------------------------------------------------------------------- +# MLKernelSpec +# --------------------------------------------------------------------------- + + +class TestMLKernelSpec: + def test_from_feature_dict(self): + feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME) + spec = feature_dict_to_ml_spec(feat, predicted_tflops=123.4) + assert spec.kernel_name == SAMPLE_KERNEL_NAME + assert spec.predicted_tflops == 123.4 + assert spec.tile_m == 128 + assert spec.wave_m == 1 # was warp_m in feature space + assert spec.warp_m == 16 # was warp_tile_m in feature space + + def test_spec_to_dispatcher_config(self): + feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME) + spec = feature_dict_to_ml_spec(feat, 100.0) + disp = ml_spec_to_dispatcher_config(spec, dtype="fp8", arch="gfx950") + assert disp["tile_m"] == 128 + assert disp["wave_m"] == 1 + assert disp["warp_m"] == 16 + assert disp["gfx_arch"] == "gfx950" + assert disp["dtype_a"] == "fp8" + + def test_roundtrip_preserves_values(self): + """feature_dict -> MLKernelSpec -> dispatcher_config should be consistent.""" + feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME) + spec = feature_dict_to_ml_spec(feat, 0.0) + disp_from_spec = ml_spec_to_dispatcher_config(spec) + disp_from_feat = feature_dict_to_dispatcher_config(feat) + for key in [ + "tile_m", + "tile_n", + "tile_k", + "wave_m", + "wave_n", + "wave_k", + "warp_m", + "warp_n", + "warp_k", + "pipeline", + "scheduler", + "epilogue", + ]: + assert disp_from_spec[key] == disp_from_feat[key], f"Mismatch on {key}" + + +# --------------------------------------------------------------------------- +# Binary pool loading +# --------------------------------------------------------------------------- + + +class TestLoadKernelPool: + def test_loads_from_real_bin_dir(self): + bin_dir = Path("/workspace/ck_tile/bin") + if not bin_dir.exists(): + pytest.skip("No /workspace/ck_tile/bin") + pool = load_kernel_pool_from_binaries(bin_dir) + assert len(pool) > 0 + assert "tile_m" in pool[0] + assert "kernel_name" in pool[0] + + def test_empty_dir_returns_empty(self, tmp_path): + pool = load_kernel_pool_from_binaries(tmp_path) + assert pool == [] + + +# --------------------------------------------------------------------------- +# ML heuristic function +# --------------------------------------------------------------------------- + + +class TestCreateMLHeuristic: + @pytest.fixture + def mock_model_dir(self, tmp_path): + """Create a minimal model for testing the heuristic flow.""" + fe = GemmUniversalFeatureEngine() + n_features = len(fe.get_feature_names()) + np.random.seed(42) + X = np.random.rand(100, n_features) + y = np.random.rand(100) * 500 + model = lgb.LGBMRegressor(n_estimators=5, verbose=-1) + model.fit(X, y) + model.booster_.save_model(str(tmp_path / "model_tflops.lgbm")) + spec = { + "feature_names": fe.get_feature_names(), + "categorical_features": fe.get_categorical_features(), + } + with open(tmp_path / "feature_spec.json", "w") as f: + json.dump(spec, f) + return tmp_path + + def _make_pool(self): + """Create a small synthetic kernel pool.""" + names = [ + "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128", + "gemm_universal_fp8_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_2x2x1_32x32x16", + "gemm_universal_fp8_rcr_mem_cshuffle_interwave_False_False_False_False_64x64x128_1x4x1_16x16x32", + ] + return [kernel_config_to_feature_dict(n) for n in names] + + def test_returns_ml_kernel_spec(self, mock_model_dir): + pool = self._make_pool() + heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool) + result = heuristic(1024, 1024, 1024) + assert isinstance(result, MLKernelSpec) + assert result.tile_m > 0 + assert isinstance(result.predicted_tflops, float) + + def test_returns_valid_kernel_from_pool(self, mock_model_dir): + pool = self._make_pool() + pool_names = {p["kernel_name"] for p in pool} + heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool) + result = heuristic(1024, 1024, 1024) + assert result.kernel_name in pool_names + + def test_different_shapes_may_select_different_kernels(self, mock_model_dir): + pool = self._make_pool() + heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool) + r1 = heuristic(16, 1536, 7168) + r2 = heuristic(8192, 8192, 256) + # At minimum both should return valid specs + assert r1.tile_m > 0 + assert r2.tile_m > 0 + + def test_m1_corner_case(self, mock_model_dir): + pool = self._make_pool() + heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool) + result = heuristic(1, 4096, 4096) + assert isinstance(result, MLKernelSpec) + assert np.isfinite(result.predicted_tflops) + + def test_empty_pool_raises(self, mock_model_dir): + with pytest.raises(ValueError, match="No kernel configs"): + create_ml_heuristic(mock_model_dir, kernel_pool=[]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dispatcher/heuristics/tests/test_evaluate.py b/dispatcher/heuristics/tests/test_evaluate.py new file mode 100644 index 0000000000..bcbe39af9d --- /dev/null +++ b/dispatcher/heuristics/tests/test_evaluate.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for evaluate.py. + +Covers: shape family classification, K-depth regime classification, +and basic evaluation metric checks. +""" + +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from evaluate import classify_shape_family, classify_k_regime + + +class TestClassifyShapeFamily: + def test_tiny_m(self): + assert classify_shape_family(1, 4096, 4096) == "tiny_m" + assert classify_shape_family(16, 1536, 7168) == "tiny_m" + + def test_small_m(self): + assert classify_shape_family(32, 1536, 7168) == "small_m" + assert classify_shape_family(128, 4096, 4096) == "small_m" + + def test_medium_m(self): + assert classify_shape_family(256, 1024, 1024) == "medium_m" + assert classify_shape_family(2048, 2048, 2048) == "medium_m" + + def test_large_m(self): + assert classify_shape_family(4096, 4096, 4096) == "large_m" + assert classify_shape_family(20480, 7168, 256) == "large_m" + + +class TestClassifyKRegime: + def test_shallow(self): + assert classify_k_regime(256) == "shallow_k" + assert classify_k_regime(32) == "shallow_k" + + def test_medium(self): + assert classify_k_regime(1024) == "medium_k" + assert classify_k_regime(2048) == "medium_k" + + def test_deep(self): + assert classify_k_regime(4096) == "deep_k" + assert classify_k_regime(7168) == "deep_k" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dispatcher/heuristics/tests/test_feature_engine.py b/dispatcher/heuristics/tests/test_feature_engine.py new file mode 100644 index 0000000000..492623ce99 --- /dev/null +++ b/dispatcher/heuristics/tests/test_feature_engine.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for feature_engine.py. + +Covers: feature count consistency, formula correctness (tile efficiency, LDS, +arithmetic intensity), corner-case shapes (M=1, huge M, square, skinny-K), +parameter space validity, config validation, and batch vs single extraction parity. +""" + +import sys +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from feature_engine import ( + GemmUniversalFeatureEngine, +) + + +@pytest.fixture +def fe(): + """Default feature engine with MI355X-like hardware.""" + return GemmUniversalFeatureEngine( + num_cus=256, + lds_capacity=65536, + max_clock_mhz=2400, + simds_per_cu=4, + shader_engines=32, + max_waves_per_cu=32, + wavefront_size=64, + l1_cache_kb=32, + l2_cache_kb=4096, + l3_cache_kb=262144, + num_xcd=8, + ) + + +def _make_problem(m=1024, n=1024, k=1024, dtype="fp8", layout="rcr", split_k=1): + return { + "m": m, + "n": n, + "k": k, + "dtype": dtype, + "layout": layout, + "split_k": split_k, + } + + +def _make_kernel( + tile_m=128, + tile_n=128, + tile_k=64, + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + pad_m=False, + pad_n=False, + pad_k=False, + persistent=False, +): + return { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + "pipeline": pipeline, + "scheduler": scheduler, + "epilogue": epilogue, + "pad_m": pad_m, + "pad_n": pad_n, + "pad_k": pad_k, + "persistent": persistent, + } + + +# --------------------------------------------------------------------------- +# Basic consistency +# --------------------------------------------------------------------------- + + +class TestFeatureConsistency: + def test_feature_count_matches_names(self, fe): + prob = _make_problem() + kern = _make_kernel() + vec = fe.extract(prob, kern) + assert len(vec) == len(fe.get_feature_names()) + + def test_feature_count_is_72(self, fe): + assert len(fe.get_feature_names()) == 72 + + def test_no_nan_in_output(self, fe): + prob = _make_problem() + kern = _make_kernel() + vec = fe.extract(prob, kern) + assert not np.any(np.isnan(vec)) + + def test_no_inf_in_output(self, fe): + prob = _make_problem() + kern = _make_kernel() + vec = fe.extract(prob, kern) + assert not np.any(np.isinf(vec)) + + def test_categorical_features_in_names(self, fe): + names = fe.get_feature_names() + for cat in fe.get_categorical_features(): + assert cat in names + + +# --------------------------------------------------------------------------- +# Formula correctness +# --------------------------------------------------------------------------- + + +class TestTileEfficiency: + """Tile efficiency: fraction of the last tile that is useful work.""" + + def test_perfectly_divisible(self, fe): + prob = _make_problem(m=256, n=256, k=128) + kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64) + vec = fe.extract(prob, kern) + names = fe.get_feature_names() + assert vec[names.index("tile_eff_m")] == 1.0 + assert vec[names.index("tile_eff_n")] == 1.0 + assert vec[names.index("tile_eff_k")] == 1.0 + assert vec[names.index("overall_tile_efficiency")] == 1.0 + + def test_not_divisible(self, fe): + prob = _make_problem(m=100, n=100, k=100) + kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64) + vec = fe.extract(prob, kern) + names = fe.get_feature_names() + assert vec[names.index("tile_eff_m")] == pytest.approx(100 / 128) + assert vec[names.index("tile_eff_n")] == pytest.approx(100 / 128) + assert vec[names.index("tile_eff_k")] == pytest.approx(36 / 64) + + def test_m_equals_1(self, fe): + """Single-token inference: M=1, tile_m=128 => eff = 1/128.""" + prob = _make_problem(m=1) + kern = _make_kernel(tile_m=128) + vec = fe.extract(prob, kern) + names = fe.get_feature_names() + assert vec[names.index("tile_eff_m")] == pytest.approx(1.0 / 128.0) + + +class TestLDSUsage: + def test_lds_formula(self, fe): + prob = _make_problem(dtype="fp8") + kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64) + vec = fe.extract(prob, kern) + names = fe.get_feature_names() + expected = (128 * 64 + 128 * 64) * 1.0 # fp8 = 1 byte + assert vec[names.index("lds_usage_estimate")] == expected + + def test_lds_ratio_compv4(self, fe): + """compv4 has 32KB LDS limit, not 64KB.""" + prob = _make_problem(dtype="fp8") + kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64, pipeline="compv4") + vec = fe.extract(prob, kern) + names = fe.get_feature_names() + lds_est = (128 * 64 + 128 * 64) * 1.0 + assert vec[names.index("lds_usage_ratio")] == pytest.approx(lds_est / 32768) + + def test_lds_fp16_doubles(self, fe): + prob = _make_problem(dtype="fp16") + kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64) + vec = fe.extract(prob, kern) + names = fe.get_feature_names() + expected = (128 * 64 + 128 * 64) * 2.0 # fp16 = 2 bytes + assert vec[names.index("lds_usage_estimate")] == expected + + +class TestArithmeticIntensity: + def test_square_shape(self, fe): + M, N, K = 1024, 1024, 1024 + prob = _make_problem(m=M, n=N, k=K, dtype="fp8") + kern = _make_kernel() + vec = fe.extract(prob, kern) + names = fe.get_feature_names() + mem = (M * K + K * N + M * N) * 1.0 + expected = (2.0 * M * N * K) / mem + assert vec[names.index("arithmetic_intensity")] == pytest.approx(expected) + + def test_skinny_k(self, fe): + """Small K => low arithmetic intensity (memory-bound).""" + prob = _make_problem(m=8192, n=8192, k=32, dtype="fp8") + kern = _make_kernel() + vec = fe.extract(prob, kern) + names = fe.get_feature_names() + assert vec[names.index("arithmetic_intensity")] < 100 + + def test_deep_k(self, fe): + """Large K => high arithmetic intensity (compute-bound).""" + prob = _make_problem(m=256, n=256, k=8192, dtype="fp8") + kern = _make_kernel() + vec = fe.extract(prob, kern) + names = fe.get_feature_names() + assert vec[names.index("arithmetic_intensity")] > 100 + + +# --------------------------------------------------------------------------- +# Corner-case shapes +# --------------------------------------------------------------------------- + + +class TestCornerCaseShapes: + def test_m1_single_token(self, fe): + vec = fe.extract(_make_problem(m=1, n=4096, k=4096), _make_kernel()) + assert not np.any(np.isnan(vec)) + + def test_m1_n1_k1_minimum(self, fe): + vec = fe.extract(_make_problem(m=1, n=1, k=1), _make_kernel()) + assert not np.any(np.isnan(vec)) + assert not np.any(np.isinf(vec)) + + def test_very_large_m(self, fe): + vec = fe.extract(_make_problem(m=20480, n=7168, k=7168), _make_kernel()) + assert not np.any(np.isnan(vec)) + + def test_non_power_of_2(self, fe): + vec = fe.extract(_make_problem(m=1536, n=7168, k=2304), _make_kernel()) + assert not np.any(np.isnan(vec)) + + def test_prime_dimensions(self, fe): + vec = fe.extract(_make_problem(m=17, n=31, k=127), _make_kernel()) + assert not np.any(np.isnan(vec)) + + def test_tall_matrix(self, fe): + """M >> N (tall matrix).""" + prob = _make_problem(m=16384, n=64, k=1024) + vec = fe.extract(prob, _make_kernel()) + names = fe.get_feature_names() + assert vec[names.index("aspect_ratio_mn")] > 100 + + def test_wide_matrix(self, fe): + """N >> M (wide matrix).""" + prob = _make_problem(m=64, n=16384, k=1024) + vec = fe.extract(prob, _make_kernel()) + names = fe.get_feature_names() + assert vec[names.index("aspect_ratio_mn")] < 0.01 + + +# --------------------------------------------------------------------------- +# Batch vs single extraction parity +# --------------------------------------------------------------------------- + + +class TestBatchParity: + def test_batch_matches_single(self, fe): + """Vectorized batch should produce identical results to row-by-row.""" + rows = [ + { + "m": 16, + "n": 1536, + "k": 7168, + "split_k": 1, + "dtype": "fp8", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 128, + "warp_m": 1, + "warp_n": 4, + "warp_k": 1, + "warp_tile_m": 16, + "warp_tile_n": 16, + "warp_tile_k": 128, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + }, + { + "m": 20480, + "n": 7168, + "k": 256, + "split_k": 1, + "dtype": "fp8", + "layout": "rcr", + "tile_m": 64, + "tile_n": 64, + "tile_k": 128, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + "pipeline": "mem", + "scheduler": "interwave", + "epilogue": "default", + "pad_m": True, + "pad_n": True, + "pad_k": True, + "persistent": True, + }, + ] + df = pd.DataFrame(rows) + batch_result = fe.extract_batch(df) + + for i, row_dict in enumerate(rows): + single_result = fe.extract(row_dict, row_dict) + np.testing.assert_allclose( + batch_result[i], + single_result, + rtol=1e-5, + atol=1e-5, + err_msg=f"Mismatch at row {i}", + ) + + +# --------------------------------------------------------------------------- +# Parameter space and validation +# --------------------------------------------------------------------------- + + +class TestParameterSpace: + def test_parameter_space_non_empty(self, fe): + ps = fe.get_parameter_space() + assert len(ps) > 0 + assert "tile_m" in ps + assert "pipeline" in ps + + def test_valid_config_passes(self, fe): + config = { + "tile_m": 128, + "tile_n": 128, + "tile_k": 64, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + } + assert fe.validate_config(config) is True + + def test_invalid_tile_rejected(self, fe): + config = {"tile_m": 999} + assert fe.validate_config(config) is False + + def test_lds_constraint_rejects_huge_tile(self, fe): + config = { + "tile_m": 256, + "tile_n": 256, + "tile_k": 256, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "pipeline": "compv4", + } + assert fe.validate_config(config) is False + + def test_project_to_valid_snaps(self, fe): + config = {"tile_m": 100, "tile_n": 200, "pipeline": "compv3"} + projected = fe.project_to_valid(config) + assert projected["tile_m"] == 128 + assert projected["tile_n"] == 192 + assert projected["pipeline"] == "compv3" + + +# --------------------------------------------------------------------------- +# Hardware features +# --------------------------------------------------------------------------- + + +class TestHardwareFeatures: + def test_hardware_values_propagated(self, fe): + vec = fe.extract(_make_problem(), _make_kernel()) + names = fe.get_feature_names() + assert vec[names.index("hw_num_cus")] == 256 + assert vec[names.index("hw_max_clock_mhz")] == 2400 + assert vec[names.index("hw_total_simds")] == 256 * 4 + assert vec[names.index("hw_num_xcd")] == 8 + + def test_different_hardware(self): + fe_small = GemmUniversalFeatureEngine(num_cus=120, max_clock_mhz=1800) + vec = fe_small.extract(_make_problem(), _make_kernel()) + names = fe_small.get_feature_names() + assert vec[names.index("hw_num_cus")] == 120 + assert vec[names.index("hw_max_clock_mhz")] == 1800 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dispatcher/heuristics/tests/test_feature_parity.py b/dispatcher/heuristics/tests/test_feature_parity.py new file mode 100644 index 0000000000..43f6968b88 --- /dev/null +++ b/dispatcher/heuristics/tests/test_feature_parity.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Test that the C++ extract_features() in ml_heuristic.hpp produces identical +values to the Python GemmUniversalFeatureEngine.extract(). + +This test uses ctypes to call the C++ feature extraction compiled into a +small shared library, then compares against Python output. If compilation +fails (no HIP/ROCm), it falls back to verifying the Python feature engine +against manually computed expected values for specific test cases. +""" + +import math +import sys +from pathlib import Path + +import numpy as np +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +from feature_engine import ( + GemmUniversalFeatureEngine, + PIPELINE_MAP, + SCHEDULER_MAP, + EPILOGUE_MAP, + LAYOUT_MAP, +) + + +def _compute_features_manually( + M, + N, + K, + split_k, + dtype, + layout, + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline, + scheduler, + epilogue, + pad_m, + pad_n, + pad_k, + persistent, + hw, +): + """Recompute features independently to verify the Python engine.""" + bpe_map = {"fp8": 1.0, "fp16": 2.0, "bf16": 2.0, "fp32": 4.0} + bpe = bpe_map.get(dtype, 1.0) + + log2_M = math.log2(max(M, 1)) + log2_N = math.log2(max(N, 1)) + log2_K = math.log2(max(K, 1)) + log2_MNK = math.log2(max(M * N * K, 1)) + mem = (M * K + K * N + M * N) * bpe + ai = (2.0 * M * N * K) / max(mem, 1) + + lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe + lds_cap = 32768 if pipeline == "compv4" else hw["lds_capacity"] + + ntm = math.ceil(M / max(tile_m, 1)) + ntn = math.ceil(N / max(tile_n, 1)) + ntk = math.ceil(K / max(tile_k, 1)) + + def eff(d, t): + if t <= 0: + return 1.0 + r = d % t + return r / t if r > 0 else 1.0 + + # Problem-to-tile ratios + ratio_M_to_tile_m = M / max(tile_m, 1) + ratio_N_to_tile_n = N / max(tile_n, 1) + ratio_K_to_tile_k = K / max(tile_k, 1) + + # Binary features: problem smaller than tile + problem_smaller_than_tile_m = float(M < tile_m) + problem_smaller_than_tile_n = float(N < tile_n) + problem_smaller_than_tile_k = float(K < tile_k) + any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k)) + + # Padding requirement features + needs_padding_m = float(tile_m > 0 and M % tile_m != 0) + needs_padding_n = float(tile_n > 0 and N % tile_n != 0) + needs_padding_k = float(tile_k > 0 and K % tile_k != 0) + + # Interaction features + has_padding_when_needed_m = float(needs_padding_m and pad_m) + has_padding_when_needed_n = float(needs_padding_n and pad_n) + has_padding_when_needed_k = float(needs_padding_k and pad_k) + + # Missing padding features + missing_required_padding_m = float(needs_padding_m and not pad_m) + missing_required_padding_n = float(needs_padding_n and not pad_n) + missing_required_padding_k = float(needs_padding_k and not pad_k) + missing_any_required_padding = float( + missing_required_padding_m or missing_required_padding_n or missing_required_padding_k + ) + + return [ + M, # 0 + N, # 1 + K, # 2 + split_k, # 3 + log2_M, # 4 + log2_N, # 5 + log2_K, # 6 + log2_MNK, # 7 + ai, # 8 + M / max(N, 1), # 9 (aspect_ratio_mn) + M / max(K, 1), # 10 (aspect_ratio_mk) + N / max(K, 1), # 11 (aspect_ratio_nk) + LAYOUT_MAP.get(layout, 0), # 12 + tile_m, # 13 + tile_n, # 14 + tile_k, # 15 + warp_m, # 16 + warp_n, # 17 + warp_k, # 18 + warp_tile_m, # 19 + warp_tile_n, # 20 + warp_tile_k, # 21 + PIPELINE_MAP.get(pipeline, 0), # 22 + SCHEDULER_MAP.get(scheduler, 0), # 23 + EPILOGUE_MAP.get(epilogue, 0), # 24 + float(pad_m), # 25 + float(pad_n), # 26 + float(pad_k), # 27 + float(persistent), # 28 + warp_m * warp_n * warp_k, # 29 (num_warps) + tile_m * tile_n * tile_k, # 30 (tile_volume) + tile_m * tile_n, # 31 (tile_mn) + lds_est, # 32 (lds_usage_estimate) + lds_est / max(lds_cap, 1), # 33 (lds_usage_ratio) + ntm, # 34 (num_tiles_m) + ntn, # 35 (num_tiles_n) + ntk, # 36 (num_tiles_k) + ntm * ntn, # 37 (total_output_tiles) + eff(M, tile_m), # 38 (tile_eff_m) + eff(N, tile_n), # 39 (tile_eff_n) + eff(K, tile_k), # 40 (tile_eff_k) + eff(M, tile_m) * eff(N, tile_n) * eff(K, tile_k), # 41 (overall_tile_efficiency) + ntm * ntn / max(hw["num_cus"], 1), # 42 (cu_utilization) + ratio_M_to_tile_m, # 43 + ratio_N_to_tile_n, # 44 + ratio_K_to_tile_k, # 45 + problem_smaller_than_tile_m, # 46 + problem_smaller_than_tile_n, # 47 + problem_smaller_than_tile_k, # 48 + any_dim_too_small, # 49 + needs_padding_m, # 50 + needs_padding_n, # 51 + needs_padding_k, # 52 + has_padding_when_needed_m, # 53 + has_padding_when_needed_n, # 54 + has_padding_when_needed_k, # 55 + missing_required_padding_m, # 56 + missing_required_padding_n, # 57 + missing_required_padding_k, # 58 + missing_any_required_padding, # 59 + hw["num_cus"], # 60 + hw["simds_per_cu"], # 61 + hw["num_cus"] * hw["simds_per_cu"], # 62 (total_simds) + hw["shader_engines"], # 63 + hw["max_clock_mhz"], # 64 + hw["max_waves_per_cu"], # 65 + hw["wavefront_size"], # 66 + hw["lds_capacity"], # 67 + hw["l1_cache_kb"], # 68 + hw["l2_cache_kb"], # 69 + hw["l3_cache_kb"], # 70 + hw["num_xcd"], # 71 + ] + + +TEST_CASES = [ + { + "problem": { + "m": 1024, + "n": 1024, + "k": 1024, + "split_k": 1, + "dtype": "fp8", + "layout": "rcr", + }, + "kernel": { + "tile_m": 128, + "tile_n": 128, + "tile_k": 64, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + }, + }, + { + "problem": { + "m": 1, + "n": 4096, + "k": 4096, + "split_k": 1, + "dtype": "fp8", + "layout": "rcr", + }, + "kernel": { + "tile_m": 64, + "tile_n": 64, + "tile_k": 128, + "warp_m": 1, + "warp_n": 4, + "warp_k": 1, + "warp_tile_m": 16, + "warp_tile_n": 16, + "warp_tile_k": 128, + "pipeline": "compv4", + "scheduler": "interwave", + "epilogue": "default", + "pad_m": True, + "pad_n": True, + "pad_k": True, + "persistent": True, + }, + }, + { + "problem": { + "m": 20480, + "n": 7168, + "k": 256, + "split_k": 1, + "dtype": "fp16", + "layout": "rrr", + }, + "kernel": { + "tile_m": 256, + "tile_n": 256, + "tile_k": 32, + "warp_m": 4, + "warp_n": 1, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + "pipeline": "mem", + "scheduler": "interwave", + "epilogue": "cshuffle", + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + }, + }, +] + +HW = { + "num_cus": 256, + "simds_per_cu": 4, + "shader_engines": 32, + "max_clock_mhz": 2400, + "max_waves_per_cu": 32, + "wavefront_size": 64, + "lds_capacity": 65536, + "l1_cache_kb": 32, + "l2_cache_kb": 4096, + "l3_cache_kb": 262144, + "num_xcd": 8, +} + + +class TestFeatureParity: + """Verify Python feature engine matches manual computation (C++ uses same logic).""" + + @pytest.fixture + def fe(self): + return GemmUniversalFeatureEngine(**HW) + + @pytest.mark.parametrize("case_idx", range(len(TEST_CASES))) + def test_python_matches_manual(self, fe, case_idx): + case = TEST_CASES[case_idx] + prob = case["problem"] + kern = case["kernel"] + + py_features = fe.extract(prob, kern) + + manual = _compute_features_manually( + prob["m"], + prob["n"], + prob["k"], + prob["split_k"], + prob["dtype"], + prob["layout"], + kern["tile_m"], + kern["tile_n"], + kern["tile_k"], + kern["warp_m"], + kern["warp_n"], + kern["warp_k"], + kern["warp_tile_m"], + kern["warp_tile_n"], + kern["warp_tile_k"], + kern["pipeline"], + kern["scheduler"], + kern["epilogue"], + kern["pad_m"], + kern["pad_n"], + kern["pad_k"], + kern["persistent"], + HW, + ) + + manual_arr = np.array(manual, dtype=np.float64) + assert len(py_features) == len(manual_arr) == 72 + + for i in range(72): + assert py_features[i] == pytest.approx( + manual_arr[i], rel=1e-10, abs=1e-15 + ), ( + f"Feature {i} ({fe.get_feature_names()[i]}): Python={py_features[i]}, Manual={manual_arr[i]}" + ) + + def test_feature_count(self, fe): + assert len(fe.get_feature_names()) == 72 + + def test_encoding_maps_match_cpp(self): + """The C++ encode_* functions must use the same mapping as Python.""" + assert PIPELINE_MAP == { + "compv3": 0, + "compv4": 1, + "compv5": 2, + "mem": 3, + "preshufflev2": 4, + } + assert SCHEDULER_MAP == {"intrawave": 0, "interwave": 1} + assert EPILOGUE_MAP == {"default": 0, "cshuffle": 1} + assert LAYOUT_MAP == {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3} + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dispatcher/heuristics/tests/test_model_compression.py b/dispatcher/heuristics/tests/test_model_compression.py new file mode 100644 index 0000000000..50727f1242 --- /dev/null +++ b/dispatcher/heuristics/tests/test_model_compression.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +"""Test that compressed models can be loaded and used.""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from predict import Predictor + + +def test_fp16_model_decompression(): + """Test that fp16 model is auto-decompressed and usable.""" + model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp16_gfx950" + + # Ensure .lgbm.gz exists + gz_file = model_dir / "model_tflops.lgbm.gz" + + assert gz_file.exists(), f"Compressed model not found: {gz_file}" + + # Load predictor - should auto-decompress + predictor = Predictor(model_dir) + + # Test prediction + problem = {"m": 128, "n": 1536, "k": 7168, "dtype": "fp16", "layout": "rcr"} + kernel_config = { + "tile_shape": {"m0": 128, "n0": 128, "k0": 16}, + "wave_shape": {"m1": 2, "n1": 2, "k1": 1}, + "warp_tile": {"m2": 32, "n2": 32, "k2": 8}, + } + + tflops = predictor.predict_tflops(problem, kernel_config) + + assert isinstance(tflops, float), f"Expected float, got {type(tflops)}" + assert tflops > 0, f"Expected positive TFLOPS, got {tflops}" + + # Verify decompressed file was created + lgbm_file = model_dir / "model_tflops.lgbm" + assert lgbm_file.exists(), "Model should have been decompressed" + + print(f"✅ FP16 model decompression test passed") + print(f" Predicted TFLOPS: {tflops:.2f}") + print(f" Decompressed to: {lgbm_file}") + return True + + +def test_fp8_model_decompression(): + """Test that fp8 model is auto-decompressed and usable.""" + model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp8_gfx950" + + # Ensure .lgbm.gz exists + gz_file = model_dir / "model_tflops.lgbm.gz" + + assert gz_file.exists(), f"Compressed model not found: {gz_file}" + + # Load predictor - should auto-decompress + predictor = Predictor(model_dir) + + # Test prediction + problem = {"m": 2048, "n": 2048, "k": 2048, "dtype": "fp8", "layout": "rcr"} + kernel_config = { + "tile_shape": {"m0": 256, "n0": 256, "k0": 64}, + "wave_shape": {"m1": 2, "n1": 2, "k1": 1}, + "warp_tile": {"m2": 32, "n2": 32, "k2": 16}, + } + + tflops = predictor.predict_tflops(problem, kernel_config) + + assert isinstance(tflops, float), f"Expected float, got {type(tflops)}" + assert tflops > 0, f"Expected positive TFLOPS, got {tflops}" + + # Verify decompressed file was created + lgbm_file = model_dir / "model_tflops.lgbm" + assert lgbm_file.exists(), "Model should have been decompressed" + + print(f"✅ FP8 model decompression test passed") + print(f" Predicted TFLOPS: {tflops:.2f}") + print(f" Decompressed to: {lgbm_file}") + return True + + +if __name__ == "__main__": + print("Testing compressed model auto-decompression...") + print() + + test_fp16_model_decompression() + print() + test_fp8_model_decompression() + print() + print("✅ All model compression tests passed!") diff --git a/dispatcher/heuristics/tests/test_predict.py b/dispatcher/heuristics/tests/test_predict.py new file mode 100644 index 0000000000..24cb26c4fa --- /dev/null +++ b/dispatcher/heuristics/tests/test_predict.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for predict.py. + +Covers: Predictor initialization, single prediction, ranking, select_best, +missing model handling, and edge cases (single kernel, empty list). +""" + +import json +import sys +from pathlib import Path + +import lightgbm as lgb +import numpy as np +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from feature_engine import GemmUniversalFeatureEngine +from predict import Predictor + + +@pytest.fixture +def model_dir(tmp_path): + """Create a minimal trained model for testing.""" + fe = GemmUniversalFeatureEngine() + n_features = len(fe.get_feature_names()) + + np.random.seed(42) + X = np.random.rand(200, n_features) + y = np.random.rand(200) * 100 + + model = lgb.LGBMRegressor(n_estimators=10, verbose=-1) + model.fit(X, y) + model.booster_.save_model(str(tmp_path / "model_tflops.lgbm")) + + y_lat = np.random.rand(200) * 0.1 + model_lat = lgb.LGBMRegressor(n_estimators=10, verbose=-1) + model_lat.fit(X, y_lat) + model_lat.booster_.save_model(str(tmp_path / "model_latency.lgbm")) + + spec = { + "feature_names": fe.get_feature_names(), + "categorical_features": fe.get_categorical_features(), + } + with open(tmp_path / "feature_spec.json", "w") as f: + json.dump(spec, f) + + return tmp_path + + +@pytest.fixture +def predictor(model_dir): + return Predictor(model_dir) + + +def _problem(): + return { + "m": 1024, + "n": 1024, + "k": 1024, + "dtype": "fp8", + "layout": "rcr", + "split_k": 1, + } + + +def _kernel(tile_m=128, pipeline="compv3"): + return { + "kernel_name": f"test_kernel_{tile_m}_{pipeline}", + "tile_m": tile_m, + "tile_n": 128, + "tile_k": 64, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + "pipeline": pipeline, + "scheduler": "intrawave", + "epilogue": "cshuffle", + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + } + + +class TestPredictor: + def test_predict_tflops_returns_float(self, predictor): + result = predictor.predict_tflops(_problem(), _kernel()) + assert isinstance(result, float) + + def test_predict_latency_returns_float(self, predictor): + result = predictor.predict_latency(_problem(), _kernel()) + assert isinstance(result, float) + + def test_predict_all_returns_dict(self, predictor): + result = predictor.predict_all(_problem(), _kernel()) + assert "tflops" in result + assert "latency_ms" in result + + def test_rank_kernels_sorted_descending(self, predictor): + kernels = [_kernel(64, "compv3"), _kernel(128, "compv4"), _kernel(256, "mem")] + ranked = predictor.rank_kernels(_problem(), kernels) + assert len(ranked) == 3 + scores = [s for _, s in ranked] + assert scores == sorted(scores, reverse=True) + + def test_select_best_returns_name(self, predictor): + kernels = [_kernel(64), _kernel(128)] + best = predictor.select_best(_problem(), kernels) + assert isinstance(best, str) + assert best in [k["kernel_name"] for k in kernels] + + def test_single_kernel(self, predictor): + kernels = [_kernel(128)] + ranked = predictor.rank_kernels(_problem(), kernels) + assert len(ranked) == 1 + + def test_missing_bandwidth_model(self, model_dir): + pred = Predictor(model_dir) + with pytest.raises(FileNotFoundError): + pred.predict_bandwidth(_problem(), _kernel()) + + def test_empty_kernel_list(self, predictor): + with pytest.raises(ValueError): + predictor.select_best(_problem(), []) + + def test_corner_case_m1(self, predictor): + prob = { + "m": 1, + "n": 4096, + "k": 4096, + "dtype": "fp8", + "layout": "rcr", + "split_k": 1, + } + result = predictor.predict_tflops(prob, _kernel()) + assert np.isfinite(result) + + def test_different_shapes_give_different_results(self, predictor): + k = _kernel() + r1 = predictor.predict_tflops( + { + "m": 16, + "n": 1536, + "k": 7168, + "dtype": "fp8", + "layout": "rcr", + "split_k": 1, + }, + k, + ) + r2 = predictor.predict_tflops( + { + "m": 20480, + "n": 7168, + "k": 256, + "dtype": "fp8", + "layout": "rcr", + "split_k": 1, + }, + k, + ) + assert r1 != r2 + + +class TestPredictorEdgeCases: + def test_nonexistent_model_dir(self): + with pytest.raises(Exception): + pred = Predictor("/nonexistent/path") + pred.predict_tflops(_problem(), _kernel()) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dispatcher/heuristics/tests/test_search.py b/dispatcher/heuristics/tests/test_search.py new file mode 100644 index 0000000000..b1d1ac79b3 --- /dev/null +++ b/dispatcher/heuristics/tests/test_search.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for search.py. + +Covers: random search, DE search, config validity, result ordering, +budget compliance, and edge cases. +""" + +import json +import sys +from pathlib import Path + +import lightgbm as lgb +import numpy as np +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from feature_engine import GemmUniversalFeatureEngine +from predict import Predictor +from search import SurrogateSearch + + +@pytest.fixture +def model_dir(tmp_path): + """Create a minimal trained model.""" + fe = GemmUniversalFeatureEngine() + n_features = len(fe.get_feature_names()) + np.random.seed(42) + X = np.random.rand(200, n_features) + y = np.random.rand(200) * 500 + model = lgb.LGBMRegressor(n_estimators=10, verbose=-1) + model.fit(X, y) + model.booster_.save_model(str(tmp_path / "model_tflops.lgbm")) + spec = { + "feature_names": fe.get_feature_names(), + "categorical_features": fe.get_categorical_features(), + } + with open(tmp_path / "feature_spec.json", "w") as f: + json.dump(spec, f) + return tmp_path + + +@pytest.fixture +def predictor(model_dir): + return Predictor(model_dir) + + +def _problem(): + return { + "m": 1024, + "n": 1024, + "k": 1024, + "dtype": "fp8", + "layout": "rcr", + "split_k": 1, + } + + +class TestRandomSearch: + def test_returns_results(self, predictor): + searcher = SurrogateSearch(predictor, strategy="random") + results = searcher.search(_problem(), budget=50, top_k=5) + assert len(results) > 0 + assert len(results) <= 5 + + def test_results_sorted_descending(self, predictor): + searcher = SurrogateSearch(predictor, strategy="random") + results = searcher.search(_problem(), budget=100, top_k=10) + scores = [s for _, s in results] + assert scores == sorted(scores, reverse=True) + + def test_configs_are_valid(self, predictor): + fe = GemmUniversalFeatureEngine() + searcher = SurrogateSearch(predictor, feature_engine=fe, strategy="random") + results = searcher.search(_problem(), budget=50, top_k=5) + for cfg, _ in results: + ps = fe.get_parameter_space() + for k, v in cfg.items(): + if k in ps: + assert v in ps[k], f"{k}={v} not in {ps[k]}" + + def test_respects_top_k(self, predictor): + searcher = SurrogateSearch(predictor, strategy="random") + results = searcher.search(_problem(), budget=100, top_k=3) + assert len(results) <= 3 + + def test_different_problems_produce_results(self, predictor): + """Both problem sizes should produce valid search results.""" + searcher = SurrogateSearch(predictor, strategy="random", seed=42) + r1 = searcher.search( + { + "m": 16, + "n": 1536, + "k": 7168, + "dtype": "fp8", + "layout": "rcr", + "split_k": 1, + }, + budget=50, + top_k=3, + ) + searcher2 = SurrogateSearch(predictor, strategy="random", seed=42) + r2 = searcher2.search( + { + "m": 20480, + "n": 7168, + "k": 256, + "dtype": "fp8", + "layout": "rcr", + "split_k": 1, + }, + budget=50, + top_k=3, + ) + assert len(r1) > 0 + assert len(r2) > 0 + for _, score in r1 + r2: + assert np.isfinite(score) + + def test_m1_corner_case(self, predictor): + searcher = SurrogateSearch(predictor, strategy="random") + results = searcher.search( + { + "m": 1, + "n": 4096, + "k": 4096, + "dtype": "fp8", + "layout": "rcr", + "split_k": 1, + }, + budget=50, + top_k=5, + ) + assert len(results) > 0 + for _, score in results: + assert np.isfinite(score) + + +class TestDESearch: + def test_returns_results(self, predictor): + searcher = SurrogateSearch(predictor, strategy="de") + results = searcher.search(_problem(), budget=100, top_k=5) + assert len(results) > 0 + + def test_results_sorted_descending(self, predictor): + searcher = SurrogateSearch(predictor, strategy="de") + results = searcher.search(_problem(), budget=100, top_k=5) + scores = [s for _, s in results] + assert scores == sorted(scores, reverse=True) + + def test_de_improves_over_initial(self, predictor): + """DE should generally find at least as good as random initialization.""" + searcher_r = SurrogateSearch(predictor, strategy="random", seed=42) + r_results = searcher_r.search(_problem(), budget=100, top_k=1) + searcher_d = SurrogateSearch(predictor, strategy="de", seed=42) + d_results = searcher_d.search(_problem(), budget=100, top_k=1) + if r_results and d_results: + assert d_results[0][1] >= r_results[0][1] * 0.9 + + def test_small_budget(self, predictor): + searcher = SurrogateSearch(predictor, strategy="de") + results = searcher.search(_problem(), budget=30, top_k=5) + assert len(results) > 0 + + +class TestSearchEdgeCases: + def test_unknown_strategy_raises(self, predictor): + searcher = SurrogateSearch(predictor, strategy="unknown") + with pytest.raises(ValueError): + searcher.search(_problem(), budget=10) + + def test_zero_budget(self, predictor): + searcher = SurrogateSearch(predictor, strategy="random") + results = searcher.search(_problem(), budget=0, top_k=5) + assert len(results) == 0 + + def test_deterministic_with_same_seed(self, predictor): + s1 = SurrogateSearch(predictor, strategy="random", seed=123) + s2 = SurrogateSearch(predictor, strategy="random", seed=123) + r1 = s1.search(_problem(), budget=50, top_k=5) + r2 = s2.search(_problem(), budget=50, top_k=5) + assert len(r1) == len(r2) + for (c1, s1_), (c2, s2_) in zip(r1, r2): + assert s1_ == pytest.approx(s2_) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dispatcher/heuristics/tests/test_train.py b/dispatcher/heuristics/tests/test_train.py new file mode 100644 index 0000000000..d437030bfa --- /dev/null +++ b/dispatcher/heuristics/tests/test_train.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for train.py. + +Covers: group key computation, TFLOPS efficiency calculation, edge cases +(single group, all-invalid data, tied predictions), and warm-start +incremental training (feature compat, lineage, quality). +""" + +import json +import sys +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from feature_engine import GemmUniversalFeatureEngine +from train import ( + compute_group_keys, + compute_tflops_efficiency, + check_feature_compatibility, + load_warm_start_model, + train_final_model, + DEFAULT_PARAMS, +) + + +class TestComputeGroupKeys: + def test_basic(self): + df = pd.DataFrame( + {"m": [16, 16, 32], "n": [1536, 1536, 1536], "k": [7168, 7168, 7168]} + ) + keys = compute_group_keys(df) + assert keys[0] == keys[1] + assert keys[0] != keys[2] + + def test_unique_shapes(self): + df = pd.DataFrame({"m": [1, 2, 3], "n": [4, 5, 6], "k": [7, 8, 9]}) + keys = compute_group_keys(df) + assert len(set(keys)) == 3 + + +class TestComputeTflopsEfficiency: + def test_perfect_prediction(self): + """Model predicts highest TFLOPS kernel => efficiency = 1.0.""" + df = pd.DataFrame( + { + "m": [1024, 1024, 1024], + "n": [1024, 1024, 1024], + "k": [1024, 1024, 1024], + "measured_tflops": [100, 200, 150], + "pred_tflops": [50, 300, 100], # correctly ranks kernel 1 highest + } + ) + eff = compute_tflops_efficiency(df, "pred_tflops") + assert len(eff) == 1 + assert eff["efficiency"].iloc[0] == pytest.approx(1.0) + + def test_worst_prediction(self): + """Model picks the worst kernel.""" + df = pd.DataFrame( + { + "m": [1024, 1024, 1024], + "n": [1024, 1024, 1024], + "k": [1024, 1024, 1024], + "measured_tflops": [100, 200, 150], + "pred_tflops": [999, 1, 1], # incorrectly ranks kernel 0 highest + } + ) + eff = compute_tflops_efficiency(df, "pred_tflops") + assert eff["efficiency"].iloc[0] == pytest.approx(100 / 200) + + def test_multiple_shapes(self): + df = pd.DataFrame( + { + "m": [16, 16, 32, 32], + "n": [1536, 1536, 1536, 1536], + "k": [7168, 7168, 7168, 7168], + "measured_tflops": [10, 20, 100, 200], + "pred_tflops": [5, 25, 150, 190], + } + ) + eff = compute_tflops_efficiency(df, "pred_tflops") + assert len(eff) == 2 + assert eff.iloc[0]["efficiency"] == pytest.approx(1.0) + assert eff.iloc[1]["efficiency"] == pytest.approx(1.0) + + def test_zero_tflops_shape_skipped(self): + df = pd.DataFrame( + { + "m": [16, 16], + "n": [16, 16], + "k": [16, 16], + "measured_tflops": [0, 0], + "pred_tflops": [1, 2], + } + ) + eff = compute_tflops_efficiency(df, "pred_tflops") + assert len(eff) == 0 + + def test_single_kernel_per_shape(self): + df = pd.DataFrame( + { + "m": [1024], + "n": [1024], + "k": [1024], + "measured_tflops": [150], + "pred_tflops": [100], + } + ) + eff = compute_tflops_efficiency(df, "pred_tflops") + assert len(eff) == 1 + assert eff["efficiency"].iloc[0] == pytest.approx(1.0) + + def test_tied_predictions(self): + """When multiple kernels have the same predicted TFLOPS, pandas idxmax picks the first.""" + df = pd.DataFrame( + { + "m": [1024, 1024, 1024], + "n": [1024, 1024, 1024], + "k": [1024, 1024, 1024], + "measured_tflops": [100, 200, 200], + "pred_tflops": [50, 50, 50], + } + ) + eff = compute_tflops_efficiency(df, "pred_tflops") + assert len(eff) == 1 + assert eff["efficiency"].iloc[0] >= 0.5 + + +# --------------------------------------------------------------------------- +# Helpers for warm-start tests +# --------------------------------------------------------------------------- + + +def _make_dummy_data(n_rows=200, n_shapes=5): + """Create a small synthetic benchmark DataFrame for testing training.""" + rng = np.random.RandomState(42) + rows = [] + for _ in range(n_rows): + m = rng.choice([64, 128, 256, 512, 1024]) + n = rng.choice([64, 128, 256, 512, 1024]) + k = rng.choice([64, 128, 256, 512, 1024]) + rows.append( + { + "m": m, + "n": n, + "k": k, + "split_k": 1, + "dtype": "fp8", + "layout": "rcr", + "op_type": "gemm_universal", + "tile_m": rng.choice([64, 128, 256]), + "tile_n": rng.choice([64, 128, 256]), + "tile_k": rng.choice([32, 64, 128]), + "warp_m": rng.choice([1, 2, 4]), + "warp_n": rng.choice([1, 2, 4]), + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + "pipeline": rng.choice(["compv3", "compv4", "mem"]), + "scheduler": rng.choice(["intrawave", "interwave"]), + "epilogue": "cshuffle", + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + "measured_tflops": float(rng.uniform(10, 500)), + "latency_ms": float(rng.uniform(0.01, 1.0)), + "bandwidth_gb_s": float(rng.uniform(50, 1500)), + "is_valid": True, + "kernel_name": f"test_kernel_{rng.randint(0, 100)}", + } + ) + return pd.DataFrame(rows) + + +def _save_feature_spec(model_dir, fe): + """Save a feature_spec.json matching the given feature engine.""" + spec = { + "feature_names": fe.get_feature_names(), + "categorical_features": fe.get_categorical_features(), + } + with open(model_dir / "feature_spec.json", "w") as f: + json.dump(spec, f) + + +def _train_and_save_base_model(model_dir, df, fe, target="tflops"): + """Train a small base model and save it to model_dir.""" + params = dict(DEFAULT_PARAMS) + params["n_estimators"] = 20 + params["n_jobs"] = 1 + model = train_final_model(df, fe, target, params) + model.booster_.save_model(str(model_dir / f"model_{target}.lgbm")) + _save_feature_spec(model_dir, fe) + return model + + +# --------------------------------------------------------------------------- +# Warm-start tests +# --------------------------------------------------------------------------- + + +class TestCheckFeatureCompatibility: + def test_compatible_passes(self, tmp_path): + fe = GemmUniversalFeatureEngine() + _save_feature_spec(tmp_path, fe) + check_feature_compatibility(tmp_path, fe) + + def test_missing_spec_raises(self, tmp_path): + fe = GemmUniversalFeatureEngine() + with pytest.raises(FileNotFoundError, match="feature_spec.json"): + check_feature_compatibility(tmp_path, fe) + + def test_added_feature_raises(self, tmp_path): + fe = GemmUniversalFeatureEngine() + spec = { + "feature_names": fe.get_feature_names()[:-1], + "categorical_features": fe.get_categorical_features(), + } + with open(tmp_path / "feature_spec.json", "w") as f: + json.dump(spec, f) + with pytest.raises(ValueError, match="Feature schema mismatch"): + check_feature_compatibility(tmp_path, fe) + + def test_removed_feature_raises(self, tmp_path): + fe = GemmUniversalFeatureEngine() + spec = { + "feature_names": fe.get_feature_names() + ["extra_feature"], + "categorical_features": fe.get_categorical_features(), + } + with open(tmp_path / "feature_spec.json", "w") as f: + json.dump(spec, f) + with pytest.raises(ValueError, match="Feature schema mismatch"): + check_feature_compatibility(tmp_path, fe) + + def test_categorical_mismatch_raises(self, tmp_path): + fe = GemmUniversalFeatureEngine() + spec = { + "feature_names": fe.get_feature_names(), + "categorical_features": ["layout", "pipeline"], + } + with open(tmp_path / "feature_spec.json", "w") as f: + json.dump(spec, f) + with pytest.raises(ValueError, match="Categorical feature mismatch"): + check_feature_compatibility(tmp_path, fe) + + +class TestLoadWarmStartModel: + def test_loads_existing_model(self, tmp_path): + fe = GemmUniversalFeatureEngine() + df = _make_dummy_data() + _train_and_save_base_model(tmp_path, df, fe) + path = load_warm_start_model(tmp_path, "tflops") + assert path is not None + assert Path(path).exists() + + def test_returns_none_for_missing_target(self, tmp_path): + assert load_warm_start_model(tmp_path, "tflops") is None + + def test_returns_none_for_wrong_target(self, tmp_path): + fe = GemmUniversalFeatureEngine() + df = _make_dummy_data() + _train_and_save_base_model(tmp_path, df, fe, target="tflops") + assert load_warm_start_model(tmp_path, "bandwidth") is None + + +class TestWarmStartTraining: + def test_warm_start_produces_more_trees(self, tmp_path): + """A warm-started model should have more trees than the base.""" + fe = GemmUniversalFeatureEngine() + df = _make_dummy_data(n_rows=300) + + base_dir = tmp_path / "base" + base_dir.mkdir() + base_model = _train_and_save_base_model(base_dir, df, fe) + base_n_trees = base_model.booster_.num_trees() + + init_model_path = load_warm_start_model(base_dir, "tflops") + params = dict(DEFAULT_PARAMS) + params["n_estimators"] = 15 + params["n_jobs"] = 1 + warm_model = train_final_model( + df, fe, "tflops", params, init_model=init_model_path + ) + warm_n_trees = warm_model.booster_.num_trees() + + assert warm_n_trees > base_n_trees + + def test_warm_start_does_not_degrade(self, tmp_path): + """Warm-started model on the same data should not be significantly worse.""" + fe = GemmUniversalFeatureEngine() + df = _make_dummy_data(n_rows=300) + + base_dir = tmp_path / "base" + base_dir.mkdir() + base_model = _train_and_save_base_model(base_dir, df, fe) + + X = fe.extract_batch(df[df["is_valid"]].reset_index(drop=True)) + y = df[df["is_valid"]]["measured_tflops"].values + base_rmse = np.sqrt(np.mean((base_model.predict(X) - y) ** 2)) + + init_model_path = load_warm_start_model(base_dir, "tflops") + params = dict(DEFAULT_PARAMS) + params["n_estimators"] = 15 + params["n_jobs"] = 1 + warm_model = train_final_model( + df, fe, "tflops", params, init_model=init_model_path + ) + warm_rmse = np.sqrt(np.mean((warm_model.predict(X) - y) ** 2)) + + assert warm_rmse <= base_rmse * 1.1 + + def test_warm_start_from_nonexistent_dir(self): + with pytest.raises(FileNotFoundError): + check_feature_compatibility( + Path("/nonexistent/model/dir"), GemmUniversalFeatureEngine() + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dispatcher/heuristics/train.py b/dispatcher/heuristics/train.py new file mode 100644 index 0000000000..6d5dc772ac --- /dev/null +++ b/dispatcher/heuristics/train.py @@ -0,0 +1,555 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Training script for CK Tile kernel performance prediction. + +Trains LGBMRegressor models (TFLOPS, latency, bandwidth) with: + - Log-space regression (log1p transform) for scale-invariant accuracy + - GroupKFold cross-validation (group key = (M, N, K)) + - Iterative Hard Example Mining (IHEM) + - Model complexity bounds for C++ deployability + - Optional Optuna hyperparameter tuning + - Warm-start incremental training from a previous model via --warm_start + +Log-transform rationale: + GEMM TFLOPS spans 5 orders of magnitude (0.02 for M=1 to 2230 for large + shapes). Raw regression optimizes for absolute RMSE, which means the model + spends all its capacity predicting large shapes accurately and ignores tiny + shapes where TFLOPS is < 10. Training on log1p(TFLOPS) puts all shapes on + equal footing, improving tiny_m efficiency from 84% to 96%. +""" + +import argparse +import json +import time +from pathlib import Path + +import lightgbm as lgb +import numpy as np +import pandas as pd +from sklearn.model_selection import GroupKFold + +from data_pipeline import build_training_dataset +from feature_engine import GemmUniversalFeatureEngine + + +TARGET_COLUMNS = { + "tflops": "measured_tflops", + "latency": "latency_ms", + "bandwidth": "bandwidth_gb_s", +} + +# Targets where log1p transform is applied by default. +# TFLOPS and bandwidth span orders of magnitude; latency is already small-scale. +LOG_TARGETS = {"tflops", "bandwidth"} + +DEFAULT_PARAMS = { + "objective": "regression", + "metric": ["rmse", "mae"], + "num_leaves": 255, + "max_depth": 15, + "n_estimators": 2000, + "learning_rate": 0.02, + "min_child_samples": 10, + "subsample": 0.85, + "colsample_bytree": 0.85, + "reg_alpha": 0.05, + "reg_lambda": 0.5, + "verbose": -1, + "n_jobs": 8, + "seed": 42, +} + +MAX_ESTIMATORS = 5000 +WARM_START_N_ESTIMATORS = 500 + + +def check_feature_compatibility( + prev_model_dir: Path, + feature_engine: GemmUniversalFeatureEngine, +) -> None: + """Verify that the previous model's feature spec matches the current engine. + + Raises ValueError with a detailed message on mismatch. This prevents silent + corruption when warm-starting from a model trained with a different feature + schema (e.g., after adding a new feature or changing an encoding). + """ + spec_path = prev_model_dir / "feature_spec.json" + if not spec_path.exists(): + raise FileNotFoundError( + f"No feature_spec.json in {prev_model_dir}. " + "Cannot verify feature compatibility for warm start." + ) + + with open(spec_path) as f: + prev_spec = json.load(f) + + prev_names = prev_spec.get("feature_names", []) + curr_names = feature_engine.get_feature_names() + if prev_names != curr_names: + added = set(curr_names) - set(prev_names) + removed = set(prev_names) - set(curr_names) + parts = ["Feature schema mismatch between previous model and current engine."] + if added: + parts.append(f" Added features: {sorted(added)}") + if removed: + parts.append(f" Removed features: {sorted(removed)}") + if not added and not removed: + parts.append(" Feature order changed (names match but order differs).") + raise ValueError("\n".join(parts)) + + prev_cats = prev_spec.get("categorical_features", []) + curr_cats = feature_engine.get_categorical_features() + if sorted(prev_cats) != sorted(curr_cats): + raise ValueError( + f"Categorical feature mismatch.\n" + f" Previous: {sorted(prev_cats)}\n" + f" Current: {sorted(curr_cats)}" + ) + + +def load_warm_start_model(prev_model_dir: Path, target: str) -> str | None: + """Load the path to a previous model file for warm-start, or None if absent. + + Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist. + The decompressed file is cached to disk for subsequent loads. + + Returns the string path (what LightGBM's init_model expects) rather than + a loaded Booster, because LGBMRegressor.fit(init_model=...) accepts both + path strings and Booster objects and path strings avoid keeping the old + model in memory. + """ + import gzip + + model_path = prev_model_dir / f"model_{target}.lgbm" + gz_path = prev_model_dir / f"model_{target}.lgbm.gz" + + # Auto-decompress if needed + if not model_path.exists() and gz_path.exists(): + print(f" Decompressing {gz_path.name}...") + with gzip.open(gz_path, "rb") as f_in: + with open(model_path, "wb") as f_out: + f_out.write(f_in.read()) + + if not model_path.exists(): + return None + return str(model_path) + + +def compute_group_keys(df: pd.DataFrame) -> np.ndarray: + """Create GroupKFold group keys from (M, N, K).""" + return ( + df["m"].astype(str) + "_" + df["n"].astype(str) + "_" + df["k"].astype(str) + ).values + + +def compute_tflops_efficiency( + df: pd.DataFrame, pred_col: str = "pred_tflops" +) -> pd.DataFrame: + """Compute per-shape efficiency: predicted-best TFLOPS / oracle-best TFLOPS.""" + results = [] + for (m, n, k), group in df.groupby(["m", "n", "k"]): + oracle_best = group["measured_tflops"].max() + if oracle_best <= 0: + continue + pred_best_idx = group[pred_col].idxmax() + selected_tflops = group.loc[pred_best_idx, "measured_tflops"] + efficiency = selected_tflops / oracle_best + results.append( + { + "m": m, + "n": n, + "k": k, + "oracle_best_tflops": oracle_best, + "selected_tflops": selected_tflops, + "efficiency": efficiency, + } + ) + return pd.DataFrame(results) + + +def train_single_target( + X_train, + y_train, + X_val, + y_val, + params: dict, + categorical_features: list[str], + feature_names: list[str], + init_model=None, +) -> lgb.LGBMRegressor: + """Train a single LGBMRegressor with early stopping. + + Parameters + ---------- + init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None + If provided, training continues from this model (warm start). + Accepts a file path to a .lgbm file, a Booster instance, or an + LGBMModel instance. The new model adds n_estimators trees on top + of the existing ones. + """ + cat_indices = [ + feature_names.index(c) for c in categorical_features if c in feature_names + ] + + model = lgb.LGBMRegressor(**params) + model.fit( + X_train, + y_train, + eval_set=[(X_val, y_val)], + eval_metric=["rmse"], + callbacks=[ + lgb.early_stopping(50, verbose=False), + lgb.log_evaluation(0), + ], + categorical_feature=cat_indices if cat_indices else "auto", + init_model=init_model, + ) + return model + + +def run_cv( + df: pd.DataFrame, + feature_engine: GemmUniversalFeatureEngine, + target: str, + params: dict, + n_splits: int = 5, + use_log: bool = True, +) -> dict: + """Run GroupKFold cross-validation and return OOF predictions + metrics. + + Parameters + ---------- + use_log : bool + If True and target is in LOG_TARGETS, train on log1p(y) and invert + predictions with expm1 for efficiency calculation. This normalizes + the scale so that tiny-M shapes (TFLOPS ~ 1) get equal attention + as large-M shapes (TFLOPS ~ 2000). + """ + target_col = TARGET_COLUMNS[target] + valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0) + df_valid = df[valid_mask].reset_index(drop=True) + + apply_log = use_log and target in LOG_TARGETS + + print( + f" Training on {len(df_valid)} valid rows for target={target}" + f"{' (log-space)' if apply_log else ''}" + ) + + X = feature_engine.extract_batch(df_valid) + y_raw = df_valid[target_col].values + y = np.log1p(y_raw) if apply_log else y_raw + groups = compute_group_keys(df_valid) + feature_names = feature_engine.get_feature_names() + cat_features = feature_engine.get_categorical_features() + + unique_groups = np.unique(groups) + actual_splits = min(n_splits, len(unique_groups)) + if actual_splits < 2: + print(f" WARNING: Only {len(unique_groups)} unique groups, skipping CV") + return {} + + gkf = GroupKFold(n_splits=actual_splits) + oof_preds = np.zeros(len(df_valid)) + fold_metrics = [] + + for fold_idx, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)): + X_tr, X_val = X[train_idx], X[val_idx] + y_tr, y_val = y[train_idx], y[val_idx] + + model = train_single_target( + X_tr, y_tr, X_val, y_val, params, cat_features, feature_names + ) + preds = model.predict(X_val) + oof_preds[val_idx] = preds + + rmse = np.sqrt(np.mean((preds - y_val) ** 2)) + r2 = 1 - np.sum((preds - y_val) ** 2) / max( + np.sum((y_val - y_val.mean()) ** 2), 1e-10 + ) + + if target == "tflops": + val_df = df_valid.iloc[val_idx].copy() + preds_raw = np.expm1(preds) if apply_log else preds + val_df["pred_tflops"] = preds_raw + eff_df = compute_tflops_efficiency(val_df) + mean_eff = eff_df["efficiency"].mean() if len(eff_df) > 0 else 0 + p10_eff = eff_df["efficiency"].quantile(0.1) if len(eff_df) > 0 else 0 + else: + mean_eff, p10_eff = None, None + + fold_metrics.append( + { + "fold": fold_idx, + "rmse": rmse, + "r2": r2, + "mean_efficiency": mean_eff, + "p10_efficiency": p10_eff, + "train_size": len(train_idx), + "val_size": len(val_idx), + "val_groups": len(np.unique(groups[val_idx])), + } + ) + + eff_str = ( + f", eff={mean_eff:.4f}, p10={p10_eff:.4f}" if mean_eff is not None else "" + ) + print(f" Fold {fold_idx}: RMSE={rmse:.4f}, R2={r2:.4f}{eff_str}") + + df_valid[f"oof_pred_{target}"] = oof_preds + + return { + "fold_metrics": fold_metrics, + "oof_df": df_valid, + "feature_names": feature_names, + "log_transform": apply_log, + } + + +def train_final_model( + df: pd.DataFrame, + feature_engine: GemmUniversalFeatureEngine, + target: str, + params: dict, + init_model=None, + use_log: bool = True, +) -> lgb.LGBMRegressor: + """Train the final model on all valid data. + + Parameters + ---------- + init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None + If provided, training continues from this model (warm start). + use_log : bool + If True and target is in LOG_TARGETS, train on log1p(y). + The saved model then predicts in log-space; callers must apply + expm1() to get raw values. + """ + target_col = TARGET_COLUMNS[target] + valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0) + df_valid = df[valid_mask].reset_index(drop=True) + + apply_log = use_log and target in LOG_TARGETS + + X = feature_engine.extract_batch(df_valid) + y_raw = df_valid[target_col].values + y = np.log1p(y_raw) if apply_log else y_raw + feature_names = feature_engine.get_feature_names() + cat_features = feature_engine.get_categorical_features() + cat_indices = [feature_names.index(c) for c in cat_features if c in feature_names] + + model = lgb.LGBMRegressor(**params) + model.fit( + X, + y, + categorical_feature=cat_indices if cat_indices else "auto", + init_model=init_model, + ) + return model + + +def main(): + parser = argparse.ArgumentParser( + description="Train CK Tile kernel performance models" + ) + parser.add_argument( + "--data_dir", required=True, help="Directory with parquet files" + ) + parser.add_argument("--out_dir", required=True, help="Output directory for models") + parser.add_argument("--op", default="gemm_universal", help="Operation type") + parser.add_argument("--dtype", default="fp8", help="Data type filter") + parser.add_argument("--arch", default="gfx950", help="Architecture") + parser.add_argument( + "--targets", default="tflops,latency,bandwidth", help="Comma-separated targets" + ) + parser.add_argument("--n_splits", type=int, default=5, help="Number of CV folds") + parser.add_argument( + "--tune", action="store_true", help="Run Optuna hyperparameter tuning" + ) + parser.add_argument( + "--no_log_transform", + action="store_true", + help="Disable log1p transform on targets. By default, TFLOPS and bandwidth " + "are trained in log-space for scale-invariant accuracy across shape sizes.", + ) + parser.add_argument( + "--warm_start", + default=None, + help="Path to previous model directory to continue training from. " + "Uses LightGBM's init_model to add new trees on top of the " + "existing model. Feature schemas must match exactly.", + ) + parser.add_argument( + "--warm_start_n_estimators", + type=int, + default=WARM_START_N_ESTIMATORS, + help=f"Number of new trees to add when warm-starting (default: {WARM_START_N_ESTIMATORS}). " + "Lower than a full train since we're refining, not starting from scratch.", + ) + args = parser.parse_args() + + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + targets = [t.strip() for t in args.targets.split(",")] + + print(f"Loading data from {args.data_dir}...") + df = build_training_dataset(args.data_dir, op_type=args.op, dtype=args.dtype) + print(f" Total rows: {len(df)}") + print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}") + print(f" Unique kernels: {df['kernel_name'].nunique()}") + + hw_cols = [c for c in df.columns if c.startswith("hw_")] + hw_kwargs = {} + if hw_cols: + row0 = df.iloc[0] + if "hw_num_cus" in df.columns: + hw_kwargs["num_cus"] = int(row0.get("hw_num_cus", 256)) + if "hw_max_clock_mhz" in df.columns: + hw_kwargs["max_clock_mhz"] = int(row0.get("hw_max_clock_mhz", 2400)) + if "hw_simds_per_cu" in df.columns: + hw_kwargs["simds_per_cu"] = int(row0.get("hw_simds_per_cu", 4)) + if "hw_shader_engines" in df.columns: + hw_kwargs["shader_engines"] = int(row0.get("hw_shader_engines", 32)) + if "hw_max_waves_per_cu" in df.columns: + hw_kwargs["max_waves_per_cu"] = int(row0.get("hw_max_waves_per_cu", 32)) + if "hw_wavefront_size" in df.columns: + hw_kwargs["wavefront_size"] = int(row0.get("hw_wavefront_size", 64)) + if "hw_l1_cache_kb" in df.columns: + hw_kwargs["l1_cache_kb"] = int(row0.get("hw_l1_cache_kb", 32)) + if "hw_l2_cache_kb" in df.columns: + hw_kwargs["l2_cache_kb"] = int(row0.get("hw_l2_cache_kb", 4096)) + if "hw_l3_cache_kb" in df.columns: + hw_kwargs["l3_cache_kb"] = int(row0.get("hw_l3_cache_kb", 262144)) + + fe = GemmUniversalFeatureEngine(**hw_kwargs) + + params = dict(DEFAULT_PARAMS) + use_log = not args.no_log_transform + + prev_model_dir = None + prev_manifest = {} + if args.warm_start: + prev_model_dir = Path(args.warm_start) + if not prev_model_dir.exists(): + raise FileNotFoundError(f"Warm-start directory not found: {prev_model_dir}") + print(f" Warm-starting from {prev_model_dir}") + check_feature_compatibility(prev_model_dir, fe) + print(" Feature compatibility: OK") + params["n_estimators"] = args.warm_start_n_estimators + print(f" New trees to add: {args.warm_start_n_estimators}") + + prev_manifest_path = prev_model_dir / "train_manifest.json" + if prev_manifest_path.exists(): + with open(prev_manifest_path) as f: + prev_manifest = json.load(f) + + all_cv_results = {} + for target in targets: + if target not in TARGET_COLUMNS: + print(f" Skipping unknown target: {target}") + continue + + print(f"\n{'=' * 60}") + print(f"Training {target} model") + print(f"{'=' * 60}") + + init_model_path = None + if prev_model_dir is not None: + init_model_path = load_warm_start_model(prev_model_dir, target) + if init_model_path: + print(f" Warm-starting from {init_model_path}") + else: + print(f" No previous {target} model found, training from scratch") + + t0 = time.time() + cv_result = run_cv( + df, fe, target, params, n_splits=args.n_splits, use_log=use_log + ) + cv_time = time.time() - t0 + + if cv_result and cv_result["fold_metrics"]: + all_cv_results[target] = cv_result["fold_metrics"] + metrics_path = out_dir / f"cv_metrics_{target}.json" + with open(metrics_path, "w") as f: + json.dump(cv_result["fold_metrics"], f, indent=2) + print(f" CV completed in {cv_time:.1f}s, saved to {metrics_path}") + + if target == "tflops" and cv_result.get("oof_df") is not None: + oof_df = cv_result["oof_df"] + oof_df.to_parquet(out_dir / "oof_predictions.parquet", index=False) + + eff_df = compute_tflops_efficiency(oof_df, "oof_pred_tflops") + if len(eff_df) > 0: + print("\n OOF TFLOPS Efficiency:") + print(f" Mean: {eff_df['efficiency'].mean():.4f}") + print(f" P10: {eff_df['efficiency'].quantile(0.1):.4f}") + print(f" P50: {eff_df['efficiency'].quantile(0.5):.4f}") + print(f" Min: {eff_df['efficiency'].min():.4f}") + + print(f"\n Training final {target} model on all data...") + t0 = time.time() + model = train_final_model( + df, fe, target, params, init_model=init_model_path, use_log=use_log + ) + train_time = time.time() - t0 + + model_path = out_dir / f"model_{target}.lgbm" + model.booster_.save_model(str(model_path)) + print(f" Saved {model_path} ({train_time:.1f}s)") + + importances = dict( + zip( + fe.get_feature_names(), + model.feature_importances_.tolist(), + ) + ) + imp_path = out_dir / f"feature_importances_{target}.json" + with open(imp_path, "w") as f: + json.dump(importances, f, indent=2) + + log_targets_used = sorted(LOG_TARGETS & set(targets)) if use_log else [] + spec = { + "op_type": args.op, + "dtype": args.dtype, + "arch": args.arch, + "feature_names": fe.get_feature_names(), + "categorical_features": fe.get_categorical_features(), + "targets": targets, + "log_targets": log_targets_used, + "params": params, + } + with open(out_dir / "feature_spec.json", "w") as f: + json.dump(spec, f, indent=2) + + manifest = { + "warm_start_from": str(prev_model_dir) if prev_model_dir else None, + "prev_n_estimators": prev_manifest.get( + "total_n_estimators", params.get("n_estimators") + ) + if prev_model_dir + else 0, + "new_n_estimators": params["n_estimators"], + "total_n_estimators": ( + prev_manifest.get("total_n_estimators", 0) + params["n_estimators"] + if prev_model_dir + else params["n_estimators"] + ), + "data_rows": len(df), + "valid_rows": int(df["is_valid"].fillna(False).sum()), + "unique_shapes": int(df.groupby(["m", "n", "k"]).ngroups), + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), + } + with open(out_dir / "train_manifest.json", "w") as f: + json.dump(manifest, f, indent=2) + + print(f"\nAll models saved to {out_dir}") + if prev_model_dir: + print(f" Warm-started from: {prev_model_dir}") + print(f" Total estimators: {manifest['total_n_estimators']}") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/heuristics/validate_ml_heuristic.py b/dispatcher/heuristics/validate_ml_heuristic.py new file mode 100644 index 0000000000..ccd7a20cd9 --- /dev/null +++ b/dispatcher/heuristics/validate_ml_heuristic.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +ML Heuristic Validation: Test ML predictions against oracle-best from training data + +This script validates ML-based kernel selection by: +1. Loading benchmark data (oracle-best results for each shape) +2. Using ML model to predict best kernel for each shape +3. Comparing ML selection with oracle-best to compute efficiency + +Usage: + python validate_ml_heuristic.py --dtype fp16 --model_dir models/gemm_universal_fp16_gfx950 + python validate_ml_heuristic.py --dtype fp8 --layout rcr +""" + +import sys +import argparse +import pandas as pd +import numpy as np +from pathlib import Path + +from predict import Predictor + + +def validate_ml_heuristic(dtype: str, layout: str, model_dir: str, data_dir: str): + """Validate ML heuristic predictions against oracle-best""" + + print("=" * 100) + print(f" ML Heuristic Validation: {dtype.upper()} {layout.upper()}") + print("=" * 100) + print() + + # Load training data + print(f"Loading training data from {data_dir}...") + + # Try dtype-specific parquet first, then fall back to combined + dtype_specific = ( + Path(data_dir) / f"{dtype}_original" / f"{dtype}_training_data.parquet" + ) + combined = Path(data_dir) / "all_training_data_fixed.parquet" + + if dtype_specific.exists(): + training_data = pd.read_parquet(dtype_specific) + print(f"✓ Loaded {len(training_data):,} benchmark runs from {dtype_specific}") + elif combined.exists(): + training_data = pd.read_parquet(combined) + training_data = training_data[ + (training_data["dtype"] == dtype) & (training_data["layout"] == layout) + ] + print(f"✓ Loaded {len(training_data):,} benchmark runs from {combined}") + else: + print(f"❌ Error: No training data found at {dtype_specific} or {combined}") + return + + if len(training_data) == 0: + print(f"❌ Error: No data found for dtype={dtype}, layout={layout}") + return + + # Get unique shapes with oracle-best + shape_groups = training_data.groupby(["m", "n", "k"]) + print(f"Unique shapes: {len(shape_groups)}") + print() + + # Load ML predictor + print(f"Loading ML predictor from {model_dir}...") + try: + predictor = Predictor(model_dir) + print("✓ Loaded ML predictor") + print(f" Log targets: {predictor._log_targets}") + except Exception as e: + print(f"❌ Error loading model: {e}") + return + + print() + print("=" * 100) + print(" Computing Oracle-Best Efficiency for Each Shape") + print("=" * 100) + print() + + results = [] + + for shape_idx, ((m, n, k), group) in enumerate(shape_groups): + # Find oracle-best (max TFLOPS across all kernels tested) + oracle_best_row = group.loc[group["measured_tflops"].idxmax()] + oracle_best_tflops = oracle_best_row["measured_tflops"] + oracle_best_kernel = oracle_best_row["kernel_name"] + + # Get all kernel configs tested for this shape + kernel_configs = [] + for _, row in group.iterrows(): + kernel_dict = { + "tile_m": row["tile_m"], + "tile_n": row["tile_n"], + "tile_k": row["tile_k"], + "warp_m": row["warp_m"], + "warp_n": row["warp_n"], + "warp_k": row["warp_k"], + "warp_tile_m": row["warp_tile_m"], + "warp_tile_n": row["warp_tile_n"], + "warp_tile_k": row["warp_tile_k"], + "pipeline": row["pipeline"], + "scheduler": row["scheduler"], + "epilogue": row["epilogue"], + "pad_m": row["pad_m"], + "pad_n": row["pad_n"], + "pad_k": row["pad_k"], + "persistent": row["persistent"], + "kernel_name": row["kernel_name"], + } + kernel_configs.append(kernel_dict) + + # Use ML model to rank kernels + problem = { + "m": m, + "n": n, + "k": k, + "dtype": dtype, + "layout": layout, + "split_k": 1, + } + + try: + ranked = predictor.rank_kernels(problem, kernel_configs) + + if ranked: + ml_best_kernel, ml_predicted_tflops = ranked[0] + + # Find actual TFLOPS for the ML-predicted kernel + ml_kernel_row = group[group["kernel_name"] == ml_best_kernel] + if len(ml_kernel_row) > 0: + ml_actual_tflops = ml_kernel_row["measured_tflops"].values[0] + + # Calculate efficiency + efficiency_pct = 100.0 * (ml_actual_tflops / oracle_best_tflops) + + # Determine if ML picked oracle-best + is_oracle_best = ml_best_kernel == oracle_best_kernel + + results.append( + { + "m": m, + "n": n, + "k": k, + "oracle_best_tflops": oracle_best_tflops, + "oracle_best_kernel": oracle_best_kernel, + "ml_predicted_tflops": ml_predicted_tflops, + "ml_selected_kernel": ml_best_kernel, + "ml_actual_tflops": ml_actual_tflops, + "efficiency_pct": efficiency_pct, + "is_oracle_best": is_oracle_best, + "num_kernels": len(group), + } + ) + + if (shape_idx + 1) % 20 == 0: + status = "✓" if is_oracle_best else f"{efficiency_pct:.1f}%" + print( + f" [{shape_idx + 1:3d}/{len(shape_groups)}] " + f"M={m:4d} N={n:5d} K={k:5d}: {status}" + ) + except Exception as e: + print(f" Error on shape M={m} N={n} K={k}: {e}") + continue + + print() + print("=" * 100) + print(" Results Summary") + print("=" * 100) + print() + + if results: + df_results = pd.DataFrame(results) + efficiencies = df_results["efficiency_pct"].values + oracle_matches = df_results["is_oracle_best"].sum() + + print(f"Total shapes tested: {len(results)}") + print() + print("Efficiency Statistics (% of Oracle-Best TFLOPS):") + print(f" Mean: {np.mean(efficiencies):.2f}%") + print(f" Median: {np.median(efficiencies):.2f}%") + print(f" Min: {np.min(efficiencies):.2f}%") + print(f" Max: {np.max(efficiencies):.2f}%") + print(f" P10: {np.percentile(efficiencies, 10):.2f}%") + print(f" P50: {np.percentile(efficiencies, 50):.2f}%") + print(f" P90: {np.percentile(efficiencies, 90):.2f}%") + print() + print( + f"Oracle-best matches: {oracle_matches}/{len(results)} ({100 * oracle_matches / len(results):.1f}%)" + ) + print() + + # Classify by M size + df_results["m_class"] = pd.cut( + df_results["m"], + bins=[0, 8, 128, 1024, float("inf")], + labels=[ + "Tiny (M<8)", + "Small (8≤M<128)", + "Medium (128≤M<1024)", + "Large (M≥1024)", + ], + ) + + print("Efficiency by M size:") + for m_class in [ + "Tiny (M<8)", + "Small (8≤M<128)", + "Medium (128≤M<1024)", + "Large (M≥1024)", + ]: + subset = df_results[df_results["m_class"] == m_class] + if len(subset) > 0: + print( + f" {m_class:25s}: {subset['efficiency_pct'].mean():6.2f}% " + f"(n={len(subset)}, P10={subset['efficiency_pct'].quantile(0.1):.2f}%)" + ) + + print() + + # Save results + output_file = f"validation_results_{dtype}_{layout}.csv" + df_results.to_csv(output_file, index=False) + print(f"✓ Results saved to {output_file}") + + # Show best and worst shapes + print() + print("Top 5 shapes (best efficiency):") + top5 = df_results.nlargest(5, "efficiency_pct")[ + ["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"] + ] + for idx, row in top5.iterrows(): + match = "✓" if row["is_oracle_best"] else " " + print( + f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: " + f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)" + ) + + print() + print("Bottom 5 shapes (worst efficiency):") + bottom5 = df_results.nsmallest(5, "efficiency_pct")[ + ["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"] + ] + for idx, row in bottom5.iterrows(): + match = "✓" if row["is_oracle_best"] else " " + print( + f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: " + f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)" + ) + + else: + print("No results to display") + + print() + print("=" * 100) + + +def main(): + parser = argparse.ArgumentParser( + description="Validate ML heuristic predictions against oracle-best from training data" + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp8"], + help="Data type to validate", + ) + parser.add_argument( + "--layout", + default="rcr", + choices=["rcr", "rrr", "crr", "ccr"], + help="Matrix layout", + ) + parser.add_argument( + "--model_dir", + default=None, + help="Path to model directory (auto-detect if not specified)", + ) + parser.add_argument( + "--data_dir", + default=None, + help="Path to training data directory (auto-detect if not specified)", + ) + + args = parser.parse_args() + + # Auto-detect model directory if not specified + if args.model_dir is None: + heuristics_dir = Path(__file__).parent + model_candidates = [ + heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx950", + heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx942", + ] + for candidate in model_candidates: + if candidate.exists(): + args.model_dir = str(candidate) + break + + if args.model_dir is None: + print(f"❌ Error: Could not find model directory for {args.dtype}") + print(f" Searched: {[str(c) for c in model_candidates]}") + print(" Please specify --model_dir explicitly") + return 1 + + # Auto-detect data directory if not specified + if args.data_dir is None: + heuristics_dir = Path(__file__).parent + args.data_dir = str(heuristics_dir / "data") + + validate_ml_heuristic(args.dtype, args.layout, args.model_dir, args.data_dir) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/include/ck_tile/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher.hpp index 98d8bb9333..b3d8f10675 100644 --- a/dispatcher/include/ck_tile/dispatcher.hpp +++ b/dispatcher/include/ck_tile/dispatcher.hpp @@ -3,9 +3,17 @@ #pragma once -/// Main dispatcher header - includes all core components -/// Use this for convenient access to the full dispatcher API +/// Full dispatcher header - includes ALL operation types. +/// For minimal includes, use the per-operation headers instead: +/// ck_tile/dispatcher_gemm.hpp -- GEMM only +/// ck_tile/dispatcher_conv.hpp -- Grouped Convolution only +// Core (needed by all ops) +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +// GEMM #include "ck_tile/dispatcher/kernel_key.hpp" #include "ck_tile/dispatcher/kernel_config.hpp" #include "ck_tile/dispatcher/kernel_decl.hpp" @@ -13,7 +21,15 @@ #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/registry.hpp" #include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/json_export.hpp" #include "ck_tile/dispatcher/arch_filter.hpp" #include "ck_tile/dispatcher/backends/tile_backend.hpp" #include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" #include "ck_tile/dispatcher/utils.hpp" + +// Grouped Convolution +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher/README.md b/dispatcher/include/ck_tile/dispatcher/README.md index db3ce996a9..430798aedd 100644 --- a/dispatcher/include/ck_tile/dispatcher/README.md +++ b/dispatcher/include/ck_tile/dispatcher/README.md @@ -1,6 +1,6 @@ # CK Tile Dispatcher - C++ Headers -C++ API for the CK Tile dispatcher. +C++ API for the CK Tile dispatcher (GEMM and Grouped Convolution). > **See also:** [Main Dispatcher README](../../../../README.md) for installation and core concepts. @@ -8,16 +8,25 @@ C++ API for the CK Tile dispatcher. ``` dispatcher/ -├── dispatcher.hpp # Main dispatcher (kernel selection) -├── registry.hpp # Kernel registry (storage & lookup) -├── problem.hpp # Problem specification -├── kernel_key.hpp # Kernel configuration key -├── kernel_instance.hpp # Kernel instance interface -├── utils.hpp # Utilities (timers, GPU buffers) -│ -└── backends/ # Backend implementations - ├── generated_tile_backend.hpp # CK Tile kernels (production) - └── tile_backend.hpp # Tile backend base +|---- dispatcher.hpp # Main include (includes all below) +| +|---- # GEMM Headers +|---- registry.hpp # Kernel registry (storage & lookup) +|---- problem.hpp # GEMM problem specification +|---- kernel_key.hpp # Kernel configuration key +|---- kernel_instance.hpp # Kernel instance interface +|---- utils.hpp # Utilities (timers, GPU buffers) +| +|---- # Grouped Convolution Headers +|---- grouped_conv_config.hpp # GroupedConvDirection, GroupedConvConfig +|---- grouped_conv_problem.hpp # GroupedConvProblem + ProblemBuilder +|---- grouped_conv_kernel_decl.hpp # GroupedConvKernelDecl, DECL_GROUPED_CONV_KERNEL_SET +|---- grouped_conv_registry.hpp # Thread-safe registry with JSON export & filtering +|---- grouped_conv_utils.hpp # Config creators, validation, benchmark utilities +| ++---- backends/ # Backend implementations + |---- generated_tile_backend.hpp # CK Tile kernels (production) + +---- tile_backend.hpp # Tile backend base ``` ## Quick Start @@ -148,6 +157,69 @@ auto kernel = create_generated_tile_kernel< >(key, name); ``` +## Grouped Convolution API + +### GroupedConvProblem (`grouped_conv_problem.hpp`) + +Problem specification with builder pattern: + +```cpp +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" + +using namespace ck_tile::dispatcher; + +auto problem = GroupedConvProblemBuilder() + .n(2).g(1).c(128).k(256) + .input_spatial({28, 28}) + .filter_spatial({3, 3}) + .strides({1, 1}) + .dilations({1, 1}) + .left_pads({1, 1}) + .right_pads({1, 1}) + .build(); + +bool ok = problem.is_valid(); +``` + +### GroupedConvRegistry (`grouped_conv_registry.hpp`) + +Thread-safe registry with JSON export and filtering: + +```cpp +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" + +auto& registry = GroupedConvRegistry::instance(); + +// Thread-safe registration +registry.register_kernel(kernel); + +// JSON export +std::string json = registry.export_json(); +registry.export_json_to_file("kernels.json"); + +// Filtering +auto gfx942_kernels = registry.filter_by_arch("gfx942"); +auto matched = registry.filter([](const auto& k) { return k.is_fwd(); }); +``` + +### DECL_GROUPED_CONV_KERNEL_SET (`grouped_conv_kernel_decl.hpp`) + +Declarative kernel definition: + +```cpp +DECL_GROUPED_CONV_KERNEL_SET(my_conv_kernels, + .add( + GroupedConvSignature().dtype("fp16").layout("nhwgc"), + GroupedConvAlgorithm().tile(128, 128, 32).wave(2, 2, 1) + .warp(32, 32, 16).pipeline("compv4"), + "gfx942" + ) +); + +// Register all matching current arch +DECL_GROUPED_CONV_KERNEL_ALL(all_conv_kernels, "gfx942"); +``` + ## Best Practices 1. Use `Release` build for performance @@ -155,6 +227,8 @@ auto kernel = create_generated_tile_kernel< 3. Use `Priority::High` for hand-tuned kernels 4. Reuse dispatcher instances 5. Clear registry between test runs +6. Use `GroupedConvProblemBuilder` for validated problem construction +7. Leverage `export_json()` for kernel inventory and debugging --- diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp new file mode 100644 index 0000000000..04ee1b2d11 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp @@ -0,0 +1,152 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Generated Convolution Kernel Backend +// +// Wraps CK Tile grouped convolution launchers for use through the +// GroupedConvDispatcher. Each generated kernel launcher is wrapped in +// a ConvKernelRunFn that builds the correct host-args type (forward, +// bwd-data, or bwd-weight) and calls Launcher::launch(). + +#pragma once + +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +// Buffer context is defined in grouped_conv_registry.hpp (g_conv_dispatch_buffers) +// so there's no circular dependency. + +// Helper: build ck_tile::conv::ConvParam from GroupedConvProblem +inline ck_tile::conv::ConvParam make_conv_param_2d(const GroupedConvProblem& p) +{ + return ck_tile::conv::ConvParam{ + 2, + static_cast(p.G), + static_cast(p.N), + static_cast(p.K), + static_cast(p.C), + {static_cast(p.filter_spatial[1]), + static_cast(p.filter_spatial[2])}, + {static_cast(p.input_spatial[1]), + static_cast(p.input_spatial[2])}, + {static_cast(p.stride[1]), static_cast(p.stride[2])}, + {static_cast(p.dilation[1]), + static_cast(p.dilation[2])}, + {static_cast(p.padding[1]), static_cast(p.padding[2])}, + {static_cast(p.padding[1]), static_cast(p.padding[2])}}; +} + +inline ck_tile::conv::ConvParam make_conv_param_3d(const GroupedConvProblem& p) +{ + return ck_tile::conv::ConvParam{3, + static_cast(p.G), + static_cast(p.N), + static_cast(p.K), + static_cast(p.C), + {static_cast(p.filter_spatial[0]), + static_cast(p.filter_spatial[1]), + static_cast(p.filter_spatial[2])}, + {static_cast(p.input_spatial[0]), + static_cast(p.input_spatial[1]), + static_cast(p.input_spatial[2])}, + {static_cast(p.stride[0]), + static_cast(p.stride[1]), + static_cast(p.stride[2])}, + {static_cast(p.dilation[0]), + static_cast(p.dilation[1]), + static_cast(p.dilation[2])}, + {static_cast(p.padding[0]), + static_cast(p.padding[1]), + static_cast(p.padding[2])}, + {static_cast(p.padding[0]), + static_cast(p.padding[1]), + static_cast(p.padding[2])}}; +} + +// Create a RunFn for a forward convolution launcher (2D or 3D) +template +inline GroupedConvKernelInstance::RunFn make_conv_fwd_run_fn() +{ + return [](const GroupedConvProblem& problem, void* stream) -> float { + auto& ctx = g_conv_dispatch_buffers; + auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem); + ck_tile::GroupedConvFwdHostArgs<> args( + param, ctx.input_ptr, ctx.weight_ptr, {}, ctx.output_ptr, 1); + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = ctx.benchmarking; + sc.log_level_ = 0; + sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0; + sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1; + sc.is_gpu_timer_ = ctx.benchmarking; + return LauncherType::launch(args, sc); + }; +} + +// Create a RunFn for a backward-data convolution launcher. +// Dispatcher convention: run(dY, W, dX, problem) where dX is computed. +// BwdDataHostArgs(param, in_ptr=dX, wei_ptr=W, {}, out_ptr=dY, k_batch) +template +inline GroupedConvKernelInstance::RunFn make_conv_bwd_data_run_fn() +{ + return [](const GroupedConvProblem& problem, void* stream) -> float { + auto& ctx = g_conv_dispatch_buffers; + auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem); + ck_tile::GroupedConvBwdDataHostArgs args( + param, + ctx.output_ptr, // in_ptr = dX (being computed) + ctx.weight_ptr, // wei_ptr = W + {}, + ctx.input_ptr, // out_ptr = dY (gradient from next layer) + 1); + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = ctx.benchmarking; + sc.log_level_ = 0; + sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0; + sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1; + sc.is_gpu_timer_ = ctx.benchmarking; + return LauncherType::launch(args, sc); + }; +} + +// Create a RunFn for a backward-weight convolution launcher. +// Dispatcher convention: run(X, dY, dW, problem) where dW is computed. +// BwdWeightHostArgs(param, in_ptr=X, wei_ptr=dW, {}, out_ptr=dY, k_batch) +template +inline GroupedConvKernelInstance::RunFn make_conv_bwd_weight_run_fn() +{ + return [](const GroupedConvProblem& problem, void* stream) -> float { + auto& ctx = g_conv_dispatch_buffers; + auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem); + const int k_batch = (ctx.split_k > 1) ? ctx.split_k : 1; + ck_tile::GroupedConvBwdWeightHostArgs args(param, + ctx.input_ptr, // in_ptr = X + ctx.output_ptr, // wei_ptr = dW (being computed) + {}, + ctx.weight_ptr, // out_ptr = dY + k_batch); + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = ctx.benchmarking; + sc.log_level_ = 0; + sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0; + sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1; + sc.is_gpu_timer_ = ctx.benchmarking; + return LauncherType::launch(args, sc); + }; +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/base_registry.hpp b/dispatcher/include/ck_tile/dispatcher/base_registry.hpp new file mode 100644 index 0000000000..2bb940c320 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/base_registry.hpp @@ -0,0 +1,199 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Shared priority enum used by all registry types +enum class Priority +{ + Low = 0, + Normal = 1, + High = 2 +}; + +/// BaseRegistry: Thread-safe, priority-aware kernel storage shared by GEMM and Conv registries. +/// +/// Template Parameters: +/// Derived - CRTP derived class (e.g., Registry, ConvRegistry) +/// KeyType - primary key type (std::string for GEMM, ConvKernelKey for Conv) +/// InstanceType - kernel instance type (KernelInstance, ConvKernelInstance) +/// KeyHash - hash functor for KeyType (defaults to std::hash) +template > +class BaseRegistry +{ + public: + using InstancePtr = std::shared_ptr; + + struct Entry + { + InstancePtr instance; + Priority priority; + }; + + BaseRegistry() = default; + virtual ~BaseRegistry() = default; + + BaseRegistry(BaseRegistry&& other) noexcept + { + std::lock_guard lock(other.mutex_); + entries_ = std::move(other.entries_); + name_ = std::move(other.name_); + } + + BaseRegistry& operator=(BaseRegistry&& other) noexcept + { + if(this != &other) + { + std::scoped_lock lock(mutex_, other.mutex_); + entries_ = std::move(other.entries_); + name_ = std::move(other.name_); + } + return *this; + } + + BaseRegistry(const BaseRegistry&) = delete; + BaseRegistry& operator=(const BaseRegistry&) = delete; + + /// Register a kernel. If the key already exists, the new entry replaces it + /// unless the existing entry has strictly higher priority. + /// Same-priority registration overwrites (last-writer-wins at equal priority). + bool + register_kernel(const KeyType& key, InstancePtr instance, Priority priority = Priority::Normal) + { + std::lock_guard lock(mutex_); + auto it = entries_.find(key); + if(it != entries_.end() && it->second.priority > priority) + { + return false; + } + entries_[key] = Entry{std::move(instance), priority}; + return true; + } + + [[nodiscard]] std::size_t size() const + { + std::lock_guard lock(mutex_); + return entries_.size(); + } + + [[nodiscard]] bool empty() const + { + std::lock_guard lock(mutex_); + return entries_.empty(); + } + + void clear() + { + std::lock_guard lock(mutex_); + entries_.clear(); + } + + [[nodiscard]] std::string get_name() const + { + std::lock_guard lock(mutex_); + return name_; // return by value to avoid dangling reference + } + + void set_name(const std::string& name) + { + std::lock_guard lock(mutex_); + name_ = name; + } + + [[nodiscard]] std::vector get_all_instances() const + { + std::lock_guard lock(mutex_); + std::vector result; + result.reserve(entries_.size()); + for(const auto& [key, entry] : entries_) + { + result.push_back(entry.instance); + } + return result; + } + + std::size_t merge_from(const BaseRegistry& other, Priority priority = Priority::Normal) + { + std::scoped_lock lock(mutex_, other.mutex_); + std::size_t merged = 0; + for(const auto& [key, entry] : other.entries_) + { + auto it = entries_.find(key); + if(it == entries_.end() || it->second.priority <= priority) + { + entries_[key] = Entry{entry.instance, priority}; + ++merged; + } + } + return merged; + } + + /// Enable automatic JSON export after every kernel registration. + /// Requires the derived class to implement export_json_to_file(path, stats). + void enable_auto_export(const std::string& path, + bool include_statistics = true, + bool export_on_every_registration = true) + { + std::lock_guard lock(mutex_); + auto_export_path_ = path; + auto_export_stats_ = include_statistics; + auto_export_on_register_ = export_on_every_registration; + auto_export_enabled_.store(true, std::memory_order_release); + } + + void disable_auto_export() { auto_export_enabled_.store(false, std::memory_order_release); } + + [[nodiscard]] bool is_auto_export_enabled() const + { + return auto_export_enabled_.load(std::memory_order_acquire); + } + + /// Call after registration to trigger auto-export if enabled. + void perform_auto_export() + { + if(!auto_export_enabled_.load(std::memory_order_acquire)) + return; + std::lock_guard lock(mutex_); + if(auto_export_on_register_) + { + static_cast(this)->export_json_to_file(auto_export_path_, auto_export_stats_); + } + } + + protected: + [[nodiscard]] const std::unordered_map& entries() const + { + return entries_; + } + + [[nodiscard]] std::unordered_map& entries_mut() { return entries_; } + + std::mutex& mutex() const { return mutex_; } + + private: + mutable std::mutex mutex_; + std::unordered_map entries_; + std::string name_ = "default"; + + std::atomic auto_export_enabled_{false}; + bool auto_export_on_register_ = true; + bool auto_export_stats_ = true; + std::string auto_export_path_; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp index 6d3f548138..d266d693da 100644 --- a/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp @@ -23,6 +23,7 @@ #pragma once +#include "ck_tile/dispatcher/dispatcher_error.hpp" #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/problem.hpp" #include "ck_tile/dispatcher/registry.hpp" @@ -52,7 +53,11 @@ class Dispatcher /// Constructor /// @param registry Registry instance to use (default: global singleton) - explicit Dispatcher(Registry* registry = nullptr); + /// @param gfx_arch Target GPU architecture (e.g. "gfx950") + explicit Dispatcher(Registry* registry = nullptr, const std::string& gfx_arch = ""); + + void set_arch(const std::string& arch) { gfx_arch_ = arch; } + [[nodiscard]] const std::string& arch() const { return gfx_arch_; } /// Register a heuristic function for kernel selection /// @param heuristic Function that maps problems to ranked kernel identifiers @@ -74,7 +79,7 @@ class Dispatcher /// @param problem Problem configuration /// @param stream HIP stream for kernel launch (nullptr = default stream) /// @return Kernel execution time in milliseconds - /// @throws std::runtime_error if no suitable kernel found + /// @throws NoKernelFound if no suitable kernel found [[nodiscard]] float run(const void* a_ptr, const void* b_ptr, void* c_ptr, @@ -89,7 +94,7 @@ class Dispatcher /// @param problem Problem configuration /// @param stream HIP stream for kernel launch (nullptr = default stream) /// @return Kernel execution time in milliseconds - /// @throws std::runtime_error if no suitable kernel found + /// @throws NoKernelFound if no suitable kernel found [[nodiscard]] float run_fused(const void* a_ptr, const void* b_ptr, void* c_ptr, @@ -106,7 +111,8 @@ class Dispatcher /// @param problem Problem configuration /// @param stream HIP stream for kernel launch (nullptr = default stream) /// @return Kernel execution time in milliseconds - /// @throws std::runtime_error if kernel not found or doesn't support problem + /// @throws NoKernelFound if the kernel identifier is not registered + /// @throws UnsupportedProblem if the selected kernel does not support the problem [[nodiscard]] float run_explicit(const std::string& kernel_id, const void* a_ptr, const void* b_ptr, @@ -130,10 +136,18 @@ class Dispatcher const Problem& problem, float tolerance = 1e-3f) const; + /// Enable or disable GPU benchmarking (timing) on all kernels. + /// When disabled, kernels execute once with no timing overhead + /// (one-shot mode for production plugins). + void set_benchmarking(bool enable) { benchmarking_ = enable; } + [[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; } + private: Registry* registry_; HeuristicFunction heuristic_; SelectionStrategy strategy_; + std::string gfx_arch_; + bool benchmarking_ = true; /// Select kernel using first-fit strategy [[nodiscard]] KernelInstancePtr select_first_fit(const Problem& problem) const; diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp new file mode 100644 index 0000000000..98b079f8d9 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp @@ -0,0 +1,28 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +namespace ck_tile { +namespace dispatcher { + +struct DispatcherError : std::runtime_error +{ + using std::runtime_error::runtime_error; +}; + +struct NoKernelFound : DispatcherError +{ + using DispatcherError::DispatcherError; +}; + +struct UnsupportedProblem : DispatcherError +{ + using DispatcherError::DispatcherError; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp new file mode 100644 index 0000000000..6a39766649 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Log levels for dispatcher transparency: +/// 0 = silent (default) +/// 1 = print selected kernel name +/// 2 = print all candidates considered and acceptance/rejection reasons +inline int get_log_level() +{ + static int level = []() { + const char* env = std::getenv("CK_DISPATCHER_LOG_LEVEL"); + return env ? std::atoi(env) : 0; + }(); + return level; +} + +inline void log_kernel_selected(const std::string& kernel_name, const std::string& problem_desc) +{ + if(get_log_level() >= 1) + { + std::cerr << "[CK Dispatcher] Selected kernel: " << kernel_name << " for " << problem_desc + << std::endl; + } +} + +inline void +log_kernel_candidate(const std::string& kernel_name, bool accepted, const std::string& reason) +{ + if(get_log_level() >= 2) + { + std::cerr << "[CK Dispatcher] Candidate: " << kernel_name << " -> " + << (accepted ? "ACCEPTED" : "REJECTED") + << (reason.empty() ? "" : " (" + reason + ")") << std::endl; + } +} + +inline void log_no_kernel_found(const std::string& problem_desc) +{ + if(get_log_level() >= 1) + { + std::cerr << "[CK Dispatcher] No kernel found for " << problem_desc << std::endl; + } +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp new file mode 100644 index 0000000000..91b7b3ad74 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp @@ -0,0 +1,588 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_config.hpp + * @brief CK Tile Grouped Convolution Configuration with Builder-style naming + * + * This adopts the Signature/Algorithm/Arch pattern from: + * experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp + * + * Structure: + * - Signature: WHAT operation (types, layouts, direction, element ops) + * - Algorithm: HOW it's computed (tiles, warps, pipeline, scheduler, padding) + * - Arch: Target GPU architecture + */ + +#pragma once + +// Use common kernel_key types for DataType, Pipeline, etc. +#include "ck_tile/dispatcher/kernel_key.hpp" + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +// DataType, Pipeline, Scheduler, Epilogue are defined in kernel_key.hpp +// No need to redefine them here + +// ============================================================================= +// Data Type Enum (matching CK Tile numeric types) +// ============================================================================= + +enum class ConvDataType +{ + // Standard floating point + FP32, // float + FP64, // double + FP16, // half_t + BF16, // bf16_t + + // 8-bit float variants (FP8/BF8) + FP8, // fp8_t (E4M3) + BF8, // bf8_t (E5M2) + FP8_E4M3, // Explicit E4M3 format + FP8_E5M2, // Explicit E5M2 format + + // Integer types + INT8, // int8_t + UINT8, // uint8_t + INT32, // int32_t (accumulator) + + // 4-bit types (gfx950+ only) + FP4, // MXFP4 + INT4 // pk_int4_t +}; + +// ============================================================================= +// Direction and Layout Enums +// ============================================================================= + +enum class GroupedConvDirection +{ + FORWARD, + BACKWARD_DATA, + BACKWARD_WEIGHT +}; + +enum class ConvLayout2D +{ + GNHWC_GKYXC_GNHWK, // NHWC-style + NHWGC_GKYXC_NHWGK, + NGCHW_GKYXC_NGKHW, // NCHW-style + NGCHW_GKCYX_NGKHW +}; + +enum class ConvLayout3D +{ + GNDHWC_GKZYXC_GNDHWK, + NDHWGC_GKZYXC_NDHWGK, + NGCDHW_GKZYXC_NGKDHW, + NGCDHW_GKCZYX_NGKDHW +}; + +// ============================================================================= +// Element-wise Operations +// ============================================================================= + +enum class ElementwiseOp +{ + PASS_THROUGH, + BIAS, + BIAS_CLAMP, + SCALE, + BILINEAR, + RELU, + GELU, + SIGMOID, + TANH +}; + +// ============================================================================= +// Grouped Convolution Specialization +// ============================================================================= + +enum class ConvSpecialization +{ + DEFAULT, + FILTER_1X1_PAD0, + FILTER_1X1_STRIDE1_PAD0, + FILTER_3X3, + FILTER_5X5, + FILTER_7X7 +}; + +// ============================================================================= +// Memory Operation Types (for accumulator operations) +// ============================================================================= + +enum class MemoryOperation +{ + SET, // Direct write (=) + ATOMIC_ADD, // Atomic addition (+=) + ATOMIC_MAX, // Atomic max + ADD // Non-atomic addition +}; + +// ============================================================================= +// Epilogue Types +// ============================================================================= + +enum class EpilogueType +{ + CSHUFFLE, // C-shuffle epilogue + DEFAULT_2D, // Default 2D epilogue + DEFAULT_GEMM_2D, // Default GEMM 2D epilogue + DIRECT_STORE, // Direct store without shuffle + BIAS_ADD, // Add bias + BIAS_ADD_RELU, // Add bias + ReLU + BIAS_ADD_GELU // Add bias + GELU +}; + +// ============================================================================= +// Algorithm Enums (matching builder/types.hpp and CK Tile pipelines) +// ============================================================================= + +enum class PipelineVersion +{ + V1, // Basic pipeline V1 + V2, // Basic pipeline V2 + V3, // Compute V3 (intrawave only) + V4, // Compute V4 (double buffer, ping-pong LDS) + V5, // Compute V5 (wave groups) + V6, // Compute V6 (newest) + MEMORY, // Memory pipeline + COMPUTE_ASYNC, // Compute with async copy + PRESHUFFLE_V2 // Preshuffle V2 pipeline +}; + +enum class PipelineScheduler +{ + DEFAULT, + INTRAWAVE, + INTERWAVE +}; + +enum class GemmPadding +{ + DEFAULT, + NO_PADDING, // No padding + M_PADDING, + N_PADDING, + K_PADDING, + MN_PADDING, + MK_PADDING, + NK_PADDING, + MNK_PADDING +}; + +// ============================================================================= +// Signature Info (WHAT operation) +// ============================================================================= + +struct GroupedConvSignatureInfo +{ + int spatial_dim = 2; // 1, 2, or 3 + GroupedConvDirection direction = GroupedConvDirection::FORWARD; + std::string in_type = "fp16"; + std::string wei_type = "fp16"; + std::string out_type = "fp16"; + std::string acc_type = "fp32"; + std::string workspace_type = "fp32"; // For two-stage algorithms + std::string bias_type = "fp16"; // For bias epilogue + ElementwiseOp in_element_op = ElementwiseOp::PASS_THROUGH; + ElementwiseOp wei_element_op = ElementwiseOp::PASS_THROUGH; + ElementwiseOp out_element_op = ElementwiseOp::PASS_THROUGH; + ConvSpecialization conv_spec = ConvSpecialization::DEFAULT; + int num_groups = 1; + + // String helpers + static const char* direction_str(GroupedConvDirection dir) + { + switch(dir) + { + case GroupedConvDirection::FORWARD: return "fwd"; + case GroupedConvDirection::BACKWARD_DATA: return "bwd_data"; + case GroupedConvDirection::BACKWARD_WEIGHT: return "bwd_weight"; + default: return "unknown"; + } + } + + static const char* datatype_str(ConvDataType dt) + { + switch(dt) + { + case ConvDataType::FP32: return "fp32"; + case ConvDataType::FP64: return "fp64"; + case ConvDataType::FP16: return "fp16"; + case ConvDataType::BF16: return "bf16"; + case ConvDataType::FP8: return "fp8"; + case ConvDataType::BF8: return "bf8"; + case ConvDataType::FP8_E4M3: return "fp8_e4m3"; + case ConvDataType::FP8_E5M2: return "fp8_e5m2"; + case ConvDataType::INT8: return "int8"; + case ConvDataType::UINT8: return "uint8"; + case ConvDataType::INT32: return "int32"; + case ConvDataType::FP4: return "fp4"; + case ConvDataType::INT4: return "int4"; + default: return "unknown"; + } + } +}; + +// ============================================================================= +// Algorithm Info (HOW it's computed) +// ============================================================================= + +struct DataTileInfo +{ + int m = 128; // M tile (output spatial * N) + int n = 128; // N tile (K output channels) + int k = 64; // K tile (C input channels) +}; + +struct WarpGemmParams +{ + int gemm_m = 16; // MFMA M dimension (MPerXDL) + int gemm_n = 16; // MFMA N dimension (NPerXDL) + int m_iter = 2; // M iterations per warp (MXdlPerWave) + int n_iter = 2; // N iterations per warp (NXdlPerWave) +}; + +struct BlockWarpConfig +{ + int m_warp = 2; // Warps along M + int n_warp = 2; // Warps along N + int k_warp = 1; // Warps along K + int m_warp_tile = 32; // Warp tile M + int n_warp_tile = 32; // Warp tile N + int k_warp_tile = 16; // Warp tile K +}; + +struct VectorSizeInfo +{ + int a = 4; // Input vector size + int b = 8; // Weight vector size + int c = 8; // Output vector size +}; + +struct GroupedConvAlgorithmInfo +{ + DataTileInfo tile; + BlockWarpConfig warp; + VectorSizeInfo vector_size; + + PipelineVersion pipeline = PipelineVersion::V4; + PipelineScheduler scheduler = PipelineScheduler::INTRAWAVE; + GemmPadding padding = GemmPadding::MNK_PADDING; + MemoryOperation memory_op = MemoryOperation::SET; + EpilogueType epilogue = EpilogueType::CSHUFFLE; + + int thread_block_size = 256; + bool double_smem_buffer = false; + int num_wave_groups = 1; + int block_per_cu = 1; + int num_groups_to_merge = 1; + + // Pipeline string + static const char* pipeline_str(PipelineVersion pv) + { + switch(pv) + { + case PipelineVersion::V1: return "v1"; + case PipelineVersion::V2: return "v2"; + case PipelineVersion::V3: return "compv3"; + case PipelineVersion::V4: return "compv4"; + case PipelineVersion::V5: return "compv5"; + case PipelineVersion::V6: return "compv6"; + case PipelineVersion::MEMORY: return "mem"; + case PipelineVersion::COMPUTE_ASYNC: return "comp_async"; + case PipelineVersion::PRESHUFFLE_V2: return "preshuffle_v2"; + default: return "unknown"; + } + } + + static const char* scheduler_str(PipelineScheduler ps) + { + switch(ps) + { + case PipelineScheduler::DEFAULT: return "default"; + case PipelineScheduler::INTRAWAVE: return "intrawave"; + case PipelineScheduler::INTERWAVE: return "interwave"; + default: return "unknown"; + } + } + + static const char* memory_op_str(MemoryOperation mo) + { + switch(mo) + { + case MemoryOperation::SET: return "set"; + case MemoryOperation::ATOMIC_ADD: return "atomic_add"; + case MemoryOperation::ATOMIC_MAX: return "atomic_max"; + case MemoryOperation::ADD: return "add"; + default: return "unknown"; + } + } + + static const char* epilogue_str(EpilogueType et) + { + switch(et) + { + case EpilogueType::CSHUFFLE: return "cshuffle"; + case EpilogueType::DEFAULT_2D: return "default_2d"; + case EpilogueType::DEFAULT_GEMM_2D: return "default_gemm_2d"; + case EpilogueType::DIRECT_STORE: return "direct_store"; + case EpilogueType::BIAS_ADD: return "bias_add"; + case EpilogueType::BIAS_ADD_RELU: return "bias_add_relu"; + case EpilogueType::BIAS_ADD_GELU: return "bias_add_gelu"; + default: return "unknown"; + } + } +}; + +// ============================================================================= +// Arch Info (Target GPU) +// ============================================================================= + +struct ArchInfo +{ + std::string name = "gfx942"; // MI300X default + int max_waves_per_cu = 8; + int lds_size_kb = 64; + int sgpr_count = 108; + int vgpr_count = 512; + + bool supports_mfma_fp16() const { return name.find("gfx9") != std::string::npos; } + bool supports_wmma() const { return name.find("gfx11") != std::string::npos; } +}; + +// ============================================================================= +// Full Grouped Conv Config (combines Signature + Algorithm + Arch) +// ============================================================================= + +struct GroupedConvConfig +{ + GroupedConvSignatureInfo signature; + GroupedConvAlgorithmInfo algorithm; + ArchInfo arch; + + // Generate unique kernel name + std::string name() const + { + std::ostringstream oss; + oss << "grouped_conv_" << GroupedConvSignatureInfo::direction_str(signature.direction) + << "_" << signature.in_type << "_" << signature.spatial_dim << "d" << "_" + << GroupedConvAlgorithmInfo::pipeline_str(algorithm.pipeline) << "_" << algorithm.tile.m + << "x" << algorithm.tile.n << "x" << algorithm.tile.k; + return oss.str(); + } + + // Brief description + std::string brief() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " + << GroupedConvSignatureInfo::direction_str(signature.direction) + << " Grouped Convolution (" << signature.in_type << ")"; + return oss.str(); + } + + // Detailed description (tree-like) + std::string detailed() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " + << GroupedConvSignatureInfo::direction_str(signature.direction) + << " Grouped Convolution Kernel\n"; + + oss << " Signature:\n"; + oss << " Data Type: " << signature.in_type << "\n"; + oss << " Accumulator: " << signature.acc_type << "\n"; + oss << " Groups: " << signature.num_groups << "\n"; + + oss << " Algorithm:\n"; + oss << " Thread Block Size: " << algorithm.thread_block_size << "\n"; + oss << " Data Tile: " << algorithm.tile.m << "x" << algorithm.tile.n << "x" + << algorithm.tile.k << "\n"; + oss << " Warp Config: " << algorithm.warp.m_warp << "x" << algorithm.warp.n_warp << "x" + << algorithm.warp.k_warp << "\n"; + oss << " Warp Tile: " << algorithm.warp.m_warp_tile << "x" << algorithm.warp.n_warp_tile + << "x" << algorithm.warp.k_warp_tile << "\n"; + oss << " Pipeline: " << GroupedConvAlgorithmInfo::pipeline_str(algorithm.pipeline) + << "\n"; + oss << " Scheduler: " << GroupedConvAlgorithmInfo::scheduler_str(algorithm.scheduler) + << "\n"; + + oss << " Arch:\n"; + oss << " Target: " << arch.name << "\n"; + + return oss.str(); + } +}; + +// ============================================================================= +// Predefined Configs +// ============================================================================= + +namespace configs { + +// Memory-bound config +template +struct Memory : public GroupedConvConfig +{ + Memory() + { + algorithm.tile = {128, 32, 128 / (int)sizeof(PrecType)}; + algorithm.warp = {4, 1, 1, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::MEMORY; + algorithm.double_smem_buffer = false; + } +}; + +// Compute V3 - Small +template +struct CompV3_Small : public GroupedConvConfig +{ + CompV3_Small() + { + algorithm.tile = {16, 64, 64}; + algorithm.warp = {1, 4, 1, 16, 16, 32}; + algorithm.pipeline = PipelineVersion::V3; + } +}; + +// Compute V3 - Medium +template +struct CompV3_Medium : public GroupedConvConfig +{ + CompV3_Medium() + { + algorithm.tile = {128, 128, 128 / (int)sizeof(PrecType)}; + algorithm.warp = {2, 2, 1, 16, 16, 32}; + algorithm.pipeline = PipelineVersion::V3; + algorithm.block_per_cu = 2; + } +}; + +// Compute V3 - Large +template +struct CompV3_Large : public GroupedConvConfig +{ + CompV3_Large() + { + algorithm.tile = {256, 256, 128 / (int)sizeof(PrecType)}; + algorithm.warp = {2, 2, 1, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::V3; + } +}; + +// Compute V4 - Double buffered +template +struct CompV4 : public GroupedConvConfig +{ + CompV4() + { + algorithm.tile = {256, 256, 64 / (int)sizeof(PrecType)}; + algorithm.warp = {2, 2, 1, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::V4; + algorithm.double_smem_buffer = true; + } +}; + +// Compute V5 - Wave groups +template +struct CompV5 : public GroupedConvConfig +{ + CompV5() + { + algorithm.tile = {128, 128, 64 / (int)sizeof(PrecType)}; + algorithm.warp = {1, 1, 2, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::V5; + algorithm.num_wave_groups = 2; + } +}; + +// WMMA config for gfx11xx +template +struct WMMA : public GroupedConvConfig +{ + WMMA() + { + algorithm.tile = {128, 128, 64 / (int)sizeof(PrecType)}; + algorithm.warp = {4, 2, 1, 16, 16, 16}; + algorithm.pipeline = PipelineVersion::V3; + algorithm.block_per_cu = 2; + arch.name = "gfx1100"; + } +}; + +// Merged groups config +template +struct CompV3_MergedGroups : public GroupedConvConfig +{ + CompV3_MergedGroups() + { + algorithm.tile = {16, 32, 32}; + algorithm.warp = {1, 2, 1, 16, 16, 32}; + algorithm.vector_size = {4, 8, 8}; + algorithm.pipeline = PipelineVersion::V3; + algorithm.num_groups_to_merge = 2; + } +}; + +} // namespace configs + +// ============================================================================= +// DataType Traits (compile-time type info for CK Tile types) +// ============================================================================= + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; + static constexpr int size_bytes = 4; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; + static constexpr int size_bytes = 8; +}; + +// Forward declare CK Tile types for traits +// Note: actual ck_tile types are defined in ck_tile/core/numeric/ +// These traits allow working with type names at compile time + +// ============================================================================= +// ConvTypeConfig (input/weight/acc/output type combinations) +// ============================================================================= + +template +struct ConvTypeConfig +{ + using input_type = InDataType; + using weight_type = WeiDataType; + using output_type = OutDataType; + using accumulator_type = AccDataType; +}; + +// Common type configurations as type aliases +// FP16 -> FP32 accumulator -> FP16 output (most common) +// BF16 -> FP32 accumulator -> BF16 output +// FP8 -> FP32 accumulator -> FP8 output +// INT8 -> INT32 accumulator -> INT8 output + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp new file mode 100644 index 0000000000..8ddfe445ff --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp @@ -0,0 +1,537 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_kernel_decl.hpp + * @brief Declarative grouped convolution kernel specification + * + * USAGE: + * ====== + * + * // Named kernel sets for grouped convolution + * DECL_GROUPED_CONV_KERNEL_SET(gconv_fwd, + * .add("fp16", "nhwc", "forward", 128, 128, 32) + * .add("fp16", "nhwc", "forward", 256, 256, 64) + * ); + * + * // Access at runtime + * auto& set = GroupedConvKernelSetRegistry::instance().get("gconv_fwd"); + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace grouped_conv_decl { + +// ============================================================================= +// Wildcard constants +// ============================================================================= + +constexpr const char* ANY = "*"; +constexpr int ANY_INT = -1; + +// ============================================================================= +// GroupedConvSignature - WHAT operation +// ============================================================================= + +class GroupedConvSignature +{ + public: + std::string dtype_in_ = "fp16"; // Input data type + std::string dtype_wei_ = "fp16"; // Weight data type + std::string dtype_out_ = "fp16"; // Output data type + std::string dtype_acc_ = "fp32"; // Accumulator type + std::string dtype_workspace_ = "fp32"; // Workspace type (two-stage algorithms) + std::string dtype_bias_ = "fp16"; // Bias type (bias epilogue) + std::string layout_ = "nhwc"; // Data layout: nhwc, nchw + std::string conv_op_ = "forward"; // forward, bwd_data, bwd_weight + int num_dims_ = 2; // Spatial dimensions: 1, 2, or 3 + int groups_ = 1; // Group grouped convolution + std::string specialization_ = "default"; // Filter specialization + + GroupedConvSignature& dtype(const std::string& in, + const std::string& wei, + const std::string& out, + const std::string& acc = "fp32") + { + dtype_in_ = in; + dtype_wei_ = wei; + dtype_out_ = out; + dtype_acc_ = acc; + return *this; + } + + GroupedConvSignature& dtype(const std::string& all) + { + dtype_in_ = dtype_wei_ = dtype_out_ = dtype_bias_ = all; + dtype_acc_ = dtype_workspace_ = "fp32"; + return *this; + } + + GroupedConvSignature& dtype_workspace(const std::string& ws) + { + dtype_workspace_ = ws; + return *this; + } + + GroupedConvSignature& dtype_bias(const std::string& b) + { + dtype_bias_ = b; + return *this; + } + + GroupedConvSignature& layout(const std::string& l) + { + layout_ = l; + return *this; + } + GroupedConvSignature& conv_type(const std::string& op) + { + conv_op_ = op; + return *this; + } + GroupedConvSignature& dims(int d) + { + num_dims_ = d; + return *this; + } + GroupedConvSignature& groups(int g) + { + groups_ = g; + return *this; + } + GroupedConvSignature& spec(const std::string& s) + { + specialization_ = s; + return *this; + } + + std::string op_str() const + { + if(conv_op_ == "forward") + return "fwd"; + if(conv_op_ == "bwd_data") + return "bwd_data"; + if(conv_op_ == "bwd_weight") + return "bwd_weight"; + return conv_op_; + } +}; + +// ============================================================================= +// GroupedConvAlgorithm - HOW it's implemented +// ============================================================================= + +class GroupedConvAlgorithm +{ + public: + // Tile shape (M, N, K per tile - M=spatial*N, N=K_out, K=C_in) + int tile_m_ = 1; // Tile M (output spatial * batch) + int tile_n_ = 128; // Tile N (output channels K) + int tile_k_ = 128; // Tile K (input channels C) + + // Output spatial tile + int tile_ho_ = 1; + int tile_wo_ = 16; + + // Wave/warp shape + int wave_m_ = ANY_INT; + int wave_n_ = ANY_INT; + int wave_k_ = 1; + int warp_m_ = ANY_INT; + int warp_n_ = ANY_INT; + int warp_k_ = 16; + + // Vector sizes + int vector_a_ = 4; // Input vector size + int vector_b_ = 8; // Weight vector size + int vector_c_ = 8; // Output vector size + + // Pipeline configuration + std::string pipeline_ = "compv4"; + std::string scheduler_ = "intrawave"; + std::string epilogue_ = "cshuffle"; + std::string memory_op_ = "set"; // Memory operation: set, atomic_add, atomic_max, add + + // Occupancy/performance hints + int block_size_ = 256; + int block_per_cu_ = 1; + int num_wave_groups_ = 1; + int num_groups_to_merge_ = 1; + bool double_smem_buffer_ = false; + + // Padding -- always enabled for convolution (MNK padding assumed) + static constexpr bool pad_m_ = true; + static constexpr bool pad_n_ = true; + static constexpr bool pad_k_ = true; + + // Tile setter (M, N, K) + GroupedConvAlgorithm& tile(int m, int n, int k) + { + tile_m_ = m; + tile_n_ = n; + tile_k_ = k; + return *this; + } + + GroupedConvAlgorithm& tile_output(int ho, int wo) + { + tile_ho_ = ho; + tile_wo_ = wo; + return *this; + } + + GroupedConvAlgorithm& wave(int m, int n, int k = 1) + { + wave_m_ = m; + wave_n_ = n; + wave_k_ = k; + return *this; + } + + GroupedConvAlgorithm& warp(int m, int n, int k = 16) + { + warp_m_ = m; + warp_n_ = n; + warp_k_ = k; + return *this; + } + + GroupedConvAlgorithm& vector_sizes(int a, int b, int c) + { + vector_a_ = a; + vector_b_ = b; + vector_c_ = c; + return *this; + } + + GroupedConvAlgorithm& pipeline(const std::string& p) + { + pipeline_ = p; + return *this; + } + GroupedConvAlgorithm& scheduler(const std::string& s) + { + scheduler_ = s; + return *this; + } + GroupedConvAlgorithm& epilogue(const std::string& e) + { + epilogue_ = e; + return *this; + } + GroupedConvAlgorithm& memory_op(const std::string& m) + { + memory_op_ = m; + return *this; + } + + // Occupancy setters + GroupedConvAlgorithm& block_per_cu(int b) + { + block_per_cu_ = b; + return *this; + } + GroupedConvAlgorithm& num_wave_groups(int n) + { + num_wave_groups_ = n; + return *this; + } + GroupedConvAlgorithm& num_groups_to_merge(int n) + { + num_groups_to_merge_ = n; + return *this; + } + GroupedConvAlgorithm& double_smem_buffer(bool d) + { + double_smem_buffer_ = d; + return *this; + } + + bool needs_expansion() const + { + return wave_m_ == ANY_INT || warp_m_ == ANY_INT || pipeline_ == "*" || scheduler_ == "*"; + } + + /// Check if specific parameter needs expansion + bool needs_wave_expansion() const { return wave_m_ == ANY_INT || wave_n_ == ANY_INT; } + bool needs_warp_expansion() const { return warp_m_ == ANY_INT || warp_n_ == ANY_INT; } + bool needs_pipeline_expansion() const { return pipeline_ == "*"; } + bool needs_scheduler_expansion() const { return scheduler_ == "*"; } + + /// Auto-fill with defaults (for single kernel generation) + void auto_fill() + { + if(wave_m_ == ANY_INT) + wave_m_ = 2; + if(wave_n_ == ANY_INT) + wave_n_ = 2; + if(warp_m_ == ANY_INT) + warp_m_ = 32; + if(warp_n_ == ANY_INT) + warp_n_ = 32; + if(pipeline_ == "*") + pipeline_ = "compv4"; + if(scheduler_ == "*") + scheduler_ = "intrawave"; + } + + /// Get all valid wave configurations for arch + static std::vector> valid_wave_configs(const std::string& arch) + { + // Match arch_specs_generated.py WARP_SUPPORTED_COMBINATIONS + if(arch == "gfx942" || arch == "gfx90a" || arch == "gfx950") + { + return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + } + return {{2, 2, 1}}; // Default + } + + /// Get all valid warp tile configurations + static std::vector> valid_warp_configs(const std::string& arch, + const std::string& dtype) + { + // Match arch_specs_generated.py WARP_TILE_SUPPORTED_COMBINATIONS + if(arch == "gfx942" && (dtype == "fp16" || dtype == "bf16")) + { + return {{16, 16, 16}, {32, 32, 16}}; + } + return {{32, 32, 16}}; // Default + } + + /// Get all valid pipeline/scheduler combinations for forward conv. + /// Backward operations (bwd_data/bwd_weight) only support compv3 and mem + /// due to transpose_tile2d and get_length constraints in CK Tile. + static std::vector> valid_trait_configs() + { + return { + {"compv3", "intrawave"}, + {"compv4", "intrawave"}, + {"compv5", "intrawave"}, + {"mem", "intrawave"}, + {"mem", "interwave"}, + }; + } +}; + +// ============================================================================= +// GroupedConvKernelDecl +// ============================================================================= + +struct GroupedConvKernelDecl +{ + GroupedConvSignature signature; + GroupedConvAlgorithm algorithm; + std::string arch = "gfx942"; + + GroupedConvKernelDecl() = default; + + GroupedConvKernelDecl(const GroupedConvSignature& sig, + const GroupedConvAlgorithm& algo, + const std::string& a = "gfx942") + : signature(sig), algorithm(algo), arch(a) + { + } + + std::string name() const + { + std::ostringstream oss; + // Generate full kernel name similar to GEMM: + // grouped_conv____d______ + oss << "grouped_conv_" << signature.op_str() << "_" << signature.dtype_in_ << "_" + << signature.layout_ << "_" << signature.num_dims_ << "d" << "_" << algorithm.pipeline_ + << "_" << algorithm.epilogue_ << "_" << algorithm.scheduler_ << "_" << algorithm.tile_m_ + << "x" << algorithm.tile_n_ << "x" << algorithm.tile_k_ << "_" << algorithm.wave_m_ + << "x" << algorithm.wave_n_ << "x" << algorithm.wave_k_ << "_" << algorithm.warp_m_ + << "x" << algorithm.warp_n_ << "x" << algorithm.warp_k_; + return oss.str(); + } + + bool has_wildcards() const { return algorithm.needs_expansion() || arch == "*"; } +}; + +// ============================================================================= +// GroupedConvKernelSet +// ============================================================================= + +class GroupedConvKernelSet +{ + public: + GroupedConvKernelSet() = default; + + GroupedConvKernelSet& add(const GroupedConvSignature& sig, + const GroupedConvAlgorithm& algo, + const std::string& arch = "gfx942") + { + decls_.emplace_back(sig, algo, arch); + return *this; + } + + // Simple add: dtype, layout, conv_type, tile_k, tile_c + GroupedConvKernelSet& add(const std::string& dtype, + const std::string& layout, + const std::string& conv_type, + int tile_k, + int tile_c, + const std::string& arch = "gfx942") + { + GroupedConvSignature sig; + sig.dtype(dtype).layout(layout).conv_type(conv_type); + GroupedConvAlgorithm algo; + algo.tile(1, tile_k, tile_c); + decls_.emplace_back(sig, algo, arch); + return *this; + } + + GroupedConvKernelSet& merge(const GroupedConvKernelSet& other) + { + decls_.insert(decls_.end(), other.decls_.begin(), other.decls_.end()); + return *this; + } + + const std::vector& declarations() const { return decls_; } + size_t size() const { return decls_.size(); } + + void print(std::ostream& os = std::cout) const + { + os << "GroupedConvKernelSet (" << size() << " declarations):\n"; + for(const auto& d : decls_) + { + os << " - " << d.name(); + if(d.algorithm.needs_expansion()) + os << " [expands]"; + os << "\n"; + } + } + + GroupedConvKernelSet& tag(const std::string& t) + { + tag_ = t; + return *this; + } + std::string tag() const { return tag_; } + + private: + std::vector decls_; + std::string tag_; +}; + +// ============================================================================= +// GroupedConvKernelSetRegistry +// ============================================================================= + +class GroupedConvKernelSetRegistry +{ + public: + static GroupedConvKernelSetRegistry& instance() + { + static GroupedConvKernelSetRegistry reg; + return reg; + } + + void add(const std::string& name, const GroupedConvKernelSet& set) + { + sets_[name] = set; + if(std::find(order_.begin(), order_.end(), name) == order_.end()) + { + order_.push_back(name); + } + } + + // Alias for add() for consistency with GEMM API + void register_set(const std::string& name, const GroupedConvKernelSet& set) { add(name, set); } + + const GroupedConvKernelSet& get(const std::string& name) const + { + static GroupedConvKernelSet empty; + auto it = sets_.find(name); + return it != sets_.end() ? it->second : empty; + } + + bool has(const std::string& name) const { return sets_.find(name) != sets_.end(); } + + std::vector names() const { return order_; } + size_t size() const { return sets_.size(); } + + void clear() + { + sets_.clear(); + order_.clear(); + } + + void print() const + { + std::cout << "Grouped Conv Kernel Sets (" << size() << "):\n"; + for(const auto& name : order_) + { + const auto& set = sets_.at(name); + std::cout << " " << name << ": " << set.size() << " declarations\n"; + } + } + + private: + GroupedConvKernelSetRegistry() = default; + std::unordered_map sets_; + std::vector order_; +}; + +// ============================================================================= +// Static Registrar +// ============================================================================= + +struct GroupedConvKernelSetRegistrar +{ + GroupedConvKernelSetRegistrar(const std::string& name, const GroupedConvKernelSet& set) + { + GroupedConvKernelSetRegistry::instance().add(name, set); + } +}; + +} // namespace grouped_conv_decl + +// Convenience aliases +using GroupedConvSignature = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgorithm = grouped_conv_decl::GroupedConvAlgorithm; +using GroupedConvKernelDecl = grouped_conv_decl::GroupedConvKernelDecl; +using GroupedConvKernelSet = grouped_conv_decl::GroupedConvKernelSet; +using GroupedConvKernelSetRegistry = grouped_conv_decl::GroupedConvKernelSetRegistry; + +} // namespace dispatcher +} // namespace ck_tile + +// ============================================================================= +// Declaration Macros +// ============================================================================= + +#define CK_GROUPED_CONV_DECL_CAT_(a, b) CK_GROUPED_CONV_DECL_CAT_IMPL_(a, b) +#define CK_GROUPED_CONV_DECL_CAT_IMPL_(a, b) a##b + +// Note: __extension__ suppresses warnings about __COUNTER__ being a GCC/Clang extension +#define DECL_GROUPED_CONV_KERNEL_SET(name, ...) \ + __extension__ static ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSetRegistrar \ + CK_GROUPED_CONV_DECL_CAT_(_gconv_kset_reg_, __COUNTER__)( \ + #name, \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet() __VA_ARGS__.tag(#name)) + +#define DECL_GROUPED_CONV_KERNEL_ALL(dtype, layout) \ + __extension__ static ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSetRegistrar \ + CK_GROUPED_CONV_DECL_CAT_(_gconv_kset_reg_, __COUNTER__)( \ + #dtype "_" #layout "_all", \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet().add( \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvSignature().dtype(#dtype).layout( \ + #layout), \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvAlgorithm(), \ + "*")) + +#define GROUPED_CONV_KERNEL_SET(name) \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet name +#define BEGIN_GROUPED_CONV_KERNEL_SET() \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet() diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp new file mode 100644 index 0000000000..5b58f37206 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp @@ -0,0 +1,255 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_problem.hpp + * @brief Grouped Convolution problem definition + */ + +#pragma once + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/** + * @brief Grouped Convolution operation type + */ +enum class GroupedConvOp +{ + Forward, // Y = Conv(X, W) + BackwardData, // dX = ConvBwdData(dY, W) + BackwardWeight // dW = ConvBwdWeight(X, dY) +}; + +/** + * @brief Grouped Convolution problem specification + */ +struct GroupedConvProblem +{ + // Batch and channels + std::int64_t N; // Batch size + std::int64_t C; // Input channels + std::int64_t K; // Output channels (filters) + std::int64_t G; // Number of groups (1 for standard conv) + + // Spatial dimensions (supports 1D, 2D, 3D) + std::array input_spatial; // {D, H, W} or {1, H, W} for 2D + std::array filter_spatial; // {Z, Y, X} or {1, Y, X} for 2D + std::array output_spatial; // {Do, Ho, Wo} or {1, Ho, Wo} for 2D + + // Convolution parameters + std::array stride; // Stride in each dimension + std::array padding; // Padding in each dimension + std::array dilation; // Dilation in each dimension + + // Operation type + GroupedConvOp op = GroupedConvOp::Forward; + + // Split-K for backward weight (k_batch parameter in CK Tile). + // Values > 1 split the reduction dimension across multiple thread blocks + // and use atomic accumulation. + int split_k = 1; + + // Default constructor for 2D convolution + GroupedConvProblem() + : N(1), + C(64), + K(64), + G(1), + input_spatial{1, 28, 28}, + filter_spatial{1, 3, 3}, + output_spatial{1, 26, 26}, + stride{1, 1, 1}, + padding{0, 0, 0}, + dilation{1, 1, 1}, + op(GroupedConvOp::Forward) + { + } + + // Constructor for 2D convolution + GroupedConvProblem(std::int64_t n, + std::int64_t c, + std::int64_t k, + std::int64_t hi, + std::int64_t wi, + std::int64_t y, + std::int64_t x, + std::int64_t stride_h = 1, + std::int64_t stride_w = 1, + std::int64_t pad_h = 0, + std::int64_t pad_w = 0, + std::int64_t dilation_h = 1, + std::int64_t dilation_w = 1) + : N(n), + C(c), + K(k), + G(1), + input_spatial{1, hi, wi}, + filter_spatial{1, y, x}, + stride{1, stride_h, stride_w}, + padding{0, pad_h, pad_w}, + dilation{1, dilation_h, dilation_w}, + op(GroupedConvOp::Forward) + { + compute_output_size(); + } + + /// Check if problem dimensions are valid + bool is_valid() const + { + return N > 0 && C > 0 && K > 0 && G > 0 && (C % G == 0) && (K % G == 0); + } + + /// Compute output spatial dimensions + void compute_output_size() + { + for(int i = 0; i < 3; ++i) + { + std::int64_t effective_filter = (filter_spatial[i] - 1) * dilation[i] + 1; + output_spatial[i] = + (input_spatial[i] + 2 * padding[i] - effective_filter) / stride[i] + 1; + } + } + + /// Get 2D height/width accessors + std::int64_t Hi() const { return input_spatial[1]; } + std::int64_t Wi() const { return input_spatial[2]; } + std::int64_t Ho() const { return output_spatial[1]; } + std::int64_t Wo() const { return output_spatial[2]; } + std::int64_t Y() const { return filter_spatial[1]; } // Filter height + std::int64_t X() const { return filter_spatial[2]; } // Filter width + + /// Get total FLOPs for this convolution + double get_flops() const + { + // Forward: 2 * N * K * Ho * Wo * C * Y * X / G + double spatial_out = 1.0; + double filter_size = 1.0; + for(int i = 0; i < 3; ++i) + { + spatial_out *= output_spatial[i]; + filter_size *= filter_spatial[i]; + } + return 2.0 * N * K * spatial_out * (C / G) * filter_size; + } + + /// Check if this is a depthwise convolution + bool is_depthwise() const { return G == C && G == K; } + + /// Check if this is a pointwise (1x1) convolution + bool is_pointwise() const + { + return filter_spatial[0] == 1 && filter_spatial[1] == 1 && filter_spatial[2] == 1; + } + + /// String representation + std::string to_string() const + { + std::string s = "GroupedConvProblem(N=" + std::to_string(N); + s += ", C=" + std::to_string(C) + ", K=" + std::to_string(K); + s += ", G=" + std::to_string(G); + s += ", Hi=" + std::to_string(Hi()) + ", Wi=" + std::to_string(Wi()); + s += ", Y=" + std::to_string(Y()) + ", X=" + std::to_string(X()); + s += ", Ho=" + std::to_string(Ho()) + ", Wo=" + std::to_string(Wo()); + s += ")"; + return s; + } +}; + +// ============================================================================= +// GroupedConvProblemBuilder +// ============================================================================= + +/// Builder pattern for Grouped Convolution problem configuration +class GroupedConvProblemBuilder +{ + public: + GroupedConvProblemBuilder() = default; + + GroupedConvProblemBuilder& batch(std::int64_t n) + { + problem_.N = n; + return *this; + } + + GroupedConvProblemBuilder& channels(std::int64_t c, std::int64_t k) + { + problem_.C = c; + problem_.K = k; + return *this; + } + + GroupedConvProblemBuilder& groups(std::int64_t g) + { + problem_.G = g; + return *this; + } + + GroupedConvProblemBuilder& input_size(std::int64_t h, std::int64_t w) + { + problem_.input_spatial[0] = 1; + problem_.input_spatial[1] = h; + problem_.input_spatial[2] = w; + return *this; + } + + GroupedConvProblemBuilder& filter_size(std::int64_t y, std::int64_t x) + { + problem_.filter_spatial[0] = 1; + problem_.filter_spatial[1] = y; + problem_.filter_spatial[2] = x; + return *this; + } + + GroupedConvProblemBuilder& stride(std::int64_t sh, std::int64_t sw) + { + problem_.stride[0] = 1; + problem_.stride[1] = sh; + problem_.stride[2] = sw; + return *this; + } + + GroupedConvProblemBuilder& padding(std::int64_t ph, std::int64_t pw) + { + problem_.padding[0] = 0; + problem_.padding[1] = ph; + problem_.padding[2] = pw; + return *this; + } + + GroupedConvProblemBuilder& dilation(std::int64_t dh, std::int64_t dw) + { + problem_.dilation[0] = 1; + problem_.dilation[1] = dh; + problem_.dilation[2] = dw; + return *this; + } + + GroupedConvProblemBuilder& operation(GroupedConvOp op) + { + problem_.op = op; + return *this; + } + + [[nodiscard]] GroupedConvProblem build() const + { + GroupedConvProblem p = problem_; + p.compute_output_size(); + if(!p.is_valid()) + { + throw std::invalid_argument("Invalid grouped convolution problem dimensions"); + } + return p; + } + + private: + GroupedConvProblem problem_; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp new file mode 100644 index 0000000000..42698a0bc8 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp @@ -0,0 +1,614 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_registry.hpp + * @brief Grouped Convolution kernel registry and dispatcher + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" + +namespace ck_tile { +namespace dispatcher { + +// ============================================================================= +// Thread-local buffer context for GroupedConvDispatcher::run() +// The generated conv backend RunFn reads these to get buffer pointers. +// ============================================================================= + +struct ConvDispatchBuffers +{ + const void* input_ptr = nullptr; + const void* weight_ptr = nullptr; + void* output_ptr = nullptr; + int warmup = 3; + int repeat = 10; + bool benchmarking = true; + int split_k = 1; +}; + +inline thread_local ConvDispatchBuffers g_conv_dispatch_buffers; + +// ============================================================================= +// GroupedConvKernelKey - Unique identifier for a grouped convolution kernel +// ============================================================================= + +struct GroupedConvKernelKey +{ + // Signature fields + std::string dtype_in; + std::string dtype_wei; + std::string dtype_out; + std::string layout; // e.g., "nhwgc" + int ndim_spatial = 2; // 1, 2, or 3 + GroupedConvOp op = GroupedConvOp::Forward; + + // Tile configuration + int tile_m = 1; + int tile_n = 128; + int tile_k = 128; + + // Wave/warp configuration + int wave_m = 2; + int wave_n = 2; + int wave_k = 1; + int warp_m = 32; + int warp_n = 32; + int warp_k = 16; + + // Pipeline + std::string pipeline = "compv3"; + std::string scheduler = "intrawave"; + std::string epilogue = "cshuffle"; + + // ConvConfigBase parity fields + int vector_size_a = 4; + int vector_size_b = 8; + int vector_size_c = 8; + int block_per_cu = 1; + int num_wave_groups = 1; + int num_groups_to_merge = 1; + + // GPU architecture (for filter_by_arch) + std::string arch = "gfx942"; + + bool operator==(const GroupedConvKernelKey& other) const + { + return dtype_in == other.dtype_in && dtype_wei == other.dtype_wei && + dtype_out == other.dtype_out && layout == other.layout && + ndim_spatial == other.ndim_spatial && op == other.op && tile_m == other.tile_m && + tile_n == other.tile_n && tile_k == other.tile_k && wave_m == other.wave_m && + wave_n == other.wave_n && wave_k == other.wave_k && warp_m == other.warp_m && + warp_n == other.warp_n && warp_k == other.warp_k && pipeline == other.pipeline && + scheduler == other.scheduler && epilogue == other.epilogue && + vector_size_a == other.vector_size_a && vector_size_b == other.vector_size_b && + vector_size_c == other.vector_size_c && block_per_cu == other.block_per_cu && + num_wave_groups == other.num_wave_groups && + num_groups_to_merge == other.num_groups_to_merge && arch == other.arch; + } + + std::string to_string() const + { + std::string op_str; + switch(op) + { + case GroupedConvOp::Forward: op_str = "fwd"; break; + case GroupedConvOp::BackwardData: op_str = "bwd_data"; break; + case GroupedConvOp::BackwardWeight: op_str = "bwd_weight"; break; + } + return "grouped_conv_" + op_str + "_" + dtype_in + "_" + std::to_string(ndim_spatial) + + "d_" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "x" + + std::to_string(tile_k) + "_" + std::to_string(wave_m) + "x" + + std::to_string(wave_n) + "x" + std::to_string(wave_k) + "_" + + std::to_string(warp_m) + "x" + std::to_string(warp_n) + "x" + + std::to_string(warp_k) + "_" + pipeline; + } +}; + +struct GroupedConvKernelKeyHash +{ + std::size_t operator()(const GroupedConvKernelKey& key) const + { + std::size_t h = std::hash{}(key.dtype_in); + h ^= std::hash{}(key.layout) << 1; + h ^= std::hash{}(key.ndim_spatial) << 2; + h ^= std::hash{}(static_cast(key.op)) << 3; + h ^= std::hash{}(key.tile_m) << 4; + h ^= std::hash{}(key.tile_n) << 5; + h ^= std::hash{}(key.tile_k) << 6; + h ^= std::hash{}(key.wave_m) << 7; + h ^= std::hash{}(key.wave_n) << 8; + h ^= std::hash{}(key.warp_m) << 9; + h ^= std::hash{}(key.warp_n) << 10; + h ^= std::hash{}(key.pipeline) << 11; + h ^= std::hash{}(key.arch) << 12; + return h; + } +}; + +// ============================================================================= +// GroupedConvKernelInstance - Runtime representation of a kernel +// ============================================================================= + +// Forward declaration for shared_ptr type alias +class GroupedConvKernelInstance; +using GroupedConvKernelInstancePtr = std::shared_ptr; + +class GroupedConvKernelInstance +{ + public: + using RunFn = std::function; + + GroupedConvKernelInstance(const GroupedConvKernelKey& key, + const std::string& name, + RunFn run_fn) + : key_(key), name_(name), run_fn_(std::move(run_fn)) + { + } + + const GroupedConvKernelKey& key() const { return key_; } + const std::string& name() const { return name_; } + + float run(const GroupedConvProblem& problem, void* stream = nullptr) const + { + return run_fn_(problem, stream); + } + + bool matches(const GroupedConvProblem& problem) const + { + // Check if this kernel can handle the problem + return problem.op == key_.op; + } + + private: + GroupedConvKernelKey key_; + std::string name_; + RunFn run_fn_; +}; + +// ============================================================================= +// GroupedConvRegistry - Stores and manages grouped convolution kernels +// ============================================================================= + +class GroupedConvRegistry : public BaseRegistry +{ + using Base = BaseRegistry; + + public: + GroupedConvRegistry() = default; + + /// Singleton instance for global kernel registration + static GroupedConvRegistry& instance() + { + static GroupedConvRegistry registry; + return registry; + } + + /// Register kernels from a GroupedConvKernelSet (atomic batch registration) + bool register_set(const GroupedConvKernelSet& kernel_set, Priority priority = Priority::Normal) + { + // Build all instances first, then register under a single lock hold + // so readers never see a half-registered set. + std::vector>> + batch; + batch.reserve(kernel_set.declarations().size()); + + for(const auto& decl : kernel_set.declarations()) + { + GroupedConvKernelKey key; + key.dtype_in = decl.signature.dtype_in_; + key.dtype_wei = decl.signature.dtype_wei_; + key.dtype_out = decl.signature.dtype_out_; + key.layout = decl.signature.layout_; + key.ndim_spatial = decl.signature.num_dims_; + key.op = (decl.signature.conv_op_ == "forward") ? GroupedConvOp::Forward + : (decl.signature.conv_op_ == "bwd_data") ? GroupedConvOp::BackwardData + : GroupedConvOp::BackwardWeight; + key.tile_m = decl.algorithm.tile_m_; + key.tile_n = decl.algorithm.tile_n_; + key.tile_k = decl.algorithm.tile_k_; + key.wave_m = decl.algorithm.wave_m_; + key.wave_n = decl.algorithm.wave_n_; + key.wave_k = decl.algorithm.wave_k_; + key.warp_m = decl.algorithm.warp_m_; + key.warp_n = decl.algorithm.warp_n_; + key.warp_k = decl.algorithm.warp_k_; + key.pipeline = decl.algorithm.pipeline_; + key.scheduler = decl.algorithm.scheduler_; + key.epilogue = decl.algorithm.epilogue_; + key.vector_size_a = decl.algorithm.vector_a_; + key.vector_size_b = decl.algorithm.vector_b_; + key.vector_size_c = decl.algorithm.vector_c_; + key.block_per_cu = decl.algorithm.block_per_cu_; + key.num_wave_groups = decl.algorithm.num_wave_groups_; + key.num_groups_to_merge = decl.algorithm.num_groups_to_merge_; + key.arch = decl.arch; + + batch.emplace_back(key, + std::make_shared( + key, decl.name(), [](const GroupedConvProblem&, void*) -> float { + return 0.0f; + })); + } + + std::lock_guard lock(mutex()); + bool any_registered = false; + for(auto& [key, instance] : batch) + { + auto it = entries().find(key); + if(it == entries().end() || it->second.priority <= priority) + { + entries_mut()[key] = typename Base::Entry{std::move(instance), priority}; + any_registered = true; + } + } + return any_registered; + } + + /// Find the best kernel for a problem + const GroupedConvKernelInstance* find(const GroupedConvProblem& problem) const + { + std::lock_guard lock(mutex()); + const GroupedConvKernelInstance* best = nullptr; + Priority best_priority = Priority::Low; + + for(const auto& [key, entry] : entries()) + { + if(entry.instance->matches(problem)) + { + if(!best || entry.priority > best_priority) + { + best = entry.instance.get(); + best_priority = entry.priority; + } + } + } + + return best; + } + + /// Get all registered kernels + std::vector all_kernels() const + { + std::lock_guard lock(mutex()); + std::vector result; + for(const auto& [key, entry] : entries()) + { + result.push_back(entry.instance.get()); + } + return result; + } + + /// Export registry to JSON string + std::string export_json(bool include_statistics = false) const + { + // Note: get_name() acquires the mutex internally, so we must NOT hold + // the registry mutex here (std::mutex is not recursive). + std::string reg_name = get_name(); + + std::lock_guard lock(mutex()); + std::ostringstream json; + + json << "{\n"; + json << " \"metadata\": {\n"; + json << " \"registry_name\": \"" << json_escape(reg_name) << "\",\n"; + json << " \"total_kernels\": " << entries().size() << "\n"; + json << " }"; + + if(include_statistics && !entries().empty()) + { + std::map by_datatype; + std::map by_pipeline; + std::map by_arch; + + for(const auto& [key, entry] : entries()) + { + std::string dtype_key = key.dtype_in + "_" + key.dtype_wei + "_" + key.dtype_out; + by_datatype[dtype_key]++; + by_pipeline[key.pipeline]++; + by_arch[key.arch]++; + } + + json << ",\n \"statistics\": {\n"; + json << " \"by_datatype\": {"; + bool first = true; + for(const auto& [dtype, count] : by_datatype) + { + if(!first) + json << ","; + json << "\"" << json_escape(dtype) << "\":" << count; + first = false; + } + json << "},\n"; + json << " \"by_pipeline\": {"; + first = true; + for(const auto& [pipeline, count] : by_pipeline) + { + if(!first) + json << ","; + json << "\"" << json_escape(pipeline) << "\":" << count; + first = false; + } + json << "},\n"; + json << " \"by_arch\": {"; + first = true; + for(const auto& [arch, count] : by_arch) + { + if(!first) + json << ","; + json << "\"" << json_escape(arch) << "\":" << count; + first = false; + } + json << "}\n }"; + } + + json << ",\n \"kernels\": [\n"; + bool first = true; + for(const auto& [key, entry] : entries()) + { + if(!first) + json << ",\n"; + json << " " << export_kernel_json(*entry.instance); + first = false; + } + json << "\n ]\n"; + json << "}\n"; + + return json.str(); + } + + /// Export registry to JSON file + void export_json_to_file(const std::string& filename, bool include_statistics = false) const + { + std::string json_str = export_json(include_statistics); + std::ofstream file(filename); + if(!file.is_open()) + { + throw std::runtime_error("Failed to open file for export: " + filename); + } + file << json_str; + } + + /// Get kernels matching a predicate + std::vector + filter(std::function predicate) const + { + std::lock_guard lock(mutex()); + std::vector result; + for(const auto& [key, entry] : entries()) + { + if(predicate(*entry.instance)) + { + result.push_back(entry.instance.get()); + } + } + return result; + } + + /// Remove kernels not matching the arch + std::size_t filter_by_arch(const std::string& gpu_arch) + { + std::lock_guard lock(mutex()); + std::vector to_remove; + for(const auto& [key, entry] : entries()) + { + if(key.arch != gpu_arch) + { + to_remove.push_back(key); + } + } + for(const auto& key : to_remove) + { + entries_mut().erase(key); + } + return to_remove.size(); + } + + private: + static std::string json_escape(const std::string& str) + { + std::ostringstream oss; + for(char c : str) + { + switch(c) + { + case '"': oss << "\\\""; break; + case '\\': oss << "\\\\"; break; + case '\b': oss << "\\b"; break; + case '\f': oss << "\\f"; break; + case '\n': oss << "\\n"; break; + case '\r': oss << "\\r"; break; + case '\t': oss << "\\t"; break; + default: + if(c < 0x20) + { + oss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (int)c; + } + else + { + oss << c; + } + } + } + return oss.str(); + } + + static std::string export_kernel_json(const GroupedConvKernelInstance& kernel) + { + std::ostringstream json; + const auto& key = kernel.key(); + + std::string op_str; + switch(key.op) + { + case GroupedConvOp::Forward: op_str = "fwd"; break; + case GroupedConvOp::BackwardData: op_str = "bwd_data"; break; + case GroupedConvOp::BackwardWeight: op_str = "bwd_weight"; break; + } + + json << "{\n"; + json << " \"name\": \"" << json_escape(kernel.name()) << "\",\n"; + json << " \"signature\": {\n"; + json << " \"dtype_in\": \"" << json_escape(key.dtype_in) << "\",\n"; + json << " \"dtype_wei\": \"" << json_escape(key.dtype_wei) << "\",\n"; + json << " \"dtype_out\": \"" << json_escape(key.dtype_out) << "\",\n"; + json << " \"layout\": \"" << json_escape(key.layout) << "\",\n"; + json << " \"ndim_spatial\": " << key.ndim_spatial << ",\n"; + json << " \"op\": \"" << op_str << "\"\n"; + json << " },\n"; + json << " \"algorithm\": {\n"; + json << " \"tile_m\": " << key.tile_m << ",\n"; + json << " \"tile_n\": " << key.tile_n << ",\n"; + json << " \"tile_k\": " << key.tile_k << ",\n"; + json << " \"wave\": \"" << key.wave_m << "x" << key.wave_n << "x" << key.wave_k + << "\",\n"; + json << " \"warp\": \"" << key.warp_m << "x" << key.warp_n << "x" << key.warp_k + << "\",\n"; + json << " \"pipeline\": \"" << json_escape(key.pipeline) << "\",\n"; + json << " \"scheduler\": \"" << json_escape(key.scheduler) << "\",\n"; + json << " \"epilogue\": \"" << json_escape(key.epilogue) << "\",\n"; + json << " \"vector_sizes\": [" << key.vector_size_a << "," << key.vector_size_b + << "," << key.vector_size_c << "],\n"; + json << " \"block_per_cu\": " << key.block_per_cu << ",\n"; + json << " \"num_wave_groups\": " << key.num_wave_groups << ",\n"; + json << " \"num_groups_to_merge\": " << key.num_groups_to_merge << "\n"; + json << " },\n"; + json << " \"arch\": \"" << json_escape(key.arch) << "\"\n"; + json << " }"; + + return json.str(); + } +}; + +// ============================================================================= +// GroupedConvDispatcher - Selects and runs the best kernel for a problem +// ============================================================================= + +class GroupedConvDispatcher +{ + public: + enum class SelectionStrategy + { + PriorityBased, + Heuristic + }; + + using HeuristicFunction = std::function(const GroupedConvProblem&)>; + + explicit GroupedConvDispatcher(GroupedConvRegistry* registry) + : registry_(registry), strategy_(SelectionStrategy::PriorityBased) + { + } + + void set_strategy(SelectionStrategy s) { strategy_ = s; } + void set_heuristic(HeuristicFunction fn) { heuristic_ = std::move(fn); } + + /// Select the best kernel for a problem (does not run it) + const GroupedConvKernelInstance* select_kernel(const GroupedConvProblem& problem) const + { + if(strategy_ == SelectionStrategy::Heuristic) + return select_heuristic(problem); + return registry_->find(problem); + } + + /// Run convolution with automatic kernel selection (legacy - no buffers) + float run(const GroupedConvProblem& problem, void* stream = nullptr) + { + const auto* kernel = select_kernel(problem); + if(!kernel) + { + throw NoKernelFound("No suitable grouped convolution kernel found for problem: " + + problem.to_string()); + } + return kernel->run(problem, stream); + } + + /// Run convolution with buffer pointers and automatic kernel selection. + /// Sets the thread-local buffer context before dispatching to the kernel. + float run(const void* input_ptr, + const void* weight_ptr, + void* output_ptr, + const GroupedConvProblem& problem, + void* stream = nullptr, + int warmup = 3, + int repeat = 10) + { + const auto* kernel = select_kernel(problem); + if(!kernel) + { + throw NoKernelFound("No suitable grouped convolution kernel found for problem: " + + problem.to_string()); + } + g_conv_dispatch_buffers.input_ptr = input_ptr; + g_conv_dispatch_buffers.weight_ptr = weight_ptr; + g_conv_dispatch_buffers.output_ptr = output_ptr; + g_conv_dispatch_buffers.warmup = warmup; + g_conv_dispatch_buffers.repeat = repeat; + g_conv_dispatch_buffers.benchmarking = benchmarking_; + g_conv_dispatch_buffers.split_k = problem.split_k; + return kernel->run(problem, stream); + } + + /// Enable or disable GPU benchmarking (timing). + /// When disabled, kernels execute once with no timing overhead. + void set_benchmarking(bool enable) { benchmarking_ = enable; } + [[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; } + + /// Alias kept for backward compatibility + const GroupedConvKernelInstance* select(const GroupedConvProblem& problem) const + { + return select_kernel(problem); + } + + private: + const GroupedConvKernelInstance* select_heuristic(const GroupedConvProblem& problem) const + { + if(!heuristic_) + return registry_->find(problem); + + auto ranked_names = heuristic_(problem); + auto all = registry_->all_kernels(); + for(const auto& name : ranked_names) + { + for(const auto* kernel : all) + { + if(kernel->name().find(name) != std::string::npos && kernel->matches(problem)) + { + return kernel; + } + } + } + return registry_->find(problem); + } + + GroupedConvRegistry* registry_; + SelectionStrategy strategy_; + HeuristicFunction heuristic_; + bool benchmarking_ = true; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp new file mode 100644 index 0000000000..c817d36673 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp @@ -0,0 +1,324 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_utils.hpp + * @brief CK Tile Grouped Convolution Dispatcher Utilities + */ + +#pragma once + +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include "ck_tile/dispatcher/utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +namespace grouped_conv_utils { + +inline GroupedConvKernelDecl create_grouped_conv2d_fwd(const std::string& dtype = "fp16", + int tile_n = 128, + int tile_k = 128, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvKernelDecl create_grouped_conv3d_fwd(const std::string& dtype = "fp16", + int tile_n = 64, + int tile_k = 64, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("ndhwc").conv_type("forward").dims(3), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvKernelDecl create_grouped_conv2d_bwd_data(const std::string& dtype = "fp16", + int tile_n = 128, + int tile_k = 128, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("bwd_data").dims(2), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvKernelDecl create_grouped_conv2d_bwd_weight(const std::string& dtype = "fp16", + int tile_n = 128, + int tile_k = 128, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("bwd_weight").dims(2), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .memory_op("atomic_add") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvProblem create_grouped_conv2d_problem(int N, + int C, + int K, + int Hi, + int Wi, + int Y, + int X, + int stride = 1, + int padding = 0, + GroupedConvOp op = GroupedConvOp::Forward) +{ + GroupedConvProblem p; + p.N = N; + p.C = C; + p.K = K; + p.G = 1; + p.input_spatial = {1, Hi, Wi}; + p.filter_spatial = {1, Y, X}; + p.stride = {1, stride, stride}; + p.padding = {0, padding, padding}; + p.dilation = {1, 1, 1}; + p.op = op; + p.compute_output_size(); + return p; +} + +inline GroupedConvProblem create_grouped_conv3d_problem(int N, + int C, + int K, + int Di, + int Hi, + int Wi, + int Z, + int Y, + int X, + int stride = 1, + int padding = 0, + GroupedConvOp op = GroupedConvOp::Forward) +{ + GroupedConvProblem p; + p.N = N; + p.C = C; + p.K = K; + p.G = 1; + p.input_spatial = {Di, Hi, Wi}; + p.filter_spatial = {Z, Y, X}; + p.stride = {stride, stride, stride}; + p.padding = {padding, padding, padding}; + p.dilation = {1, 1, 1}; + p.op = op; + p.compute_output_size(); + return p; +} + +inline GroupedConvProblem create_depthwise_grouped_conv2d_problem( + int N, int C, int Hi, int Wi, int Y, int X, int stride = 1, int padding = 0) +{ + GroupedConvProblem p; + p.N = N; + p.C = C; + p.K = C; + p.G = C; + p.input_spatial = {1, Hi, Wi}; + p.filter_spatial = {1, Y, X}; + p.stride = {1, stride, stride}; + p.padding = {0, padding, padding}; + p.dilation = {1, 1, 1}; + p.op = GroupedConvOp::Forward; + p.compute_output_size(); + return p; +} + +inline void print_pattern_docs(std::ostream& os = std::cout) +{ + os << "Grouped Convolution Pattern Documentation\n"; + os << "==========================================\n"; + os << "Signature patterns: dtype, layout, conv_type (forward/bwd_data/bwd_weight), dims " + "(2/3)\n"; + os << "Algorithm patterns: tile(M,N,K), wave(M,N,K), warp(M,N,K), pipeline, vector_sizes\n"; + os << "Arch patterns: gfx942, gfx90a, gfx950, or '*' for all\n"; +} + +inline void print_grouped_conv_kernel_decl(const GroupedConvKernelDecl& decl, + std::ostream& os = std::cout) +{ + os << "GroupedConvKernelDecl: " << decl.name() << "\n"; + os << " Signature: dtype=" << decl.signature.dtype_in_ << ", layout=" << decl.signature.layout_ + << ", conv_type=" << decl.signature.conv_op_ << ", dims=" << decl.signature.num_dims_ + << "\n"; + os << " Algorithm: tile=" << decl.algorithm.tile_m_ << "x" << decl.algorithm.tile_n_ << "x" + << decl.algorithm.tile_k_ << ", wave=" << decl.algorithm.wave_m_ << "x" + << decl.algorithm.wave_n_ << "x" << decl.algorithm.wave_k_ + << ", warp=" << decl.algorithm.warp_m_ << "x" << decl.algorithm.warp_n_ << "x" + << decl.algorithm.warp_k_ << ", pipeline=" << decl.algorithm.pipeline_ << "\n"; + os << " Arch: " << decl.arch << "\n"; +} + +inline void print_grouped_conv_problem(const GroupedConvProblem& p, std::ostream& os = std::cout) +{ + os << p.to_string() << "\n"; + os << " FLOPs: " << std::scientific << p.get_flops() << "\n"; +} + +inline GroupedConvKernelSet build_grouped_conv2d_fwd_set(const std::string& dtype = "fp16", + const std::string& arch = "gfx942") +{ + GroupedConvKernelSet set; + auto decl1 = create_grouped_conv2d_fwd(dtype, 128, 128, arch); + set.add(decl1.signature, decl1.algorithm, decl1.arch); + auto decl2 = create_grouped_conv2d_fwd(dtype, 256, 256, arch); + set.add(decl2.signature, decl2.algorithm, decl2.arch); + return set; +} + +inline GroupedConvKernelSet build_grouped_conv2d_full_set(const std::string& dtype = "fp16", + const std::string& arch = "gfx942") +{ + GroupedConvKernelSet set; + set.merge(build_grouped_conv2d_fwd_set(dtype, arch)); + auto bwd_data = create_grouped_conv2d_bwd_data(dtype, 128, 128, arch); + set.add(bwd_data.signature, bwd_data.algorithm, bwd_data.arch); + auto bwd_weight = create_grouped_conv2d_bwd_weight(dtype, 128, 128, arch); + set.add(bwd_weight.signature, bwd_weight.algorithm, bwd_weight.arch); + return set; +} + +struct ValidationResult +{ + bool passed = false; + float max_abs_diff = 0.0f; + float max_rel_diff = 0.0f; + float rtol = 1e-3f; + float atol = 1e-3f; + + void print(std::ostream& os = std::cout) const + { + os << "ValidationResult: " << (passed ? "PASSED" : "FAILED") << "\n"; + os << " max_abs_diff: " << max_abs_diff << ", max_rel_diff: " << max_rel_diff << "\n"; + os << " rtol: " << rtol << ", atol: " << atol << "\n"; + } +}; + +template +inline ValidationResult validate_buffers( + const T* result, const T* reference, size_t count, float rtol = 1e-3f, float atol = 1e-3f) +{ + ValidationResult vr; + vr.rtol = rtol; + vr.atol = atol; + vr.passed = true; + + for(size_t i = 0; i < count; ++i) + { + float r = static_cast(result[i]); + float ref = static_cast(reference[i]); + float abs_diff = std::abs(r - ref); + float rel_diff = (std::abs(ref) > 1e-10f) ? (abs_diff / std::abs(ref)) : 0.0f; + + vr.max_abs_diff = std::max(vr.max_abs_diff, abs_diff); + vr.max_rel_diff = std::max(vr.max_rel_diff, rel_diff); + + float threshold = atol + rtol * std::abs(ref); + if(abs_diff > threshold) + { + vr.passed = false; + } + } + + return vr; +} + +struct BenchmarkResult +{ + std::string kernel_name; + float time_ms = 0.0f; + float tflops = 0.0f; + int warmup_runs = 0; + int benchmark_runs = 0; + + void print(std::ostream& os = std::cout) const + { + os << "BenchmarkResult: " << kernel_name << "\n"; + os << " time_ms: " << time_ms << ", tflops: " << tflops << "\n"; + os << " warmup_runs: " << warmup_runs << ", benchmark_runs: " << benchmark_runs << "\n"; + } +}; + +inline float calc_tflops(double flops, float time_ms) +{ + return static_cast(flops / (time_ms * 1e9)); +} + +inline double calculate_conv_tflops(const GroupedConvProblem& problem, double time_ms) +{ + return problem.get_flops() / (time_ms * 1e9); +} + +} // namespace grouped_conv_utils + +namespace examples { +inline int basic_grouped_conv_example_main(const std::string& example_name) +{ + std::cout << "=== " << example_name << " ===\n"; + + // Create a grouped convolution problem + auto problem = grouped_conv_utils::create_grouped_conv2d_problem( + 32, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward); + + grouped_conv_utils::print_grouped_conv_problem(problem); + + // Create and print a kernel declaration + auto decl = grouped_conv_utils::create_grouped_conv2d_fwd("fp16", 128, 128, "gfx942"); + grouped_conv_utils::print_grouped_conv_kernel_decl(decl); + + // Build and print kernel set + auto kernel_set = grouped_conv_utils::build_grouped_conv2d_fwd_set("fp16", "gfx942"); + kernel_set.print(); + + return 0; +} +} // namespace examples + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp index f49b3a0d74..f5a93c6d34 100644 --- a/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp +++ b/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp @@ -140,6 +140,11 @@ struct KernelKey bool preshuffle; // Preshuffle (for weight preshuffle variants) bool transpose_c; // TransposeC std::uint8_t num_wave_groups; // NumWaveGroups + + // Padding support flags (kPadM, kPadN, kPadK in generated kernels) + bool pad_m = true; // Support arbitrary M dimensions via padding + bool pad_n = true; // Support arbitrary N dimensions via padding + bool pad_k = true; // Support arbitrary K dimensions via padding } algorithm; std::string gfx_arch; // e.g. "gfx942", "gfx90a", "gfx908" @@ -185,7 +190,10 @@ struct KernelKey algorithm.double_buffer, algorithm.preshuffle, algorithm.transpose_c, - algorithm.num_wave_groups); + algorithm.num_wave_groups, + algorithm.pad_m, + algorithm.pad_n, + algorithm.pad_k); } /// Equality comparison @@ -397,8 +405,14 @@ inline std::string KernelKey::encode_identifier() const // Include pipeline, scheduler, epilogue for uniqueness oss << to_string(algorithm.pipeline) << "_"; - oss << to_string(algorithm.scheduler) << "_"; oss << to_string(algorithm.epilogue) << "_"; + oss << to_string(algorithm.scheduler) << "_"; + + // Match tile_engine naming: padding flags (True/False) then persistent flag + oss << (algorithm.pad_m ? "True" : "False") << "_"; + oss << (algorithm.pad_n ? "True" : "False") << "_"; + oss << (algorithm.pad_k ? "True" : "False") << "_"; + oss << (algorithm.persistent ? "True" : "False") << "_"; // Match tile_engine naming: tile_m x tile_n x tile_k _ warp_m x warp_n x warp_k _ // warp_tile_m x warp_tile_n x warp_tile_k @@ -407,9 +421,6 @@ inline std::string KernelKey::encode_identifier() const << unsigned(algorithm.wave_shape.k) << "_" << unsigned(algorithm.warp_tile_shape.m) << "x" << unsigned(algorithm.warp_tile_shape.n) << "x" << unsigned(algorithm.warp_tile_shape.k); - // Add trait flags - oss << "_" << (algorithm.persistent ? "persist" : "nopers"); - if(signature.split_k > 1) oss << "_splitk" << unsigned(signature.split_k); if(!signature.elementwise_op.empty() && signature.elementwise_op != "PassThrough") diff --git a/dispatcher/include/ck_tile/dispatcher/ml_heuristic.hpp b/dispatcher/include/ck_tile/dispatcher/ml_heuristic.hpp new file mode 100644 index 0000000000..359d772735 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/ml_heuristic.hpp @@ -0,0 +1,379 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +namespace ck_tile { +namespace dispatcher { +extern "C" { +int LGBM_BoosterCreateFromModelfile(const char*, int*, void**); +int LGBM_BoosterPredictForMat( + void*, const void*, int, int, int, int, int, int, int, const char*, int64_t*, double*); +int LGBM_BoosterFree(void*); +} +inline int encode_pipeline(Pipeline p) +{ + switch(p) + { + case Pipeline::CompV3: return 0; + case Pipeline::CompV4: return 1; + case Pipeline::CompV5: return 2; + case Pipeline::Mem: return 3; + case Pipeline::PreShuffleV2: return 4; + default: return 0; + } +} +inline int encode_scheduler(Scheduler s) +{ + switch(s) + { + case Scheduler::Intrawave: return 0; + case Scheduler::Interwave: return 1; + default: return 0; + } +} +inline int encode_epilogue(Epilogue e) +{ + switch(e) + { + case Epilogue::Default: return 0; + case Epilogue::CShuffle: return 1; + default: return 0; + } +} +inline int encode_layout(LayoutTag a, LayoutTag b, LayoutTag c) +{ + bool ra = (a == LayoutTag::RowMajor), rb = (b == LayoutTag::RowMajor); + if(ra && !rb) + return 0; // RCR + if(ra && rb) + return 1; // RRR + if(!ra && rb) + return 2; // CCR + return 3; // CRR +} +inline double dtype_bytes_ml(DataType dt) +{ + switch(dt) + { + case DataType::FP32: return 4; + case DataType::FP16: + case DataType::BF16: return 2; + case DataType::FP8: + case DataType::BF8: + case DataType::INT8: return 1; + case DataType::INT4: return 0.5; + default: return 2; + } +} +struct HardwareProfile +{ + int num_cus = 256, simds_per_cu = 4, shader_engines = 32, max_clock_mhz = 2400, + max_waves_per_cu = 32, wavefront_size = 64, lds_capacity = 65536, l1_cache_kb = 32, + l2_cache_kb = 4096, l3_cache_kb = 262144, num_xcd = 8; + int total_simds() const { return num_cus * simds_per_cu; } +}; + +// CRITICAL: Feature count MUST match feature_spec.json +// Python training uses 72 features - this header MUST extract exactly 72 features in the same order +static constexpr int NUM_FEATURES = 72; + +inline std::array +extract_features(const Problem& prob, const KernelKey& key, const HardwareProfile& hw) +{ + // Problem dimensions + double M = prob.M, N = prob.N, K = prob.K; + double sk = (prob.k_batch > 0 ? prob.k_batch : 1); + double bpe = dtype_bytes_ml(key.signature.dtype_a); + + // Log-scale features + double l2M = std::log2(std::max(M, 1.0)); + double l2N = std::log2(std::max(N, 1.0)); + double l2K = std::log2(std::max(K, 1.0)); + double l2MNK = std::log2(std::max(M * N * K, 1.0)); + + // Arithmetic intensity + double mem = (M * K + K * N + M * N) * bpe; + double ai = 2.0 * M * N * K / std::max(mem, 1.0); + + // Aspect ratios + double ar_mn = M / std::max(N, 1.0); + double ar_mk = M / std::max(K, 1.0); + double ar_nk = N / std::max(K, 1.0); + + // Layout encoding + double layout = (double)encode_layout( + key.signature.layout_a, key.signature.layout_b, key.signature.layout_c); + + // Tile dimensions + double tm = key.algorithm.tile_shape.m; + double tn = key.algorithm.tile_shape.n; + double tk = key.algorithm.tile_shape.k; + + // Wave/warp dimensions + double wm = key.algorithm.wave_shape.m; + double wn = key.algorithm.wave_shape.n; + double wk = key.algorithm.wave_shape.k; + + // Warp tile dimensions + double wtm = key.algorithm.warp_tile_shape.m; + double wtn = key.algorithm.warp_tile_shape.n; + double wtk = key.algorithm.warp_tile_shape.k; + + // Algorithm encoding + double pipeline = (double)encode_pipeline(key.algorithm.pipeline); + double scheduler = (double)encode_scheduler(key.algorithm.scheduler); + double epilogue = (double)encode_epilogue(key.algorithm.epilogue); + + // Padding flags - read from KernelKey + double pad_m = key.algorithm.pad_m ? 1.0 : 0.0; + double pad_n = key.algorithm.pad_n ? 1.0 : 0.0; + double pad_k = key.algorithm.pad_k ? 1.0 : 0.0; + + // Persistent kernel flag + double persistent = key.algorithm.persistent ? 1.0 : 0.0; + + // Derived features + double num_warps = wm * wn * wk; + double tile_volume = tm * tn * tk; + double tile_mn = tm * tn; + + // LDS usage estimation + double lest = (tm * tk + tn * tk) * bpe; + double lcap = (key.algorithm.pipeline == Pipeline::CompV4) ? 32768.0 : (double)hw.lds_capacity; + double lds_ratio = lest / std::max(lcap, 1.0); + + // Tile counts + double ntm = std::ceil(M / std::max(tm, 1.0)); + double ntn = std::ceil(N / std::max(tn, 1.0)); + double ntk = std::ceil(K / std::max(tk, 1.0)); + double total_output_tiles = ntm * ntn; + + // Tile efficiency (fractional remainder utilization) + auto ef = [](double d, double t) -> double { + if(t <= 0) + return 1.0; + double r = std::fmod(d, t); + return r > 0 ? r / t : 1.0; + }; + double tile_eff_m = ef(M, tm); + double tile_eff_n = ef(N, tn); + double tile_eff_k = ef(K, tk); + double overall_tile_efficiency = tile_eff_m * tile_eff_n * tile_eff_k; + + // CU utilization + double cu_utilization = total_output_tiles / std::max((double)hw.num_cus, 1.0); + + // P0 FIX: Problem-to-tile ratio features (critical for small problems) + double ratio_M_to_tile_m = M / std::max(tm, 1.0); + double ratio_N_to_tile_n = N / std::max(tn, 1.0); + double ratio_K_to_tile_k = K / std::max(tk, 1.0); + + // Binary features: is problem dimension smaller than tile? + double problem_smaller_than_tile_m = (M < tm) ? 1.0 : 0.0; + double problem_smaller_than_tile_n = (N < tn) ? 1.0 : 0.0; + double problem_smaller_than_tile_k = (K < tk) ? 1.0 : 0.0; + double any_dim_too_small = ((M < tm) || (N < tn) || (K < tk)) ? 1.0 : 0.0; + + // P1 FIX: Padding requirement features + double needs_padding_m = (tm > 0 && std::fmod(M, tm) != 0.0) ? 1.0 : 0.0; + double needs_padding_n = (tn > 0 && std::fmod(N, tn) != 0.0) ? 1.0 : 0.0; + double needs_padding_k = (tk > 0 && std::fmod(K, tk) != 0.0) ? 1.0 : 0.0; + + // Interaction features: kernel has padding when problem needs it + double has_padding_when_needed_m = (needs_padding_m && pad_m) ? 1.0 : 0.0; + double has_padding_when_needed_n = (needs_padding_n && pad_n) ? 1.0 : 0.0; + double has_padding_when_needed_k = (needs_padding_k && pad_k) ? 1.0 : 0.0; + + // Critical feature: missing required padding (kernel will likely fail) + double missing_required_padding_m = (needs_padding_m && !pad_m) ? 1.0 : 0.0; + double missing_required_padding_n = (needs_padding_n && !pad_n) ? 1.0 : 0.0; + double missing_required_padding_k = (needs_padding_k && !pad_k) ? 1.0 : 0.0; + double missing_any_required_padding = + (missing_required_padding_m || missing_required_padding_n || missing_required_padding_k) + ? 1.0 + : 0.0; + + // Hardware features + double hw_num_cus = (double)hw.num_cus; + double hw_simds_per_cu = (double)hw.simds_per_cu; + double hw_total_simds = (double)hw.total_simds(); + double hw_shader_engines = (double)hw.shader_engines; + double hw_max_clock_mhz = (double)hw.max_clock_mhz; + double hw_max_waves_per_cu = (double)hw.max_waves_per_cu; + double hw_wavefront_size = (double)hw.wavefront_size; + double hw_lds_capacity = (double)hw.lds_capacity; + double hw_l1_cache_kb = (double)hw.l1_cache_kb; + double hw_l2_cache_kb = (double)hw.l2_cache_kb; + double hw_l3_cache_kb = (double)hw.l3_cache_kb; + double hw_num_xcd = (double)hw.num_xcd; + + // Feature vector in EXACT order from feature_spec.json + // This order MUST match Python feature_engine.py::get_feature_names() + return {{ + M, // 0 + N, // 1 + K, // 2 + sk, // 3 (split_k) + l2M, // 4 (log2_M) + l2N, // 5 (log2_N) + l2K, // 6 (log2_K) + l2MNK, // 7 (log2_MNK) + ai, // 8 (arithmetic_intensity) + ar_mn, // 9 (aspect_ratio_mn) + ar_mk, // 10 (aspect_ratio_mk) + ar_nk, // 11 (aspect_ratio_nk) + layout, // 12 (layout) + tm, // 13 (tile_m) + tn, // 14 (tile_n) + tk, // 15 (tile_k) + wm, // 16 (warp_m) + wn, // 17 (warp_n) + wk, // 18 (warp_k) + wtm, // 19 (warp_tile_m) + wtn, // 20 (warp_tile_n) + wtk, // 21 (warp_tile_k) + pipeline, // 22 (pipeline) + scheduler, // 23 (scheduler) + epilogue, // 24 (epilogue) + pad_m, // 25 (pad_m) + pad_n, // 26 (pad_n) + pad_k, // 27 (pad_k) + persistent, // 28 (persistent) + num_warps, // 29 (num_warps) + tile_volume, // 30 (tile_volume) + tile_mn, // 31 (tile_mn) + lest, // 32 (lds_usage_estimate) + lds_ratio, // 33 (lds_usage_ratio) + ntm, // 34 (num_tiles_m) + ntn, // 35 (num_tiles_n) + ntk, // 36 (num_tiles_k) + total_output_tiles, // 37 (total_output_tiles) + tile_eff_m, // 38 (tile_eff_m) + tile_eff_n, // 39 (tile_eff_n) + tile_eff_k, // 40 (tile_eff_k) + overall_tile_efficiency, // 41 (overall_tile_efficiency) + cu_utilization, // 42 (cu_utilization) + ratio_M_to_tile_m, // 43 (ratio_M_to_tile_m) + ratio_N_to_tile_n, // 44 (ratio_N_to_tile_n) + ratio_K_to_tile_k, // 45 (ratio_K_to_tile_k) + problem_smaller_than_tile_m, // 46 (problem_smaller_than_tile_m) + problem_smaller_than_tile_n, // 47 (problem_smaller_than_tile_n) + problem_smaller_than_tile_k, // 48 (problem_smaller_than_tile_k) + any_dim_too_small, // 49 (any_dim_too_small) + needs_padding_m, // 50 (needs_padding_m) + needs_padding_n, // 51 (needs_padding_n) + needs_padding_k, // 52 (needs_padding_k) + has_padding_when_needed_m, // 53 (has_padding_when_needed_m) + has_padding_when_needed_n, // 54 (has_padding_when_needed_n) + has_padding_when_needed_k, // 55 (has_padding_when_needed_k) + missing_required_padding_m, // 56 (missing_required_padding_m) + missing_required_padding_n, // 57 (missing_required_padding_n) + missing_required_padding_k, // 58 (missing_required_padding_k) + missing_any_required_padding, // 59 (missing_any_required_padding) + hw_num_cus, // 60 (hw_num_cus) + hw_simds_per_cu, // 61 (hw_simds_per_cu) + hw_total_simds, // 62 (hw_total_simds) + hw_shader_engines, // 63 (hw_shader_engines) + hw_max_clock_mhz, // 64 (hw_max_clock_mhz) + hw_max_waves_per_cu, // 65 (hw_max_waves_per_cu) + hw_wavefront_size, // 66 (hw_wavefront_size) + hw_lds_capacity, // 67 (hw_lds_capacity) + hw_l1_cache_kb, // 68 (hw_l1_cache_kb) + hw_l2_cache_kb, // 69 (hw_l2_cache_kb) + hw_l3_cache_kb, // 70 (hw_l3_cache_kb) + hw_num_xcd, // 71 (hw_num_xcd) + }}; +} + +class MLHeuristic +{ + public: + MLHeuristic(const std::string& path, + const Registry* reg, + HardwareProfile hw = {}, + bool log_t = false) + : registry_(reg), hw_(hw), log_t_(log_t) + { + int iters = 0; + if(LGBM_BoosterCreateFromModelfile(path.c_str(), &iters, &b_) != 0 || !b_) + { + std::cerr << "MLHeuristic: Failed to load " << path << std::endl; + + // Check if a compressed .gz version exists + std::string gz_path = path + ".gz"; + std::ifstream gz_check(gz_path); + if(gz_check.good()) + { + std::cerr << "MLHeuristic: Found compressed model at " << gz_path << std::endl; + std::cerr << "MLHeuristic: Please decompress it first:" << std::endl; + std::cerr << " gunzip " << gz_path << std::endl; + } + + b_ = nullptr; + } + else + std::cout << "MLHeuristic: Loaded (" << iters << " iters)" << std::endl; + } + ~MLHeuristic() + { + if(b_) + LGBM_BoosterFree(b_); + } + MLHeuristic(const MLHeuristic&) = delete; + MLHeuristic& operator=(const MLHeuristic&) = delete; + bool is_loaded() const { return b_ != nullptr; } + double predict_tflops(const Problem& prob, const KernelKey& key) const + { + if(!b_) + return 0; + auto f = extract_features(prob, key, hw_); + int64_t ol = 0; + double pred = 0; + if(LGBM_BoosterPredictForMat( + b_, f.data(), 0, 1, NUM_FEATURES, 1, 0, 0, 0, "", &ol, &pred) != 0) + return 0; + return log_t_ ? std::expm1(pred) : pred; + } + std::vector operator()(const Problem& prob) const + { + if(!b_ || !registry_) + return {}; + auto insts = registry_->get_all(); + struct C + { + std::string id; + double t; + }; + std::vector cs; + cs.reserve(insts.size()); + for(auto& i : insts) + { + auto& k = i->get_key(); + cs.push_back({k.encode_identifier(), predict_tflops(prob, k)}); + } + std::sort(cs.begin(), cs.end(), [](auto& a, auto& b) { return a.t > b.t; }); + std::vector r; + r.reserve(cs.size()); + for(auto& c : cs) + r.push_back(std::move(c.id)); + return r; + } + + private: + void* b_ = nullptr; + const Registry* registry_ = nullptr; + HardwareProfile hw_; + bool log_t_ = false; +}; +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/problem.hpp b/dispatcher/include/ck_tile/dispatcher/problem.hpp index 437511d1ba..5bffb56b49 100644 --- a/dispatcher/include/ck_tile/dispatcher/problem.hpp +++ b/dispatcher/include/ck_tile/dispatcher/problem.hpp @@ -98,7 +98,7 @@ struct Problem /** * Create Problem by inferring MNK from tensor shapes. * - * For GEMM: C[M,N] = A[M,K] × B[K,N] + * For GEMM: C[M,N] = A[M,K] x B[K,N] * * @param a_shape Shape of matrix A (M x K, or K x M if transposed) * @param b_shape Shape of matrix B (K x N, or N x K if transposed) @@ -113,7 +113,7 @@ struct Problem [[nodiscard]] static Problem from_shapes(TensorShape a_shape, TensorShape b_shape, TensorShape c_shape) { - // For C = A × B: + // For C = A x B: // A: [M, K] (or [K, M] if transposed) // B: [K, N] (or [N, K] if transposed) // C: [M, N] @@ -164,7 +164,7 @@ struct Problem * @throws std::invalid_argument if dimensions are inconsistent * * Example: - * // A[512,256] × B[256,1024] = C[512,1024] + * // A[512,256] x B[256,1024] = C[512,1024] * auto problem = Problem::from_dimensions(512, 256, 256, 1024, 512, 1024); */ [[nodiscard]] static Problem from_dimensions(std::int64_t a_rows, @@ -188,7 +188,7 @@ struct Problem * @throws std::invalid_argument if K dimensions don't match * * Example: - * // A[512,256] × B[256,1024] = C[512,1024] + * // A[512,256] x B[256,1024] = C[512,1024] * auto problem = Problem::from_ab(512, 256, 256, 1024); */ [[nodiscard]] static Problem diff --git a/dispatcher/include/ck_tile/dispatcher/registry.hpp b/dispatcher/include/ck_tile/dispatcher/registry.hpp index 93d1eb9f64..4f34e589ea 100644 --- a/dispatcher/include/ck_tile/dispatcher/registry.hpp +++ b/dispatcher/include/ck_tile/dispatcher/registry.hpp @@ -7,38 +7,20 @@ * Central registry for all available kernel instances with priority-based * ordering and efficient lookup. * - * Features: - * - Thread-safe registration and lookup - * - Priority-based ordering (High, Normal, Low) - * - Lookup by name or KernelKey - * - Filter by problem compatibility - * - Supports both singleton and multiple instance patterns - * - * Usage (Singleton - backward compatible): - * auto& registry = Registry::instance(); - * registry.register_kernel(kernel, Priority::High); - * auto kernel = registry.lookup("kernel_name"); - * - * Usage (Multiple registries): - * Registry fp16_registry; - * Registry bf16_registry; - * fp16_registry.register_kernel(fp16_kernel, Priority::High); - * bf16_registry.register_kernel(bf16_kernel, Priority::High); - * - * Dispatcher fp16_dispatcher(&fp16_registry); - * Dispatcher bf16_dispatcher(&bf16_registry); + * Derives from BaseRegistry for shared logic (thread safety, naming, priority, + * merge) while keeping GEMM-specific APIs (lookup by KernelKey, filter_by_arch, + * JSON export, auto-export). * * Status: Production ready, thread-safe */ #pragma once +#include "ck_tile/dispatcher/base_registry.hpp" #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/kernel_key.hpp" #include -#include #include -#include #include #include @@ -47,20 +29,16 @@ namespace dispatcher { /// Registry: Central mapping from kernel configurations to executable instances /// Thread-safe kernel registration and lookup -/// Supports both singleton pattern and multiple independent instances -class Registry +/// Derives from BaseRegistry for shared functionality +class Registry : public BaseRegistry { + using Base = BaseRegistry; + public: - /// Priority levels for conflict resolution when multiple kernels have same key - enum class Priority - { - Low = 0, - Normal = 1, - High = 2 - }; + // Re-export Priority from the shared enum for backward compatibility + using Priority = ck_tile::dispatcher::Priority; /// Default constructor - creates an empty registry instance - /// Use this to create independent registries for different kernel sets Registry(); /// Destructor - triggers auto-export if enabled @@ -72,106 +50,51 @@ class Registry /// Move assignment Registry& operator=(Registry&& other) noexcept; - // Prevent copying (registries contain shared_ptrs that shouldn't be duplicated) + // Prevent copying Registry(const Registry&) = delete; Registry& operator=(const Registry&) = delete; /// Register a kernel instance with the registry - /// @param instance Kernel instance to register - /// @param priority Priority level for conflict resolution (default: Normal) - /// @return true if registered successfully, false if duplicate with higher priority exists bool register_kernel(KernelInstancePtr instance, Priority priority = Priority::Normal); /// Lookup a kernel by its string identifier - /// @param identifier Kernel identifier string - /// @return Kernel instance if found, nullptr otherwise [[nodiscard]] KernelInstancePtr lookup(const std::string& identifier) const; /// Lookup a kernel by its KernelKey - /// @param key Kernel configuration key - /// @return Kernel instance if found, nullptr otherwise [[nodiscard]] KernelInstancePtr lookup(const KernelKey& key) const; /// Get all registered kernels - /// @return Vector of all kernel instances [[nodiscard]] std::vector get_all() const; /// Get all kernels matching a predicate - /// @param predicate Function to filter kernels - /// @return Vector of matching kernel instances [[nodiscard]] std::vector filter(std::function predicate) const; - /// Get number of registered kernels - [[nodiscard]] std::size_t size() const; - - /// Check if registry is empty - [[nodiscard]] bool empty() const; - - /// Clear all registered kernels - void clear(); - - /// Get registry name (for logging/debugging) - [[nodiscard]] const std::string& get_name() const; - - /// Set registry name (for logging/debugging) - void set_name(const std::string& name); + // size(), empty(), clear(), get_name(), set_name(), merge_from() inherited from Base /// Export registry to JSON string - /// @param include_statistics Whether to include kernel statistics breakdown - /// @return JSON string with all kernel metadata [[nodiscard]] std::string export_json(bool include_statistics = true) const; /// Export registry to JSON file - /// @param filename Output filename - /// @param include_statistics Whether to include kernel statistics breakdown - /// @return true if export succeeded, false otherwise bool export_json_to_file(const std::string& filename, bool include_statistics = true) const; - /// Enable automatic JSON export on kernel registration - /// @param filename Output filename for auto-export - /// @param include_statistics Whether to include statistics in auto-export - /// @param export_on_every_registration If true, exports after every registration (default). - /// If false, only exports on destruction. void enable_auto_export(const std::string& filename, bool include_statistics = true, bool export_on_every_registration = true); - /// Disable automatic JSON export void disable_auto_export(); - /// Check if auto-export is enabled [[nodiscard]] bool is_auto_export_enabled() const; - /// Merge kernels from another registry into this one - /// @param other Registry to merge from - /// @param priority Priority for merged kernels (default: Normal) - /// @return Number of kernels successfully merged - std::size_t merge_from(const Registry& other, Priority priority = Priority::Normal); - /// Filter kernels in-place by architecture - /// @param gpu_arch Target GPU architecture string (e.g., "gfx942") - /// @return Number of kernels removed std::size_t filter_by_arch(const std::string& gpu_arch); - /// Get singleton instance of the global registry (backward compatible) - /// This is the default registry used when no specific registry is provided + /// Get singleton instance static Registry& instance(); private: - struct RegistryEntry - { - KernelInstancePtr instance; - Priority priority; - }; - - /// Perform auto-export if enabled void perform_auto_export(); - mutable std::mutex mutex_; - std::unordered_map kernels_; - std::string name_; - // Auto-export configuration bool auto_export_enabled_ = false; std::string auto_export_filename_; @@ -179,7 +102,7 @@ class Registry bool auto_export_on_every_registration_ = true; }; -/// Shared pointer type for registries (useful for managing lifetime) +/// Shared pointer type for registries using RegistryPtr = std::shared_ptr; /// Create a new registry instance (factory function) diff --git a/dispatcher/include/ck_tile/dispatcher_conv.hpp b/dispatcher/include/ck_tile/dispatcher_conv.hpp new file mode 100644 index 0000000000..46d14f90f3 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher_conv.hpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Grouped Convolution-only dispatcher header -- minimal include for conv operations. + +#pragma once + +// Core (needed by all ops) +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +// Grouped Convolution +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher_gemm.hpp b/dispatcher/include/ck_tile/dispatcher_gemm.hpp new file mode 100644 index 0000000000..79317c7399 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher_gemm.hpp @@ -0,0 +1,22 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// GEMM-only dispatcher header -- minimal include for GEMM operations. + +#pragma once + +// Core (needed by all ops) +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +// GEMM +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/kernel_config.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/json_export.hpp" +#include "ck_tile/dispatcher/utils.hpp" diff --git a/dispatcher/python/CMakeLists.txt b/dispatcher/python/CMakeLists.txt index e57678952e..71634fa926 100644 --- a/dispatcher/python/CMakeLists.txt +++ b/dispatcher/python/CMakeLists.txt @@ -3,7 +3,7 @@ # This directory contains Python utilities for the dispatcher examples. # The main utility file is ctypes_utils.py which is used by GEMM Python examples. -# Conv Python examples use their own conv_utils.py in the examples directory. +# Grouped conv Python examples use grouped_conv_utils.py in this directory. # No build targets needed - these are pure Python utilities. message(STATUS "Python utilities directory configured (no build targets)") diff --git a/dispatcher/python/README.md b/dispatcher/python/README.md index 9286acbf72..edbc7acc9d 100644 --- a/dispatcher/python/README.md +++ b/dispatcher/python/README.md @@ -4,6 +4,19 @@ This directory contains Python utilities used by the dispatcher examples. ## Contents +### Shared Utilities (used by both GEMM and Grouped Conv) + +- `dispatcher_common.py` - Shared dispatcher infrastructure + - Path helpers (`get_dispatcher_root`, `get_build_dir`, etc.) + - `ValidationResultBase` - Structured validation feedback + - `validate_wave_config`, `validate_warp_tile_config`, `validate_trait_combo` + - `auto_correct_wave`, `auto_correct_trait` - Auto-correction helpers + - `Colors` - Cross-platform ANSI color support + - `print_phase`, `print_success`, `print_error`, `print_info` - Phased output + - `cleanup_generated_kernels` - Cleanup helper + +### GEMM Utilities + - `ctypes_utils.py` - Core ctypes utilities for GEMM Python examples - `KernelConfig` - Kernel configuration dataclass - `setup_gemm_dispatcher()` - Setup dispatcher with auto-correction @@ -11,11 +24,15 @@ This directory contains Python utilities used by the dispatcher examples. - `GemmRunner` - GPU execution helper - Auto-correction and validation utilities -- `conv_utils.py` - Core utilities for Conv Python examples - - `ConvSignature`, `ConvAlgorithm` - Convolution configuration - - `ConvProblem` - Problem definition - - `GpuConvRunner` - GPU execution helper - - `EnhancedConvCodegenRunner` - Kernel codegen utilities +### Grouped Convolution Utilities + +- `grouped_conv_utils.py` - Utilities for grouped convolution + - `GroupedConvValidationResult` - Validation result (extends `ValidationResultBase`) + - `validate_grouped_conv_config` - Validate a grouped conv config + - `auto_correct_grouped_conv_config` - Auto-correct invalid configs + - `get_grouped_conv_default_config` - Get default config for a variant + - `GroupedConvDataType` - Data type enum (FP16, BF16, FP32, FP8, BF8, INT8) + - `format_grouped_conv_summary` - Human-readable config summary ## Usage @@ -36,21 +53,26 @@ from ctypes_utils import ( ) ``` -### Conv Examples - -The Conv Python examples in `dispatcher/examples/conv/python/` import: +### Grouped Conv Usage ```python import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) -from conv_utils import ( - ConvSignature, - ConvAlgorithm, - ConvProblem, - GpuConvRunner, +from grouped_conv_utils import ( + validate_grouped_conv_config, + auto_correct_grouped_conv_config, + get_grouped_conv_default_config, + GroupedConvDataType, ) + +# Get a default config +config = get_grouped_conv_default_config(variant="forward", arch="gfx942") + +# Validate +result = validate_grouped_conv_config(config) +print(f"Valid: {result.is_valid}") ``` ## Requirements diff --git a/dispatcher/python/ctypes_utils.py b/dispatcher/python/ctypes_utils.py index 821fc2b08d..c11aaca835 100644 --- a/dispatcher/python/ctypes_utils.py +++ b/dispatcher/python/ctypes_utils.py @@ -37,6 +37,43 @@ import multiprocessing import time +# ============================================================================= +# GPU Architecture Auto-Detection +# ============================================================================= + +_detected_arch: Optional[str] = None + + +def detect_gpu_arch(fallback: str = "gfx942") -> str: + """ + Auto-detect the GPU architecture by querying rocminfo. + + Caches the result after the first call. Falls back to `fallback` if + detection fails (e.g. no GPU, rocminfo not installed). + """ + global _detected_arch + if _detected_arch is not None: + return _detected_arch + + try: + result = subprocess.run( + ["/opt/rocm/bin/rocminfo"], capture_output=True, text=True, timeout=10 + ) + for line in result.stdout.splitlines(): + stripped = line.strip() + if stripped.startswith("Name:") and "gfx" in stripped: + # Extract e.g. "gfx950" from "Name: gfx950" + name = stripped.split(":", 1)[1].strip() + if name.startswith("gfx") and name[3:].isdigit(): + _detected_arch = name + return _detected_arch + except Exception: + pass + + _detected_arch = fallback + return _detected_arch + + # ============================================================================= # Path Configuration # ============================================================================= @@ -159,9 +196,9 @@ class ValidationResult: def print_result(self, indent: str = " "): """Print validation result.""" if self.is_valid: - print(f"{indent}✓ Configuration valid") + print(f"{indent}OK Configuration valid") else: - print(f"{indent}⚠ Configuration has issues:") + print(f"{indent}WARNING Configuration has issues:") for err in self.errors: print(f"{indent} - {err}") @@ -300,7 +337,7 @@ def auto_correct_kernel_config( # Check each fix and describe what changed if "scheduler" in fixes and fixes["scheduler"] != config.scheduler: corrections.append( - f"Scheduler: {config.scheduler} → {fixes['scheduler']} " + f"Scheduler: {config.scheduler} -> {fixes['scheduler']} " f"('{config.scheduler}' not supported with pipeline={config.pipeline}, epilogue={config.epilogue})" ) @@ -309,7 +346,7 @@ def auto_correct_kernel_config( new_wave = f"[{fixes.get('wave_m', config.wave_m)}, {fixes.get('wave_n', config.wave_n)}, {fixes.get('wave_k', config.wave_k)}]" if old_wave != new_wave: corrections.append( - f"Wave config: {old_wave} → {new_wave} " + f"Wave config: {old_wave} -> {new_wave} " f"(original not supported on {config.gfx_arch})" ) @@ -318,7 +355,7 @@ def auto_correct_kernel_config( new_warp = f"[{fixes.get('warp_m', config.warp_m)}, {fixes.get('warp_n', config.warp_n)}, {fixes.get('warp_k', config.warp_k)}]" if old_warp != new_warp: corrections.append( - f"Warp tile: {old_warp} → {new_warp} " + f"Warp tile: {old_warp} -> {new_warp} " f"(original not supported for {config.dtype_a} on {config.gfx_arch})" ) @@ -386,13 +423,13 @@ def print_auto_correction( indent: Indentation for output """ if not corrections: - print(f"{indent}✓ Configuration valid - no corrections needed") + print(f"{indent}OK Configuration valid - no corrections needed") return - print(f"\n{indent}⚠ AUTO-CORRECTION APPLIED:") + print(f"\n{indent}WARNING AUTO-CORRECTION APPLIED:") print(f"{indent}" + "-" * 50) for correction in corrections: - print(f"{indent} • {correction}") + print(f"{indent} - {correction}") print(f"{indent}" + "-" * 50) print() @@ -976,6 +1013,226 @@ def _run_codegen_subprocess(args: Dict[str, Any]) -> CodegenResult: ) +def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]: + """Module-level function to run hipcc compilation in parallel.""" + import subprocess + from pathlib import Path + + compile_cmd = args["compile_cmd"] + link_cmd = args["link_cmd"] + lib_path = Path(args["lib_path"]) + + try: + res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300) + if res_c.returncode != 0: + return False, None, f"Compile failed: {res_c.stderr[:200]}" + + res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300) + if res_l.returncode != 0: + return False, None, f"Link failed: {res_l.stderr[:200]}" + + return True, lib_path, "" + except subprocess.TimeoutExpired: + return False, None, "Timeout" + except Exception as e: + return False, None, str(e) + + +def _generate_single_kernel_subprocess(args: dict) -> Tuple[bool, Optional[str], str]: + """Module-level function: generate ONE kernel .hpp via --config JSON file. + + Used by setup_multiple_gemm_dispatchers for per-config parallel codegen. + Returns (success, header_path_or_None, error_msg). + """ + import subprocess + import json + import tempfile + import os + from pathlib import Path + + try: + out_dir = Path(args["output_dir"]) + out_dir.mkdir(parents=True, exist_ok=True) + + # Write the single-config JSON to a temp file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(args["tile_config_json"], f) + config_file = f.name + + cmd = [ + args["python"], + str(args["codegen_script"]), + "--output-dir", + str(out_dir), + "--datatype", + args["dtype"], + "--layout", + args["layout"], + "--gpu-target", + args["gpu_target"], + "--config", + config_file, + "--variants", + "standard", + ] + + res = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + os.unlink(config_file) + + if res.returncode != 0: + return False, None, f"Codegen failed: {res.stderr[:200]}" + + # Find the generated .hpp using the expected name pattern + pattern = args["hpp_glob_pattern"] + matches = sorted(out_dir.glob(pattern)) + if matches: + return True, str(matches[0]), "" + else: + return False, None, f"No .hpp matching {pattern} after codegen" + + except Exception as e: + return False, None, str(e) + + +def _parse_triplet(text: str) -> Optional[Tuple[int, int, int]]: + parts = text.split("x") + if len(parts) != 3: + return None + try: + return (int(parts[0]), int(parts[1]), int(parts[2])) + except ValueError: + return None + + +def _parse_gemm_header_metadata(header: Path) -> Optional[Dict[str, Any]]: + """ + Parse GEMM header name into configuration metadata. + + Expected stem format: + gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler} + _{pad_m}_{pad_n}_{pad_k}_{persistent} + _{tile_m}x{tile_n}x{tile_k}_{wave_m}x{wave_n}x{wave_k}_{warp_m}x{warp_n}x{warp_k} + """ + parts = header.stem.split("_") + if len(parts) < 13 or parts[0] != "gemm": + return None + + tile = _parse_triplet(parts[10]) + wave = _parse_triplet(parts[11]) + warp = _parse_triplet(parts[12]) + if tile is None or wave is None or warp is None: + return None + + def _as_bool(v: str) -> bool: + return v.lower() == "true" + + return { + "dtype": parts[1], + "layout": parts[2], + "pipeline": parts[3], + "epilogue": parts[4], + "scheduler": parts[5], + "pad_m": _as_bool(parts[6]), + "pad_n": _as_bool(parts[7]), + "pad_k": _as_bool(parts[8]), + "persistent": _as_bool(parts[9]), + "tile": tile, + "wave": wave, + "warp": warp, + } + + +def _generate_arch_valid_gemm_headers( + python_exe: str, + codegen_script: Path, + output_dir: Path, + dtype: str, + layout: str, + gpu_target: str, + variant: str = "standard", +) -> Tuple[bool, List[Path], str]: + """Generate (or reuse) an arch-filtered kernel catalog for fallback selection.""" + output_dir.mkdir(parents=True, exist_ok=True) + pattern = f"gemm_{dtype}_{layout}_*.hpp" + existing = sorted(output_dir.glob(pattern)) + if existing: + return True, existing, "" + + cmd = [ + python_exe, + str(codegen_script), + "--output-dir", + str(output_dir), + "--datatype", + dtype, + "--layout", + layout, + "--gpu-target", + gpu_target, + "--variants", + variant, + ] + res = subprocess.run(cmd, capture_output=True, text=True, timeout=600) + if res.returncode != 0: + err = (res.stderr or res.stdout or "").strip()[:500] + return False, [], f"Catalog codegen failed: {err}" + + generated = sorted(output_dir.glob(pattern)) + if not generated: + return False, [], "Catalog codegen produced no GEMM headers" + return True, generated, "" + + +def _select_best_arch_valid_gemm_header( + config: "KernelConfig", + headers: List[Path], +) -> Tuple[Optional[Path], Optional[Dict[str, Any]]]: + """Choose nearest arch-valid header for a requested GEMM config.""" + best: Optional[Path] = None + best_meta: Optional[Dict[str, Any]] = None + best_score: Optional[Tuple[int, int, int, int, int, int]] = None + + for h in headers: + meta = _parse_gemm_header_metadata(h) + if meta is None: + continue + if meta["dtype"] != config.dtype_a or meta["layout"] != config.layout: + continue + + tile = meta["tile"] + wave = meta["wave"] + warp = meta["warp"] + tile_delta = ( + abs(tile[0] - config.tile_m) + + abs(tile[1] - config.tile_n) + + abs(tile[2] - config.tile_k) + ) + wave_delta = ( + abs(wave[0] - config.wave_m) + + abs(wave[1] - config.wave_n) + + abs(wave[2] - config.wave_k) + ) + warp_delta = ( + abs(warp[0] - config.warp_m) + + abs(warp[1] - config.warp_n) + + abs(warp[2] - config.warp_k) + ) + score = ( + 0 if meta["pipeline"] == config.pipeline else 1, + 0 if meta["scheduler"] == config.scheduler else 1, + 0 if meta["epilogue"] == config.epilogue else 1, + tile_delta, + wave_delta, + warp_delta, + ) + if best_score is None or score < best_score: + best_score = score + best = h + best_meta = meta + + return best, best_meta + + # ============================================================================= # Preshuffle Utilities # ============================================================================= @@ -1319,7 +1576,7 @@ class CodegenRunner: result = future.result() results.append(result) if verbose: - status = "✓" if result.success else "✗" + status = "OK" if result.success else "FAIL" print( f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" ) @@ -1337,7 +1594,7 @@ class CodegenRunner: ) ) if verbose: - print(f" ✗ {variant}: FAILED - {e}") + print(f" FAIL {variant}: FAILED - {e}") total_time = time.time() - start_total if verbose: @@ -1399,7 +1656,7 @@ class CodegenRunner: result = future.result() results.append(result) if verbose: - status = "✓" if result.success else "✗" + status = "OK" if result.success else "FAIL" print( f" {status} {tile_str}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" ) @@ -1417,7 +1674,7 @@ class CodegenRunner: ) ) if verbose: - print(f" ✗ {tile_str}: FAILED - {e}") + print(f" FAIL {tile_str}: FAILED - {e}") total_time = time.time() - start_total if verbose: @@ -1481,7 +1738,7 @@ class CodegenRunner: result = future.result() results.append(result) if verbose: - status = "✓" if result.success else "✗" + status = "OK" if result.success else "FAIL" print( f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" ) @@ -1499,7 +1756,7 @@ class CodegenRunner: ) ) if verbose: - print(f" ✗ {variant}: FAILED - {e}") + print(f" FAIL {variant}: FAILED - {e}") total_time = time.time() - start_total if verbose: @@ -1767,7 +2024,7 @@ class CodegenRunner: link_cmd, capture_output=True, text=True, timeout=300 ) if result.returncode == 0: - print(f" ✓ Library rebuilt: {lib_path.name}") + print(f" OK Library rebuilt: {lib_path.name}") # Clean up object file obj_file.unlink(missing_ok=True) return lib_path @@ -1781,6 +2038,105 @@ class CodegenRunner: print(f" Build error: {e}") return None + def build_libraries_parallel( + self, configs_and_headers: List[Tuple[KernelConfig, Path]], verbose: bool = True + ) -> List[Optional[Path]]: + """ + Build multiple libraries in parallel using ProcessPoolExecutor. + Returns a list of library paths (or None if a build failed) in the same order. + """ + import time + from concurrent.futures import ProcessPoolExecutor, as_completed + + start_time = time.time() + build_dir = get_build_dir() + root = get_dispatcher_root() + ck_root = root.parent + ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp" + static_lib = build_dir / "libck_tile_dispatcher.a" + + if not ctypes_source.exists() or not static_lib.exists(): + if verbose: + print(" Required source or static library missing for parallel build.") + return [None] * len(configs_and_headers) + + args_list = [] + for config, kernel_header in configs_and_headers: + lib_name = f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_{config.tile_str}_{config.pipeline}.so" + lib_path = build_dir / "examples" / lib_name + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{root / 'build/generated_kernels'}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{kernel_header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={config.gfx_arch}", + f'-DGFX_ARCH="{config.gfx_arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={config.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + args_list.append( + { + "compile_cmd": compile_cmd, + "link_cmd": link_cmd, + "lib_path": str(lib_path), + "config_name": f"{config.dtype_a}_{config.layout}_{config.tile_str}", + } + ) + + if verbose: + print( + f"Building {len(args_list)} libraries in parallel (workers={self.max_workers})..." + ) + + results_map = {} + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_hipcc_subprocess, args): i + for i, args in enumerate(args_list) + } + for future in as_completed(futures): + idx = futures[future] + success, lib_path, err = future.result() + results_map[idx] = Path(lib_path) if success else None + if verbose: + status = "OK" if success else f"FAIL ({err})" + print( + f" {status} {Path(lib_path).name if success else args_list[idx]['config_name']}" + ) + + if verbose: + elapsed = time.time() - start_time + print(f"Parallel build finished in {elapsed:.2f}s") + + return [results_map[i] for i in range(len(configs_and_headers))] + def generate_preselected( self, preset: str = "fp16_rcr_essential", output_dir: Optional[Path] = None ) -> CodegenResult: @@ -1933,6 +2289,28 @@ class Registry: """Bind to a loaded dispatcher library.""" self._lib = lib + def build( + self, + verbose: bool = False, + max_workers: Optional[int] = None, + ) -> List["GemmSetupResult"]: + """Parallel JIT compile all kernels in this registry. + + Args: + verbose: Print progress during build. + max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8). + + Returns a GemmSetupResult per registered kernel (same order as get_kernels()). + """ + if not self._kernels: + return [] + return setup_multiple_gemm_dispatchers( + self._kernels, + registry_name=self._name, + verbose=verbose, + max_workers=max_workers, + ) + def __repr__(self) -> str: return f"Registry(name='{self._name}', kernels={self.kernel_count})" @@ -2109,7 +2487,7 @@ def setup_gemm_dispatcher( log(" Validating config...") validation = validate_kernel_config(config) if not validation.is_valid: - log(" ⚠ Auto-correcting configuration...") + log(" WARNING Auto-correcting configuration...") config, was_modified, corrections = auto_correct_kernel_config( config, verbose=verbose ) @@ -2128,13 +2506,13 @@ def setup_gemm_dispatcher( codegen_result = codegen.generate_from_config(config) if not codegen_result.success: - log(" ⚠ Kernel generation: using existing") + log(" WARNING Kernel generation: using existing") # Step 3: Find matching kernel header kernel_header = find_matching_kernel_header(config) result.kernel_header = kernel_header if not kernel_header: - log(" ⚠ No matching kernel header found") + log(" WARNING No matching kernel header found") # Step 4: Load library log(" Loading library...") @@ -2188,11 +2566,11 @@ def setup_gemm_dispatcher( result.error = "Failed to load rebuilt library" return result result.lib = lib - log(f" ✓ Rebuilt library: {lib.get_kernel_name()}") + log(f" OK Rebuilt library: {lib.get_kernel_name()}") else: - log(" ⚠ Rebuild failed, using existing library") + log(" WARNING Rebuild failed, using existing library") else: - log(" ⚠ No kernel header found for config, using existing library") + log(" WARNING No kernel header found for config, using existing library") # Step 5: Create registry and dispatcher log(" Creating registry and dispatcher...") @@ -2203,12 +2581,305 @@ def setup_gemm_dispatcher( dispatcher = Dispatcher(registry=registry, lib=lib) result.dispatcher = dispatcher - log(f" ✓ Ready: {lib.get_kernel_name()}") + log(f" OK Ready: {lib.get_kernel_name()}") result.success = True return result +def setup_multiple_gemm_dispatchers( + configs: List[KernelConfig], + registry_name: str = "gemm_registry", + verbose: bool = True, + max_workers: Optional[int] = None, +) -> List[GemmSetupResult]: + """ + Setup multiple GEMM dispatchers in parallel. + + Pipeline: + 1. Validate + auto-correct each config + 2. Parallel codegen: generate .hpp for each config via --config JSON + 3. Parallel hipcc: compile each .hpp -> .so + 4. Load + wire up each .so into a GemmSetupResult + + Each config gets its own .so, so different tile sizes can coexist. + + Args: + max_workers: Max parallel processes for codegen/compile (default: cpu_count capped at 8). + """ + import sys + + results = [GemmSetupResult(success=False, config=c) for c in configs] + max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + + # -- Step 1: Validate & correct --------------------------------------- + valid_configs = [] + for i, c in enumerate(configs): + val = validate_kernel_config(c) + if not val.is_valid: + c, modified, corrections = auto_correct_kernel_config(c, verbose=False) + results[i].config = c + results[i].corrections = corrections + valid_configs.append(c) + + # -- Step 2: Parallel codegen (one --config JSON per config) ---------- + codegen_script = get_codegen_path() + output_dir = get_generated_kernels_dir() + + codegen_args = [] + for c in valid_configs: + tile_str = c.tile_str + wave_str = f"{c.wave_m}x{c.wave_n}x{c.wave_k}" + warp_str = f"{c.warp_m}x{c.warp_n}x{c.warp_k}" + + tile_config_json = { + "tile_config": { + "tile_m": [c.tile_m], + "tile_n": [c.tile_n], + "tile_k": [c.tile_k], + "warp_m": [c.wave_m], + "warp_n": [c.wave_n], + "warp_k": [c.wave_k], + "warp_tile_m": [c.warp_m], + "warp_tile_n": [c.warp_n], + "warp_tile_k": [c.warp_k], + }, + "trait_config": { + "pipeline": [c.pipeline], + "epilogue": [c.epilogue], + "scheduler": [c.scheduler], + "pad_m": [c.pad_m], + "pad_n": [c.pad_n], + "pad_k": [c.pad_k], + "persistent": [False], + }, + } + + hpp_pattern = ( + f"gemm_{c.dtype_a}_{c.layout}_{c.pipeline}_{c.epilogue}_{c.scheduler}" + f"_*_{tile_str}_{wave_str}_{warp_str}.hpp" + ) + + codegen_args.append( + { + "python": sys.executable, + "codegen_script": str(codegen_script), + "output_dir": str(output_dir), + "dtype": c.dtype_a, + "layout": c.layout, + "gpu_target": c.gfx_arch, + "tile_config_json": tile_config_json, + "hpp_glob_pattern": hpp_pattern, + } + ) + + if verbose: + print( + f"Generating {len(codegen_args)} kernel headers in parallel (workers={max_workers})..." + ) + + headers: List[Optional[Path]] = [None] * len(valid_configs) + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_generate_single_kernel_subprocess, a): i + for i, a in enumerate(codegen_args) + } + for future in as_completed(futures): + idx = futures[future] + ok, hdr_str, err = future.result() + if ok and hdr_str: + headers[idx] = Path(hdr_str) + results[idx].kernel_header = Path(hdr_str) + if verbose: + print( + f" OK [{idx}] {valid_configs[idx].tile_str}: {Path(hdr_str).name}" + ) + else: + results[idx].error = f"Codegen: {err}" + if verbose: + print(f" FAIL [{idx}] {valid_configs[idx].tile_str}: {err}") + + # For configs rejected by arch filter, map to nearest arch-valid header. + fallback_needed = [i for i, h in enumerate(headers) if h is None] + if fallback_needed: + if verbose: + print( + f"Resolving {len(fallback_needed)} configs via arch-valid GEMM catalog..." + ) + + catalog_cache: Dict[Tuple[str, str, str, str], List[Path]] = {} + for i in fallback_needed: + c = valid_configs[i] + key = (c.gfx_arch, c.dtype_a, c.layout, c.variant) + if key not in catalog_cache: + catalog_dir = ( + output_dir + / "_arch_valid_catalog" + / (f"{c.gfx_arch}_{c.dtype_a}_{c.layout}_{c.variant}") + ) + ok, catalog_headers, err = _generate_arch_valid_gemm_headers( + python_exe=sys.executable, + codegen_script=codegen_script, + output_dir=catalog_dir, + dtype=c.dtype_a, + layout=c.layout, + gpu_target=c.gfx_arch, + variant=c.variant, + ) + if not ok: + catalog_headers = [] + if verbose: + print(f" FAIL [{i}] catalog generation: {err}") + catalog_cache[key] = catalog_headers + + chosen, meta = _select_best_arch_valid_gemm_header(c, catalog_cache[key]) + if chosen is None or meta is None: + continue + + headers[i] = chosen + results[i].kernel_header = chosen + results[i].error = "" + + # Keep Python-side config aligned with the selected kernel header. + valid_configs[i].pipeline = str(meta["pipeline"]) + valid_configs[i].epilogue = str(meta["epilogue"]) + valid_configs[i].scheduler = str(meta["scheduler"]) + valid_configs[i].pad_m = bool(meta["pad_m"]) + valid_configs[i].pad_n = bool(meta["pad_n"]) + valid_configs[i].pad_k = bool(meta["pad_k"]) + valid_configs[i].tile_m = int(meta["tile"][0]) + valid_configs[i].tile_n = int(meta["tile"][1]) + valid_configs[i].tile_k = int(meta["tile"][2]) + valid_configs[i].wave_m = int(meta["wave"][0]) + valid_configs[i].wave_n = int(meta["wave"][1]) + valid_configs[i].wave_k = int(meta["wave"][2]) + valid_configs[i].warp_m = int(meta["warp"][0]) + valid_configs[i].warp_n = int(meta["warp"][1]) + valid_configs[i].warp_k = int(meta["warp"][2]) + results[i].config = valid_configs[i] + + if verbose: + print(f" INFO [{i}] mapped to arch-valid header: {chosen.name}") + + # -- Step 3: Parallel hipcc compilation ------------------------------- + root = get_dispatcher_root() + ck_root = root.parent + build_dir = get_build_dir() + ctypes_source = root / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp" + static_lib = build_dir / "libck_tile_dispatcher.a" + + if not ctypes_source.exists() or not static_lib.exists(): + for i in range(len(valid_configs)): + if results[i].error == "": + results[ + i + ].error = "Missing ctypes source or static library for compilation" + return results + + compile_jobs = [] + compile_index_map = {} + for i, c in enumerate(valid_configs): + hdr = headers[i] + if hdr is None: + continue + + lib_name = ( + f"libdispatcher_gemm_{c.dtype_a}_{c.layout}_{c.tile_str}_{c.pipeline}.so" + ) + lib_path = build_dir / "examples" / lib_name + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{str(output_dir)}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{hdr}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={c.gfx_arch}", + f'-DGFX_ARCH="{c.gfx_arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={c.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + compile_index_map[len(compile_jobs)] = i + compile_jobs.append( + { + "compile_cmd": compile_cmd, + "link_cmd": link_cmd, + "lib_path": str(lib_path), + } + ) + + if verbose and compile_jobs: + print( + f"Compiling {len(compile_jobs)} libraries in parallel (workers={max_workers})..." + ) + + lib_paths: Dict[int, Optional[Path]] = {} + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_run_hipcc_subprocess, job): j + for j, job in enumerate(compile_jobs) + } + for future in as_completed(futures): + j = futures[future] + i = compile_index_map[j] + ok, lp, err = future.result() + if ok and lp: + lib_paths[i] = Path(lp) + if verbose: + print(f" OK [{i}] {valid_configs[i].tile_str}: {Path(lp).name}") + else: + results[i].error = f"Compile: {err}" + if verbose: + print(f" FAIL [{i}] {valid_configs[i].tile_str}: {err}") + + # -- Step 4: Load libraries and create dispatchers -------------------- + for i, c in enumerate(valid_configs): + lp = lib_paths.get(i) + if lp is None: + continue + + lib = DispatcherLib.load(lp) + if lib is not None and lib.initialize(): + results[i].lib = lib + reg = Registry(name=f"{registry_name}_{i}", lib=lib) + reg.register_kernel(c) + results[i].registry = reg + results[i].dispatcher = Dispatcher(registry=reg, lib=lib) + results[i].success = True + else: + results[i].error = "Failed to load compiled library" + + if verbose: + ok_count = sum(1 for r in results if r.success) + print(f"Setup complete: {ok_count}/{len(results)} dispatchers ready") + + return results + + def cleanup_gemm(): """ Cleanup function to call after running GEMM examples. diff --git a/dispatcher/python/dispatcher_common.py b/dispatcher/python/dispatcher_common.py new file mode 100644 index 0000000000..a19ecbdb49 --- /dev/null +++ b/dispatcher/python/dispatcher_common.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Shared Python dispatcher utilities for GEMM and grouped convolution. + +Extracted from ctypes_utils.py (GEMM) + compile_grouped_conv_examples.py (grouped conv). +Both ctypes_utils.py and grouped_conv_utils.py import from here to +eliminate duplication. + +Best-of-both: + - Validation and auto-correction return typed objects (GEMM pattern) + - Colors class with cross-platform ANSI handling (conv pattern) + - Phased output helpers (conv pattern) + - logging module instead of bare print() (shared improvement) +""" + +import logging +import shutil +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +log = logging.getLogger(__name__) + + +# ============================================================================ +# Path Configuration +# ============================================================================ + + +def get_dispatcher_root() -> Path: + """Get the dispatcher root directory (parent of python/).""" + return Path(__file__).parent.parent + + +def get_ck_root() -> Path: + """Get the CK root directory (parent of dispatcher/).""" + return get_dispatcher_root().parent + + +def get_build_dir() -> Path: + """Get the build directory.""" + return get_dispatcher_root() / "build" + + +def get_generated_kernels_dir() -> Path: + """Get the generated kernels directory.""" + return get_build_dir() / "generated_kernels" + + +def get_codegen_dir() -> Path: + """Get the codegen scripts directory.""" + return get_dispatcher_root() / "codegen" + + +# ============================================================================ +# Architecture Filter Data +# ============================================================================ + +_arch_data_cache: Optional[Dict[str, Any]] = None + + +def detect_gpu_arch(fallback: str = "gfx942") -> str: + """Detect the GPU architecture from rocminfo. Falls back to the given default.""" + import subprocess + + try: + out = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.DEVNULL + ) + for line in out.splitlines(): + if "Name:" in line and "gfx" in line: + return line.split()[-1].strip() + except Exception: + pass + return fallback + + +def get_arch_filter_data() -> Dict[str, Any]: + """Load arch filter data from arch_specs_generated if available. + + Returns dict with keys: trait_unsupported, warp_combos, + warp_tile_combos, supported_archs. + """ + global _arch_data_cache + if _arch_data_cache is not None: + return _arch_data_cache + + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + TRAIT_UNSUPPORTED_COMBINATIONS, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + _arch_data_cache = { + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + _arch_data_cache = { + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + }, + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + "gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + + return _arch_data_cache + + +# ============================================================================ +# Validation Result +# ============================================================================ + + +@dataclass +class ValidationResultBase: + """Result of kernel config validation (shared base for GEMM and conv).""" + + is_valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + suggested_fixes: Dict[str, Any] = field(default_factory=dict) + + def print_result(self, indent: str = " "): + if self.is_valid: + print(f"{indent}OK Configuration valid") + else: + print(f"{indent}WARNING Configuration has issues:") + for err in self.errors: + print(f"{indent} - {err}") + if self.warnings: + for warn in self.warnings: + print(f"{indent} Warning: {warn}") + if self.suggested_fixes: + print(f"{indent} Suggested fixes:") + for key, val in self.suggested_fixes.items(): + print(f"{indent} {key}: {val}") + + +# ============================================================================ +# Validation Helpers +# ============================================================================ + + +def validate_wave_config(wave_cfg: List[int], arch: str) -> Tuple[bool, str]: + """Validate a [wave_m, wave_n, wave_k] config for *arch*. + + Returns (is_valid, error_message). Empty string on success. + """ + data = get_arch_filter_data() + valid_waves = data["warp_combos"].get(arch, [[2, 2, 1]]) + if wave_cfg in valid_waves: + return True, "" + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in valid_waves) + return ( + False, + f"Unsupported wave configuration {wave_cfg} for {arch}. " + f"Valid wave configs: {valid_str}", + ) + + +def validate_warp_tile_config( + warp_cfg: List[int], arch: str, dtype: str +) -> Tuple[bool, str]: + """Validate a [warp_m, warp_n, warp_k] config for *arch*/*dtype*. + + Returns (is_valid, error_message). Empty string on success. + """ + data = get_arch_filter_data() + acc = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc}" + valid_tiles = ( + data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + if warp_cfg in valid_tiles: + return True, "" + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in valid_tiles[:5]) + return ( + False, + f"Unsupported warp tile {warp_cfg} for {arch}/{dtype}. " + f"Valid warp tiles: {valid_str}", + ) + + +def validate_trait_combo( + pipeline: str, epilogue: str, scheduler: str +) -> Tuple[bool, str]: + """Validate a (pipeline, epilogue, scheduler) combination. + + Returns (is_valid, error_message). Empty string on success. + """ + data = get_arch_filter_data() + combo = (pipeline, epilogue, scheduler) + if combo in data["trait_unsupported"]: + return ( + False, + f"Unsupported trait combination: pipeline={pipeline}, " + f"epilogue={epilogue}, scheduler={scheduler}", + ) + return True, "" + + +# ============================================================================ +# Auto-Correction Helpers +# ============================================================================ + + +def auto_correct_wave(wave_cfg: List[int], arch: str) -> List[int]: + """Return the first valid wave config for *arch*. + + If *wave_cfg* is already valid, returns it unchanged. + """ + data = get_arch_filter_data() + valid_waves = data["warp_combos"].get(arch, [[2, 2, 1]]) + if wave_cfg in valid_waves: + return wave_cfg + return valid_waves[0] if valid_waves else [2, 2, 1] + + +def auto_correct_trait(pipeline: str, scheduler: str) -> Tuple[str, str]: + """Return a corrected (pipeline, scheduler) pair. + + If the compute pipeline doesn't support interwave, switch to intrawave. + """ + data = get_arch_filter_data() + for epilogue in ("cshuffle", "default"): + if (pipeline, epilogue, scheduler) in data["trait_unsupported"]: + return pipeline, "intrawave" + return pipeline, scheduler + + +# ============================================================================ +# Colors (adopted from compile_grouped_conv_examples.py -- cross-platform) +# ============================================================================ + + +class Colors: + """Cross-platform ANSI color support. + + Respects sys.platform (no ANSI on Windows) and isatty() check so + piped/redirected output stays clean. + """ + + _GREEN = "\033[0;32m" + _YELLOW = "\033[1;33m" + _RED = "\033[0;31m" + _CYAN = "\033[0;36m" + _BOLD = "\033[1m" + _NC = "\033[0m" + + @classmethod + def _use_color(cls) -> bool: + return ( + sys.platform != "win32" + and hasattr(sys.stdout, "isatty") + and sys.stdout.isatty() + ) + + @classmethod + def green(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._GREEN}{text}{cls._NC}" + return text + + @classmethod + def red(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._RED}{text}{cls._NC}" + return text + + @classmethod + def yellow(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._YELLOW}{text}{cls._NC}" + return text + + @classmethod + def cyan(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._CYAN}{text}{cls._NC}" + return text + + @classmethod + def bold(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._BOLD}{text}{cls._NC}" + return text + + +# ============================================================================ +# Phased Output Helpers +# ============================================================================ + + +def print_phase(number: int, description: str) -> None: + """Print a phase header (e.g. 'Phase 1: Codegen').""" + print(f"\n{'=' * 60}") + print(f" Phase {number}: {description}") + print(f"{'=' * 60}") + + +def print_success(message: str) -> None: + """Print a success message.""" + print(f" OK {Colors.green(message)}") + + +def print_error(message: str) -> None: + """Print an error message.""" + print(f" FAIL {Colors.red(message)}") + + +def print_info(message: str) -> None: + """Print an info message.""" + print(f" {Colors.cyan(message)}") + + +# ============================================================================ +# Cleanup Helpers +# ============================================================================ + + +def cleanup_generated_kernels(gen_dir: Optional[Path] = None) -> None: + """Remove generated kernel directory if it exists.""" + if gen_dir is None: + gen_dir = get_generated_kernels_dir() + if gen_dir.exists(): + shutil.rmtree(gen_dir, ignore_errors=True) + log.info("Cleaned up generated kernels at %s", gen_dir) + + +# ============================================================================ +# Tool Helpers +# ============================================================================ + + +def find_hipcc() -> Optional[str]: + """Find the hipcc compiler.""" + import os + + candidates = [ + os.environ.get("HIPCC"), + "/opt/rocm/bin/hipcc", + shutil.which("hipcc"), + ] + for path in candidates: + if path and os.path.isfile(path): + return path + return None diff --git a/dispatcher/python/grouped_conv_utils.py b/dispatcher/python/grouped_conv_utils.py new file mode 100644 index 0000000000..cd6ef5647c --- /dev/null +++ b/dispatcher/python/grouped_conv_utils.py @@ -0,0 +1,1806 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Grouped Convolution Dispatcher Utilities + +Typed Python API for grouped convolution kernels, matching the patterns from +the old conv_utils.py and the GEMM ctypes_utils.py. + +Classes: + GroupedConvKernelConfig - Kernel configuration (tile, wave, pipeline, arch) + GroupedConvProblem - Runtime problem specification (N,C,K,H,W,etc.) + GroupedConvProblemC - ctypes struct matching C++ ConvProblemC + GroupedConvDispatcherLib - Wrapper for libdispatcher_conv_lib.so + GpuGroupedConvRunner - High-level GPU execution runner + GroupedConvResult - Result of GPU execution (output, time, tflops) + GroupedConvRegistry - Collection of kernel configs with JSON export + +Usage: + from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GpuGroupedConvRunner, + ) + + config = GroupedConvKernelConfig(variant="forward", ndim_spatial=2) + problem = GroupedConvProblem(N=1, C=64, K=128, Hi=28, Wi=28, Y=3, X=3, + stride_h=1, pad_h=1, direction="forward") + runner = GpuGroupedConvRunner() + if runner.is_available(): + result = runner.run(input_np, weight_np, problem) + print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}") +""" + +import ctypes +import json +import copy +import subprocess +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from dispatcher_common import ( + ValidationResultBase, + auto_correct_trait, + auto_correct_wave, + get_arch_filter_data, + validate_trait_combo, + validate_wave_config, + validate_warp_tile_config, +) + + +# ============================================================================= +# Constants +# ============================================================================= + +VALID_VARIANTS = ("forward", "bwd_data", "bwd_weight") +VALID_NDIM_SPATIAL = (1, 2, 3) +BACKWARD_VARIANTS = ("bwd_data", "bwd_weight") +BACKWARD_PIPELINES = ("compv3", "mem") + +VARIANT_ALIASES = { + "2d_fwd": "forward", + "2d_bwdd": "bwd_data", + "2d_bwdw": "bwd_weight", + "fwd": "forward", + "bwdd": "bwd_data", + "bwdw": "bwd_weight", +} + +DIRECTION_MAP = {"forward": 0, "bwd_data": 1, "bwd_weight": 2} + + +def _resolve_variant(v: str) -> str: + return VARIANT_ALIASES.get(v, v) + + +# ============================================================================= +# GroupedConvDataType +# ============================================================================= + + +class GroupedConvDataType(Enum): + FP16 = "fp16" + BF16 = "bf16" + FP32 = "fp32" + FP8 = "fp8" + BF8 = "bf8" + INT8 = "int8" + + +# ============================================================================= +# GroupedConvKernelConfig +# ============================================================================= + + +@dataclass +class GroupedConvKernelConfig: + """Complete kernel configuration for grouped convolution. + + Captures all parameters needed to identify and run a specific kernel. + Mirrors the C++ GroupedConvSignature + GroupedConvAlgorithm. + """ + + # What: signature + variant: str = "forward" + ndim_spatial: int = 2 + dtype: str = "fp16" + layout: str = "nhwgc" + arch: str = "gfx942" + + # How: algorithm - tile shape + tile_m: int = 1 + tile_n: int = 128 + tile_k: int = 128 + + # How: wave config + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + + # How: warp tile + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + # How: pipeline traits + pipeline: str = "compv4" + epilogue: str = "cshuffle" + scheduler: str = "intrawave" + + # ConvConfigBase parity fields + vector_size_a: int = 4 + vector_size_b: int = 8 + vector_size_c: int = 8 + block_per_cu: int = 1 + num_wave_groups: int = 1 + num_groups_to_merge: int = 1 + + # Padding (enables arbitrary problem sizes) + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + + def __post_init__(self): + self.variant = _resolve_variant(self.variant) + if ( + self.variant in BACKWARD_VARIANTS + and self.pipeline not in BACKWARD_PIPELINES + ): + self.pipeline = "compv3" + + @property + def tile_str(self) -> str: + return f"{self.tile_m}x{self.tile_n}x{self.tile_k}" + + @property + def wave_str(self) -> str: + return f"{self.wave_m}x{self.wave_n}x{self.wave_k}" + + @property + def warp_str(self) -> str: + return f"{self.warp_tile_m}x{self.warp_tile_n}x{self.warp_tile_k}" + + @property + def vec_str(self) -> str: + return f"{self.vector_size_a}x{self.vector_size_b}x{self.vector_size_c}" + + @property + def name(self) -> str: + return ( + f"grouped_conv_{self.variant}_{self.dtype}_{self.ndim_spatial}d_" + f"{self.tile_str}_{self.pipeline}" + ) + + def to_dict(self) -> dict: + """Convert to legacy dict format for codegen compatibility.""" + return { + "tile_config": { + "tile_m": [self.tile_m], + "tile_n": [self.tile_n], + "tile_k": [self.tile_k], + "wave_m": [self.wave_m], + "wave_n": [self.wave_n], + "wave_k": [self.wave_k], + "warp_tile_m": [self.warp_tile_m], + "warp_tile_n": [self.warp_tile_n], + "warp_tile_k": [self.warp_tile_k], + }, + "trait_config": { + "pipeline": [self.pipeline], + "epilogue": [self.epilogue], + "scheduler": [self.scheduler], + "pad_m": [self.pad_m], + "pad_n": [self.pad_n], + "pad_k": [self.pad_k], + "vector_size_a": [self.vector_size_a], + "vector_size_b": [self.vector_size_b], + "vector_size_c": [self.vector_size_c], + "block_per_cu": [self.block_per_cu], + "num_wave_groups": [self.num_wave_groups], + "num_groups_to_merge": [self.num_groups_to_merge], + }, + "variant": self.variant, + "ndim_spatial": self.ndim_spatial, + "arch": self.arch, + "layout": self.layout, + "dtype": self.dtype, + } + + def to_json_obj(self) -> dict: + """Serializable dict for JSON export.""" + return { + "name": self.name, + "signature": { + "variant": self.variant, + "dtype": self.dtype, + "ndim_spatial": self.ndim_spatial, + "layout": self.layout, + }, + "algorithm": { + "tile_m": self.tile_m, + "tile_n": self.tile_n, + "tile_k": self.tile_k, + "wave": self.wave_str, + "warp": self.warp_str, + "pipeline": self.pipeline, + "epilogue": self.epilogue, + "scheduler": self.scheduler, + "vector_sizes": [ + self.vector_size_a, + self.vector_size_b, + self.vector_size_c, + ], + "block_per_cu": self.block_per_cu, + "num_wave_groups": self.num_wave_groups, + "num_groups_to_merge": self.num_groups_to_merge, + }, + "arch": self.arch, + } + + def print_config(self, indent: str = " "): + print(f"{indent}GroupedConvKernelConfig:") + print(f"{indent} Variant: {self.variant} {self.ndim_spatial}D") + print(f"{indent} Dtype: {self.dtype}") + print(f"{indent} Layout: {self.layout}") + print(f"{indent} Arch: {self.arch}") + print(f"{indent} Tile: {self.tile_str}") + print(f"{indent} Wave: {self.wave_str}") + print(f"{indent} Warp: {self.warp_str}") + print(f"{indent} Pipeline: {self.pipeline}/{self.scheduler}/{self.epilogue}") + print(f"{indent} VecSizes: {self.vec_str}") + print( + f"{indent} BlockCU: {self.block_per_cu} WaveGroups: {self.num_wave_groups} MergeGroups: {self.num_groups_to_merge}" + ) + + +# ============================================================================= +# GroupedConvProblem +# ============================================================================= + + +@dataclass +class GroupedConvProblem: + """Runtime convolution problem specification. + + Describes the actual sizes of a convolution to be computed. + Matches the old ConvProblem from conv_utils.py. + """ + + N: int = 1 + C: int = 64 + K: int = 128 + G: int = 1 + + Hi: int = 28 + Wi: int = 28 + Di: int = 1 + + Y: int = 3 + X: int = 3 + Z: int = 1 + + stride_h: int = 1 + stride_w: int = 1 + stride_d: int = 1 + + pad_h: int = 0 + pad_w: int = 0 + pad_d: int = 0 + + dilation_h: int = 1 + dilation_w: int = 1 + dilation_d: int = 1 + + direction: str = "forward" + split_k: int = 1 + + @property + def Ho(self) -> int: + eff_y = (self.Y - 1) * self.dilation_h + 1 + return (self.Hi + 2 * self.pad_h - eff_y) // self.stride_h + 1 + + @property + def Wo(self) -> int: + eff_x = (self.X - 1) * self.dilation_w + 1 + return (self.Wi + 2 * self.pad_w - eff_x) // self.stride_w + 1 + + @property + def Do(self) -> int: + eff_z = (self.Z - 1) * self.dilation_d + 1 + return (self.Di + 2 * self.pad_d - eff_z) // self.stride_d + 1 + + @property + def is_3d(self) -> bool: + return self.Di > 1 or self.Z > 1 or self.pad_d > 0 + + @property + def ndim_spatial(self) -> int: + return 3 if self.is_3d else 2 + + @property + def flops(self) -> float: + """Total FLOPs for this convolution (any direction, same count).""" + c_per_group = self.C // self.G + if self.is_3d: + return ( + 2.0 + * self.N + * self.K + * self.Do + * self.Ho + * self.Wo + * c_per_group + * self.Z + * self.Y + * self.X + ) + return 2.0 * self.N * self.K * self.Ho * self.Wo * c_per_group * self.Y * self.X + + @property + def gflops(self) -> float: + return self.flops / 1e9 + + def input_shape(self) -> tuple: + """NHWGC or NDHWGC layout.""" + c_per_g = self.C // self.G + if self.is_3d: + return (self.N, self.Di, self.Hi, self.Wi, self.G, c_per_g) + return (self.N, self.Hi, self.Wi, self.G, c_per_g) + + def weight_shape(self) -> tuple: + """GKYXC or GKZYXC layout.""" + c_per_g = self.C // self.G + k_per_g = self.K // self.G + if self.is_3d: + return (self.G, k_per_g, self.Z, self.Y, self.X, c_per_g) + return (self.G, k_per_g, self.Y, self.X, c_per_g) + + def output_shape(self) -> tuple: + """NHWGK or NDHWGK layout.""" + k_per_g = self.K // self.G + if self.is_3d: + return (self.N, self.Do, self.Ho, self.Wo, self.G, k_per_g) + return (self.N, self.Ho, self.Wo, self.G, k_per_g) + + def print_problem(self, indent: str = " "): + dim_str = "3D" if self.is_3d else "2D" + print(f"{indent}GroupedConvProblem ({dim_str} {self.direction}):") + print(f"{indent} Batch: N={self.N}, G={self.G}") + print(f"{indent} Channels: C={self.C}, K={self.K}") + if self.is_3d: + print(f"{indent} Input: Di={self.Di}, Hi={self.Hi}, Wi={self.Wi}") + print(f"{indent} Filter: Z={self.Z}, Y={self.Y}, X={self.X}") + print(f"{indent} Output: Do={self.Do}, Ho={self.Ho}, Wo={self.Wo}") + else: + print(f"{indent} Input: Hi={self.Hi}, Wi={self.Wi}") + print(f"{indent} Filter: Y={self.Y}, X={self.X}") + print(f"{indent} Output: Ho={self.Ho}, Wo={self.Wo}") + print(f"{indent} GFLOPs: {self.gflops:.2f}") + + +# ============================================================================= +# GroupedConvProblemC (ctypes struct matching C++) +# ============================================================================= + + +class GroupedConvProblemC(ctypes.Structure): + """C structure matching ConvProblemC in conv_ctypes_lib.cpp.""" + + _fields_ = [ + ("N", ctypes.c_int), + ("G", ctypes.c_int), + ("C", ctypes.c_int), + ("K", ctypes.c_int), + ("input_d", ctypes.c_int), + ("input_h", ctypes.c_int), + ("input_w", ctypes.c_int), + ("filter_z", ctypes.c_int), + ("filter_y", ctypes.c_int), + ("filter_x", ctypes.c_int), + ("stride_d", ctypes.c_int), + ("stride_h", ctypes.c_int), + ("stride_w", ctypes.c_int), + ("pad_d", ctypes.c_int), + ("pad_h", ctypes.c_int), + ("pad_w", ctypes.c_int), + ("dilation_d", ctypes.c_int), + ("dilation_h", ctypes.c_int), + ("dilation_w", ctypes.c_int), + ("direction", ctypes.c_int), + ("split_k", ctypes.c_int), + ] + + @classmethod + def from_problem(cls, p: GroupedConvProblem) -> "GroupedConvProblemC": + c = cls() + c.N, c.G, c.C, c.K = p.N, p.G, p.C, p.K + c.input_d, c.input_h, c.input_w = p.Di, p.Hi, p.Wi + c.filter_z, c.filter_y, c.filter_x = p.Z, p.Y, p.X + c.stride_d, c.stride_h, c.stride_w = p.stride_d, p.stride_h, p.stride_w + c.pad_d, c.pad_h, c.pad_w = p.pad_d, p.pad_h, p.pad_w + c.dilation_d, c.dilation_h, c.dilation_w = ( + p.dilation_d, + p.dilation_h, + p.dilation_w, + ) + c.direction = DIRECTION_MAP.get(p.direction, 0) + c.split_k = getattr(p, "split_k", 1) + return c + + +# ============================================================================= +# GroupedConvResult +# ============================================================================= + + +@dataclass +class GroupedConvResult: + """Result of GPU convolution execution.""" + + success: bool = False + time_ms: float = 0.0 + tflops: float = 0.0 + output: Optional[np.ndarray] = None + error: str = "" + + +# ============================================================================= +# GroupedConvDispatcherLib +# ============================================================================= + + +class GroupedConvDispatcherLib: + """Wrapper for the compiled convolution dispatcher library. + + Provides Python interface to the C API in conv_ctypes_lib.cpp. + """ + + SEARCH_PATHS = [ + "build/examples/libdispatcher_conv_lib.so", + "build/bindings/libdispatcher_conv_lib.so", + "build/lib/libdispatcher_conv_lib.so", + ] + + def __init__(self, lib: ctypes.CDLL, path: Path): + self._lib = lib + self._path = path + self._setup_functions() + + def _setup_functions(self): + self._lib.conv_dispatcher_init.argtypes = [] + self._lib.conv_dispatcher_init.restype = ctypes.c_int + self._lib.conv_dispatcher_cleanup.argtypes = [] + self._lib.conv_dispatcher_cleanup.restype = ctypes.c_int + self._lib.conv_dispatcher_version.argtypes = [] + self._lib.conv_dispatcher_version.restype = ctypes.c_char_p + self._lib.conv_dispatcher_has_kernels.argtypes = [] + self._lib.conv_dispatcher_has_kernels.restype = ctypes.c_int + self._lib.conv_dispatcher_has_bwd_data.argtypes = [] + self._lib.conv_dispatcher_has_bwd_data.restype = ctypes.c_int + self._lib.conv_dispatcher_has_bwd_weight.argtypes = [] + self._lib.conv_dispatcher_has_bwd_weight.restype = ctypes.c_int + self._lib.conv_dispatcher_get_kernel_count.argtypes = [] + self._lib.conv_dispatcher_get_kernel_count.restype = ctypes.c_int + self._lib.conv_dispatcher_get_kernel_name.argtypes = [ + ctypes.c_int, + ctypes.c_char_p, + ctypes.c_int, + ] + self._lib.conv_dispatcher_get_kernel_name.restype = ctypes.c_int + self._lib.conv_dispatcher_is_supported.argtypes = [ + ctypes.POINTER(GroupedConvProblemC), + ] + self._lib.conv_dispatcher_is_supported.restype = ctypes.c_int + self._lib.conv_dispatcher_run.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.POINTER(GroupedConvProblemC), + ctypes.c_void_p, + ] + self._lib.conv_dispatcher_run.restype = ctypes.c_float + + @classmethod + def find(cls) -> Optional["GroupedConvDispatcherLib"]: + """Search standard paths for the conv library.""" + root = Path(__file__).parent.parent + for rel in cls.SEARCH_PATHS: + path = root / rel + if path.exists(): + try: + lib = ctypes.CDLL(str(path)) + return cls(lib, path) + except OSError: + continue + return None + + @property + def path(self) -> Path: + return self._path + + def initialize(self): + self._lib.conv_dispatcher_init() + + def cleanup(self): + self._lib.conv_dispatcher_cleanup() + + def version(self) -> str: + return self._lib.conv_dispatcher_version().decode() + + def has_forward(self) -> bool: + return self._lib.conv_dispatcher_has_kernels() != 0 + + def has_bwd_data(self) -> bool: + return self._lib.conv_dispatcher_has_bwd_data() != 0 + + def has_bwd_weight(self) -> bool: + return self._lib.conv_dispatcher_has_bwd_weight() != 0 + + def kernel_count(self) -> int: + return self._lib.conv_dispatcher_get_kernel_count() + + def kernel_names(self) -> List[str]: + names = [] + for i in range(self.kernel_count()): + buf = ctypes.create_string_buffer(256) + if self._lib.conv_dispatcher_get_kernel_name(i, buf, 256) == 0: + names.append(buf.value.decode()) + return names + + def is_supported(self, problem: GroupedConvProblem) -> bool: + pc = GroupedConvProblemC.from_problem(problem) + return self._lib.conv_dispatcher_is_supported(ctypes.byref(pc)) != 0 + + def run( + self, a_ptr: int, b_ptr: int, c_ptr: int, problem: GroupedConvProblem + ) -> float: + """Run convolution. Returns time_ms (>0 success, <0 error).""" + pc = GroupedConvProblemC.from_problem(problem) + return self._lib.conv_dispatcher_run( + a_ptr, b_ptr, c_ptr, ctypes.byref(pc), None + ) + + +# ============================================================================= +# GpuGroupedConvRunner +# ============================================================================= + + +class GpuGroupedConvRunner: + """High-level GPU convolution runner. + + Handles library loading, HIP memory management, and kernel execution. + Follows the same pattern as the old GpuConvRunner from conv_utils.py. + + Usage: + runner = GpuGroupedConvRunner() + if runner.is_available(): + result = runner.run(input_np, weight_np, problem) + print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}") + """ + + HIP_MEMCPY_H2D = 1 + HIP_MEMCPY_D2H = 2 + + def __init__(self, lib_path: Optional[str] = None): + self._dispatch_lib: Optional[GroupedConvDispatcherLib] = None + self._hip = None + self._initialized = False + + try: + if lib_path: + lib = ctypes.CDLL(lib_path) + self._dispatch_lib = GroupedConvDispatcherLib(lib, Path(lib_path)) + else: + self._dispatch_lib = GroupedConvDispatcherLib.find() + + if self._dispatch_lib is None: + return + + self._hip = ctypes.CDLL("libamdhip64.so") + self._hip.hipMalloc.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_size_t, + ] + self._hip.hipMalloc.restype = ctypes.c_int + self._hip.hipFree.argtypes = [ctypes.c_void_p] + self._hip.hipFree.restype = ctypes.c_int + self._hip.hipMemcpy.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ] + self._hip.hipMemcpy.restype = ctypes.c_int + self._hip.hipDeviceSynchronize.argtypes = [] + self._hip.hipDeviceSynchronize.restype = ctypes.c_int + + self._dispatch_lib.initialize() + self._initialized = True + except Exception: + self._initialized = False + + def is_available(self) -> bool: + return self._initialized and self._dispatch_lib is not None + + @property + def library_path(self) -> Optional[str]: + if self._dispatch_lib: + return str(self._dispatch_lib.path) + return None + + @property + def lib(self) -> Optional[GroupedConvDispatcherLib]: + return self._dispatch_lib + + def run( + self, + input_np: np.ndarray, + weight_np: np.ndarray, + problem: GroupedConvProblem, + output_np: Optional[np.ndarray] = None, + ) -> GroupedConvResult: + """Run convolution on GPU. + + Args: + input_np: For forward: X (NHWGC). For bwd_data: dY. For bwd_weight: X. + weight_np: For forward: W (GKYXC). For bwd_data: W. For bwd_weight: dY. + problem: Problem specification. + output_np: Optional pre-allocated output buffer. + + Returns: + GroupedConvResult with success, time_ms, tflops, output. + """ + if not self.is_available(): + return GroupedConvResult(error="GPU not available") + + try: + # Determine output shape based on direction + d = problem.direction + if d == "bwd_data": + out_shape = problem.input_shape() + elif d == "bwd_weight": + out_shape = problem.weight_shape() + else: + out_shape = problem.output_shape() + + if output_np is None: + output_np = np.zeros(out_shape, dtype=input_np.dtype) + + output_size = output_np.nbytes + + # Allocate GPU memory + d_a, d_b, d_c = ctypes.c_void_p(), ctypes.c_void_p(), ctypes.c_void_p() + self._hip.hipMalloc(ctypes.byref(d_a), input_np.nbytes) + self._hip.hipMalloc(ctypes.byref(d_b), weight_np.nbytes) + self._hip.hipMalloc(ctypes.byref(d_c), output_size) + + # Host to device + self._hip.hipMemcpy( + d_a, input_np.ctypes.data, input_np.nbytes, self.HIP_MEMCPY_H2D + ) + self._hip.hipMemcpy( + d_b, weight_np.ctypes.data, weight_np.nbytes, self.HIP_MEMCPY_H2D + ) + self._hip.hipDeviceSynchronize() + + # Launch kernel + time_ms = self._dispatch_lib.run(d_a.value, d_b.value, d_c.value, problem) + self._hip.hipDeviceSynchronize() + + result = GroupedConvResult() + + if time_ms > 0: + # Device to host + self._hip.hipMemcpy( + output_np.ctypes.data, d_c, output_size, self.HIP_MEMCPY_D2H + ) + self._hip.hipDeviceSynchronize() + result.success = True + result.time_ms = time_ms + result.tflops = problem.flops / (time_ms * 1e9) + result.output = output_np + else: + result.error = ( + "unsupported" + if time_ms == -3.0 + else "no kernel" + if time_ms == -2.0 + else f"error (code {time_ms})" + ) + + # Free GPU memory + self._hip.hipFree(d_a) + self._hip.hipFree(d_b) + self._hip.hipFree(d_c) + + return result + + except Exception as e: + return GroupedConvResult(error=str(e)) + + def cleanup(self): + if self._dispatch_lib: + try: + self._dispatch_lib.cleanup() + except Exception: + pass + + +# ============================================================================= +# GroupedConvRegistry +# ============================================================================= + + +class GroupedConvRegistry: + """Collection of grouped conv kernel configs with JSON export/import.""" + + def __init__(self, name: str = "default"): + self.name = name + self._kernels: List[GroupedConvKernelConfig] = [] + + def add(self, config: GroupedConvKernelConfig): + self._kernels.append(config) + + @property + def kernels(self) -> List[GroupedConvKernelConfig]: + return list(self._kernels) + + def __len__(self) -> int: + return len(self._kernels) + + def select( + self, problem: "GroupedConvProblem", heuristic=None + ) -> Optional[GroupedConvKernelConfig]: + """Select the best kernel for a problem. + + Args: + problem: The convolution problem. + heuristic: Optional callable(problem) -> List[str] returning + ranked kernel name substrings. The registry tries + each in order; falls back to first matching kernel. + + Returns: + The best matching GroupedConvKernelConfig, or None. + """ + matching = [k for k in self._kernels if k.variant == problem.direction] + if not matching: + return None + + if heuristic is not None: + ranked = heuristic(problem) + for hint in ranked: + for k in matching: + if hint in k.name: + return k + + return matching[0] if matching else None + + def filter_by_variant(self, variant: str) -> "GroupedConvRegistry": + variant = _resolve_variant(variant) + reg = GroupedConvRegistry(f"{self.name}_{variant}") + for k in self._kernels: + if k.variant == variant: + reg.add(k) + return reg + + def filter_by_arch(self, arch: str) -> "GroupedConvRegistry": + reg = GroupedConvRegistry(f"{self.name}_{arch}") + for k in self._kernels: + if k.arch == arch: + reg.add(k) + return reg + + def to_json(self, indent: int = 2) -> str: + return json.dumps( + { + "name": self.name, + "kernels": [k.to_json_obj() for k in self._kernels], + }, + indent=indent, + ) + + @classmethod + def from_json(cls, json_str: str) -> "GroupedConvRegistry": + data = json.loads(json_str) + reg = cls(data.get("name", "imported")) + for kd in data.get("kernels", []): + sig = kd.get("signature", {}) + algo = kd.get("algorithm", {}) + wave = algo.get("wave", "2x2x1").split("x") + warp = algo.get("warp", "32x32x16").split("x") + vec = algo.get("vector_sizes", [4, 8, 8]) + reg.add( + GroupedConvKernelConfig( + variant=sig.get("variant", "forward"), + ndim_spatial=sig.get("ndim_spatial", 2), + dtype=sig.get("dtype", "fp16"), + layout=sig.get("layout", "nhwgc"), + arch=kd.get("arch", "gfx942"), + tile_m=algo.get("tile_m", 1), + tile_n=algo.get("tile_n", 128), + tile_k=algo.get("tile_k", 128), + wave_m=int(wave[0]), + wave_n=int(wave[1]), + wave_k=int(wave[2]), + warp_tile_m=int(warp[0]), + warp_tile_n=int(warp[1]), + warp_tile_k=int(warp[2]), + pipeline=algo.get("pipeline", "compv3"), + epilogue=algo.get("epilogue", "cshuffle"), + scheduler=algo.get("scheduler", "intrawave"), + vector_size_a=vec[0] if len(vec) > 0 else 4, + vector_size_b=vec[1] if len(vec) > 1 else 8, + vector_size_c=vec[2] if len(vec) > 2 else 8, + block_per_cu=algo.get("block_per_cu", 1), + num_wave_groups=algo.get("num_wave_groups", 1), + num_groups_to_merge=algo.get("num_groups_to_merge", 1), + ) + ) + return reg + + def build( + self, + verbose: bool = False, + max_workers: Optional[int] = None, + ) -> Dict[Tuple[str, int], "GpuGroupedConvRunner"]: + """Parallel JIT compile all kernels in this registry. + + Args: + verbose: Print progress during build. + max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8). + + Returns a dict mapping (variant, ndim_spatial) to a ready-to-use + GpuGroupedConvRunner. + """ + if not self._kernels: + return {} + + libs = setup_multiple_grouped_conv_dispatchers( + self._kernels, + verbose=verbose, + max_workers=max_workers, + ) + + runners: Dict[Tuple[str, int], GpuGroupedConvRunner] = {} + for cfg, lib in zip(self._kernels, libs): + if lib is None: + continue + key = (cfg.variant, cfg.ndim_spatial) + if key in runners: + continue + runner = GpuGroupedConvRunner(lib_path=str(lib.path)) + if runner.is_available(): + runners[key] = runner + return runners + + def print_registry(self, indent: str = " "): + print(f"{indent}Registry '{self.name}': {len(self)} kernels") + for i, k in enumerate(self._kernels): + print( + f"{indent} [{i}] {k.name} (valid={validate_grouped_conv_config(k.to_dict()).is_valid})" + ) + + +# ============================================================================= +# GroupedConvValidationResult +# ============================================================================= + + +@dataclass +class GroupedConvValidationResult(ValidationResultBase): + """Result of grouped conv kernel config validation.""" + + variant: str = "forward" + + def __init__( + self, + is_valid=True, + errors=None, + warnings=None, + suggested_fixes=None, + variant="forward", + ): + super().__init__( + is_valid=is_valid, + errors=errors or [], + warnings=warnings or [], + suggested_fixes=suggested_fixes or {}, + ) + self.variant = variant + + +# ============================================================================= +# Validation helpers (extracted from the original config extraction code) +# ============================================================================= + + +def _first(val): + if isinstance(val, list) and len(val) > 0: + return val[0] + return val + + +def _get_tile_config(config: dict) -> dict: + return config.get("tile_config") or {} + + +def _get_trait_config(config: dict) -> dict: + return config.get("trait_config") or {} + + +def _extract_wave_config(tile_config: dict) -> List[int]: + wm = tile_config.get("wave_m") or tile_config.get("warp_m") + wn = tile_config.get("wave_n") or tile_config.get("warp_n") + wk = tile_config.get("wave_k") or tile_config.get("warp_k") + if wm is not None and wn is not None and wk is not None: + return [_first(wm), _first(wn), _first(wk)] + return [2, 2, 1] + + +def _extract_warp_tile_config(tile_config: dict) -> List[int]: + wtm = tile_config.get("warp_tile_m") or tile_config.get("warp_m") + wtn = tile_config.get("warp_tile_n") or tile_config.get("warp_n") + wtk = tile_config.get("warp_tile_k") or tile_config.get("warp_k") + if wtm is not None and wtn is not None and wtk is not None: + return [_first(wtm), _first(wtn), _first(wtk)] + return [32, 32, 16] + + +def _extract_trait_values(trait_config: dict) -> Tuple[str, str, str]: + p = _first(trait_config.get("pipeline", "compv4")) + e = _first(trait_config.get("epilogue", "cshuffle")) + s = _first(trait_config.get("scheduler", "intrawave")) + if isinstance(p, list): + p = p[0] if p else "compv4" + if isinstance(e, list): + e = e[0] if e else "cshuffle" + if isinstance(s, list): + s = s[0] if s else "intrawave" + return (str(p), str(e), str(s)) + + +# ============================================================================= +# validate_grouped_conv_config / auto_correct_grouped_conv_config +# ============================================================================= + + +def validate_grouped_conv_config(config: dict) -> GroupedConvValidationResult: + """Validate a grouped conv kernel config dict. + + Accepts either a raw dict (legacy) or GroupedConvKernelConfig.to_dict() output. + """ + errors: List[str] = [] + warnings: List[str] = [] + suggested_fixes: Dict[str, Any] = {} + + required = ( + "tile_config", + "trait_config", + "variant", + "ndim_spatial", + "arch", + "layout", + ) + for key in required: + if key not in config: + errors.append(f"Missing required key: {key}") + if errors: + return GroupedConvValidationResult( + is_valid=False, + errors=errors, + warnings=warnings, + suggested_fixes=suggested_fixes, + variant=config.get("variant", "forward"), + ) + + tile_config = _get_tile_config(config) + trait_config = _get_trait_config(config) + variant = _first(config.get("variant", "forward")) + if isinstance(variant, list): + variant = variant[0] if variant else "forward" + variant = _resolve_variant(str(variant)) + + ndim_spatial = config.get("ndim_spatial") + arch = config.get("arch", "gfx942") + dtype = config.get("dtype", "fp16") + + if variant not in VALID_VARIANTS: + errors.append(f"Invalid variant: {variant}. Valid: {', '.join(VALID_VARIANTS)}") + suggested_fixes["variant"] = "forward" + + if ndim_spatial is not None: + ndim = ndim_spatial + if isinstance(ndim, list): + ndim = ndim[0] if ndim else 2 + if ndim not in VALID_NDIM_SPATIAL: + errors.append( + f"Invalid ndim_spatial: {ndim}. Valid: {', '.join(map(str, VALID_NDIM_SPATIAL))}" + ) + suggested_fixes["ndim_spatial"] = 2 + + pipeline, epilogue, scheduler = _extract_trait_values(trait_config) + if variant in BACKWARD_VARIANTS and pipeline not in BACKWARD_PIPELINES: + errors.append( + f"Backward variant '{variant}' requires pipeline compv3 or mem, got {pipeline}" + ) + suggested_fixes["pipeline"] = "compv3" + + ok, msg = validate_trait_combo(pipeline, epilogue, scheduler) + if not ok: + errors.append(msg) + suggested_fixes["scheduler"] = "intrawave" + + wave_cfg = _extract_wave_config(tile_config) + ok, msg = validate_wave_config(wave_cfg, arch) + if not ok: + errors.append(msg) + arch_data = get_arch_filter_data() + valid_waves = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + if valid_waves: + suggested_fixes["wave_m"] = valid_waves[0][0] + suggested_fixes["wave_n"] = valid_waves[0][1] + suggested_fixes["wave_k"] = valid_waves[0][2] + + warp_cfg = _extract_warp_tile_config(tile_config) + ok, msg = validate_warp_tile_config(warp_cfg, arch, dtype) + if not ok: + errors.append(msg) + arch_data = get_arch_filter_data() + acc = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc}" + valid_tiles = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + if valid_tiles: + suggested_fixes["warp_tile_m"] = valid_tiles[0][0] + suggested_fixes["warp_tile_n"] = valid_tiles[0][1] + suggested_fixes["warp_tile_k"] = valid_tiles[0][2] + + arch_data = get_arch_filter_data() + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}. Supported: {', '.join(arch_data['supported_archs'])}" + ) + + return GroupedConvValidationResult( + is_valid=len(errors) == 0, + errors=errors, + warnings=warnings, + suggested_fixes=suggested_fixes, + variant=variant, + ) + + +def auto_correct_grouped_conv_config( + config: dict, +) -> Tuple[dict, GroupedConvValidationResult]: + """Auto-correct invalid grouped conv config. Returns (corrected, result).""" + result = validate_grouped_conv_config(config) + corrected = copy.deepcopy(config) + + if result.is_valid: + return corrected, result + + tile_config = corrected.setdefault("tile_config", {}) + trait_config = corrected.setdefault("trait_config", {}) + + wave_cfg = _extract_wave_config(tile_config) + arch = config.get("arch", "gfx942") + fixed_wave = auto_correct_wave(wave_cfg, arch) + tile_config["wave_m"] = fixed_wave[0] + tile_config["wave_n"] = fixed_wave[1] + tile_config["wave_k"] = fixed_wave[2] + + pipeline, epilogue, scheduler = _extract_trait_values(trait_config) + fixed_pipeline, fixed_scheduler = auto_correct_trait(pipeline, scheduler) + trait_config["pipeline"] = fixed_pipeline + trait_config["scheduler"] = fixed_scheduler + + variant = _first(config.get("variant", "forward")) + if isinstance(variant, list): + variant = variant[0] if variant else "forward" + variant = _resolve_variant(str(variant)) + if variant in BACKWARD_VARIANTS and fixed_pipeline not in BACKWARD_PIPELINES: + trait_config["pipeline"] = "compv3" + + if "warp_tile_m" in result.suggested_fixes: + tile_config["warp_tile_m"] = result.suggested_fixes["warp_tile_m"] + tile_config["warp_tile_n"] = result.suggested_fixes["warp_tile_n"] + tile_config["warp_tile_k"] = result.suggested_fixes["warp_tile_k"] + + result = validate_grouped_conv_config(corrected) + return corrected, result + + +def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]: + """Run one hipcc compile+link job in a subprocess worker.""" + import subprocess + from pathlib import Path + + compile_cmd = args["compile_cmd"] + link_cmd = args["link_cmd"] + lib_path = Path(args["lib_path"]) + + try: + res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300) + if res_c.returncode != 0: + return False, None, f"Compile failed: {res_c.stderr[:400]}" + + res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300) + if res_l.returncode != 0: + return False, None, f"Link failed: {res_l.stderr[:400]}" + + return True, lib_path, "" + except subprocess.TimeoutExpired: + return False, None, "Timeout" + except Exception as e: + return False, None, f"Error: {e}" + + +def _run_conv_codegen_subprocess(args: dict) -> Tuple[bool, Optional[str], str]: + """Run grouped-conv codegen once and return generated kernel header path.""" + import subprocess + from pathlib import Path + + out_dir = Path(args["output_dir"]) + out_dir.mkdir(parents=True, exist_ok=True) + + # Remove stale kernels so header discovery is exact for this invocation. + for stale in out_dir.glob("grouped_conv_*.hpp"): + stale.unlink(missing_ok=True) + for stale in out_dir.glob("include_all_grouped_conv_*.hpp"): + stale.unlink(missing_ok=True) + + try: + res = subprocess.run(args["cmd"], capture_output=True, text=True, timeout=300) + if res.returncode != 0: + err = (res.stderr or res.stdout or "").strip()[:500] + return False, None, f"Codegen failed: {err}" + + generated = sorted( + out_dir.glob("grouped_conv_*.hpp"), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + if not generated: + return False, None, "Codegen produced no grouped_conv_*.hpp header" + + return True, str(generated[0]), "" + except subprocess.TimeoutExpired: + return False, None, "Codegen timed out" + except Exception as e: + return False, None, f"Codegen error: {e}" + + +def _config_key(c: GroupedConvKernelConfig) -> Tuple[Any, ...]: + return ( + c.variant, + c.ndim_spatial, + c.dtype, + c.layout, + c.arch, + c.tile_m, + c.tile_n, + c.tile_k, + c.wave_m, + c.wave_n, + c.wave_k, + c.warp_tile_m, + c.warp_tile_n, + c.warp_tile_k, + c.pipeline, + c.epilogue, + c.scheduler, + ) + + +def _parse_triplet(value: str) -> Tuple[int, int, int]: + parts = value.split("x") + if len(parts) != 3: + raise ValueError(f"Invalid triplet: {value}") + return int(parts[0]), int(parts[1]), int(parts[2]) + + +def _list_arch_valid_grouped_conv_configs( + codegen_script: Path, + arch: str, + dtype: str, + variant: str, + ndim_spatial: int, +) -> List[GroupedConvKernelConfig]: + """Query codegen defaults for this (arch, dtype, variant, ndim) tuple.""" + import re + import sys + + cmd = [ + sys.executable, + str(codegen_script), + "--list-configs", + "--arch", + arch, + "--datatype", + dtype, + "--variant", + variant, + "--ndim", + str(ndim_spatial), + ] + res = subprocess.run(cmd, capture_output=True, text=True, timeout=180) + if res.returncode != 0: + return [] + + # Example: + # grouped_conv_fwd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1_32x32x16 + name_re = re.compile( + r"^grouped_conv_(fwd|bwd_data|bwd_weight|bwdd|bwdw)_([a-z0-9]+)_([a-z0-9]+)_([123])d_" + r"([a-z0-9]+)_([a-z0-9]+)_([a-z0-9]+)_" + r"([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)" + r"(?:_.*)?$" + ) + short_to_variant = { + "fwd": "forward", + "bwd_data": "bwd_data", + "bwd_weight": "bwd_weight", + "bwdd": "bwd_data", + "bwdw": "bwd_weight", + } + + out: List[GroupedConvKernelConfig] = [] + seen = set() + for raw in res.stdout.splitlines(): + line = raw.strip() + if not line.startswith("- grouped_conv_"): + continue + name = line[2:].strip() + m = name_re.match(name) + if not m: + continue + + v_short, dt, layout, ndim, pipe, epi, sched, tile_s, wave_s, warp_s = m.groups() + tm, tn, tk = _parse_triplet(tile_s) + wm, wn, wk = _parse_triplet(wave_s) + wtm, wtn, wtk = _parse_triplet(warp_s) + + cfg = GroupedConvKernelConfig( + variant=short_to_variant[v_short], + ndim_spatial=int(ndim), + dtype=dt, + layout=layout, + arch=arch, + tile_m=tm, + tile_n=tn, + tile_k=tk, + wave_m=wm, + wave_n=wn, + wave_k=wk, + warp_tile_m=wtm, + warp_tile_n=wtn, + warp_tile_k=wtk, + pipeline=pipe, + epilogue=epi, + scheduler=sched, + ) + key = _config_key(cfg) + if key not in seen: + out.append(cfg) + seen.add(key) + + return out + + +def _select_best_arch_valid_conv_config( + requested: GroupedConvKernelConfig, + candidates: List[GroupedConvKernelConfig], +) -> GroupedConvKernelConfig: + """Pick nearest arch-valid config while preferring trait exact matches.""" + + def score(c: GroupedConvKernelConfig) -> Tuple[int, int, int, int, int, int]: + tile_delta = ( + abs(c.tile_m - requested.tile_m) + + abs(c.tile_n - requested.tile_n) + + abs(c.tile_k - requested.tile_k) + ) + wave_delta = ( + abs(c.wave_m - requested.wave_m) + + abs(c.wave_n - requested.wave_n) + + abs(c.wave_k - requested.wave_k) + ) + warp_tile_delta = ( + abs(c.warp_tile_m - requested.warp_tile_m) + + abs(c.warp_tile_n - requested.warp_tile_n) + + abs(c.warp_tile_k - requested.warp_tile_k) + ) + return ( + 0 if c.pipeline == requested.pipeline else 1, + 0 if c.scheduler == requested.scheduler else 1, + 0 if c.epilogue == requested.epilogue else 1, + tile_delta, + wave_delta, + warp_tile_delta, + ) + + best = min(candidates, key=score) + selected = copy.deepcopy(best) + selected.arch = requested.arch + return selected + + +def _write_single_conv_dispatch_header( + config: GroupedConvKernelConfig, + kernel_header: Path, + dispatch_header: Path, +) -> None: + """Create a tiny dispatch header consumed by conv_ctypes_lib.cpp.""" + macros: List[str] = [] + aliases: List[str] = [] + + if config.variant == "forward": + kernel_name_symbol = "CONV_FWD_KERNEL_NAME" + if config.ndim_spatial == 3: + macros.append("#define CONV_FWD_3D_AVAILABLE 1") + aliases.append("using ConvFwd3dLauncher = SelectedConvKernelLauncher;") + else: + macros.append("#define CONV_FWD_2D_AVAILABLE 1") + elif config.variant == "bwd_data": + kernel_name_symbol = "CONV_BWD_DATA_KERNEL_NAME" + if config.ndim_spatial == 3: + macros.append("#define CONV_BWD_DATA_3D_AVAILABLE 1") + aliases.append("using ConvBwdData3dLauncher = SelectedConvBwdDataLauncher;") + else: + macros.append("#define CONV_BWD_DATA_2D_AVAILABLE 1") + else: + kernel_name_symbol = "CONV_BWD_WEIGHT_KERNEL_NAME" + if config.ndim_spatial == 3: + macros.append("#define CONV_BWD_WEIGHT_3D_AVAILABLE 1") + aliases.append( + "using ConvBwdWeight3dLauncher = SelectedConvBwdWeightLauncher;" + ) + else: + macros.append("#define CONV_BWD_WEIGHT_2D_AVAILABLE 1") + + content = ( + "// Auto-generated single-kernel dispatch header for Python JIT\n" + "#pragma once\n\n" + f'#include "{kernel_header.name}"\n\n' + + "\n".join(macros) + + "\n\n" + + "\n".join(aliases) + + "\n\n" + + f"static const char* CONV_KERNEL_NAMES[] = {{{kernel_name_symbol}}};\n" + + "static constexpr int CONV_KERNEL_COUNT = 1;\n" + ) + dispatch_header.write_text(content) + + +class GroupedConvCodegenRunner: + """Generate and compile grouped-conv JIT libraries in parallel.""" + + def __init__(self, max_workers: Optional[int] = None): + import multiprocessing + + self.max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + self.root = Path(__file__).parent.parent + self.build_dir = self.root / "build" + self.codegen_script = self.root / "codegen" / "unified_grouped_conv_codegen.py" + + def generate_and_compile_parallel( + self, + configs: List[GroupedConvKernelConfig], + verbose: bool = True, + ) -> List[Optional[Path]]: + import sys + from concurrent.futures import ProcessPoolExecutor, as_completed + + if not configs: + return [] + + if not self.build_dir.exists(): + self.build_dir.mkdir(parents=True, exist_ok=True) + + ctypes_source = self.root / "bindings" / "ctypes" / "conv_ctypes_lib.cpp" + static_lib = self.build_dir / "libck_tile_dispatcher.a" + jit_root = self.build_dir / "generated_kernels" / "python_jit" + jit_root.mkdir(parents=True, exist_ok=True) + (self.build_dir / "examples").mkdir(parents=True, exist_ok=True) + + if not self.codegen_script.exists(): + if verbose: + print(f"Codegen script missing: {self.codegen_script}") + return [None] * len(configs) + if not ctypes_source.exists() or not static_lib.exists(): + if verbose: + print("Missing conv ctypes source or static dispatcher library") + return [None] * len(configs) + + if verbose: + print( + f"Generating {len(configs)} grouped-conv kernels in parallel " + f"(workers={self.max_workers})..." + ) + + gen_jobs: List[Dict[str, Any]] = [] + job_dirs: List[Path] = [] + for i, c in enumerate(configs): + cfg_dir = jit_root / f"cfg_{i}" + cfg_dir.mkdir(parents=True, exist_ok=True) + job_dirs.append(cfg_dir) + + cmd = [ + sys.executable, + str(self.codegen_script), + "--output", + str(cfg_dir), + "--datatype", + c.dtype, + "--variant", + c.variant, + "--ndim", + str(c.ndim_spatial), + "--arch", + c.arch, + "--tile-m", + str(c.tile_m), + "--tile-n", + str(c.tile_n), + "--tile-k", + str(c.tile_k), + "--warp-m", + str(c.wave_m), + "--warp-n", + str(c.wave_n), + "--warp-k", + str(c.wave_k), + "--warp-tile-m", + str(c.warp_tile_m), + "--warp-tile-n", + str(c.warp_tile_n), + "--warp-tile-k", + str(c.warp_tile_k), + "--pipeline", + c.pipeline, + "--scheduler", + c.scheduler, + "--epilogue", + c.epilogue, + ] + gen_jobs.append({"cmd": cmd, "output_dir": str(cfg_dir)}) + + generated_headers: List[Optional[Path]] = [None] * len(configs) + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_conv_codegen_subprocess, job): idx + for idx, job in enumerate(gen_jobs) + } + for future in as_completed(futures): + idx = futures[future] + ok, header_path, err = future.result() + if ok and header_path: + generated_headers[idx] = Path(header_path) + if verbose: + print(f" OK [{idx}] codegen: {Path(header_path).name}") + else: + if verbose: + print(f" FAIL [{idx}] codegen: {err}") + + if verbose: + compile_count = sum(1 for h in generated_headers if h is not None) + print( + f"Compiling {compile_count} grouped-conv libraries in parallel " + f"(workers={self.max_workers})..." + ) + + compile_jobs: List[Dict[str, Any]] = [] + compile_to_input_index: Dict[int, int] = {} + for i, c in enumerate(configs): + hdr_path = generated_headers[i] + if hdr_path is None: + continue + + cfg_dir = job_dirs[i] + dispatch_header = cfg_dir / "conv_python_dispatch.hpp" + _write_single_conv_dispatch_header(c, hdr_path, dispatch_header) + + lib_name = ( + f"libdispatcher_conv_{c.variant}_{c.ndim_spatial}d_{c.dtype}_" + f"{c.tile_str}_{c.wave_str}_{c.warp_str}_{c.pipeline}_{c.scheduler}.so" + ) + lib_path = self.build_dir / "examples" / lib_name + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{self.root / 'include'}", + f"-I{self.root.parent / 'include'}", + f"-I{self.root.parent}", + f"-I{cfg_dir}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{dispatch_header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={c.arch}", + f'-DGFX_ARCH="{c.arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={c.arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + compile_to_input_index[len(compile_jobs)] = i + compile_jobs.append( + { + "compile_cmd": compile_cmd, + "link_cmd": link_cmd, + "lib_path": str(lib_path), + "config_name": c.name, + } + ) + + results_map: Dict[int, Optional[Path]] = {i: None for i in range(len(configs))} + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_hipcc_subprocess, job): j + for j, job in enumerate(compile_jobs) + } + for future in as_completed(futures): + job_idx = futures[future] + idx = compile_to_input_index[job_idx] + success, lib_path, err = future.result() + if success and lib_path: + results_map[idx] = Path(lib_path) + if verbose: + status = "OK" if success else f"FAIL ({err})" + name = ( + Path(lib_path).name + if success and lib_path + else compile_jobs[job_idx]["config_name"] + ) + print(f" {status} {name}") + + return [results_map.get(i) for i in range(len(configs))] + + +# ============================================================================= +# Convenience functions +# ============================================================================= + + +def get_grouped_conv_default_config( + variant: str = "forward", + ndim_spatial: int = 2, + arch: str = "gfx942", + dtype: str = "fp16", +) -> GroupedConvKernelConfig: + """Return a valid default GroupedConvKernelConfig.""" + return GroupedConvKernelConfig( + variant=variant, + ndim_spatial=ndim_spatial, + arch=arch, + dtype=dtype, + ) + + +def format_grouped_conv_summary(config) -> str: + """Format a config (dict or GroupedConvKernelConfig) into a human-readable string.""" + if isinstance(config, GroupedConvKernelConfig): + lines = [ + f"Grouped Conv Config: {config.variant} {config.ndim_spatial}D", + f" Arch: {config.arch}", + f" Layout: {config.layout}", + f" Dtype: {config.dtype}", + f" Tile: {config.tile_str}", + f" Wave: {config.wave_str}", + f" Warp: {config.warp_str}", + f" Traits: pipeline={config.pipeline} epilogue={config.epilogue} scheduler={config.scheduler}", + ] + return "\n".join(lines) + + # Legacy dict support + tile_config = _get_tile_config(config) if isinstance(config, dict) else {} + trait_config = _get_trait_config(config) if isinstance(config, dict) else {} + variant = config.get("variant", "?") if isinstance(config, dict) else "?" + ndim = config.get("ndim_spatial", "?") if isinstance(config, dict) else "?" + arch = config.get("arch", "?") if isinstance(config, dict) else "?" + layout = config.get("layout", "?") if isinstance(config, dict) else "?" + dtype = config.get("dtype", "fp16") if isinstance(config, dict) else "fp16" + + lines = [f"Grouped Conv Config: {variant} {ndim}D"] + lines.append(f" Arch: {arch}") + lines.append(f" Layout: {layout}") + lines.append(f" Dtype: {dtype}") + + if tile_config: + wave = _extract_wave_config(tile_config) + warp = _extract_warp_tile_config(tile_config) + lines.append( + f" Tile: M={_first(tile_config.get('tile_m', 1))} N={_first(tile_config.get('tile_n', 128))} K={_first(tile_config.get('tile_k', 128))}" + ) + lines.append(f" Wave: {wave[0]}x{wave[1]}x{wave[2]}") + lines.append(f" Warp: {warp[0]}x{warp[1]}x{warp[2]}") + + if trait_config: + pipeline = _first(trait_config.get("pipeline", "?")) + epilogue = _first(trait_config.get("epilogue", "?")) + scheduler = _first(trait_config.get("scheduler", "?")) + lines.append( + f" Traits: pipeline={pipeline} epilogue={epilogue} scheduler={scheduler}" + ) + + return "\n".join(lines) if lines else "(empty config)" + + +def setup_multiple_grouped_conv_dispatchers( + configs: List[GroupedConvKernelConfig], + verbose: bool = True, + max_workers: Optional[int] = None, +) -> List[Optional[GroupedConvDispatcherLib]]: + """ + Setup multiple grouped-conv dispatchers in parallel. + + This keeps architecture filtering strict: + 1. Validate + auto-correct each requested config + 2. Query codegen's arch-valid config set for each (arch, dtype, variant, ndim) + 3. Map each request to nearest valid config + 4. Parallel codegen + parallel compile + """ + if not configs: + return [] + + codegen_script = ( + Path(__file__).parent.parent / "codegen" / "unified_grouped_conv_codegen.py" + ) + arch_valid_cache: Dict[ + Tuple[str, str, str, int], List[GroupedConvKernelConfig] + ] = {} + + selected_configs: List[Optional[GroupedConvKernelConfig]] = [] + for i, original in enumerate(configs): + c = copy.deepcopy(original) + + val = validate_grouped_conv_config(c.to_dict()) + if not val.is_valid: + corrected, corrected_result = auto_correct_grouped_conv_config(c.to_dict()) + if not corrected_result.is_valid: + if verbose: + print(f" FAIL [{i}] config remains invalid after auto-correct") + selected_configs.append(None) + continue + + tile_cfg = corrected.get("tile_config", {}) + trait_cfg = corrected.get("trait_config", {}) + c.variant = _resolve_variant( + str(_first(corrected.get("variant", c.variant))) + ) + c.ndim_spatial = int(_first(corrected.get("ndim_spatial", c.ndim_spatial))) + c.arch = str(corrected.get("arch", c.arch)) + c.layout = str(corrected.get("layout", c.layout)) + c.dtype = str(corrected.get("dtype", c.dtype)) + c.tile_m = int(_first(tile_cfg.get("tile_m", c.tile_m))) + c.tile_n = int(_first(tile_cfg.get("tile_n", c.tile_n))) + c.tile_k = int(_first(tile_cfg.get("tile_k", c.tile_k))) + c.wave_m = int(_first(tile_cfg.get("wave_m", c.wave_m))) + c.wave_n = int(_first(tile_cfg.get("wave_n", c.wave_n))) + c.wave_k = int(_first(tile_cfg.get("wave_k", c.wave_k))) + c.warp_tile_m = int(_first(tile_cfg.get("warp_tile_m", c.warp_tile_m))) + c.warp_tile_n = int(_first(tile_cfg.get("warp_tile_n", c.warp_tile_n))) + c.warp_tile_k = int(_first(tile_cfg.get("warp_tile_k", c.warp_tile_k))) + c.pipeline = str(_first(trait_cfg.get("pipeline", c.pipeline))) + c.scheduler = str(_first(trait_cfg.get("scheduler", c.scheduler))) + c.epilogue = str(_first(trait_cfg.get("epilogue", c.epilogue))) + + cache_key = (c.arch, c.dtype, c.variant, c.ndim_spatial) + if cache_key not in arch_valid_cache: + arch_valid_cache[cache_key] = _list_arch_valid_grouped_conv_configs( + codegen_script=codegen_script, + arch=c.arch, + dtype=c.dtype, + variant=c.variant, + ndim_spatial=c.ndim_spatial, + ) + if verbose and not arch_valid_cache[cache_key]: + print( + f" FAIL [{i}] no arch-valid configs listed for " + f"{c.arch}/{c.dtype}/{c.variant}/{c.ndim_spatial}d" + ) + + candidates = arch_valid_cache[cache_key] + if not candidates: + selected_configs.append(None) + continue + + selected = _select_best_arch_valid_conv_config(c, candidates) + if verbose and _config_key(selected) != _config_key(c): + print( + f" INFO [{i}] mapped to arch-valid config: " + f"{selected.tile_str} {selected.wave_str} {selected.warp_str} " + f"{selected.pipeline}/{selected.scheduler}/{selected.epilogue}" + ) + selected_configs.append(selected) + + unique_configs: List[GroupedConvKernelConfig] = [] + unique_index_by_key: Dict[Tuple[Any, ...], int] = {} + input_to_unique: List[Optional[int]] = [] + for cfg in selected_configs: + if cfg is None: + input_to_unique.append(None) + continue + key = _config_key(cfg) + if key not in unique_index_by_key: + unique_index_by_key[key] = len(unique_configs) + unique_configs.append(cfg) + input_to_unique.append(unique_index_by_key[key]) + + runner = GroupedConvCodegenRunner(max_workers=max_workers) + unique_lib_paths = runner.generate_and_compile_parallel( + unique_configs, verbose=verbose + ) + + libs: List[Optional[GroupedConvDispatcherLib]] = [] + loaded_cache: Dict[int, Optional[GroupedConvDispatcherLib]] = {} + for input_idx, unique_idx in enumerate(input_to_unique): + if unique_idx is None: + libs.append(None) + continue + + if unique_idx in loaded_cache: + libs.append(loaded_cache[unique_idx]) + continue + + path = ( + unique_lib_paths[unique_idx] if unique_idx < len(unique_lib_paths) else None + ) + disp: Optional[GroupedConvDispatcherLib] = None + if path and path.exists(): + try: + lib = ctypes.CDLL(str(path)) + disp = GroupedConvDispatcherLib(lib, path) + disp.initialize() + except Exception as e: + if verbose: + print(f" FAIL [{input_idx}] failed to load {path}: {e}") + loaded_cache[unique_idx] = disp + libs.append(disp) + + return libs + + +def detect_gpu_arch() -> str: + """Detect GPU architecture using rocminfo.""" + try: + out = subprocess.check_output( + ["rocminfo"], stderr=subprocess.DEVNULL, text=True + ) + for line in out.split("\n"): + if "gfx" in line.lower() and "name:" in line.lower(): + for part in line.split(): + if part.startswith("gfx"): + return part + except Exception: + pass + return "gfx942" diff --git a/dispatcher/python/requirements.txt b/dispatcher/python/requirements.txt index 9d429235f7..3ed0a13de8 100644 --- a/dispatcher/python/requirements.txt +++ b/dispatcher/python/requirements.txt @@ -1,11 +1,16 @@ # Core dependencies numpy>=1.19.0 +# ML Heuristic dependencies (OPTIONAL - large dependencies) +# For ML-based kernel selection, install separately: +# pip install -r ../requirements-ml.txt +# This avoids mandatory large dependencies (pyarrow, lightgbm) for users who don't need ML features + # Optional dependencies (install with pip install -e ".[torch]") # torch>=2.0.0 # Development dependencies (install with pip install -e ".[dev]") -# pytest>=6.0.0 +pytest>=6.0.0 # pytest-cov>=2.0.0 # black>=21.0 # flake8>=3.9.0 diff --git a/dispatcher/requirements-ml.txt b/dispatcher/requirements-ml.txt new file mode 100644 index 0000000000..68f60a3d91 --- /dev/null +++ b/dispatcher/requirements-ml.txt @@ -0,0 +1,6 @@ +# ML Heuristic dependencies for ML-based kernel selection +# Install with: pip install -r requirements-ml.txt +lightgbm>=3.0.0 +pandas>=1.3.0 +pyarrow>=6.0.0 +scikit-learn>=0.24.0 diff --git a/dispatcher/scripts/compile_gemm_examples.py b/dispatcher/scripts/compile_gemm_examples.py index b19c18a13a..98ba18ab51 100644 --- a/dispatcher/scripts/compile_gemm_examples.py +++ b/dispatcher/scripts/compile_gemm_examples.py @@ -94,17 +94,17 @@ def find_hipcc() -> str: def extract_conv_kernel_declarations(source_file: Path) -> list: - """Extract CONVOLUTION kernel declarations from C++ source file. + """Extract GROUPED CONVOLUTION kernel declarations from C++ source file. - Supports DECL_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern. + Supports DECL_GROUPED_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern. Extracts all parameters: dtype, layout, conv_type, dims, tile, wave, warp, pipeline, scheduler. """ content = source_file.read_text() declarations = [] seen = set() - # Pattern: DECL_CONV_KERNEL_SET(name, .add(...).add(...)) - set_pattern = r"DECL_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)" + # Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...)) + set_pattern = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)" for match in re.finditer(set_pattern, content, re.DOTALL): set_name = match.group(1) @@ -396,24 +396,23 @@ def expand_conv_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") - def generate_conv_kernels(declarations: list, gpu_target: str = "gfx942") -> int: - """Generate convolution kernels using unified_conv_codegen.""" + """Generate grouped convolution kernels using unified_grouped_conv_codegen.""" kernel_dir = get_generated_kernels_dir() kernel_dir.mkdir(parents=True, exist_ok=True) - # Import conv codegen codegen_dir = get_dispatcher_root() / "codegen" sys.path.insert(0, str(codegen_dir)) try: - from unified_conv_codegen import ( - UnifiedConvCodegen, - ConvKernelConfig, - ConvVariant, + from unified_grouped_conv_codegen import ( + UnifiedGroupedConvCodegen as UnifiedConvCodegen, + GroupedConvKernelConfig as ConvKernelConfig, + GroupedConvVariant as ConvVariant, TileConfig, - TraitConfig, + GroupedConvTraitConfig as TraitConfig, ) except ImportError as e: - print_error(f" Failed to import conv codegen: {e}") + print_error(f" Failed to import grouped conv codegen: {e}") return 0 codegen = UnifiedConvCodegen(kernel_dir) @@ -1564,9 +1563,9 @@ def build_exact_conv_kernel_filename(decl: dict) -> str: if conv_type == "forward": type_prefix = "fwd" elif conv_type == "bwd_data": - type_prefix = "bwdd" + type_prefix = "bwd_data" elif conv_type == "bwd_weight": - type_prefix = "bwdw" + type_prefix = "bwd_weight" else: type_prefix = conv_type @@ -1601,9 +1600,9 @@ def generate_specific_conv_kernel(decl: dict, gpu_target: str = "gfx942") -> boo else: variant = "forward" - # Use unified_conv_codegen + # Use unified_grouped_conv_codegen codegen_dir = get_dispatcher_root() / "codegen" - codegen_script = codegen_dir / "unified_conv_codegen.py" + codegen_script = codegen_dir / "unified_grouped_conv_codegen.py" output_dir = get_generated_kernels_dir() cmd = [ @@ -1661,9 +1660,9 @@ def find_conv_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path: if conv_type == "forward": type_prefix = "fwd" elif conv_type == "bwd_data": - type_prefix = "bwdd" + type_prefix = "bwd_data" elif conv_type == "bwd_weight": - type_prefix = "bwdw" + type_prefix = "bwd_weight" else: type_prefix = conv_type @@ -1865,7 +1864,9 @@ In your C++ code, declare kernels like: if not gemm_declarations and not conv_declarations: print_error(" No kernel declarations found!") - print(" Add DECL_KERNEL_SET for GEMM or DECL_CONV_KERNEL_SET for Conv") + print( + " Add DECL_KERNEL_SET for GEMM or DECL_GROUPED_CONV_KERNEL_SET for Grouped Conv" + ) return 1 # Handle GEMM declarations @@ -1913,7 +1914,7 @@ In your C++ code, declare kernels like: is_valid, error_msg = validate_kernel_config(decl, arch) if not is_valid: - print(f"\n ⚠ Invalid configuration: {decl_name}") + print(f"\n WARNING Invalid configuration: {decl_name}") # Parse the error and show specific auto-corrections corrections = [] @@ -1926,7 +1927,7 @@ In your C++ code, declare kernels like: decl["wave_m"] = -1 decl["wave_n"] = -1 corrections.append( - f"wave: {original_values['wave']} → [wildcard expansion]" + f"wave: {original_values['wave']} -> [wildcard expansion]" ) if "warp tile" in error_msg.lower(): @@ -1936,7 +1937,7 @@ In your C++ code, declare kernels like: decl["warp_m"] = -1 decl["warp_n"] = -1 corrections.append( - f"warp_tile: {original_values['warp']} → [wildcard expansion]" + f"warp_tile: {original_values['warp']} -> [wildcard expansion]" ) if "trait combination" in error_msg.lower(): @@ -1945,16 +1946,16 @@ In your C++ code, declare kernels like: decl["pipeline"] = "*" decl["scheduler"] = "*" corrections.append( - f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + f"pipeline: {original_values['pipeline']} -> [wildcard expansion]" ) corrections.append( - f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + f"scheduler: {original_values['scheduler']} -> [wildcard expansion]" ) # Print the auto-corrections print(" AUTO-CORRECTION:") for corr in corrections: - print(f" • {corr}") + print(f" - {corr}") auto_corrections.append((decl_name, corrections)) invalid_count += 1 @@ -1962,15 +1963,15 @@ In your C++ code, declare kernels like: if invalid_count > 0: print( - f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion" ) if wildcard_count > 0: print( - f" ✓ {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + f" OK {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" ) else: - print(f" ✓ All {len(gemm_declarations)} configurations valid") + print(f" OK All {len(gemm_declarations)} configurations valid") # Expand GEMM declarations (for wildcards) print("\n Expanding wildcards to valid configurations...") @@ -1994,7 +1995,7 @@ In your C++ code, declare kernels like: wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" print( - f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" ) if len(expanded) > 3: print(f" ... and {len(expanded) - 3} more") @@ -2002,11 +2003,11 @@ In your C++ code, declare kernels like: exp = expanded[0] wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" - print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}") if len(expanded_gemm) > len(gemm_declarations): print( - f"\n Total: {len(gemm_declarations)} declarations → {len(expanded_gemm)} configurations" + f"\n Total: {len(gemm_declarations)} declarations -> {len(expanded_gemm)} configurations" ) gemm_declarations = expanded_gemm @@ -2054,7 +2055,7 @@ In your C++ code, declare kernels like: is_valid, error_msg = validate_conv_kernel_config(decl, arch) if not is_valid: - print(f"\n ⚠ Invalid conv configuration: {decl_name}") + print(f"\n WARNING Invalid conv configuration: {decl_name}") # Parse the error and show specific auto-corrections corrections = [] @@ -2067,7 +2068,7 @@ In your C++ code, declare kernels like: decl["wave_m"] = -1 decl["wave_n"] = -1 corrections.append( - f"wave: {original_values['wave']} → [wildcard expansion]" + f"wave: {original_values['wave']} -> [wildcard expansion]" ) if "warp tile" in error_msg.lower(): @@ -2077,7 +2078,7 @@ In your C++ code, declare kernels like: decl["warp_m"] = -1 decl["warp_n"] = -1 corrections.append( - f"warp_tile: {original_values['warp']} → [wildcard expansion]" + f"warp_tile: {original_values['warp']} -> [wildcard expansion]" ) if "trait combination" in error_msg.lower(): @@ -2086,16 +2087,16 @@ In your C++ code, declare kernels like: decl["pipeline"] = "*" decl["scheduler"] = "*" corrections.append( - f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + f"pipeline: {original_values['pipeline']} -> [wildcard expansion]" ) corrections.append( - f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + f"scheduler: {original_values['scheduler']} -> [wildcard expansion]" ) # Print the auto-corrections print(" AUTO-CORRECTION:") for corr in corrections: - print(f" • {corr}") + print(f" - {corr}") auto_corrections.append((decl_name, corrections)) invalid_count += 1 @@ -2103,15 +2104,15 @@ In your C++ code, declare kernels like: if invalid_count > 0: print( - f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion" ) if wildcard_count > 0: print( - f" ✓ {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + f" OK {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" ) else: - print(f" ✓ All {len(conv_declarations)} configurations valid") + print(f" OK All {len(conv_declarations)} configurations valid") # Expand Conv declarations (for wildcards) print("\n Expanding wildcards to valid configurations...") @@ -2134,7 +2135,7 @@ In your C++ code, declare kernels like: wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" print( - f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" ) if len(expanded) > 3: print(f" ... and {len(expanded) - 3} more") @@ -2142,11 +2143,11 @@ In your C++ code, declare kernels like: exp = expanded[0] wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" - print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}") if len(expanded_conv) > len(conv_declarations): print( - f"\n Total: {len(conv_declarations)} declarations → {len(expanded_conv)} configurations" + f"\n Total: {len(conv_declarations)} declarations -> {len(expanded_conv)} configurations" ) conv_declarations = expanded_conv diff --git a/dispatcher/scripts/compile_grouped_conv_examples.py b/dispatcher/scripts/compile_grouped_conv_examples.py new file mode 100644 index 0000000000..32fe70a2de --- /dev/null +++ b/dispatcher/scripts/compile_grouped_conv_examples.py @@ -0,0 +1,882 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Self-contained build script for C++ grouped convolution examples. + +Parses DECL_GROUPED_CONV_KERNEL_SET declarations from source files, +generates the needed kernels, and compiles the example. + +Includes validation and auto-correction via wildcard expansion. + +Usage: + python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/02_grouped_conv_forward.cpp + python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/03_grouped_conv_validation.cpp --no-compile +""" + +import argparse +import os +import re +import subprocess +import sys +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Optional + +# Setup paths +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +CK_ROOT = DISPATCHER_DIR.parent + +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from dispatcher_common import ( # noqa: E402 + print_phase, + print_success, + print_error, + print_info, + find_hipcc, + get_arch_filter_data, + get_build_dir, + get_ck_root, + get_dispatcher_root, + get_generated_kernels_dir, +) + + +def extract_grouped_conv_declarations(source_file: Path) -> list: + """Extract DECL_GROUPED_CONV_KERNEL_SET declarations from C++ source.""" + content = source_file.read_text() + declarations = [] + + # Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...)) + # Find all DECL_GROUPED_CONV_KERNEL_SET blocks by matching parentheses + pattern_start = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*," + for match in re.finditer(pattern_start, content): + set_name = match.group(1) + start_pos = match.end() + + # Find matching closing paren by counting parens + paren_count = 1 # We're already inside the first paren + end_pos = start_pos + for i, c in enumerate(content[start_pos:]): + if c == "(": + paren_count += 1 + elif c == ")": + paren_count -= 1 + if paren_count == 0: + end_pos = start_pos + i + break + + set_body = content[start_pos:end_pos] + + # Pattern 1: Simple add("dtype", "layout", "conv_type", tile_k, tile_c) + simple_add = ( + r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)' + ) + for add_match in re.finditer(simple_add, set_body): + conv_type = add_match.group(3) + default_pipeline = ( + "compv3" if conv_type in ("bwd_data", "bwd_weight") else "compv4" + ) + declarations.append( + { + "set": set_name, + "dtype": add_match.group(1), + "layout": add_match.group(2), + "conv_type": conv_type, + "tile_k": int(add_match.group(4)), + "tile_c": int(add_match.group(5)), + "num_dims": 2, + "pipeline": default_pipeline, + "scheduler": "intrawave", + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "arch": "gfx942", + } + ) + + # Pattern 2: Full ConvSig()/ConvAlgo() specification + # Find all .add( positions that start with ConvSig() + full_add = r"\.add\s*\(\s*ConvSig\(\)" + add_positions = [m.start() for m in re.finditer(full_add, set_body)] + + for pos in add_positions: + # Find matching closing paren by counting parens + paren_count = 0 + in_add = False + end = pos + for i, c in enumerate(set_body[pos:]): + if c == "(": + paren_count += 1 + in_add = True + elif c == ")": + paren_count -= 1 + if in_add and paren_count == 0: + end = pos + i + 1 + break + + add_str = set_body[pos:end] + + # Extract signature part (between ConvSig() and ConvAlgo()) + sig_match = re.search(r"ConvSig\(\)(.*?)ConvAlgo\(\)", add_str, re.DOTALL) + if not sig_match: + continue + sig_str = sig_match.group(1) + + # Extract algorithm part (between ConvAlgo() and arch string) + algo_match = re.search( + r'ConvAlgo\(\)(.*?),\s*"(\w+)"\s*\)', add_str, re.DOTALL + ) + if not algo_match: + continue + algo_str = algo_match.group(1) + arch = algo_match.group(2) + + # Parse signature + dtype = "fp16" + dtype_match = re.search(r'\.dtype\s*\(\s*"(\w+)"', sig_str) + if dtype_match: + dtype = dtype_match.group(1) + + layout = "nhwgc" + layout_match = re.search(r'\.layout\s*\(\s*"(\w+)"', sig_str) + if layout_match: + layout = layout_match.group(1) + + conv_type = "forward" + conv_type_match = re.search(r'\.conv_type\s*\(\s*"(\w+)"', sig_str) + if conv_type_match: + conv_type = conv_type_match.group(1) + + num_dims = 2 + dims_match = re.search(r"\.dims\s*\(\s*(\d+)", sig_str) + if dims_match: + num_dims = int(dims_match.group(1)) + + # Parse algorithm + tile_k, tile_c = 128, 128 + tile_match = re.search( + r"\.tile\s*\(\s*\d+\s*,\s*(\d+)\s*,\s*(\d+)", algo_str + ) + if tile_match: + tile_k = int(tile_match.group(1)) + tile_c = int(tile_match.group(2)) + + wave_m, wave_n, wave_k = 2, 2, 1 + wave_match = re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str + ) + if wave_match: + wave_m = int(wave_match.group(1)) + wave_n = int(wave_match.group(2)) + wave_k = int(wave_match.group(3) or 1) + + warp_m, warp_n, warp_k = 32, 32, 16 + warp_match = re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str + ) + if warp_match: + warp_m = int(warp_match.group(1)) + warp_n = int(warp_match.group(2)) + warp_k = int(warp_match.group(3) or 16) + + pipeline = "compv4" + pipeline_match = re.search(r'\.pipeline\s*\(\s*"(\w+)"', algo_str) + if pipeline_match: + pipeline = pipeline_match.group(1) + + scheduler = "intrawave" + scheduler_match = re.search(r'\.scheduler\s*\(\s*"(\w+)"', algo_str) + if scheduler_match: + scheduler = scheduler_match.group(1) + + # Parse additional parameters + vector_a, vector_b, vector_c = 4, 8, 8 + vector_match = re.search( + r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)", algo_str + ) + if vector_match: + vector_a = int(vector_match.group(1)) + vector_b = int(vector_match.group(2)) + vector_c = int(vector_match.group(3)) + + block_per_cu = 1 + block_per_cu_match = re.search(r"\.block_per_cu\s*\(\s*(\d+)", algo_str) + if block_per_cu_match: + block_per_cu = int(block_per_cu_match.group(1)) + + memory_op = "set" + memory_op_match = re.search(r'\.memory_op\s*\(\s*"(\w+)"', algo_str) + if memory_op_match: + memory_op = memory_op_match.group(1) + + epilogue = "cshuffle" + epilogue_match = re.search(r'\.epilogue\s*\(\s*"(\w+)"', algo_str) + if epilogue_match: + epilogue = epilogue_match.group(1) + + # Parse num_wave_groups (for V5 pipeline) + num_wave_groups = 1 + nwg_match = re.search(r"\.num_wave_groups\s*\(\s*(\d+)", algo_str) + if nwg_match: + num_wave_groups = int(nwg_match.group(1)) + + # Parse num_groups_to_merge (for merged group grouped convolution) + num_groups_to_merge = 1 + ngm_match = re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)", algo_str) + if ngm_match: + num_groups_to_merge = int(ngm_match.group(1)) + + # Parse double_smem_buffer (for V4 pipeline) + double_smem_buffer = False + dsb_match = re.search( + r"\.double_smem_buffer\s*\(\s*(true|false)", algo_str, re.I + ) + if dsb_match: + double_smem_buffer = dsb_match.group(1).lower() == "true" + + # Parse padding flags + pad_m, pad_n, pad_k = True, True, True + padding_match = re.search( + r"\.padding\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)", + algo_str, + re.I, + ) + if padding_match: + pad_m = padding_match.group(1).lower() == "true" + pad_n = padding_match.group(2).lower() == "true" + pad_k = padding_match.group(3).lower() == "true" + + declarations.append( + { + "set": set_name, + "dtype": dtype, + "layout": layout, + "conv_type": conv_type, + "tile_k": tile_k, + "tile_c": tile_c, + "num_dims": num_dims, + "pipeline": pipeline, + "scheduler": scheduler, + "wave_m": wave_m, + "wave_n": wave_n, + "wave_k": wave_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "vector_a": vector_a, + "vector_b": vector_b, + "vector_c": vector_c, + "block_per_cu": block_per_cu, + "memory_op": memory_op, + "epilogue": epilogue, + "num_wave_groups": num_wave_groups, + "num_groups_to_merge": num_groups_to_merge, + "double_smem_buffer": double_smem_buffer, + "pad_m": pad_m, + "pad_n": pad_n, + "pad_k": pad_k, + "arch": arch, + } + ) + + return declarations + + +# ============================================================================= +# VALIDATION AND AUTO-CORRECTION +# ============================================================================= + + +def is_grouped_conv_wildcard_declaration(decl: dict) -> bool: + """Check if a declaration uses wildcards (-1 or '*').""" + wildcard_fields = ["wave_m", "wave_n", "warp_m", "warp_n", "pipeline", "scheduler"] + for field in wildcard_fields: + val = decl.get(field) + if val == -1 or val == "*": + return True + return False + + +def validate_grouped_conv_kernel_config(decl: dict, arch: str = "gfx942") -> tuple: + """Validate a grouped conv kernel configuration against known supported combinations. + + Returns: (is_valid, error_message) + """ + # Skip validation for wildcards - expansion will filter invalid combos + if is_grouped_conv_wildcard_declaration(decl): + return (True, None) + + arch_data = get_arch_filter_data() + + pipeline = decl.get("pipeline", "compv4") + scheduler = decl.get("scheduler", "intrawave") + dtype = decl.get("dtype", "fp16") + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + errors = [] + + # Check trait combination (pipeline, epilogue, scheduler) + combo = (pipeline, "cshuffle", scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, scheduler={scheduler}\n" + f" Valid schedulers for {pipeline}: intrawave" + ) + + # Check wave configuration for this arch + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n" + f" Valid wave configs: {valid_str}" + ) + + # Check warp tile configuration for this arch and dtype + acc_dtype = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc_dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16], [16, 16, 32]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n" + f" Valid warp tiles: {valid_str}" + ) + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}\n" + f" Supported: {', '.join(arch_data['supported_archs'])}" + ) + + if errors: + return (False, "\n".join(errors)) + + return (True, None) + + +def expand_grouped_conv_declaration_with_arch_filter( + decl: dict, arch: str = "gfx942" +) -> list: + """Expand a grouped conv declaration with wildcards into valid configurations. + + Wildcards: + - wave_m/wave_n = -1: Try all valid wave configs for this arch + - warp_m/warp_n = -1: Try all valid warp tiles for this arch/dtype + - pipeline/scheduler = "*": Try all valid combinations + + Returns a list of fully-specified declarations. + """ + arch_data = get_arch_filter_data() + dtype = decl.get("dtype", "fp16") + + # Get valid combinations for this arch + valid_wave_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + acc_dtype = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc_dtype}" + valid_warp_tiles = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + + # Valid pipelines and schedulers + valid_pipelines = ["compv3", "compv4"] + valid_schedulers = ["intrawave"] # interwave often unsupported + + # Determine which fields need expansion + expand_wave = decl.get("wave_m", 2) == -1 or decl.get("wave_n", 2) == -1 + expand_warp = decl.get("warp_m", 32) == -1 or decl.get("warp_n", 32) == -1 + expand_pipeline = decl.get("pipeline", "compv4") == "*" + expand_scheduler = decl.get("scheduler", "intrawave") == "*" + + # Build combinations + wave_options = ( + valid_wave_combos + if expand_wave + else [[decl.get("wave_m", 2), decl.get("wave_n", 2), decl.get("wave_k", 1)]] + ) + warp_options = ( + valid_warp_tiles + if expand_warp + else [[decl.get("warp_m", 32), decl.get("warp_n", 32), decl.get("warp_k", 16)]] + ) + pipeline_options = ( + valid_pipelines if expand_pipeline else [decl.get("pipeline", "compv4")] + ) + scheduler_options = ( + valid_schedulers if expand_scheduler else [decl.get("scheduler", "intrawave")] + ) + + expanded = [] + for wave in wave_options: + for warp in warp_options: + for pipeline in pipeline_options: + for scheduler in scheduler_options: + # Skip known invalid combinations + if (pipeline, "cshuffle", scheduler) in arch_data[ + "trait_unsupported" + ]: + continue + + new_decl = decl.copy() + new_decl["wave_m"] = wave[0] + new_decl["wave_n"] = wave[1] + new_decl["wave_k"] = wave[2] + new_decl["warp_m"] = warp[0] + new_decl["warp_n"] = warp[1] + new_decl["warp_k"] = warp[2] + new_decl["pipeline"] = pipeline + new_decl["scheduler"] = scheduler + + expanded.append(new_decl) + + # If no valid expansions, return original (will fail validation later) + if not expanded: + return [decl] + + # Return first valid config (or all if needed) + return expanded[:1] # Just use first valid config for grouped conv + + +def validate_and_expand_grouped_conv_declarations( + declarations: list, arch: str, verbose: bool = False +) -> list: + """Validate declarations and auto-correct invalid ones via wildcard expansion.""" + print(f"\n Validating against {arch} arch filter...") + + wildcard_count = 0 + invalid_count = 0 + auto_corrections = [] + + for decl in declarations: + decl_arch = decl.get("arch", arch) + decl_name = ( + f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}" + ) + + # Check for wildcards + if is_grouped_conv_wildcard_declaration(decl): + wildcard_count += 1 + continue + + is_valid, error_msg = validate_grouped_conv_kernel_config(decl, decl_arch) + if not is_valid: + print(f"\n WARNING Invalid grouped conv configuration: {decl_name}") + + # Parse the error and show specific auto-corrections + corrections = [] + original_values = {} + + if "wave configuration" in error_msg.lower(): + original_values["wave"] = ( + f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]" + ) + decl["wave_m"] = -1 + decl["wave_n"] = -1 + corrections.append( + f"wave: {original_values['wave']} -> [wildcard expansion]" + ) + + if "warp tile" in error_msg.lower(): + original_values["warp"] = ( + f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]" + ) + decl["warp_m"] = -1 + decl["warp_n"] = -1 + corrections.append( + f"warp_tile: {original_values['warp']} -> [wildcard expansion]" + ) + + if "trait combination" in error_msg.lower(): + original_values["pipeline"] = decl.get("pipeline", "compv4") + original_values["scheduler"] = decl.get("scheduler", "intrawave") + decl["pipeline"] = "*" + decl["scheduler"] = "*" + corrections.append( + f"pipeline: {original_values['pipeline']} -> [wildcard expansion]" + ) + corrections.append( + f"scheduler: {original_values['scheduler']} -> [wildcard expansion]" + ) + + # Print the auto-corrections + print(" AUTO-CORRECTION:") + for corr in corrections: + print(f" - {corr}") + auto_corrections.append((decl_name, corrections)) + + invalid_count += 1 + wildcard_count += 1 + + if invalid_count > 0: + print( + f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + ) + + if wildcard_count > 0: + print( + f" OK {len(declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + ) + else: + print(f" OK All {len(declarations)} configurations valid") + + # Expand wildcards + print("\n Expanding wildcards to valid configurations...") + expanded_declarations = [] + for decl in declarations: + decl_arch = decl.get("arch", arch) + decl_name = ( + f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}" + ) + + expanded = expand_grouped_conv_declaration_with_arch_filter(decl, decl_arch) + expanded_declarations.extend(expanded) + + if len(expanded) > 1: + print( + f" {decl_name}: expanded to {len(expanded)} valid configurations" + ) + for exp in expanded[:3]: + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print( + f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}" + ) + if len(expanded) > 3: + print(f" ... and {len(expanded) - 3} more") + elif is_grouped_conv_wildcard_declaration(decl) and len(expanded) == 1: + exp = expanded[0] + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}") + + if len(expanded_declarations) != len(declarations): + print( + f"\n Total: {len(declarations)} declarations -> {len(expanded_declarations)} configurations" + ) + + return expanded_declarations + + +def _generate_single_grouped_conv_kernel(args: tuple) -> tuple: + """Generate one grouped conv kernel (picklable for ProcessPoolExecutor). + + Args: (decl, output_dir_str, gpu_target) + Returns: (idx, filepath_str or None, error_str or None) + """ + decl, output_dir_str, gpu_target = args + output_dir = Path(output_dir_str) + idx = decl.get("_idx", 0) + + try: + from codegen_common import TileConfig + from unified_grouped_conv_codegen import ( + GroupedConvKernelConfig, + GroupedConvTraitConfig, + GroupedConvVariant, + UnifiedGroupedConvCodegen, + ) + + # Map conv_type to variant + variant = GroupedConvVariant.FORWARD + if decl["conv_type"] == "bwd_data": + variant = GroupedConvVariant.BACKWARD_DATA + elif decl["conv_type"] == "bwd_weight": + variant = GroupedConvVariant.BACKWARD_WEIGHT + + pipeline = decl.get("pipeline", "compv4") + adj_tile_k = 64 * 2 if pipeline == "compv4" else 64 + + # Create tile config (tile_m=tile_k, tile_n=tile_c for conv GEMM view) + tile = TileConfig( + tile_m=decl["tile_k"], + tile_n=decl["tile_c"], + tile_k=adj_tile_k, + warp_m=decl["wave_m"], + warp_n=decl["wave_n"], + warp_k=decl.get("wave_k", 1), + warp_tile_m=decl["warp_m"], + warp_tile_n=decl["warp_n"], + warp_tile_k=decl["warp_k"], + ) + + trait = GroupedConvTraitConfig( + pipeline=pipeline, + scheduler=decl["scheduler"], + epilogue=decl.get("epilogue", "cshuffle"), + double_smem_buffer=decl.get("double_smem_buffer", False), + pad_m=decl.get("pad_m", True), + pad_n=decl.get("pad_n", True), + pad_k=decl.get("pad_k", True), + num_groups_to_merge=decl.get("num_groups_to_merge", 1), + ) + + config = GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=variant, + ndim_spatial=decl["num_dims"], + arch=decl.get("arch", gpu_target), + vector_size_a=decl.get("vector_a", 4), + vector_size_b=decl.get("vector_b", 8), + vector_size_c=decl.get("vector_c", 8), + block_per_cu=decl.get("block_per_cu", 1), + num_wave_groups=decl.get("num_wave_groups", 1), + num_groups_to_merge=decl.get("num_groups_to_merge", 1), + double_smem_buffer=decl.get("double_smem_buffer", False), + ) + + codegen = UnifiedGroupedConvCodegen(output_dir, gpu_target=gpu_target) + kernel_path, _ = codegen.generate_kernel(config, decl["dtype"], variant) + return (idx, str(kernel_path), None) + + except Exception as e: + return (idx, None, str(e)) + + +def generate_grouped_conv_kernels( + declarations: list, + output_dir: Path, + gpu_target: str = "gfx942", + max_workers: Optional[int] = None, +) -> list: + """Generate grouped convolution kernels using unified_grouped_conv_codegen. + + Uses ProcessPoolExecutor for parallel kernel generation. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + # Prepare work items (add _idx for ordering) + work_items = [] + for idx, decl in enumerate(declarations): + decl_copy = decl.copy() + decl_copy["_idx"] = idx + work_items.append((decl_copy, str(output_dir), gpu_target)) + + max_workers = max_workers or min(len(work_items), os.cpu_count() or 4) + generated = [] + failed = [] + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_generate_single_grouped_conv_kernel, w): w[0]["_idx"] + for w in work_items + } + for future in as_completed(futures): + idx, path, err = future.result() + if path: + generated.append(Path(path)) + print_info(f" Generated: {Path(path).name}") + else: + failed.append((idx, err)) + print_error(f" Failed kernel {idx + 1}: {err}") + + if failed: + for idx, err in failed[:3]: + print_error(f" Kernel {idx + 1}: {err[:200]}") + if len(failed) > 3: + print_error(f" ... and {len(failed) - 3} more failures") + + return generated + + +def compile_grouped_conv_example( + source_file: Path, + output_bin: Path, + kernel_headers: list, + hipcc: str, + gpu_target: str, +) -> bool: + """Compile the C++ example with generated kernels.""" + kernel_dir = get_generated_kernels_dir() + ck_root = get_ck_root() + dispatcher_dir = get_dispatcher_root() + + includes = [ + f"-I{ck_root / 'include'}", + f"-I{dispatcher_dir / 'include'}", + f"-I{kernel_dir}", + ] + + # Build include flags for generated kernels + kernel_includes = [] + for header in kernel_headers: + kernel_includes.extend(["-include", str(header)]) + + # Add define to indicate kernels are available + defines = ["-DGROUPED_CONV_KERNEL_AVAILABLE=1"] + + cmd = [ + hipcc, + "-std=c++20", + "-O2", + f"--offload-arch={gpu_target}", + *includes, + *defines, + *kernel_includes, + "-o", + str(output_bin), + str(source_file), + ] + + print_info(f" Compiling: {source_file.name}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + if result.stderr: + lines = result.stderr.split("\n") + errors = [line for line in lines if "error:" in line.lower()][:5] + for err_line in errors: + print_error(f" {err_line}") + return False + + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Build C++ grouped convolution example with self-contained kernel generation" + ) + parser.add_argument("source", help="Source file (.cpp)") + parser.add_argument("--output", "-o", help="Output binary name") + parser.add_argument("--gpu-target", default="gfx942", help="GPU target") + parser.add_argument( + "--no-compile", action="store_true", help="Only generate kernels, don't compile" + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument( + "--jobs", + "-j", + type=int, + default=None, + help="Parallel jobs for kernel generation (default: cpu_count)", + ) + args = parser.parse_args() + + # Resolve source file + source_file = Path(args.source) + if not source_file.is_absolute(): + candidates = [ + get_dispatcher_root() / args.source, + Path.cwd() / args.source, + ] + for c in candidates: + if c.exists(): + source_file = c + break + + if not source_file.exists(): + print_error(f"Source file not found: {source_file}") + return 1 + + build_dir = get_build_dir() + kernel_dir = get_generated_kernels_dir() + output_name = args.output or source_file.stem + output_bin = build_dir / output_name + + print_success("=== Grouped Conv Example Builder (Self-Contained) ===") + + # Phase 1: Extract declarations + print_phase(1, "Scanning for DECL_GROUPED_CONV_KERNEL_SET...") + declarations = extract_grouped_conv_declarations(source_file) + + if not declarations: + print_error(" No DECL_GROUPED_CONV_KERNEL_SET declarations found!") + return 1 + + print(f" Found {len(declarations)} kernel declaration(s):") + for decl in declarations: + name = f"{decl['dtype']}_{decl['conv_type']}_{decl['num_dims']}d_{decl['tile_k']}x{decl['tile_c']}" + print(f" [{decl['set']}] {name}") + + # Phase 2: Validate and expand + print_phase(2, "Validating and expanding declarations...") + declarations = validate_and_expand_grouped_conv_declarations( + declarations, args.gpu_target, args.verbose + ) + print() + + # Phase 3: Generate kernels + print_phase(3, "Generating kernels...") + generated = generate_grouped_conv_kernels( + declarations, kernel_dir, args.gpu_target, max_workers=args.jobs + ) + + if not generated: + print_error(" No kernels generated!") + return 1 + + print(f" Generated {len(generated)} kernel file(s)") + print() + + # Phase 4: Compile (optional) + if args.no_compile: + print_info("Skipping compilation (--no-compile)") + print() + print_success("=== Kernel Generation Complete ===") + print(f"Kernels in: {kernel_dir}") + return 0 + + print_phase(4, "Compiling example...") + hipcc_path = find_hipcc() + + if not hipcc_path: + print_error(" hipcc not found. Install ROCm or set HIPCC env var.") + print(" To compile manually:") + ck_root = get_dispatcher_root().parent + print( + f" hipcc -std=c++20 -O2 -I{ck_root / 'include'} -I{get_dispatcher_root() / 'include'} \\" + ) + print(f" -I{kernel_dir} \\") + for h in generated[:1]: + print(f" -include {h} \\") + print(" -DGROUPED_CONV_KERNEL_AVAILABLE=1 \\") + print(f" --offload-arch={args.gpu_target} \\") + print(f" {source_file} -o {output_bin}") + return 1 + + build_dir.mkdir(parents=True, exist_ok=True) + + if not compile_grouped_conv_example( + source_file, output_bin, generated, hipcc_path, args.gpu_target + ): + print_error(" Compilation failed!") + return 1 + + print_success(f" Output: {output_bin}") + print() + + print_success("=== Build Complete ===") + print() + print("Run with:") + print(f" {output_bin}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/scripts/example_kernel_builder.py b/dispatcher/scripts/example_kernel_builder.py index d3bb619174..20952cd91f 100755 --- a/dispatcher/scripts/example_kernel_builder.py +++ b/dispatcher/scripts/example_kernel_builder.py @@ -55,10 +55,10 @@ def extract_balanced_parens(text: str, start_pos: int) -> str: def parse_conv_declarations(content: str) -> List[Dict]: - """Parse DECL_CONV_KERNEL_SET declarations with all parameters.""" + """Parse DECL_GROUPED_CONV_KERNEL_SET declarations with all parameters.""" kernels = [] - for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content): + for match in re.finditer(r"DECL_GROUPED_CONV_KERNEL_SET\s*\(", content): body = extract_balanced_parens(content, match.end() - 1) if not body: continue @@ -619,7 +619,7 @@ def strip_cpp_strings_and_comments(content: str) -> str: n = len(content) # Patterns that indicate a string is problematic and should be stripped - problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("] + problematic_patterns = ["DECL_KERNEL_SET", "DECL_GROUPED_CONV_KERNEL_SET", ".add("] while i < n: # Check for raw string literal: R"delimiter(...)delimiter" @@ -697,7 +697,7 @@ def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]: content = source_path.read_text() content = strip_cpp_strings_and_comments(content) - if "DECL_CONV_KERNEL_SET" in content: + if "DECL_GROUPED_CONV_KERNEL_SET" in content: return "conv", parse_conv_declarations(content) elif "DECL_KERNEL_SET" in content: return "gemm", parse_gemm_declarations(content) @@ -966,30 +966,128 @@ def generate_per_set_functions(source_stem: str) -> str: def generate_conv_registration( kernel_headers: List[Path], example_name: str, kernels: List[Dict] ) -> str: - """Generate Conv kernel registration code for the dispatcher registry.""" + """Generate Conv kernel registration code for the dispatcher registry. + + Creates real GroupedConvKernelInstance entries backed by the generated + launcher's launch() method via the conv backend RunFn factories. + """ if not kernel_headers: return " // No kernels to register" lines = [] - lines.append( - " (void)registry; (void)arch; // Conv uses direct launcher pattern for now" - ) - # For conv, we provide direct access to kernel launchers for i, h in enumerate(kernel_headers): - kernel_name = h.stem - lines.append(f" // Kernel {i + 1}: {kernel_name}") + kname = h.stem + ns = f"ns_{kname}" + launcher = f"{ns}::{kname}_Launcher" + + # Determine direction and ndim from the kernel header name + if "_fwd_" in kname: + direction = "Forward" + run_fn_factory = "make_conv_fwd_run_fn" + elif "_bwd_data_" in kname or "_bwdd_" in kname: + direction = "BackwardData" + run_fn_factory = "make_conv_bwd_data_run_fn" + elif "_bwd_weight_" in kname or "_bwdw_" in kname: + direction = "BackwardWeight" + run_fn_factory = "make_conv_bwd_weight_run_fn" + else: + direction = "Forward" + run_fn_factory = "make_conv_fwd_run_fn" + + ndim = 3 if "_3d_" in kname else 2 + + # Parse dtype from name (e.g. grouped_conv_fwd_fp16_...) + dtype = "fp16" + for dt in ["fp16", "bf16", "fp32"]: + if f"_{dt}_" in kname: + dtype = dt + break + + # Parse tile, wave, warp from name. + # Format: ..._TILExTILExTILE_WAVExWAVExWAVE_WARPxWARPxWARP_... + import re as _re + + tile_m, tile_n, tile_k = 1, 128, 128 + wave_m, wave_n, wave_k = 2, 2, 1 + warp_m, warp_n, warp_k = 32, 32, 16 + + triplets = _re.findall(r"_(\d+)x(\d+)x(\d+)", kname) + if len(triplets) >= 1: + tile_m, tile_n, tile_k = ( + int(triplets[0][0]), + int(triplets[0][1]), + int(triplets[0][2]), + ) + if len(triplets) >= 2: + wave_m, wave_n, wave_k = ( + int(triplets[1][0]), + int(triplets[1][1]), + int(triplets[1][2]), + ) + if len(triplets) >= 3: + warp_m, warp_n, warp_k = ( + int(triplets[2][0]), + int(triplets[2][1]), + int(triplets[2][2]), + ) + + pipeline = "compv4" if "compv4" in kname else "compv3" + scheduler = "interwave" if "interwave" in kname else "intrawave" + epilogue = "cshuffle" if "cshuffle" in kname else "default" + + # ConvConfigBase defaults + vec_a, vec_b, vec_c = 4, 8, 8 + block_per_cu = 1 + num_wave_groups = 1 + num_groups_to_merge = 1 + + lines.append(f" // Kernel {i + 1}: {kname}") + lines.append(" {") + lines.append(f" ck_tile::dispatcher::GroupedConvKernelKey key_{i};") + lines.append(f' key_{i}.dtype_in = "{dtype}";') + lines.append(f' key_{i}.dtype_wei = "{dtype}";') + lines.append(f' key_{i}.dtype_out = "{dtype}";') + lines.append(f' key_{i}.layout = "nhwgc";') + lines.append(f" key_{i}.ndim_spatial = {ndim};") + lines.append( + f" key_{i}.op = ck_tile::dispatcher::GroupedConvOp::{direction};" + ) + lines.append(f" key_{i}.tile_m = {tile_m};") + lines.append(f" key_{i}.tile_n = {tile_n};") + lines.append(f" key_{i}.tile_k = {tile_k};") + lines.append(f" key_{i}.wave_m = {wave_m};") + lines.append(f" key_{i}.wave_n = {wave_n};") + lines.append(f" key_{i}.wave_k = {wave_k};") + lines.append(f" key_{i}.warp_m = {warp_m};") + lines.append(f" key_{i}.warp_n = {warp_n};") + lines.append(f" key_{i}.warp_k = {warp_k};") + lines.append(f' key_{i}.pipeline = "{pipeline}";') + lines.append(f' key_{i}.scheduler = "{scheduler}";') + lines.append(f' key_{i}.epilogue = "{epilogue}";') + lines.append(f" key_{i}.vector_size_a = {vec_a};") + lines.append(f" key_{i}.vector_size_b = {vec_b};") + lines.append(f" key_{i}.vector_size_c = {vec_c};") + lines.append(f" key_{i}.block_per_cu = {block_per_cu};") + lines.append(f" key_{i}.num_wave_groups = {num_wave_groups};") + lines.append(f" key_{i}.num_groups_to_merge = {num_groups_to_merge};") + lines.append(f" key_{i}.arch = arch;") + lines.append( + f" auto run_fn_{i} = ck_tile::dispatcher::backends::{run_fn_factory}<{launcher}, {ndim}>();" + ) + lines.append( + f' auto inst_{i} = std::make_shared(key_{i}, "{kname}", std::move(run_fn_{i}));' + ) + lines.append(f" registry.register_kernel(key_{i}, inst_{i});") + lines.append(" }") return "\n".join(lines) -def generate_conv_kernels( - kernels: List[Dict], output_dir: Path, codegen_dir: Path -) -> bool: - """Generate Conv kernels for ALL declarations using unified codegen.""" - if not kernels: - return False - +def _build_conv_codegen_cmd( + idx: int, k: Dict, codegen_dir: Path, output_dir: Path +) -> Tuple[int, List[str], str]: + """Build the command for a single conv kernel codegen invocation.""" variant_map = { "forward": "forward", "bwd_data": "bwd_data", @@ -997,93 +1095,130 @@ def generate_conv_kernels( "bwd_weight": "bwd_weight", "backward_weight": "bwd_weight", } + variant = variant_map.get(k.get("conv_type", "forward"), "forward") + + cmd = [ + sys.executable, + str(codegen_dir / "unified_grouped_conv_codegen.py"), + "--datatype", + k.get("dtype", "fp16"), + "--variant", + variant, + "--ndim", + str(k.get("ndim", 2)), + "--output", + str(output_dir), + ] + + if k.get("tile_m"): + cmd.extend(["--tile-m", str(k["tile_m"])]) + if k.get("tile_n"): + cmd.extend(["--tile-n", str(k["tile_n"])]) + if k.get("warp_m"): + cmd.extend(["--warp-m", str(k["warp_m"])]) + if k.get("warp_n"): + cmd.extend(["--warp-n", str(k["warp_n"])]) + if k.get("warp_k"): + cmd.extend(["--warp-k", str(k["warp_k"])]) + if k.get("warp_tile_m"): + cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])]) + if k.get("warp_tile_n"): + cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])]) + if k.get("warp_tile_k"): + cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])]) + if k.get("pipeline"): + cmd.extend(["--pipeline", k["pipeline"]]) + if k.get("scheduler"): + cmd.extend(["--scheduler", k["scheduler"]]) + if k.get("epilogue"): + cmd.extend(["--epilogue", k["epilogue"]]) + if k.get("vector_a"): + cmd.extend(["--vector-a", str(k["vector_a"])]) + if k.get("vector_b"): + cmd.extend(["--vector-b", str(k["vector_b"])]) + if k.get("vector_c"): + cmd.extend(["--vector-c", str(k["vector_c"])]) + if k.get("block_per_cu"): + cmd.extend(["--block-per-cu", str(k["block_per_cu"])]) + if k.get("num_wave_groups"): + cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])]) + if k.get("num_groups_to_merge"): + cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])]) + if k.get("double_smem_buffer") is not None: + cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()]) + if k.get("tile_k"): + cmd.extend(["--tile-k", str(k["tile_k"])]) + + return (idx, cmd, str(codegen_dir)) + + +def _run_conv_codegen(args: Tuple) -> Tuple[int, bool, str]: + """Run unified_grouped_conv_codegen.py for a single kernel config (picklable for ProcessPoolExecutor).""" + idx, cmd, cwd = args + result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd) + if result.returncode != 0: + return (idx, False, result.stderr[:300]) + return (idx, True, "") + + +def generate_conv_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path +) -> bool: + """Generate Conv kernels for ALL declarations using unified codegen. + + Launches all codegen subprocesses in parallel via ProcessPoolExecutor + for significantly faster generation when multiple conv kernels are declared. + """ + if not kernels: + return False + + work_items = [ + _build_conv_codegen_cmd(idx, k, codegen_dir, output_dir) + for idx, k in enumerate(kernels) + ] success_count = 0 + max_workers = min(len(work_items), os.cpu_count() or 4) - # Generate a kernel for EACH declaration - for idx, k in enumerate(kernels): - variant = variant_map.get(k.get("conv_type", "forward"), "forward") - - cmd = [ - sys.executable, - str(codegen_dir / "unified_conv_codegen.py"), - "--datatype", - k.get("dtype", "fp16"), - "--variant", - variant, - "--ndim", - str(k.get("ndim", 2)), - "--output", - str(output_dir), - ] - - # Add optional parameters if specified - if k.get("tile_m"): - cmd.extend(["--tile-m", str(k["tile_m"])]) - if k.get("tile_n"): - cmd.extend(["--tile-n", str(k["tile_n"])]) - if k.get("warp_m"): - cmd.extend(["--warp-m", str(k["warp_m"])]) - if k.get("warp_n"): - cmd.extend(["--warp-n", str(k["warp_n"])]) - if k.get("warp_k"): - cmd.extend(["--warp-k", str(k["warp_k"])]) - if k.get("warp_tile_m"): - cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])]) - if k.get("warp_tile_n"): - cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])]) - if k.get("warp_tile_k"): - cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])]) - if k.get("pipeline"): - cmd.extend(["--pipeline", k["pipeline"]]) - if k.get("scheduler"): - cmd.extend(["--scheduler", k["scheduler"]]) - if k.get("epilogue"): - cmd.extend(["--epilogue", k["epilogue"]]) - if k.get("vector_a"): - cmd.extend(["--vector-a", str(k["vector_a"])]) - if k.get("vector_b"): - cmd.extend(["--vector-b", str(k["vector_b"])]) - if k.get("vector_c"): - cmd.extend(["--vector-c", str(k["vector_c"])]) - if k.get("block_per_cu"): - cmd.extend(["--block-per-cu", str(k["block_per_cu"])]) - if k.get("num_wave_groups"): - cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])]) - if k.get("num_groups_to_merge"): - cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])]) - if k.get("double_smem_buffer") is not None: - cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()]) - if k.get("tile_k"): - cmd.extend(["--tile-k", str(k["tile_k"])]) - - result = subprocess.run( - cmd, capture_output=True, text=True, cwd=str(codegen_dir) - ) - if result.returncode != 0: - print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") - else: - success_count += 1 + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(_run_conv_codegen, w): w[0] for w in work_items} + for future in as_completed(futures): + idx, ok, err = future.result() + if ok: + success_count += 1 + else: + print(f" Codegen error for kernel {idx + 1}: {err}") return success_count > 0 +def _run_gemm_codegen(args: Tuple) -> Tuple[int, bool, str]: + """Run unified_gemm_codegen.py for a single kernel config (picklable for ProcessPoolExecutor).""" + idx, cmd, cwd = args + result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd) + if result.returncode != 0: + return (idx, False, result.stderr[:300]) + return (idx, True, "") + + def generate_gemm_kernels( kernels: List[Dict], output_dir: Path, codegen_dir: Path ) -> bool: - """Generate GEMM kernels for ALL declarations using unified codegen.""" + """Generate GEMM kernels for ALL declarations using unified codegen. + + Launches all codegen subprocesses in parallel via ProcessPoolExecutor + for significantly faster generation when multiple kernels are declared. + """ import json if not kernels: return False - success_count = 0 - - # Generate a kernel for EACH declaration + # Build all commands upfront + work_items = [] for idx, k in enumerate(kernels): variant = "multi_d" if k.get("elementwise_op") else "standard" - # Build tile config JSON for this specific kernel tile_config = { "tile_m": [k.get("tile_m", 128)], "tile_n": [k.get("tile_n", 128)], @@ -1125,13 +1260,20 @@ def generate_gemm_kernels( config_json, ] - result = subprocess.run( - cmd, capture_output=True, text=True, cwd=str(codegen_dir) - ) - if result.returncode != 0: - print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") - else: - success_count += 1 + work_items.append((idx, cmd, str(codegen_dir))) + + # Run all codegen subprocesses in parallel + success_count = 0 + max_workers = min(len(work_items), os.cpu_count() or 4) + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(_run_gemm_codegen, w): w[0] for w in work_items} + for future in as_completed(futures): + idx, ok, err = future.result() + if ok: + success_count += 1 + else: + print(f" Codegen error for kernel {idx + 1}: {err}") return success_count > 0 @@ -1229,15 +1371,17 @@ def main(): if example_type == "gemm": kernel_headers = list(args.output_dir.glob("gemm_*.hpp")) else: - k = kernels[0] if kernels else {} - variant = k.get("conv_type", "forward") prefix_map = { - "forward": "conv_fwd", - "bwd_data": "conv_bwdd", - "bwd_weight": "conv_bwdw", + "forward": "grouped_conv_fwd", + "bwd_data": "grouped_conv_bwd_data", + "bwd_weight": "grouped_conv_bwd_weight", } - prefix = prefix_map.get(variant, "conv_fwd") - kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp")) + # Collect headers from ALL variants present in declarations + variants_used = set(k.get("conv_type", "forward") for k in kernels) + kernel_headers = [] + for variant in variants_used: + prefix = prefix_map.get(variant, "grouped_conv_fwd") + kernel_headers.extend(args.output_dir.glob(f"{prefix}_*.hpp")) if not kernel_headers: print(f"[{target_name}] No kernel headers generated!") @@ -1347,29 +1491,39 @@ def main(): ) if has_bwd_data: - bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_") - if bwdd_kernel: - bwdd_ns = f"ns_{bwdd_kernel.stem}" - launcher_aliases.append( - f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + bwd_data_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwd_data_" + ) + if not bwd_data_kernel: + bwd_data_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwdd_" ) - if not has_fwd: # If no fwd, use bwd_data as first + if bwd_data_kernel: + bwd_data_ns = f"ns_{bwd_data_kernel.stem}" + launcher_aliases.append( + f"using BwdDataKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;" + ) + if not has_fwd: launcher_aliases.append( - f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + f"using FirstKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;" ) if has_bwd_weight: - bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_") - if bwdw_kernel: - bwdw_ns = f"ns_{bwdw_kernel.stem}" - launcher_aliases.append( - f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + bwd_weight_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwd_weight_" + ) + if not bwd_weight_kernel: + bwd_weight_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwdw_" ) - if ( - not has_fwd and not has_bwd_data - ): # If no fwd or bwdd, use bwdw as first + if bwd_weight_kernel: + bwd_weight_ns = f"ns_{bwd_weight_kernel.stem}" + launcher_aliases.append( + f"using BwdWeightKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;" + ) + if not has_fwd and not has_bwd_data: launcher_aliases.append( - f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + f"using FirstKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;" ) launcher_section = "\n".join(launcher_aliases) @@ -1382,14 +1536,16 @@ def main(): #include "ck_tile/dispatcher/registry.hpp" #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/backends/generated_conv_backend.hpp" namespace generated {{ // Kernel launchers for direct use {launcher_section} -// Registration function -inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{ +// Registration function (takes GroupedConvRegistry for conv kernels) +inline void {func_name}(ck_tile::dispatcher::GroupedConvRegistry& registry, const std::string& arch) {{ {register_body} }} @@ -1439,7 +1595,7 @@ inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::stri """ header_path.write_text(header_content) - print(f"[{target_name}] ✓ {len(obj_files)} kernels compiled") + print(f"[{target_name}] OK {len(obj_files)} kernels compiled") return 0 diff --git a/dispatcher/scripts/generate_conv_dispatch_header.py b/dispatcher/scripts/generate_conv_dispatch_header.py new file mode 100644 index 0000000000..55cc085ed9 --- /dev/null +++ b/dispatcher/scripts/generate_conv_dispatch_header.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""Generate the conv_python_dispatch.hpp header for the Python conv library. + +Reads the include_all headers to find available kernels and creates dispatch +aliases for 2D/3D x fwd/bwd_data/bwd_weight. +""" + +import argparse +import re +from pathlib import Path + + +def find_3d_launcher(include_all_path: Path, variant_prefix: str) -> str: + """Find first 3D launcher name from an include_all header.""" + text = include_all_path.read_text() + pattern = rf"(grouped_conv_{variant_prefix}_\w+_3d_\w+)\.hpp" + match = re.search(pattern, text) + if match: + return match.group(1) + "_Launcher" + return "" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--kernel-dir", required=True) + parser.add_argument("--output", required=True) + args = parser.parse_args() + + kdir = Path(args.kernel_dir) + + fwd_3d = find_3d_launcher(kdir / "include_all_grouped_conv_fwd_kernels.hpp", "fwd") + bwd_data_3d = find_3d_launcher( + kdir / "include_all_grouped_conv_bwd_data_kernels.hpp", "bwd_data" + ) + bwd_weight_3d = find_3d_launcher( + kdir / "include_all_grouped_conv_bwd_weight_kernels.hpp", "bwd_weight" + ) + + lines = [ + "// Auto-generated dispatch header for Python conv library", + "#pragma once", + "", + "// Forward kernels", + '#include "include_all_grouped_conv_fwd_kernels.hpp"', + "#define CONV_FWD_2D_AVAILABLE 1", + ] + if fwd_3d: + lines += [ + "#define CONV_FWD_3D_AVAILABLE 1", + f"using ConvFwd3dLauncher = {fwd_3d};", + ] + lines += [ + "", + "// Backward data kernels", + '#include "include_all_grouped_conv_bwd_data_kernels.hpp"', + "#define CONV_BWD_DATA_2D_AVAILABLE 1", + ] + if bwd_data_3d: + lines += [ + "#define CONV_BWD_DATA_3D_AVAILABLE 1", + f"using ConvBwdData3dLauncher = {bwd_data_3d};", + ] + lines += [ + "", + "// Backward weight kernels", + '#include "include_all_grouped_conv_bwd_weight_kernels.hpp"', + "#define CONV_BWD_WEIGHT_2D_AVAILABLE 1", + ] + if bwd_weight_3d: + lines += [ + "#define CONV_BWD_WEIGHT_3D_AVAILABLE 1", + f"using ConvBwdWeight3dLauncher = {bwd_weight_3d};", + ] + + # Kernel name table for Python introspection + names = [] + if True: # fwd 2D always present + names.append('"fwd_2d"') + if fwd_3d: + names.append('"fwd_3d"') + if True: # bwd_data 2D + names.append('"bwd_data_2d"') + if bwd_data_3d: + names.append('"bwd_data_3d"') + if True: # bwd_weight 2D + names.append('"bwd_weight_2d"') + if bwd_weight_3d: + names.append('"bwd_weight_3d"') + + lines += [ + "", + "// Kernel inventory for Python", + f"static const char* CONV_KERNEL_NAMES[] = {{{', '.join(names)}}};", + f"static const int CONV_KERNEL_COUNT = {len(names)};", + "", + ] + + Path(args.output).write_text("\n".join(lines) + "\n") + print(f"Generated dispatch header: {args.output} ({len(names)} kernels)") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/scripts/parallel_kernel_builder.py b/dispatcher/scripts/parallel_kernel_builder.py index 911ea61bd7..aef8f4ff0b 100755 --- a/dispatcher/scripts/parallel_kernel_builder.py +++ b/dispatcher/scripts/parallel_kernel_builder.py @@ -132,7 +132,7 @@ def main(): print(f"Linking failed: {result.stderr}") return 1 - print(f"✓ Built: {lib_path}") + print(f"OK Built: {lib_path}") return 0 diff --git a/dispatcher/scripts/stress_test_autocorrect.py b/dispatcher/scripts/stress_test_autocorrect.py index 13e92abffa..63b250071e 100644 --- a/dispatcher/scripts/stress_test_autocorrect.py +++ b/dispatcher/scripts/stress_test_autocorrect.py @@ -34,9 +34,9 @@ from compile_gemm_examples import ( # noqa: E402 validate_kernel_config, expand_declaration_with_arch_filter, ) -from compile_conv_examples import ( # noqa: E402 - validate_conv_kernel_config, - expand_conv_declaration_with_arch_filter, +from compile_grouped_conv_examples import ( # noqa: E402 + validate_grouped_conv_kernel_config as validate_conv_kernel_config, + expand_grouped_conv_declaration_with_arch_filter as expand_conv_declaration_with_arch_filter, ) @@ -316,7 +316,7 @@ def test_python_autocorrect(verbose=False): if was_modified: print(f" Modified: {len(corrections)} correction(s)") for c in corrections: - print(f" • {c}") + print(f" - {c}") except Exception as e: results["failed"] += 1 @@ -465,7 +465,7 @@ def run_stress_test(arch, num_samples, verbose): } expanded = expand_declaration_with_arch_filter(config, test_arch) - status = "✓" if expanded else "✗" + status = "OK" if expanded else "FAIL" expected = test_arch in test["expected_archs"] match = "OK" if (bool(expanded) == expected) else "MISMATCH" diff --git a/dispatcher/src/dispatcher.cpp b/dispatcher/src/dispatcher.cpp index fdb400921e..2cb589adf2 100644 --- a/dispatcher/src/dispatcher.cpp +++ b/dispatcher/src/dispatcher.cpp @@ -2,17 +2,18 @@ // SPDX-License-Identifier: MIT #include "ck_tile/dispatcher/dispatcher.hpp" -#include +#include "ck_tile/dispatcher/dispatcher_error.hpp" #include #include namespace ck_tile { namespace dispatcher { -Dispatcher::Dispatcher(Registry* registry) +Dispatcher::Dispatcher(Registry* registry, const std::string& gfx_arch) : registry_(registry ? registry : &Registry::instance()), heuristic_(nullptr), - strategy_(SelectionStrategy::FirstFit) + strategy_(SelectionStrategy::FirstFit), + gfx_arch_(gfx_arch) { } @@ -61,7 +62,7 @@ float Dispatcher::run_fused(const void* a_ptr, std::ostringstream oss; oss << "No suitable kernel found for problem: M=" << problem.M << " N=" << problem.N << " K=" << problem.K; - throw std::runtime_error(oss.str()); + throw NoKernelFound(oss.str()); } return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); @@ -78,7 +79,7 @@ float Dispatcher::run_explicit(const std::string& kernel_id, auto kernel = registry_->lookup(kernel_id); if(!kernel) { - throw std::runtime_error("Kernel not found: " + kernel_id); + throw NoKernelFound("Kernel not found: " + kernel_id); } if(!kernel->supports(problem)) @@ -86,7 +87,7 @@ float Dispatcher::run_explicit(const std::string& kernel_id, std::ostringstream oss; oss << "Kernel " << kernel_id << " does not support problem: M=" << problem.M << " N=" << problem.N << " K=" << problem.K; - throw std::runtime_error(oss.str()); + throw UnsupportedProblem(oss.str()); } return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); diff --git a/dispatcher/src/registry.cpp b/dispatcher/src/registry.cpp index 0d83afd613..f565885181 100644 --- a/dispatcher/src/registry.cpp +++ b/dispatcher/src/registry.cpp @@ -5,39 +5,32 @@ #include "ck_tile/dispatcher/json_export.hpp" #include "ck_tile/dispatcher/arch_filter.hpp" #include +#include +#include namespace ck_tile { namespace dispatcher { -Registry::Registry() - : name_("default"), - auto_export_enabled_(false), - auto_export_include_statistics_(true), - auto_export_on_every_registration_(true) -{ -} +Registry::Registry() = default; Registry::~Registry() { - // Perform auto-export on destruction if enabled (regardless of export_on_every_registration - // setting) if(auto_export_enabled_) { perform_auto_export(); } } -Registry::Registry(Registry&& other) noexcept - : mutex_() // mutex is not movable, create new one - , - kernels_(std::move(other.kernels_)), - name_(std::move(other.name_)), - auto_export_enabled_(other.auto_export_enabled_), - auto_export_filename_(std::move(other.auto_export_filename_)), - auto_export_include_statistics_(other.auto_export_include_statistics_), - auto_export_on_every_registration_(other.auto_export_on_every_registration_) +Registry::Registry(Registry&& other) noexcept : Base(std::move(other)) { - // Disable auto-export on the moved-from object to prevent double export + // Base move constructor already locked+released other.mutex_. + // Re-acquire to safely read the remaining fields. + std::lock_guard lock(other.mutex()); + auto_export_enabled_ = other.auto_export_enabled_; + auto_export_filename_ = std::move(other.auto_export_filename_); + auto_export_include_statistics_ = other.auto_export_include_statistics_; + auto_export_on_every_registration_ = other.auto_export_on_every_registration_; + other.auto_export_enabled_ = false; } @@ -45,11 +38,7 @@ Registry& Registry::operator=(Registry&& other) noexcept { if(this != &other) { - std::lock_guard lock(mutex_); - std::lock_guard other_lock(other.mutex_); - - kernels_ = std::move(other.kernels_); - name_ = std::move(other.name_); + Base::operator=(std::move(other)); auto_export_enabled_ = other.auto_export_enabled_; auto_export_filename_ = std::move(other.auto_export_filename_); auto_export_include_statistics_ = other.auto_export_include_statistics_; @@ -64,55 +53,27 @@ Registry& Registry::operator=(Registry&& other) noexcept bool Registry::register_kernel(KernelInstancePtr instance, Priority priority) { if(!instance) - { return false; - } - const std::string identifier = instance->get_key().encode_identifier(); - - bool registered = false; + if(Base::register_kernel(instance->get_name(), instance, priority)) { - std::lock_guard lock(mutex_); - - auto it = kernels_.find(identifier); - if(it != kernels_.end()) + if(auto_export_enabled_ && auto_export_on_every_registration_) { - // Kernel with this identifier already exists - // Only replace if new priority is higher - if(priority > it->second.priority) - { - it->second.instance = instance; - it->second.priority = priority; - registered = true; - } - } - else - { - // New kernel, insert it - kernels_[identifier] = RegistryEntry{instance, priority}; - registered = true; + perform_auto_export(); } + return true; } - - // Perform auto-export if enabled and configured to export on every registration - if(registered && auto_export_enabled_ && auto_export_on_every_registration_) - { - perform_auto_export(); - } - - return registered; + return false; } KernelInstancePtr Registry::lookup(const std::string& identifier) const { - std::lock_guard lock(mutex_); - - auto it = kernels_.find(identifier); - if(it != kernels_.end()) + std::lock_guard lock(mutex()); + auto it = entries().find(identifier); + if(it != entries().end()) { return it->second.instance; } - return nullptr; } @@ -121,75 +82,23 @@ KernelInstancePtr Registry::lookup(const KernelKey& key) const return lookup(key.encode_identifier()); } -std::vector Registry::get_all() const -{ - std::lock_guard lock(mutex_); - - std::vector result; - result.reserve(kernels_.size()); - - for(const auto& pair : kernels_) - { - result.push_back(pair.second.instance); - } - - return result; -} +std::vector Registry::get_all() const { return Base::get_all_instances(); } std::vector Registry::filter(std::function predicate) const { - std::lock_guard lock(mutex_); - + std::lock_guard lock(mutex()); std::vector result; - - for(const auto& pair : kernels_) + for(const auto& [name, entry] : entries()) { - if(predicate(*pair.second.instance)) + if(predicate(*(entry.instance))) { - result.push_back(pair.second.instance); + result.push_back(entry.instance); } } - return result; } -std::size_t Registry::size() const -{ - std::lock_guard lock(mutex_); - return kernels_.size(); -} - -bool Registry::empty() const -{ - std::lock_guard lock(mutex_); - return kernels_.empty(); -} - -void Registry::clear() -{ - std::lock_guard lock(mutex_); - kernels_.clear(); -} - -const std::string& Registry::get_name() const -{ - std::lock_guard lock(mutex_); - return name_; -} - -void Registry::set_name(const std::string& name) -{ - std::lock_guard lock(mutex_); - name_ = name; -} - -Registry& Registry::instance() -{ - static Registry global_registry; - return global_registry; -} - std::string Registry::export_json(bool include_statistics) const { return export_registry_json(*this, include_statistics); @@ -204,7 +113,7 @@ void Registry::enable_auto_export(const std::string& filename, bool include_statistics, bool export_on_every_registration) { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); auto_export_enabled_ = true; auto_export_filename_ = filename; auto_export_include_statistics_ = include_statistics; @@ -213,13 +122,13 @@ void Registry::enable_auto_export(const std::string& filename, void Registry::disable_auto_export() { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); auto_export_enabled_ = false; } bool Registry::is_auto_export_enabled() const { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); return auto_export_enabled_; } @@ -230,7 +139,7 @@ void Registry::perform_auto_export() bool include_stats; { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); if(!auto_export_enabled_) { return; @@ -243,31 +152,15 @@ void Registry::perform_auto_export() export_json_to_file(filename, include_stats); } -std::size_t Registry::merge_from(const Registry& other, Priority priority) -{ - auto other_kernels = other.get_all(); - std::size_t merged_count = 0; - - for(const auto& kernel : other_kernels) - { - if(register_kernel(kernel, priority)) - { - merged_count++; - } - } - - return merged_count; -} - std::size_t Registry::filter_by_arch(const std::string& gpu_arch) { ArchFilter filter(gpu_arch); std::vector to_remove; { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); - for(const auto& pair : kernels_) + for(const auto& pair : entries()) { if(!filter.is_valid(pair.second.instance->get_key())) { @@ -277,12 +170,18 @@ std::size_t Registry::filter_by_arch(const std::string& gpu_arch) for(const auto& key : to_remove) { - kernels_.erase(key); + entries_mut().erase(key); } } return to_remove.size(); } +Registry& Registry::instance() +{ + static Registry global_registry; + return global_registry; +} + } // namespace dispatcher -} // namespace ck_tile +} // namespace ck_tile \ No newline at end of file diff --git a/dispatcher/tests/CMakeLists.txt b/dispatcher/tests/CMakeLists.txt index 6c20c18c95..a54feba284 100644 --- a/dispatcher/tests/CMakeLists.txt +++ b/dispatcher/tests/CMakeLists.txt @@ -217,6 +217,10 @@ endforeach() # Standalone integration tests (with their own main()) set(STANDALONE_TESTS test_minimal.cpp + test_grouped_conv_config.cpp + test_grouped_conv_problem.cpp + test_grouped_conv_kernel_decl.cpp + test_grouped_conv_registry.cpp ) foreach(test_source ${STANDALONE_TESTS}) diff --git a/dispatcher/tests/test_autocorrect.py b/dispatcher/tests/test_autocorrect.py index 0ec3ebda3c..3f52049f74 100644 --- a/dispatcher/tests/test_autocorrect.py +++ b/dispatcher/tests/test_autocorrect.py @@ -42,10 +42,10 @@ from compile_gemm_examples import ( # noqa: E402 expand_declaration_with_arch_filter, is_wildcard_declaration, ) -from compile_conv_examples import ( # noqa: E402 - validate_conv_kernel_config, - expand_conv_declaration_with_arch_filter, - is_conv_wildcard_declaration, +from compile_grouped_conv_examples import ( # noqa: E402 + validate_grouped_conv_kernel_config as validate_conv_kernel_config, + expand_grouped_conv_declaration_with_arch_filter as expand_conv_declaration_with_arch_filter, + is_grouped_conv_wildcard_declaration as is_conv_wildcard_declaration, ) from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402 diff --git a/dispatcher/tests/test_codegen_common.py b/dispatcher/tests/test_codegen_common.py new file mode 100644 index 0000000000..2efeaefb4d --- /dev/null +++ b/dispatcher/tests/test_codegen_common.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for codegen/codegen_common.py -- shared infrastructure for GEMM and grouped conv codegen. + +Phase 1a TDD: these tests are written BEFORE the implementation exists. +Run: python3 -m pytest tests/test_codegen_common.py -v +""" + +import sys +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from codegen_common import ( # noqa: E402 + TileConfig, + TraitConfigBase, + CommonTypeMappings, + generate_cpp_compilation_unit, + parallel_generate, + valid_wave_configs, + valid_warp_configs, + valid_trait_configs, + needs_wave_expansion, + needs_warp_expansion, + needs_pipeline_expansion, +) + + +class TestTileConfig(unittest.TestCase): + """TileConfig dataclass tests.""" + + def test_valid_config(self): + tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertTrue(tc.is_valid()) + + def test_zero_tile_invalid(self): + tc = TileConfig(0, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertFalse(tc.is_valid()) + + def test_non_divisible_invalid(self): + tc = TileConfig(127, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertFalse(tc.is_valid()) + + def test_all_fields_accessible(self): + tc = TileConfig(256, 128, 64, 4, 1, 1, 32, 32, 16) + self.assertEqual(tc.tile_m, 256) + self.assertEqual(tc.tile_n, 128) + self.assertEqual(tc.tile_k, 64) + self.assertEqual(tc.warp_m, 4) + self.assertEqual(tc.warp_n, 1) + self.assertEqual(tc.warp_k, 1) + self.assertEqual(tc.warp_tile_m, 32) + self.assertEqual(tc.warp_tile_n, 32) + self.assertEqual(tc.warp_tile_k, 16) + + def test_small_valid_config(self): + tc = TileConfig(16, 16, 16, 1, 1, 1, 16, 16, 16) + self.assertTrue(tc.is_valid()) + + +class TestTraitConfigBase(unittest.TestCase): + """TraitConfigBase dataclass tests.""" + + def test_valid_intrawave(self): + tc = TraitConfigBase("compv3", "cshuffle", "intrawave", False, False, False) + self.assertTrue(tc.is_valid()) + + def test_invalid_interwave_compv3(self): + tc = TraitConfigBase("compv3", "cshuffle", "interwave", False, False, False) + self.assertFalse(tc.is_valid()) + + def test_invalid_interwave_compv4(self): + tc = TraitConfigBase("compv4", "cshuffle", "interwave", False, False, False) + self.assertFalse(tc.is_valid()) + + def test_valid_mem_interwave(self): + tc = TraitConfigBase("mem", "cshuffle", "interwave", False, False, False) + self.assertTrue(tc.is_valid()) + + def test_valid_mem_intrawave(self): + tc = TraitConfigBase("mem", "cshuffle", "intrawave", False, False, False) + self.assertTrue(tc.is_valid()) + + def test_padding_fields(self): + tc = TraitConfigBase("compv3", "cshuffle", "intrawave", True, True, True) + self.assertTrue(tc.pad_m) + self.assertTrue(tc.pad_n) + self.assertTrue(tc.pad_k) + + +class TestCommonTypeMappings(unittest.TestCase): + """CommonTypeMappings tests.""" + + def test_dtype_to_ck(self): + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp16"], "fp16_t") + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["bf16"], "bf16_t") + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp32"], "float") + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp8"], "fp8_t") + + def test_pipeline_to_ck(self): + self.assertEqual( + CommonTypeMappings.PIPELINE_TO_CK["mem"], "GemmPipelineAgBgCrMem" + ) + self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_CK) + self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_CK) + + def test_pipeline_to_base(self): + self.assertIn("mem", CommonTypeMappings.PIPELINE_TO_BASE) + self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_BASE) + self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_BASE) + + def test_scheduler_to_ck(self): + self.assertIn("intrawave", CommonTypeMappings.SCHEDULER_TO_CK) + self.assertIn("interwave", CommonTypeMappings.SCHEDULER_TO_CK) + + def test_epilogue_to_dispatcher(self): + self.assertIn("cshuffle", CommonTypeMappings.EPILOGUE_TO_DISPATCHER) + self.assertIn("default", CommonTypeMappings.EPILOGUE_TO_DISPATCHER) + + def test_layout_to_ck(self): + self.assertIn("r", CommonTypeMappings.LAYOUT_TO_CK) + self.assertIn("c", CommonTypeMappings.LAYOUT_TO_CK) + + def test_get_output_dtype(self): + self.assertEqual(CommonTypeMappings.get_output_dtype("fp8"), "fp16") + self.assertEqual(CommonTypeMappings.get_output_dtype("bf8"), "fp16") + self.assertEqual(CommonTypeMappings.get_output_dtype("fp16"), "fp16") + self.assertEqual(CommonTypeMappings.get_output_dtype("fp32"), "fp32") + + +class TestGenerateCppCompilationUnit(unittest.TestCase): + """Tests for generate_cpp_compilation_unit.""" + + def test_includes_kernel_header(self): + result = generate_cpp_compilation_unit("my_kernel") + self.assertIn('#include "my_kernel.hpp"', result) + + def test_contains_pragma_once_or_guard(self): + result = generate_cpp_compilation_unit("test_kernel") + self.assertIn("test_kernel", result) + + def test_different_names_different_output(self): + a = generate_cpp_compilation_unit("kernel_a") + b = generate_cpp_compilation_unit("kernel_b") + self.assertNotEqual(a, b) + + +class TestParallelGenerate(unittest.TestCase): + """Tests for parallel_generate helper.""" + + def _dummy_generate(self, item): + return f"generated_{item}" + + def test_parallel_returns_all(self): + items = ["a", "b", "c", "d"] + results = parallel_generate(self._dummy_generate, items, parallel=True) + self.assertEqual(len(results), 4) + for item in items: + self.assertIn(f"generated_{item}", results) + + def test_sequential_returns_all(self): + items = ["x", "y", "z"] + results = parallel_generate(self._dummy_generate, items, parallel=False) + self.assertEqual(len(results), 3) + for item in items: + self.assertIn(f"generated_{item}", results) + + def test_empty_items(self): + results = parallel_generate(self._dummy_generate, [], parallel=True) + self.assertEqual(len(results), 0) + + def test_logs_per_kernel_progress(self): + items = ["k1", "k2"] + with self.assertLogs(level="INFO") as cm: + parallel_generate(self._dummy_generate, items, parallel=False) + log_output = "\n".join(cm.output) + self.assertIn("k1", log_output) + self.assertIn("k2", log_output) + + +class TestArchAwareExpansion(unittest.TestCase): + """Tests for arch-aware expansion helpers (best-of-conv).""" + + def test_valid_wave_configs_gfx942(self): + configs = valid_wave_configs("gfx942") + self.assertIsInstance(configs, list) + self.assertIn([2, 2, 1], configs) + self.assertIn([1, 4, 1], configs) + + def test_valid_wave_configs_unknown_arch(self): + configs = valid_wave_configs("gfx_unknown") + self.assertIsInstance(configs, list) + self.assertGreater(len(configs), 0) + + def test_valid_warp_configs_gfx942_fp16(self): + configs = valid_warp_configs("gfx942", "fp16") + self.assertIsInstance(configs, list) + self.assertIn([32, 32, 16], configs) + + def test_valid_warp_configs_unknown_arch(self): + configs = valid_warp_configs("gfx_unknown", "fp16") + self.assertIsInstance(configs, list) + self.assertGreater(len(configs), 0) + + def test_valid_trait_configs_excludes_interwave_compute(self): + configs = valid_trait_configs() + self.assertIsInstance(configs, list) + self.assertNotIn(("compv3", "cshuffle", "interwave"), configs) + self.assertNotIn(("compv4", "cshuffle", "interwave"), configs) + + def test_valid_trait_configs_includes_mem_interwave(self): + configs = valid_trait_configs() + has_mem_interwave = any(p == "mem" and s == "interwave" for p, s in configs) + self.assertTrue(has_mem_interwave) + + def test_needs_wave_expansion_wildcard(self): + self.assertTrue(needs_wave_expansion({"wave_m": -1, "wave_n": 2})) + self.assertTrue(needs_wave_expansion({"wave_m": 2, "wave_n": -1})) + + def test_needs_wave_expansion_explicit(self): + self.assertFalse(needs_wave_expansion({"wave_m": 2, "wave_n": 2})) + + def test_needs_warp_expansion_wildcard(self): + self.assertTrue(needs_warp_expansion({"warp_m": -1, "warp_n": 32})) + + def test_needs_warp_expansion_explicit(self): + self.assertFalse(needs_warp_expansion({"warp_m": 32, "warp_n": 32})) + + def test_needs_pipeline_expansion_wildcard(self): + self.assertTrue(needs_pipeline_expansion({"pipeline": "*"})) + + def test_needs_pipeline_expansion_explicit(self): + self.assertFalse(needs_pipeline_expansion({"pipeline": "compv4"})) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_dispatcher_common.py b/dispatcher/tests/test_dispatcher_common.py new file mode 100644 index 0000000000..2c0fc8307c --- /dev/null +++ b/dispatcher/tests/test_dispatcher_common.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for python/dispatcher_common.py -- shared Python dispatcher utilities. + +Phase 1b TDD: tests written BEFORE implementation exists. +Run: python3 -m pytest tests/test_dispatcher_common.py -v +""" + +import io +import sys +import unittest +from pathlib import Path +from unittest.mock import patch + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from dispatcher_common import ( # noqa: E402 + get_dispatcher_root, + get_ck_root, + get_build_dir, + get_generated_kernels_dir, + get_arch_filter_data, + ValidationResultBase, + validate_wave_config, + validate_warp_tile_config, + validate_trait_combo, + auto_correct_wave, + auto_correct_trait, + Colors, + print_phase, + print_success, + print_error, + print_info, + cleanup_generated_kernels, +) + + +class TestPathHelpers(unittest.TestCase): + """Tests for path helper functions.""" + + def test_dispatcher_root_contains_codegen(self): + root = get_dispatcher_root() + self.assertTrue((root / "codegen").exists()) + + def test_ck_root_contains_include_or_is_parent(self): + root = get_ck_root() + self.assertTrue(root.exists()) + self.assertEqual(root, get_dispatcher_root().parent) + + def test_build_dir_is_under_dispatcher(self): + build = get_build_dir() + self.assertEqual(build.parent, get_dispatcher_root()) + + def test_generated_kernels_dir_under_build(self): + gen_dir = get_generated_kernels_dir() + self.assertEqual(gen_dir.parent, get_build_dir()) + + +class TestGetArchFilterData(unittest.TestCase): + """Tests for get_arch_filter_data.""" + + def test_returns_dict(self): + data = get_arch_filter_data() + self.assertIsInstance(data, dict) + + def test_has_warp_combos(self): + data = get_arch_filter_data() + self.assertIn("warp_combos", data) + + def test_has_warp_tile_combos(self): + data = get_arch_filter_data() + self.assertIn("warp_tile_combos", data) + + def test_has_trait_unsupported(self): + data = get_arch_filter_data() + self.assertIn("trait_unsupported", data) + + def test_has_supported_archs(self): + data = get_arch_filter_data() + self.assertIn("supported_archs", data) + self.assertIn("gfx942", data["supported_archs"]) + + def test_gfx942_wave_configs(self): + data = get_arch_filter_data() + gfx942 = data["warp_combos"].get("gfx942", []) + self.assertIn([2, 2, 1], gfx942) + + +class TestValidationResultBase(unittest.TestCase): + """Tests for ValidationResultBase dataclass.""" + + def test_valid_result(self): + vr = ValidationResultBase(is_valid=True) + self.assertTrue(vr.is_valid) + self.assertEqual(vr.errors, []) + self.assertEqual(vr.warnings, []) + self.assertEqual(vr.suggested_fixes, {}) + + def test_invalid_result(self): + vr = ValidationResultBase( + is_valid=False, + errors=["bad wave"], + suggested_fixes={"wave_m": 2}, + ) + self.assertFalse(vr.is_valid) + self.assertEqual(len(vr.errors), 1) + self.assertIn("wave_m", vr.suggested_fixes) + + +class TestValidateWaveConfig(unittest.TestCase): + """Tests for validate_wave_config.""" + + def test_valid_wave(self): + is_valid, msg = validate_wave_config([2, 2, 1], "gfx942") + self.assertTrue(is_valid) + self.assertEqual(msg, "") + + def test_invalid_wave(self): + is_valid, msg = validate_wave_config([3, 3, 1], "gfx942") + self.assertFalse(is_valid) + self.assertIn("wave", msg.lower()) + + +class TestValidateWarpTileConfig(unittest.TestCase): + """Tests for validate_warp_tile_config.""" + + def test_valid_warp_tile(self): + is_valid, msg = validate_warp_tile_config([32, 32, 16], "gfx942", "fp16") + self.assertTrue(is_valid) + + def test_invalid_warp_tile(self): + is_valid, msg = validate_warp_tile_config([99, 99, 99], "gfx942", "fp16") + self.assertFalse(is_valid) + self.assertIn("warp", msg.lower()) + + +class TestValidateTraitCombo(unittest.TestCase): + """Tests for validate_trait_combo.""" + + def test_valid_trait(self): + is_valid, msg = validate_trait_combo("compv3", "cshuffle", "intrawave") + self.assertTrue(is_valid) + + def test_invalid_trait_interwave_compute(self): + is_valid, msg = validate_trait_combo("compv4", "cshuffle", "interwave") + self.assertFalse(is_valid) + + def test_valid_mem_interwave(self): + is_valid, msg = validate_trait_combo("mem", "cshuffle", "interwave") + self.assertTrue(is_valid) + + +class TestAutoCorrectWave(unittest.TestCase): + """Tests for auto_correct_wave.""" + + def test_corrects_invalid_wave(self): + corrected = auto_correct_wave([1, 1, 1], "gfx942") + self.assertIsInstance(corrected, list) + self.assertEqual(len(corrected), 3) + data = get_arch_filter_data() + valid_waves = data["warp_combos"].get("gfx942", [[2, 2, 1]]) + self.assertIn(corrected, valid_waves) + + +class TestAutoCorrectTrait(unittest.TestCase): + """Tests for auto_correct_trait.""" + + def test_corrects_invalid_scheduler(self): + corrected_pipeline, corrected_scheduler = auto_correct_trait( + "compv4", "interwave" + ) + self.assertEqual(corrected_scheduler, "intrawave") + + +class TestColors(unittest.TestCase): + """Tests for Colors class (cross-platform ANSI support from conv).""" + + def test_green_returns_string(self): + result = Colors.green("ok") + self.assertIn("ok", result) + + def test_red_returns_string(self): + result = Colors.red("error") + self.assertIn("error", result) + + def test_yellow_returns_string(self): + result = Colors.yellow("warn") + self.assertIn("warn", result) + + def test_bold_returns_string(self): + result = Colors.bold("title") + self.assertIn("title", result) + + def test_plain_mode_no_ansi(self): + with patch.object(Colors, "_use_color", return_value=False): + result = Colors.green("plain") + self.assertEqual(result, "plain") + + +class TestPhasedOutput(unittest.TestCase): + """Tests for phased output helpers.""" + + def test_print_phase(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_phase(1, "Setup") + self.assertIn("Setup", buf.getvalue()) + + def test_print_success(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_success("Done") + self.assertIn("Done", buf.getvalue()) + + def test_print_error(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_error("Oops") + self.assertIn("Oops", buf.getvalue()) + + def test_print_info(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_info("FYI") + self.assertIn("FYI", buf.getvalue()) + + +class TestCleanup(unittest.TestCase): + """Tests for cleanup_generated_kernels.""" + + def test_cleanup_nonexistent_dir_no_error(self): + cleanup_generated_kernels(Path("/tmp/nonexistent_ck_test_dir_12345")) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_examples_integration.py b/dispatcher/tests/test_examples_integration.py index cfd18a3305..d02ea69787 100644 --- a/dispatcher/tests/test_examples_integration.py +++ b/dispatcher/tests/test_examples_integration.py @@ -28,14 +28,18 @@ sys.path.insert(0, str(PYTHON_DIR)) def run_python_example( - example_path: Path, timeout: int = 120 + example_path: Path, timeout: int = 120, extra_args: list = None ) -> subprocess.CompletedProcess: """Run a Python example and capture output.""" env = os.environ.copy() env["PYTHONPATH"] = str(PYTHON_DIR) + cmd = [sys.executable, str(example_path)] + if extra_args: + cmd.extend(extra_args) + return subprocess.run( - [sys.executable, str(example_path)], + cmd, capture_output=True, text=True, timeout=timeout, @@ -111,61 +115,74 @@ class TestGemmPythonExamples(unittest.TestCase): result = run_python_example(example) self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - # Should pass validation self.assertIn("PASS", result.stdout.upper(), "Validation should pass") class TestConvPythonExamples(unittest.TestCase): - """Test Conv Python examples.""" + """Test grouped conv Python examples.""" @classmethod def setUpClass(cls): """Check if examples directory exists.""" - cls.conv_examples_dir = EXAMPLES_DIR / "conv" / "python" + cls.conv_examples_dir = EXAMPLES_DIR / "grouped_conv" / "python" if not cls.conv_examples_dir.exists(): - raise unittest.SkipTest("Conv Python examples not found") + raise unittest.SkipTest("Grouped conv Python examples not found") - def test_01_basic_conv(self): - """Test basic conv example.""" - example = self.conv_examples_dir / "01_basic_conv.py" + def test_01_basic_grouped_conv(self): + """Test basic grouped conv example.""" + example = self.conv_examples_dir / "01_basic_grouped_conv.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + self.assertIn("PASS", result.stdout.upper()) - def test_02_conv2d_fwd(self): - """Test 2D forward conv example.""" - example = self.conv_examples_dir / "02_conv2d_fwd.py" + def test_02_forward(self): + """Test forward conv example (2D + 3D).""" + example = self.conv_examples_dir / "02_forward.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) - def test_03_conv3d_fwd(self): - """Test 3D forward conv example.""" - example = self.conv_examples_dir / "03_conv3d_fwd.py" + def test_03_bwd_data(self): + """Test backward data example.""" + example = self.conv_examples_dir / "03_bwd_data.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) - def test_07_validation(self): - """Test validation example.""" - example = self.conv_examples_dir / "07_validation.py" + def test_04_bwd_weight(self): + """Test backward weight example.""" + example = self.conv_examples_dir / "04_bwd_weight.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + self.assertIn("PASS", result.stdout.upper()) + + def test_05_benchmark(self): + """Test benchmark example.""" + example = self.conv_examples_dir / "05_benchmark.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + result = run_python_example( + example, extra_args=["--warmup", "1", "--repeat", "1"] + ) + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) + + def test_06_registry_json(self): + """Test registry + heuristic + JSON example.""" + example = self.conv_examples_dir / "06_registry_json.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + result = run_python_example(example) + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) class TestGemmCppExamples(unittest.TestCase): @@ -195,18 +212,18 @@ class TestGemmCppExamples(unittest.TestCase): self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - def test_gemm_04_validation(self): - """Test validation GEMM C++ example.""" - result = run_cpp_example("gemm_04_validation") + def test_gemm_03_benchmark_validation(self): + """Test benchmark+validation GEMM C++ example.""" + result = run_cpp_example("gemm_03_benchmark_validation") if result is None: - self.skipTest("gemm_04_validation not built") + self.skipTest("gemm_03_benchmark_validation not built") self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") self.assertIn("PASS", result.stdout.upper(), "Validation should pass") class TestConvCppExamples(unittest.TestCase): - """Test Conv C++ examples.""" + """Test grouped conv C++ examples.""" @classmethod def setUpClass(cls): @@ -215,23 +232,29 @@ class TestConvCppExamples(unittest.TestCase): if not cls.examples_dir.exists(): raise unittest.SkipTest("C++ examples not built") - def test_conv_01_forward(self): - """Test forward conv C++ example.""" - result = run_cpp_example("conv_01_forward") + def test_grouped_conv_01_basic(self): + """Test basic grouped conv C++ example.""" + result = run_cpp_example("grouped_conv_01_basic") if result is None: - self.skipTest("conv_01_forward not built") - + self.skipTest("grouped_conv_01_basic not built") self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + self.assertIn("PASS", result.stdout.upper()) - def test_conv_02_validation(self): - """Test validation conv C++ example.""" - result = run_cpp_example("conv_02_validation") + def test_grouped_conv_02_all_dirs(self): + """Test all directions grouped conv C++ example.""" + result = run_cpp_example("grouped_conv_02_all_dirs") if result is None: - self.skipTest("conv_02_validation not built") - + self.skipTest("grouped_conv_02_all_dirs not built") self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + self.assertIn("PASS", result.stdout.upper()) + + def test_grouped_conv_03_bench_val(self): + """Test benchmark+validation grouped conv C++ example.""" + result = run_cpp_example("grouped_conv_03_bench_val") + if result is None: + self.skipTest("grouped_conv_03_bench_val not built") + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) class TestUtilityImports(unittest.TestCase): @@ -246,14 +269,18 @@ class TestUtilityImports(unittest.TestCase): except ImportError as e: self.fail(f"Failed to import ctypes_utils: {e}") - def test_import_conv_utils(self): - """Test importing conv_utils.""" + def test_import_grouped_conv_utils(self): + """Test importing grouped_conv_utils.""" try: - from conv_utils import ConvSignature, ConvAlgorithm, ConvProblem # noqa: F401 + from grouped_conv_utils import ( # noqa: F401 + GroupedConvValidationResult, + validate_grouped_conv_config, + GroupedConvDataType, + ) self.assertTrue(True) except ImportError as e: - self.fail(f"Failed to import conv_utils: {e}") + self.fail(f"Failed to import grouped_conv_utils: {e}") def test_kernel_config_creation(self): """Test creating a KernelConfig.""" @@ -272,22 +299,19 @@ class TestUtilityImports(unittest.TestCase): self.assertEqual(config.dtype_a, "fp16") self.assertEqual(config.layout_a, "row") - def test_conv_signature_creation(self): - """Test creating a ConvSignature.""" - from conv_utils import ConvSignature + def test_grouped_conv_default_config(self): + """Test creating a grouped conv default config.""" + from grouped_conv_utils import get_grouped_conv_default_config - sig = ConvSignature( - dtype_in="fp16", - dtype_wei="fp16", - dtype_out="fp16", - dtype_acc="fp32", - layout="nhwgc", - direction="forward", - num_dims=2, + config = get_grouped_conv_default_config( + variant="forward", + ndim_spatial=2, + arch="gfx942", ) - self.assertEqual(sig.dtype_in, "fp16") - self.assertEqual(sig.direction, "forward") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertEqual(d["variant"], "forward") + self.assertEqual(d["arch"], "gfx942") class TestAutoCorrection(unittest.TestCase): @@ -316,21 +340,22 @@ class TestAutoCorrection(unittest.TestCase): self.assertTrue(was_modified, "Config should be modified") self.assertGreater(len(corrections), 0, "Should have corrections") - def test_conv_auto_correct(self): - """Test Conv auto-correction.""" - from conv_utils import auto_correct_conv_config - - # Call with invalid wave config parameters - corrected, was_modified, corrections = auto_correct_conv_config( - wave_m=99, # Invalid - wave_n=99, # Invalid - wave_k=99, # Invalid - dtype="fp16", - arch="gfx942", + def test_grouped_conv_auto_correct(self): + """Test Grouped Conv auto-correction.""" + from grouped_conv_utils import ( + auto_correct_grouped_conv_config, + get_grouped_conv_default_config, ) - self.assertTrue(was_modified, "Config should be modified") - self.assertGreater(len(corrections), 0, "Should have corrections") + config = get_grouped_conv_default_config() + d = config.to_dict() if hasattr(config, "to_dict") else config + d["tile_config"]["warp_m"] = [99] + d["tile_config"]["warp_n"] = [99] + + corrected, result = auto_correct_grouped_conv_config(d) + + self.assertIsInstance(corrected, dict) + self.assertIn("tile_config", corrected) if __name__ == "__main__": diff --git a/dispatcher/tests/test_grouped_conv_codegen.py b/dispatcher/tests/test_grouped_conv_codegen.py new file mode 100644 index 0000000000..acfa5abd8f --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_codegen.py @@ -0,0 +1,589 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +TDD tests for codegen/unified_grouped_conv_codegen.py -- grouped convolution code generator. + +These tests are written BEFORE the implementation exists. +Run: python3 -m pytest dispatcher/tests/test_grouped_conv_codegen.py -v +""" + +import sys +import unittest +from pathlib import Path +from unittest.mock import patch + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +from codegen_common import TileConfig, TraitConfigBase # noqa: E402 + +from unified_grouped_conv_codegen import ( # noqa: E402 + GroupedConvVariant, + GroupedConvLayout, + GroupedConvKernelConfig, + GroupedConvTypeMappings, + GroupedConvTraitConfig, + CKTileGroupedConvKernelGenerator, + GroupedConvDispatcherWrapperGenerator, + UnifiedGroupedConvCodegen, +) + + +# ============================================================================= +# TestGroupedConvVariant +# ============================================================================= + + +class TestGroupedConvVariant(unittest.TestCase): + """Test GroupedConvVariant enum values.""" + + def test_forward_value(self): + self.assertEqual(GroupedConvVariant.FORWARD.value, "forward") + + def test_backward_data_value(self): + self.assertEqual(GroupedConvVariant.BACKWARD_DATA.value, "bwd_data") + + def test_backward_weight_value(self): + self.assertEqual(GroupedConvVariant.BACKWARD_WEIGHT.value, "bwd_weight") + + def test_all_variants_exist(self): + self.assertIn(GroupedConvVariant.FORWARD, GroupedConvVariant) + self.assertIn(GroupedConvVariant.BACKWARD_DATA, GroupedConvVariant) + self.assertIn(GroupedConvVariant.BACKWARD_WEIGHT, GroupedConvVariant) + + +# ============================================================================= +# TestGroupedConvLayout +# ============================================================================= + + +class TestGroupedConvLayout(unittest.TestCase): + """Test GroupedConvLayout enum for 1D/2D/3D layouts.""" + + def test_nhwgc_value(self): + self.assertEqual(GroupedConvLayout.NHWGC.value, "NHWGC") + + def test_gkyxc_value(self): + self.assertEqual(GroupedConvLayout.GKYXC.value, "GKYXC") + + def test_nhwgk_value(self): + self.assertEqual(GroupedConvLayout.NHWGK.value, "NHWGK") + + def test_1d_layouts_exist(self): + """1D conv layouts (e.g., NWGC, GYXC, NWGK).""" + layouts_1d = [ + lay + for lay in GroupedConvLayout + if "W" in lay.value and "H" not in lay.value + ] + self.assertGreater(len(layouts_1d), 0) + + def test_2d_layouts_exist(self): + """2D conv layouts (e.g., NHWGC, GKYXC, NHWGK).""" + layouts_2d = [lay for lay in GroupedConvLayout if "HW" in lay.value] + self.assertGreater(len(layouts_2d), 0) + + def test_3d_layouts_exist(self): + """3D conv layouts (e.g., NDHWGC, GDKYXC).""" + layouts_3d = [ + lay for lay in GroupedConvLayout if "D" in lay.value or "DHW" in lay.value + ] + self.assertGreater(len(layouts_3d), 0) + + +# ============================================================================= +# TestGroupedConvKernelConfig +# ============================================================================= + + +class TestGroupedConvKernelConfig(unittest.TestCase): + """Test GroupedConvKernelConfig dataclass.""" + + def _make_tile(self): + return TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + + def _make_trait(self): + return GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + + def test_name_contains_grouped_conv_fwd(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + name = config.name("fp16") + self.assertIn("grouped_conv_fwd", name) + + def test_name_backward_data_contains_bwd_data(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.BACKWARD_DATA, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + name = config.name("fp16") + self.assertIn("bwd_data", name) + + def test_is_valid_for_arch_supported(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + self.assertTrue(config.is_valid_for_arch("gfx942")) + + def test_is_valid_for_arch_unsupported(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + self.assertFalse(config.is_valid_for_arch("gfx600")) + + +# ============================================================================= +# TestGroupedConvTypeMappings +# ============================================================================= + + +class TestGroupedConvTypeMappings(unittest.TestCase): + """Test GroupedConvTypeMappings class.""" + + def test_dtype_to_ck_fp16(self): + self.assertEqual(GroupedConvTypeMappings.DTYPE_TO_CK["fp16"], "half_t") + + def test_dtype_to_ck_bf16(self): + self.assertIn("bf16", GroupedConvTypeMappings.DTYPE_TO_CK) + + def test_dtype_to_ck_fp32(self): + self.assertIn("fp32", GroupedConvTypeMappings.DTYPE_TO_CK) + + def test_get_layouts_2d_has_in_wei_out_keys(self): + layouts = GroupedConvTypeMappings.get_layouts(2) + self.assertIn("in", layouts) + self.assertIn("wei", layouts) + self.assertIn("out", layouts) + + def test_get_layouts_2d_returns_dict(self): + layouts = GroupedConvTypeMappings.get_layouts(2) + self.assertIsInstance(layouts, dict) + + def test_get_layouts_1d(self): + layouts = GroupedConvTypeMappings.get_layouts(1) + self.assertIn("in", layouts) + self.assertIn("wei", layouts) + self.assertIn("out", layouts) + + def test_get_layouts_3d(self): + layouts = GroupedConvTypeMappings.get_layouts(3) + self.assertIn("in", layouts) + self.assertIn("wei", layouts) + self.assertIn("out", layouts) + + +# ============================================================================= +# TestCKTileGroupedConvKernelGenerator +# ============================================================================= + + +class TestCKTileGroupedConvKernelGenerator(unittest.TestCase): + """Test CKTileGroupedConvKernelGenerator.generate().""" + + def _make_config(self): + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + return GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + + def test_generate_contains_pragma_once(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIn("#pragma once", result) + + def test_generate_contains_forward_kernel_include(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIn("grouped_convolution_forward_kernel.hpp", result) + + def test_generate_returns_non_empty_string(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIsInstance(result, str) + self.assertGreater(len(result), 100) + + def test_generate_valid_cpp_structure(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIn("#include", result) + self.assertIn("ck_tile", result) + + +# ============================================================================= +# TestGroupedConvDispatcherWrapperGenerator +# ============================================================================= + + +class TestGroupedConvDispatcherWrapperGenerator(unittest.TestCase): + """Test GroupedConvDispatcherWrapperGenerator.generate().""" + + def _make_config(self): + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + return GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + + def test_generate_contains_dispatcher_registration(self): + gen = GroupedConvDispatcherWrapperGenerator("fp16") + config = self._make_config() + kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp" + output_dir = DISPATCHER_DIR / "build" / "generated" + result = gen.generate(config, kernel_path, output_dir) + self.assertIn("dispatcher", result) + self.assertIn("KernelKey", result) + self.assertIn("KernelInstancePtr", result) + + def test_generate_contains_pragma_once(self): + gen = GroupedConvDispatcherWrapperGenerator("fp16") + config = self._make_config() + kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp" + output_dir = DISPATCHER_DIR / "build" / "generated" + result = gen.generate(config, kernel_path, output_dir) + self.assertIn("#pragma once", result) + + def test_generate_valid_cpp(self): + gen = GroupedConvDispatcherWrapperGenerator("fp16") + config = self._make_config() + kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp" + output_dir = DISPATCHER_DIR / "build" / "generated" + result = gen.generate(config, kernel_path, output_dir) + self.assertIn("#include", result) + self.assertIn("namespace", result) + + +# ============================================================================= +# TestUnifiedGroupedConvCodegen +# ============================================================================= + + +class TestUnifiedGroupedConvCodegen(unittest.TestCase): + """Test UnifiedGroupedConvCodegen.generate_all().""" + + def test_generate_all_returns_dict_with_expected_keys(self): + output_dir = DISPATCHER_DIR / "build" / "generated" / "grouped_conv" + output_dir.mkdir(parents=True, exist_ok=True) + codegen = UnifiedGroupedConvCodegen( + output_dir=output_dir, + datatype="fp16", + ndim_spatial=2, + gpu_target="gfx942", + ) + with patch.object( + codegen, + "_get_configs", + return_value=[], # Mock empty config list for fast test + ): + results = codegen.generate_all(parallel=False) + self.assertIn("kernels", results) + self.assertIn("wrappers", results) + self.assertIn("failed", results) + self.assertIsInstance(results["kernels"], list) + self.assertIsInstance(results["wrappers"], list) + self.assertIsInstance(results["failed"], list) + + def test_generate_all_with_mock_config_produces_output(self): + output_dir = DISPATCHER_DIR / "build" / "generated" / "grouped_conv_test" + output_dir.mkdir(parents=True, exist_ok=True) + codegen = UnifiedGroupedConvCodegen( + output_dir=output_dir, + datatype="fp16", + ndim_spatial=2, + gpu_target="gfx942", + ) + # Use a real config - patch the config source to return one config + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + config = GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + + with patch.object(codegen, "_get_configs", return_value=[config]): + results = codegen.generate_all(parallel=False) + self.assertIsInstance(results, dict) + self.assertIn("kernels", results) + + +# ============================================================================= +# TestSharedImports +# ============================================================================= + + +class TestSharedImports(unittest.TestCase): + """Verify TileConfig from codegen_common and GroupedConvTraitConfig extends TraitConfigBase.""" + + def test_tile_config_has_expected_fields(self): + """TileConfig from codegen_common has tile_m, tile_n, tile_k, etc.""" + tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertEqual(tc.tile_m, 128) + self.assertEqual(tc.tile_n, 128) + self.assertEqual(tc.tile_k, 32) + self.assertEqual(tc.warp_m, 2) + self.assertEqual(tc.warp_n, 2) + self.assertEqual(tc.warp_k, 1) + self.assertEqual(tc.warp_tile_m, 32) + self.assertEqual(tc.warp_tile_n, 32) + self.assertEqual(tc.warp_tile_k, 16) + + def test_tile_config_is_from_codegen_common(self): + """TileConfig used by grouped conv is the same as codegen_common.TileConfig.""" + tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertTrue(tc.is_valid()) + + def test_grouped_conv_trait_config_extends_trait_config_base(self): + """GroupedConvTraitConfig extends TraitConfigBase.""" + self.assertTrue(issubclass(GroupedConvTraitConfig, TraitConfigBase)) + + def test_grouped_conv_trait_config_has_double_smem_buffer(self): + """GroupedConvTraitConfig has double_smem_buffer field.""" + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=True, + num_groups_to_merge=2, + ) + self.assertTrue(trait.double_smem_buffer) + self.assertEqual(trait.num_groups_to_merge, 2) + + def test_grouped_conv_trait_config_has_num_groups_to_merge(self): + """GroupedConvTraitConfig has num_groups_to_merge field.""" + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=4, + ) + self.assertEqual(trait.num_groups_to_merge, 4) + + def test_grouped_conv_trait_config_inherits_base_fields(self): + """GroupedConvTraitConfig inherits pipeline, epilogue, scheduler from base.""" + trait = GroupedConvTraitConfig( + "compv4", + "cshuffle", + "intrawave", + True, + True, + True, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + self.assertEqual(trait.pipeline, "compv4") + self.assertEqual(trait.epilogue, "cshuffle") + self.assertEqual(trait.scheduler, "intrawave") + self.assertTrue(trait.pad_m) + self.assertTrue(trait.pad_n) + self.assertTrue(trait.pad_k) + + +# ============================================================================= +# TestTwoStageBwdWeightCodegen +# ============================================================================= + + +def _make_two_stage_config(): + """Helper: create a two-stage bwd_weight config.""" + return GroupedConvKernelConfig( + tile=TileConfig(16, 64, 64, 1, 4, 1, 16, 16, 32), + trait=GroupedConvTraitConfig( + pipeline="compv3", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + two_stage=True, + ), + variant=GroupedConvVariant.BACKWARD_WEIGHT, + ndim_spatial=2, + arch="gfx942", + ) + + +class TestTwoStageBwdWeightCodegen(unittest.TestCase): + """Tests for two-stage backward weight kernel generation.""" + + def test_kernel_name_contains_2stage(self): + config = _make_two_stage_config() + name = config.name("fp16") + self.assertIn("_2stage", name) + self.assertIn("bwd_weight", name) + + def test_single_stage_name_has_no_2stage(self): + config = _make_two_stage_config() + config.trait.two_stage = False + name = config.name("fp16") + self.assertNotIn("_2stage", name) + + def test_generate_contains_elementwise_include(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("elementwise.hpp", code) + + def test_generate_contains_workspace_type(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("WorkspaceDataType", code) + + def test_generate_contains_elementwise_kernel(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("ElementWiseKernel", code) + + def test_generate_contains_launch_kernel_time_mask(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("launch_kernel_time_mask", code) + + def test_generate_forces_vector_size_c_to_1(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("VectorSizeC_TwoStage = 1", code) + + def test_generate_contains_workspace_memset(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("hipMemsetAsync", code) + + def test_single_stage_does_not_contain_workspace(self): + config = _make_two_stage_config() + config.trait.two_stage = False + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertNotIn("WorkspaceDataType", code) + self.assertNotIn("ElementWiseKernel", code) + self.assertNotIn("launch_kernel_time_mask", code) + + def test_default_configs_include_two_stage(self): + from unified_grouped_conv_codegen import get_default_configs + + configs = get_default_configs( + arch="gfx942", + variants=[GroupedConvVariant.BACKWARD_WEIGHT], + ndims=[2], + ) + two_stage = [c for c in configs if c.trait.two_stage] + single_stage = [c for c in configs if not c.trait.two_stage] + self.assertGreater(len(two_stage), 0, "Should have two-stage configs") + self.assertGreater( + len(single_stage), 0, "Should still have single-stage configs" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_grouped_conv_config.cpp b/dispatcher/tests/test_grouped_conv_config.cpp new file mode 100644 index 0000000000..c9a1faeaf9 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_config.cpp @@ -0,0 +1,112 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvConfig using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include +#include + +using namespace ck_tile::dispatcher; + +void test_grouped_conv_direction_enum() +{ + std::cout << " test_grouped_conv_direction_enum... "; + assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::FORWARD) == + std::string("fwd")); + assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::BACKWARD_DATA) == + std::string("bwd_data")); + assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::BACKWARD_WEIGHT) == + std::string("bwd_weight")); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_signature_info() +{ + std::cout << " test_grouped_conv_signature_info... "; + GroupedConvSignatureInfo sig; + assert(sig.spatial_dim == 2); + assert(sig.direction == GroupedConvDirection::FORWARD); + assert(sig.in_type == "fp16"); + assert(sig.wei_type == "fp16"); + assert(sig.out_type == "fp16"); + assert(sig.acc_type == "fp32"); + assert(sig.num_groups == 1); + sig.in_type = "bf16"; + sig.num_groups = 4; + assert(sig.in_type == "bf16"); + assert(sig.num_groups == 4); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_algorithm_info() +{ + std::cout << " test_grouped_conv_algorithm_info... "; + GroupedConvAlgorithmInfo algo; + assert(algo.tile.m == 128); + assert(algo.tile.n == 128); + assert(algo.tile.k == 64); + assert(algo.pipeline == PipelineVersion::V4); + assert(algo.scheduler == PipelineScheduler::INTRAWAVE); + assert(GroupedConvAlgorithmInfo::pipeline_str(PipelineVersion::V4) == std::string("compv4")); + assert(GroupedConvAlgorithmInfo::scheduler_str(PipelineScheduler::INTRAWAVE) == + std::string("intrawave")); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_config() +{ + std::cout << " test_grouped_conv_config... "; + GroupedConvConfig cfg; + std::string name = cfg.name(); + assert(!name.empty()); + assert(name.find("grouped_conv_") != std::string::npos); + assert(name.find("fwd") != std::string::npos); + assert(name.find("fp16") != std::string::npos); + assert(name.find("2d") != std::string::npos); + + std::string brief = cfg.brief(); + assert(!brief.empty()); + assert(brief.find("2D") != std::string::npos || brief.find("Grouped") != std::string::npos); + + std::string detailed = cfg.detailed(); + assert(!detailed.empty()); + assert(detailed.find("Signature:") != std::string::npos); + assert(detailed.find("Algorithm:") != std::string::npos); + assert(detailed.find("Arch:") != std::string::npos); + std::cout << "PASSED\n"; +} + +void test_predefined_grouped_conv_configs() +{ + std::cout << " test_predefined_grouped_conv_configs... "; + configs::Memory mem_cfg; + assert(mem_cfg.algorithm.pipeline == PipelineVersion::MEMORY); + assert(mem_cfg.algorithm.tile.m == 128); + assert(mem_cfg.algorithm.tile.n == 32); + + configs::CompV3_Small compv3_small; + assert(compv3_small.algorithm.pipeline == PipelineVersion::V3); + assert(compv3_small.algorithm.tile.m == 16); + assert(compv3_small.algorithm.tile.n == 64); + + configs::CompV4 compv4; + assert(compv4.algorithm.pipeline == PipelineVersion::V4); + assert(compv4.algorithm.double_smem_buffer == true); + + configs::WMMA wmma_cfg; + assert(wmma_cfg.arch.name == "gfx1100"); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Config ===\n\n"; + test_grouped_conv_direction_enum(); + test_grouped_conv_signature_info(); + test_grouped_conv_algorithm_info(); + test_grouped_conv_config(); + test_predefined_grouped_conv_configs(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_kernel_decl.cpp b/dispatcher/tests/test_grouped_conv_kernel_decl.cpp new file mode 100644 index 0000000000..7b28a451bc --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_kernel_decl.cpp @@ -0,0 +1,141 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvKernelDecl using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_decl; + +void test_grouped_conv_signature_builder() +{ + std::cout << " test_grouped_conv_signature_builder... "; + GroupedConvSignature sig; + sig.dtype("fp16").layout("nhwc").conv_type("forward").dims(2).groups(4); + assert(sig.dtype_in_ == "fp16"); + assert(sig.dtype_wei_ == "fp16"); + assert(sig.dtype_out_ == "fp16"); + assert(sig.layout_ == "nhwc"); + assert(sig.conv_op_ == "forward"); + assert(sig.num_dims_ == 2); + assert(sig.groups_ == 4); + assert(sig.op_str() == "fwd"); + sig.conv_type("bwd_data"); + assert(sig.op_str() == "bwd_data"); + sig.conv_type("bwd_weight"); + assert(sig.op_str() == "bwd_weight"); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_algorithm_builder() +{ + std::cout << " test_grouped_conv_algorithm_builder... "; + GroupedConvAlgorithm algo; + algo.tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave"); + assert(algo.tile_m_ == 128); + assert(algo.tile_n_ == 128); + assert(algo.tile_k_ == 64); + assert(algo.wave_m_ == 2); + assert(algo.wave_n_ == 2); + assert(algo.warp_m_ == 32); + assert(algo.warp_n_ == 32); + assert(algo.pipeline_ == "compv4"); + assert(algo.scheduler_ == "intrawave"); + assert(!algo.needs_expansion()); + algo.wave_m_ = ANY_INT; + assert(algo.needs_wave_expansion()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_decl() +{ + std::cout << " test_grouped_conv_kernel_decl... "; + GroupedConvSignature sig; + sig.dtype("fp16").layout("nhwc").conv_type("forward").dims(2); + GroupedConvAlgorithm algo; + algo.tile(128, 128, 64).wave(2, 2, 1).warp(32, 32, 16); + GroupedConvKernelDecl decl(sig, algo, "gfx942"); + std::string name = decl.name(); + assert(!name.empty()); + assert(name.find("grouped_conv_") != std::string::npos); + assert(name.find("fwd") != std::string::npos); + assert(name.find("fp16") != std::string::npos); + assert(name.find("128x128x64") != std::string::npos); + assert(!decl.has_wildcards()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_set() +{ + std::cout << " test_grouped_conv_kernel_set... "; + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + assert(set.size() == 1); + set.add("fp16", "nhwc", "forward", 256, 256); + assert(set.size() == 2); + const auto& decls = set.declarations(); + assert(decls[0].algorithm.tile_n_ == 128); + assert(decls[0].algorithm.tile_k_ == 128); + assert(decls[1].algorithm.tile_n_ == 256); + assert(decls[1].algorithm.tile_k_ == 256); + set.tag("test_set"); + assert(set.tag() == "test_set"); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_set_merge() +{ + std::cout << " test_grouped_conv_kernel_set_merge... "; + GroupedConvKernelSet set1; + set1.add("fp16", "nhwc", "forward", 128, 128); + GroupedConvKernelSet set2; + set2.add("fp16", "nhwc", "forward", 256, 256); + set1.merge(set2); + assert(set1.size() == 2); + assert(set1.declarations()[0].algorithm.tile_n_ == 128); + assert(set1.declarations()[1].algorithm.tile_n_ == 256); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_set_registry() +{ + std::cout << " test_grouped_conv_kernel_set_registry... "; + auto& reg = GroupedConvKernelSetRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set("gconv_test", set); + assert(reg.has("gconv_test")); + assert(reg.size() >= 1); + + const auto& retrieved = reg.get("gconv_test"); + assert(retrieved.size() == 1); + + const auto& empty = reg.get("nonexistent"); + assert(empty.size() == 0); + + reg.clear(); + assert(!reg.has("gconv_test")); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Kernel Decl ===\n\n"; + test_grouped_conv_signature_builder(); + test_grouped_conv_algorithm_builder(); + test_grouped_conv_kernel_decl(); + test_grouped_conv_kernel_set(); + test_grouped_conv_kernel_set_merge(); + test_grouped_conv_kernel_set_registry(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_problem.cpp b/dispatcher/tests/test_grouped_conv_problem.cpp new file mode 100644 index 0000000000..a6a4d8ba08 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_problem.cpp @@ -0,0 +1,245 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvProblem using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; + +void test_grouped_conv_problem_defaults() +{ + std::cout << " test_grouped_conv_problem_defaults... "; + GroupedConvProblem p; + assert(p.N == 1); + assert(p.C == 64); + assert(p.K == 64); + assert(p.G == 1); + assert(p.Hi() == 28); + assert(p.Wi() == 28); + assert(p.Y() == 3); + assert(p.X() == 3); + assert(p.op == GroupedConvOp::Forward); + assert(p.stride[0] == 1 && p.stride[1] == 1 && p.stride[2] == 1); + assert(p.padding[0] == 0 && p.padding[1] == 0 && p.padding[2] == 0); + assert(p.dilation[0] == 1 && p.dilation[1] == 1 && p.dilation[2] == 1); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_2d() +{ + std::cout << " test_grouped_conv_problem_2d... "; + GroupedConvProblem p(4, 64, 128, 28, 28, 3, 3); + p.compute_output_size(); + assert(p.N == 4); + assert(p.C == 64); + assert(p.K == 128); + assert(p.Hi() == 28); + assert(p.Wi() == 28); + assert(p.Y() == 3); + assert(p.X() == 3); + assert(p.Ho() == 26); + assert(p.Wo() == 26); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_strided() +{ + std::cout << " test_grouped_conv_problem_strided... "; + GroupedConvProblem p; + p.N = 1; + p.C = 64; + p.K = 64; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 2, 2}; + p.padding = {0, 1, 1}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.Ho() == 7); + assert(p.Wo() == 7); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_grouped() +{ + std::cout << " test_grouped_conv_problem_grouped... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 64; + p.G = 4; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.G == 4); + assert(p.C % p.G == 0); + assert(p.K % p.G == 0); + assert(p.is_valid()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_depthwise() +{ + std::cout << " test_grouped_conv_problem_depthwise... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 64; + p.G = 64; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.is_depthwise()); + assert(p.G == p.C && p.G == p.K); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_pointwise() +{ + std::cout << " test_grouped_conv_problem_pointwise... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 128; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 1, 1}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.is_pointwise()); + assert(p.Y() == 1 && p.X() == 1); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_flops() +{ + std::cout << " test_grouped_conv_problem_flops... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 64; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + double flops = p.get_flops(); + assert(flops > 0); + assert(flops == 2.0 * p.N * p.K * p.Ho() * p.Wo() * (p.C / p.G) * p.Y() * p.X()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_is_valid() +{ + std::cout << " test_grouped_conv_problem_is_valid... "; + GroupedConvProblem p; + p.N = 1; + p.C = 64; + p.K = 64; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.compute_output_size(); + assert(p.is_valid()); + + p.N = 0; + assert(!p.is_valid()); + p.N = 1; + + p.C = 0; + assert(!p.is_valid()); + p.C = 64; + + p.K = 0; + assert(!p.is_valid()); + p.K = 64; + + p.G = 0; + assert(!p.is_valid()); + p.G = 1; + + p.C = 64; + p.K = 64; + p.G = 3; + assert(!p.is_valid()); + p.G = 4; + assert(p.is_valid()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_builder() +{ + std::cout << " test_grouped_conv_problem_builder... "; + auto p = GroupedConvProblemBuilder() + .batch(8) + .channels(128, 256) + .groups(4) + .input_size(32, 32) + .filter_size(3, 3) + .stride(2, 2) + .padding(1, 1) + .dilation(1, 1) + .operation(GroupedConvOp::Forward) + .build(); + assert(p.N == 8); + assert(p.C == 128); + assert(p.K == 256); + assert(p.G == 4); + assert(p.Hi() == 32); + assert(p.Wi() == 32); + assert(p.Y() == 3); + assert(p.X() == 3); + assert(p.stride[1] == 2 && p.stride[2] == 2); + assert(p.padding[1] == 1 && p.padding[2] == 1); + assert(p.op == GroupedConvOp::Forward); + assert(p.is_valid()); + + bool threw = false; + try + { + (void)GroupedConvProblemBuilder() + .batch(0) + .channels(64, 64) + .groups(1) + .input_size(14, 14) + .filter_size(3, 3) + .build(); + } + catch(const std::invalid_argument&) + { + threw = true; + } + assert(threw); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Problem ===\n\n"; + test_grouped_conv_problem_defaults(); + test_grouped_conv_problem_2d(); + test_grouped_conv_problem_strided(); + test_grouped_conv_problem_grouped(); + test_grouped_conv_problem_depthwise(); + test_grouped_conv_problem_pointwise(); + test_grouped_conv_problem_flops(); + test_grouped_conv_problem_is_valid(); + test_grouped_conv_problem_builder(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_registry.cpp b/dispatcher/tests/test_grouped_conv_registry.cpp new file mode 100644 index 0000000000..47d13a9997 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_registry.cpp @@ -0,0 +1,230 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvRegistry and GroupedConvDispatcher using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_decl; + +void test_grouped_conv_registry_basic() +{ + std::cout << " test_grouped_conv_registry_basic... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + reg.set_name("test_registry"); + assert(reg.name() == "test_registry"); + + assert(reg.size() == 0); + assert(reg.empty()); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_register_set() +{ + std::cout << " test_grouped_conv_registry_register_set... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + set.add("fp16", "nhwc", "forward", 256, 256); + + bool ok = reg.register_set(set); + assert(ok); + assert(reg.size() == 2); + assert(!reg.empty()); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_all_kernels() +{ + std::cout << " test_grouped_conv_registry_all_kernels... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + auto all = reg.all_kernels(); + assert(all.size() == 1); + assert(all[0]->name().find("grouped_conv_") != std::string::npos); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_clear() +{ + std::cout << " test_grouped_conv_registry_clear... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + assert(reg.size() == 1); + + reg.clear(); + assert(reg.size() == 0); + assert(reg.empty()); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_thread_safe() +{ + std::cout << " test_grouped_conv_registry_thread_safe... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + const int num_threads = 4; + const int sets_per_thread = 10; + std::vector threads; + std::atomic success_count{0}; + + for(int t = 0; t < num_threads; t++) + { + threads.emplace_back([t, ®, &success_count]() { + for(int k = 0; k < sets_per_thread; k++) + { + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128 + t * 32 + k, 128); + if(reg.register_set(set)) + { + success_count++; + } + } + }); + } + + for(auto& th : threads) + th.join(); + + assert(reg.size() == num_threads * sets_per_thread); + assert(success_count.load() == num_threads * sets_per_thread); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_export_json() +{ + std::cout << " test_grouped_conv_registry_export_json... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + std::string json = reg.export_json(false); + assert(!json.empty()); + assert(json.find("\"kernels\"") != std::string::npos); + assert(json.find("\"metadata\"") != std::string::npos); + assert(json.find("grouped_conv_") != std::string::npos); + + std::string json_stats = reg.export_json(true); + assert(json_stats.find("\"statistics\"") != std::string::npos); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_filter() +{ + std::cout << " test_grouped_conv_registry_filter... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + set.add("fp16", "nhwc", "forward", 256, 256); + set.add("bf16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + auto fp16_only = + reg.filter([](const GroupedConvKernelInstance& k) { return k.key().dtype_in == "fp16"; }); + assert(fp16_only.size() == 2); + + auto large_tile = reg.filter([](const GroupedConvKernelInstance& k) { + return k.key().tile_m >= 256 || k.key().tile_n >= 256; + }); + assert(large_tile.size() >= 1); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_dispatcher_basic() +{ + std::cout << " test_grouped_conv_dispatcher_basic... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + GroupedConvDispatcher dispatcher(®); + GroupedConvProblem problem = grouped_conv_utils::create_grouped_conv2d_problem( + 4, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward); + + float time = dispatcher.run(problem, nullptr); + assert(time >= 0.0f); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_dispatcher_select() +{ + std::cout << " test_grouped_conv_dispatcher_select... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + set.add("fp16", "nhwc", "forward", 256, 256); + reg.register_set(set); + + GroupedConvDispatcher dispatcher(®); + GroupedConvProblem problem = grouped_conv_utils::create_grouped_conv2d_problem( + 4, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward); + + const auto* selected = dispatcher.select(problem); + assert(selected != nullptr); + assert(selected->name().find("grouped_conv_") != std::string::npos); + assert(selected->matches(problem)); + + reg.clear(); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Registry ===\n\n"; + test_grouped_conv_registry_basic(); + test_grouped_conv_registry_register_set(); + test_grouped_conv_registry_all_kernels(); + test_grouped_conv_registry_clear(); + test_grouped_conv_registry_thread_safe(); + test_grouped_conv_registry_export_json(); + test_grouped_conv_registry_filter(); + test_grouped_conv_dispatcher_basic(); + test_grouped_conv_dispatcher_select(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_utils.py b/dispatcher/tests/test_grouped_conv_utils.py new file mode 100644 index 0000000000..9d0638dc08 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_utils.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +TDD tests for python/grouped_conv_utils.py -- grouped convolution Python utilities. + +Phase 1 TDD: tests written BEFORE implementation exists. +Run: python3 -m pytest tests/test_grouped_conv_utils.py -v +""" + +import sys +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from dispatcher_common import ValidationResultBase # noqa: E402 +from grouped_conv_utils import ( # noqa: E402 + GroupedConvValidationResult, + validate_grouped_conv_config, + auto_correct_grouped_conv_config, + get_grouped_conv_default_config, + GroupedConvDataType, + format_grouped_conv_summary, +) + + +# ============================================================================= +# VALID CONFIG FIXTURES +# ============================================================================= + + +def make_valid_grouped_conv_config(): + """Return a valid grouped conv config dict for gfx942.""" + return { + "tile_config": { + "tile_k": 128, + "tile_c": 128, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + }, + "trait_config": { + "pipeline": "compv4", + "epilogue": "cshuffle", + "scheduler": "intrawave", + }, + "variant": "2d_fwd", + "ndim_spatial": 2, + "arch": "gfx942", + "layout": "nhwgc", + "dtype": "fp16", + } + + +# ============================================================================= +# TestGroupedConvValidationResult +# ============================================================================= + + +class TestGroupedConvValidationResult(unittest.TestCase): + """Tests for GroupedConvValidationResult dataclass.""" + + def test_inherits_from_validation_result_base(self): + """GroupedConvValidationResult should inherit from ValidationResultBase.""" + self.assertTrue( + issubclass(GroupedConvValidationResult, ValidationResultBase), + "GroupedConvValidationResult must inherit from ValidationResultBase", + ) + + def test_valid_result_has_is_valid(self): + """Valid result has is_valid=True.""" + vr = GroupedConvValidationResult(is_valid=True) + self.assertTrue(vr.is_valid) + + def test_invalid_result_has_is_valid_false(self): + """Invalid result has is_valid=False.""" + vr = GroupedConvValidationResult(is_valid=False, errors=["bad config"]) + self.assertFalse(vr.is_valid) + + def test_has_errors_list(self): + """Result has errors list.""" + vr = GroupedConvValidationResult( + is_valid=False, + errors=["invalid wave", "invalid trait"], + ) + self.assertEqual(len(vr.errors), 2) + self.assertIn("invalid wave", vr.errors) + self.assertIn("invalid trait", vr.errors) + + def test_has_warnings_list(self): + """Result has warnings list.""" + vr = GroupedConvValidationResult( + is_valid=True, + warnings=["deprecated option"], + ) + self.assertEqual(len(vr.warnings), 1) + self.assertIn("deprecated option", vr.warnings) + + def test_has_suggested_fixes_dict(self): + """Result has suggested_fixes dict.""" + vr = GroupedConvValidationResult( + is_valid=False, + suggested_fixes={"wave_m": 2, "wave_n": 2}, + ) + self.assertIn("wave_m", vr.suggested_fixes) + self.assertEqual(vr.suggested_fixes["wave_m"], 2) + self.assertIn("wave_n", vr.suggested_fixes) + self.assertEqual(vr.suggested_fixes["wave_n"], 2) + + def test_default_empty_errors_warnings_fixes(self): + """Default result has empty errors, warnings, suggested_fixes.""" + vr = GroupedConvValidationResult(is_valid=True) + self.assertEqual(vr.errors, []) + self.assertEqual(vr.warnings, []) + self.assertEqual(vr.suggested_fixes, {}) + + +# ============================================================================= +# TestValidateGroupedConvConfig +# ============================================================================= + + +class TestValidateGroupedConvConfig(unittest.TestCase): + """Tests for validate_grouped_conv_config.""" + + def test_valid_config_passes(self): + """Valid config should pass validation.""" + config = make_valid_grouped_conv_config() + result = validate_grouped_conv_config(config) + self.assertTrue(result.is_valid, f"Expected valid, got errors: {result.errors}") + self.assertEqual(result.errors, []) + + def test_invalid_wave_config_fails(self): + """Invalid wave config should fail validation.""" + config = make_valid_grouped_conv_config() + config["tile_config"]["wave_m"] = 3 + config["tile_config"]["wave_n"] = 3 + result = validate_grouped_conv_config(config) + self.assertFalse(result.is_valid) + self.assertGreater(len(result.errors), 0) + error_str = " ".join(result.errors).lower() + self.assertIn("wave", error_str) + + def test_invalid_trait_fails(self): + """Invalid trait combination should fail validation.""" + config = make_valid_grouped_conv_config() + config["trait_config"]["pipeline"] = "compv4" + config["trait_config"]["epilogue"] = "cshuffle" + config["trait_config"]["scheduler"] = "interwave" # Invalid combo + result = validate_grouped_conv_config(config) + self.assertFalse(result.is_valid) + self.assertGreater(len(result.errors), 0) + error_str = " ".join(result.errors).lower() + self.assertIn("trait", error_str) + + def test_missing_fields_fails(self): + """Config with missing required fields should fail validation.""" + config = {"arch": "gfx942"} # Missing tile_config, trait_config, etc. + result = validate_grouped_conv_config(config) + self.assertFalse(result.is_valid) + self.assertGreater(len(result.errors), 0) + + +# ============================================================================= +# TestAutoCorrectGroupedConvConfig +# ============================================================================= + + +class TestAutoCorrectGroupedConvConfig(unittest.TestCase): + """Tests for auto_correct_grouped_conv_config.""" + + def test_invalid_wave_gets_corrected(self): + """Invalid wave config should be auto-corrected.""" + config = make_valid_grouped_conv_config() + config["tile_config"]["wave_m"] = 3 + config["tile_config"]["wave_n"] = 3 + corrected, result = auto_correct_grouped_conv_config(config) + self.assertIsInstance(corrected, dict) + self.assertIsInstance(result, GroupedConvValidationResult) + # Corrected wave should be valid for arch + wave_m = corrected.get("tile_config", {}).get("wave_m") + wave_n = corrected.get("tile_config", {}).get("wave_n") + self.assertIn(wave_m, [1, 2, 4]) + self.assertIn(wave_n, [1, 2, 4]) + + def test_invalid_trait_gets_corrected(self): + """Invalid trait combination should be auto-corrected.""" + config = make_valid_grouped_conv_config() + config["trait_config"]["scheduler"] = "interwave" + config["trait_config"]["pipeline"] = "compv4" + config["trait_config"]["epilogue"] = "cshuffle" + corrected, result = auto_correct_grouped_conv_config(config) + self.assertIsInstance(corrected, dict) + self.assertIsInstance(result, GroupedConvValidationResult) + # Scheduler should be corrected to intrawave for compv4+cshuffle + scheduler = corrected.get("trait_config", {}).get("scheduler") + self.assertEqual(scheduler, "intrawave") + + +# ============================================================================= +# TestGetGroupedConvDefaultConfig +# ============================================================================= + + +class TestGetGroupedConvDefaultConfig(unittest.TestCase): + """Tests for get_grouped_conv_default_config.""" + + def test_returns_config(self): + """Should return a GroupedConvKernelConfig (or dict via to_dict).""" + config = get_grouped_conv_default_config("2d_fwd") + # Accepts both dataclass and dict + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIsInstance(d, dict) + + def test_has_tile_config(self): + """Returned config has tile_config key.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("tile_config", d) + self.assertIsInstance(d["tile_config"], dict) + + def test_has_trait_config(self): + """Returned config has trait_config key.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("trait_config", d) + self.assertIsInstance(d["trait_config"], dict) + + def test_has_variant(self): + """Returned config has variant.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("variant", d) + + def test_has_ndim_spatial(self): + """Returned config has ndim_spatial.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("ndim_spatial", d) + + def test_has_arch(self): + """Returned config has arch.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("arch", d) + + def test_has_layout(self): + """Returned config has layout.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("layout", d) + + +# ============================================================================= +# TestGroupedConvDataType +# ============================================================================= + + +class TestGroupedConvDataType(unittest.TestCase): + """Tests for GroupedConvDataType enum.""" + + def test_fp16_exists(self): + """GroupedConvDataType has FP16.""" + self.assertIsNotNone(GroupedConvDataType.FP16) + + def test_bf16_exists(self): + """GroupedConvDataType has BF16.""" + self.assertIsNotNone(GroupedConvDataType.BF16) + + def test_fp32_exists(self): + """GroupedConvDataType has FP32.""" + self.assertIsNotNone(GroupedConvDataType.FP32) + + def test_fp8_exists(self): + """GroupedConvDataType has FP8.""" + self.assertIsNotNone(GroupedConvDataType.FP8) + + def test_bf8_exists(self): + """GroupedConvDataType has BF8.""" + self.assertIsNotNone(GroupedConvDataType.BF8) + + def test_int8_exists(self): + """GroupedConvDataType has INT8.""" + self.assertIsNotNone(GroupedConvDataType.INT8) + + def test_enum_values_unique(self): + """All enum values should be unique.""" + values = [ + GroupedConvDataType.FP16, + GroupedConvDataType.BF16, + GroupedConvDataType.FP32, + GroupedConvDataType.FP8, + GroupedConvDataType.BF8, + GroupedConvDataType.INT8, + ] + self.assertEqual(len(values), len(set(values))) + + +# ============================================================================= +# TestFormatGroupedConvSummary +# ============================================================================= + + +class TestFormatGroupedConvSummary(unittest.TestCase): + """Tests for format_grouped_conv_summary.""" + + def test_returns_non_empty_string(self): + """Should return a non-empty string.""" + config = make_valid_grouped_conv_config() + summary = format_grouped_conv_summary(config) + self.assertIsInstance(summary, str) + self.assertGreater(len(summary), 0) + + def test_contains_key_info(self): + """Summary should contain key config info (variant, arch, layout, dtype).""" + config = make_valid_grouped_conv_config() + summary = format_grouped_conv_summary(config) + # Should mention at least some of: variant, arch, layout, dtype + summary_lower = summary.lower() + has_key_info = ( + "2d" in summary_lower + or "fwd" in summary_lower + or "gfx" in summary_lower + or "nhwgc" in summary_lower + or "fp16" in summary_lower + ) + self.assertTrue( + has_key_info, + f"Summary should contain key info, got: {summary}", + ) + + def test_empty_config_returns_something(self): + """Empty or minimal config should still return a string.""" + summary = format_grouped_conv_summary({}) + self.assertIsInstance(summary, str) + self.assertGreaterEqual(len(summary), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_problem_extended.cpp b/dispatcher/tests/test_problem_extended.cpp index 21ea545292..ba6068e3ee 100644 --- a/dispatcher/tests/test_problem_extended.cpp +++ b/dispatcher/tests/test_problem_extended.cpp @@ -19,7 +19,7 @@ class ProblemDimensionInferenceTest : public ::testing::Test TEST_F(ProblemDimensionInferenceTest, FromAB_Basic) { - // A: M×K (1024×512), B: K×N (512×2048) + // A: MxK (1024x512), B: KxN (512x2048) auto problem = Problem::from_ab(1024, 512, 512, 2048); EXPECT_EQ(problem.M, 1024); @@ -30,7 +30,7 @@ TEST_F(ProblemDimensionInferenceTest, FromAB_Basic) TEST_F(ProblemDimensionInferenceTest, FromDimensions_Valid) { - // A: 1024×512, B: 512×2048, C: 1024×2048 + // A: 1024x512, B: 512x2048, C: 1024x2048 auto problem = Problem::from_dimensions(1024, 512, 512, 2048, 1024, 2048); EXPECT_EQ(problem.M, 1024); @@ -55,7 +55,7 @@ TEST_F(ProblemDimensionInferenceTest, FromShapes_WithC) TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA) { - // A stored as K×M (transposed) + // A stored as KxM (transposed) TensorShape A{512, 1024, true}; TensorShape B{512, 2048, false}; TensorShape C{1024, 2048, false}; @@ -70,7 +70,7 @@ TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA) TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedB) { TensorShape A{1024, 512, false}; - // B stored as N×K (transposed) + // B stored as NxK (transposed) TensorShape B{2048, 512, true}; TensorShape C{1024, 2048, false}; diff --git a/dispatcher/tests/test_real_kernel_multi_size.cpp b/dispatcher/tests/test_real_kernel_multi_size.cpp index f23f684631..79282da557 100644 --- a/dispatcher/tests/test_real_kernel_multi_size.cpp +++ b/dispatcher/tests/test_real_kernel_multi_size.cpp @@ -187,7 +187,7 @@ int main() for(const auto& r : results) { char size_str[32]; - snprintf(size_str, sizeof(size_str), "%4d×%4d×%4d", r.M, r.N, r.K); + snprintf(size_str, sizeof(size_str), "%4dx%4dx%4d", r.M, r.N, r.K); printf(" %-14s | %9.4f | %6.2f | %7.2f%% | %s\n", size_str, diff --git a/dispatcher/tests/test_real_kernel_performance.cpp b/dispatcher/tests/test_real_kernel_performance.cpp index ff3d635968..29c7c80ac3 100644 --- a/dispatcher/tests/test_real_kernel_performance.cpp +++ b/dispatcher/tests/test_real_kernel_performance.cpp @@ -144,7 +144,7 @@ int main() all_passed = all_passed && passed; char size_label[32]; - snprintf(size_label, sizeof(size_label), "%s %d³", label, M); + snprintf(size_label, sizeof(size_label), "%s %d^3", label, M); printf(" %-9s | %9.4f | %6.2f | %9.1f | %s\n", size_label, diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index b7226270fc..c4c69cb751 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -209,7 +209,7 @@ referencing==0.37.0 # via # jsonschema # jsonschema-specifications -requests==2.32.5 +requests==2.33.0 # via # pygithub # sphinx diff --git a/example/26_contraction/common_instances.hpp b/example/26_contraction/common_instances.hpp index 457bae21aa..808c548042 100644 --- a/example/26_contraction/common_instances.hpp +++ b/example/26_contraction/common_instances.hpp @@ -194,3 +194,35 @@ using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device:: //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; // clang-format on + +// Macro to instantiate all four layout variants of DeviceOpInstance. +// +// BASE: Generic (for fp16/bf16/fp32) or FP64 (for fp64 — different tile sizes) +// SUFFIX: NN for bilinear (DsDataType = Tuple), +// N for scale (DsDataType = Tuple<>) +// +// Requires these names to be defined in the calling TU before invocation: +// NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, +// CShuffleDataType, DsDataType, EDataType, ComputeDataType, +// AElementOp, BElementOp, CDEElementOp +// +// Example: CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); +// expands to DeviceOpInstanceKKNN, DeviceOpInstanceKNNN, +// DeviceOpInstanceMKNN, DeviceOpInstanceMNNN, +// and sets DeviceOpInstance = DeviceOpInstanceKKNN. +// clang-format off +#define CK_CONTRACTION_DEVICE_OP_INSTANCES(BASE, SUFFIX) \ + using DeviceOpInstanceKK##SUFFIX = DeviceOpInstanceKK_##BASE; \ + using DeviceOpInstanceKN##SUFFIX = DeviceOpInstanceKN_##BASE; \ + using DeviceOpInstanceMK##SUFFIX = DeviceOpInstanceMK_##BASE; \ + using DeviceOpInstanceMN##SUFFIX = DeviceOpInstanceMN_##BASE; \ + using DeviceOpInstance = DeviceOpInstanceKK##SUFFIX +// clang-format on diff --git a/example/26_contraction/contraction_bilinear_xdl_bf16.cpp b/example/26_contraction/contraction_bilinear_xdl_bf16.cpp index 8899b54fbf..b5758ed428 100644 --- a/example/26_contraction/contraction_bilinear_xdl_bf16.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_bf16.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); #include "run_contraction_bilinear_example.inc" diff --git a/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp index 2dac449e99..be03613bd1 100644 --- a/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); #include "run_contraction_bilinear_example.inc" diff --git a/example/26_contraction/contraction_bilinear_xdl_fp16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp16.cpp index 16e33e0886..5d6d401836 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp16.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp16.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); #include "run_contraction_bilinear_example.inc" diff --git a/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp index 494670bcca..ded63dec25 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); #include "run_contraction_bilinear_example.inc" diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp index e960199fc3..8779e1fab9 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); #include "run_contraction_bilinear_example.inc" diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp index 2963152eb1..467672986e 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); #include "run_contraction_bilinear_example.inc" diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp index 01966960cc..dff5a0446a 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); #include "run_contraction_bilinear_example.inc" diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp index 1ea9bcedfd..2d697f3e07 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(FP64, NN); #include "run_contraction_bilinear_example.inc" diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp index 9e40e28485..341dad6d5b 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(FP64, NN); #include "run_contraction_bilinear_example.inc" diff --git a/example/26_contraction/contraction_scale_xdl_bf16.cpp b/example/26_contraction/contraction_scale_xdl_bf16.cpp index 586b022397..003bc0274a 100644 --- a/example/26_contraction/contraction_scale_xdl_bf16.cpp +++ b/example/26_contraction/contraction_scale_xdl_bf16.cpp @@ -22,63 +22,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, N); #include "run_contraction_scale_example.inc" diff --git a/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp index 9e4a02967a..bada39204e 100644 --- a/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp @@ -22,63 +22,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, N); #include "run_contraction_scale_example.inc" diff --git a/example/26_contraction/contraction_scale_xdl_fp16.cpp b/example/26_contraction/contraction_scale_xdl_fp16.cpp index 1f29e16223..4f3adef47a 100644 --- a/example/26_contraction/contraction_scale_xdl_fp16.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp16.cpp @@ -22,63 +22,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, N); #include "run_contraction_scale_example.inc" diff --git a/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp index 878011afd1..9be3b616f6 100644 --- a/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp @@ -22,63 +22,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, N); #include "run_contraction_scale_example.inc" diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp index 5d8aa7b9c5..d7754ef546 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -22,63 +22,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, N); #include "run_contraction_scale_example.inc" diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp index 57b1052a83..deaf7e7bdc 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp @@ -22,63 +22,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, N); #include "run_contraction_scale_example.inc" diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp index ae23986bc9..de52096712 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp @@ -22,63 +22,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, N); #include "run_contraction_scale_example.inc" diff --git a/example/26_contraction/contraction_scale_xdl_fp64.cpp b/example/26_contraction/contraction_scale_xdl_fp64.cpp index 66f22ce63c..3d5d23968f 100644 --- a/example/26_contraction/contraction_scale_xdl_fp64.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp64.cpp @@ -22,63 +22,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; - -using DeviceOpInstance = DeviceOpInstanceKKN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(FP64, N); #include "run_contraction_scale_example.inc" diff --git a/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp index 2d72be8157..ee2533ca0a 100644 --- a/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp @@ -22,63 +22,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; - -using DeviceOpInstance = DeviceOpInstanceKKN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(FP64, N); #include "run_contraction_scale_example.inc" diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp index b0b2d29d98..2ceca3c877 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp @@ -238,16 +238,6 @@ int main(int argc, char* argv[]) Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); -#if 0 - for(int n = 0; n < N; ++n) - { - for(int k = 0; k < K; ++k) - { - b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); - } - } -#endif - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1})); - // max_token_id.mData[0] = valid_size; - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} for(int i = 0; i < sorted_tile_num; i++) diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 552d3cd7b5..8ae97ef1c2 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -261,16 +261,6 @@ int main(int argc, char* argv[]) Tensor max_token_id(HostTensorDescriptor({1})); max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - // int eids[] = {0, 1, 3, 3, 3}; - // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7}; //, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; - // int eids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - // 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - // 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - // 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, - // 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, - // 7, 7, - // 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); diff --git a/example/66_complex_contraction_bilinear/common_instances.hpp b/example/66_complex_contraction_bilinear/common_instances.hpp index cb6157b29b..3ae168cb72 100644 --- a/example/66_complex_contraction_bilinear/common_instances.hpp +++ b/example/66_complex_contraction_bilinear/common_instances.hpp @@ -194,3 +194,35 @@ using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device:: //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; // clang-format on + +// Macro to instantiate all four layout variants of DeviceOpInstance. +// +// BASE: Generic (for fp16/bf16/fp32) or FP64 (for fp64 — different tile sizes) +// SUFFIX: NN for bilinear (DsDataType = Tuple), +// N for scale (DsDataType = Tuple<>) +// +// Requires these names to be defined in the calling TU before invocation: +// NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, +// CShuffleDataType, DsDataType, EDataType, ComputeDataType, +// AElementOp, BElementOp, CDEElementOp +// +// Example: CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); +// expands to DeviceOpInstanceKKNN, DeviceOpInstanceKNNN, +// DeviceOpInstanceMKNN, DeviceOpInstanceMNNN, +// and sets DeviceOpInstance = DeviceOpInstanceKKNN. +// clang-format off +#define CK_CONTRACTION_DEVICE_OP_INSTANCES(BASE, SUFFIX) \ + using DeviceOpInstanceKK##SUFFIX = DeviceOpInstanceKK_##BASE; \ + using DeviceOpInstanceKN##SUFFIX = DeviceOpInstanceKN_##BASE; \ + using DeviceOpInstanceMK##SUFFIX = DeviceOpInstanceMK_##BASE; \ + using DeviceOpInstanceMN##SUFFIX = DeviceOpInstanceMN_##BASE; \ + using DeviceOpInstance = DeviceOpInstanceKK##SUFFIX +// clang-format on diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp index e2cae7a1f8..7533281f1a 100644 --- a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp +++ b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(Generic, NN); #include "run_complex_contraction_bilinear_example.inc" diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp index a2021b5eaa..a41e1f1785 100644 --- a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp +++ b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp @@ -23,63 +23,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; - -using DeviceOpInstance = DeviceOpInstanceKKNN; +// Instantiate DeviceOpInstance for all four layout variants (KK, KN, MK, MN). +// See common_instances.hpp for macro definition and available BASE/SUFFIX options. +CK_CONTRACTION_DEVICE_OP_INSTANCES(FP64, NN); #include "run_complex_contraction_bilinear_example.inc" diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index e9ae11fb5f..79fe6492a6 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -139,6 +139,7 @@ LAYOUT_MAP = {"row": "true", "col": "false"} PIPELINE_MAP = { "qr": "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_hpad": "ck_tile::BlockFmhaPipelineQRKSVSHpad", "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", @@ -147,6 +148,7 @@ PIPELINE_MAP = { PIPELINE_ENUM_MAP = { "qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_hpad": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_HPAD", "qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index f172bb6ab6..35e8c1be49 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -635,6 +635,7 @@ class KernelComponentFactory: elif dtype in ["fp8bf16"]: return { 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + 256 : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip else: return None diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 6739abf621..7105f1aa5c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -533,6 +533,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::LSEDataType, /* BlockSize = M0 = */ {F_bm0}, {F_hdim}, {F_mode}, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 1849068161..c64a19104e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -60,6 +60,22 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_fwd.hpp" """ +FMHA_FWD_KERNEL_HEADER_QR_HPAD = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#if defined(__HIP_DEVICE_COMPILE__) && \ + (defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) || \ + defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)) +#if !defined(CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK) +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 +#endif +#endif +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + FMHA_FWD_KERNEL_BODY_TEMPLATE = """ #include @@ -206,22 +222,14 @@ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fw """ FMHA_FWD_API_FOOTER_TEMPLATE = """ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ - const std::string device_name = ck_tile::get_device_name(); - - const bool is_swa = (traits.mask_type != mask_enum::no_mask) and - ((0 < args.window_size_left) or (0 < args.window_size_right)); - const bool can_dispatch_v3 = - (device_name.compare(0, 6, "gfx950") == 0) and - (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and - traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and - (not traits.has_lse) and (not traits.has_dropout) and - (traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and - (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); - if ({F_is_v3_enabled} and can_dispatch_v3) {{ - return fmha_fwd_v3(traits, args, config); - }} else {{ - return fmha_fwd_v2(traits, args, config); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunreachable-code" + if ({F_is_v3_enabled}) {{ + float r = fmha_fwd_v3(traits, args, config); + if (r >= 0) return r; }} +#pragma clang diagnostic pop + return fmha_fwd_v2(traits, args, config); }} """ @@ -308,7 +316,7 @@ class FmhaFwdApiTrait: return "true" # always support else: return "true" - elif self.pipeline_tag in ["qr", "qs"]: + elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]: if self.spad == "t": return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: @@ -331,7 +339,7 @@ class FmhaFwdApiTrait: return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)" else: return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" - elif self.pipeline_tag in ["qr", "qs"]: + elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]: if self.skpad == "t": return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: @@ -352,6 +360,11 @@ class FmhaFwdApiTrait: return f"a.hdim_q % {vec} == 0" else: assert False + elif self.pipeline_tag == "qr_hpad": + if self.dpad == "t": + return "a.hdim_q % 8 == 0" + else: + assert False elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == "t": @@ -369,6 +382,11 @@ class FmhaFwdApiTrait: return f"a.hdim_v % {vec} == 0" else: assert False + elif self.pipeline_tag == "qr_hpad": + if self.dvpad == "t": + return "a.hdim_v % 8 == 0" + else: + assert False elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == "t": @@ -642,6 +660,7 @@ class FmhaFwdKernel: F_pipeline: FmhaFwdPipeline _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER + _KERNEL_HEADER_QR_HPAD: ClassVar[str] = FMHA_FWD_KERNEL_HEADER_QR_HPAD _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE @classmethod @@ -651,6 +670,12 @@ class FmhaFwdKernel: else: return "ck_tile::FmhaFwdKernel" + @classmethod + def _get_kernel_header(cls, pipeline_tag): + if pipeline_tag == "qr_hpad": + return cls._KERNEL_HEADER_QR_HPAD + return cls._KERNEL_HEADER + @classmethod def _get_cpp_kargs_creator_func_name(cls, pipeline_tag): if pipeline_tag == "qr_async_trload_v3": @@ -659,7 +684,9 @@ class FmhaFwdKernel: return "fmha_fwd_create_kargs_and_grids" def render(self) -> str: - return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( + return type(self)._get_kernel_header(self.F_pipeline.tag) + type( + self + )._KERNEL_BODY_TEMPLATE.format( F_kname=self.name, F_arch=self.F_arch, F_hdim=self.F_hdim, @@ -1059,10 +1086,11 @@ class KernelComponentFactoryGfx950( def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) if dtype in cls._DT_FP16_BF16: - # add tile for qr_async_trload_v3 - if (128, 128) in result.keys(): - result[(128, 128)].append( - FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + # # add tile for qr_async_trload_v3 (bf16/fp16 V3 not ready) + # if (128, 128) in result.keys(): + # result[(128, 128)].append( + # FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + pass elif dtype in cls._DT_MXFP8: return { # bm0, bn0, bk0, bn1, bk1, @@ -1075,6 +1103,10 @@ class KernelComponentFactoryGfx950( (128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)], } # fmt: skip + elif dtype in cls._DT_FP8BF16: + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip return result @classmethod @@ -1105,12 +1137,19 @@ class KernelComponentFactoryGfx950( pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip - # qr_async_trload_v3 only supports hdim=hdim_v=128 for now - if (hdim, hdim_v) == (128, 128): - # qr_async_trload_v3 only supports (generic) causal mask - for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): - pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", - F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + # # qr_async_trload_v3 bf16/fp16 not ready + # if (hdim, hdim_v) == (128, 128): + # for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): + # pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", + # F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + elif dtype in cls._DT_FP8BF16: + # qr_async_trload_v3 only supports (generic) causal mask + for logits, qscale, mask in itertools.product( + ["t", "f"], + ["no", "pertensor"], + ["no", "causal"], + ): + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4: # no need dropout kernels @@ -1140,6 +1179,37 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): def supported_dtypes(cls) -> Tuple[str]: return cls._DT_FP16_BF16 + @classmethod + def get_rules(cls) -> List[CompatibilityRule]: + rules = super().get_rules() + + # For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile: + # the exact-hdim variant (dpad=dvpad=f) is much slower here. + def check_d128_tile_pipeline( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if problem_ctx.dtype not in cls._DT_FP16_BF16: + return True + + if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128): + return True + + is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32 + pads_hdim = ( + kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t" + ) + exact_hdim = ( + kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f" + ) + + if is_64x32_tile: + return pads_hdim + + return exact_hdim + + rules.append(check_d128_tile_pipeline) + return rules + @classmethod def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: if dtype in cls._DT_FP16_BF16: @@ -1148,7 +1218,8 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), + (128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")), + FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)], (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)] @@ -1175,7 +1246,9 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): # Keep only ttff/tttt for gfx11: ffff path is often similar or worse # pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + if receipt == 1: + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip return pipelines @@ -1209,7 +1282,8 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q <= 8192")), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)], (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip @@ -1244,9 +1318,11 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): ["t", "f"], ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + # pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + if receipt == 1: + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( @@ -1303,7 +1379,23 @@ class Product: def get_product(receipt: int) -> Product: # Flash attention integration - if receipt in (2, 3): + if receipt == 2: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_sink == "f" + # FlashAttention direct fwd wrappers always use softcap disabled and LSE enabled. + cond &= kernel_ctx.pipeline.F_logits == "f" + cond &= kernel_ctx.pipeline.F_lse == "t" + return cond + + return Product(name="Flash attention integration", rule=fit) + # Receipt 3 forward coverage used by CK library / smoke tests + elif receipt == 3: def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond = problem_ctx.dtype in ["fp16", "bf16"] @@ -1477,8 +1569,8 @@ def write_fwd_api( FMHA_FWD_API_FOOTER_TEMPLATE.format( F_is_v3_enabled=BOOL_MAP[ # NOTE: enable v3 pipelines when ready - # 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) - False + 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + # False ] ), ] diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index e0ccde8a6b..c9bac50da1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -939,6 +939,8 @@ def get_fwd_splitkv_blobs( cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] + # FlashAttention splitkv paths use softcap-disabled kernels only. + cond &= pipeline.F_logits == "f" cond &= pipeline.F_squant == "f" cond &= pipeline.F_sink == "f" if not cond: @@ -1142,4 +1144,7 @@ def list_blobs( ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + "\n") + f.write( + (file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + + "\n" + ) diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp index c1f3a4fce3..bec7da0a2f 100644 --- a/example/ck_tile/01_fmha/example_fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -87,6 +87,7 @@ auto create_args(int argc, char* argv[]) "0", "if set to 1 will use multi-buffer reduction strategy for dq, atomic operation " "will not be used") + .insert("sink_grad", "0", "if set to 1, compute and validate sink token gradient") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "fmha_bwd.json", "json file name to dump results"); @@ -122,6 +123,7 @@ auto run(const ck_tile::ArgParser& arg_parser) bool deterministic = arg_parser.get_bool("deterministic"); std::string init_method = arg_parser.get_str("init"); uint32_t seed = arg_parser.get_uint32("seed"); + bool sink_grad = arg_parser.get_bool("sink_grad"); ck_tile::stream_config stream_config{nullptr, true, @@ -154,6 +156,7 @@ auto run(const ck_tile::ArgParser& arg_parser) drop_offset, drop_prefs, mask_str, + sink_grad, deterministic, init_method, seed, diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 8eb8834e12..4496a6c9dd 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -116,6 +116,9 @@ struct fmha_bwd_args void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; + const void* + sink_ptr; // sink scores [batch, nhead] in log-space (LSEDataType); nullptr disables sink + void* d_sink_ptr; // sink gradient output [nhead] (LSEDataType); nullptr disables sink gradient // Usage notes for sequence length pointer parameters: // @@ -362,11 +365,15 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, args.do_ptr, args.d_ptr, + args.lse_ptr, + args.sink_ptr, + args.d_sink_ptr, args.p_undrop, args.seqstart_q_ptr, args.seqlen_q_ptr, args.cu_seqlen_q_ptr, args.hdim_v, + args.nhead_q, args.stride_do, args.stride_o, args.nhead_stride_do, @@ -378,9 +385,13 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, args.do_ptr, args.d_ptr, + args.lse_ptr, + args.sink_ptr, + args.d_sink_ptr, args.p_undrop, args.seqlen_q, args.hdim_v, + args.nhead_q, args.stride_do, args.stride_o, args.nhead_stride_do, diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 3123e4f2a8..361bda20eb 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -77,6 +77,7 @@ bwd_result fmha_bwd_run(mode_enum mode, uint64_t drop_offset, bool drop_prefs, std::string mask_str, + bool sink_grad, // if true, compute and validate sink gradient bool deterministic, std::string init_method, uint32_t seed, @@ -284,6 +285,16 @@ bwd_result fmha_bwd_run(mode_enum mode, get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); ck_tile::HostTensor lse_host( std::array{shape_batch, nhead, shape_seqlen_q}); + ck_tile::HostTensor sink_host( + sink_grad ? std::array{shape_batch, nhead} + : std::array{1, 1} /* dummy when sink is disabled */); + if(sink_grad) + { + std::uniform_real_distribution sink_dist(30.0f, 60.0f); + sink_host.ForEach([&](auto& self, auto i) { + self(i) = static_cast(sink_dist(random_engine)); + }); + } ck_tile::HostTensor d_host( std::array{shape_batch, nhead, shape_seqlen_q}); ck_tile::HostTensor randval_host( @@ -301,6 +312,12 @@ bwd_result fmha_bwd_run(mode_enum mode, use_dbias ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor d_sink_host(sink_grad ? std::array{nhead} + : std::array{0}); + if(sink_grad) + { + d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; }); + } ck_tile::HostTensor dq_acc_host( std::array{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q}); @@ -361,11 +378,13 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sink_buf(sink_grad ? sink_host.get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_sink_buf(sink_grad ? d_sink_host.get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); @@ -396,6 +415,11 @@ bwd_result fmha_bwd_run(mode_enum mode, drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); + if(sink_grad) + { + sink_buf.ToDevice(sink_host.data()); + d_sink_buf.ToDevice(d_sink_host.data()); + } // clang-format off auto layout_str = [&](bool permute){ @@ -415,7 +439,8 @@ bwd_result fmha_bwd_run(mode_enum mode, << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop - << ", s_randval:" << s_randval << ", deterministic:" << deterministic + << (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval + << ", deterministic:" << deterministic << (deterministic ? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) + "MiB|" + std::to_string(nsplits) + "splits" @@ -479,7 +504,6 @@ bwd_result fmha_bwd_run(mode_enum mode, const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr; const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr; - return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), @@ -495,6 +519,8 @@ bwd_result fmha_bwd_run(mode_enum mode, dv_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(), dq_acc_buf.GetDeviceBuffer(), + sink_buf.GetDeviceBuffer(), + d_sink_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), seqlen_q_ptr_dev, @@ -589,6 +615,7 @@ bwd_result fmha_bwd_run(mode_enum mode, std::vector> randval_host_refs; std::vector> p_hp_host_refs; std::vector> p_lp_host_refs; + std::vector> p_sink_host_refs; randval_buf.FromDevice(randval_host.data()); @@ -765,6 +792,46 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::reference_batched_softmax( s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); + // Incorporate sink token into the softmax distribution (reference computation). + // The sink acts as an extra key whose score is sink_host(wb, i_h) (in log-space), + // which is a per-head random value in [30, 60]. + // lse_new = log(exp(lse_old) + exp(sink)) + // P_new = P_old * exp(lse_old - lse_new) (rescaled token attention) + // P_sink = exp(sink - lse_new) (sink attention weight) + ck_tile::HostTensor p_sink_host_ref( + sink_grad ? std::array{nhead, real_seqlen_q} + : std::array{0, 0}); + if(sink_grad) + { + for(int i_h = 0; i_h < nhead; ++i_h) + { + AccDataType sink_val = sink_host(wb, i_h); + for(int i_q = 0; i_q < real_seqlen_q; ++i_q) + { + // Use numerically stable log-sum-exp: lse_new = log(exp(lse_old)+exp(sink)) + // = max(lse_old, sink) + log(1 + exp(min - max)) + // This handles lse_old = -inf (fully-masked rows) without producing NaN: + // if lse_old=-inf: max=sink, min=-inf, exp(-inf-sink)=0, lse_new=sink + // It also avoids exp(lse_old) overflow when lse_old is large. + // p_scale = exp(lse_old - lse_new) [fraction kept by regular tokens] + // p_sink = exp(sink - lse_new) [sink attention weight] + AccDataType lse_old = lse_host_ref(i_h, i_q); + AccDataType hi = lse_old > sink_val ? lse_old : sink_val; + AccDataType lo = lse_old > sink_val ? sink_val : lse_old; + AccDataType lse_new = + hi + ck_tile::log(AccDataType(1) + ck_tile::exp(lo - hi)); + AccDataType p_scale = ck_tile::exp(lse_old - lse_new); + + lse_host_ref(i_h, i_q) = lse_new; + + for(int i_k = 0; i_k < real_seqlen_k; ++i_k) + p_hp_host_ref(i_h, i_q, i_k) *= p_scale; + + p_sink_host_ref(i_h, i_q) = ck_tile::exp(sink_val - lse_new); + } + } + } + if(p_drop > 0) { p_dropped_hp_host_ref = p_hp_host_ref; @@ -823,6 +890,7 @@ bwd_result fmha_bwd_run(mode_enum mode, o_host_refs.push_back(o_host_ref); p_hp_host_refs.push_back(p_hp_host_ref); p_lp_host_refs.push_back(p_lp_host_ref); + p_sink_host_refs.push_back(p_sink_host_ref); if(p_drop > 0) { randval_host_refs.push_back(randval_host_ref); @@ -842,6 +910,8 @@ bwd_result fmha_bwd_run(mode_enum mode, o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); dbias_buf.SetZero(); + if(sink_grad) + d_sink_buf.SetZero(); if(launcher.needs_zero_dq_acc) dq_acc_buf.SetZero(); @@ -853,10 +923,19 @@ bwd_result fmha_bwd_run(mode_enum mode, dk_buf.FromDevice(dk_host.data()); dv_buf.FromDevice(dv_host.data()); dbias_buf.FromDevice(dbias_host.data()); + if(sink_grad) + d_sink_buf.FromDevice(d_sink_host.data()); // Track the index into reference vectors (may differ from wb if batches were skipped) ck_tile::index_t ref_idx = 0; + // validation sink accumulator: global over batch, shape [nhead] + ck_tile::HostTensor d_sink_host_ref( + sink_grad ? std::array{nhead} + : std::array{0}); + if(sink_grad) + d_sink_host_ref.ForEach([&](auto& self, auto i) { self(i) = 0; }); + for(ck_tile::index_t wb = 0; wb < batch; ++wb) { // When padding is enabled, use logical lengths instead of computing from padded @@ -932,6 +1011,30 @@ bwd_result fmha_bwd_run(mode_enum mode, ds_hp_host_ref.mDesc.get_lengths()[1], ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency()); + if(sink_grad) + { + // Reference: dSink[h] = -sum_q( P_sink[h,q] * D[h,q] ) + // where D[h,q] = sum_j(dO[h,q,j] * O[h,q,j]) * p_undrop + for(int i_h = 0; i_h < nhead; ++i_h) + { + AccDataType d_sink_head_acc = 0; + for(int i_q = 0; i_q < real_seqlen_q; ++i_q) + { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + do_dot_o += + ck_tile::type_convert(do_host_ref(i_h, i_q, o)) * + ck_tile::type_convert( + o_host_refs[ref_idx](i_h, i_q, o)) * + p_undrop; + } + d_sink_head_acc += -p_sink_host_refs[ref_idx](i_h, i_q) * do_dot_o; + } + d_sink_host_ref(i_h) += d_sink_head_acc; + } + } + if(use_dbias) { dbias_host_ref = ds_hp_host_ref.template CopyAsType(); @@ -1044,6 +1147,17 @@ bwd_result fmha_bwd_run(mode_enum mode, ref_idx++; } + if(pass && sink_grad) + { + auto [rtol, atol] = get_elimit(hdim_q, hdim_v); + bool dsink_pass = ck_tile::check_err(d_sink_host, + d_sink_host_ref, + std::string("Error: SinkGrad Incorrect results!"), + rtol, + atol); + pass &= dsink_pass; + } + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 521f1e4738..7d7d01bd05 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -844,6 +844,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqstart_q_ptr, @@ -877,6 +880,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) return FmhaKernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, nullptr, // lse_ptr args.o_ptr, args.seqlen_q, diff --git a/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp b/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp index 9cd1fb9cdc..9dad951d41 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp @@ -8,13 +8,16 @@ #include #include #include -#include #include #include #include #include #include +#ifdef __linux__ +#include +#endif + #ifndef CK_TILE_FMHA_ENABLE_HEAD_GROUPING #define CK_TILE_FMHA_ENABLE_HEAD_GROUPING 1 #endif @@ -70,6 +73,8 @@ inline std::optional read_property_value(const std::string& filepath, return std::nullopt; } +#if defined(__linux__) + struct kfd_device_location { int domain = 0; @@ -176,6 +181,12 @@ inline size_t get_kfd_sysfs_llc_cache_bytes() return read_kfd_node_l3_bytes(*node); } +#else + +inline size_t get_kfd_sysfs_llc_cache_bytes() { return 0; } + +#endif + inline size_t get_default_llc_cache_bytes_for_arch(const std::string& arch); inline size_t resolve_llc_cache_bytes_uncached(const std::string& arch) diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index 456c3986fa..4fbde37cae 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -39,7 +39,6 @@ function print_log_header(){ #run verification tests time example/ck_tile/01_fmha/script/smoke_test_fwd.sh time example/ck_tile/01_fmha/script/smoke_test_bwd.sh -time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index 81617ee16c..c246ccb98f 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -69,6 +69,28 @@ test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS done + +# sink gradient tests: same coverage as main tests but with -sink_grad=1 +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 64 128 256 ; do +for mode in 0 1 ; do +for bias in "n" "a" ; do +for p_drop in 0.0 0.2 ; do +test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=0 -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=0 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -sink_grad=1 +done +done +done +done +done +done + +# sink gradient additional cases: non-standard hdim +for hdim in 40 48 72 96 ; do +test_h_s_mask -prec=fp16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS -sink_grad=1 +test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1 +test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1 +done set +x new_fails_count=0 diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 227f26c8f3..1e9942a6e1 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -235,6 +235,64 @@ run_padding_basic_boundary_tests() { done } +# Sink-specific mask pattern tests (sliding window + sink token). +run_sink_mask_tests() { + # window_size[2,0], sink_size=2 (top-left causal + sink) + # before: after: + # 1 * * * * * * * 1 * * * * * * * + # 1 1 * * * * * * 1 1 * * * * * * + # 1 1 1 * * * * * 1 1 1 * * * * * + # * 1 1 1 * * * * 1 1 1 1 * * * * + # * * 1 1 1 * * * 1 1 1 1 1 * * * + # * * * 1 1 1 * * 1 1 * 1 1 1 * * + # * * * * 1 1 1 * 1 1 * * 1 1 1 * + # * * * * * 1 1 1 1 1 * * * 1 1 1 + run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2 + run_exe -prec=bf16 -mode=0 -b=2 -h=2 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2 + + # window_size[0,3], sink_size=2 (top-left + sink) + # before: after: + # 1 1 1 1 * * * * 1 1 1 1 * * * * + # * 1 1 1 1 * * * 1 1 1 1 1 * * * + # * * 1 1 1 1 * * 1 1 1 1 1 1 * * + # * * * 1 1 1 1 * 1 1 * 1 1 1 1 * + # * * * * 1 1 1 1 1 1 * * 1 1 1 1 + run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2 + run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2 + + # window_size[1,0], sink_size=2 (bottom-right + sink) + # before: after: + # * * 1 1 * * * * 1 1 1 1 * * * * + # * * * 1 1 * * * 1 1 * 1 1 * * * + # * * * * 1 1 * * 1 1 * * 1 1 * * + # * * * * * 1 1 * 1 1 * * * 1 1 * + # * * * * * * 1 1 1 1 * * * * 1 1 + run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2 + run_exe -prec=bf16 -mode=0 -b=2 -h=4 -d=128 -d_v=128 -s=2048 -s_k=2048 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2 + + # window_size[2,0], sink_size=2 (bottom-right, group mode + sink) + run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2 + run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2 + + # window_size[-1,1], sink_size=2 (bottom-right, large seqlen + sink) + run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2 + run_exe -prec=bf16 -mode=1 -b=1 -h=2 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2 +} + +# init_sink tests: validate sink token initialization across prec/hdim/mode. +run_sink_init_tests() { + for prec in "fp16" "bf16" ; do + for hdim in 64 128 256 ; do + for mode in 0 1 ; do + for mask in 0 1 ; do + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -d=$hdim -d_v=$hdim -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask + run_exe -prec=$prec -mode=$mode -b=2 -h=4 -d=$hdim -d_v=$hdim -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask + done + done + done + done +} + set -x run_fp16_bf16_tests @@ -242,6 +300,8 @@ run_padding_smoke_tests run_padding_basic_boundary_tests run_fp8bf16_tests run_fp8fp32_tests +run_sink_mask_tests +run_sink_init_tests if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh deleted file mode 100755 index 5c9d3132b3..0000000000 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ /dev/null @@ -1,93 +0,0 @@ -#!/bin/bash -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -# TODO: run this script from CK root or build directory -#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd" -set -euo pipefail - -SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) -EXE_NAME=tile_example_fmha_fwd -EXE="$(find . -name $EXE_NAME -type f | head -n 1)" -KNAME=1 -GPU_arch=$GPU_arch -if [ -z "$GPU_arch" ] ; then - GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') -fi -set -x - -COMMON_ARGS='-v=1 -warmup=0 -repeat=1' - - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2 - -# window_size[2,0], sink_size = 2 - -# x=1/y=3 -# 1 * * * * * * * 1 * * * * * * * -# 1 1 * * * * * * 1 1 * * * * * * -# 1 1 1 * * * * * ----> 1 1 1 * * * * * -# * 1 1 1 * * * * 1 1 1 1 * * * * -# * * 1 1 1 * * * 1 1 1 1 1 * * * -# * * * 1 1 1 * * 1 1 * 1 1 1 * * -# * * * * 1 1 1 * 1 1 * * 1 1 1 * -# * * * * * 1 1 1 1 1 * * * 1 1 1 -# l=2/r=0(tl) l=2/r=0/s=2(tl) - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2 - -# x=4/y=1 -# 1 1 1 1 * * * * 1 1 1 1 * * * * -# * 1 1 1 1 * * * 1 1 1 1 1 * * * -# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * * -# * * * 1 1 1 1 * 1 1 * 1 1 1 1 * -# * * * * 1 1 1 1 1 1 * * 1 1 1 1 -# l=0/r=3(tl) l=0/r=3/s=2(tl) -# l=3/r=0(br) l=3/r=0/s=2(br) - - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2 - -# x=4/y=-1 -# * * 1 1 * * * * 1 1 1 1 * * * * -# * * * 1 1 * * * 1 1 * 1 1 * * * -# * * * * 1 1 * * ----> 1 1 * * 1 1 * * -# * * * * * 1 1 * 1 1 * * * 1 1 * -# * * * * * * 1 1 1 1 * * * * 1 1 -# l=1/r=0(br) l=1/r=0/s=2(br) - - -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2 - -# x=-1/y=5 - -# * * * * * * * * * * * * -# * * * * * * * * * * * * -# 1 * * * * * 1 * * * * * -# 1 1 * * * * 1 1 * * * * -# 1 1 1 * * * ----> 1 1 1 * * * -# * 1 1 1 * * 1 1 1 1 * * -# * * 1 1 1 * 1 1 1 1 1 * -# * * * 1 1 1 1 1 * 1 1 1 -# l=2/r=0(br) l=2/r=0/s=2(br) - - -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2 -# x=-1/y=8 -# * * * * * * * * * * -# * * * * * * * * * * -# 1 * * * * ----> 1 * * * * -# 1 1 * * * 1 1 * * * -# 1 1 1 * * 1 1 1 * * -# 1 1 1 1 * 1 1 1 1 * -# 1 1 1 1 1 1 1 1 1 1 -# 1 1 1 1 1 1 1 1 1 1 -# l=2/r=0(br) l=2/r=0/s=2(br) - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0 - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 - -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 39dd6357e5..4d13bca2a0 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -284,12 +284,9 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, b_k_n.SetZero(); } - if(!preshuffle && GemmConfig::UseStructuredSparsity) + if constexpr(!preshuffle && GemmConfig::UseStructuredSparsity) { - if constexpr(GemmConfig::UseStructuredSparsity) - { - ck_tile::AdjustToStructuredSparsity{}(a_m_k); - } + ck_tile::AdjustToStructuredSparsity{}(a_m_k); } ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp index 8f4813a47e..ca49114844 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp @@ -5,14 +5,6 @@ // clang-format off // rm rn tm tn vn pd x 3p -#if 0 -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -#endif template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp index e357d7e3ac..f754d8e959 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp @@ -5,14 +5,6 @@ // clang-format off // rm rn tm tn vn pd x 3p -#if 0 -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -#endif template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp index 8a5e0c74a0..66f427247a 100644 --- a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp @@ -5,14 +5,6 @@ // clang-format off // rm rn tm tn vn pd 2p -#if 0 -template float smoothquant_>(const S&, A); -template float smoothquant_>(const S&, A); -template float smoothquant_>(const S&, A); -template float smoothquant_>(const S&, A); - -template float smoothquant_>(const S&, A); -#endif template float smoothquant_>(const S&, A); template float smoothquant_>(const S&, A); diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp index 9c08cf64f0..103f7281b0 100644 --- a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp @@ -5,14 +5,6 @@ // clang-format off // rm rn tm tn vn pd 2p -#if 0 -template float smoothquant_>(const S&, A); -template float smoothquant_>(const S&, A); -template float smoothquant_>(const S&, A); -template float smoothquant_>(const S&, A); - -template float smoothquant_>(const S&, A); -#endif template float smoothquant_>(const S&, A); template float smoothquant_>(const S&, A); diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp index 8c72b81dc1..56fcca3beb 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp @@ -5,14 +5,6 @@ // clang-format off // rm rn tm tn vn pd 2p -#if 0 -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); - -template float moe_smoothquant_>(const S&, A); -#endif template float moe_smoothquant_>(const S&, A); template float moe_smoothquant_>(const S&, A); diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp index 6d7a5e7c1f..2462cd218e 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp @@ -5,14 +5,6 @@ // clang-format off // rm rn tm tn vn pd 2p -#if 0 -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); - -template float moe_smoothquant_>(const S&, A); -#endif template float moe_smoothquant_>(const S&, A); template float moe_smoothquant_>(const S&, A); diff --git a/example/ck_tile/20_grouped_convolution/CMakeLists.txt b/example/ck_tile/20_grouped_convolution/CMakeLists.txt index 090aae482b..18e71c255d 100644 --- a/example/ck_tile/20_grouped_convolution/CMakeLists.txt +++ b/example/ck_tile/20_grouped_convolution/CMakeLists.txt @@ -17,6 +17,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") add_executable(tile_example_grouped_conv_bwd_weight grouped_convolution_backward_weight.cpp) target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) + # StreamK requires cross-CU coherence (StreamKCoherency), CDNA only. + if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") + add_executable(tile_example_grouped_conv_bwd_weight_streamk grouped_convolution_backward_weight_streamk.cpp) + target_compile_options(tile_example_grouped_conv_bwd_weight_streamk PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) + endif() + add_executable(tile_example_grouped_conv_bwd_weight_two_stage grouped_convolution_backward_weight_two_stage.cpp) target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp index 8287d1171c..6abc002207 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp @@ -17,7 +17,7 @@ template