Merge origin/develop into users/yiding12/fmha-bwd-workspace

This commit is contained in:
Ding, Yi
2026-04-07 05:28:49 -05:00
153 changed files with 15806 additions and 2387 deletions

25
.gitignore vendored
View File

@@ -114,3 +114,28 @@ experimental/grouped_convolution_tile_instances/instances/*
!experimental/grouped_convolution_tile_instances/instances/*.inc
!experimental/grouped_convolution_tile_instances/instances/*.hpp
experimental/grouped_convolution_tile_instances/*.inc
# Heuristics: benchmark data (never in git)
dispatcher/heuristics/data/
# Heuristics: experimental/training artifacts (exclude from git)
dispatcher/heuristics/models/**/oof_predictions.parquet
dispatcher/heuristics/models/**/cv_metrics_*.json
dispatcher/heuristics/models/**/eval_report.json
dispatcher/heuristics/models/**/feature_importances_*.json
dispatcher/heuristics/models/**/model_tflops_ihem.lgbm
dispatcher/heuristics/models/**/model_tflops_log.lgbm
dispatcher/heuristics/models/**/model_tflops_log_big.lgbm
# Heuristics: keep in git (production model files):
# models/{op}_{dtype}_{arch}/model_tflops.lgbm
# models/{op}_{dtype}_{arch}/model_latency.lgbm
# models/{op}_{dtype}_{arch}/model_bandwidth.lgbm
# models/{op}_{dtype}_{arch}/feature_spec.json
# models/{op}_{dtype}_{arch}/train_manifest.json
# Heuristics: logs and caches
dispatcher/heuristics/*.log
dispatcher/heuristics/__pycache__/
dispatcher/heuristics/tests/__pycache__/
dispatcher/heuristics/.pytest_cache/

View File

@@ -2,22 +2,33 @@
FROM ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=7.1.1
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
ARG TARBALL_URL=https://rocm.nightlies.amd.com/tarball/therock-dist-linux-gfx90X-dcgpu-7.12.0a20260218.tar.gz
ARG compiler_version=""
ARG compiler_commit=""
ARG CK_SCCACHE=""
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
ENV DEBIAN_FRONTEND=noninteractive
ENV PATH=$PATH:/opt/rocm/bin
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib
ENV HIP_PLATFORM=amd
# Add rocm repository
RUN set -xe && \
apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl
RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
apt update && \
apt install python3-setuptools python3-wheel -y && \
apt install rocm-dev -y
RUN if [ "$compiler_version" = "therock" ]; then \
rm -rf /opt/rocm && mkdir /opt/rocm && \
echo "Downloading ROCm tarball from $TARBALL_URL..." && \
wget -q -O /tmp/rocm.tar.gz "$TARBALL_URL" && \
echo "Extracting tarball to /opt/rocm..." && \
tar -xzf /tmp/rocm.tar.gz -C /opt/rocm --strip-components=1 ; \
else echo "using the release compiler" && \
wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
apt update && \
apt install python3-setuptools python3-wheel -y && \
apt install rocm-dev -y; \
fi
# Install SCCACHE
ENV SCCACHE_VERSION="0.14.0"
@@ -34,7 +45,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
build-essential \
cmake \
git \
hip-rocclr \
iputils-ping \
jq \
libelf-dev \
@@ -44,8 +54,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
net-tools \
pkg-config \
python3-full \
python3-pip \
redis \
rocm-llvm-dev \
sshpass \
stunnel \
software-properties-common \
@@ -88,26 +98,3 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
git clone -b master https://github.com/ROCm/rocm-cmake.git && \
cd rocm-cmake && mkdir build && cd build && \
cmake .. && cmake --build . && cmake --build . --target install
WORKDIR /
# Add alternative compilers, if necessary
ENV compiler_version=$compiler_version
ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'" && \
sh -c "echo compiler commit = '$compiler_commit'"
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 8 ; \
else echo "using the release compiler"; \
fi
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 8 ; \
else echo "using the release compiler"; \
fi

View File

@@ -9,7 +9,7 @@ ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'" && \
sh -c "echo compiler commit = '$compiler_commit'"
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && git log -1 && mkdir build && cd build && \
cmake -G Ninja \
@@ -43,7 +43,7 @@ RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-sta
else echo "using the release compiler"; \
fi
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" != "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
cmake -G Ninja \

View File

@@ -3,7 +3,6 @@ ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=7.2
ARG compiler_version=""
ARG compiler_commit=""
ARG CK_SCCACHE=""
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
ENV DEBIAN_FRONTEND=noninteractive
@@ -19,16 +18,15 @@ RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2
dnf install python3-setuptools python3-wheel -y && \
dnf install rocm-dev -y
## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined
ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache
# Install SCCACHE
ENV SCCACHE_VERSION="0.14.0"
ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin
ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION}
ENV CK_SCCACHE=$CK_SCCACHE
RUN if [ "$CK_SCCACHE" != "" ]; then \
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \
fi
RUN set -x && \
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
wget -qO sccache.tar.gz https://github.com/mozilla/sccache/releases/latest/download/sccache-v$SCCACHE_VERSION-x86_64-unknown-linux-musl.tar.gz && \
tar -xzf sccache.tar.gz --strip-components=1 -C ${SCCACHE_INSTALL_LOCATION} && \
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache
# Install dependencies
RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \
@@ -83,19 +81,71 @@ ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'" && \
sh -c "echo compiler commit = '$compiler_commit'"
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 8 ; \
cmake -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_PROJECTS="clang;lld;clang-tools-extra;flang;mlir" \
-DLLVM_LIT_ARGS="-vv --show-unsupported --show-xfail -j 32" \
-DPACKAGE_VENDOR="AMD" \
-DCMAKE_INSTALL_PREFIX=/home/$USER/rocm/pure_llvm_1.0 \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_BUILD_DOCS=ON \
-DLLVM_TARGETS_TO_BUILD=all \
-DLIBOMPTARGET_ENABLE_DEBUG=ON \
-DOFFLOAD_ENABLE_EMISSARY_APIS=OFF \
-DCLANG_DEFAULT_LINKER=lld \
-DCLANG_DEFAULT_PIE_ON_LINUX=0 \
-DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi;openmp;compiler-rt;libunwind;flang-rt" \
-DLIBCXX_ENABLE_SHARED=OFF \
-DLIBCXX_ENABLE_STATIC=ON \
-DLIBCXX_INSTALL_LIBRARY=OFF \
-DLIBCXX_INSTALL_HEADERS=OFF \
-DLIBCXXABI_ENABLE_SHARED=OFF \
-DLIBCXXABI_ENABLE_STATIC=ON \
-DLIBCXXABI_INSTALL_STATIC_LIBRARY=OFF \
-DLLVM_ENABLE_ASSERTIONS=1 \
-DLLVM_ENABLE_Z3_SOLVER=OFF \
-DLLVM_ENABLE_ZLIB=ON \
-DLLVM_LINK_LLVM_DYLIB=OFF \
-DCLANG_LINK_CLANG_DYLIB=OFF \
../llvm && \
ninja -j16 ; \
else echo "using the release compiler"; \
fi
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" != "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 8 ; \
cmake -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_PROJECTS="clang;lld;clang-tools-extra;flang;mlir" \
-DLLVM_LIT_ARGS="-vv --show-unsupported --show-xfail -j 32" \
-DPACKAGE_VENDOR="AMD" \
-DCMAKE_INSTALL_PREFIX=/home/$USER/rocm/pure_llvm_1.0 \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_BUILD_DOCS=ON \
-DLLVM_TARGETS_TO_BUILD=all \
-DLIBOMPTARGET_ENABLE_DEBUG=ON \
-DOFFLOAD_ENABLE_EMISSARY_APIS=OFF \
-DCLANG_DEFAULT_LINKER=lld \
-DCLANG_DEFAULT_PIE_ON_LINUX=0 \
-DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi;openmp;compiler-rt;libunwind;flang-rt" \
-DLIBCXX_ENABLE_SHARED=OFF \
-DLIBCXX_ENABLE_STATIC=ON \
-DLIBCXX_INSTALL_LIBRARY=OFF \
-DLIBCXX_INSTALL_HEADERS=OFF \
-DLIBCXXABI_ENABLE_SHARED=OFF \
-DLIBCXXABI_ENABLE_STATIC=ON \
-DLIBCXXABI_INSTALL_STATIC_LIBRARY=OFF \
-DLLVM_ENABLE_ASSERTIONS=1 \
-DLLVM_ENABLE_Z3_SOLVER=OFF \
-DLLVM_ENABLE_ZLIB=ON \
-DLLVM_LINK_LLVM_DYLIB=OFF \
-DCLANG_LINK_CLANG_DYLIB=OFF \
../llvm && \
ninja -j16 ; \
else echo "using the release compiler"; \
fi

22
Jenkinsfile vendored
View File

@@ -421,16 +421,19 @@ def buildDocker(install_prefix){
def image_name = getDockerImageName()
def base_image_name = getBaseDockerImageName()
echo "Building Docker for ${image_name}"
def dockerArgs = "--build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if(params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
def dockerArgs = "--build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if(params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_COMMIT != ""){
dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f projects/composablekernel/Dockerfile.compiler . "
}
else if(params.COMPILER_VERSION == "therock"){
dockerArgs = dockerArgs + " --no-cache -f projects/composablekernel/Dockerfile . "
}
else if(params.RUN_AITER_TESTS){
image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter"
dockerArgs = dockerArgs + " --no-cache -f projects/composablekernel/Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . "
}
else if(params.RUN_PYTORCH_TESTS){
image_name = "${env.CK_DOCKERHUB}:ck_pytorch"
image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_pytorch"
dockerArgs = dockerArgs + " --no-cache -f projects/composablekernel/Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . "
}
else{
@@ -470,10 +473,10 @@ def get_docker_options(){
else{ //only add kfd and dri paths if you actually going to run somthing on GPUs
dockerOpts = "--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
}
if (params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
if (params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "therock" || params.COMPILER_COMMIT != ""){
// the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with
// newer clang22 compilers and running with older hip runtima libraries
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 "
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 --env HIP_PLATFORM=amd "
}
// on some machines the group ids for video and render groups may not be the same as in the docker image!
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
@@ -1140,7 +1143,7 @@ def run_pytorch_tests(Map conf=[:]){
show_node_info()
checkoutComposableKernel()
//use the latest pytorch-nightly image
def image = "${env.CK_DOCKERHUB}:ck_pytorch"
def image = "${env.CK_DOCKERHUB_PRIVATE}:ck_pytorch"
def dockerOpts=get_docker_options() + ' --group-add irc '
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
@@ -1183,16 +1186,19 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_
0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_BASIC_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=false;BUILD_GFX908=false;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=develop;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=therock;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 15 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 13 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
0 11 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true
0 9 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : ""
POLL_SPEC = BRANCH_NAME == "develop" ? 'H H/6 * * *' : ''
pipeline {
agent none
triggers {
parameterizedCron(CRON_SETTINGS)
pollSCM(POLL_SPEC)
}
options {
skipDefaultCheckout()
@@ -1214,7 +1220,7 @@ pipeline {
string(
name: 'COMPILER_VERSION',
defaultValue: '',
description: 'Specify which version of compiler to use: release, develop, amd-staging, amd-mainline, or leave blank (default).')
description: 'Specify which version of compiler to use: develop, amd-staging, therock, or leave blank (default).')
string(
name: 'COMPILER_COMMIT',
defaultValue: '',

View File

@@ -154,6 +154,8 @@ rocminfo | grep -i "gfx"
### Install Python Dependencies
#### Core Dependencies (Required)
NumPy is required for Python examples and kernel generation. We recommend using a virtual environment:
**Option 1: Using standard venv**
@@ -165,8 +167,8 @@ python3 -m venv .venv
source .venv/bin/activate # Linux/macOS
# .venv\Scripts\activate # Windows
# Install NumPy
pip install numpy
# Install core dependencies
pip install -r python/requirements.txt
```
**Option 2: Using uv (faster alternative)**
@@ -179,17 +181,38 @@ uv venv .venv
source .venv/bin/activate # Linux/macOS
# .venv\Scripts\activate # Windows
# Install NumPy
uv pip install numpy
# Install core dependencies
uv pip install -r python/requirements.txt
```
**Option 3: System-wide install (not recommended)**
```bash
pip install numpy
pip install -r python/requirements.txt
```
> **Note:** Always activate your virtual environment before running CMake or Python examples.
#### ML Heuristics Dependencies (Optional)
For ML-based kernel selection (examples 09-11), install additional dependencies:
```bash
# Activate your virtual environment first
source .venv/bin/activate
# Install ML dependencies (LightGBM, pandas, pyarrow, scikit-learn)
pip install -r requirements-ml.txt
```
**Why separate?** ML dependencies are large (especially pyarrow) and not needed for basic dispatcher usage. Install only if you need:
- ML-based kernel selection (`examples/gemm/python/09_ml_heuristic.py`)
- Model training (`heuristics/train.py`)
- Model evaluation (`heuristics/evaluate.py`)
- Automated benchmark analysis
**Core dependencies:** ~50 MB (NumPy only)
**With ML dependencies:** ~500 MB (includes LightGBM, pandas, pyarrow, scikit-learn)
### Supported Data Types
CK Tile supports a wide range of data types for GEMM operations:
@@ -470,6 +493,42 @@ python3 examples/gemm/python/10_advanced_benchmark.py \
---
## ML-Based Kernel Selection (Optional)
The dispatcher includes ML heuristics for automated kernel selection using trained LightGBM models.
**Prerequisites:** Install ML dependencies first:
```bash
pip install -r requirements-ml.txt # ~500 MB (LightGBM, pandas, pyarrow, scikit-learn)
```
**Documentation:** See [heuristics/README.md](heuristics/README.md) for:
- Training and evaluating models
- Feature engineering (72 features)
- Using pre-trained models
- Python API reference
**Examples:**
```bash
python3 examples/gemm/python/09_ml_heuristic.py # ML-based kernel selection
python3 examples/gemm/python/10_rank_kernels.py # Kernel ranking
```
**Model Compression:** Trained models are stored in compressed `.lgbm.gz` format to save space (~67% size reduction). Python tools automatically decompress models on first use. For C++ examples, decompress manually:
```bash
# If you have compressed models
cd heuristics/models/gemm_universal_fp16_gfx950
gunzip model_tflops.lgbm.gz
# Then use in C++ example
cd ../../../build
./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm
```
---
## External Integration
### Using Dispatcher in Your Own Project

View File

@@ -346,6 +346,55 @@ add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.
add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp)
add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp)
# ML Heuristic example -- requires LightGBM shared library
# Derive site-packages from active Python interpreter (respects virtualenvs)
find_package(Python3 COMPONENTS Interpreter)
set(LIGHTGBM_SEARCH_PATHS)
if(Python3_FOUND AND Python3_EXECUTABLE)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_path('purelib'))"
OUTPUT_VARIABLE PYTHON_SITE_PACKAGES
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)
if(PYTHON_SITE_PACKAGES)
list(APPEND LIGHTGBM_SEARCH_PATHS "${PYTHON_SITE_PACKAGES}/lightgbm/lib")
endif()
endif()
# Fallback to common Python 3.x site-packages if auto-detection failed
if(NOT PYTHON_SITE_PACKAGES)
list(APPEND LIGHTGBM_SEARCH_PATHS
"$ENV{HOME}/.local/lib/python3.12/site-packages/lightgbm/lib"
)
endif()
find_library(LIGHTGBM_LIB NAMES LightGBM lib_lightgbm _lightgbm
HINTS ${CMAKE_PREFIX_PATH}
PATHS ${LIGHTGBM_SEARCH_PATHS}
NO_DEFAULT_PATH
DOC "LightGBM shared library for ML heuristics"
)
# Fallback: search default paths (respects LightGBM_DIR if set by user)
if(NOT LIGHTGBM_LIB)
find_library(LIGHTGBM_LIB NAMES LightGBM lib_lightgbm)
endif()
if(LIGHTGBM_LIB)
add_declarative_gpu_example(gemm_09_ml_heuristic gemm/cpp/09_ml_heuristic.cpp)
target_link_libraries(gemm_09_ml_heuristic PRIVATE ${LIGHTGBM_LIB})
message(STATUS "LightGBM found: ${LIGHTGBM_LIB} -- building gemm_09_ml_heuristic")
else()
message(STATUS "LightGBM not found -- skipping gemm_09_ml_heuristic")
message(STATUS " To enable ML heuristic example:")
message(STATUS " 1. Activate virtualenv: source .venv/bin/activate")
message(STATUS " 2. Install: pip install -r ../requirements-ml.txt")
message(STATUS " 3. Reconfigure: cmake ..")
message(STATUS " Or set CMAKE_PREFIX_PATH or LightGBM_DIR to LightGBM location")
endif()
# =============================================================================
# GEMM Python Library - Single Fallback Kernel
# =============================================================================

View File

@@ -0,0 +1,211 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 09: ML-Based Kernel Selection (Native C++)
*
* Uses a trained LightGBM model loaded via the C API to predict TFLOPS
* for each kernel in the registry and select the best one. The kernels
* are JIT-compiled at build time via DECL_KERNEL_SET (same as other examples).
*
* Build: cd dispatcher/build && cmake .. && make gemm_09_ml_heuristic
* Run: ./gemm_09_ml_heuristic --model <path_to_model.lgbm>
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <chrono>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
#include "ck_tile/dispatcher/ml_heuristic.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// Multiple kernel configs for ML to choose from
DECL_KERNEL_SET(ml_kernels,
// Small tiles
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 32)
.wave(2, 2, 1)
.warp(16, 16, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 64)
.wave(2, 2, 1)
.warp(16, 16, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
// Medium tiles
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
// Large tiles
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(256, 256, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(256, 128, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 256, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942"));
int main(int argc, char* argv[])
{
ExampleArgs args("Example 09: ML-Based Kernel Selection",
"Uses trained LightGBM model for kernel selection");
args.add_option("--arch", "gfx942", "GPU architecture");
args.add_option("--model", "", "Path to LightGBM model file (.lgbm)");
args.add_option("--log_transform", "false", "Model uses log1p transform");
if(!args.parse(argc, argv))
return 0;
print_header("Example 09: ML-Based Kernel Selection");
std::string gfx_arch = args.get("--arch", "gfx942");
std::string model_path = args.get("--model", "");
bool log_transform = (args.get("--log_transform", "false") == "true");
if(model_path.empty())
{
std::cerr << "Error: --model <path> is required" << std::endl;
std::cerr << "Usage: ./gemm_09_ml_heuristic --model path/to/model_tflops.lgbm" << std::endl;
return 1;
}
// Setup Registry (kernels are JIT compiled from DECL_KERNEL_SET above)
Registry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << "Registry: " << registry.size() << " kernel(s)" << std::endl;
// Load ML model and create heuristic
HardwareProfile hw;
MLHeuristic ml_heuristic(model_path, &registry, hw, log_transform);
if(!ml_heuristic.is_loaded())
{
std::cerr << "Failed to load model. Exiting." << std::endl;
return 1;
}
// Wire ML heuristic into dispatcher
Dispatcher dispatcher(&registry);
dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic);
dispatcher.set_heuristic([&ml_heuristic](const Problem& p) { return ml_heuristic(p); });
std::cout << "Strategy: ML Heuristic (LightGBM)" << std::endl;
// Test with different problem sizes
using DataType = ck_tile::fp16_t;
std::vector<std::tuple<int, int, int>> sizes = {
{128, 128, 64},
{512, 512, 256},
{1024, 1024, 512},
{2048, 2048, 1024},
};
std::cout << std::endl
<< std::setw(20) << "Shape" << std::setw(30) << "Selected Kernel" << std::setw(15)
<< "Pred TFLOPS" << std::setw(12) << "Select ms" << std::setw(10) << "Status"
<< std::endl;
std::cout << std::string(87, '-') << std::endl;
bool all_passed = true;
for(const auto& [M, N, K] : sizes)
{
Problem problem;
problem.M = M;
problem.N = N;
problem.K = K;
problem.k_batch = 1;
auto t0 = std::chrono::high_resolution_clock::now();
auto kernel = dispatcher.select_kernel(problem);
auto t1 = std::chrono::high_resolution_clock::now();
double select_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
std::string size_str =
std::to_string(M) + "x" + std::to_string(N) + "x" + std::to_string(K);
if(!kernel)
{
std::cout << std::setw(20) << size_str << std::setw(30) << "NONE" << std::setw(15)
<< "N/A" << std::setw(12) << std::fixed << std::setprecision(2) << select_ms
<< std::setw(10) << "FAIL" << std::endl;
all_passed = false;
continue;
}
double pred = ml_heuristic.predict_tflops(problem, kernel->get_key());
std::string name = kernel->get_key().encode_identifier();
if(name.length() > 27)
name = name.substr(0, 27) + "..";
std::cout << std::setw(20) << size_str << std::setw(30) << name << std::setw(15)
<< std::fixed << std::setprecision(2) << pred << std::setw(12)
<< std::setprecision(2) << select_ms << std::setw(10) << "OK" << std::endl;
}
std::cout << std::endl
<< (all_passed ? "*** ALL TESTS PASSED ***" : "*** SOME TESTS FAILED ***")
<< std::endl;
return all_passed ? 0 : 1;
}

View File

@@ -0,0 +1,305 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 09: ML-Based Kernel Selection
Uses a trained LightGBM model to select the optimal kernel for each problem
size. The model predicts TFLOPS for every candidate in the kernel pool and
picks the highest-scoring one, which is then JIT-compiled and run.
This replaces the hand-crafted rules in 08_heuristics.py with a data-driven
approach achieving 97-98% of oracle-best TFLOPS efficiency.
Complexity: *****
Prerequisites:
- Trained model in dispatcher/heuristics/models/gemm_universal_fp8_gfx950/
- lightgbm, pandas, numpy, pyarrow installed
Usage:
python3 09_ml_heuristic.py
python3 09_ml_heuristic.py --dtype fp16 --arch gfx942
"""
import sys
import argparse
import time
from pathlib import Path
from dataclasses import dataclass
from typing import List
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "heuristics"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
)
from predict import Predictor
@dataclass
class KernelSpec:
"""Kernel specification -- same structure as 08_heuristics.py"""
name: str
tile_m: int
tile_n: int
tile_k: int
pipeline: str = "compv3"
scheduler: str = "intrawave"
wave_m: int = 2
wave_n: int = 2
wave_k: int = 1
warp_m: int = 32
warp_n: int = 32
warp_k: int = 16
# Kernel pool: representative configs spanning small to large tiles,
# compv3/compv4/mem pipelines, and intrawave/interwave schedulers.
KERNEL_POOL = [
# Small tiles
KernelSpec("s_64x64_k32_v3", 64, 64, 32, "compv3", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k64_v3", 64, 64, 64, "compv3", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k128_v3", 64, 64, 128, "compv3", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k32_v4", 64, 64, 32, "compv4", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", warp_m=16, warp_n=16),
KernelSpec("s_64x64_k128_mem", 64, 64, 128, "mem", warp_m=16, warp_n=16),
# Medium tiles
KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3"),
KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3"),
KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3"),
KernelSpec("m_128x128_k32_v4", 128, 128, 32, "compv4"),
KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4"),
KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem"),
KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem"),
# Rectangular medium
KernelSpec("r_64x128_k32", 64, 128, 32, "compv3", warp_m=16),
KernelSpec("r_128x64_k32", 128, 64, 32, "compv3", warp_n=16),
KernelSpec("r_64x128_k64", 64, 128, 64, "compv3", warp_m=16),
KernelSpec("r_128x64_k64", 128, 64, 64, "compv3", warp_n=16),
# Large tiles
KernelSpec("l_256x128_k32", 256, 128, 32, "compv3"),
KernelSpec("l_128x256_k32", 128, 256, 32, "compv3"),
KernelSpec("l_256x256_k32", 256, 256, 32, "compv3"),
KernelSpec("l_256x256_k64", 256, 256, 64, "compv3"),
# Interwave variants
KernelSpec("m_128x128_k64_iw", 128, 128, 64, "compv3", "interwave"),
KernelSpec("m_128x128_k64_mem_iw", 128, 128, 64, "mem", "interwave"),
]
def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict:
"""Convert a KernelSpec to the dict format the feature engine expects.
Note: pad_m/n/k default to True to match KernelConfig defaults and actual
compiled kernels. This ensures the ML model receives the correct padding
flags that will be used during JIT compilation.
"""
return {
"kernel_name": spec.name,
"tile_m": spec.tile_m,
"tile_n": spec.tile_n,
"tile_k": spec.tile_k,
"warp_m": spec.wave_m,
"warp_n": spec.wave_n,
"warp_k": spec.wave_k,
"warp_tile_m": spec.warp_m,
"warp_tile_n": spec.warp_n,
"warp_tile_k": spec.warp_k,
"pipeline": spec.pipeline,
"scheduler": spec.scheduler,
"epilogue": "cshuffle",
"pad_m": True, # Match KernelConfig default
"pad_n": True, # Match KernelConfig default
"pad_k": True, # Match KernelConfig default
"persistent": False,
"dtype": dtype,
"layout": layout,
}
def spec_to_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
"""Convert a KernelSpec to the dispatcher's KernelConfig for JIT compilation."""
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
dtype_c=dtype,
dtype_acc="fp32",
layout_a="row",
layout_b="col",
layout_c="row",
tile_m=spec.tile_m,
tile_n=spec.tile_n,
tile_k=spec.tile_k,
wave_m=spec.wave_m,
wave_n=spec.wave_n,
wave_k=spec.wave_k,
warp_m=spec.warp_m,
warp_n=spec.warp_n,
warp_k=spec.warp_k,
pipeline=spec.pipeline,
scheduler=spec.scheduler,
epilogue="cshuffle",
gfx_arch=arch,
)
def ml_select_kernel(
predictor: Predictor,
pool: List[KernelSpec],
M: int,
N: int,
K: int,
dtype: str,
layout: str,
) -> tuple:
"""Score all kernels in the pool and return (best_spec, predicted_tflops)."""
problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1}
kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool]
ranked = predictor.rank_kernels(problem, kernel_dicts)
if not ranked:
return pool[0], 0.0
best_name, best_tflops = ranked[0]
best_spec = next((s for s in pool if s.name == best_name), pool[0])
return best_spec, best_tflops
def main():
parser = argparse.ArgumentParser(description="ML-based kernel selection for GEMM")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp8"])
parser.add_argument("--arch", default="gfx942")
parser.add_argument(
"--model_dir",
default=str(
Path(__file__).parent.parent.parent.parent
/ "heuristics"
/ "models"
/ "gemm_universal_fp8_gfx950"
),
)
parser.add_argument(
"--no_run", action="store_true", help="Only predict, don't run GEMMs"
)
args = parser.parse_args()
print("=" * 75)
print(" Example 09: ML-Based Kernel Selection")
print("=" * 75)
print(f"\n Model: {args.model_dir}")
print(f" Dtype: {args.dtype}")
print(f" Arch: {args.arch}")
print(f" Pool: {len(KERNEL_POOL)} kernels")
predictor = Predictor(args.model_dir)
print(" Model loaded successfully")
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float16
test_sizes = [
(128, 128, 64),
(256, 256, 128),
(512, 512, 256),
(1024, 1024, 512),
(2048, 2048, 1024),
]
header = f"{'Shape':<20} {'Selected Kernel':<25} {'Pred TFLOPS':>12}"
if not args.no_run:
header += f" {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
print(f"\n {header}")
print(" " + "-" * len(header))
results = []
for M, N, K in test_sizes:
t0 = time.time()
best_spec, pred_tflops = ml_select_kernel(
predictor, KERNEL_POOL, M, N, K, args.dtype, "rcr"
)
_ = (time.time() - t0) * 1000 # ML selection time (unused)
size_str = f"{M}x{N}x{K}"
line = f" {size_str:<20} {best_spec.name:<25} {pred_tflops:>12.2f}"
if args.no_run:
print(line)
results.append((size_str, best_spec.name, True, 0, pred_tflops))
continue
config = spec_to_kernel_config(best_spec, args.dtype, args.arch)
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"ml_{best_spec.name}",
verbose=False,
auto_rebuild=True,
)
if not setup.success:
line += f" {'N/A':>10} {'N/A':>10} {'BUILD':>8}"
print(line)
results.append((size_str, best_spec.name, False, 0, 0))
cleanup_gemm()
continue
dispatcher = setup.dispatcher
if not dispatcher.is_supported(M, N, K):
line += f" {'N/A':>10} {'N/A':>10} {'UNSUP':>8}"
print(line)
results.append((size_str, best_spec.name, False, 0, 0))
cleanup_gemm()
continue
np.random.seed(42)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
result = dispatcher.run(A, B, M, N, K)
if result.success:
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(
np_dtype
)
max_err = np.max(np.abs(result.output - C_ref))
passed = max_err < 1e-2
status = "PASS" if passed else "FAIL"
line += f" {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
results.append(
(size_str, best_spec.name, passed, result.time_ms, result.tflops)
)
else:
line += f" {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
results.append((size_str, best_spec.name, False, 0, 0))
print(line)
cleanup_gemm()
# Summary
print("\n" + "=" * 75)
print(" SUMMARY")
print("=" * 75)
passed = sum(1 for r in results if r[2])
print(f"\n Results: {passed}/{len(results)} tests passed")
valid = [r for r in results if r[2] and r[4] > 0]
if valid:
avg = sum(r[4] for r in valid) / len(valid)
print(f" Average TFLOPS: {avg:.2f}")
if passed == len(results):
print("\n *** ALL TESTS PASSED ***")
print("=" * 75)
return 0 if passed == len(results) else 1
if __name__ == "__main__":
sys.exit(main())

60
dispatcher/heuristics/.gitignore vendored Normal file
View File

@@ -0,0 +1,60 @@
# Python bytecode and caches
__pycache__/
*.pyc
*.pyo
*.pyd
.Python
# Jupyter notebooks
*.ipynb
.ipynb_checkpoints/
# Virtual environments
.venv/
venv/
ENV/
# IDE and editor files
.vscode/
.idea/
*.swp
*.swo
*~
# Test output and logs
*.log
test_output.log
custom_shapes_gpu_test.log
# Benchmark and analysis output files
*.csv
*.json
!models/*/feature_spec.json
!models/*/train_manifest.json
# Data files (parquet, arrow)
*.parquet
*.arrow
# Temporary and NFS files
.nfs*
*.tmp
*.bak
# Decompressed model files (compressed .lgbm.gz versions are tracked)
models/**/*.lgbm
# User-specific test and analysis scripts
test_*.py
!tests/test_*.py
find_*.py
oracle_*.json
validation_results_*.csv
custom_shapes_*.csv
fp16_bf16_*.csv
# Ignore all markdown files except tracked documentation
*.md
!DATA_GENERATION.md
!LEARNINGS.md
!README.md

View File

@@ -0,0 +1,412 @@
# Data Generation Guide
This document explains how to build benchmark binaries from the CK Tile engine,
generate benchmark datasets, and manage them for the ML kernel performance
prediction system.
## Overview
The ML heuristic needs benchmark data: measured TFLOPS, latency, and bandwidth
for every (problem shape, kernel config) pair. The tile engine builds one
executable per kernel configuration. Each executable benchmarks a single kernel
on a given problem size and outputs JSON with performance metrics.
```
CK source --> CMake configure --> ninja build --> benchmark binaries
(4608 per op/dtype/layout)
benchmark binaries --> run on GPU --> streaming log --> parquet dataset
(per shape) (JSON blocks) (canonical schema)
```
## Prerequisites
- **ROCm**: HIP >= 6.0.3 (for gfx950: HIP >= 6.0.4)
- **Build tools**: CMake >= 3.21, Ninja, HIP-aware clang compiler
- **Python**: 3.10+ with `pandas`, `pyarrow`
- **GPU**: ROCm-capable AMD GPU (MI250X, MI300X, MI355X, etc.)
---
## Part 1: Building Benchmark Binaries from the Tile Engine
If you already have pre-built binaries (e.g., in `/workspace/ck_tile/bin/`),
skip to Part 2. This section explains how to build them from source.
### Step 1: CMake Configure
From the CK repository root:
```bash
cmake -S /workspace/rocm-libraries/projects/composablekernel \
-B build \
-DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx950" \
-DGEMM_UNIVERSAL_DATATYPE="fp8" \
-DGEMM_UNIVERSAL_LAYOUT="rcr" \
-G Ninja
```
**Key CMake variables:**
| Variable | Default | Description |
|---|---|---|
| `GPU_TARGETS` | (required) | Target GPU architectures. Supported: `gfx90a`, `gfx942`, `gfx950`, `gfx1201`. Semicolon-separated for multiple. |
| `GEMM_UNIVERSAL_DATATYPE` | `"fp8;fp16"` | Data types to build. Options: `fp8`, `fp16`, `bf16`, `bf8`. Semicolon-separated. |
| `GEMM_UNIVERSAL_LAYOUT` | `"rcr;rrr;crr;ccr"` | Layouts to build. Semicolon-separated. |
| `GEMM_UNIVERSAL_CONFIG_FILE` | `"default_config.json"` | Kernel config file (in the `configs/` directory). Controls which tile sizes, warp configs, pipelines, etc. are enumerated. |
| `ENABLE_CCACHE_GEMM_UNIVERSAL` | `OFF` | Enable ccache for faster rebuilds. |
**Example: build only fp8 RCR for gfx950 (fastest, ~4608 kernels):**
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx950" \
-DGEMM_UNIVERSAL_DATATYPE="fp8" \
-DGEMM_UNIVERSAL_LAYOUT="rcr" \
-G Ninja
```
**Example: build all dtypes and layouts (slow, ~4608 * 4 * 4 = ~73K kernels):**
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx950" \
-DGEMM_UNIVERSAL_DATATYPE="fp8;fp16;bf16;bf8" \
-DGEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
-G Ninja
```
### What happens during configure
1. CMake calls `gemm_universal_instance_builder.py --list_kernels` to enumerate
all valid kernel configurations from the config JSON.
2. It writes `gemm_universal_kernel_list.txt` (one kernel per line) and
`gemm_universal_kernel_count.txt` to the build directory.
3. For each kernel, it creates a ninja build target.
### Step 2: Build
```bash
# Build all benchmarks for the configured dtypes/layouts
ninja -C build benchmark_gemm_universal_all
# Or build a specific dtype/layout combo
ninja -C build benchmark_gemm_universal_fp8_rcr
# Or build by pipeline type
ninja -C build benchmark_gemm_universal_compv4_pipeline
ninja -C build benchmark_gemm_universal_mem_pipeline
# Or build a single specific kernel
ninja -C build benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128
```
**Build time estimates:**
- ~4608 kernels (one dtype, one layout): 1-4 hours depending on CPU cores
- Use `-j <N>` to control parallelism: `ninja -C build -j 32 benchmark_gemm_universal_fp8_rcr`
### Step 3: Verify binaries
Binaries are placed in `build/bin/`:
```bash
ls build/bin/benchmark_gemm_universal_fp8_rcr_* | wc -l
# Expected: 4608 (for default config)
# Test one binary
./build/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \
-m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0
```
### Kernel config files
The config files live in:
```
tile_engine/ops/gemm/gemm_universal/configs/
default_config.json # Default: full enumeration
default_ci_config.json # CI: reduced set for fast testing
user_provided_config.json # Custom: your own subset
```
To use a custom config:
```bash
cmake ... -DGEMM_UNIVERSAL_CONFIG_FILE="user_provided_config.json"
```
The config controls which tile sizes (e.g., 128x128x64, 256x256x32), warp
configurations (e.g., 2x2x1, 1x4x1), pipelines (compv3, compv4, mem),
schedulers, and other parameters are included in the kernel enumeration.
### Building StreamK / other ops
The same pattern applies to other tile engine ops:
```bash
# StreamK
ninja -C build benchmark_gemm_streamk_fp8_rcr
# Grouped convolution
ninja -C build benchmark_grouped_conv_fwd_fp16_nhwgc
```
Each op has its own instance builder and config directory.
---
## Part 2: Running Benchmarks and Generating Data
## Quick Start
### 1. Run benchmarks for a set of shapes
Each binary accepts `-m=`, `-n=`, `-k=`, `-warmup=`, `-repeat=`, `-verify=` flags
and outputs JSON to stdout:
```bash
/workspace/ck_tile/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \
-m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0
```
Output:
```json
{
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_...",
"problem": {
"split_k": 1, "m": 1024, "n": 1024, "k": 1024,
"dtype_a": "fp8", "dtype_b": "fp8", ...
},
"perf_result": {
"latency(ms)": 0.04,
"tflops(TFlops)": 204.60,
"bandwidth(GB/s)": 624.39
}
}
```
### 2. Batch generation using provided scripts
**Wide coverage (diverse shapes across all regimes):**
```bash
python3 generate_wide_coverage.py \
--bin_dir /workspace/ck_tile/bin \
--out_dir data/wide_coverage \
--batch_size 25 \
--warmup 3 --repeat 10
```
**Edge-case dimensions (N=1, K=1, small N/K):**
```bash
python3 generate_edge_dims.py
```
Both scripts write streaming log files that `data_pipeline.py` can parse.
### 3. Parse logs into parquet
```bash
python3 data_pipeline.py <log_file> \
-o data/my_dataset.parquet \
--arch gfx950 \
--capture_hw
```
The `--capture_hw` flag runs `rocminfo` once and injects the GPU hardware
profile (CU count, clock speed, cache sizes, etc.) into every row.
## Canonical Data Schema
Every parquet file follows this schema:
| Column | Type | Description |
|---|---|---|
| `op_type` | str | `gemm_universal`, `gemm_streamk`, etc. |
| `dtype` | str | `fp8`, `fp16`, `bf16`, `bf8` |
| `layout` | str | `rcr`, `rrr`, `crr`, `ccr` |
| `arch` | str | `gfx942`, `gfx950`, etc. |
| `kernel_name` | str | Full kernel identifier |
| `m`, `n`, `k` | int | Problem dimensions |
| `split_k` | int | Split-K factor (1 = standard) |
| `measured_tflops` | float | Ground-truth TFLOPS |
| `latency_ms` | float | Measured latency |
| `bandwidth_gb_s` | float | Measured bandwidth |
| `is_valid` | bool | True if tflops > 0 and latency > 0 |
| `tile_m`, `tile_n`, `tile_k` | int | Tile dimensions |
| `warp_m`, `warp_n`, `warp_k` | int | Warp config |
| `warp_tile_m/n/k` | int | Warp tile dimensions |
| `pipeline` | str | `compv3`, `compv4`, `mem`, etc. |
| `scheduler` | str | `intrawave`, `interwave` |
| `epilogue` | str | `cshuffle`, `default` |
| `pad_m`, `pad_n`, `pad_k` | bool | Padding flags |
| `persistent` | bool | Persistent kernel flag |
| `run_id` | str | Unique collection run identifier |
## Shape Selection Guidelines
Good training data requires diverse shapes. Cover all of these regimes:
### By M dimension (batch size / output rows)
- **M=1**: single-token inference (hardest case for tiling)
- **Tiny M (2-16)**: small batch inference
- **Small M (32-128)**: medium batch
- **Medium M (256-2048)**: large batch / training
- **Large M (4096-20480)**: very large batch
### By N and K dimension
- **N=1**: vector-matrix multiply (degenerate)
- **K=1**: rank-1 update / outer product (degenerate)
- **Small N or K (2-16)**: stress tile efficiency
- **Deep K (K > 4096)**: compute-bound regime
- **Shallow K (K < 256)**: memory-bound regime
### By shape family
- **Square**: M ~ N ~ K (powers of 2)
- **Tall**: M >> N (tall output matrix)
- **Wide**: N >> M (wide output matrix)
- **Deep-K**: K >> M and K >> N
### Special cases
- **Prime dimensions**: 17, 31, 127, 251, 509, 1021, 2039, 4093
(worst-case for tile alignment, tests padding logic)
- **Non-power-of-2**: 48, 96, 192, 384, 576, 768, 1536, 3072, 4608
(common in LLM architectures)
- **LLM inference shapes**: DeepSeek, LLaMA-7B, LLaMA-70B MLP/attention dims
### Minimum recommended coverage
For a production-quality model, aim for:
- At least 200 unique (M, N, K) shapes
- At least 10 shapes per shape family
- All kernel configs (4608 for fp8 RCR) run against every shape
- Multiple layouts if training a cross-layout model
## Benchmark Quality Guidelines
### Warmup and repeat
- Minimum `warmup=3`, `repeat=10` for fast iteration
- Production quality: `warmup=5`, `repeat=20` for stable measurements
- The `perf_result` values are averaged over `repeat` iterations
### Noise handling
- Use **median** latency when aggregating multiple runs of the same benchmark
- Flag measurements where coefficient of variation exceeds 10%
- Avoid benchmarking under thermal throttling (check GPU temperature)
- Lock GPU clocks if possible for reproducibility
### Environment metadata
Store with every dataset:
- GPU model and architecture (from `rocminfo`)
- ROCm driver version
- Clock mode (default / locked)
- Git hash of the CK tile engine build (if available)
- Timestamp
## Adding Data for a New Op
To generate benchmark data for a new operation (e.g., `gemm_streamk`):
1. **Build the binaries** using the tile engine:
```bash
ninja -C build benchmark_gemm_streamk_fp8_rcr
```
2. **Write a generation script** (or modify `generate_wide_coverage.py`):
- Change the executable glob pattern to match the new op
- Add any op-specific CLI flags the binaries need
3. **Run and parse**:
```bash
python3 data_pipeline.py my_streamk_run.log \
-o data/gemm_streamk_fp8_gfx950.parquet --arch gfx950
```
4. **Train**:
```bash
python3 train.py --op gemm_streamk --dtype fp8 --arch gfx950 \
--data_dir data/ --out_dir models/gemm_streamk_fp8_gfx950
```
## Adding Data for a New Layout
Same binaries, same shapes -- just change the layout filter:
```bash
# Build rrr binaries
ninja -C build benchmark_gemm_universal_fp8_rrr
# Generate and parse
# ... (same flow, different bin_dir or executable glob)
# Train a cross-layout model by putting all layouts in the same data_dir
python3 train.py --data_dir data/ --out_dir models/gemm_universal_fp8_gfx950_all_layouts
```
The feature engine includes `layout` as a categorical feature, so one model
can handle all layouts.
## Incremental Data Collection
When you have a trained model and want to add more data:
1. Generate new data (new shapes, new layouts, etc.)
2. Parse into parquet alongside existing data
3. Warm-start from the previous model:
```bash
python3 train.py --data_dir data/ --out_dir models/v2 \
--warm_start models/v1 \
--warm_start_n_estimators 200
```
This adds 200 new trees on top of the existing model. The feature schema
must match exactly (enforced automatically).
## File Organization
Recommended directory structure:
```
heuristics/
data/
gemm_universal_fp8_rcr_gfx950.parquet # original 108 shapes
wide_coverage/ # batch log files
wide_coverage_batch_001.log
wide_coverage_batch_002.log
...
edge_dims/ # N=1, K=1 edge cases
edge_dims_batch_001.log
...
models/
gemm_universal_fp8_gfx950/ # trained model artifacts
model_tflops.lgbm
model_latency.lgbm
model_bandwidth.lgbm
feature_spec.json
train_manifest.json
cv_metrics_tflops.json
eval_report.json
...
```
## Troubleshooting
### Benchmark binary exits with non-zero code
Some kernel configs are invalid for certain problem sizes (e.g., tile_m=256
with M=16). The data pipeline marks these as `is_valid=False` and they are
filtered out during training. This is expected.
### Edge dims produce very few results
N=1 and K=1 shapes are degenerate -- most kernel configurations have minimum
dimension requirements and will fail or produce zero TFLOPS. The small number
of valid results is still useful (it tells the model which configs work for
these shapes).
### Benchmarks are slow
Each shape requires running all 4608 kernel executables sequentially. At
~0.01s per kernel, that is ~46 seconds per shape. For 700 shapes, expect
~9 hours. Tips:
- Run on a dedicated GPU (no other workloads)
- Use `--batch_size 25` to get incremental output
- Parse and train on partial data while generation continues
### Data from different GPUs / driver versions
Store `run_id` and hardware metadata with each dataset. Training on mixed
data is allowed but not recommended for production models. Filter to a
single `run_id` or `arch` for clean experiments.

View File

@@ -0,0 +1,151 @@
# Learnings and Design Decisions
Empirical findings from building the CK Tile kernel performance prediction system.
These inform the current defaults and explain why certain approaches were chosen.
## 1. Log-Transform is Essential for Cross-Scale Accuracy
**Problem**: GEMM TFLOPS spans 5 orders of magnitude across different problem
sizes. When training on raw TFLOPS, the regression loss (RMSE) is dominated by
large shapes where absolute errors are biggest. The model learns to predict
large shapes accurately but ignores tiny shapes where the TFLOPS values are
much lower.
**Evidence** (168 shapes, 626K rows, 5-fold GroupKFold CV):
| Model | Mean Eff | P10 Eff | tiny_m Eff | Min Eff |
| ----------------------------- | ---------- | ---------- | ---------- | ---------- |
| Raw TFLOPS (500 trees) | 92.73% | 80.24% | 84.55% | 4.26% |
| **log1p(TFLOPS)** (500 trees) | **96.92%** | **94.34%** | **94.89%** | **60.27%** |
| log1p(TFLOPS) (2000 trees) | 97.51% | 93.89% | 96.04% | 63.56% |
**Solution**: Train on `log1p(measured_tflops)` and apply `expm1()` to
predictions. This is now the default in `train.py`. Pass `--no_log_transform`
to revert to raw regression (not recommended).
**Why log1p, not log**: `log1p(x) = log(1 + x)` handles zero and near-zero
TFLOPS gracefully, whereas `log(x)` produces -inf for x=0.
## 2. Tiny-M Shapes are the Hardest Case
M=1 (single-token inference) shapes are fundamentally different from batch shapes:
- Most kernel configurations produce very low TFLOPS
- The "best" kernel is often only marginally better than the rest
- The oracle performance itself is very low, so any prediction error tanks efficiency
- Many kernels fail outright (tile_m=128 with M=1 wastes 127/128 of the tile)
The bottom shapes in our evaluation are all M=1, with efficiencies in the
63-70% range. These shapes have such low absolute performance that the model's
noise floor exceeds the performance difference between kernels.
**Mitigation**: Log-transform helps significantly (tiny_m improved from 84% to
96%). For production use with M=1, consider a dedicated fallback (e.g.,
hardcoded kernel selection for M < 4 based on known-good configs).
## 3. IHEM (Hard Example Mining) Hurts When Scale is the Issue
We tried Iterative Hard Example Mining with sample reweighting (2x-5x weight
on hard shapes). Result: it made things **worse**, degrading mean efficiency
from 94.31% to 92.90% over 3 iterations.
**Why**: The hard shapes are hard because of scale mismatch, not because the
model lacks capacity. Reweighting amplifies the small-TFLOPS rows, which
distorts the learned relationship between features and performance for the
majority of shapes. The log-transform was the correct fix -- it addresses the
root cause (scale) rather than the symptom (bad predictions on tiny shapes).
**Lesson**: IHEM is useful when the model has capacity gaps (e.g., certain
pipeline types are underrepresented). It is counterproductive when the issue
is target-variable scale. Always try target transforms before reweighting.
## 4. GroupKFold Key = (M, N, K) Forces Generalization
The validation uses `GroupKFold` where the group key is `(M, N, K)` -- all
kernels for the same shape go to the same fold. This means:
- The model is always evaluated on shapes it has **never seen** during training
- Layout is excluded from the key, forcing the model to generalize across layouts
- Since models are per-arch, `arch` is implicit (constant within one training run)
This is much stricter than random row splitting, where the model would see some
kernels for each shape during training. Our efficiency numbers are conservative
estimates of real-world performance on unseen shapes.
## 5. Model Size vs Accuracy Tradeoff
| Config | Trees | Leaves | LR | Mean Eff | P10 Eff | Train Time |
| ------------------ | -------- | ------- | -------- | ---------- | ---------- | ------------- |
| Small (default v1) | 500 | 127 | 0.05 | 96.92% | 94.34% | ~20s |
| **Big (current)** | **2000** | **255** | **0.02** | **97.51%** | **93.89%** | **~25s/fold** |
The bigger model improved mean efficiency by 0.6% but P10 didn't improve
(actually slightly worse). The extra capacity helps on medium shapes but
doesn't crack the tiny-M floor. This suggests the feature set, not model
capacity, is the limiting factor for the hardest shapes.
For C++ deployment, the bigger model (2000 trees, 255 leaves) is still fast
enough -- LightGBM inference is O(trees * log(leaves)) per sample, which is
~microseconds even at 2000 trees.
## 6. N=1 and K=1 Shapes are Degenerate
We generated benchmark data for 546 edge-case shapes (N=1, K=1, small N/K).
Result: **zero valid kernel results** across 94 shapes. All 4608 kernels either
fail or produce 0 TFLOPS for these degenerate dimensions.
This means:
- The tile engine kernels have hard minimum dimension requirements
- N=1 / K=1 shapes cannot be handled by the current kernel set
- These shapes need dedicated kernels (e.g., BLAS-1/BLAS-2 fallbacks)
- The ML model should not be expected to handle them -- they should be filtered
out before reaching the heuristic
## 7. Feature Engineering Insights
From LightGBM feature importances on the log-target model:
**Top features** (by split count):
- `M, N, K` -- raw dimensions are always the most important
- `tile_m, tile_n, tile_k` -- the tile shape is the primary kernel differentiator
- `overall_tile_efficiency` -- how well the shape fits the tile (the interaction)
- `num_tiles_m, total_output_tiles` -- work decomposition
- `arithmetic_intensity` -- compute vs memory bound regime
- `pipeline` -- pipeline type (compv3 vs compv4 vs mem) significantly affects perf
**Low-importance features**:
- Hardware constants (CUs, clock, caches) -- they're constant within one arch
model, so they provide no discriminative signal. They'll become important when
training cross-arch models.
- `split_k` -- always 1 in current data
- `persistent` -- rarely True in current kernel set
## 8. Warm-Start Works for Incremental Updates
LightGBM's `init_model` parameter successfully continues training from an
existing model. New trees are added on top of existing ones. Key considerations:
- Feature schema must match exactly (enforced by `check_feature_compatibility`)
- Use fewer new trees (200-500) since we're refining, not starting fresh
- The `train_manifest.json` tracks the full lineage (total trees, data sizes)
- Quality should be at least as good as the base model (tested)
## 9. Data Volume Matters More Than Model Complexity
| Dataset | Shapes | Rows | Mean Eff (log, 500 trees) |
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------ | ---- | ----------------------------- |
| Original (DeepSeek only) | 108 | 418K | 98.28% (on seen distribution) |
| + Wide coverage M=1 distribution. Adding 60 diverse shapes (many M=1) exposed the model's weakness on tiny shapes. More diverse training data is always better than a bigger model on narrow data.Summary of DefaultsBased on these findings, the current defaults in `train.py` are:- **Target transform**: `log1p` for TFLOPS and bandwidth (scale normalization)- **Model**: 2000 trees, 255 leaves, max depth 15, LR 0.02- **Validation**: 5-fold GroupKFold, key = (M, N, K)- **Early stopping**: patience 100 (let trees fully converge)- **Warm start**: 500 new trees (was 200, increased for bigger base model) | 168 | 626K | 96.92% (harder distribution) |
The original 108-shape model looked great (98.28%) but was overfitting to the
DeepSeek LLM inference

View File

@@ -0,0 +1,271 @@
# CK Tile Heuristics: ML-Based Kernel Selection
Fast, accurate kernel selection for CK Tile operations using LightGBM regression
with Origami-augmented feature engineering.
## What This Does
Instead of running all 4608+ kernel configurations on the GPU to find the best
one (exhaustive search taking ~46 seconds per shape), this system trains an ML
model that predicts TFLOPS for any (problem, kernel) pair in microseconds. It
scores all candidates instantly and picks the best kernel -- achieving 98.28%
of oracle-best TFLOPS efficiency across 108 tested shapes.
## Quick Start
### 1. Generate and convert benchmark data
**Step 1: Generate benchmark data**
```bash
python3 generate_benchmark_data.py \
--build_dir /path/to/build \
--output_dir data/fp16_original \
--dtype fp16 \
--layout rcr \
--num_build_jobs 4 \
--warmup 10 \
--repeat 50
```
This outputs JSON with all benchmark results.
**Step 2: Convert JSON to parquet training format**
```bash
python3 convert_json_to_parquet.py \
--input data/fp16_original/benchmark_results_fp16_rcr.json \
--output data/fp16_original/fp16_training_data.parquet \
--arch gfx950
```
The converter automatically fixes pad flags for `_mem` kernels and validates data.
**Alternative: Parse existing logs**
If you have raw benchmark logs from CK Tile:
```bash
python3 data_pipeline.py ck_tile_testrun_2.log \
-o data/gemm_universal_fp8_rcr_gfx950.parquet \
--arch gfx950 --capture_hw
```
### 2. Train a model
```bash
python3 train.py \
--data_dir data/ \
--out_dir models/gemm_universal_fp8_gfx950 \
--op gemm_universal --dtype fp8 --arch gfx950
```
**Note**: Trained models are automatically compressed to `.lgbm.gz` format to save space (~67% reduction). The Python tools automatically decompress them on first use and cache the decompressed version. For warm-start training, decompression happens automatically.
### 3. Evaluate
```bash
python3 evaluate.py \
--model_dir models/gemm_universal_fp8_gfx950 \
--data_dir data/ --op gemm_universal --dtype fp8
```
### 4. Predict the best kernel for a problem
```bash
python3 predict.py \
--model_dir models/gemm_universal_fp8_gfx950 \
--m 128 --n 1536 --k 7168 --layout rcr
```
### 5. Search for optimal configs (optional)
```bash
python3 search.py \
--model_dir models/gemm_universal_fp8_gfx950 \
--m 128 --n 1536 --k 7168 \
--strategy random --budget 500 --top_k 10
```
### 6. Using models in C++ (requires decompression)
C++ code uses the LightGBM C API which requires uncompressed `.lgbm` files. If you have compressed models (`.lgbm.gz`), decompress them first:
```bash
cd models/gemm_universal_fp16_gfx950
gunzip model_tflops.lgbm.gz
```
Then use in C++ examples:
```bash
cd dispatcher/build
./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm
```
**Note**: Python tools automatically decompress `.lgbm.gz` files on first use, so you can run Python scripts first to trigger decompression, then use the same models in C++.
## Architecture
```
Problem (M, N, K, dtype, layout)
|
v
FeatureEngine.extract_batch() <-- 55 features: problem, kernel, interaction, hardware
|
v
LGBMRegressor.predict() <-- predicts TFLOPS for each candidate kernel
|
v
Sort by predicted TFLOPS <-- rank all candidates
|
v
Select Top-1 kernel <-- 98.28% mean efficiency, <1ms inference
```
Three models are trained per (op, dtype, arch):
- **TFLOPS model** (primary): used for kernel ranking
- **Latency model** (auxiliary): for latency-sensitive workloads
- **Bandwidth model** (auxiliary): for memory-bound analysis
## File Inventory
| File | Purpose |
|---|---|
| `generate_benchmark_data.py` | Build and run benchmarks across ~25 diverse problem sizes, output JSON |
| `convert_json_to_parquet.py` | Convert benchmark JSON to parquet training format, fix `_mem` pad flags |
| `data_pipeline.py` | Parse raw benchmark logs into canonical parquet datasets |
| `feature_engine.py` | 55-feature extraction: problem, kernel, interaction, hardware profile |
| `train.py` | Multi-target LGBMRegressor training with GroupKFold CV, IHEM, warm-start |
| `predict.py` | Predictor class: predict TFLOPS/latency/bandwidth, rank kernels |
| `evaluate.py` | Full evaluation: global metrics, per-shape/layout/pipeline slices |
| `search.py` | Surrogate search: discrete DE, random top-K |
| `generate_wide_coverage.py` | Generate benchmark data across 706 diverse shapes |
| `generate_edge_dims.py` | Generate N=1, K=1, and other edge-case shapes |
| `DATA_GENERATION.md` | Detailed guide for building binaries and generating data |
| `plan.md` | Full design plan with architecture, milestones, and rationale |
## Features Used (55 total)
### Problem features (13)
`M, N, K, split_k, log2(M), log2(N), log2(K), log2(MNK),
arithmetic_intensity, aspect_ratio_mn, aspect_ratio_mk, aspect_ratio_nk, layout`
### Kernel features (17)
`tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n,
warp_tile_k, pipeline, scheduler, epilogue, pad_m, pad_n, pad_k, persistent,
num_warps, tile_volume, tile_mn, lds_usage_estimate, lds_usage_ratio`
### Interaction features (9)
`num_tiles_m, num_tiles_n, num_tiles_k, total_output_tiles,
tile_eff_m, tile_eff_n, tile_eff_k, overall_tile_efficiency, cu_utilization`
### Hardware profile features (12)
`hw_num_cus, hw_simds_per_cu, hw_total_simds, hw_shader_engines,
hw_max_clock_mhz, hw_max_waves_per_cu, hw_wavefront_size, hw_lds_capacity,
hw_l1_cache_kb, hw_l2_cache_kb, hw_l3_cache_kb, hw_num_xcd`
## Model Performance
### fp8 RCR, gfx950
| Metric | 108 shapes (original) | 168 shapes (wide coverage) |
|---|---|---|
| Mean TFLOPS Efficiency | 98.28% | 97.51% |
| P10 TFLOPS Efficiency | 94.64% | 93.89% |
| tiny_m (M=1) Efficiency | 95.57% | 96.04% |
| R2 (TFLOPS) | 0.997 | 0.993 |
### fp16 RCR, gfx950
Trained on 25 shapes, 1,024 kernels, 21,920 valid benchmarks.
| Metric | Value |
|---|---|
| Mean TFLOPS Efficiency | 99.36% |
| P10 TFLOPS Efficiency | 98.05% |
| P50 TFLOPS Efficiency | 100.00% |
| Min Efficiency | 95.45% |
| NDCG@1 | 64.00% |
| Top-5 Hit Rate | 88.00% |
**Shape Family Breakdown:**
| Shape Family | Mean Eff | P10 Eff | Shapes |
|---|---|---|---|
| Large M (M≥1024) | 99.54% | 99.07% | 4 |
| Medium M (128≤M<1024) | 99.62% | 98.74% | 7 |
| Small M (8≤M<128) | 98.82% | 96.22% | 8 |
| Tiny M (M<8) | 99.65% | 98.96% | 6 |
**Pipeline Breakdown:**
| Pipeline | Mean Eff | P10 Eff |
|---|---|---|
| compv3 | 99.75% | 99.09% |
| compv4 | 99.40% | 98.54% |
| mem | 99.08% | 96.59% |
Training uses `log1p(TFLOPS)` as the target by default, which normalizes the
scale across shapes spanning 0.02 to 2230 TFLOPS. This was the key finding
that improved tiny-M shapes from 84% to 96% efficiency. See
[LEARNINGS.md](LEARNINGS.md) for details.
## Validation
Training uses `GroupKFold(n_splits=5)` with group key `(M, N, K)` to ensure
the model is evaluated on shapes it has never seen during training. Layout is
excluded from the group key to force cross-layout generalization.
## Incremental Training (Warm Start)
When new benchmark data arrives, update the model without retraining from scratch:
```bash
python3 train.py \
--data_dir data/ \
--out_dir models/v2 \
--warm_start models/gemm_universal_fp8_gfx950 \
--warm_start_n_estimators 200
```
This adds 200 new trees on top of the existing model. Feature schemas must
match exactly (automatically enforced).
## Extending to New Ops
Adding support for a new operation (e.g., `gemm_streamk`, `grouped_conv`):
1. **Build binaries**: `ninja -C build benchmark_gemm_streamk_fp8_rcr`
2. **Subclass `FeatureEngine`**: add op-specific features (e.g., StreamK split factor)
3. **Generate data**: run benchmarks across diverse shapes
4. **Train**: `python3 train.py --op gemm_streamk --dtype fp8 --data_dir data/ --out_dir models/`
The training, evaluation, prediction, and search infrastructure is fully
op-agnostic -- only the feature engine needs a new subclass.
## Tests
102 tests covering all modules:
```bash
python3 -m pytest tests/ -v
```
Test coverage includes:
- Log parsing with malformed JSON, empty logs, single-kernel shapes
- Feature formula correctness (tile efficiency, LDS usage, arithmetic intensity)
- Corner-case shapes: M=1, N=1, K=1, prime dimensions, 20480x7168x256
- Batch vs single extraction parity
- Parameter space validation and projection
- Predictor: single/batch prediction, ranking, missing models, empty inputs
- Training: group keys, efficiency computation, warm-start, feature compatibility
- Search: random, DE, config validity, determinism
## Documentation
- **[README.md](README.md)**: This file -- quick start, architecture, performance
- **[DATA_GENERATION.md](DATA_GENERATION.md)**: Complete guide for building tile engine
binaries, running benchmarks, managing datasets, and troubleshooting
- **[LEARNINGS.md](LEARNINGS.md)**: Empirical findings and design decisions (log-transform,
IHEM results, tiny-M analysis, feature importance, N=1/K=1 edge cases)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Tile Heuristics: ML-based kernel selection

View File

@@ -0,0 +1,67 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# Generate additional benchmark data for shapes NOT in the original log.
# Runs in background; outputs streaming JSON that can be parsed by data_pipeline.py.
BIN_DIR="/workspace/ck_tile/bin"
OUT_LOG="data/additional_shapes.log"
WARMUP=3
REPEAT=10
mkdir -p data
# Additional shapes: square powers-of-2 and common ML sizes not in original DeepSeek set
SHAPES=(
"64,64,64"
"128,128,128"
"256,256,256"
"512,512,512"
"1024,1024,1024"
"2048,2048,2048"
"4096,4096,4096"
"1,4096,4096"
"8,4096,4096"
"32,4096,4096"
"128,4096,4096"
"1,4096,11008"
"32,4096,11008"
"1,8192,8192"
"32,8192,8192"
"1,8192,28672"
"32,8192,28672"
"256,256,8192"
"8192,8192,256"
"1024,4096,1024"
"4096,1024,4096"
"2048,8192,2048"
)
echo "CK Tile Additional Shapes Benchmark" > "$OUT_LOG"
echo "GPU ID: 0" >> "$OUT_LOG"
echo "Implementation: gemm_universal" >> "$OUT_LOG"
echo "" >> "$OUT_LOG"
SHAPE_IDX=0
for SHAPE in "${SHAPES[@]}"; do
IFS=',' read -r M N K <<< "$SHAPE"
SHAPE_IDX=$((SHAPE_IDX + 1))
echo "========================================" >> "$OUT_LOG"
echo "Shape $SHAPE_IDX: M=$M N=$N K=$K dtype=fp8 layout=rcr" >> "$OUT_LOG"
echo "========================================" >> "$OUT_LOG"
KERNEL_COUNT=0
for EXE in "$BIN_DIR"/benchmark_gemm_universal_fp8_rcr_*; do
KERNEL_COUNT=$((KERNEL_COUNT + 1))
OUTPUT=$("$EXE" -m="$M" -n="$N" -k="$K" -warmup=$WARMUP -repeat=$REPEAT -verify=0 2>/dev/null)
# Extract just the JSON block
echo "$OUTPUT" | sed -n '/{/,/^}/p' >> "$OUT_LOG"
done
echo "Found $KERNEL_COUNT kernels" >> "$OUT_LOG"
echo "Completed shape $SHAPE_IDX: M=$M N=$N K=$K ($KERNEL_COUNT kernels)" >&2
done
echo "Done generating additional data" >&2

View File

@@ -0,0 +1,233 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Convert benchmark JSON results to parquet format for training.
Usage:
python convert_json_to_parquet.py \
--input benchmark_results_fp16_rcr.json \
--output fp16_training_data.parquet
Features:
- Converts JSON benchmark results to flat row format
- Automatically fixes pad flags for _mem kernels
- Captures both successes and failures
- Compatible with existing training data format
"""
import argparse
import json
import pandas as pd
from pathlib import Path
def convert_json_to_parquet(json_file: Path, output_file: Path, arch: str = "gfx950"):
"""Convert benchmark JSON to parquet training data format."""
print(f"Loading {json_file}...")
with open(json_file) as f:
data = json.load(f)
metadata = data.get("metadata", {})
dtype = metadata.get("dtype", "fp16")
layout = metadata.get("layout", "rcr")
print(f" Data type: {dtype}")
print(f" Layout: {layout}")
print(f" Kernels: {metadata.get('num_kernels', 0)}")
print(f" Problem sizes: {metadata.get('num_problems', 0)}")
print()
rows = []
for kernel_result in data["results"]:
kernel_config = kernel_result["kernel_config"]
for benchmark in kernel_result["benchmarks"]:
# Common fields for both valid and invalid runs
row = {
"op_type": "gemm_universal",
"dtype": dtype,
"layout": layout,
"arch": arch,
"kernel_name": kernel_config["name"],
"m": benchmark["m"],
"n": benchmark["n"],
"k": benchmark["k"],
"split_k": 1,
"is_valid": benchmark["is_valid"],
"run_id": 0,
"pipeline": kernel_config["pipeline"],
"epilogue": kernel_config["epilogue"],
"scheduler": kernel_config["scheduler"],
"pad_m": kernel_config["pad_m"],
"pad_n": kernel_config["pad_n"],
"pad_k": kernel_config["pad_k"],
"persistent": kernel_config["persistent"],
"tile_m": kernel_config["tile_m"],
"tile_n": kernel_config["tile_n"],
"tile_k": kernel_config["tile_k"],
"warp_m": kernel_config["warp_m"],
"warp_n": kernel_config["warp_n"],
"warp_k": kernel_config["warp_k"],
"warp_tile_m": kernel_config["warp_tile_m"],
"warp_tile_n": kernel_config["warp_tile_n"],
"warp_tile_k": kernel_config["warp_tile_k"],
}
if benchmark["is_valid"]:
# Valid run - include performance metrics
row["measured_tflops"] = benchmark["tflops"]
row["latency_ms"] = benchmark["avg_time_ms"]
# Calculate bandwidth if needed
m, n, k = benchmark["m"], benchmark["n"], benchmark["k"]
bytes_transferred = (m * k + k * n + m * n) * 2 # FP16 = 2 bytes
if benchmark["avg_time_ms"] > 0:
row["bandwidth_gb_s"] = (bytes_transferred / 1e9) / (
benchmark["avg_time_ms"] / 1000
)
else:
row["bandwidth_gb_s"] = 0.0
else:
# Failed run - zero metrics
row["measured_tflops"] = 0.0
row["latency_ms"] = 0.0
row["bandwidth_gb_s"] = 0.0
rows.append(row)
df = pd.DataFrame(rows)
print(f"Converted {len(df):,} benchmark results")
print(f" Valid: {df['is_valid'].sum():,}")
print(f" Failed: {(~df['is_valid']).sum():,}")
print()
# Fix pad flags for _mem kernels (critical for P1 features!)
print("Fixing pad flags for _mem kernels...")
mem_mask = df["pipeline"] == "mem"
mem_count = mem_mask.sum()
if mem_count > 0:
df.loc[mem_mask, "pad_m"] = True
df.loc[mem_mask, "pad_n"] = True
df.loc[mem_mask, "pad_k"] = True
print(f" ✓ Fixed {mem_count:,} _mem kernel rows")
print()
# Save to parquet
df.to_parquet(output_file, index=False)
print(f"✓ Saved to {output_file}")
print()
# Show statistics
print("=" * 80)
print("STATISTICS")
print("=" * 80)
print()
print("Dimension ranges:")
print(f" M: {df['m'].min():,} - {df['m'].max():,}")
print(f" N: {df['n'].min():,} - {df['n'].max():,}")
print(f" K: {df['k'].min():,} - {df['k'].max():,}")
print()
print("Pipeline distribution:")
print(df["pipeline"].value_counts())
print()
print("Pad flag distribution:")
pad_combos = df[["pad_m", "pad_n", "pad_k"]].value_counts()
print(pad_combos)
print()
if (~df["is_valid"]).sum() > 0:
print("Failure analysis:")
failed = df[~df["is_valid"]]
print(f" Total failures: {len(failed):,}")
# Group by pipeline
print("\n By pipeline:")
for pipeline, count in failed["pipeline"].value_counts().items():
print(f" {pipeline}: {count:,}")
# Show sample failures
print("\n Sample failures:")
for _, row in failed.head(5).iterrows():
print(
f" {row['kernel_name'][:60]:60s} M={row['m']:4d} N={row['n']:4d} K={row['k']:4d}"
)
return df
def merge_datasets(parquet_files: list[Path], output_file: Path):
"""Merge multiple parquet files into one."""
print("=" * 80)
print("MERGING DATASETS")
print("=" * 80)
print()
dfs = []
for pq_file in parquet_files:
if pq_file.exists():
df = pd.read_parquet(pq_file)
print(f" {pq_file.name}: {len(df):,} rows")
dfs.append(df)
else:
print(f"{pq_file} not found, skipping")
if not dfs:
print("No files to merge!")
return
combined = pd.concat(dfs, ignore_index=True)
combined.to_parquet(output_file, index=False)
print()
print(f"✓ Merged {len(combined):,} total rows to {output_file}")
print()
# Show dtype distribution
print("Data type distribution:")
print(combined["dtype"].value_counts())
print()
print("Layout distribution:")
print(combined["layout"].value_counts())
def main():
parser = argparse.ArgumentParser(
description="Convert benchmark JSON to parquet training data",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--input", type=str, required=True, help="Input JSON file from benchmark"
)
parser.add_argument("--output", type=str, required=True, help="Output parquet file")
parser.add_argument("--arch", type=str, default="gfx950", help="GPU architecture")
parser.add_argument(
"--merge_with", type=str, nargs="*", help="Additional parquet files to merge"
)
args = parser.parse_args()
input_file = Path(args.input)
output_file = Path(args.output)
# Convert JSON to parquet
df = convert_json_to_parquet(input_file, output_file, args.arch)
# Merge if requested
if args.merge_with:
merge_files = [output_file] + [Path(f) for f in args.merge_with]
merged_output = output_file.parent / f"{output_file.stem}_merged.parquet"
merge_datasets(merge_files, merged_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,394 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Data pipeline for CK Tile heuristics.
Parses benchmark logs and structured JSON into a canonical parquet dataset.
Supports:
- Streaming log format (Shape N: headers + inline JSON) from ck_tile profiling runs
- Structured JSON from generate_benchmark_data.py
- Direct parquet passthrough
"""
import json
import re
import subprocess
import hashlib
from pathlib import Path
from typing import Optional
import pandas as pd
CANONICAL_COLUMNS = [
"op_type",
"dtype",
"layout",
"arch",
"kernel_name",
"m",
"n",
"k",
"split_k",
"measured_tflops",
"latency_ms",
"bandwidth_gb_s",
"is_valid",
"tile_m",
"tile_n",
"tile_k",
"warp_m",
"warp_n",
"warp_k",
"warp_tile_m",
"warp_tile_n",
"warp_tile_k",
"pipeline",
"scheduler",
"epilogue",
"pad_m",
"pad_n",
"pad_k",
"persistent",
"run_id",
]
def parse_kernel_name(name: str) -> dict:
"""Extract kernel config fields from a gemm_universal kernel name.
Name format:
gemm_universal_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}
_{padM}_{padN}_{padK}_{persistent}_{tileM}x{tileN}x{tileK}
_{warpM}x{warpN}x{warpK}_{warpTileM}x{warpTileN}x{warpTileK}
"""
result = {}
try:
prefix_match = re.match(
r"gemm_universal_(\w+?)_((?:rcr|rrr|crr|ccr))_(.*)", name
)
if not prefix_match:
return result
result["dtype"] = prefix_match.group(1)
result["layout"] = prefix_match.group(2)
remainder = prefix_match.group(3)
parts = remainder.split("_")
if len(parts) < 10:
return result
result["pipeline"] = parts[0]
result["epilogue"] = parts[1]
result["scheduler"] = parts[2]
result["pad_m"] = parts[3] == "True"
result["pad_n"] = parts[4] == "True"
result["pad_k"] = parts[5] == "True"
result["persistent"] = parts[6] == "True"
tile_dims = parts[7].split("x")
warp_dims = parts[8].split("x")
warp_tile_dims = parts[9].split("x")
result["tile_m"] = int(tile_dims[0])
result["tile_n"] = int(tile_dims[1])
result["tile_k"] = int(tile_dims[2])
result["warp_m"] = int(warp_dims[0])
result["warp_n"] = int(warp_dims[1])
result["warp_k"] = int(warp_dims[2])
result["warp_tile_m"] = int(warp_tile_dims[0])
result["warp_tile_n"] = int(warp_tile_dims[1])
result["warp_tile_k"] = int(warp_tile_dims[2])
except (IndexError, ValueError):
pass
return result
def _layout_from_problem(problem: dict) -> str:
"""Derive layout shorthand (rcr/rrr/etc.) from problem JSON fields."""
la = problem.get("layout_a", "")
lb = problem.get("layout_b", "")
lc = problem.get("layout_c", "")
def _tag(s):
s = s.lower()
if "row" in s:
return "r"
if "col" in s:
return "c"
return "?"
return _tag(la) + _tag(lb) + _tag(lc)
def parse_streaming_log(
path: str | Path,
arch: str = "unknown",
run_id: Optional[str] = None,
op_type: str = "gemm_universal",
) -> pd.DataFrame:
"""Parse a CK Tile streaming benchmark log into a canonical DataFrame.
The log alternates between shape headers and JSON result blocks:
Shape N: M=16 N=1536 K=7168 dtype=fp8 layout=rcr
{
"name": "gemm_universal_...",
"problem": { ... },
"perf_result": { "latency(ms)": ..., "tflops(TFlops)": ..., "bandwidth(GB/s)": ... }
}
"""
path = Path(path)
if run_id is None:
run_id = hashlib.md5(path.name.encode()).hexdigest()[:12]
shape_re = re.compile(
r"Shape\s+\d+:\s+M=(\d+)\s+N=(\d+)\s+K=(\d+)\s+dtype=(\w+)\s+layout=(\w+)"
)
rows = []
current_m, current_n, current_k = 0, 0, 0
current_dtype, current_layout = "", ""
json_buf = []
brace_depth = 0
with open(path, "r") as f:
for line in f:
stripped = line.strip()
shape_match = shape_re.search(stripped)
if shape_match:
current_m = int(shape_match.group(1))
current_n = int(shape_match.group(2))
current_k = int(shape_match.group(3))
current_dtype = shape_match.group(4)
current_layout = shape_match.group(5)
continue
if brace_depth == 0 and stripped.startswith("{"):
json_buf = [stripped]
brace_depth = stripped.count("{") - stripped.count("}")
if brace_depth == 0:
raw = "\n".join(json_buf)
try:
obj = json.loads(raw)
except json.JSONDecodeError:
continue
else:
continue
elif brace_depth > 0:
json_buf.append(stripped)
brace_depth += stripped.count("{") - stripped.count("}")
if brace_depth <= 0:
brace_depth = 0
raw = "\n".join(json_buf)
try:
obj = json.loads(raw)
except json.JSONDecodeError:
continue
else:
continue
else:
continue
# If we get here, obj was successfully parsed
kernel_name = obj.get("name", "")
problem = obj.get("problem", {})
perf = obj.get("perf_result", {})
m = problem.get("m", current_m)
n = problem.get("n", current_n)
k = problem.get("k", current_k)
split_k = problem.get("split_k", 1)
dtype = problem.get("dtype_a", current_dtype)
layout = (
_layout_from_problem(problem)
if problem.get("layout_a")
else current_layout
)
tflops = perf.get("tflops(TFlops)", 0.0)
latency = perf.get("latency(ms)", 0.0)
bandwidth = perf.get("bandwidth(GB/s)", 0.0)
kp = parse_kernel_name(kernel_name)
row = {
"op_type": op_type,
"dtype": dtype,
"layout": layout,
"arch": arch,
"kernel_name": kernel_name,
"m": m,
"n": n,
"k": k,
"split_k": split_k,
"measured_tflops": tflops,
"latency_ms": latency,
"bandwidth_gb_s": bandwidth,
"is_valid": tflops > 0 and latency > 0,
"run_id": run_id,
}
row.update(kp)
rows.append(row)
df = pd.DataFrame(rows)
for col in CANONICAL_COLUMNS:
if col not in df.columns:
df[col] = None
return df
def get_hardware_profile() -> dict:
"""Capture GPU hardware profile from rocminfo."""
profile = {}
try:
result = subprocess.run(
["rocminfo"], capture_output=True, text=True, timeout=30
)
output = result.stdout
gpu_section = False
for line in output.split("\n"):
line = line.strip()
if "Device Type:" in line and "GPU" in line:
gpu_section = True
continue
if gpu_section and "Device Type:" in line and "GPU" not in line:
break
if not gpu_section:
continue
if ":" not in line:
continue
key, _, val = line.partition(":")
key = key.strip()
val = val.strip()
mapping = {
"Name": "gfx_name",
"Marketing Name": "marketing_name",
"Compute Unit": "num_cus",
"SIMDs per CU": "simds_per_cu",
"Shader Engines": "shader_engines",
"Shader Arrs. per Eng.": "shader_arrays_per_engine",
"Max Clock Freq. (MHz)": "max_clock_mhz",
"Wavefront Size": "wavefront_size",
"Max Waves Per CU": "max_waves_per_cu",
"Chip ID": "chip_id",
}
if key in mapping:
raw = val.split("(")[0].strip()
try:
profile[mapping[key]] = int(raw)
except ValueError:
profile[mapping[key]] = raw
for line in output.split("\n"):
line = line.strip()
if line.startswith("L1:") and "num_cus" in profile:
raw = line.split(":")[1].strip().split("(")[0].strip()
try:
profile["l1_cache_kb"] = int(raw)
except ValueError:
pass
elif line.startswith("L2:"):
raw = line.split(":")[1].strip().split("(")[0].strip()
try:
profile["l2_cache_kb"] = int(raw)
except ValueError:
pass
elif line.startswith("L3:"):
raw = line.split(":")[1].strip().split("(")[0].strip()
try:
profile["l3_cache_kb"] = int(raw)
except ValueError:
pass
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
return profile
def load_parquet(path: str | Path) -> pd.DataFrame:
"""Load a canonical parquet dataset."""
return pd.read_parquet(path)
def save_parquet(df: pd.DataFrame, path: str | Path):
"""Save a DataFrame in canonical parquet format."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(path, index=False, engine="pyarrow")
def build_training_dataset(
data_dir: str | Path,
op_type: str = "gemm_universal",
dtype: str = "fp8",
) -> pd.DataFrame:
"""Load and merge all parquet files matching the given op/dtype from a directory."""
data_dir = Path(data_dir)
frames = []
for f in sorted(data_dir.glob("*.parquet")):
df = pd.read_parquet(f)
if "op_type" in df.columns:
df = df[df["op_type"] == op_type]
if "dtype" in df.columns:
df = df[df["dtype"] == dtype]
if len(df) > 0:
frames.append(df)
if not frames:
raise FileNotFoundError(
f"No parquet files with op_type={op_type}, dtype={dtype} in {data_dir}"
)
return pd.concat(frames, ignore_index=True)
if __name__ == "__main__":
import argparse
import time
parser = argparse.ArgumentParser(description="Parse CK Tile benchmark data")
parser.add_argument("input", help="Input file (log or parquet)")
parser.add_argument("--output", "-o", required=True, help="Output parquet path")
parser.add_argument("--arch", default="gfx950", help="GPU architecture")
parser.add_argument("--op_type", default="gemm_universal", help="Operation type")
parser.add_argument(
"--capture_hw",
action="store_true",
help="Capture hardware profile from rocminfo",
)
args = parser.parse_args()
input_path = Path(args.input)
print(f"Parsing {input_path}...")
t0 = time.time()
if input_path.suffix == ".parquet":
df = load_parquet(input_path)
else:
df = parse_streaming_log(input_path, arch=args.arch, op_type=args.op_type)
elapsed = time.time() - t0
print(f"Parsed {len(df)} rows in {elapsed:.1f}s")
print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}")
print(f" Unique kernels: {df['kernel_name'].nunique()}")
print(f" Valid rows: {df['is_valid'].sum()} / {len(df)}")
if df["measured_tflops"].max() > 0:
print(
f" TFLOPS range: {df['measured_tflops'].min():.2f} - {df['measured_tflops'].max():.2f}"
)
if args.capture_hw:
hw = get_hardware_profile()
print(f" Hardware profile: {hw}")
for k, v in hw.items():
df[f"hw_{k}"] = v
save_parquet(df, args.output)
print(f"Saved to {args.output}")

View File

@@ -0,0 +1,324 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Dispatcher integration for ML-based kernel selection.
Bridges the trained LightGBM Predictor with the CK Tile dispatcher's
kernel selection flow. Provides heuristic functions compatible with
both the Python pre-selection pattern (08_heuristics.py style) and
the C++ HeuristicFunction signature.
Name mapping between feature engine and dispatcher KernelConfig:
Feature engine Dispatcher KernelConfig
--------------------- ----------------------
warp_m (warps/block) wave_m
warp_n wave_n
warp_k wave_k
warp_tile_m warp_m
warp_tile_n warp_n
warp_tile_k warp_k
Usage:
from dispatcher_integration import create_ml_heuristic
heuristic = create_ml_heuristic("models/gemm_universal_fp8_gfx950")
best_spec = heuristic(M=1024, N=1024, K=1024, kernel_pool=KERNEL_POOL)
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from data_pipeline import parse_kernel_name
from predict import Predictor
LAYOUT_TO_DISPATCHER = {
"rcr": ("row", "col", "row"),
"rrr": ("row", "row", "row"),
"crr": ("col", "row", "row"),
"ccr": ("col", "col", "row"),
}
DTYPE_TO_C_DTYPE = {
"fp8": "fp16",
"fp16": "fp16",
"bf16": "bf16",
"fp32": "fp32",
}
@dataclass
class MLKernelSpec:
"""Kernel spec returned by the ML heuristic, compatible with the dispatcher
example pattern. Carries both the feature-engine-space config and the
dispatcher-space KernelConfig fields."""
kernel_name: str
predicted_tflops: float
tile_m: int
tile_n: int
tile_k: int
wave_m: int
wave_n: int
wave_k: int
warp_m: int
warp_n: int
warp_k: int
pipeline: str
scheduler: str
epilogue: str
pad_m: bool
pad_n: bool
pad_k: bool
persistent: bool
def kernel_config_to_feature_dict(kernel_name: str) -> dict:
"""Parse a tile-engine kernel name into a feature-engine-compatible dict.
Returns a dict with fields matching what GemmUniversalFeatureEngine.extract()
expects for the kernel parameter: tile_m/n/k, warp_m/n/k (warps per block),
warp_tile_m/n/k, pipeline, scheduler, epilogue, pad_m/n/k, persistent.
"""
parsed = parse_kernel_name(kernel_name)
if not parsed:
return {}
parsed["kernel_name"] = kernel_name
return parsed
def feature_dict_to_dispatcher_config(
feat: dict, dtype: str = "fp8", arch: str = "gfx950"
) -> dict:
"""Convert a feature-engine kernel dict to dispatcher KernelConfig fields.
Handles the naming inversion:
feature engine warp_m -> KernelConfig wave_m (warps per block)
feature engine warp_tile_m -> KernelConfig warp_m (elements per warp)
"""
layout = feat.get("layout", "rcr")
la, lb, lc = LAYOUT_TO_DISPATCHER.get(layout, ("row", "col", "row"))
c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype)
return {
"dtype_a": dtype,
"dtype_b": dtype,
"dtype_c": c_dtype,
"dtype_acc": "fp32",
"layout_a": la,
"layout_b": lb,
"layout_c": lc,
"tile_m": feat.get("tile_m", 128),
"tile_n": feat.get("tile_n", 128),
"tile_k": feat.get("tile_k", 64),
"wave_m": feat.get("warp_m", 2),
"wave_n": feat.get("warp_n", 2),
"wave_k": feat.get("warp_k", 1),
"warp_m": feat.get("warp_tile_m", 32),
"warp_n": feat.get("warp_tile_n", 32),
"warp_k": feat.get("warp_tile_k", 16),
"pipeline": feat.get("pipeline", "compv3"),
"scheduler": feat.get("scheduler", "intrawave"),
"epilogue": feat.get("epilogue", "cshuffle"),
"pad_m": feat.get("pad_m", True),
"pad_n": feat.get("pad_n", True),
"pad_k": feat.get("pad_k", True),
"gfx_arch": arch,
}
def feature_dict_to_ml_spec(feat: dict, predicted_tflops: float = 0.0) -> MLKernelSpec:
"""Convert a feature-engine kernel dict + prediction to an MLKernelSpec."""
return MLKernelSpec(
kernel_name=feat.get("kernel_name", "unknown"),
predicted_tflops=predicted_tflops,
tile_m=feat.get("tile_m", 128),
tile_n=feat.get("tile_n", 128),
tile_k=feat.get("tile_k", 64),
wave_m=feat.get("warp_m", 2),
wave_n=feat.get("warp_n", 2),
wave_k=feat.get("warp_k", 1),
warp_m=feat.get("warp_tile_m", 32),
warp_n=feat.get("warp_tile_n", 32),
warp_k=feat.get("warp_tile_k", 16),
pipeline=feat.get("pipeline", "compv3"),
scheduler=feat.get("scheduler", "intrawave"),
epilogue=feat.get("epilogue", "cshuffle"),
pad_m=feat.get("pad_m", False),
pad_n=feat.get("pad_n", False),
pad_k=feat.get("pad_k", False),
persistent=feat.get("persistent", False),
)
def load_kernel_pool_from_binaries(bin_dir: str | Path) -> list[dict]:
"""Discover benchmark executables and parse their names into feature dicts.
Each executable name encodes the full kernel config. This creates the
candidate pool for the ML heuristic without needing a registry JSON export.
"""
bin_dir = Path(bin_dir)
configs = []
for exe in sorted(bin_dir.glob("benchmark_gemm_universal_*")):
name = exe.stem.replace("benchmark_", "")
feat = kernel_config_to_feature_dict(name)
if feat and "tile_m" in feat:
configs.append(feat)
return configs
def create_ml_heuristic(
model_dir: str | Path,
dtype: str = "fp8",
arch: str = "gfx950",
layout: str = "rcr",
kernel_pool: Optional[list[dict]] = None,
bin_dir: Optional[str | Path] = None,
):
"""Create an ML heuristic function for kernel selection.
Returns a callable with signature:
(M: int, N: int, K: int) -> MLKernelSpec
The returned function scores all candidate kernels using the trained
LightGBM regressor and returns the best one as an MLKernelSpec.
Parameters
----------
model_dir : str or Path
Path to trained model directory (must contain model_tflops.lgbm or
model_tflops_log_big.lgbm and feature_spec.json).
dtype : str
Data type for the problem (fp8, fp16, bf16).
arch : str
GPU architecture (gfx942, gfx950).
layout : str
Matrix layout (rcr, rrr, crr, ccr).
kernel_pool : list of dict, optional
Pre-parsed kernel configs. If None, loads from bin_dir.
bin_dir : str or Path, optional
Directory with benchmark executables. Used to build kernel_pool if
kernel_pool is not provided. Defaults to /workspace/ck_tile/bin.
"""
model_dir = Path(model_dir)
predictor = Predictor(model_dir)
if kernel_pool is None:
if bin_dir is None:
bin_dir = Path("/workspace/ck_tile/bin")
kernel_pool = load_kernel_pool_from_binaries(bin_dir)
if not kernel_pool:
raise ValueError(
"No kernel configs found. Check bin_dir or provide kernel_pool."
)
def heuristic(M: int, N: int, K: int) -> MLKernelSpec:
problem = {
"m": M,
"n": N,
"k": K,
"dtype": dtype,
"layout": layout,
"split_k": 1,
}
ranked = predictor.rank_kernels(problem, kernel_pool)
if not ranked:
feat = kernel_pool[0]
return feature_dict_to_ml_spec(feat, 0.0)
best_name, best_tflops = ranked[0]
best_feat = next(
(kp for kp in kernel_pool if kp.get("kernel_name") == best_name),
kernel_pool[0],
)
return feature_dict_to_ml_spec(best_feat, best_tflops)
return heuristic
def create_ranked_heuristic(
model_dir: str | Path,
dtype: str = "fp8",
arch: str = "gfx950",
layout: str = "rcr",
kernel_pool: Optional[list[dict]] = None,
bin_dir: Optional[str | Path] = None,
top_k: int = 5,
):
"""Create an ML heuristic that returns the top-K ranked kernel specs.
Returns a callable with signature:
(M: int, N: int, K: int) -> list[MLKernelSpec]
Useful when you want fallback options if the top-1 kernel fails to build.
"""
model_dir = Path(model_dir)
predictor = Predictor(model_dir)
if kernel_pool is None:
if bin_dir is None:
bin_dir = Path("/workspace/ck_tile/bin")
kernel_pool = load_kernel_pool_from_binaries(bin_dir)
name_to_feat = {kp.get("kernel_name", ""): kp for kp in kernel_pool}
def heuristic(M: int, N: int, K: int) -> list[MLKernelSpec]:
problem = {
"m": M,
"n": N,
"k": K,
"dtype": dtype,
"layout": layout,
"split_k": 1,
}
ranked = predictor.rank_kernels(problem, kernel_pool)
results = []
for name, tflops in ranked[:top_k]:
feat = name_to_feat.get(name, kernel_pool[0])
results.append(feature_dict_to_ml_spec(feat, tflops))
return results
return heuristic
def ml_spec_to_dispatcher_config(
spec: MLKernelSpec, dtype: str = "fp8", arch: str = "gfx950"
) -> dict:
"""Convert an MLKernelSpec to a dict compatible with ctypes_utils.KernelConfig."""
layout_a, layout_b, layout_c = "row", "col", "row"
c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype)
return {
"dtype_a": dtype,
"dtype_b": dtype,
"dtype_c": c_dtype,
"dtype_acc": "fp32",
"layout_a": layout_a,
"layout_b": layout_b,
"layout_c": layout_c,
"tile_m": spec.tile_m,
"tile_n": spec.tile_n,
"tile_k": spec.tile_k,
"wave_m": spec.wave_m,
"wave_n": spec.wave_n,
"wave_k": spec.wave_k,
"warp_m": spec.warp_m,
"warp_n": spec.warp_n,
"warp_k": spec.warp_k,
"pipeline": spec.pipeline,
"scheduler": spec.scheduler,
"epilogue": spec.epilogue,
"pad_m": spec.pad_m,
"pad_n": spec.pad_n,
"pad_k": spec.pad_k,
"gfx_arch": arch,
}

View File

@@ -0,0 +1,254 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Evaluation and reporting for CK Tile kernel performance models.
Computes:
- Global metrics: TFLOPS efficiency (mean, p10, p50, min), R2, NDCG@1, Top-K hit rate
- Per-slice breakdowns: by layout, shape family, K-depth regime, pipeline
- Cross-target consistency checks
- Feature importance analysis
Usage:
python evaluate.py --model_dir models/gemm_universal_fp8_gfx950 --data_dir data/
"""
import argparse
import json
import numpy as np
import pandas as pd
from data_pipeline import build_training_dataset
from feature_engine import GemmUniversalFeatureEngine
from predict import Predictor
from train import compute_tflops_efficiency
def classify_shape_family(m: int, n: int, k: int) -> str:
"""Classify a GEMM shape into a family for sliced evaluation.
Families:
- tiny_m: M < 32 (single-token / very small batch inference)
- small_m: 32 <= M < 256
- medium_m: 256 <= M < 4096
- large_m: M >= 4096
- square: 0.5 <= M/N <= 2.0 and 0.5 <= M/K <= 2.0
- tall: M/N > 2.0
- wide: M/N < 0.5
"""
if m < 32:
return "tiny_m"
elif m < 256:
return "small_m"
elif m < 4096:
return "medium_m"
elif m >= 4096:
return "large_m"
return "other"
def classify_k_regime(k: int) -> str:
"""Classify K dimension into depth regime."""
if k < 512:
return "shallow_k"
elif k < 4096:
return "medium_k"
else:
return "deep_k"
def evaluate_model(
predictor: Predictor,
df: pd.DataFrame,
feature_engine: GemmUniversalFeatureEngine,
) -> dict:
"""Run full evaluation on a dataset. Returns a metrics dictionary.
Parameters
----------
predictor : Predictor
Trained predictor with at least a TFLOPS model loaded.
df : pd.DataFrame
Benchmark data in canonical schema.
feature_engine : GemmUniversalFeatureEngine
Feature engine matching the trained model.
Returns
-------
dict with keys: global_metrics, shape_family_metrics, k_regime_metrics,
pipeline_metrics, per_shape_efficiency.
"""
valid = df[df["is_valid"].fillna(False) & (df["measured_tflops"] > 0)].copy()
valid = valid.reset_index(drop=True)
X = feature_engine.extract_batch(valid)
model = predictor._load_model("tflops")
if model is None:
raise FileNotFoundError("No TFLOPS model found")
# Predict and apply inverse log transform if model was trained in log-space
raw_pred = model.predict(X)
if "tflops" in predictor._log_targets:
valid["pred_tflops"] = np.expm1(raw_pred)
else:
# Clamp to non-negative even for non-log models
valid["pred_tflops"] = np.maximum(0.0, raw_pred)
y_true = valid["measured_tflops"].values
y_pred = valid["pred_tflops"].values
ss_res = np.sum((y_true - y_pred) ** 2)
ss_tot = np.sum((y_true - y_true.mean()) ** 2)
r2 = 1 - ss_res / max(ss_tot, 1e-10)
rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
mae = np.mean(np.abs(y_true - y_pred))
eff_df = compute_tflops_efficiency(valid, "pred_tflops")
ndcg1_count = 0
total_shapes = 0
topk_hits = {3: 0, 5: 0, 10: 0}
for (m, n, k), group in valid.groupby(["m", "n", "k"]):
if group["measured_tflops"].max() <= 0:
continue
total_shapes += 1
oracle_idx = group["measured_tflops"].idxmax()
pred_ranking = group.sort_values("pred_tflops", ascending=False).index.tolist()
if pred_ranking[0] == oracle_idx:
ndcg1_count += 1
oracle_rank = pred_ranking.index(oracle_idx)
for topk in topk_hits:
if oracle_rank < topk:
topk_hits[topk] += 1
global_metrics = {
"r2": r2,
"rmse": rmse,
"mae": mae,
"num_valid_rows": len(valid),
"num_shapes": total_shapes,
"efficiency_mean": float(eff_df["efficiency"].mean()) if len(eff_df) > 0 else 0,
"efficiency_p10": float(eff_df["efficiency"].quantile(0.1))
if len(eff_df) > 0
else 0,
"efficiency_p50": float(eff_df["efficiency"].quantile(0.5))
if len(eff_df) > 0
else 0,
"efficiency_min": float(eff_df["efficiency"].min()) if len(eff_df) > 0 else 0,
"ndcg_at_1": ndcg1_count / max(total_shapes, 1),
"top3_hit_rate": topk_hits[3] / max(total_shapes, 1),
"top5_hit_rate": topk_hits[5] / max(total_shapes, 1),
"top10_hit_rate": topk_hits[10] / max(total_shapes, 1),
}
def _slice_efficiency(slice_df):
if len(slice_df) == 0:
return {"count": 0}
eff = compute_tflops_efficiency(slice_df, "pred_tflops")
if len(eff) == 0:
return {"count": 0}
return {
"count": len(eff),
"mean": float(eff["efficiency"].mean()),
"p10": float(eff["efficiency"].quantile(0.1)),
"min": float(eff["efficiency"].min()),
}
valid["shape_family"] = valid.apply(
lambda r: classify_shape_family(r["m"], r["n"], r["k"]), axis=1
)
valid["k_regime"] = valid["k"].apply(classify_k_regime)
shape_family_metrics = {}
for family, group in valid.groupby("shape_family"):
shape_family_metrics[family] = _slice_efficiency(group)
k_regime_metrics = {}
for regime, group in valid.groupby("k_regime"):
k_regime_metrics[regime] = _slice_efficiency(group)
pipeline_metrics = {}
if "pipeline" in valid.columns:
for pipeline, group in valid.groupby("pipeline"):
pipeline_metrics[str(pipeline)] = _slice_efficiency(group)
return {
"global_metrics": global_metrics,
"shape_family_metrics": shape_family_metrics,
"k_regime_metrics": k_regime_metrics,
"pipeline_metrics": pipeline_metrics,
"per_shape_efficiency": eff_df.to_dict(orient="records")
if len(eff_df) > 0
else [],
}
def main():
parser = argparse.ArgumentParser(description="Evaluate CK Tile performance model")
parser.add_argument(
"--model_dir", required=True, help="Directory with trained models"
)
parser.add_argument("--data_dir", required=True, help="Directory with parquet data")
parser.add_argument("--op", default="gemm_universal")
parser.add_argument("--dtype", default="fp8")
parser.add_argument("--output", "-o", help="Output JSON path for metrics")
args = parser.parse_args()
print(f"Loading data from {args.data_dir}...")
df = build_training_dataset(args.data_dir, op_type=args.op, dtype=args.dtype)
print(f" {len(df)} rows, {df.groupby(['m', 'n', 'k']).ngroups} shapes")
fe = GemmUniversalFeatureEngine()
predictor = Predictor(args.model_dir, feature_engine=fe)
print("Evaluating...")
results = evaluate_model(predictor, df, fe)
gm = results["global_metrics"]
print("\nGlobal Metrics:")
print(f" R2: {gm['r2']:.4f}")
print(f" RMSE: {gm['rmse']:.2f}")
print(f" Efficiency Mean: {gm['efficiency_mean']:.4f}")
print(f" Efficiency P10: {gm['efficiency_p10']:.4f}")
print(f" Efficiency P50: {gm['efficiency_p50']:.4f}")
print(f" Efficiency Min: {gm['efficiency_min']:.4f}")
print(f" NDCG@1: {gm['ndcg_at_1']:.4f}")
print(f" Top-3 Hit Rate: {gm['top3_hit_rate']:.4f}")
print(f" Top-5 Hit Rate: {gm['top5_hit_rate']:.4f}")
print(f" Top-10 Hit Rate: {gm['top10_hit_rate']:.4f}")
print("\nShape Family Breakdown:")
for family, metrics in sorted(results["shape_family_metrics"].items()):
if metrics.get("count", 0) > 0:
print(
f" {family:12s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} min={metrics['min']:.4f} (n={metrics['count']})"
)
print("\nK-Depth Regime Breakdown:")
for regime, metrics in sorted(results["k_regime_metrics"].items()):
if metrics.get("count", 0) > 0:
print(
f" {regime:12s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} min={metrics['min']:.4f} (n={metrics['count']})"
)
print("\nPipeline Breakdown:")
for pipeline, metrics in sorted(results["pipeline_metrics"].items()):
if metrics.get("count", 0) > 0:
print(
f" {pipeline:15s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} (n={metrics['count']})"
)
if args.output:
with open(args.output, "w") as f:
json.dump(results, f, indent=2, default=str)
print(f"\nFull results saved to {args.output}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,577 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Feature engineering for CK Tile kernel performance prediction.
Provides a strict FeatureEngine interface with per-op subclasses.
All feature engines produce a consistent numpy array for LightGBM.
"""
import math
from abc import ABC, abstractmethod
import numpy as np
import pandas as pd
DTYPE_BYTES = {
"fp32": 4.0,
"fp16": 2.0,
"bf16": 2.0,
"fp8": 1.0,
"bf8": 1.0,
"int8": 1.0,
"int4": 0.5,
}
LAYOUT_MAP = {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3}
PIPELINE_MAP = {"compv3": 0, "compv4": 1, "compv5": 2, "mem": 3, "preshufflev2": 4}
SCHEDULER_MAP = {"intrawave": 0, "interwave": 1}
EPILOGUE_MAP = {"default": 0, "cshuffle": 1}
class FeatureEngine(ABC):
"""Abstract base for per-op feature extraction."""
@abstractmethod
def get_feature_names(self) -> list[str]:
"""Ordered list of feature names matching the output array columns."""
...
@abstractmethod
def get_categorical_features(self) -> list[str]:
"""Feature names that should be treated as categorical by LightGBM."""
...
@abstractmethod
def extract(self, problem: dict, kernel: dict) -> np.ndarray:
"""Extract a single feature vector from a (problem, kernel) pair."""
...
def extract_batch(self, df: pd.DataFrame) -> np.ndarray:
"""Vectorized batch extraction from a DataFrame. Override for speed."""
names = self.get_feature_names()
result = np.zeros((len(df), len(names)), dtype=np.float64)
for i in range(len(df)):
row = df.iloc[i]
prob = row.to_dict()
kern = row.to_dict()
result[i] = self.extract(prob, kern)
return result
def get_parameter_space(self) -> dict[str, list]:
"""Valid discrete values for each kernel parameter (for surrogate search)."""
return {}
def get_constraints(self) -> list:
"""Multi-param constraint functions returning True if config is valid."""
return []
def validate_config(self, config: dict) -> bool:
"""Check all constraints. Returns True if the config is valid."""
ps = self.get_parameter_space()
for k, valid_vals in ps.items():
if k in config and config[k] not in valid_vals:
return False
for constraint in self.get_constraints():
if not constraint(config):
return False
return True
def project_to_valid(self, config: dict) -> dict:
"""Snap a config to the nearest valid discrete point."""
ps = self.get_parameter_space()
result = dict(config)
for k, valid_vals in ps.items():
if k not in result:
continue
v = result[k]
if isinstance(valid_vals[0], (int, float)):
result[k] = min(valid_vals, key=lambda x: abs(x - v))
elif v not in valid_vals:
result[k] = valid_vals[0]
return result
class GemmUniversalFeatureEngine(FeatureEngine):
"""Feature engine for gemm_universal kernels."""
def __init__(
self,
num_cus: int = 256,
lds_capacity: int = 65536,
max_clock_mhz: int = 2400,
simds_per_cu: int = 4,
shader_engines: int = 32,
max_waves_per_cu: int = 32,
wavefront_size: int = 64,
l1_cache_kb: int = 32,
l2_cache_kb: int = 4096,
l3_cache_kb: int = 262144,
num_xcd: int = 8,
):
self._hw = {
"num_cus": num_cus,
"lds_capacity": lds_capacity,
"max_clock_mhz": max_clock_mhz,
"simds_per_cu": simds_per_cu,
"shader_engines": shader_engines,
"max_waves_per_cu": max_waves_per_cu,
"wavefront_size": wavefront_size,
"l1_cache_kb": l1_cache_kb,
"l2_cache_kb": l2_cache_kb,
"l3_cache_kb": l3_cache_kb,
"num_xcd": num_xcd,
"total_simds": num_cus * simds_per_cu,
}
def get_feature_names(self) -> list[str]:
return [
# Problem features
"M",
"N",
"K",
"split_k",
"log2_M",
"log2_N",
"log2_K",
"log2_MNK",
"arithmetic_intensity",
"aspect_ratio_mn",
"aspect_ratio_mk",
"aspect_ratio_nk",
"layout",
# Kernel features
"tile_m",
"tile_n",
"tile_k",
"warp_m",
"warp_n",
"warp_k",
"warp_tile_m",
"warp_tile_n",
"warp_tile_k",
"pipeline",
"scheduler",
"epilogue",
"pad_m",
"pad_n",
"pad_k",
"persistent",
"num_warps",
"tile_volume",
"tile_mn",
"lds_usage_estimate",
"lds_usage_ratio",
# Interaction features
"num_tiles_m",
"num_tiles_n",
"num_tiles_k",
"total_output_tiles",
"tile_eff_m",
"tile_eff_n",
"tile_eff_k",
"overall_tile_efficiency",
"cu_utilization",
# P0 FIX: Problem-to-tile ratio features
"ratio_M_to_tile_m",
"ratio_N_to_tile_n",
"ratio_K_to_tile_k",
"problem_smaller_than_tile_m",
"problem_smaller_than_tile_n",
"problem_smaller_than_tile_k",
"any_dim_too_small",
# P1 FIX: Padding requirement interaction features
"needs_padding_m",
"needs_padding_n",
"needs_padding_k",
"has_padding_when_needed_m",
"has_padding_when_needed_n",
"has_padding_when_needed_k",
"missing_required_padding_m",
"missing_required_padding_n",
"missing_required_padding_k",
"missing_any_required_padding",
# Hardware features
"hw_num_cus",
"hw_simds_per_cu",
"hw_total_simds",
"hw_shader_engines",
"hw_max_clock_mhz",
"hw_max_waves_per_cu",
"hw_wavefront_size",
"hw_lds_capacity",
"hw_l1_cache_kb",
"hw_l2_cache_kb",
"hw_l3_cache_kb",
"hw_num_xcd",
]
def get_categorical_features(self) -> list[str]:
return ["layout", "pipeline", "scheduler", "epilogue"]
def extract(self, problem: dict, kernel: dict) -> np.ndarray:
M = int(problem.get("m", problem.get("M", 0)))
N = int(problem.get("n", problem.get("N", 0)))
K = int(problem.get("k", problem.get("K", 0)))
split_k = int(problem.get("split_k", 1))
dtype = str(problem.get("dtype", "fp8"))
bpe = DTYPE_BYTES.get(dtype, 1.0)
log2_M = math.log2(max(M, 1))
log2_N = math.log2(max(N, 1))
log2_K = math.log2(max(K, 1))
log2_MNK = math.log2(max(M * N * K, 1))
mem_bytes = (M * K + K * N + M * N) * bpe
ai = (2.0 * M * N * K) / max(mem_bytes, 1)
ar_mn = M / max(N, 1)
ar_mk = M / max(K, 1)
ar_nk = N / max(K, 1)
layout_code = LAYOUT_MAP.get(str(problem.get("layout", "rcr")), 0)
tile_m = int(kernel.get("tile_m", 128))
tile_n = int(kernel.get("tile_n", 128))
tile_k = int(kernel.get("tile_k", 64))
warp_m = int(kernel.get("warp_m", 2))
warp_n = int(kernel.get("warp_n", 2))
warp_k = int(kernel.get("warp_k", 1))
warp_tile_m = int(kernel.get("warp_tile_m", 32))
warp_tile_n = int(kernel.get("warp_tile_n", 32))
warp_tile_k = int(kernel.get("warp_tile_k", 16))
pipeline_code = PIPELINE_MAP.get(str(kernel.get("pipeline", "compv4")), 0)
scheduler_code = SCHEDULER_MAP.get(str(kernel.get("scheduler", "intrawave")), 0)
epilogue_code = EPILOGUE_MAP.get(str(kernel.get("epilogue", "cshuffle")), 0)
pad_m = float(kernel.get("pad_m", False))
pad_n = float(kernel.get("pad_n", False))
pad_k = float(kernel.get("pad_k", False))
persistent = float(kernel.get("persistent", False))
num_warps = warp_m * warp_n * warp_k
tile_volume = tile_m * tile_n * tile_k
tile_mn = tile_m * tile_n
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
lds_cap = self._hw["lds_capacity"]
if str(kernel.get("pipeline", "")).startswith("compv4"):
lds_cap = 32768
lds_ratio = lds_est / max(lds_cap, 1)
num_tiles_m = math.ceil(M / max(tile_m, 1))
num_tiles_n = math.ceil(N / max(tile_n, 1))
num_tiles_k = math.ceil(K / max(tile_k, 1))
total_output_tiles = num_tiles_m * num_tiles_n
rem_m = M % tile_m if tile_m > 0 else 0
tile_eff_m = rem_m / tile_m if rem_m > 0 else 1.0
rem_n = N % tile_n if tile_n > 0 else 0
tile_eff_n = rem_n / tile_n if rem_n > 0 else 1.0
rem_k = K % tile_k if tile_k > 0 else 0
tile_eff_k = rem_k / tile_k if rem_k > 0 else 1.0
overall_eff = tile_eff_m * tile_eff_n * tile_eff_k
cu_util = total_output_tiles / max(self._hw["num_cus"], 1)
# P0 FIX: Problem-to-tile ratio features (avoid oversized tiles for tiny problems)
ratio_M_to_tile_m = M / max(tile_m, 1)
ratio_N_to_tile_n = N / max(tile_n, 1)
ratio_K_to_tile_k = K / max(tile_k, 1)
# Binary features: is problem dimension smaller than tile?
problem_smaller_than_tile_m = float(M < tile_m)
problem_smaller_than_tile_n = float(N < tile_n)
problem_smaller_than_tile_k = float(K < tile_k)
any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k))
# P1 FIX: Padding requirement features (does this kernel have padding when needed?)
needs_padding_m = float(M % tile_m != 0) if tile_m > 0 else 0.0
needs_padding_n = float(N % tile_n != 0) if tile_n > 0 else 0.0
needs_padding_k = float(K % tile_k != 0) if tile_k > 0 else 0.0
# Interaction features: kernel has padding capability when problem needs it
has_padding_when_needed_m = float(needs_padding_m and pad_m)
has_padding_when_needed_n = float(needs_padding_n and pad_n)
has_padding_when_needed_k = float(needs_padding_k and pad_k)
# Critical feature: missing required padding (kernel will likely fail)
missing_required_padding_m = float(needs_padding_m and not pad_m)
missing_required_padding_n = float(needs_padding_n and not pad_n)
missing_required_padding_k = float(needs_padding_k and not pad_k)
missing_any_required_padding = float(
missing_required_padding_m
or missing_required_padding_n
or missing_required_padding_k
)
hw = self._hw
return np.array(
[
M,
N,
K,
split_k,
log2_M,
log2_N,
log2_K,
log2_MNK,
ai,
ar_mn,
ar_mk,
ar_nk,
layout_code,
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline_code,
scheduler_code,
epilogue_code,
pad_m,
pad_n,
pad_k,
persistent,
num_warps,
tile_volume,
tile_mn,
lds_est,
lds_ratio,
num_tiles_m,
num_tiles_n,
num_tiles_k,
total_output_tiles,
tile_eff_m,
tile_eff_n,
tile_eff_k,
overall_eff,
cu_util,
# P0 FIX: New ratio and binary features
ratio_M_to_tile_m,
ratio_N_to_tile_n,
ratio_K_to_tile_k,
problem_smaller_than_tile_m,
problem_smaller_than_tile_n,
problem_smaller_than_tile_k,
any_dim_too_small,
# P1 FIX: Padding requirement interaction features
needs_padding_m,
needs_padding_n,
needs_padding_k,
has_padding_when_needed_m,
has_padding_when_needed_n,
has_padding_when_needed_k,
missing_required_padding_m,
missing_required_padding_n,
missing_required_padding_k,
missing_any_required_padding,
hw["num_cus"],
hw["simds_per_cu"],
hw["total_simds"],
hw["shader_engines"],
hw["max_clock_mhz"],
hw["max_waves_per_cu"],
hw["wavefront_size"],
hw["lds_capacity"],
hw["l1_cache_kb"],
hw["l2_cache_kb"],
hw["l3_cache_kb"],
hw["num_xcd"],
],
dtype=np.float64,
)
def extract_batch(self, df: pd.DataFrame) -> np.ndarray:
"""Vectorized batch extraction -- much faster than row-by-row."""
n = len(df)
names = self.get_feature_names()
result = np.zeros((n, len(names)), dtype=np.float64)
M = df["m"].values.astype(np.float64)
N = df["n"].values.astype(np.float64)
K = df["k"].values.astype(np.float64)
split_k = df["split_k"].fillna(1).values.astype(np.float64)
dtype_col = df["dtype"].fillna("fp8")
bpe = dtype_col.map(DTYPE_BYTES).fillna(1.0).values
result[:, 0] = M
result[:, 1] = N
result[:, 2] = K
result[:, 3] = split_k
result[:, 4] = np.log2(np.maximum(M, 1))
result[:, 5] = np.log2(np.maximum(N, 1))
result[:, 6] = np.log2(np.maximum(K, 1))
result[:, 7] = np.log2(np.maximum(M * N * K, 1))
mem = (M * K + K * N + M * N) * bpe
result[:, 8] = (2.0 * M * N * K) / np.maximum(mem, 1)
result[:, 9] = M / np.maximum(N, 1)
result[:, 10] = M / np.maximum(K, 1)
result[:, 11] = N / np.maximum(K, 1)
result[:, 12] = df["layout"].map(LAYOUT_MAP).fillna(0).values
tile_m = df["tile_m"].fillna(128).values.astype(np.float64)
tile_n = df["tile_n"].fillna(128).values.astype(np.float64)
tile_k = df["tile_k"].fillna(64).values.astype(np.float64)
warp_m = df["warp_m"].fillna(2).values.astype(np.float64)
warp_n = df["warp_n"].fillna(2).values.astype(np.float64)
warp_k = df["warp_k"].fillna(1).values.astype(np.float64)
warp_tile_m = df["warp_tile_m"].fillna(32).values.astype(np.float64)
warp_tile_n = df["warp_tile_n"].fillna(32).values.astype(np.float64)
warp_tile_k = df["warp_tile_k"].fillna(16).values.astype(np.float64)
result[:, 13] = tile_m
result[:, 14] = tile_n
result[:, 15] = tile_k
result[:, 16] = warp_m
result[:, 17] = warp_n
result[:, 18] = warp_k
result[:, 19] = warp_tile_m
result[:, 20] = warp_tile_n
result[:, 21] = warp_tile_k
result[:, 22] = df["pipeline"].map(PIPELINE_MAP).fillna(0).values
result[:, 23] = df["scheduler"].map(SCHEDULER_MAP).fillna(0).values
result[:, 24] = df["epilogue"].map(EPILOGUE_MAP).fillna(0).values
result[:, 25] = df["pad_m"].fillna(False).astype(float).values
result[:, 26] = df["pad_n"].fillna(False).astype(float).values
result[:, 27] = df["pad_k"].fillna(False).astype(float).values
result[:, 28] = df["persistent"].fillna(False).astype(float).values
num_warps = warp_m * warp_n * warp_k
result[:, 29] = num_warps
result[:, 30] = tile_m * tile_n * tile_k
result[:, 31] = tile_m * tile_n
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
result[:, 32] = lds_est
lds_cap = np.full(n, self._hw["lds_capacity"], dtype=np.float64)
is_compv4 = df["pipeline"].fillna("").str.startswith("compv4")
lds_cap[is_compv4] = 32768
result[:, 33] = lds_est / np.maximum(lds_cap, 1)
ntm = np.ceil(M / np.maximum(tile_m, 1))
ntn = np.ceil(N / np.maximum(tile_n, 1))
ntk = np.ceil(K / np.maximum(tile_k, 1))
result[:, 34] = ntm
result[:, 35] = ntn
result[:, 36] = ntk
result[:, 37] = ntm * ntn
rem_m = np.mod(M, np.maximum(tile_m, 1))
result[:, 38] = np.where(rem_m > 0, rem_m / tile_m, 1.0)
rem_n = np.mod(N, np.maximum(tile_n, 1))
result[:, 39] = np.where(rem_n > 0, rem_n / tile_n, 1.0)
rem_k = np.mod(K, np.maximum(tile_k, 1))
result[:, 40] = np.where(rem_k > 0, rem_k / tile_k, 1.0)
result[:, 41] = result[:, 38] * result[:, 39] * result[:, 40]
result[:, 42] = (ntm * ntn) / max(self._hw["num_cus"], 1)
# P0 FIX: Problem-to-tile ratio features
result[:, 43] = M / np.maximum(tile_m, 1) # ratio_M_to_tile_m
result[:, 44] = N / np.maximum(tile_n, 1) # ratio_N_to_tile_n
result[:, 45] = K / np.maximum(tile_k, 1) # ratio_K_to_tile_k
# Binary features: is problem smaller than tile?
result[:, 46] = (M < tile_m).astype(float) # problem_smaller_than_tile_m
result[:, 47] = (N < tile_n).astype(float) # problem_smaller_than_tile_n
result[:, 48] = (K < tile_k).astype(float) # problem_smaller_than_tile_k
result[:, 49] = ((M < tile_m) | (N < tile_n) | (K < tile_k)).astype(
float
) # any_dim_too_small
# P1 FIX: Padding requirement features
pad_m_bool = df["pad_m"].fillna(False).astype(bool).values
pad_n_bool = df["pad_n"].fillna(False).astype(bool).values
pad_k_bool = df["pad_k"].fillna(False).astype(bool).values
needs_padding_m = (np.mod(M, np.maximum(tile_m, 1)) != 0)
needs_padding_n = (np.mod(N, np.maximum(tile_n, 1)) != 0)
needs_padding_k = (np.mod(K, np.maximum(tile_k, 1)) != 0)
result[:, 50] = needs_padding_m.astype(float)
result[:, 51] = needs_padding_n.astype(float)
result[:, 52] = needs_padding_k.astype(float)
# Interaction features: kernel has padding when problem needs it
result[:, 53] = (needs_padding_m & pad_m_bool).astype(float) # has_padding_when_needed_m
result[:, 54] = (needs_padding_n & pad_n_bool).astype(float) # has_padding_when_needed_n
result[:, 55] = (needs_padding_k & pad_k_bool).astype(float) # has_padding_when_needed_k
# Critical feature: missing required padding
result[:, 56] = (needs_padding_m & ~pad_m_bool).astype(float) # missing_required_padding_m
result[:, 57] = (needs_padding_n & ~pad_n_bool).astype(float) # missing_required_padding_n
result[:, 58] = (needs_padding_k & ~pad_k_bool).astype(float) # missing_required_padding_k
result[:, 59] = ((needs_padding_m & ~pad_m_bool) | (needs_padding_n & ~pad_n_bool) | (needs_padding_k & ~pad_k_bool)).astype(float) # missing_any_required_padding
# Hardware profile features
hw = self._hw
result[:, 60] = hw["num_cus"]
result[:, 61] = hw["simds_per_cu"]
result[:, 62] = hw["total_simds"]
result[:, 63] = hw["shader_engines"]
result[:, 64] = hw["max_clock_mhz"]
result[:, 65] = hw["max_waves_per_cu"]
result[:, 66] = hw["wavefront_size"]
result[:, 67] = hw["lds_capacity"]
result[:, 68] = hw["l1_cache_kb"]
result[:, 69] = hw["l2_cache_kb"]
result[:, 70] = hw["l3_cache_kb"]
result[:, 71] = hw["num_xcd"]
return result
def get_parameter_space(self) -> dict[str, list]:
return {
"tile_m": [32, 64, 128, 192, 256],
"tile_n": [32, 64, 128, 192, 256],
"tile_k": [32, 64, 128, 256],
"warp_m": [1, 2, 4],
"warp_n": [1, 2, 4],
"warp_k": [1],
"warp_tile_m": [4, 16, 32, 64],
"warp_tile_n": [4, 16, 32, 64],
"warp_tile_k": [8, 16, 32, 64, 128],
"pipeline": list(PIPELINE_MAP.keys()),
"scheduler": list(SCHEDULER_MAP.keys()),
"epilogue": list(EPILOGUE_MAP.keys()),
"pad_m": [True, False],
"pad_n": [True, False],
"pad_k": [True, False],
"persistent": [True, False],
}
def get_constraints(self) -> list:
lds_cap = self._hw["lds_capacity"]
def _lds_constraint(cfg):
tm = cfg.get("tile_m", 128)
tn = cfg.get("tile_n", 128)
tk = cfg.get("tile_k", 64)
bpe = 1.0 # fp8 default
est = (tm * tk + tn * tk) * bpe
cap = (
32768 if str(cfg.get("pipeline", "")).startswith("compv4") else lds_cap
)
return est <= cap
def _warp_constraint(cfg):
wm = cfg.get("warp_m", 2)
wn = cfg.get("warp_n", 2)
wk = cfg.get("warp_k", 1)
return (wm * wn * wk) in [2, 4, 8]
return [_lds_constraint, _warp_constraint]

View File

@@ -0,0 +1,553 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
GEMM Universal Benchmark Data Generation Script
This script generates training data for ML-based kernel selection heuristics by:
1. Reading kernel configurations from the tile engine
2. Building benchmark executables (in parallel)
3. Running benchmarks across multiple problem sizes
4. Outputting performance data in JSON format
Usage:
python generate_benchmark_data.py \
--build_dir /tmp/build \
--output_dir /tmp/benchmark_data \
--dtype fp16 \
--layout rcr \
--num_build_jobs 4 \
--num_benchmark_jobs 1
Requirements:
- ROCm-capable GPU
- CK tile engine built with CMake
"""
import argparse
import json
import subprocess
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import List, Optional, Tuple
import re
@dataclass
class KernelConfig:
"""Represents a single kernel configuration."""
name: str
dtype: str
layout: str
pipeline: str
epilogue: str
scheduler: str
pad_m: bool
pad_n: bool
pad_k: bool
persistent: bool
tile_m: int
tile_n: int
tile_k: int
warp_m: int
warp_n: int
warp_k: int
warp_tile_m: int
warp_tile_n: int
warp_tile_k: int
@classmethod
def from_kernel_name(cls, name: str, dtype: str, layout: str) -> "KernelConfig":
"""Parse kernel name to extract configuration."""
# Format: gemm_universal_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{padM}_{padN}_{padK}_{persistent}_{tile_config}
# tile_config: {tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}
parts = name.split("_")
prefix = f"gemm_universal_{dtype}_{layout}_"
trait_and_tile = name[len(prefix) :]
trait_parts = trait_and_tile.split("_")
pipeline = trait_parts[0]
epilogue = trait_parts[1]
scheduler = trait_parts[2]
pad_m = trait_parts[3] == "True"
pad_n = trait_parts[4] == "True"
pad_k = trait_parts[5] == "True"
persistent = trait_parts[6] == "True"
# Parse tile config
tile_dims = trait_parts[7].split("x")
warp_dims = trait_parts[8].split("x")
warp_tile_dims = trait_parts[9].split("x")
return cls(
name=name,
dtype=dtype,
layout=layout,
pipeline=pipeline,
epilogue=epilogue,
scheduler=scheduler,
pad_m=pad_m,
pad_n=pad_n,
pad_k=pad_k,
persistent=persistent,
tile_m=int(tile_dims[0]),
tile_n=int(tile_dims[1]),
tile_k=int(tile_dims[2]),
warp_m=int(warp_dims[0]),
warp_n=int(warp_dims[1]),
warp_k=int(warp_dims[2]),
warp_tile_m=int(warp_tile_dims[0]),
warp_tile_n=int(warp_tile_dims[1]),
warp_tile_k=int(warp_tile_dims[2]),
)
@dataclass
class BenchmarkResult:
"""Result of a single benchmark run."""
kernel_name: str
m: int
n: int
k: int
avg_time_ms: float
tflops: float
is_valid: bool
error: Optional[str] = None
@dataclass
class ProblemSize:
"""GEMM problem dimensions."""
m: int
n: int
k: int
def get_problem_sizes() -> List[ProblemSize]:
"""
Generate diverse problem sizes for benchmarking.
Includes:
- Square matrices (powers of 2)
- Rectangular matrices (common in ML)
- LLM-specific sizes (attention, MLP)
- Edge cases (small, very large)
"""
sizes = []
# Powers of 2 (square)
for p in [6, 7, 8, 9, 10, 11, 12, 13]: # 64 to 8192
dim = 2**p
sizes.append(ProblemSize(dim, dim, dim))
# Common ML sizes (batch x hidden)
ml_sizes = [
(1, 4096, 4096), # Single token inference
(8, 4096, 4096), # Small batch
(32, 4096, 4096), # Medium batch
(128, 4096, 4096), # Large batch
(1, 4096, 11008), # LLaMA MLP up-projection
(1, 11008, 4096), # LLaMA MLP down-projection
(32, 4096, 11008),
(32, 11008, 4096),
(1, 8192, 8192), # Large model
(32, 8192, 8192),
(1, 8192, 28672), # LLaMA-70B MLP
(32, 8192, 28672),
]
for m, n, k in ml_sizes:
sizes.append(ProblemSize(m, n, k))
# Rectangular matrices
rect_sizes = [
(1024, 4096, 1024),
(4096, 1024, 4096),
(2048, 8192, 2048),
(256, 256, 8192), # Tall K
(8192, 8192, 256), # Short K
]
for m, n, k in rect_sizes:
sizes.append(ProblemSize(m, n, k))
# Remove duplicates while preserving order
seen = set()
unique_sizes = []
for s in sizes:
key = (s.m, s.n, s.k)
if key not in seen:
seen.add(key)
unique_sizes.append(s)
return unique_sizes
def load_kernel_list(build_dir: Path, dtype: str, layout: str) -> List[KernelConfig]:
"""Load kernel configurations from the tile engine build."""
kernel_list_path = (
build_dir
/ "tile_engine"
/ "ops"
/ "gemm"
/ "gemm_universal"
/ dtype
/ layout
/ "gemm_universal_kernel_list.txt"
)
if not kernel_list_path.exists():
raise FileNotFoundError(f"Kernel list not found: {kernel_list_path}")
kernels = []
with open(kernel_list_path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
# Format: kernel_name|tile_config|trait_combo
parts = line.split("|")
kernel_name = parts[0]
kernels.append(KernelConfig.from_kernel_name(kernel_name, dtype, layout))
return kernels
def build_kernel(build_dir: Path, kernel: KernelConfig) -> Tuple[str, bool, str]:
"""
Build a single kernel benchmark executable.
Returns: (kernel_name, success, error_message)
"""
target_name = f"benchmark_{kernel.name}"
try:
result = subprocess.run(
["ninja", "-j1", target_name],
cwd=build_dir,
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
)
if result.returncode != 0:
return (kernel.name, False, result.stderr[:500])
return (kernel.name, True, "")
except subprocess.TimeoutExpired:
return (kernel.name, False, "Build timeout")
except Exception as e:
return (kernel.name, False, str(e))
def run_benchmark(
build_dir: Path,
kernel: KernelConfig,
problem: ProblemSize,
warmup: int = 10,
repeat: int = 50,
) -> BenchmarkResult:
"""
Run benchmark for a single kernel and problem size.
"""
exe_path = build_dir / "bin" / f"benchmark_{kernel.name}"
if not exe_path.exists():
return BenchmarkResult(
kernel_name=kernel.name,
m=problem.m,
n=problem.n,
k=problem.k,
avg_time_ms=0,
tflops=0,
is_valid=False,
error="Executable not found",
)
try:
result = subprocess.run(
[
str(exe_path),
f"-m={problem.m}",
f"-n={problem.n}",
f"-k={problem.k}",
f"-warmup={warmup}",
f"-repeat={repeat}",
"-verify=0",
"-json_output=true",
],
capture_output=True,
text=True,
timeout=120,
)
if result.returncode != 0:
# Try to parse error
error = result.stderr[:200] if result.stderr else result.stdout[:200]
return BenchmarkResult(
kernel_name=kernel.name,
m=problem.m,
n=problem.n,
k=problem.k,
avg_time_ms=0,
tflops=0,
is_valid=False,
error=error,
)
# Parse JSON output
output = result.stdout.strip()
# Try to find JSON in output
json_match = re.search(r"\{.*\}", output, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
# Extract from nested perf_result object
perf = data.get("perf_result", {})
avg_time_ms = perf.get("latency(ms)", 0)
tflops = perf.get("tflops(TFlops)", 0)
return BenchmarkResult(
kernel_name=kernel.name,
m=problem.m,
n=problem.n,
k=problem.k,
avg_time_ms=avg_time_ms,
tflops=tflops,
is_valid=True,
)
else:
# Parse from text output
# Look for patterns like "avg_time: X ms" or "tflops: Y"
avg_time = 0.0
tflops = 0.0
time_match = re.search(
r"(?:avg[_\s]?time|latency)[:\s]+(\d+\.?\d*)\s*(?:ms)?", output, re.I
)
if time_match:
avg_time = float(time_match.group(1))
tflops_match = re.search(r"tflops[:\s]+(\d+\.?\d*)", output, re.I)
if tflops_match:
tflops = float(tflops_match.group(1))
# Calculate TFLOPs if not provided
if tflops == 0 and avg_time > 0:
flops = 2.0 * problem.m * problem.n * problem.k
tflops = flops / (avg_time * 1e-3) / 1e12
return BenchmarkResult(
kernel_name=kernel.name,
m=problem.m,
n=problem.n,
k=problem.k,
avg_time_ms=avg_time,
tflops=tflops,
is_valid=avg_time > 0,
error=None if avg_time > 0 else "Could not parse output",
)
except subprocess.TimeoutExpired:
return BenchmarkResult(
kernel_name=kernel.name,
m=problem.m,
n=problem.n,
k=problem.k,
avg_time_ms=0,
tflops=0,
is_valid=False,
error="Benchmark timeout",
)
except Exception as e:
return BenchmarkResult(
kernel_name=kernel.name,
m=problem.m,
n=problem.n,
k=problem.k,
avg_time_ms=0,
tflops=0,
is_valid=False,
error=str(e),
)
def main():
parser = argparse.ArgumentParser(
description="Generate GEMM benchmark data for ML training"
)
parser.add_argument(
"--build_dir", type=str, default="/tmp/build", help="CK build directory"
)
parser.add_argument(
"--output_dir",
type=str,
default="/tmp/benchmark_data",
help="Output directory for benchmark results",
)
parser.add_argument(
"--dtype",
type=str,
default="fp16",
choices=["fp16", "fp8", "bf16", "bf8"],
help="Data type to benchmark",
)
parser.add_argument(
"--layout",
type=str,
default="rcr",
choices=["rcr", "rrr", "crr", "ccr"],
help="Matrix layout to benchmark",
)
parser.add_argument(
"--num_build_jobs", type=int, default=4, help="Number of parallel build jobs"
)
parser.add_argument(
"--num_benchmark_jobs",
type=int,
default=1,
help="Number of parallel benchmark jobs (use 1 for accurate timing)",
)
parser.add_argument(
"--max_kernels",
type=int,
default=None,
help="Maximum number of kernels to benchmark (for testing)",
)
parser.add_argument(
"--skip_build",
action="store_true",
help="Skip building and only run benchmarks",
)
parser.add_argument(
"--warmup", type=int, default=10, help="Number of warmup iterations"
)
parser.add_argument(
"--repeat", type=int, default=50, help="Number of benchmark iterations"
)
args = parser.parse_args()
build_dir = Path(args.build_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load kernel configurations
print(f"Loading kernel list for {args.dtype}/{args.layout}...")
kernels = load_kernel_list(build_dir, args.dtype, args.layout)
print(f"Found {len(kernels)} kernel configurations")
if args.max_kernels:
kernels = kernels[: args.max_kernels]
print(f"Limiting to {len(kernels)} kernels")
# Build kernels
if not args.skip_build:
print(
f"\nBuilding {len(kernels)} kernels with {args.num_build_jobs} parallel jobs..."
)
build_results = {"success": 0, "failed": 0, "failed_kernels": []}
with ProcessPoolExecutor(max_workers=args.num_build_jobs) as executor:
futures = {executor.submit(build_kernel, build_dir, k): k for k in kernels}
for i, future in enumerate(as_completed(futures)):
kernel_name, success, error = future.result()
if success:
build_results["success"] += 1
else:
build_results["failed"] += 1
build_results["failed_kernels"].append(
{"name": kernel_name, "error": error}
)
if (i + 1) % 10 == 0:
print(
f" Built {i + 1}/{len(kernels)} ({build_results['success']} success, {build_results['failed']} failed)"
)
print(
f"\nBuild complete: {build_results['success']} success, {build_results['failed']} failed"
)
# Save build results
with open(output_dir / "build_results.json", "w") as f:
json.dump(build_results, f, indent=2)
# Get problem sizes
problem_sizes = get_problem_sizes()
print(f"\nBenchmarking {len(problem_sizes)} problem sizes...")
# Run benchmarks
all_results = []
total_benchmarks = len(kernels) * len(problem_sizes)
completed = 0
print(f"Total benchmarks to run: {total_benchmarks}")
for kernel in kernels:
kernel_results = {
"kernel_config": asdict(kernel),
"benchmarks": [],
}
for problem in problem_sizes:
result = run_benchmark(
build_dir,
kernel,
problem,
warmup=args.warmup,
repeat=args.repeat,
)
kernel_results["benchmarks"].append(asdict(result))
completed += 1
if completed % 100 == 0:
print(f" Progress: {completed}/{total_benchmarks} benchmarks complete")
all_results.append(kernel_results)
# Save intermediate results
intermediate_file = (
output_dir / f"benchmark_results_{args.dtype}_{args.layout}_partial.json"
)
with open(intermediate_file, "w") as f:
json.dump(all_results, f, indent=2)
# Save final results
final_file = output_dir / f"benchmark_results_{args.dtype}_{args.layout}.json"
with open(final_file, "w") as f:
json.dump(
{
"metadata": {
"dtype": args.dtype,
"layout": args.layout,
"num_kernels": len(kernels),
"num_problems": len(problem_sizes),
"warmup": args.warmup,
"repeat": args.repeat,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
},
"problem_sizes": [asdict(p) for p in problem_sizes],
"results": all_results,
},
f,
indent=2,
)
print(f"\nResults saved to {final_file}")
# Print summary
valid_count = sum(
1 for kr in all_results for br in kr["benchmarks"] if br["is_valid"]
)
print(f"Valid benchmarks: {valid_count}/{total_benchmarks}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,166 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Supplementary edge-case benchmark generator for N=1 and K=1 dimensions.
These shapes represent vector-matrix multiply (N=1), rank-1 updates (K=1),
and other degenerate GEMM cases that stress tile efficiency and padding logic.
"""
import json
import subprocess
import sys
from pathlib import Path
def generate_edge_shapes():
"""Generate shapes with N=1, K=1, and other single-dimension edge cases."""
shapes = set()
# --- N=1: vector-matrix multiply / single output column ---
for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
for k in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]:
shapes.add((m, 1, k))
# --- K=1: rank-1 update / outer product ---
for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
for n in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]:
shapes.add((m, n, 1))
# --- M=1, N=1: dot product ---
for k in [1, 16, 64, 256, 1024, 4096, 8192]:
shapes.add((1, 1, k))
# --- M=1, K=1: scalar-vector ---
for n in [1, 16, 64, 256, 1024, 4096, 8192]:
shapes.add((1, n, 1))
# --- N=1, K=1: scalar-vector ---
for m in [1, 16, 64, 256, 1024, 4096, 8192]:
shapes.add((m, 1, 1))
# --- All ones: 1x1x1 ---
shapes.add((1, 1, 1))
# --- Small N (2-16) ---
for m in [64, 256, 1024, 4096]:
for n in [2, 3, 4, 7, 8, 15, 16]:
for k in [64, 256, 1024, 4096]:
shapes.add((m, n, k))
# --- Small K (2-16) ---
for m in [64, 256, 1024, 4096]:
for n in [64, 256, 1024, 4096]:
for k in [2, 3, 4, 7, 8, 15, 16]:
shapes.add((m, n, k))
return sorted(shapes)
def run_shapes(bin_dir, shapes, out_file, warmup=3, repeat=10):
"""Run all kernels against shapes, writing streaming log."""
executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*"))
if not executables:
print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr)
return 0
total = 0
for idx, (m, n, k) in enumerate(shapes):
out_file.write("\n========================================\n")
out_file.write(f"Shape {idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n")
out_file.write("========================================\n")
out_file.write(f"Found {len(executables)} kernels\n")
out_file.flush()
for exe in executables:
try:
result = subprocess.run(
[
str(exe),
f"-m={m}",
f"-n={n}",
f"-k={k}",
f"-warmup={warmup}",
f"-repeat={repeat}",
"-verify=0",
],
capture_output=True,
text=True,
timeout=60,
)
output = result.stdout
json_start = output.find("{")
json_end = output.rfind("}") + 1
if json_start >= 0 and json_end > json_start:
json_block = output[json_start:json_end]
try:
json.loads(json_block)
out_file.write(json_block + "\n")
total += 1
except json.JSONDecodeError:
pass
except (subprocess.TimeoutExpired, Exception):
pass
out_file.flush()
print(
f" Shape {idx + 1}/{len(shapes)}: M={m} N={n} K={k}",
file=sys.stderr,
flush=True,
)
return total
if __name__ == "__main__":
bin_dir = "/workspace/ck_tile/bin"
out_dir = Path("data/edge_dims")
out_dir.mkdir(parents=True, exist_ok=True)
shapes = generate_edge_shapes()
print(f"Generated {len(shapes)} edge-case shapes", file=sys.stderr, flush=True)
n1_count = sum(1 for m, n, k in shapes if n == 1)
k1_count = sum(1 for m, n, k in shapes if k == 1)
both1 = sum(1 for m, n, k in shapes if n == 1 and k == 1)
small_n = sum(1 for m, n, k in shapes if 2 <= n <= 16)
small_k = sum(1 for m, n, k in shapes if 2 <= k <= 16)
print(
f" N=1: {n1_count}, K=1: {k1_count}, both=1: {both1}",
file=sys.stderr,
flush=True,
)
print(
f" Small N(2-16): {small_n}, Small K(2-16): {small_k}",
file=sys.stderr,
flush=True,
)
batch_size = 25
total = 0
batch_idx = 0
for i in range(0, len(shapes), batch_size):
batch = shapes[i : i + batch_size]
batch_idx += 1
out_path = out_dir / f"edge_dims_batch_{batch_idx:03d}.log"
print(
f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}",
file=sys.stderr,
flush=True,
)
with open(out_path, "w") as f:
f.write(f"CK Tile Edge Dims Benchmark Batch {batch_idx}\n")
f.write("GPU ID: 0\nImplementation: gemm_universal\n\n")
count = run_shapes(bin_dir, batch, f, warmup=3, repeat=10)
total += count
print(f" Batch {batch_idx} done: {count} results", file=sys.stderr, flush=True)
print(
f"\nTotal: {total} benchmarks across {len(shapes)} shapes",
file=sys.stderr,
flush=True,
)

View File

@@ -0,0 +1,289 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Wide-coverage benchmark data generator.
Generates benchmark results for hundreds of diverse GEMM shapes across all
corner cases: skinny M, tall N, deep K, M=1, prime dimensions, power-of-2,
LLM inference shapes, training shapes, and edge cases.
Runs all 4608 kernels in /workspace/ck_tile/bin/ against each shape and
writes streaming log output parseable by data_pipeline.py.
Usage:
python3 generate_wide_coverage.py --bin_dir /workspace/ck_tile/bin \
--out_dir data/ --batch_size 20 --warmup 3 --repeat 10
"""
import argparse
import json
import subprocess
import sys
from pathlib import Path
def generate_shape_list():
"""Generate a comprehensive list of (M, N, K) shapes covering all corner cases.
Categories:
1. M=1 (single token inference) -- the hardest case
2. Tiny M (2-16) -- small batch inference
3. Small M (32-128) -- medium batch
4. Medium M (256-2048) -- large batch / training
5. Large M (4096-20480) -- very large batch
6. Square shapes (powers of 2)
7. Skinny M, tall N (M << N)
8. Tall M, skinny N (M >> N)
9. Deep K (K >> M, N) -- compute-bound
10. Shallow K (K << M, N) -- memory-bound
11. Prime dimensions -- worst-case for tiling
12. LLM-specific shapes (DeepSeek, LLaMA, etc.)
13. Non-power-of-2 common sizes
"""
shapes = set()
# --- 1. M=1 (single token) across various N, K ---
for n in [512, 1024, 1536, 2048, 3072, 4096, 4608, 7168, 8192, 11008, 14336, 28672]:
for k in [256, 512, 1024, 1536, 2048, 2304, 4096, 7168, 8192]:
shapes.add((1, n, k))
# --- 2. Tiny M (2-16) ---
for m in [2, 4, 8, 16]:
for n in [512, 1536, 4096, 7168]:
for k in [256, 1024, 4096, 7168]:
shapes.add((m, n, k))
# --- 3. Small M (32-128) ---
for m in [32, 48, 64, 96, 128]:
for n in [512, 1536, 4096, 7168, 8192]:
for k in [256, 512, 2048, 4096, 7168]:
shapes.add((m, n, k))
# --- 4. Medium M (256-2048) ---
for m in [256, 384, 512, 768, 1024, 1536, 2048]:
for n in [512, 1536, 4096, 7168]:
for k in [256, 1024, 2048, 4096, 7168]:
shapes.add((m, n, k))
# --- 5. Large M (4096-20480) ---
for m in [4096, 8192, 12288, 16384, 20480]:
for n in [512, 1536, 4096, 7168]:
for k in [256, 1024, 2048, 7168]:
shapes.add((m, n, k))
# --- 6. Square shapes (powers of 2) ---
for p in range(5, 14): # 32 to 8192
d = 2**p
shapes.add((d, d, d))
# --- 7. Skinny M, tall N ---
for m in [1, 4, 16, 64]:
for n in [8192, 16384, 28672]:
for k in [1024, 4096, 8192]:
shapes.add((m, n, k))
# --- 8. Tall M, skinny N ---
for m in [4096, 8192, 16384]:
for n in [32, 64, 128, 256]:
for k in [1024, 4096]:
shapes.add((m, n, k))
# --- 9. Deep K (K >> M, N) ---
for m in [16, 64, 256]:
for n in [16, 64, 256]:
for k in [4096, 8192, 16384, 32768]:
shapes.add((m, n, k))
# --- 10. Shallow K (K << M, N) ---
for m in [1024, 4096, 8192]:
for n in [1024, 4096, 8192]:
for k in [16, 32, 64, 128]:
shapes.add((m, n, k))
# --- 11. Prime dimensions ---
primes = [17, 31, 37, 127, 251, 509, 1021, 2039, 4093]
for p in primes:
shapes.add((p, p, p))
for p in primes[:5]:
shapes.add((p, 4096, 4096))
shapes.add((4096, p, 4096))
shapes.add((4096, 4096, p))
# --- 12. LLM-specific shapes ---
llm_shapes = [
# DeepSeek MoE
(1, 1536, 7168),
(1, 4608, 7168),
(1, 7168, 2048),
(1, 7168, 2304),
(1, 7168, 256),
(1, 576, 7168),
(1, 512, 7168),
(1, 3072, 1536),
# LLaMA-7B
(1, 4096, 4096),
(32, 4096, 4096),
(128, 4096, 4096),
(1, 4096, 11008),
(32, 4096, 11008),
(1, 11008, 4096),
(32, 11008, 4096),
# LLaMA-70B
(1, 8192, 8192),
(32, 8192, 8192),
(128, 8192, 8192),
(1, 8192, 28672),
(32, 8192, 28672),
(1, 28672, 8192),
# GPT-style attention
(128, 128, 64),
(128, 128, 128),
(256, 256, 64),
(512, 512, 64),
(1024, 1024, 64),
(2048, 2048, 64),
]
for s in llm_shapes:
shapes.add(s)
# --- 13. Non-power-of-2 common sizes ---
for m in [48, 96, 192, 384, 576, 768, 1152, 1536, 2304, 3072, 4608, 6144]:
shapes.add((m, m, m))
shapes.add((m, 4096, 4096))
sorted_shapes = sorted(shapes)
return sorted_shapes
def run_shape_batch(bin_dir, shapes, out_file, warmup=3, repeat=10):
"""Run all kernels against a batch of shapes, writing streaming log output."""
executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*"))
if not executables:
print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr)
return 0
total_benchmarks = 0
for shape_idx, (m, n, k) in enumerate(shapes):
out_file.write("\n========================================\n")
out_file.write(
f"Shape {shape_idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n"
)
out_file.write("========================================\n")
out_file.write(f"Found {len(executables)} kernels\n")
out_file.flush()
for exe in executables:
try:
result = subprocess.run(
[
str(exe),
f"-m={m}",
f"-n={n}",
f"-k={k}",
f"-warmup={warmup}",
f"-repeat={repeat}",
"-verify=0",
],
capture_output=True,
text=True,
timeout=60,
)
output = result.stdout
# Extract JSON block from output
json_start = output.find("{")
json_end = output.rfind("}") + 1
if json_start >= 0 and json_end > json_start:
json_block = output[json_start:json_end]
try:
json.loads(json_block)
out_file.write(json_block + "\n")
total_benchmarks += 1
except json.JSONDecodeError:
pass
except (subprocess.TimeoutExpired, Exception):
pass
out_file.flush()
elapsed_kernels = len(executables)
print(
f" Shape {shape_idx + 1}/{len(shapes)}: M={m} N={n} K={k} "
f"({elapsed_kernels} kernels)",
file=sys.stderr,
flush=True,
)
return total_benchmarks
def main():
parser = argparse.ArgumentParser(
description="Generate wide-coverage benchmark data"
)
parser.add_argument(
"--bin_dir",
default="/workspace/ck_tile/bin",
help="Directory with benchmark executables",
)
parser.add_argument("--out_dir", default="data", help="Output directory")
parser.add_argument(
"--batch_size", type=int, default=25, help="Shapes per output file"
)
parser.add_argument("--warmup", type=int, default=3)
parser.add_argument("--repeat", type=int, default=10)
parser.add_argument(
"--max_shapes", type=int, default=None, help="Limit total shapes (for testing)"
)
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
shapes = generate_shape_list()
if args.max_shapes:
shapes = shapes[: args.max_shapes]
print(f"Generated {len(shapes)} unique shapes", file=sys.stderr, flush=True)
print(f"Bin dir: {args.bin_dir}", file=sys.stderr, flush=True)
print(f"Output dir: {args.out_dir}", file=sys.stderr, flush=True)
print(f"Batch size: {args.batch_size}", file=sys.stderr, flush=True)
total = 0
batch_idx = 0
for i in range(0, len(shapes), args.batch_size):
batch = shapes[i : i + args.batch_size]
batch_idx += 1
out_path = out_dir / f"wide_coverage_batch_{batch_idx:03d}.log"
print(
f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}",
file=sys.stderr,
flush=True,
)
with open(out_path, "w") as f:
f.write(f"CK Tile Wide Coverage Benchmark Batch {batch_idx}\n")
f.write("GPU ID: 0\n")
f.write("Implementation: gemm_universal\n\n")
count = run_shape_batch(
args.bin_dir, batch, f, warmup=args.warmup, repeat=args.repeat
)
total += count
print(
f" Batch {batch_idx} complete: {count} benchmarks",
file=sys.stderr,
flush=True,
)
print(
f"\nTotal: {total} benchmarks across {len(shapes)} shapes",
file=sys.stderr,
flush=True,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,867 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
ML Heuristic Sweep: Comprehensive GEMM Performance Evaluation
Sweeps across diverse problem shapes with ML-based kernel selection to measure
TFLOPS performance. Supports multiple dtypes (fp16, bf16, fp8) and validates
ML model predictions by executing kernels on GPU.
Shape Constraints (fp16/bf16 on gfx950):
- M >= 1 (any M is valid)
- N % 8 == 0 AND N >= 64
- K % 2 == 0 AND K >= 32
Usage:
python ml_heuristic_sweep.py --dtype fp16 --num_shapes 256
python ml_heuristic_sweep.py --dtypes fp16 bf16 --output sweep_results.csv
python ml_heuristic_sweep.py --dtype fp16 --dry_run # Prediction only, no GPU execution
"""
import sys
import argparse
import time
import csv
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple
# Add parent directories to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
)
try:
from predict import Predictor
# from feature_engine import GemmUniversalFeatureEngine
HAS_ML = True
except ImportError:
HAS_ML = False
print("WARNING: ML heuristic modules not available. Will use first-fit selection.")
@dataclass
class KernelSpec:
"""Kernel specification for ML heuristic"""
name: str
tile_m: int
tile_n: int
tile_k: int
pipeline: str = "compv3"
scheduler: str = "intrawave"
wave_m: int = 2
wave_n: int = 2
wave_k: int = 1
warp_m: int = 32
warp_n: int = 32
warp_k: int = 16
# Comprehensive kernel pool covering diverse tile sizes and configurations
KERNEL_POOL = [
# Small tiles (64x64)
KernelSpec(
"s_64x64_k32_v3", 64, 64, 32, "compv3", "intrawave", 2, 2, 1, 16, 16, 16
),
KernelSpec(
"s_64x64_k64_v3", 64, 64, 64, "compv3", "intrawave", 2, 2, 1, 16, 16, 16
),
KernelSpec(
"s_64x64_k128_v3", 64, 64, 128, "compv3", "intrawave", 2, 2, 1, 16, 16, 16
),
KernelSpec(
"s_64x64_k64_v4", 64, 64, 64, "compv4", "intrawave", 2, 2, 1, 16, 16, 16
),
KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", "intrawave", 2, 2, 1, 16, 16, 16),
KernelSpec(
"s_64x64_k128_mem", 64, 64, 128, "mem", "intrawave", 2, 2, 1, 16, 16, 16
),
# Medium tiles (128x128)
KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3", "intrawave"),
KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3", "intrawave"),
KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3", "intrawave"),
KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4", "intrawave"),
KernelSpec("m_128x128_k128_v4", 128, 128, 128, "compv4", "intrawave"),
KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem", "intrawave"),
KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem", "intrawave"),
# Rectangular medium (M != N)
KernelSpec(
"r_64x128_k32_v3", 64, 128, 32, "compv3", "intrawave", 2, 2, 1, 16, 32, 16
),
KernelSpec(
"r_128x64_k32_v3", 128, 64, 32, "compv3", "intrawave", 2, 2, 1, 32, 16, 16
),
KernelSpec(
"r_64x128_k64_v3", 64, 128, 64, "compv3", "intrawave", 2, 2, 1, 16, 32, 16
),
KernelSpec(
"r_128x64_k64_v3", 128, 64, 64, "compv3", "intrawave", 2, 2, 1, 32, 16, 16
),
KernelSpec(
"r_64x256_k32_v3", 64, 256, 32, "compv3", "intrawave", 2, 2, 1, 16, 32, 16
),
KernelSpec(
"r_256x64_k32_v3", 256, 64, 32, "compv3", "intrawave", 2, 2, 1, 32, 16, 16
),
# Large tiles (256x256)
KernelSpec("l_256x128_k32_v3", 256, 128, 32, "compv3", "intrawave"),
KernelSpec("l_128x256_k32_v3", 128, 256, 32, "compv3", "intrawave"),
KernelSpec("l_256x256_k32_v3", 256, 256, 32, "compv3", "intrawave"),
KernelSpec("l_256x256_k64_v3", 256, 256, 64, "compv3", "intrawave"),
KernelSpec("l_256x256_k64_v4", 256, 256, 64, "compv4", "intrawave"),
# Interwave variants
KernelSpec("m_128x128_k64_iw_v3", 128, 128, 64, "compv3", "interwave"),
KernelSpec("m_128x128_k128_iw_v3", 128, 128, 128, "compv3", "interwave"),
KernelSpec("l_256x256_k32_iw_v3", 256, 256, 32, "compv3", "interwave"),
]
def generate_problem_shapes(num_shapes: int = 1024) -> List[Tuple[int, int, int]]:
"""
Generate diverse problem shapes with hardware constraints:
- M >= 1 (any M is valid, including tiny M for inference)
- N % 8 == 0 AND N >= 64 (hardware alignment requirement)
- K % 2 == 0 AND K >= 32 (fp16 requirement)
Covers:
- Powers of 2 (square and rectangular)
- ML workloads (LLM attention, MLP, batch inference)
- Non-power-of-2 dimensions (aligned to constraints)
- Edge cases (tiny M, very large matrices, extreme aspect ratios)
"""
shapes = []
# 1. Powers of 2 - Square (64 to 8192) with K variations
for p in range(6, 14): # 2^6=64 to 2^13=8192
dim = 2**p
shapes.append((dim, dim, dim))
if dim >= 128:
# K variations (must be even and >= 32)
shapes.append((dim, dim, dim // 2))
shapes.append((dim, dim, dim * 2))
shapes.append((dim, dim, max(32, dim // 4)))
# 2. Small batch inference (1-256 batch, common hidden dims)
# N must be multiple of 8 and >= 64
hidden_dims = [768, 1024, 2048, 3072, 4096, 5120, 8192, 11008, 12288, 16384]
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
for hidden in hidden_dims:
for batch in batch_sizes[:8]:
shapes.append((batch, hidden, hidden))
if hidden >= 4096:
# LLM MLP projections (ensure K is even)
k_mlp = hidden * 3 // 4
if k_mlp % 2 == 1:
k_mlp += 1 # Make even
if k_mlp >= 32:
shapes.append((batch, hidden, k_mlp))
shapes.append((batch, k_mlp, hidden))
# 3. Attention patterns (seq_len x head_dim)
# seq_len can be any value >= 1, total_dim must be multiple of 8
seq_lens = [128, 256, 512, 1024, 2048, 4096, 8192]
head_dims = [64, 80, 96, 128, 256]
num_heads = [8, 12, 16, 32, 40, 64]
for seq in seq_lens:
for head_dim in head_dims:
for nh in num_heads[:4]:
total_dim = nh * head_dim
# total_dim should be multiple of 8 (naturally satisfied for most cases)
if total_dim % 8 == 0 and total_dim >= 64:
# head_dim must be even for K
if head_dim % 2 == 0 and head_dim >= 32:
shapes.append((seq, total_dim, head_dim))
shapes.append((seq, head_dim, total_dim))
# 4. Rectangular matrices (extreme aspect ratios)
# All dims must satisfy constraints
dims_m = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
dims_n = [64, 128, 256, 512, 1024, 2048, 4096, 8192] # N >= 64, N % 8 == 0
dims_k = [
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
] # K >= 32, K % 2 == 0
# Sample to avoid explosion
for i, m in enumerate(dims_m):
for j, n in enumerate(dims_n):
for _l, k in enumerate(dims_k):
if (i + j + _l) % 3 == 0: # Stratified sampling
shapes.append((m, n, k))
# 5. Non-power-of-2 dimensions (aligned to constraints)
# N values: multiples of 8, >= 64
non_pow2_n = [
72,
80,
88,
96,
104,
112,
120,
136,
144,
152,
160,
176,
184,
192,
200,
224,
240,
272,
288,
304,
320,
336,
352,
368,
384,
400,
416,
448,
480,
544,
576,
640,
672,
704,
736,
768,
800,
832,
896,
960,
1088,
1152,
1216,
1280,
1344,
1408,
1472,
1536,
1600,
1664,
1728,
1792,
1856,
1920,
2176,
2304,
2432,
2560,
2688,
2816,
2944,
3072,
3200,
3328,
3456,
3584,
3712,
3840,
3968,
4224,
4352,
4480,
4608,
4736,
4864,
4992,
]
# K values: even numbers >= 32
non_pow2_k = [
34,
36,
38,
40,
42,
44,
48,
50,
52,
56,
60,
66,
68,
72,
76,
80,
88,
96,
100,
112,
120,
136,
144,
160,
176,
192,
224,
240,
272,
288,
320,
352,
384,
416,
448,
480,
544,
576,
640,
672,
704,
768,
800,
832,
896,
960,
1088,
1152,
1280,
1344,
1408,
1536,
1600,
1664,
1792,
1920,
]
# M values: any value >= 1
non_pow2_m = [
1,
3,
5,
7,
9,
11,
13,
15,
17,
19,
23,
27,
31,
33,
37,
41,
47,
51,
57,
63,
65,
71,
79,
87,
95,
97,
111,
119,
127,
129,
143,
159,
175,
191,
193,
223,
239,
255,
257,
287,
319,
351,
383,
385,
447,
479,
511,
513,
575,
639,
703,
767,
769,
895,
959,
1023,
1025,
]
# Sample non-power-of-2 shapes
for i, m in enumerate(non_pow2_m[:30]):
for j, n in enumerate(non_pow2_n[:20]):
for _l, k in enumerate(non_pow2_k[:15]):
if (i + j + _l) % 4 == 0: # Stratified sampling
shapes.append((m, n, k))
# 6. Very tall K (memory-bound) - ensure N % 8 == 0, K % 2 == 0
for mn in [64, 128, 256, 512, 1024]:
for k in [4096, 8192, 16384]:
shapes.append((mn, mn, k))
# 7. Very short K (compute-bound) - ensure K >= 32, K % 2 == 0
for mn in [512, 1024, 2048, 4096]:
for k in [32, 64, 128]:
shapes.append((mn, mn, k))
# 8. Tiny M (edge cases for batch-1 inference)
for m in [1, 2, 4, 8, 16, 32]:
for n in [64, 128, 256, 512, 1024, 2048]: # N >= 64, N % 8 == 0
for k in [32, 64, 128, 256, 512]: # K >= 32, K % 2 == 0
shapes.append((m, n, k))
# 9. Stress test sizes (aligned to constraints)
stress_sizes = [
(10000, 10000, 10000),
(1000, 10000, 1000),
(1000, 1000, 10000),
(5000, 5000, 5000),
(7168, 7168, 7168), # Common LLM hidden dim
(8192, 11008, 8192), # LLaMA MLP dimensions
]
shapes.extend(stress_sizes)
# Remove duplicates while preserving order
seen = set()
unique_shapes = []
for s in shapes:
if s not in seen:
seen.add(s)
unique_shapes.append(s)
# Filter to ensure all shapes meet constraints
valid_shapes = []
for m, n, k in unique_shapes:
if m >= 1 and n >= 64 and n % 8 == 0 and k >= 32 and k % 2 == 0:
valid_shapes.append((m, n, k))
# Sample down to target number if we have too many
if len(valid_shapes) > num_shapes:
# Stratified sampling to preserve diversity
step = len(valid_shapes) / num_shapes
valid_shapes = [valid_shapes[int(i * step)] for i in range(num_shapes)]
return valid_shapes
def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict:
"""Convert KernelSpec to feature dict for ML predictor"""
return {
"kernel_name": spec.name,
"tile_m": spec.tile_m,
"tile_n": spec.tile_n,
"tile_k": spec.tile_k,
"warp_m": spec.wave_m,
"warp_n": spec.wave_n,
"warp_k": spec.wave_k,
"warp_tile_m": spec.warp_m,
"warp_tile_n": spec.warp_n,
"warp_tile_k": spec.warp_k,
"pipeline": spec.pipeline,
"scheduler": spec.scheduler,
"epilogue": "cshuffle",
"pad_m": True, # Enable padding to support arbitrary M dimensions
"pad_n": True, # Enable padding to support arbitrary N dimensions
"pad_k": True, # Enable padding to support arbitrary K dimensions
"persistent": False,
"dtype": dtype,
"layout": layout,
}
def spec_to_kernel_config(
spec: KernelSpec, dtype: str, arch: str, dtype_acc: str = "fp32"
) -> KernelConfig:
"""Convert KernelSpec to KernelConfig for dispatcher"""
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
dtype_c=dtype,
dtype_acc=dtype_acc,
layout_a="row",
layout_b="col",
layout_c="row",
tile_m=spec.tile_m,
tile_n=spec.tile_n,
tile_k=spec.tile_k,
wave_m=spec.wave_m,
wave_n=spec.wave_n,
wave_k=spec.wave_k,
warp_m=spec.warp_m,
warp_n=spec.warp_n,
warp_k=spec.warp_k,
pipeline=spec.pipeline,
scheduler=spec.scheduler,
epilogue="cshuffle",
gfx_arch=arch,
)
def ml_select_kernel(
predictor, pool: List[KernelSpec], M: int, N: int, K: int, dtype: str, layout: str
) -> Tuple[KernelSpec, float]:
"""Use ML model to select best kernel"""
if not HAS_ML or predictor is None:
# Fallback: select first kernel
return pool[0], 0.0
problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1}
kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool]
ranked = predictor.rank_kernels(problem, kernel_dicts)
if not ranked:
return pool[0], 0.0
best_name, best_tflops = ranked[0]
best_spec = next((s for s in pool if s.name == best_name), pool[0])
return best_spec, best_tflops
def run_single_gemm(
M: int,
N: int,
K: int,
dtype: str,
arch: str,
predictor,
dry_run: bool = False,
dtype_acc: str = "fp32",
) -> dict:
"""Run a single GEMM with ML heuristic selection"""
# Select kernel via ML heuristic
t0 = time.time()
best_spec, pred_tflops = ml_select_kernel(
predictor, KERNEL_POOL, M, N, K, dtype, "rcr"
)
select_time_ms = (time.time() - t0) * 1000
result = {
"M": M,
"N": N,
"K": K,
"dtype": dtype,
"selected_kernel": best_spec.name,
"predicted_tflops": pred_tflops,
"selection_time_ms": select_time_ms,
"actual_time_ms": 0,
"actual_tflops": 0,
"status": "SKIP" if dry_run else "PENDING",
"error": None,
}
if dry_run:
return result
# Build and run kernel
config = spec_to_kernel_config(best_spec, dtype, arch, dtype_acc)
try:
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"sweep_{dtype}_{best_spec.name}",
verbose=False,
auto_rebuild=True,
)
if not setup.success:
result["status"] = "BUILD_FAIL"
result["error"] = "Failed to build kernel"
cleanup_gemm()
return result
dispatcher = setup.dispatcher
if not dispatcher.is_supported(M, N, K):
result["status"] = "UNSUPPORTED"
result["error"] = "Problem size not supported by kernel"
cleanup_gemm()
return result
# Create input data
np_dtype = {"fp16": np.float16, "bf16": np.float16, "fp8": np.float16}[dtype]
np.random.seed(42)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
# Run GEMM
exec_result = dispatcher.run(A, B, M, N, K)
if exec_result.success:
result["actual_time_ms"] = exec_result.time_ms
result["actual_tflops"] = exec_result.tflops
result["status"] = "SUCCESS"
else:
# Decode status code for better error message
status_messages = {
0: "Success",
-1: "GPU/HIP error (check permissions, memory, or kernel validity)",
-2: "No suitable kernel found for this problem size",
}
error_msg = status_messages.get(exec_result.status, f"Unknown error (status={exec_result.status})")
result["status"] = "RUN_FAIL"
result["error"] = f"{error_msg} (status_code={exec_result.status})"
# Print detailed error for debugging
print(f" ERROR: {error_msg}")
print(f" Status code: {exec_result.status}")
print(f" Time returned: {exec_result.time_ms}")
print(f" Kernel: {exec_result.kernel_name}")
cleanup_gemm()
except Exception as e:
result["status"] = "ERROR"
result["error"] = str(e)[:200]
cleanup_gemm()
return result
def main():
parser = argparse.ArgumentParser(
description="ML Heuristic Sweep: Test GEMM across many shapes and dtypes"
)
parser.add_argument(
"--dtypes",
nargs="+",
default=["fp16"],
choices=["fp16", "bf16", "fp8"],
help="Data types to test (default: fp16)",
)
parser.add_argument(
"--arch", default="gfx950", help="GPU architecture (default: gfx950)"
)
parser.add_argument(
"--dtype_acc",
default="fp32",
choices=["fp16", "fp32"],
help="Accumulator data type (default: fp32)",
)
parser.add_argument(
"--model_dir",
default=None,
help="Path to model directory (auto-detect if not specified)",
)
parser.add_argument(
"--num_shapes",
type=int,
default=256,
help="Number of problem shapes to test (default: 256)",
)
parser.add_argument(
"--output",
default="ml_heuristic_sweep_results.csv",
help="Output CSV file path",
)
parser.add_argument(
"--dry_run",
action="store_true",
help="Only predict, do not run kernels (fast validation)",
)
args = parser.parse_args()
# Setup ML predictor
predictor = None
if HAS_ML:
if args.model_dir is None:
# Auto-detect model directory based on first dtype
first_dtype = args.dtypes[0]
heuristics_dir = Path(__file__).parent
model_candidates = [
heuristics_dir / "models" / f"gemm_universal_{first_dtype}_{args.arch}",
]
for model_dir in model_candidates:
if model_dir.exists():
args.model_dir = str(model_dir)
break
if args.model_dir and Path(args.model_dir).exists():
try:
predictor = Predictor(args.model_dir)
print(f"✓ Loaded ML model from: {args.model_dir}")
except Exception as e:
print(f"⚠ Failed to load ML model: {e}")
print(" Will use first-fit selection instead")
else:
print(f"⚠ Model directory not found: {args.model_dir}")
print(" Will use first-fit selection instead")
# Generate problem shapes
print(f"\nGenerating {args.num_shapes} problem shapes...")
shapes = generate_problem_shapes(args.num_shapes)
print(
f"✓ Generated {len(shapes)} valid shapes (M>=1, N%8==0, N>=64, K%2==0, K>=32)"
)
# Validate all shapes meet constraints
invalid = [
(m, n, k)
for m, n, k in shapes
if not (m >= 1 and n >= 64 and n % 8 == 0 and k >= 32 and k % 2 == 0)
]
if invalid:
print(f"⚠ WARNING: {len(invalid)} shapes violate constraints!")
print(f" First few: {invalid[:5]}")
# Print configuration
print("\n" + "=" * 80)
print(" ML Heuristic Sweep Configuration")
print("=" * 80)
print(
f" Model: {args.model_dir if args.model_dir else 'first-fit (no ML)'}"
)
print(f" Data types: {', '.join(args.dtypes)}")
print(f" Accumulator: {args.dtype_acc}")
print(f" Architecture: {args.arch}")
print(f" Kernel pool: {len(KERNEL_POOL)} kernels")
print(f" Problem shapes: {len(shapes)}")
print(f" Total tests: {len(shapes) * len(args.dtypes)}")
print(
f" Mode: {'DRY RUN (prediction only)' if args.dry_run else 'FULL RUN (execute kernels)'}"
)
print(f" Output: {args.output}")
print("=" * 80)
# Open output CSV
csv_file = open(args.output, "w", newline="")
csv_writer = csv.DictWriter(
csv_file,
fieldnames=[
"dtype",
"M",
"N",
"K",
"selected_kernel",
"predicted_tflops",
"selection_time_ms",
"actual_time_ms",
"actual_tflops",
"status",
"error",
],
)
csv_writer.writeheader()
# Run sweep
total_tests = len(shapes) * len(args.dtypes)
completed = 0
start_time = time.time()
print("\nStarting sweep... (Ctrl+C to stop and save partial results)\n")
try:
for dtype in args.dtypes:
print(f"\n{'=' * 80}")
print(f" Testing dtype: {dtype.upper()}")
print(f"{'=' * 80}\n")
for i, (M, N, K) in enumerate(shapes):
result = run_single_gemm(
M, N, K, dtype, args.arch, predictor, args.dry_run, args.dtype_acc
)
# Write to CSV
csv_writer.writerow(result)
csv_file.flush()
completed += 1
# Progress update
if completed % 10 == 0 or result["status"] != "SUCCESS":
elapsed = time.time() - start_time
rate = completed / elapsed if elapsed > 0 else 0
eta = (total_tests - completed) / rate if rate > 0 else 0
status_emoji = {
"SUCCESS": "",
"SKIP": "",
"BUILD_FAIL": "",
"UNSUPPORTED": "",
"RUN_FAIL": "",
"ERROR": "",
}.get(result["status"], "?")
print(
f" [{completed:4d}/{total_tests}] {status_emoji} "
f"{dtype:4s} {M:5d}x{N:5d}x{K:5d}"
f"{result['selected_kernel']:20s} "
f"pred={result['predicted_tflops']:6.1f} "
f"actual={result['actual_tflops']:6.1f} TFLOPS "
f"[{rate:.1f} tests/s, ETA {eta / 60:.1f}m]"
)
except KeyboardInterrupt:
print(f"\n\n⚠ Interrupted! Saving partial results to {args.output}...")
finally:
csv_file.close()
# Summary
print("\n" + "=" * 80)
print(" SWEEP COMPLETE")
print("=" * 80)
# Read back results and compute statistics
results = []
with open(args.output, "r") as f:
reader = csv.DictReader(f)
results = list(reader)
print(f"\n Total tests: {len(results)}")
print(f" Output file: {args.output}")
if not args.dry_run:
success = [r for r in results if r["status"] == "SUCCESS"]
print(
f" Successful: {len(success)} ({100 * len(success) / len(results):.1f}%)"
)
if success:
avg_tflops = np.mean([float(r["actual_tflops"]) for r in success])
max_tflops = max([float(r["actual_tflops"]) for r in success])
print(f" Avg TFLOPS: {avg_tflops:.2f}")
print(f" Max TFLOPS: {max_tflops:.2f}")
# Per-dtype breakdown
for dtype in args.dtypes:
dtype_results = [r for r in success if r["dtype"] == dtype]
if dtype_results:
avg = np.mean([float(r["actual_tflops"]) for r in dtype_results])
print(
f" {dtype:4s}: {avg:.2f} TFLOPS (n={len(dtype_results)})"
)
print("=" * 80)
print()
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,113 @@
{
"op_type": "gemm_universal",
"dtype": "fp16",
"arch": "gfx950",
"feature_names": [
"M",
"N",
"K",
"split_k",
"log2_M",
"log2_N",
"log2_K",
"log2_MNK",
"arithmetic_intensity",
"aspect_ratio_mn",
"aspect_ratio_mk",
"aspect_ratio_nk",
"layout",
"tile_m",
"tile_n",
"tile_k",
"warp_m",
"warp_n",
"warp_k",
"warp_tile_m",
"warp_tile_n",
"warp_tile_k",
"pipeline",
"scheduler",
"epilogue",
"pad_m",
"pad_n",
"pad_k",
"persistent",
"num_warps",
"tile_volume",
"tile_mn",
"lds_usage_estimate",
"lds_usage_ratio",
"num_tiles_m",
"num_tiles_n",
"num_tiles_k",
"total_output_tiles",
"tile_eff_m",
"tile_eff_n",
"tile_eff_k",
"overall_tile_efficiency",
"cu_utilization",
"ratio_M_to_tile_m",
"ratio_N_to_tile_n",
"ratio_K_to_tile_k",
"problem_smaller_than_tile_m",
"problem_smaller_than_tile_n",
"problem_smaller_than_tile_k",
"any_dim_too_small",
"needs_padding_m",
"needs_padding_n",
"needs_padding_k",
"has_padding_when_needed_m",
"has_padding_when_needed_n",
"has_padding_when_needed_k",
"missing_required_padding_m",
"missing_required_padding_n",
"missing_required_padding_k",
"missing_any_required_padding",
"hw_num_cus",
"hw_simds_per_cu",
"hw_total_simds",
"hw_shader_engines",
"hw_max_clock_mhz",
"hw_max_waves_per_cu",
"hw_wavefront_size",
"hw_lds_capacity",
"hw_l1_cache_kb",
"hw_l2_cache_kb",
"hw_l3_cache_kb",
"hw_num_xcd"
],
"categorical_features": [
"layout",
"pipeline",
"scheduler",
"epilogue"
],
"targets": [
"tflops",
"latency",
"bandwidth"
],
"log_targets": [
"bandwidth",
"tflops"
],
"params": {
"objective": "regression",
"metric": [
"rmse",
"mae"
],
"num_leaves": 255,
"max_depth": 15,
"n_estimators": 2000,
"learning_rate": 0.02,
"min_child_samples": 10,
"subsample": 0.85,
"colsample_bytree": 0.85,
"reg_alpha": 0.05,
"reg_lambda": 0.5,
"verbose": -1,
"n_jobs": 8,
"seed": 42
}
}

View File

@@ -0,0 +1,10 @@
{
"warm_start_from": null,
"prev_n_estimators": 0,
"new_n_estimators": 2000,
"total_n_estimators": 2000,
"data_rows": 25600,
"valid_rows": 21920,
"unique_shapes": 25,
"timestamp": "2026-03-20T05:00:55"
}

View File

@@ -0,0 +1,113 @@
{
"op_type": "gemm_universal",
"dtype": "fp8",
"arch": "gfx950",
"feature_names": [
"M",
"N",
"K",
"split_k",
"log2_M",
"log2_N",
"log2_K",
"log2_MNK",
"arithmetic_intensity",
"aspect_ratio_mn",
"aspect_ratio_mk",
"aspect_ratio_nk",
"layout",
"tile_m",
"tile_n",
"tile_k",
"warp_m",
"warp_n",
"warp_k",
"warp_tile_m",
"warp_tile_n",
"warp_tile_k",
"pipeline",
"scheduler",
"epilogue",
"pad_m",
"pad_n",
"pad_k",
"persistent",
"num_warps",
"tile_volume",
"tile_mn",
"lds_usage_estimate",
"lds_usage_ratio",
"num_tiles_m",
"num_tiles_n",
"num_tiles_k",
"total_output_tiles",
"tile_eff_m",
"tile_eff_n",
"tile_eff_k",
"overall_tile_efficiency",
"cu_utilization",
"ratio_M_to_tile_m",
"ratio_N_to_tile_n",
"ratio_K_to_tile_k",
"problem_smaller_than_tile_m",
"problem_smaller_than_tile_n",
"problem_smaller_than_tile_k",
"any_dim_too_small",
"needs_padding_m",
"needs_padding_n",
"needs_padding_k",
"has_padding_when_needed_m",
"has_padding_when_needed_n",
"has_padding_when_needed_k",
"missing_required_padding_m",
"missing_required_padding_n",
"missing_required_padding_k",
"missing_any_required_padding",
"hw_num_cus",
"hw_simds_per_cu",
"hw_total_simds",
"hw_shader_engines",
"hw_max_clock_mhz",
"hw_max_waves_per_cu",
"hw_wavefront_size",
"hw_lds_capacity",
"hw_l1_cache_kb",
"hw_l2_cache_kb",
"hw_l3_cache_kb",
"hw_num_xcd"
],
"categorical_features": [
"layout",
"pipeline",
"scheduler",
"epilogue"
],
"targets": [
"tflops",
"latency",
"bandwidth"
],
"log_targets": [
"bandwidth",
"tflops"
],
"params": {
"objective": "regression",
"metric": [
"rmse",
"mae"
],
"num_leaves": 255,
"max_depth": 15,
"n_estimators": 2000,
"learning_rate": 0.02,
"min_child_samples": 10,
"subsample": 0.85,
"colsample_bytree": 0.85,
"reg_alpha": 0.05,
"reg_lambda": 0.5,
"verbose": -1,
"n_jobs": 8,
"seed": 42
}
}

View File

@@ -0,0 +1,10 @@
{
"warm_start_from": null,
"prev_n_estimators": 0,
"new_n_estimators": 2000,
"total_n_estimators": 2000,
"data_rows": 1296528,
"valid_rows": 1253076,
"unique_shapes": 168,
"timestamp": "2026-03-19T06:10:29"
}

View File

@@ -0,0 +1,243 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Predictor for CK Tile kernel performance.
Loads trained LightGBM models and provides:
- predict_tflops(): predicted TFLOPS for a single (problem, kernel) pair
- predict_latency(): predicted latency in ms
- predict_bandwidth(): predicted bandwidth in GB/s
- predict_all(): all three predictions at once
- rank_kernels(): rank all candidate kernels by predicted TFLOPS
- select_best(): return the best kernel ID
Usage:
predictor = Predictor("models/gemm_universal_fp8_gfx950")
best_kernel = predictor.select_best(
problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"},
kernel_configs=[...],
)
"""
import gzip
import json
from pathlib import Path
from typing import Optional
import lightgbm as lgb
import numpy as np
import pandas as pd
from feature_engine import GemmUniversalFeatureEngine
class Predictor:
"""Loads trained models and feature spec for kernel performance prediction.
Parameters
----------
model_dir : str or Path
Directory containing model artifacts:
- model_tflops.lgbm (required)
- model_latency.lgbm (optional)
- model_bandwidth.lgbm (optional)
- feature_spec.json (required)
feature_engine : FeatureEngine, optional
Override the feature engine. If None, constructs one from feature_spec.json.
"""
def __init__(self, model_dir: str | Path, feature_engine=None):
self._model_dir = Path(model_dir)
self._models: dict[str, lgb.Booster] = {}
spec_path = self._model_dir / "feature_spec.json"
if spec_path.exists():
with open(spec_path) as f:
self._spec = json.load(f)
else:
self._spec = {}
self._log_targets = set(self._spec.get("log_targets", []))
if feature_engine is not None:
self._feature_engine = feature_engine
else:
self._feature_engine = GemmUniversalFeatureEngine()
def _load_model(self, target: str) -> Optional[lgb.Booster]:
"""Lazy-load a model for the given target.
Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist.
The decompressed file is cached to disk for subsequent loads.
"""
if target in self._models:
return self._models[target]
path = self._model_dir / f"model_{target}.lgbm"
gz_path = self._model_dir / f"model_{target}.lgbm.gz"
# Auto-decompress if needed
if not path.exists() and gz_path.exists():
with gzip.open(gz_path, 'rb') as f_in:
with open(path, 'wb') as f_out:
f_out.write(f_in.read())
if not path.exists():
return None
model = lgb.Booster(model_file=str(path))
self._models[target] = model
return model
def _predict_single(self, target: str, problem: dict, kernel_config: dict) -> float:
"""Predict a single target value, applying inverse log transform if needed."""
model = self._load_model(target)
if model is None:
raise FileNotFoundError(f"No model_{target}.lgbm in {self._model_dir}")
features = self._feature_engine.extract(problem, kernel_config)
raw = float(model.predict(features.reshape(1, -1))[0])
if target in self._log_targets:
return float(np.expm1(raw))
# Clamp to non-negative even for non-log models
return float(max(0.0, raw))
def predict_tflops(self, problem: dict, kernel_config: dict) -> float:
"""Predict TFLOPS for a single (problem, kernel) pair.
Returns a real TFLOPS estimate (interpretable, usable as DE surrogate).
If the model was trained in log-space, the inverse transform is applied
automatically.
"""
return self._predict_single("tflops", problem, kernel_config)
def predict_latency(self, problem: dict, kernel_config: dict) -> float:
"""Predict latency in milliseconds for a single (problem, kernel) pair."""
return self._predict_single("latency", problem, kernel_config)
def predict_bandwidth(self, problem: dict, kernel_config: dict) -> float:
"""Predict bandwidth in GB/s for a single (problem, kernel) pair."""
return self._predict_single("bandwidth", problem, kernel_config)
def predict_all(self, problem: dict, kernel_config: dict) -> dict[str, float]:
"""Predict all available targets for a single (problem, kernel) pair.
Returns dict with keys 'tflops', 'latency_ms', 'bandwidth_gb_s' (if models exist).
Note: Applies inverse log transform for targets in log_targets and clamps
negatives to 0.0, consistent with _predict_single().
"""
features = self._feature_engine.extract(problem, kernel_config).reshape(1, -1)
result = {}
for target, key in [
("tflops", "tflops"),
("latency", "latency_ms"),
("bandwidth", "bandwidth_gb_s"),
]:
model = self._load_model(target)
if model is not None:
raw = float(model.predict(features)[0])
# Apply inverse log transform if model was trained in log-space
if target in self._log_targets:
result[key] = float(np.expm1(raw))
else:
# Clamp to non-negative even for non-log models
result[key] = float(max(0.0, raw))
return result
def rank_kernels(
self, problem: dict, kernel_configs: list[dict]
) -> list[tuple[str, float]]:
"""Rank candidate kernels by predicted TFLOPS (descending).
Parameters
----------
problem : dict
Problem specification with keys: m, n, k, dtype, layout, split_k.
kernel_configs : list of dict
Each dict must have a 'kernel_name' key plus kernel parameters.
Returns
-------
list of (kernel_name, predicted_tflops) tuples, sorted descending.
"""
if not kernel_configs:
return []
model = self._load_model("tflops")
if model is None:
raise FileNotFoundError(f"No model_tflops.lgbm in {self._model_dir}")
rows = []
for kc in kernel_configs:
merged = {**problem, **kc}
rows.append(merged)
df = pd.DataFrame(rows)
X = self._feature_engine.extract_batch(df)
preds = model.predict(X)
if "tflops" in self._log_targets:
preds = np.expm1(preds)
results = []
for i, kc in enumerate(kernel_configs):
name = kc.get("kernel_name", f"kernel_{i}")
results.append((name, float(preds[i])))
results.sort(key=lambda x: -x[1])
return results
def select_best(self, problem: dict, kernel_configs: list[dict]) -> str:
"""Return the kernel_name of the best predicted kernel."""
ranked = self.rank_kernels(problem, kernel_configs)
if not ranked:
raise ValueError("No kernel configs provided")
return ranked[0][0]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Predict kernel performance")
parser.add_argument(
"--model_dir", required=True, help="Directory with trained models"
)
parser.add_argument("--m", type=int, required=True)
parser.add_argument("--n", type=int, required=True)
parser.add_argument("--k", type=int, required=True)
parser.add_argument("--layout", default="rcr")
parser.add_argument("--dtype", default="fp8")
args = parser.parse_args()
predictor = Predictor(args.model_dir)
problem = {
"m": args.m,
"n": args.n,
"k": args.k,
"dtype": args.dtype,
"layout": args.layout,
"split_k": 1,
}
print(f"Loading models from {args.model_dir}...")
print(
f"Problem: M={args.m} N={args.n} K={args.k} dtype={args.dtype} layout={args.layout}"
)
data_dir = Path(args.model_dir).parent.parent / "data"
if data_dir.exists():
for pq in data_dir.glob("*.parquet"):
df = pd.read_parquet(pq)
kernel_names = df["kernel_name"].unique()
configs = []
for kn in kernel_names[:10]:
row = df[df["kernel_name"] == kn].iloc[0]
configs.append(row.to_dict())
if configs:
ranked = predictor.rank_kernels(problem, configs)
print(f"\nTop 5 kernels (from {len(configs)} candidates):")
for name, tflops in ranked[:5]:
print(f" {tflops:8.2f} TFLOPS {name}")
break

View File

@@ -0,0 +1,272 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Surrogate search for CK Tile kernel configuration optimization.
Uses a trained LGBMRegressor as a cheap surrogate function to search the
discrete kernel parameter space (tile sizes, warp config, pipeline, etc.)
without running actual GPU benchmarks.
Strategies:
- 'random': Sample N random valid configs, score all, return top-K.
- 'de': Discrete Differential Evolution with mutation over valid parameter choices.
Usage:
from search import SurrogateSearch
from predict import Predictor
predictor = Predictor("models/gemm_universal_fp8_gfx950")
searcher = SurrogateSearch(predictor, strategy='random')
results = searcher.search(
problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"},
budget=500,
)
# results: [(config_dict, predicted_tflops), ...] sorted descending
"""
import random
from typing import Optional
import numpy as np
from feature_engine import GemmUniversalFeatureEngine
from predict import Predictor
class SurrogateSearch:
"""Search kernel parameter space using ML regressor as surrogate objective.
Parameters
----------
predictor : Predictor
Trained predictor with a TFLOPS model.
feature_engine : GemmUniversalFeatureEngine, optional
Feature engine for parameter space and validation. If None, uses default.
strategy : str
Search strategy: 'random' or 'de' (Discrete Differential Evolution).
seed : int
Random seed for reproducibility.
"""
def __init__(
self,
predictor: Predictor,
feature_engine: Optional[GemmUniversalFeatureEngine] = None,
strategy: str = "random",
seed: int = 42,
):
self._predictor = predictor
self._fe = feature_engine or GemmUniversalFeatureEngine()
self._strategy = strategy
self._rng = random.Random(seed)
self._np_rng = np.random.RandomState(seed)
self._param_space = self._fe.get_parameter_space()
def _sample_random_config(self) -> dict:
"""Sample a single random config from the parameter space."""
config = {}
for param, values in self._param_space.items():
config[param] = self._rng.choice(values)
return config
def _sample_valid_config(self, max_attempts: int = 50) -> Optional[dict]:
"""Sample a random config that passes all validation constraints."""
for _ in range(max_attempts):
config = self._sample_random_config()
if self._fe.validate_config(config):
return config
return None
def _score_config(self, problem: dict, config: dict) -> float:
"""Score a config using the predictor."""
return self._predictor.predict_tflops(problem, config)
def _search_random(
self, problem: dict, budget: int, top_k: int
) -> list[tuple[dict, float]]:
"""Random search: sample valid configs, score all, return top-K."""
configs = []
for _ in range(budget):
cfg = self._sample_valid_config()
if cfg is not None:
configs.append(cfg)
if not configs:
return []
scored = []
for cfg in configs:
try:
score = self._score_config(problem, cfg)
scored.append((cfg, score))
except Exception:
continue
scored.sort(key=lambda x: -x[1])
return scored[:top_k]
def _search_de(
self,
problem: dict,
budget: int,
top_k: int,
pop_size: int = 20,
mutation_rate: float = 0.3,
crossover_rate: float = 0.7,
) -> list[tuple[dict, float]]:
"""Discrete Differential Evolution.
Uses discrete mutation: randomly swap parameters to other valid values
from the parameter space (no continuous relaxation + snap).
Each generation:
1. For each member of the population, create a trial vector by:
- Selecting 3 random other members (a, b, c)
- For each parameter, with probability mutation_rate, take the value
from a, b, or c (uniform choice among the three donors)
- With probability crossover_rate, take the trial value; otherwise keep original
2. Validate the trial; if invalid, resample that parameter from the space
3. Score the trial; if better than parent, replace
"""
param_names = list(self._param_space.keys())
population = []
for _ in range(pop_size):
cfg = self._sample_valid_config()
if cfg is not None:
score = self._score_config(problem, cfg)
population.append((cfg, score))
if len(population) < 4:
return self._search_random(problem, budget, top_k)
evals_used = len(population)
max_gens = (budget - evals_used) // pop_size
for gen in range(max_gens):
new_pop = []
for i, (parent, parent_score) in enumerate(population):
candidates = [j for j in range(len(population)) if j != i]
if len(candidates) < 3:
new_pop.append((parent, parent_score))
continue
a_idx, b_idx, c_idx = self._rng.sample(candidates, 3)
a, b, c = (
population[a_idx][0],
population[b_idx][0],
population[c_idx][0],
)
trial = dict(parent)
for param in param_names:
if self._rng.random() < mutation_rate:
donor = self._rng.choice([a, b, c])
trial[param] = donor.get(param, parent.get(param))
if self._rng.random() > crossover_rate:
trial[param] = parent.get(param)
if not self._fe.validate_config(trial):
for param in param_names:
if param in trial and trial[param] not in self._param_space.get(
param, [trial[param]]
):
trial[param] = self._rng.choice(self._param_space[param])
if not self._fe.validate_config(trial):
new_pop.append((parent, parent_score))
continue
try:
trial_score = self._score_config(problem, trial)
evals_used += 1
except Exception:
new_pop.append((parent, parent_score))
continue
if trial_score > parent_score:
new_pop.append((trial, trial_score))
else:
new_pop.append((parent, parent_score))
population = new_pop
population.sort(key=lambda x: -x[1])
return population[:top_k]
def search(
self,
problem: dict,
budget: int = 500,
top_k: int = 10,
**kwargs,
) -> list[tuple[dict, float]]:
"""Search the kernel parameter space for the best configuration.
Parameters
----------
problem : dict
Problem specification: m, n, k, dtype, layout, split_k.
budget : int
Maximum number of surrogate evaluations.
top_k : int
Number of top configurations to return.
**kwargs
Strategy-specific parameters (pop_size, mutation_rate, etc.).
Returns
-------
list of (config_dict, predicted_tflops), sorted descending by TFLOPS.
"""
if self._strategy == "random":
return self._search_random(problem, budget, top_k)
elif self._strategy == "de":
return self._search_de(problem, budget, top_k, **kwargs)
else:
raise ValueError(f"Unknown strategy: {self._strategy}")
if __name__ == "__main__":
import argparse
import time
parser = argparse.ArgumentParser(
description="Surrogate search for optimal kernel config"
)
parser.add_argument("--model_dir", required=True)
parser.add_argument("--m", type=int, required=True)
parser.add_argument("--n", type=int, required=True)
parser.add_argument("--k", type=int, required=True)
parser.add_argument("--dtype", default="fp8")
parser.add_argument("--layout", default="rcr")
parser.add_argument("--strategy", default="random", choices=["random", "de"])
parser.add_argument("--budget", type=int, default=500)
parser.add_argument("--top_k", type=int, default=10)
args = parser.parse_args()
predictor = Predictor(args.model_dir)
searcher = SurrogateSearch(predictor, strategy=args.strategy)
problem = {
"m": args.m,
"n": args.n,
"k": args.k,
"dtype": args.dtype,
"layout": args.layout,
"split_k": 1,
}
print(f"Searching with strategy={args.strategy}, budget={args.budget}...")
t0 = time.time()
results = searcher.search(problem, budget=args.budget, top_k=args.top_k)
elapsed = time.time() - t0
print(f"\nTop {len(results)} configs found in {elapsed * 1000:.1f}ms:")
for i, (cfg, tflops) in enumerate(results):
tile_str = f"{cfg.get('tile_m', '?')}x{cfg.get('tile_n', '?')}x{cfg.get('tile_k', '?')}"
warp_str = f"{cfg.get('warp_m', '?')}x{cfg.get('warp_n', '?')}x{cfg.get('warp_k', '?')}"
print(
f" #{i + 1}: {tflops:8.2f} TFLOPS tile={tile_str} warp={warp_str} "
f"pipeline={cfg.get('pipeline', '?')} scheduler={cfg.get('scheduler', '?')}"
)

View File

@@ -0,0 +1,2 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,368 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for data_pipeline.py.
Covers: kernel name parsing, layout derivation, streaming log parsing,
schema validation, and corner cases (empty logs, malformed JSON, single-shape).
"""
import tempfile
from pathlib import Path
import pandas as pd
import pytest
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from data_pipeline import (
parse_kernel_name,
_layout_from_problem,
parse_streaming_log,
save_parquet,
load_parquet,
CANONICAL_COLUMNS,
)
# ---------------------------------------------------------------------------
# parse_kernel_name
# ---------------------------------------------------------------------------
class TestParseKernelName:
def test_standard_name(self):
name = "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128"
result = parse_kernel_name(name)
assert result["dtype"] == "fp8"
assert result["layout"] == "rcr"
assert result["pipeline"] == "compv3"
assert result["epilogue"] == "cshuffle"
assert result["scheduler"] == "intrawave"
assert result["pad_m"] is False
assert result["pad_n"] is False
assert result["pad_k"] is False
assert result["persistent"] is False
assert result["tile_m"] == 128
assert result["tile_n"] == 128
assert result["tile_k"] == 128
assert result["warp_m"] == 1
assert result["warp_n"] == 4
assert result["warp_k"] == 1
assert result["warp_tile_m"] == 16
assert result["warp_tile_n"] == 16
assert result["warp_tile_k"] == 128
def test_with_padding_and_persistent(self):
name = "gemm_universal_fp16_rrr_compv4_default_interwave_True_True_True_True_256x256x64_2x2x1_32x32x16"
result = parse_kernel_name(name)
assert result["dtype"] == "fp16"
assert result["layout"] == "rrr"
assert result["pad_m"] is True
assert result["pad_n"] is True
assert result["pad_k"] is True
assert result["persistent"] is True
assert result["tile_m"] == 256
def test_empty_name(self):
assert parse_kernel_name("") == {}
def test_malformed_name(self):
assert parse_kernel_name("not_a_kernel_name") == {}
def test_partial_name(self):
result = parse_kernel_name("gemm_universal_fp8_rcr_compv3")
assert result.get("dtype") == "fp8"
assert result.get("layout") == "rcr"
assert "tile_m" not in result # not enough parts
def test_all_layouts(self):
for layout in ["rcr", "rrr", "crr", "ccr"]:
name = f"gemm_universal_fp8_{layout}_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128"
result = parse_kernel_name(name)
assert result["layout"] == layout
# ---------------------------------------------------------------------------
# _layout_from_problem
# ---------------------------------------------------------------------------
class TestLayoutFromProblem:
def test_rcr(self):
assert (
_layout_from_problem(
{
"layout_a": "RowMajor",
"layout_b": "ColumnMajor",
"layout_c": "RowMajor",
}
)
== "rcr"
)
def test_rrr(self):
assert (
_layout_from_problem(
{"layout_a": "RowMajor", "layout_b": "RowMajor", "layout_c": "RowMajor"}
)
== "rrr"
)
def test_empty(self):
assert _layout_from_problem({}) == "???"
def test_case_insensitive(self):
assert (
_layout_from_problem(
{
"layout_a": "rowmajor",
"layout_b": "COLUMNMAJOR",
"layout_c": "RowMajor",
}
)
== "rcr"
)
# ---------------------------------------------------------------------------
# parse_streaming_log
# ---------------------------------------------------------------------------
SAMPLE_LOG = """\
================================================================================
LOG FILE: test.log
================================================================================
CK Tile Profiling Run
GPU ID: 0
--- Running CK Tile benchmarks on GPU 0 ---
========================================
Shape 1: M=16 N=1536 K=7168 dtype=fp8 layout=rcr
========================================
Found 2 kernels
{
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
"problem": {
"split_k":1,
"m":16,
"n":1536,
"k":7168,
"stride_a":7168,
"stride_b":7168,
"stride_c":1536,
"dtype_a":"fp8",
"dtype_b":"fp8",
"dtype_acc":"fp32",
"dtype_c":"fp16",
"layout_a":"RowMajor",
"layout_b":"ColumnMajor",
"layout_c":"RowMajor",
"structured_sparsity":false
},
"perf_result": {
"latency(ms)": 0.04,
"tflops(TFlops)": 8.81,
"bandwidth(GB/s)": 279.51
}
}
{
"name": "gemm_universal_fp8_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_2x2x1_32x32x16",
"problem": {
"split_k":1,
"m":16,
"n":1536,
"k":7168,
"stride_a":7168,
"stride_b":7168,
"stride_c":1536,
"dtype_a":"fp8",
"dtype_b":"fp8",
"dtype_acc":"fp32",
"dtype_c":"fp16",
"layout_a":"RowMajor",
"layout_b":"ColumnMajor",
"layout_c":"RowMajor",
"structured_sparsity":false
},
"perf_result": {
"latency(ms)": 0.05,
"tflops(TFlops)": 7.22,
"bandwidth(GB/s)": 228.85
}
}
========================================
Shape 2: M=20480 N=7168 K=256 dtype=fp8 layout=rcr
========================================
Found 1 kernels
{
"name": "gemm_universal_fp8_rcr_mem_cshuffle_intrawave_False_False_False_True_64x64x128_1x4x1_16x16x32",
"problem": {
"split_k":1,
"m":20480,
"n":7168,
"k":256,
"stride_a":256,
"stride_b":256,
"stride_c":7168,
"dtype_a":"fp8",
"dtype_b":"fp8",
"dtype_acc":"fp32",
"dtype_c":"fp16",
"layout_a":"RowMajor",
"layout_b":"ColumnMajor",
"layout_c":"RowMajor",
"structured_sparsity":false
},
"perf_result": {
"latency(ms)": 0.15,
"tflops(TFlops)": 505.00,
"bandwidth(GB/s)": 1200.50
}
}
"""
class TestParseStreamingLog:
def _write_log(self, content: str) -> Path:
f = tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False)
f.write(content)
f.close()
return Path(f.name)
def test_basic_parse(self):
path = self._write_log(SAMPLE_LOG)
df = parse_streaming_log(path, arch="gfx950")
assert len(df) == 3
assert df["arch"].iloc[0] == "gfx950"
assert df["m"].tolist() == [16, 16, 20480]
assert df["n"].tolist() == [1536, 1536, 7168]
assert df["k"].tolist() == [7168, 7168, 256]
def test_tflops_values(self):
path = self._write_log(SAMPLE_LOG)
df = parse_streaming_log(path)
assert df["measured_tflops"].tolist() == pytest.approx([8.81, 7.22, 505.0])
def test_kernel_config_parsed(self):
path = self._write_log(SAMPLE_LOG)
df = parse_streaming_log(path)
assert df["tile_m"].iloc[0] == 128
assert df["pipeline"].iloc[0] == "compv3"
assert df["pipeline"].iloc[1] == "compv4"
def test_layout_derived_from_json(self):
path = self._write_log(SAMPLE_LOG)
df = parse_streaming_log(path)
assert all(df["layout"] == "rcr")
def test_empty_log(self):
path = self._write_log("No shapes here\nJust noise\n")
df = parse_streaming_log(path)
assert len(df) == 0
for col in CANONICAL_COLUMNS:
assert col in df.columns
def test_single_kernel(self):
log = """\
Shape 1: M=1 N=1 K=1 dtype=fp8 layout=rcr
{
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
"problem": {"split_k":1, "m":1, "n":1, "k":1, "dtype_a":"fp8", "dtype_b":"fp8", "layout_a":"RowMajor", "layout_b":"ColumnMajor", "layout_c":"RowMajor"},
"perf_result": {"latency(ms)": 0.001, "tflops(TFlops)": 0.002, "bandwidth(GB/s)": 0.01}
}
"""
path = self._write_log(log)
df = parse_streaming_log(path)
assert len(df) == 1
assert df["m"].iloc[0] == 1
assert bool(df["is_valid"].iloc[0]) is True
def test_zero_tflops_marked_invalid(self):
log = """\
Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr
{
"name": "test_kernel",
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
"perf_result": {"latency(ms)": 0.0, "tflops(TFlops)": 0.0, "bandwidth(GB/s)": 0.0}
}
"""
path = self._write_log(log)
df = parse_streaming_log(path)
assert len(df) == 1
assert bool(df["is_valid"].iloc[0]) is False
def test_malformed_json_skipped(self):
log = """\
Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr
{
"name": "good_kernel",
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
"perf_result": {"latency(ms)": 0.01, "tflops(TFlops)": 1.0, "bandwidth(GB/s)": 10.0}
}
{ this is not valid json }
{
"name": "another_good",
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
"perf_result": {"latency(ms)": 0.02, "tflops(TFlops)": 2.0, "bandwidth(GB/s)": 20.0}
}
"""
path = self._write_log(log)
df = parse_streaming_log(path)
assert len(df) == 2
def test_extreme_shapes(self):
"""Tiny M=1 (single token) and very large M=20480."""
path = self._write_log(SAMPLE_LOG)
df = parse_streaming_log(path)
assert 1 not in df["m"].values # sample has M=16, M=20480
assert 16 in df["m"].values
assert 20480 in df["m"].values
def test_run_id_assigned(self):
path = self._write_log(SAMPLE_LOG)
df = parse_streaming_log(path, run_id="test_run_123")
assert all(df["run_id"] == "test_run_123")
def test_op_type_assigned(self):
path = self._write_log(SAMPLE_LOG)
df = parse_streaming_log(path, op_type="gemm_streamk")
assert all(df["op_type"] == "gemm_streamk")
# ---------------------------------------------------------------------------
# Parquet round-trip
# ---------------------------------------------------------------------------
class TestParquetIO:
def test_round_trip(self, tmp_path):
df = pd.DataFrame(
{
"m": [16, 32],
"n": [1536, 1536],
"k": [7168, 7168],
"measured_tflops": [8.81, 15.0],
}
)
path = tmp_path / "test.parquet"
save_parquet(df, path)
loaded = load_parquet(path)
assert len(loaded) == 2
assert loaded["m"].tolist() == [16, 32]
def test_creates_parent_dirs(self, tmp_path):
path = tmp_path / "sub" / "dir" / "test.parquet"
df = pd.DataFrame({"x": [1]})
save_parquet(df, path)
assert path.exists()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,264 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for dispatcher_integration.py.
Covers: kernel name parsing to feature dict, feature dict to dispatcher config
(name mapping inversion), MLKernelSpec creation, binary pool loading, and
the ML heuristic function.
"""
import json
import sys
from pathlib import Path
import lightgbm as lgb
import numpy as np
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from dispatcher_integration import (
kernel_config_to_feature_dict,
feature_dict_to_dispatcher_config,
feature_dict_to_ml_spec,
ml_spec_to_dispatcher_config,
create_ml_heuristic,
load_kernel_pool_from_binaries,
MLKernelSpec,
LAYOUT_TO_DISPATCHER,
)
from feature_engine import GemmUniversalFeatureEngine
SAMPLE_KERNEL_NAME = (
"gemm_universal_fp8_rcr_compv3_cshuffle_intrawave"
"_False_False_False_False_128x128x128_1x4x1_16x16x128"
)
# ---------------------------------------------------------------------------
# kernel_config_to_feature_dict
# ---------------------------------------------------------------------------
class TestKernelConfigToFeatureDict:
def test_parses_standard_name(self):
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
assert feat["tile_m"] == 128
assert feat["tile_n"] == 128
assert feat["tile_k"] == 128
assert feat["warp_m"] == 1 # warps per block
assert feat["warp_n"] == 4
assert feat["warp_k"] == 1
assert feat["warp_tile_m"] == 16
assert feat["warp_tile_n"] == 16
assert feat["warp_tile_k"] == 128
assert feat["pipeline"] == "compv3"
assert feat["scheduler"] == "intrawave"
assert feat["epilogue"] == "cshuffle"
assert feat["kernel_name"] == SAMPLE_KERNEL_NAME
def test_empty_name_returns_empty(self):
assert kernel_config_to_feature_dict("") == {}
def test_invalid_name_returns_empty(self):
assert kernel_config_to_feature_dict("not_a_kernel") == {}
# ---------------------------------------------------------------------------
# Name mapping: feature dict <-> dispatcher config
# ---------------------------------------------------------------------------
class TestNameMapping:
"""The critical inversion: feature engine warp_m/n/k (warps per block)
maps to dispatcher wave_m/n/k, and feature engine warp_tile_m/n/k
maps to dispatcher warp_m/n/k."""
def test_warp_to_wave_mapping(self):
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
disp = feature_dict_to_dispatcher_config(feat)
assert disp["wave_m"] == feat["warp_m"] # 1
assert disp["wave_n"] == feat["warp_n"] # 4
assert disp["wave_k"] == feat["warp_k"] # 1
def test_warp_tile_to_warp_mapping(self):
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
disp = feature_dict_to_dispatcher_config(feat)
assert disp["warp_m"] == feat["warp_tile_m"] # 16
assert disp["warp_n"] == feat["warp_tile_n"] # 16
assert disp["warp_k"] == feat["warp_tile_k"] # 128
def test_tile_dims_pass_through(self):
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
disp = feature_dict_to_dispatcher_config(feat)
assert disp["tile_m"] == 128
assert disp["tile_n"] == 128
assert disp["tile_k"] == 128
def test_pipeline_passes_through(self):
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
disp = feature_dict_to_dispatcher_config(feat)
assert disp["pipeline"] == "compv3"
assert disp["scheduler"] == "intrawave"
assert disp["epilogue"] == "cshuffle"
def test_rcr_layout_mapping(self):
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
disp = feature_dict_to_dispatcher_config(feat, dtype="fp8")
assert disp["layout_a"] == "row"
assert disp["layout_b"] == "col"
assert disp["layout_c"] == "row"
def test_all_layouts(self):
for layout, (la, lb, lc) in LAYOUT_TO_DISPATCHER.items():
feat = {"layout": layout, "tile_m": 128}
disp = feature_dict_to_dispatcher_config(feat)
assert disp["layout_a"] == la
assert disp["layout_b"] == lb
assert disp["layout_c"] == lc
# ---------------------------------------------------------------------------
# MLKernelSpec
# ---------------------------------------------------------------------------
class TestMLKernelSpec:
def test_from_feature_dict(self):
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
spec = feature_dict_to_ml_spec(feat, predicted_tflops=123.4)
assert spec.kernel_name == SAMPLE_KERNEL_NAME
assert spec.predicted_tflops == 123.4
assert spec.tile_m == 128
assert spec.wave_m == 1 # was warp_m in feature space
assert spec.warp_m == 16 # was warp_tile_m in feature space
def test_spec_to_dispatcher_config(self):
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
spec = feature_dict_to_ml_spec(feat, 100.0)
disp = ml_spec_to_dispatcher_config(spec, dtype="fp8", arch="gfx950")
assert disp["tile_m"] == 128
assert disp["wave_m"] == 1
assert disp["warp_m"] == 16
assert disp["gfx_arch"] == "gfx950"
assert disp["dtype_a"] == "fp8"
def test_roundtrip_preserves_values(self):
"""feature_dict -> MLKernelSpec -> dispatcher_config should be consistent."""
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
spec = feature_dict_to_ml_spec(feat, 0.0)
disp_from_spec = ml_spec_to_dispatcher_config(spec)
disp_from_feat = feature_dict_to_dispatcher_config(feat)
for key in [
"tile_m",
"tile_n",
"tile_k",
"wave_m",
"wave_n",
"wave_k",
"warp_m",
"warp_n",
"warp_k",
"pipeline",
"scheduler",
"epilogue",
]:
assert disp_from_spec[key] == disp_from_feat[key], f"Mismatch on {key}"
# ---------------------------------------------------------------------------
# Binary pool loading
# ---------------------------------------------------------------------------
class TestLoadKernelPool:
def test_loads_from_real_bin_dir(self):
bin_dir = Path("/workspace/ck_tile/bin")
if not bin_dir.exists():
pytest.skip("No /workspace/ck_tile/bin")
pool = load_kernel_pool_from_binaries(bin_dir)
assert len(pool) > 0
assert "tile_m" in pool[0]
assert "kernel_name" in pool[0]
def test_empty_dir_returns_empty(self, tmp_path):
pool = load_kernel_pool_from_binaries(tmp_path)
assert pool == []
# ---------------------------------------------------------------------------
# ML heuristic function
# ---------------------------------------------------------------------------
class TestCreateMLHeuristic:
@pytest.fixture
def mock_model_dir(self, tmp_path):
"""Create a minimal model for testing the heuristic flow."""
fe = GemmUniversalFeatureEngine()
n_features = len(fe.get_feature_names())
np.random.seed(42)
X = np.random.rand(100, n_features)
y = np.random.rand(100) * 500
model = lgb.LGBMRegressor(n_estimators=5, verbose=-1)
model.fit(X, y)
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
return tmp_path
def _make_pool(self):
"""Create a small synthetic kernel pool."""
names = [
"gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
"gemm_universal_fp8_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_2x2x1_32x32x16",
"gemm_universal_fp8_rcr_mem_cshuffle_interwave_False_False_False_False_64x64x128_1x4x1_16x16x32",
]
return [kernel_config_to_feature_dict(n) for n in names]
def test_returns_ml_kernel_spec(self, mock_model_dir):
pool = self._make_pool()
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
result = heuristic(1024, 1024, 1024)
assert isinstance(result, MLKernelSpec)
assert result.tile_m > 0
assert isinstance(result.predicted_tflops, float)
def test_returns_valid_kernel_from_pool(self, mock_model_dir):
pool = self._make_pool()
pool_names = {p["kernel_name"] for p in pool}
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
result = heuristic(1024, 1024, 1024)
assert result.kernel_name in pool_names
def test_different_shapes_may_select_different_kernels(self, mock_model_dir):
pool = self._make_pool()
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
r1 = heuristic(16, 1536, 7168)
r2 = heuristic(8192, 8192, 256)
# At minimum both should return valid specs
assert r1.tile_m > 0
assert r2.tile_m > 0
def test_m1_corner_case(self, mock_model_dir):
pool = self._make_pool()
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
result = heuristic(1, 4096, 4096)
assert isinstance(result, MLKernelSpec)
assert np.isfinite(result.predicted_tflops)
def test_empty_pool_raises(self, mock_model_dir):
with pytest.raises(ValueError, match="No kernel configs"):
create_ml_heuristic(mock_model_dir, kernel_pool=[])
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for evaluate.py.
Covers: shape family classification, K-depth regime classification,
and basic evaluation metric checks.
"""
import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from evaluate import classify_shape_family, classify_k_regime
class TestClassifyShapeFamily:
def test_tiny_m(self):
assert classify_shape_family(1, 4096, 4096) == "tiny_m"
assert classify_shape_family(16, 1536, 7168) == "tiny_m"
def test_small_m(self):
assert classify_shape_family(32, 1536, 7168) == "small_m"
assert classify_shape_family(128, 4096, 4096) == "small_m"
def test_medium_m(self):
assert classify_shape_family(256, 1024, 1024) == "medium_m"
assert classify_shape_family(2048, 2048, 2048) == "medium_m"
def test_large_m(self):
assert classify_shape_family(4096, 4096, 4096) == "large_m"
assert classify_shape_family(20480, 7168, 256) == "large_m"
class TestClassifyKRegime:
def test_shallow(self):
assert classify_k_regime(256) == "shallow_k"
assert classify_k_regime(32) == "shallow_k"
def test_medium(self):
assert classify_k_regime(1024) == "medium_k"
assert classify_k_regime(2048) == "medium_k"
def test_deep(self):
assert classify_k_regime(4096) == "deep_k"
assert classify_k_regime(7168) == "deep_k"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,409 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for feature_engine.py.
Covers: feature count consistency, formula correctness (tile efficiency, LDS,
arithmetic intensity), corner-case shapes (M=1, huge M, square, skinny-K),
parameter space validity, config validation, and batch vs single extraction parity.
"""
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from feature_engine import (
GemmUniversalFeatureEngine,
)
@pytest.fixture
def fe():
"""Default feature engine with MI355X-like hardware."""
return GemmUniversalFeatureEngine(
num_cus=256,
lds_capacity=65536,
max_clock_mhz=2400,
simds_per_cu=4,
shader_engines=32,
max_waves_per_cu=32,
wavefront_size=64,
l1_cache_kb=32,
l2_cache_kb=4096,
l3_cache_kb=262144,
num_xcd=8,
)
def _make_problem(m=1024, n=1024, k=1024, dtype="fp8", layout="rcr", split_k=1):
return {
"m": m,
"n": n,
"k": k,
"dtype": dtype,
"layout": layout,
"split_k": split_k,
}
def _make_kernel(
tile_m=128,
tile_n=128,
tile_k=64,
warp_m=2,
warp_n=2,
warp_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
pad_m=False,
pad_n=False,
pad_k=False,
persistent=False,
):
return {
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
"pipeline": pipeline,
"scheduler": scheduler,
"epilogue": epilogue,
"pad_m": pad_m,
"pad_n": pad_n,
"pad_k": pad_k,
"persistent": persistent,
}
# ---------------------------------------------------------------------------
# Basic consistency
# ---------------------------------------------------------------------------
class TestFeatureConsistency:
def test_feature_count_matches_names(self, fe):
prob = _make_problem()
kern = _make_kernel()
vec = fe.extract(prob, kern)
assert len(vec) == len(fe.get_feature_names())
def test_feature_count_is_72(self, fe):
assert len(fe.get_feature_names()) == 72
def test_no_nan_in_output(self, fe):
prob = _make_problem()
kern = _make_kernel()
vec = fe.extract(prob, kern)
assert not np.any(np.isnan(vec))
def test_no_inf_in_output(self, fe):
prob = _make_problem()
kern = _make_kernel()
vec = fe.extract(prob, kern)
assert not np.any(np.isinf(vec))
def test_categorical_features_in_names(self, fe):
names = fe.get_feature_names()
for cat in fe.get_categorical_features():
assert cat in names
# ---------------------------------------------------------------------------
# Formula correctness
# ---------------------------------------------------------------------------
class TestTileEfficiency:
"""Tile efficiency: fraction of the last tile that is useful work."""
def test_perfectly_divisible(self, fe):
prob = _make_problem(m=256, n=256, k=128)
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
vec = fe.extract(prob, kern)
names = fe.get_feature_names()
assert vec[names.index("tile_eff_m")] == 1.0
assert vec[names.index("tile_eff_n")] == 1.0
assert vec[names.index("tile_eff_k")] == 1.0
assert vec[names.index("overall_tile_efficiency")] == 1.0
def test_not_divisible(self, fe):
prob = _make_problem(m=100, n=100, k=100)
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
vec = fe.extract(prob, kern)
names = fe.get_feature_names()
assert vec[names.index("tile_eff_m")] == pytest.approx(100 / 128)
assert vec[names.index("tile_eff_n")] == pytest.approx(100 / 128)
assert vec[names.index("tile_eff_k")] == pytest.approx(36 / 64)
def test_m_equals_1(self, fe):
"""Single-token inference: M=1, tile_m=128 => eff = 1/128."""
prob = _make_problem(m=1)
kern = _make_kernel(tile_m=128)
vec = fe.extract(prob, kern)
names = fe.get_feature_names()
assert vec[names.index("tile_eff_m")] == pytest.approx(1.0 / 128.0)
class TestLDSUsage:
def test_lds_formula(self, fe):
prob = _make_problem(dtype="fp8")
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
vec = fe.extract(prob, kern)
names = fe.get_feature_names()
expected = (128 * 64 + 128 * 64) * 1.0 # fp8 = 1 byte
assert vec[names.index("lds_usage_estimate")] == expected
def test_lds_ratio_compv4(self, fe):
"""compv4 has 32KB LDS limit, not 64KB."""
prob = _make_problem(dtype="fp8")
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64, pipeline="compv4")
vec = fe.extract(prob, kern)
names = fe.get_feature_names()
lds_est = (128 * 64 + 128 * 64) * 1.0
assert vec[names.index("lds_usage_ratio")] == pytest.approx(lds_est / 32768)
def test_lds_fp16_doubles(self, fe):
prob = _make_problem(dtype="fp16")
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
vec = fe.extract(prob, kern)
names = fe.get_feature_names()
expected = (128 * 64 + 128 * 64) * 2.0 # fp16 = 2 bytes
assert vec[names.index("lds_usage_estimate")] == expected
class TestArithmeticIntensity:
def test_square_shape(self, fe):
M, N, K = 1024, 1024, 1024
prob = _make_problem(m=M, n=N, k=K, dtype="fp8")
kern = _make_kernel()
vec = fe.extract(prob, kern)
names = fe.get_feature_names()
mem = (M * K + K * N + M * N) * 1.0
expected = (2.0 * M * N * K) / mem
assert vec[names.index("arithmetic_intensity")] == pytest.approx(expected)
def test_skinny_k(self, fe):
"""Small K => low arithmetic intensity (memory-bound)."""
prob = _make_problem(m=8192, n=8192, k=32, dtype="fp8")
kern = _make_kernel()
vec = fe.extract(prob, kern)
names = fe.get_feature_names()
assert vec[names.index("arithmetic_intensity")] < 100
def test_deep_k(self, fe):
"""Large K => high arithmetic intensity (compute-bound)."""
prob = _make_problem(m=256, n=256, k=8192, dtype="fp8")
kern = _make_kernel()
vec = fe.extract(prob, kern)
names = fe.get_feature_names()
assert vec[names.index("arithmetic_intensity")] > 100
# ---------------------------------------------------------------------------
# Corner-case shapes
# ---------------------------------------------------------------------------
class TestCornerCaseShapes:
def test_m1_single_token(self, fe):
vec = fe.extract(_make_problem(m=1, n=4096, k=4096), _make_kernel())
assert not np.any(np.isnan(vec))
def test_m1_n1_k1_minimum(self, fe):
vec = fe.extract(_make_problem(m=1, n=1, k=1), _make_kernel())
assert not np.any(np.isnan(vec))
assert not np.any(np.isinf(vec))
def test_very_large_m(self, fe):
vec = fe.extract(_make_problem(m=20480, n=7168, k=7168), _make_kernel())
assert not np.any(np.isnan(vec))
def test_non_power_of_2(self, fe):
vec = fe.extract(_make_problem(m=1536, n=7168, k=2304), _make_kernel())
assert not np.any(np.isnan(vec))
def test_prime_dimensions(self, fe):
vec = fe.extract(_make_problem(m=17, n=31, k=127), _make_kernel())
assert not np.any(np.isnan(vec))
def test_tall_matrix(self, fe):
"""M >> N (tall matrix)."""
prob = _make_problem(m=16384, n=64, k=1024)
vec = fe.extract(prob, _make_kernel())
names = fe.get_feature_names()
assert vec[names.index("aspect_ratio_mn")] > 100
def test_wide_matrix(self, fe):
"""N >> M (wide matrix)."""
prob = _make_problem(m=64, n=16384, k=1024)
vec = fe.extract(prob, _make_kernel())
names = fe.get_feature_names()
assert vec[names.index("aspect_ratio_mn")] < 0.01
# ---------------------------------------------------------------------------
# Batch vs single extraction parity
# ---------------------------------------------------------------------------
class TestBatchParity:
def test_batch_matches_single(self, fe):
"""Vectorized batch should produce identical results to row-by-row."""
rows = [
{
"m": 16,
"n": 1536,
"k": 7168,
"split_k": 1,
"dtype": "fp8",
"layout": "rcr",
"tile_m": 128,
"tile_n": 128,
"tile_k": 128,
"warp_m": 1,
"warp_n": 4,
"warp_k": 1,
"warp_tile_m": 16,
"warp_tile_n": 16,
"warp_tile_k": 128,
"pipeline": "compv3",
"scheduler": "intrawave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
},
{
"m": 20480,
"n": 7168,
"k": 256,
"split_k": 1,
"dtype": "fp8",
"layout": "rcr",
"tile_m": 64,
"tile_n": 64,
"tile_k": 128,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": "mem",
"scheduler": "interwave",
"epilogue": "default",
"pad_m": True,
"pad_n": True,
"pad_k": True,
"persistent": True,
},
]
df = pd.DataFrame(rows)
batch_result = fe.extract_batch(df)
for i, row_dict in enumerate(rows):
single_result = fe.extract(row_dict, row_dict)
np.testing.assert_allclose(
batch_result[i],
single_result,
rtol=1e-5,
atol=1e-5,
err_msg=f"Mismatch at row {i}",
)
# ---------------------------------------------------------------------------
# Parameter space and validation
# ---------------------------------------------------------------------------
class TestParameterSpace:
def test_parameter_space_non_empty(self, fe):
ps = fe.get_parameter_space()
assert len(ps) > 0
assert "tile_m" in ps
assert "pipeline" in ps
def test_valid_config_passes(self, fe):
config = {
"tile_m": 128,
"tile_n": 128,
"tile_k": 64,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"pipeline": "compv3",
"scheduler": "intrawave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
}
assert fe.validate_config(config) is True
def test_invalid_tile_rejected(self, fe):
config = {"tile_m": 999}
assert fe.validate_config(config) is False
def test_lds_constraint_rejects_huge_tile(self, fe):
config = {
"tile_m": 256,
"tile_n": 256,
"tile_k": 256,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"pipeline": "compv4",
}
assert fe.validate_config(config) is False
def test_project_to_valid_snaps(self, fe):
config = {"tile_m": 100, "tile_n": 200, "pipeline": "compv3"}
projected = fe.project_to_valid(config)
assert projected["tile_m"] == 128
assert projected["tile_n"] == 192
assert projected["pipeline"] == "compv3"
# ---------------------------------------------------------------------------
# Hardware features
# ---------------------------------------------------------------------------
class TestHardwareFeatures:
def test_hardware_values_propagated(self, fe):
vec = fe.extract(_make_problem(), _make_kernel())
names = fe.get_feature_names()
assert vec[names.index("hw_num_cus")] == 256
assert vec[names.index("hw_max_clock_mhz")] == 2400
assert vec[names.index("hw_total_simds")] == 256 * 4
assert vec[names.index("hw_num_xcd")] == 8
def test_different_hardware(self):
fe_small = GemmUniversalFeatureEngine(num_cus=120, max_clock_mhz=1800)
vec = fe_small.extract(_make_problem(), _make_kernel())
names = fe_small.get_feature_names()
assert vec[names.index("hw_num_cus")] == 120
assert vec[names.index("hw_max_clock_mhz")] == 1800
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,357 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Test that the C++ extract_features() in ml_heuristic.hpp produces identical
values to the Python GemmUniversalFeatureEngine.extract().
This test uses ctypes to call the C++ feature extraction compiled into a
small shared library, then compares against Python output. If compilation
fails (no HIP/ROCm), it falls back to verifying the Python feature engine
against manually computed expected values for specific test cases.
"""
import math
import sys
from pathlib import Path
import numpy as np
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from feature_engine import (
GemmUniversalFeatureEngine,
PIPELINE_MAP,
SCHEDULER_MAP,
EPILOGUE_MAP,
LAYOUT_MAP,
)
def _compute_features_manually(
M,
N,
K,
split_k,
dtype,
layout,
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline,
scheduler,
epilogue,
pad_m,
pad_n,
pad_k,
persistent,
hw,
):
"""Recompute features independently to verify the Python engine."""
bpe_map = {"fp8": 1.0, "fp16": 2.0, "bf16": 2.0, "fp32": 4.0}
bpe = bpe_map.get(dtype, 1.0)
log2_M = math.log2(max(M, 1))
log2_N = math.log2(max(N, 1))
log2_K = math.log2(max(K, 1))
log2_MNK = math.log2(max(M * N * K, 1))
mem = (M * K + K * N + M * N) * bpe
ai = (2.0 * M * N * K) / max(mem, 1)
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
lds_cap = 32768 if pipeline == "compv4" else hw["lds_capacity"]
ntm = math.ceil(M / max(tile_m, 1))
ntn = math.ceil(N / max(tile_n, 1))
ntk = math.ceil(K / max(tile_k, 1))
def eff(d, t):
if t <= 0:
return 1.0
r = d % t
return r / t if r > 0 else 1.0
# Problem-to-tile ratios
ratio_M_to_tile_m = M / max(tile_m, 1)
ratio_N_to_tile_n = N / max(tile_n, 1)
ratio_K_to_tile_k = K / max(tile_k, 1)
# Binary features: problem smaller than tile
problem_smaller_than_tile_m = float(M < tile_m)
problem_smaller_than_tile_n = float(N < tile_n)
problem_smaller_than_tile_k = float(K < tile_k)
any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k))
# Padding requirement features
needs_padding_m = float(tile_m > 0 and M % tile_m != 0)
needs_padding_n = float(tile_n > 0 and N % tile_n != 0)
needs_padding_k = float(tile_k > 0 and K % tile_k != 0)
# Interaction features
has_padding_when_needed_m = float(needs_padding_m and pad_m)
has_padding_when_needed_n = float(needs_padding_n and pad_n)
has_padding_when_needed_k = float(needs_padding_k and pad_k)
# Missing padding features
missing_required_padding_m = float(needs_padding_m and not pad_m)
missing_required_padding_n = float(needs_padding_n and not pad_n)
missing_required_padding_k = float(needs_padding_k and not pad_k)
missing_any_required_padding = float(
missing_required_padding_m or missing_required_padding_n or missing_required_padding_k
)
return [
M, # 0
N, # 1
K, # 2
split_k, # 3
log2_M, # 4
log2_N, # 5
log2_K, # 6
log2_MNK, # 7
ai, # 8
M / max(N, 1), # 9 (aspect_ratio_mn)
M / max(K, 1), # 10 (aspect_ratio_mk)
N / max(K, 1), # 11 (aspect_ratio_nk)
LAYOUT_MAP.get(layout, 0), # 12
tile_m, # 13
tile_n, # 14
tile_k, # 15
warp_m, # 16
warp_n, # 17
warp_k, # 18
warp_tile_m, # 19
warp_tile_n, # 20
warp_tile_k, # 21
PIPELINE_MAP.get(pipeline, 0), # 22
SCHEDULER_MAP.get(scheduler, 0), # 23
EPILOGUE_MAP.get(epilogue, 0), # 24
float(pad_m), # 25
float(pad_n), # 26
float(pad_k), # 27
float(persistent), # 28
warp_m * warp_n * warp_k, # 29 (num_warps)
tile_m * tile_n * tile_k, # 30 (tile_volume)
tile_m * tile_n, # 31 (tile_mn)
lds_est, # 32 (lds_usage_estimate)
lds_est / max(lds_cap, 1), # 33 (lds_usage_ratio)
ntm, # 34 (num_tiles_m)
ntn, # 35 (num_tiles_n)
ntk, # 36 (num_tiles_k)
ntm * ntn, # 37 (total_output_tiles)
eff(M, tile_m), # 38 (tile_eff_m)
eff(N, tile_n), # 39 (tile_eff_n)
eff(K, tile_k), # 40 (tile_eff_k)
eff(M, tile_m) * eff(N, tile_n) * eff(K, tile_k), # 41 (overall_tile_efficiency)
ntm * ntn / max(hw["num_cus"], 1), # 42 (cu_utilization)
ratio_M_to_tile_m, # 43
ratio_N_to_tile_n, # 44
ratio_K_to_tile_k, # 45
problem_smaller_than_tile_m, # 46
problem_smaller_than_tile_n, # 47
problem_smaller_than_tile_k, # 48
any_dim_too_small, # 49
needs_padding_m, # 50
needs_padding_n, # 51
needs_padding_k, # 52
has_padding_when_needed_m, # 53
has_padding_when_needed_n, # 54
has_padding_when_needed_k, # 55
missing_required_padding_m, # 56
missing_required_padding_n, # 57
missing_required_padding_k, # 58
missing_any_required_padding, # 59
hw["num_cus"], # 60
hw["simds_per_cu"], # 61
hw["num_cus"] * hw["simds_per_cu"], # 62 (total_simds)
hw["shader_engines"], # 63
hw["max_clock_mhz"], # 64
hw["max_waves_per_cu"], # 65
hw["wavefront_size"], # 66
hw["lds_capacity"], # 67
hw["l1_cache_kb"], # 68
hw["l2_cache_kb"], # 69
hw["l3_cache_kb"], # 70
hw["num_xcd"], # 71
]
TEST_CASES = [
{
"problem": {
"m": 1024,
"n": 1024,
"k": 1024,
"split_k": 1,
"dtype": "fp8",
"layout": "rcr",
},
"kernel": {
"tile_m": 128,
"tile_n": 128,
"tile_k": 64,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": "compv3",
"scheduler": "intrawave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
},
},
{
"problem": {
"m": 1,
"n": 4096,
"k": 4096,
"split_k": 1,
"dtype": "fp8",
"layout": "rcr",
},
"kernel": {
"tile_m": 64,
"tile_n": 64,
"tile_k": 128,
"warp_m": 1,
"warp_n": 4,
"warp_k": 1,
"warp_tile_m": 16,
"warp_tile_n": 16,
"warp_tile_k": 128,
"pipeline": "compv4",
"scheduler": "interwave",
"epilogue": "default",
"pad_m": True,
"pad_n": True,
"pad_k": True,
"persistent": True,
},
},
{
"problem": {
"m": 20480,
"n": 7168,
"k": 256,
"split_k": 1,
"dtype": "fp16",
"layout": "rrr",
},
"kernel": {
"tile_m": 256,
"tile_n": 256,
"tile_k": 32,
"warp_m": 4,
"warp_n": 1,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": "mem",
"scheduler": "interwave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
},
},
]
HW = {
"num_cus": 256,
"simds_per_cu": 4,
"shader_engines": 32,
"max_clock_mhz": 2400,
"max_waves_per_cu": 32,
"wavefront_size": 64,
"lds_capacity": 65536,
"l1_cache_kb": 32,
"l2_cache_kb": 4096,
"l3_cache_kb": 262144,
"num_xcd": 8,
}
class TestFeatureParity:
"""Verify Python feature engine matches manual computation (C++ uses same logic)."""
@pytest.fixture
def fe(self):
return GemmUniversalFeatureEngine(**HW)
@pytest.mark.parametrize("case_idx", range(len(TEST_CASES)))
def test_python_matches_manual(self, fe, case_idx):
case = TEST_CASES[case_idx]
prob = case["problem"]
kern = case["kernel"]
py_features = fe.extract(prob, kern)
manual = _compute_features_manually(
prob["m"],
prob["n"],
prob["k"],
prob["split_k"],
prob["dtype"],
prob["layout"],
kern["tile_m"],
kern["tile_n"],
kern["tile_k"],
kern["warp_m"],
kern["warp_n"],
kern["warp_k"],
kern["warp_tile_m"],
kern["warp_tile_n"],
kern["warp_tile_k"],
kern["pipeline"],
kern["scheduler"],
kern["epilogue"],
kern["pad_m"],
kern["pad_n"],
kern["pad_k"],
kern["persistent"],
HW,
)
manual_arr = np.array(manual, dtype=np.float64)
assert len(py_features) == len(manual_arr) == 72
for i in range(72):
assert py_features[i] == pytest.approx(
manual_arr[i], rel=1e-10, abs=1e-15
), (
f"Feature {i} ({fe.get_feature_names()[i]}): Python={py_features[i]}, Manual={manual_arr[i]}"
)
def test_feature_count(self, fe):
assert len(fe.get_feature_names()) == 72
def test_encoding_maps_match_cpp(self):
"""The C++ encode_* functions must use the same mapping as Python."""
assert PIPELINE_MAP == {
"compv3": 0,
"compv4": 1,
"compv5": 2,
"mem": 3,
"preshufflev2": 4,
}
assert SCHEDULER_MAP == {"intrawave": 0, "interwave": 1}
assert EPILOGUE_MAP == {"default": 0, "cshuffle": 1}
assert LAYOUT_MAP == {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3}
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python3
"""Test that compressed models can be loaded and used."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from predict import Predictor
def test_fp16_model_decompression():
"""Test that fp16 model is auto-decompressed and usable."""
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp16_gfx950"
# Ensure .lgbm.gz exists
gz_file = model_dir / "model_tflops.lgbm.gz"
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
# Load predictor - should auto-decompress
predictor = Predictor(model_dir)
# Test prediction
problem = {"m": 128, "n": 1536, "k": 7168, "dtype": "fp16", "layout": "rcr"}
kernel_config = {
"tile_shape": {"m0": 128, "n0": 128, "k0": 16},
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
"warp_tile": {"m2": 32, "n2": 32, "k2": 8},
}
tflops = predictor.predict_tflops(problem, kernel_config)
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
# Verify decompressed file was created
lgbm_file = model_dir / "model_tflops.lgbm"
assert lgbm_file.exists(), "Model should have been decompressed"
print(f"✅ FP16 model decompression test passed")
print(f" Predicted TFLOPS: {tflops:.2f}")
print(f" Decompressed to: {lgbm_file}")
return True
def test_fp8_model_decompression():
"""Test that fp8 model is auto-decompressed and usable."""
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp8_gfx950"
# Ensure .lgbm.gz exists
gz_file = model_dir / "model_tflops.lgbm.gz"
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
# Load predictor - should auto-decompress
predictor = Predictor(model_dir)
# Test prediction
problem = {"m": 2048, "n": 2048, "k": 2048, "dtype": "fp8", "layout": "rcr"}
kernel_config = {
"tile_shape": {"m0": 256, "n0": 256, "k0": 64},
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
"warp_tile": {"m2": 32, "n2": 32, "k2": 16},
}
tflops = predictor.predict_tflops(problem, kernel_config)
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
# Verify decompressed file was created
lgbm_file = model_dir / "model_tflops.lgbm"
assert lgbm_file.exists(), "Model should have been decompressed"
print(f"✅ FP8 model decompression test passed")
print(f" Predicted TFLOPS: {tflops:.2f}")
print(f" Decompressed to: {lgbm_file}")
return True
if __name__ == "__main__":
print("Testing compressed model auto-decompression...")
print()
test_fp16_model_decompression()
print()
test_fp8_model_decompression()
print()
print("✅ All model compression tests passed!")

View File

@@ -0,0 +1,181 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for predict.py.
Covers: Predictor initialization, single prediction, ranking, select_best,
missing model handling, and edge cases (single kernel, empty list).
"""
import json
import sys
from pathlib import Path
import lightgbm as lgb
import numpy as np
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from feature_engine import GemmUniversalFeatureEngine
from predict import Predictor
@pytest.fixture
def model_dir(tmp_path):
"""Create a minimal trained model for testing."""
fe = GemmUniversalFeatureEngine()
n_features = len(fe.get_feature_names())
np.random.seed(42)
X = np.random.rand(200, n_features)
y = np.random.rand(200) * 100
model = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
model.fit(X, y)
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
y_lat = np.random.rand(200) * 0.1
model_lat = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
model_lat.fit(X, y_lat)
model_lat.booster_.save_model(str(tmp_path / "model_latency.lgbm"))
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
return tmp_path
@pytest.fixture
def predictor(model_dir):
return Predictor(model_dir)
def _problem():
return {
"m": 1024,
"n": 1024,
"k": 1024,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
}
def _kernel(tile_m=128, pipeline="compv3"):
return {
"kernel_name": f"test_kernel_{tile_m}_{pipeline}",
"tile_m": tile_m,
"tile_n": 128,
"tile_k": 64,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": pipeline,
"scheduler": "intrawave",
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
}
class TestPredictor:
def test_predict_tflops_returns_float(self, predictor):
result = predictor.predict_tflops(_problem(), _kernel())
assert isinstance(result, float)
def test_predict_latency_returns_float(self, predictor):
result = predictor.predict_latency(_problem(), _kernel())
assert isinstance(result, float)
def test_predict_all_returns_dict(self, predictor):
result = predictor.predict_all(_problem(), _kernel())
assert "tflops" in result
assert "latency_ms" in result
def test_rank_kernels_sorted_descending(self, predictor):
kernels = [_kernel(64, "compv3"), _kernel(128, "compv4"), _kernel(256, "mem")]
ranked = predictor.rank_kernels(_problem(), kernels)
assert len(ranked) == 3
scores = [s for _, s in ranked]
assert scores == sorted(scores, reverse=True)
def test_select_best_returns_name(self, predictor):
kernels = [_kernel(64), _kernel(128)]
best = predictor.select_best(_problem(), kernels)
assert isinstance(best, str)
assert best in [k["kernel_name"] for k in kernels]
def test_single_kernel(self, predictor):
kernels = [_kernel(128)]
ranked = predictor.rank_kernels(_problem(), kernels)
assert len(ranked) == 1
def test_missing_bandwidth_model(self, model_dir):
pred = Predictor(model_dir)
with pytest.raises(FileNotFoundError):
pred.predict_bandwidth(_problem(), _kernel())
def test_empty_kernel_list(self, predictor):
with pytest.raises(ValueError):
predictor.select_best(_problem(), [])
def test_corner_case_m1(self, predictor):
prob = {
"m": 1,
"n": 4096,
"k": 4096,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
}
result = predictor.predict_tflops(prob, _kernel())
assert np.isfinite(result)
def test_different_shapes_give_different_results(self, predictor):
k = _kernel()
r1 = predictor.predict_tflops(
{
"m": 16,
"n": 1536,
"k": 7168,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
},
k,
)
r2 = predictor.predict_tflops(
{
"m": 20480,
"n": 7168,
"k": 256,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
},
k,
)
assert r1 != r2
class TestPredictorEdgeCases:
def test_nonexistent_model_dir(self):
with pytest.raises(Exception):
pred = Predictor("/nonexistent/path")
pred.predict_tflops(_problem(), _kernel())
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,192 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for search.py.
Covers: random search, DE search, config validity, result ordering,
budget compliance, and edge cases.
"""
import json
import sys
from pathlib import Path
import lightgbm as lgb
import numpy as np
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from feature_engine import GemmUniversalFeatureEngine
from predict import Predictor
from search import SurrogateSearch
@pytest.fixture
def model_dir(tmp_path):
"""Create a minimal trained model."""
fe = GemmUniversalFeatureEngine()
n_features = len(fe.get_feature_names())
np.random.seed(42)
X = np.random.rand(200, n_features)
y = np.random.rand(200) * 500
model = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
model.fit(X, y)
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
return tmp_path
@pytest.fixture
def predictor(model_dir):
return Predictor(model_dir)
def _problem():
return {
"m": 1024,
"n": 1024,
"k": 1024,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
}
class TestRandomSearch:
def test_returns_results(self, predictor):
searcher = SurrogateSearch(predictor, strategy="random")
results = searcher.search(_problem(), budget=50, top_k=5)
assert len(results) > 0
assert len(results) <= 5
def test_results_sorted_descending(self, predictor):
searcher = SurrogateSearch(predictor, strategy="random")
results = searcher.search(_problem(), budget=100, top_k=10)
scores = [s for _, s in results]
assert scores == sorted(scores, reverse=True)
def test_configs_are_valid(self, predictor):
fe = GemmUniversalFeatureEngine()
searcher = SurrogateSearch(predictor, feature_engine=fe, strategy="random")
results = searcher.search(_problem(), budget=50, top_k=5)
for cfg, _ in results:
ps = fe.get_parameter_space()
for k, v in cfg.items():
if k in ps:
assert v in ps[k], f"{k}={v} not in {ps[k]}"
def test_respects_top_k(self, predictor):
searcher = SurrogateSearch(predictor, strategy="random")
results = searcher.search(_problem(), budget=100, top_k=3)
assert len(results) <= 3
def test_different_problems_produce_results(self, predictor):
"""Both problem sizes should produce valid search results."""
searcher = SurrogateSearch(predictor, strategy="random", seed=42)
r1 = searcher.search(
{
"m": 16,
"n": 1536,
"k": 7168,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
},
budget=50,
top_k=3,
)
searcher2 = SurrogateSearch(predictor, strategy="random", seed=42)
r2 = searcher2.search(
{
"m": 20480,
"n": 7168,
"k": 256,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
},
budget=50,
top_k=3,
)
assert len(r1) > 0
assert len(r2) > 0
for _, score in r1 + r2:
assert np.isfinite(score)
def test_m1_corner_case(self, predictor):
searcher = SurrogateSearch(predictor, strategy="random")
results = searcher.search(
{
"m": 1,
"n": 4096,
"k": 4096,
"dtype": "fp8",
"layout": "rcr",
"split_k": 1,
},
budget=50,
top_k=5,
)
assert len(results) > 0
for _, score in results:
assert np.isfinite(score)
class TestDESearch:
def test_returns_results(self, predictor):
searcher = SurrogateSearch(predictor, strategy="de")
results = searcher.search(_problem(), budget=100, top_k=5)
assert len(results) > 0
def test_results_sorted_descending(self, predictor):
searcher = SurrogateSearch(predictor, strategy="de")
results = searcher.search(_problem(), budget=100, top_k=5)
scores = [s for _, s in results]
assert scores == sorted(scores, reverse=True)
def test_de_improves_over_initial(self, predictor):
"""DE should generally find at least as good as random initialization."""
searcher_r = SurrogateSearch(predictor, strategy="random", seed=42)
r_results = searcher_r.search(_problem(), budget=100, top_k=1)
searcher_d = SurrogateSearch(predictor, strategy="de", seed=42)
d_results = searcher_d.search(_problem(), budget=100, top_k=1)
if r_results and d_results:
assert d_results[0][1] >= r_results[0][1] * 0.9
def test_small_budget(self, predictor):
searcher = SurrogateSearch(predictor, strategy="de")
results = searcher.search(_problem(), budget=30, top_k=5)
assert len(results) > 0
class TestSearchEdgeCases:
def test_unknown_strategy_raises(self, predictor):
searcher = SurrogateSearch(predictor, strategy="unknown")
with pytest.raises(ValueError):
searcher.search(_problem(), budget=10)
def test_zero_budget(self, predictor):
searcher = SurrogateSearch(predictor, strategy="random")
results = searcher.search(_problem(), budget=0, top_k=5)
assert len(results) == 0
def test_deterministic_with_same_seed(self, predictor):
s1 = SurrogateSearch(predictor, strategy="random", seed=123)
s2 = SurrogateSearch(predictor, strategy="random", seed=123)
r1 = s1.search(_problem(), budget=50, top_k=5)
r2 = s2.search(_problem(), budget=50, top_k=5)
assert len(r1) == len(r2)
for (c1, s1_), (c2, s2_) in zip(r1, r2):
assert s1_ == pytest.approx(s2_)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,329 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for train.py.
Covers: group key computation, TFLOPS efficiency calculation, edge cases
(single group, all-invalid data, tied predictions), and warm-start
incremental training (feature compat, lineage, quality).
"""
import json
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from feature_engine import GemmUniversalFeatureEngine
from train import (
compute_group_keys,
compute_tflops_efficiency,
check_feature_compatibility,
load_warm_start_model,
train_final_model,
DEFAULT_PARAMS,
)
class TestComputeGroupKeys:
def test_basic(self):
df = pd.DataFrame(
{"m": [16, 16, 32], "n": [1536, 1536, 1536], "k": [7168, 7168, 7168]}
)
keys = compute_group_keys(df)
assert keys[0] == keys[1]
assert keys[0] != keys[2]
def test_unique_shapes(self):
df = pd.DataFrame({"m": [1, 2, 3], "n": [4, 5, 6], "k": [7, 8, 9]})
keys = compute_group_keys(df)
assert len(set(keys)) == 3
class TestComputeTflopsEfficiency:
def test_perfect_prediction(self):
"""Model predicts highest TFLOPS kernel => efficiency = 1.0."""
df = pd.DataFrame(
{
"m": [1024, 1024, 1024],
"n": [1024, 1024, 1024],
"k": [1024, 1024, 1024],
"measured_tflops": [100, 200, 150],
"pred_tflops": [50, 300, 100], # correctly ranks kernel 1 highest
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 1
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
def test_worst_prediction(self):
"""Model picks the worst kernel."""
df = pd.DataFrame(
{
"m": [1024, 1024, 1024],
"n": [1024, 1024, 1024],
"k": [1024, 1024, 1024],
"measured_tflops": [100, 200, 150],
"pred_tflops": [999, 1, 1], # incorrectly ranks kernel 0 highest
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert eff["efficiency"].iloc[0] == pytest.approx(100 / 200)
def test_multiple_shapes(self):
df = pd.DataFrame(
{
"m": [16, 16, 32, 32],
"n": [1536, 1536, 1536, 1536],
"k": [7168, 7168, 7168, 7168],
"measured_tflops": [10, 20, 100, 200],
"pred_tflops": [5, 25, 150, 190],
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 2
assert eff.iloc[0]["efficiency"] == pytest.approx(1.0)
assert eff.iloc[1]["efficiency"] == pytest.approx(1.0)
def test_zero_tflops_shape_skipped(self):
df = pd.DataFrame(
{
"m": [16, 16],
"n": [16, 16],
"k": [16, 16],
"measured_tflops": [0, 0],
"pred_tflops": [1, 2],
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 0
def test_single_kernel_per_shape(self):
df = pd.DataFrame(
{
"m": [1024],
"n": [1024],
"k": [1024],
"measured_tflops": [150],
"pred_tflops": [100],
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 1
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
def test_tied_predictions(self):
"""When multiple kernels have the same predicted TFLOPS, pandas idxmax picks the first."""
df = pd.DataFrame(
{
"m": [1024, 1024, 1024],
"n": [1024, 1024, 1024],
"k": [1024, 1024, 1024],
"measured_tflops": [100, 200, 200],
"pred_tflops": [50, 50, 50],
}
)
eff = compute_tflops_efficiency(df, "pred_tflops")
assert len(eff) == 1
assert eff["efficiency"].iloc[0] >= 0.5
# ---------------------------------------------------------------------------
# Helpers for warm-start tests
# ---------------------------------------------------------------------------
def _make_dummy_data(n_rows=200, n_shapes=5):
"""Create a small synthetic benchmark DataFrame for testing training."""
rng = np.random.RandomState(42)
rows = []
for _ in range(n_rows):
m = rng.choice([64, 128, 256, 512, 1024])
n = rng.choice([64, 128, 256, 512, 1024])
k = rng.choice([64, 128, 256, 512, 1024])
rows.append(
{
"m": m,
"n": n,
"k": k,
"split_k": 1,
"dtype": "fp8",
"layout": "rcr",
"op_type": "gemm_universal",
"tile_m": rng.choice([64, 128, 256]),
"tile_n": rng.choice([64, 128, 256]),
"tile_k": rng.choice([32, 64, 128]),
"warp_m": rng.choice([1, 2, 4]),
"warp_n": rng.choice([1, 2, 4]),
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
"pipeline": rng.choice(["compv3", "compv4", "mem"]),
"scheduler": rng.choice(["intrawave", "interwave"]),
"epilogue": "cshuffle",
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
"measured_tflops": float(rng.uniform(10, 500)),
"latency_ms": float(rng.uniform(0.01, 1.0)),
"bandwidth_gb_s": float(rng.uniform(50, 1500)),
"is_valid": True,
"kernel_name": f"test_kernel_{rng.randint(0, 100)}",
}
)
return pd.DataFrame(rows)
def _save_feature_spec(model_dir, fe):
"""Save a feature_spec.json matching the given feature engine."""
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": fe.get_categorical_features(),
}
with open(model_dir / "feature_spec.json", "w") as f:
json.dump(spec, f)
def _train_and_save_base_model(model_dir, df, fe, target="tflops"):
"""Train a small base model and save it to model_dir."""
params = dict(DEFAULT_PARAMS)
params["n_estimators"] = 20
params["n_jobs"] = 1
model = train_final_model(df, fe, target, params)
model.booster_.save_model(str(model_dir / f"model_{target}.lgbm"))
_save_feature_spec(model_dir, fe)
return model
# ---------------------------------------------------------------------------
# Warm-start tests
# ---------------------------------------------------------------------------
class TestCheckFeatureCompatibility:
def test_compatible_passes(self, tmp_path):
fe = GemmUniversalFeatureEngine()
_save_feature_spec(tmp_path, fe)
check_feature_compatibility(tmp_path, fe)
def test_missing_spec_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
with pytest.raises(FileNotFoundError, match="feature_spec.json"):
check_feature_compatibility(tmp_path, fe)
def test_added_feature_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
spec = {
"feature_names": fe.get_feature_names()[:-1],
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
with pytest.raises(ValueError, match="Feature schema mismatch"):
check_feature_compatibility(tmp_path, fe)
def test_removed_feature_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
spec = {
"feature_names": fe.get_feature_names() + ["extra_feature"],
"categorical_features": fe.get_categorical_features(),
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
with pytest.raises(ValueError, match="Feature schema mismatch"):
check_feature_compatibility(tmp_path, fe)
def test_categorical_mismatch_raises(self, tmp_path):
fe = GemmUniversalFeatureEngine()
spec = {
"feature_names": fe.get_feature_names(),
"categorical_features": ["layout", "pipeline"],
}
with open(tmp_path / "feature_spec.json", "w") as f:
json.dump(spec, f)
with pytest.raises(ValueError, match="Categorical feature mismatch"):
check_feature_compatibility(tmp_path, fe)
class TestLoadWarmStartModel:
def test_loads_existing_model(self, tmp_path):
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data()
_train_and_save_base_model(tmp_path, df, fe)
path = load_warm_start_model(tmp_path, "tflops")
assert path is not None
assert Path(path).exists()
def test_returns_none_for_missing_target(self, tmp_path):
assert load_warm_start_model(tmp_path, "tflops") is None
def test_returns_none_for_wrong_target(self, tmp_path):
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data()
_train_and_save_base_model(tmp_path, df, fe, target="tflops")
assert load_warm_start_model(tmp_path, "bandwidth") is None
class TestWarmStartTraining:
def test_warm_start_produces_more_trees(self, tmp_path):
"""A warm-started model should have more trees than the base."""
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data(n_rows=300)
base_dir = tmp_path / "base"
base_dir.mkdir()
base_model = _train_and_save_base_model(base_dir, df, fe)
base_n_trees = base_model.booster_.num_trees()
init_model_path = load_warm_start_model(base_dir, "tflops")
params = dict(DEFAULT_PARAMS)
params["n_estimators"] = 15
params["n_jobs"] = 1
warm_model = train_final_model(
df, fe, "tflops", params, init_model=init_model_path
)
warm_n_trees = warm_model.booster_.num_trees()
assert warm_n_trees > base_n_trees
def test_warm_start_does_not_degrade(self, tmp_path):
"""Warm-started model on the same data should not be significantly worse."""
fe = GemmUniversalFeatureEngine()
df = _make_dummy_data(n_rows=300)
base_dir = tmp_path / "base"
base_dir.mkdir()
base_model = _train_and_save_base_model(base_dir, df, fe)
X = fe.extract_batch(df[df["is_valid"]].reset_index(drop=True))
y = df[df["is_valid"]]["measured_tflops"].values
base_rmse = np.sqrt(np.mean((base_model.predict(X) - y) ** 2))
init_model_path = load_warm_start_model(base_dir, "tflops")
params = dict(DEFAULT_PARAMS)
params["n_estimators"] = 15
params["n_jobs"] = 1
warm_model = train_final_model(
df, fe, "tflops", params, init_model=init_model_path
)
warm_rmse = np.sqrt(np.mean((warm_model.predict(X) - y) ** 2))
assert warm_rmse <= base_rmse * 1.1
def test_warm_start_from_nonexistent_dir(self):
with pytest.raises(FileNotFoundError):
check_feature_compatibility(
Path("/nonexistent/model/dir"), GemmUniversalFeatureEngine()
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,555 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Training script for CK Tile kernel performance prediction.
Trains LGBMRegressor models (TFLOPS, latency, bandwidth) with:
- Log-space regression (log1p transform) for scale-invariant accuracy
- GroupKFold cross-validation (group key = (M, N, K))
- Iterative Hard Example Mining (IHEM)
- Model complexity bounds for C++ deployability
- Optional Optuna hyperparameter tuning
- Warm-start incremental training from a previous model via --warm_start
Log-transform rationale:
GEMM TFLOPS spans 5 orders of magnitude (0.02 for M=1 to 2230 for large
shapes). Raw regression optimizes for absolute RMSE, which means the model
spends all its capacity predicting large shapes accurately and ignores tiny
shapes where TFLOPS is < 10. Training on log1p(TFLOPS) puts all shapes on
equal footing, improving tiny_m efficiency from 84% to 96%.
"""
import argparse
import json
import time
from pathlib import Path
import lightgbm as lgb
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold
from data_pipeline import build_training_dataset
from feature_engine import GemmUniversalFeatureEngine
TARGET_COLUMNS = {
"tflops": "measured_tflops",
"latency": "latency_ms",
"bandwidth": "bandwidth_gb_s",
}
# Targets where log1p transform is applied by default.
# TFLOPS and bandwidth span orders of magnitude; latency is already small-scale.
LOG_TARGETS = {"tflops", "bandwidth"}
DEFAULT_PARAMS = {
"objective": "regression",
"metric": ["rmse", "mae"],
"num_leaves": 255,
"max_depth": 15,
"n_estimators": 2000,
"learning_rate": 0.02,
"min_child_samples": 10,
"subsample": 0.85,
"colsample_bytree": 0.85,
"reg_alpha": 0.05,
"reg_lambda": 0.5,
"verbose": -1,
"n_jobs": 8,
"seed": 42,
}
MAX_ESTIMATORS = 5000
WARM_START_N_ESTIMATORS = 500
def check_feature_compatibility(
prev_model_dir: Path,
feature_engine: GemmUniversalFeatureEngine,
) -> None:
"""Verify that the previous model's feature spec matches the current engine.
Raises ValueError with a detailed message on mismatch. This prevents silent
corruption when warm-starting from a model trained with a different feature
schema (e.g., after adding a new feature or changing an encoding).
"""
spec_path = prev_model_dir / "feature_spec.json"
if not spec_path.exists():
raise FileNotFoundError(
f"No feature_spec.json in {prev_model_dir}. "
"Cannot verify feature compatibility for warm start."
)
with open(spec_path) as f:
prev_spec = json.load(f)
prev_names = prev_spec.get("feature_names", [])
curr_names = feature_engine.get_feature_names()
if prev_names != curr_names:
added = set(curr_names) - set(prev_names)
removed = set(prev_names) - set(curr_names)
parts = ["Feature schema mismatch between previous model and current engine."]
if added:
parts.append(f" Added features: {sorted(added)}")
if removed:
parts.append(f" Removed features: {sorted(removed)}")
if not added and not removed:
parts.append(" Feature order changed (names match but order differs).")
raise ValueError("\n".join(parts))
prev_cats = prev_spec.get("categorical_features", [])
curr_cats = feature_engine.get_categorical_features()
if sorted(prev_cats) != sorted(curr_cats):
raise ValueError(
f"Categorical feature mismatch.\n"
f" Previous: {sorted(prev_cats)}\n"
f" Current: {sorted(curr_cats)}"
)
def load_warm_start_model(prev_model_dir: Path, target: str) -> str | None:
"""Load the path to a previous model file for warm-start, or None if absent.
Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist.
The decompressed file is cached to disk for subsequent loads.
Returns the string path (what LightGBM's init_model expects) rather than
a loaded Booster, because LGBMRegressor.fit(init_model=...) accepts both
path strings and Booster objects and path strings avoid keeping the old
model in memory.
"""
import gzip
model_path = prev_model_dir / f"model_{target}.lgbm"
gz_path = prev_model_dir / f"model_{target}.lgbm.gz"
# Auto-decompress if needed
if not model_path.exists() and gz_path.exists():
print(f" Decompressing {gz_path.name}...")
with gzip.open(gz_path, "rb") as f_in:
with open(model_path, "wb") as f_out:
f_out.write(f_in.read())
if not model_path.exists():
return None
return str(model_path)
def compute_group_keys(df: pd.DataFrame) -> np.ndarray:
"""Create GroupKFold group keys from (M, N, K)."""
return (
df["m"].astype(str) + "_" + df["n"].astype(str) + "_" + df["k"].astype(str)
).values
def compute_tflops_efficiency(
df: pd.DataFrame, pred_col: str = "pred_tflops"
) -> pd.DataFrame:
"""Compute per-shape efficiency: predicted-best TFLOPS / oracle-best TFLOPS."""
results = []
for (m, n, k), group in df.groupby(["m", "n", "k"]):
oracle_best = group["measured_tflops"].max()
if oracle_best <= 0:
continue
pred_best_idx = group[pred_col].idxmax()
selected_tflops = group.loc[pred_best_idx, "measured_tflops"]
efficiency = selected_tflops / oracle_best
results.append(
{
"m": m,
"n": n,
"k": k,
"oracle_best_tflops": oracle_best,
"selected_tflops": selected_tflops,
"efficiency": efficiency,
}
)
return pd.DataFrame(results)
def train_single_target(
X_train,
y_train,
X_val,
y_val,
params: dict,
categorical_features: list[str],
feature_names: list[str],
init_model=None,
) -> lgb.LGBMRegressor:
"""Train a single LGBMRegressor with early stopping.
Parameters
----------
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
If provided, training continues from this model (warm start).
Accepts a file path to a .lgbm file, a Booster instance, or an
LGBMModel instance. The new model adds n_estimators trees on top
of the existing ones.
"""
cat_indices = [
feature_names.index(c) for c in categorical_features if c in feature_names
]
model = lgb.LGBMRegressor(**params)
model.fit(
X_train,
y_train,
eval_set=[(X_val, y_val)],
eval_metric=["rmse"],
callbacks=[
lgb.early_stopping(50, verbose=False),
lgb.log_evaluation(0),
],
categorical_feature=cat_indices if cat_indices else "auto",
init_model=init_model,
)
return model
def run_cv(
df: pd.DataFrame,
feature_engine: GemmUniversalFeatureEngine,
target: str,
params: dict,
n_splits: int = 5,
use_log: bool = True,
) -> dict:
"""Run GroupKFold cross-validation and return OOF predictions + metrics.
Parameters
----------
use_log : bool
If True and target is in LOG_TARGETS, train on log1p(y) and invert
predictions with expm1 for efficiency calculation. This normalizes
the scale so that tiny-M shapes (TFLOPS ~ 1) get equal attention
as large-M shapes (TFLOPS ~ 2000).
"""
target_col = TARGET_COLUMNS[target]
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
df_valid = df[valid_mask].reset_index(drop=True)
apply_log = use_log and target in LOG_TARGETS
print(
f" Training on {len(df_valid)} valid rows for target={target}"
f"{' (log-space)' if apply_log else ''}"
)
X = feature_engine.extract_batch(df_valid)
y_raw = df_valid[target_col].values
y = np.log1p(y_raw) if apply_log else y_raw
groups = compute_group_keys(df_valid)
feature_names = feature_engine.get_feature_names()
cat_features = feature_engine.get_categorical_features()
unique_groups = np.unique(groups)
actual_splits = min(n_splits, len(unique_groups))
if actual_splits < 2:
print(f" WARNING: Only {len(unique_groups)} unique groups, skipping CV")
return {}
gkf = GroupKFold(n_splits=actual_splits)
oof_preds = np.zeros(len(df_valid))
fold_metrics = []
for fold_idx, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
X_tr, X_val = X[train_idx], X[val_idx]
y_tr, y_val = y[train_idx], y[val_idx]
model = train_single_target(
X_tr, y_tr, X_val, y_val, params, cat_features, feature_names
)
preds = model.predict(X_val)
oof_preds[val_idx] = preds
rmse = np.sqrt(np.mean((preds - y_val) ** 2))
r2 = 1 - np.sum((preds - y_val) ** 2) / max(
np.sum((y_val - y_val.mean()) ** 2), 1e-10
)
if target == "tflops":
val_df = df_valid.iloc[val_idx].copy()
preds_raw = np.expm1(preds) if apply_log else preds
val_df["pred_tflops"] = preds_raw
eff_df = compute_tflops_efficiency(val_df)
mean_eff = eff_df["efficiency"].mean() if len(eff_df) > 0 else 0
p10_eff = eff_df["efficiency"].quantile(0.1) if len(eff_df) > 0 else 0
else:
mean_eff, p10_eff = None, None
fold_metrics.append(
{
"fold": fold_idx,
"rmse": rmse,
"r2": r2,
"mean_efficiency": mean_eff,
"p10_efficiency": p10_eff,
"train_size": len(train_idx),
"val_size": len(val_idx),
"val_groups": len(np.unique(groups[val_idx])),
}
)
eff_str = (
f", eff={mean_eff:.4f}, p10={p10_eff:.4f}" if mean_eff is not None else ""
)
print(f" Fold {fold_idx}: RMSE={rmse:.4f}, R2={r2:.4f}{eff_str}")
df_valid[f"oof_pred_{target}"] = oof_preds
return {
"fold_metrics": fold_metrics,
"oof_df": df_valid,
"feature_names": feature_names,
"log_transform": apply_log,
}
def train_final_model(
df: pd.DataFrame,
feature_engine: GemmUniversalFeatureEngine,
target: str,
params: dict,
init_model=None,
use_log: bool = True,
) -> lgb.LGBMRegressor:
"""Train the final model on all valid data.
Parameters
----------
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
If provided, training continues from this model (warm start).
use_log : bool
If True and target is in LOG_TARGETS, train on log1p(y).
The saved model then predicts in log-space; callers must apply
expm1() to get raw values.
"""
target_col = TARGET_COLUMNS[target]
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
df_valid = df[valid_mask].reset_index(drop=True)
apply_log = use_log and target in LOG_TARGETS
X = feature_engine.extract_batch(df_valid)
y_raw = df_valid[target_col].values
y = np.log1p(y_raw) if apply_log else y_raw
feature_names = feature_engine.get_feature_names()
cat_features = feature_engine.get_categorical_features()
cat_indices = [feature_names.index(c) for c in cat_features if c in feature_names]
model = lgb.LGBMRegressor(**params)
model.fit(
X,
y,
categorical_feature=cat_indices if cat_indices else "auto",
init_model=init_model,
)
return model
def main():
parser = argparse.ArgumentParser(
description="Train CK Tile kernel performance models"
)
parser.add_argument(
"--data_dir", required=True, help="Directory with parquet files"
)
parser.add_argument("--out_dir", required=True, help="Output directory for models")
parser.add_argument("--op", default="gemm_universal", help="Operation type")
parser.add_argument("--dtype", default="fp8", help="Data type filter")
parser.add_argument("--arch", default="gfx950", help="Architecture")
parser.add_argument(
"--targets", default="tflops,latency,bandwidth", help="Comma-separated targets"
)
parser.add_argument("--n_splits", type=int, default=5, help="Number of CV folds")
parser.add_argument(
"--tune", action="store_true", help="Run Optuna hyperparameter tuning"
)
parser.add_argument(
"--no_log_transform",
action="store_true",
help="Disable log1p transform on targets. By default, TFLOPS and bandwidth "
"are trained in log-space for scale-invariant accuracy across shape sizes.",
)
parser.add_argument(
"--warm_start",
default=None,
help="Path to previous model directory to continue training from. "
"Uses LightGBM's init_model to add new trees on top of the "
"existing model. Feature schemas must match exactly.",
)
parser.add_argument(
"--warm_start_n_estimators",
type=int,
default=WARM_START_N_ESTIMATORS,
help=f"Number of new trees to add when warm-starting (default: {WARM_START_N_ESTIMATORS}). "
"Lower than a full train since we're refining, not starting from scratch.",
)
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
targets = [t.strip() for t in args.targets.split(",")]
print(f"Loading data from {args.data_dir}...")
df = build_training_dataset(args.data_dir, op_type=args.op, dtype=args.dtype)
print(f" Total rows: {len(df)}")
print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}")
print(f" Unique kernels: {df['kernel_name'].nunique()}")
hw_cols = [c for c in df.columns if c.startswith("hw_")]
hw_kwargs = {}
if hw_cols:
row0 = df.iloc[0]
if "hw_num_cus" in df.columns:
hw_kwargs["num_cus"] = int(row0.get("hw_num_cus", 256))
if "hw_max_clock_mhz" in df.columns:
hw_kwargs["max_clock_mhz"] = int(row0.get("hw_max_clock_mhz", 2400))
if "hw_simds_per_cu" in df.columns:
hw_kwargs["simds_per_cu"] = int(row0.get("hw_simds_per_cu", 4))
if "hw_shader_engines" in df.columns:
hw_kwargs["shader_engines"] = int(row0.get("hw_shader_engines", 32))
if "hw_max_waves_per_cu" in df.columns:
hw_kwargs["max_waves_per_cu"] = int(row0.get("hw_max_waves_per_cu", 32))
if "hw_wavefront_size" in df.columns:
hw_kwargs["wavefront_size"] = int(row0.get("hw_wavefront_size", 64))
if "hw_l1_cache_kb" in df.columns:
hw_kwargs["l1_cache_kb"] = int(row0.get("hw_l1_cache_kb", 32))
if "hw_l2_cache_kb" in df.columns:
hw_kwargs["l2_cache_kb"] = int(row0.get("hw_l2_cache_kb", 4096))
if "hw_l3_cache_kb" in df.columns:
hw_kwargs["l3_cache_kb"] = int(row0.get("hw_l3_cache_kb", 262144))
fe = GemmUniversalFeatureEngine(**hw_kwargs)
params = dict(DEFAULT_PARAMS)
use_log = not args.no_log_transform
prev_model_dir = None
prev_manifest = {}
if args.warm_start:
prev_model_dir = Path(args.warm_start)
if not prev_model_dir.exists():
raise FileNotFoundError(f"Warm-start directory not found: {prev_model_dir}")
print(f" Warm-starting from {prev_model_dir}")
check_feature_compatibility(prev_model_dir, fe)
print(" Feature compatibility: OK")
params["n_estimators"] = args.warm_start_n_estimators
print(f" New trees to add: {args.warm_start_n_estimators}")
prev_manifest_path = prev_model_dir / "train_manifest.json"
if prev_manifest_path.exists():
with open(prev_manifest_path) as f:
prev_manifest = json.load(f)
all_cv_results = {}
for target in targets:
if target not in TARGET_COLUMNS:
print(f" Skipping unknown target: {target}")
continue
print(f"\n{'=' * 60}")
print(f"Training {target} model")
print(f"{'=' * 60}")
init_model_path = None
if prev_model_dir is not None:
init_model_path = load_warm_start_model(prev_model_dir, target)
if init_model_path:
print(f" Warm-starting from {init_model_path}")
else:
print(f" No previous {target} model found, training from scratch")
t0 = time.time()
cv_result = run_cv(
df, fe, target, params, n_splits=args.n_splits, use_log=use_log
)
cv_time = time.time() - t0
if cv_result and cv_result["fold_metrics"]:
all_cv_results[target] = cv_result["fold_metrics"]
metrics_path = out_dir / f"cv_metrics_{target}.json"
with open(metrics_path, "w") as f:
json.dump(cv_result["fold_metrics"], f, indent=2)
print(f" CV completed in {cv_time:.1f}s, saved to {metrics_path}")
if target == "tflops" and cv_result.get("oof_df") is not None:
oof_df = cv_result["oof_df"]
oof_df.to_parquet(out_dir / "oof_predictions.parquet", index=False)
eff_df = compute_tflops_efficiency(oof_df, "oof_pred_tflops")
if len(eff_df) > 0:
print("\n OOF TFLOPS Efficiency:")
print(f" Mean: {eff_df['efficiency'].mean():.4f}")
print(f" P10: {eff_df['efficiency'].quantile(0.1):.4f}")
print(f" P50: {eff_df['efficiency'].quantile(0.5):.4f}")
print(f" Min: {eff_df['efficiency'].min():.4f}")
print(f"\n Training final {target} model on all data...")
t0 = time.time()
model = train_final_model(
df, fe, target, params, init_model=init_model_path, use_log=use_log
)
train_time = time.time() - t0
model_path = out_dir / f"model_{target}.lgbm"
model.booster_.save_model(str(model_path))
print(f" Saved {model_path} ({train_time:.1f}s)")
importances = dict(
zip(
fe.get_feature_names(),
model.feature_importances_.tolist(),
)
)
imp_path = out_dir / f"feature_importances_{target}.json"
with open(imp_path, "w") as f:
json.dump(importances, f, indent=2)
log_targets_used = sorted(LOG_TARGETS & set(targets)) if use_log else []
spec = {
"op_type": args.op,
"dtype": args.dtype,
"arch": args.arch,
"feature_names": fe.get_feature_names(),
"categorical_features": fe.get_categorical_features(),
"targets": targets,
"log_targets": log_targets_used,
"params": params,
}
with open(out_dir / "feature_spec.json", "w") as f:
json.dump(spec, f, indent=2)
manifest = {
"warm_start_from": str(prev_model_dir) if prev_model_dir else None,
"prev_n_estimators": prev_manifest.get(
"total_n_estimators", params.get("n_estimators")
)
if prev_model_dir
else 0,
"new_n_estimators": params["n_estimators"],
"total_n_estimators": (
prev_manifest.get("total_n_estimators", 0) + params["n_estimators"]
if prev_model_dir
else params["n_estimators"]
),
"data_rows": len(df),
"valid_rows": int(df["is_valid"].fillna(False).sum()),
"unique_shapes": int(df.groupby(["m", "n", "k"]).ngroups),
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
with open(out_dir / "train_manifest.json", "w") as f:
json.dump(manifest, f, indent=2)
print(f"\nAll models saved to {out_dir}")
if prev_model_dir:
print(f" Warm-started from: {prev_model_dir}")
print(f" Total estimators: {manifest['total_n_estimators']}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,317 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
ML Heuristic Validation: Test ML predictions against oracle-best from training data
This script validates ML-based kernel selection by:
1. Loading benchmark data (oracle-best results for each shape)
2. Using ML model to predict best kernel for each shape
3. Comparing ML selection with oracle-best to compute efficiency
Usage:
python validate_ml_heuristic.py --dtype fp16 --model_dir models/gemm_universal_fp16_gfx950
python validate_ml_heuristic.py --dtype fp8 --layout rcr
"""
import sys
import argparse
import pandas as pd
import numpy as np
from pathlib import Path
from predict import Predictor
def validate_ml_heuristic(dtype: str, layout: str, model_dir: str, data_dir: str):
"""Validate ML heuristic predictions against oracle-best"""
print("=" * 100)
print(f" ML Heuristic Validation: {dtype.upper()} {layout.upper()}")
print("=" * 100)
print()
# Load training data
print(f"Loading training data from {data_dir}...")
# Try dtype-specific parquet first, then fall back to combined
dtype_specific = (
Path(data_dir) / f"{dtype}_original" / f"{dtype}_training_data.parquet"
)
combined = Path(data_dir) / "all_training_data_fixed.parquet"
if dtype_specific.exists():
training_data = pd.read_parquet(dtype_specific)
print(f"✓ Loaded {len(training_data):,} benchmark runs from {dtype_specific}")
elif combined.exists():
training_data = pd.read_parquet(combined)
training_data = training_data[
(training_data["dtype"] == dtype) & (training_data["layout"] == layout)
]
print(f"✓ Loaded {len(training_data):,} benchmark runs from {combined}")
else:
print(f"❌ Error: No training data found at {dtype_specific} or {combined}")
return
if len(training_data) == 0:
print(f"❌ Error: No data found for dtype={dtype}, layout={layout}")
return
# Get unique shapes with oracle-best
shape_groups = training_data.groupby(["m", "n", "k"])
print(f"Unique shapes: {len(shape_groups)}")
print()
# Load ML predictor
print(f"Loading ML predictor from {model_dir}...")
try:
predictor = Predictor(model_dir)
print("✓ Loaded ML predictor")
print(f" Log targets: {predictor._log_targets}")
except Exception as e:
print(f"❌ Error loading model: {e}")
return
print()
print("=" * 100)
print(" Computing Oracle-Best Efficiency for Each Shape")
print("=" * 100)
print()
results = []
for shape_idx, ((m, n, k), group) in enumerate(shape_groups):
# Find oracle-best (max TFLOPS across all kernels tested)
oracle_best_row = group.loc[group["measured_tflops"].idxmax()]
oracle_best_tflops = oracle_best_row["measured_tflops"]
oracle_best_kernel = oracle_best_row["kernel_name"]
# Get all kernel configs tested for this shape
kernel_configs = []
for _, row in group.iterrows():
kernel_dict = {
"tile_m": row["tile_m"],
"tile_n": row["tile_n"],
"tile_k": row["tile_k"],
"warp_m": row["warp_m"],
"warp_n": row["warp_n"],
"warp_k": row["warp_k"],
"warp_tile_m": row["warp_tile_m"],
"warp_tile_n": row["warp_tile_n"],
"warp_tile_k": row["warp_tile_k"],
"pipeline": row["pipeline"],
"scheduler": row["scheduler"],
"epilogue": row["epilogue"],
"pad_m": row["pad_m"],
"pad_n": row["pad_n"],
"pad_k": row["pad_k"],
"persistent": row["persistent"],
"kernel_name": row["kernel_name"],
}
kernel_configs.append(kernel_dict)
# Use ML model to rank kernels
problem = {
"m": m,
"n": n,
"k": k,
"dtype": dtype,
"layout": layout,
"split_k": 1,
}
try:
ranked = predictor.rank_kernels(problem, kernel_configs)
if ranked:
ml_best_kernel, ml_predicted_tflops = ranked[0]
# Find actual TFLOPS for the ML-predicted kernel
ml_kernel_row = group[group["kernel_name"] == ml_best_kernel]
if len(ml_kernel_row) > 0:
ml_actual_tflops = ml_kernel_row["measured_tflops"].values[0]
# Calculate efficiency
efficiency_pct = 100.0 * (ml_actual_tflops / oracle_best_tflops)
# Determine if ML picked oracle-best
is_oracle_best = ml_best_kernel == oracle_best_kernel
results.append(
{
"m": m,
"n": n,
"k": k,
"oracle_best_tflops": oracle_best_tflops,
"oracle_best_kernel": oracle_best_kernel,
"ml_predicted_tflops": ml_predicted_tflops,
"ml_selected_kernel": ml_best_kernel,
"ml_actual_tflops": ml_actual_tflops,
"efficiency_pct": efficiency_pct,
"is_oracle_best": is_oracle_best,
"num_kernels": len(group),
}
)
if (shape_idx + 1) % 20 == 0:
status = "" if is_oracle_best else f"{efficiency_pct:.1f}%"
print(
f" [{shape_idx + 1:3d}/{len(shape_groups)}] "
f"M={m:4d} N={n:5d} K={k:5d}: {status}"
)
except Exception as e:
print(f" Error on shape M={m} N={n} K={k}: {e}")
continue
print()
print("=" * 100)
print(" Results Summary")
print("=" * 100)
print()
if results:
df_results = pd.DataFrame(results)
efficiencies = df_results["efficiency_pct"].values
oracle_matches = df_results["is_oracle_best"].sum()
print(f"Total shapes tested: {len(results)}")
print()
print("Efficiency Statistics (% of Oracle-Best TFLOPS):")
print(f" Mean: {np.mean(efficiencies):.2f}%")
print(f" Median: {np.median(efficiencies):.2f}%")
print(f" Min: {np.min(efficiencies):.2f}%")
print(f" Max: {np.max(efficiencies):.2f}%")
print(f" P10: {np.percentile(efficiencies, 10):.2f}%")
print(f" P50: {np.percentile(efficiencies, 50):.2f}%")
print(f" P90: {np.percentile(efficiencies, 90):.2f}%")
print()
print(
f"Oracle-best matches: {oracle_matches}/{len(results)} ({100 * oracle_matches / len(results):.1f}%)"
)
print()
# Classify by M size
df_results["m_class"] = pd.cut(
df_results["m"],
bins=[0, 8, 128, 1024, float("inf")],
labels=[
"Tiny (M<8)",
"Small (8≤M<128)",
"Medium (128≤M<1024)",
"Large (M≥1024)",
],
)
print("Efficiency by M size:")
for m_class in [
"Tiny (M<8)",
"Small (8≤M<128)",
"Medium (128≤M<1024)",
"Large (M≥1024)",
]:
subset = df_results[df_results["m_class"] == m_class]
if len(subset) > 0:
print(
f" {m_class:25s}: {subset['efficiency_pct'].mean():6.2f}% "
f"(n={len(subset)}, P10={subset['efficiency_pct'].quantile(0.1):.2f}%)"
)
print()
# Save results
output_file = f"validation_results_{dtype}_{layout}.csv"
df_results.to_csv(output_file, index=False)
print(f"✓ Results saved to {output_file}")
# Show best and worst shapes
print()
print("Top 5 shapes (best efficiency):")
top5 = df_results.nlargest(5, "efficiency_pct")[
["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"]
]
for idx, row in top5.iterrows():
match = "" if row["is_oracle_best"] else " "
print(
f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: "
f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)"
)
print()
print("Bottom 5 shapes (worst efficiency):")
bottom5 = df_results.nsmallest(5, "efficiency_pct")[
["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"]
]
for idx, row in bottom5.iterrows():
match = "" if row["is_oracle_best"] else " "
print(
f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: "
f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)"
)
else:
print("No results to display")
print()
print("=" * 100)
def main():
parser = argparse.ArgumentParser(
description="Validate ML heuristic predictions against oracle-best from training data"
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp8"],
help="Data type to validate",
)
parser.add_argument(
"--layout",
default="rcr",
choices=["rcr", "rrr", "crr", "ccr"],
help="Matrix layout",
)
parser.add_argument(
"--model_dir",
default=None,
help="Path to model directory (auto-detect if not specified)",
)
parser.add_argument(
"--data_dir",
default=None,
help="Path to training data directory (auto-detect if not specified)",
)
args = parser.parse_args()
# Auto-detect model directory if not specified
if args.model_dir is None:
heuristics_dir = Path(__file__).parent
model_candidates = [
heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx950",
heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx942",
]
for candidate in model_candidates:
if candidate.exists():
args.model_dir = str(candidate)
break
if args.model_dir is None:
print(f"❌ Error: Could not find model directory for {args.dtype}")
print(f" Searched: {[str(c) for c in model_candidates]}")
print(" Please specify --model_dir explicitly")
return 1
# Auto-detect data directory if not specified
if args.data_dir is None:
heuristics_dir = Path(__file__).parent
args.data_dir = str(heuristics_dir / "data")
validate_ml_heuristic(args.dtype, args.layout, args.model_dir, args.data_dir)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -140,6 +140,11 @@ struct KernelKey
bool preshuffle; // Preshuffle (for weight preshuffle variants)
bool transpose_c; // TransposeC
std::uint8_t num_wave_groups; // NumWaveGroups
// Padding support flags (kPadM, kPadN, kPadK in generated kernels)
bool pad_m = true; // Support arbitrary M dimensions via padding
bool pad_n = true; // Support arbitrary N dimensions via padding
bool pad_k = true; // Support arbitrary K dimensions via padding
} algorithm;
std::string gfx_arch; // e.g. "gfx942", "gfx90a", "gfx908"
@@ -185,7 +190,10 @@ struct KernelKey
algorithm.double_buffer,
algorithm.preshuffle,
algorithm.transpose_c,
algorithm.num_wave_groups);
algorithm.num_wave_groups,
algorithm.pad_m,
algorithm.pad_n,
algorithm.pad_k);
}
/// Equality comparison
@@ -397,8 +405,14 @@ inline std::string KernelKey::encode_identifier() const
// Include pipeline, scheduler, epilogue for uniqueness
oss << to_string(algorithm.pipeline) << "_";
oss << to_string(algorithm.scheduler) << "_";
oss << to_string(algorithm.epilogue) << "_";
oss << to_string(algorithm.scheduler) << "_";
// Match tile_engine naming: padding flags (True/False) then persistent flag
oss << (algorithm.pad_m ? "True" : "False") << "_";
oss << (algorithm.pad_n ? "True" : "False") << "_";
oss << (algorithm.pad_k ? "True" : "False") << "_";
oss << (algorithm.persistent ? "True" : "False") << "_";
// Match tile_engine naming: tile_m x tile_n x tile_k _ warp_m x warp_n x warp_k _
// warp_tile_m x warp_tile_n x warp_tile_k
@@ -407,9 +421,6 @@ inline std::string KernelKey::encode_identifier() const
<< unsigned(algorithm.wave_shape.k) << "_" << unsigned(algorithm.warp_tile_shape.m) << "x"
<< unsigned(algorithm.warp_tile_shape.n) << "x" << unsigned(algorithm.warp_tile_shape.k);
// Add trait flags
oss << "_" << (algorithm.persistent ? "persist" : "nopers");
if(signature.split_k > 1)
oss << "_splitk" << unsigned(signature.split_k);
if(!signature.elementwise_op.empty() && signature.elementwise_op != "PassThrough")

View File

@@ -0,0 +1,379 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/dispatcher/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "ck_tile/dispatcher/registry.hpp"
#include <algorithm>
#include <array>
#include <cmath>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
namespace ck_tile {
namespace dispatcher {
extern "C" {
int LGBM_BoosterCreateFromModelfile(const char*, int*, void**);
int LGBM_BoosterPredictForMat(
void*, const void*, int, int, int, int, int, int, int, const char*, int64_t*, double*);
int LGBM_BoosterFree(void*);
}
inline int encode_pipeline(Pipeline p)
{
switch(p)
{
case Pipeline::CompV3: return 0;
case Pipeline::CompV4: return 1;
case Pipeline::CompV5: return 2;
case Pipeline::Mem: return 3;
case Pipeline::PreShuffleV2: return 4;
default: return 0;
}
}
inline int encode_scheduler(Scheduler s)
{
switch(s)
{
case Scheduler::Intrawave: return 0;
case Scheduler::Interwave: return 1;
default: return 0;
}
}
inline int encode_epilogue(Epilogue e)
{
switch(e)
{
case Epilogue::Default: return 0;
case Epilogue::CShuffle: return 1;
default: return 0;
}
}
inline int encode_layout(LayoutTag a, LayoutTag b, LayoutTag c)
{
bool ra = (a == LayoutTag::RowMajor), rb = (b == LayoutTag::RowMajor);
if(ra && !rb)
return 0; // RCR
if(ra && rb)
return 1; // RRR
if(!ra && rb)
return 2; // CCR
return 3; // CRR
}
inline double dtype_bytes_ml(DataType dt)
{
switch(dt)
{
case DataType::FP32: return 4;
case DataType::FP16:
case DataType::BF16: return 2;
case DataType::FP8:
case DataType::BF8:
case DataType::INT8: return 1;
case DataType::INT4: return 0.5;
default: return 2;
}
}
struct HardwareProfile
{
int num_cus = 256, simds_per_cu = 4, shader_engines = 32, max_clock_mhz = 2400,
max_waves_per_cu = 32, wavefront_size = 64, lds_capacity = 65536, l1_cache_kb = 32,
l2_cache_kb = 4096, l3_cache_kb = 262144, num_xcd = 8;
int total_simds() const { return num_cus * simds_per_cu; }
};
// CRITICAL: Feature count MUST match feature_spec.json
// Python training uses 72 features - this header MUST extract exactly 72 features in the same order
static constexpr int NUM_FEATURES = 72;
inline std::array<double, NUM_FEATURES>
extract_features(const Problem& prob, const KernelKey& key, const HardwareProfile& hw)
{
// Problem dimensions
double M = prob.M, N = prob.N, K = prob.K;
double sk = (prob.k_batch > 0 ? prob.k_batch : 1);
double bpe = dtype_bytes_ml(key.signature.dtype_a);
// Log-scale features
double l2M = std::log2(std::max(M, 1.0));
double l2N = std::log2(std::max(N, 1.0));
double l2K = std::log2(std::max(K, 1.0));
double l2MNK = std::log2(std::max(M * N * K, 1.0));
// Arithmetic intensity
double mem = (M * K + K * N + M * N) * bpe;
double ai = 2.0 * M * N * K / std::max(mem, 1.0);
// Aspect ratios
double ar_mn = M / std::max(N, 1.0);
double ar_mk = M / std::max(K, 1.0);
double ar_nk = N / std::max(K, 1.0);
// Layout encoding
double layout = (double)encode_layout(
key.signature.layout_a, key.signature.layout_b, key.signature.layout_c);
// Tile dimensions
double tm = key.algorithm.tile_shape.m;
double tn = key.algorithm.tile_shape.n;
double tk = key.algorithm.tile_shape.k;
// Wave/warp dimensions
double wm = key.algorithm.wave_shape.m;
double wn = key.algorithm.wave_shape.n;
double wk = key.algorithm.wave_shape.k;
// Warp tile dimensions
double wtm = key.algorithm.warp_tile_shape.m;
double wtn = key.algorithm.warp_tile_shape.n;
double wtk = key.algorithm.warp_tile_shape.k;
// Algorithm encoding
double pipeline = (double)encode_pipeline(key.algorithm.pipeline);
double scheduler = (double)encode_scheduler(key.algorithm.scheduler);
double epilogue = (double)encode_epilogue(key.algorithm.epilogue);
// Padding flags - read from KernelKey
double pad_m = key.algorithm.pad_m ? 1.0 : 0.0;
double pad_n = key.algorithm.pad_n ? 1.0 : 0.0;
double pad_k = key.algorithm.pad_k ? 1.0 : 0.0;
// Persistent kernel flag
double persistent = key.algorithm.persistent ? 1.0 : 0.0;
// Derived features
double num_warps = wm * wn * wk;
double tile_volume = tm * tn * tk;
double tile_mn = tm * tn;
// LDS usage estimation
double lest = (tm * tk + tn * tk) * bpe;
double lcap = (key.algorithm.pipeline == Pipeline::CompV4) ? 32768.0 : (double)hw.lds_capacity;
double lds_ratio = lest / std::max(lcap, 1.0);
// Tile counts
double ntm = std::ceil(M / std::max(tm, 1.0));
double ntn = std::ceil(N / std::max(tn, 1.0));
double ntk = std::ceil(K / std::max(tk, 1.0));
double total_output_tiles = ntm * ntn;
// Tile efficiency (fractional remainder utilization)
auto ef = [](double d, double t) -> double {
if(t <= 0)
return 1.0;
double r = std::fmod(d, t);
return r > 0 ? r / t : 1.0;
};
double tile_eff_m = ef(M, tm);
double tile_eff_n = ef(N, tn);
double tile_eff_k = ef(K, tk);
double overall_tile_efficiency = tile_eff_m * tile_eff_n * tile_eff_k;
// CU utilization
double cu_utilization = total_output_tiles / std::max((double)hw.num_cus, 1.0);
// P0 FIX: Problem-to-tile ratio features (critical for small problems)
double ratio_M_to_tile_m = M / std::max(tm, 1.0);
double ratio_N_to_tile_n = N / std::max(tn, 1.0);
double ratio_K_to_tile_k = K / std::max(tk, 1.0);
// Binary features: is problem dimension smaller than tile?
double problem_smaller_than_tile_m = (M < tm) ? 1.0 : 0.0;
double problem_smaller_than_tile_n = (N < tn) ? 1.0 : 0.0;
double problem_smaller_than_tile_k = (K < tk) ? 1.0 : 0.0;
double any_dim_too_small = ((M < tm) || (N < tn) || (K < tk)) ? 1.0 : 0.0;
// P1 FIX: Padding requirement features
double needs_padding_m = (tm > 0 && std::fmod(M, tm) != 0.0) ? 1.0 : 0.0;
double needs_padding_n = (tn > 0 && std::fmod(N, tn) != 0.0) ? 1.0 : 0.0;
double needs_padding_k = (tk > 0 && std::fmod(K, tk) != 0.0) ? 1.0 : 0.0;
// Interaction features: kernel has padding when problem needs it
double has_padding_when_needed_m = (needs_padding_m && pad_m) ? 1.0 : 0.0;
double has_padding_when_needed_n = (needs_padding_n && pad_n) ? 1.0 : 0.0;
double has_padding_when_needed_k = (needs_padding_k && pad_k) ? 1.0 : 0.0;
// Critical feature: missing required padding (kernel will likely fail)
double missing_required_padding_m = (needs_padding_m && !pad_m) ? 1.0 : 0.0;
double missing_required_padding_n = (needs_padding_n && !pad_n) ? 1.0 : 0.0;
double missing_required_padding_k = (needs_padding_k && !pad_k) ? 1.0 : 0.0;
double missing_any_required_padding =
(missing_required_padding_m || missing_required_padding_n || missing_required_padding_k)
? 1.0
: 0.0;
// Hardware features
double hw_num_cus = (double)hw.num_cus;
double hw_simds_per_cu = (double)hw.simds_per_cu;
double hw_total_simds = (double)hw.total_simds();
double hw_shader_engines = (double)hw.shader_engines;
double hw_max_clock_mhz = (double)hw.max_clock_mhz;
double hw_max_waves_per_cu = (double)hw.max_waves_per_cu;
double hw_wavefront_size = (double)hw.wavefront_size;
double hw_lds_capacity = (double)hw.lds_capacity;
double hw_l1_cache_kb = (double)hw.l1_cache_kb;
double hw_l2_cache_kb = (double)hw.l2_cache_kb;
double hw_l3_cache_kb = (double)hw.l3_cache_kb;
double hw_num_xcd = (double)hw.num_xcd;
// Feature vector in EXACT order from feature_spec.json
// This order MUST match Python feature_engine.py::get_feature_names()
return {{
M, // 0
N, // 1
K, // 2
sk, // 3 (split_k)
l2M, // 4 (log2_M)
l2N, // 5 (log2_N)
l2K, // 6 (log2_K)
l2MNK, // 7 (log2_MNK)
ai, // 8 (arithmetic_intensity)
ar_mn, // 9 (aspect_ratio_mn)
ar_mk, // 10 (aspect_ratio_mk)
ar_nk, // 11 (aspect_ratio_nk)
layout, // 12 (layout)
tm, // 13 (tile_m)
tn, // 14 (tile_n)
tk, // 15 (tile_k)
wm, // 16 (warp_m)
wn, // 17 (warp_n)
wk, // 18 (warp_k)
wtm, // 19 (warp_tile_m)
wtn, // 20 (warp_tile_n)
wtk, // 21 (warp_tile_k)
pipeline, // 22 (pipeline)
scheduler, // 23 (scheduler)
epilogue, // 24 (epilogue)
pad_m, // 25 (pad_m)
pad_n, // 26 (pad_n)
pad_k, // 27 (pad_k)
persistent, // 28 (persistent)
num_warps, // 29 (num_warps)
tile_volume, // 30 (tile_volume)
tile_mn, // 31 (tile_mn)
lest, // 32 (lds_usage_estimate)
lds_ratio, // 33 (lds_usage_ratio)
ntm, // 34 (num_tiles_m)
ntn, // 35 (num_tiles_n)
ntk, // 36 (num_tiles_k)
total_output_tiles, // 37 (total_output_tiles)
tile_eff_m, // 38 (tile_eff_m)
tile_eff_n, // 39 (tile_eff_n)
tile_eff_k, // 40 (tile_eff_k)
overall_tile_efficiency, // 41 (overall_tile_efficiency)
cu_utilization, // 42 (cu_utilization)
ratio_M_to_tile_m, // 43 (ratio_M_to_tile_m)
ratio_N_to_tile_n, // 44 (ratio_N_to_tile_n)
ratio_K_to_tile_k, // 45 (ratio_K_to_tile_k)
problem_smaller_than_tile_m, // 46 (problem_smaller_than_tile_m)
problem_smaller_than_tile_n, // 47 (problem_smaller_than_tile_n)
problem_smaller_than_tile_k, // 48 (problem_smaller_than_tile_k)
any_dim_too_small, // 49 (any_dim_too_small)
needs_padding_m, // 50 (needs_padding_m)
needs_padding_n, // 51 (needs_padding_n)
needs_padding_k, // 52 (needs_padding_k)
has_padding_when_needed_m, // 53 (has_padding_when_needed_m)
has_padding_when_needed_n, // 54 (has_padding_when_needed_n)
has_padding_when_needed_k, // 55 (has_padding_when_needed_k)
missing_required_padding_m, // 56 (missing_required_padding_m)
missing_required_padding_n, // 57 (missing_required_padding_n)
missing_required_padding_k, // 58 (missing_required_padding_k)
missing_any_required_padding, // 59 (missing_any_required_padding)
hw_num_cus, // 60 (hw_num_cus)
hw_simds_per_cu, // 61 (hw_simds_per_cu)
hw_total_simds, // 62 (hw_total_simds)
hw_shader_engines, // 63 (hw_shader_engines)
hw_max_clock_mhz, // 64 (hw_max_clock_mhz)
hw_max_waves_per_cu, // 65 (hw_max_waves_per_cu)
hw_wavefront_size, // 66 (hw_wavefront_size)
hw_lds_capacity, // 67 (hw_lds_capacity)
hw_l1_cache_kb, // 68 (hw_l1_cache_kb)
hw_l2_cache_kb, // 69 (hw_l2_cache_kb)
hw_l3_cache_kb, // 70 (hw_l3_cache_kb)
hw_num_xcd, // 71 (hw_num_xcd)
}};
}
class MLHeuristic
{
public:
MLHeuristic(const std::string& path,
const Registry* reg,
HardwareProfile hw = {},
bool log_t = false)
: registry_(reg), hw_(hw), log_t_(log_t)
{
int iters = 0;
if(LGBM_BoosterCreateFromModelfile(path.c_str(), &iters, &b_) != 0 || !b_)
{
std::cerr << "MLHeuristic: Failed to load " << path << std::endl;
// Check if a compressed .gz version exists
std::string gz_path = path + ".gz";
std::ifstream gz_check(gz_path);
if(gz_check.good())
{
std::cerr << "MLHeuristic: Found compressed model at " << gz_path << std::endl;
std::cerr << "MLHeuristic: Please decompress it first:" << std::endl;
std::cerr << " gunzip " << gz_path << std::endl;
}
b_ = nullptr;
}
else
std::cout << "MLHeuristic: Loaded (" << iters << " iters)" << std::endl;
}
~MLHeuristic()
{
if(b_)
LGBM_BoosterFree(b_);
}
MLHeuristic(const MLHeuristic&) = delete;
MLHeuristic& operator=(const MLHeuristic&) = delete;
bool is_loaded() const { return b_ != nullptr; }
double predict_tflops(const Problem& prob, const KernelKey& key) const
{
if(!b_)
return 0;
auto f = extract_features(prob, key, hw_);
int64_t ol = 0;
double pred = 0;
if(LGBM_BoosterPredictForMat(
b_, f.data(), 0, 1, NUM_FEATURES, 1, 0, 0, 0, "", &ol, &pred) != 0)
return 0;
return log_t_ ? std::expm1(pred) : pred;
}
std::vector<std::string> operator()(const Problem& prob) const
{
if(!b_ || !registry_)
return {};
auto insts = registry_->get_all();
struct C
{
std::string id;
double t;
};
std::vector<C> cs;
cs.reserve(insts.size());
for(auto& i : insts)
{
auto& k = i->get_key();
cs.push_back({k.encode_identifier(), predict_tflops(prob, k)});
}
std::sort(cs.begin(), cs.end(), [](auto& a, auto& b) { return a.t > b.t; });
std::vector<std::string> r;
r.reserve(cs.size());
for(auto& c : cs)
r.push_back(std::move(c.id));
return r;
}
private:
void* b_ = nullptr;
const Registry* registry_ = nullptr;
HardwareProfile hw_;
bool log_t_ = false;
};
} // namespace dispatcher
} // namespace ck_tile

View File

@@ -1,11 +1,16 @@
# Core dependencies
numpy>=1.19.0
# ML Heuristic dependencies (OPTIONAL - large dependencies)
# For ML-based kernel selection, install separately:
# pip install -r ../requirements-ml.txt
# This avoids mandatory large dependencies (pyarrow, lightgbm) for users who don't need ML features
# Optional dependencies (install with pip install -e ".[torch]")
# torch>=2.0.0
# Development dependencies (install with pip install -e ".[dev]")
# pytest>=6.0.0
pytest>=6.0.0
# pytest-cov>=2.0.0
# black>=21.0
# flake8>=3.9.0

View File

@@ -0,0 +1,6 @@
# ML Heuristic dependencies for ML-based kernel selection
# Install with: pip install -r requirements-ml.txt
lightgbm>=3.0.0
pandas>=1.3.0
pyarrow>=6.0.0
scikit-learn>=0.24.0

View File

@@ -538,6 +538,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
/* BlockSize = M0 = */ {F_bm0},
{F_hdim},
{F_mode},

View File

@@ -1209,7 +1209,8 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
# bm0, bn0, bk0, bn1, bk1,
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q <= 8192")),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
@@ -1244,7 +1245,7 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
@@ -1303,7 +1304,23 @@ class Product:
def get_product(receipt: int) -> Product:
# Flash attention integration
if receipt in (2, 3):
if receipt == 2:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
cond = problem_ctx.dtype in ["fp16", "bf16"]
cond &= kernel_ctx.pipeline.F_vlayout == "row"
cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"]
cond &= kernel_ctx.pipeline.F_qscale == "no"
cond &= kernel_ctx.pipeline.F_skip == "f"
cond &= kernel_ctx.pipeline.F_sink == "f"
# FlashAttention direct fwd wrappers always use softcap disabled and LSE enabled.
cond &= kernel_ctx.pipeline.F_logits == "f"
cond &= kernel_ctx.pipeline.F_lse == "t"
return cond
return Product(name="Flash attention integration", rule=fit)
# Receipt 3 forward coverage used by CK library / smoke tests
elif receipt == 3:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
cond = problem_ctx.dtype in ["fp16", "bf16"]

View File

@@ -939,6 +939,8 @@ def get_fwd_splitkv_blobs(
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
# FlashAttention splitkv paths use softcap-disabled kernels only.
cond &= pipeline.F_logits == "f"
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
@@ -1142,4 +1144,7 @@ def list_blobs(
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + "\n")
f.write(
(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix()
+ "\n"
)

View File

@@ -87,6 +87,7 @@ auto create_args(int argc, char* argv[])
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
"will not be used")
.insert("sink_grad", "0", "if set to 1, compute and validate sink token gradient")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
@@ -122,6 +123,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
bool deterministic = arg_parser.get_bool("deterministic");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
bool sink_grad = arg_parser.get_bool("sink_grad");
ck_tile::stream_config stream_config{nullptr,
true,
@@ -154,6 +156,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
drop_offset,
drop_prefs,
mask_str,
sink_grad,
deterministic,
init_method,
seed,

View File

@@ -117,6 +117,9 @@ struct fmha_bwd_args
void* dv_ptr;
void* dbias_ptr;
void* workspace_ptr;
const void*
sink_ptr; // sink scores [batch, nhead] in log-space (LSEDataType); nullptr disables sink
void* d_sink_ptr; // sink gradient output [nhead] (LSEDataType); nullptr disables sink gradient
// Usage notes for sequence length pointer parameters:
//
@@ -353,11 +356,15 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.lse_ptr,
args.sink_ptr,
args.d_sink_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.seqlen_q_ptr,
args.cu_seqlen_q_ptr,
args.hdim_v,
args.nhead_q,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
@@ -369,9 +376,13 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.lse_ptr,
args.sink_ptr,
args.d_sink_ptr,
args.p_undrop,
args.seqlen_q,
args.hdim_v,
args.nhead_q,
args.stride_do,
args.stride_o,
args.nhead_stride_do,

View File

@@ -77,6 +77,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
uint64_t drop_offset,
bool drop_prefs,
std::string mask_str,
bool sink_grad, // if true, compute and validate sink gradient
bool deterministic,
std::string init_method,
uint32_t seed,
@@ -285,6 +286,16 @@ bwd_result fmha_bwd_run(mode_enum mode,
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<LSEDataType> sink_host(
sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
: std::array<ck_tile::index_t, 2>{1, 1} /* dummy when sink is disabled */);
if(sink_grad)
{
std::uniform_real_distribution<float> sink_dist(30.0f, 60.0f);
sink_host.ForEach([&](auto& self, auto i) {
self(i) = static_cast<LSEDataType>(sink_dist(random_engine));
});
}
ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host(
@@ -302,6 +313,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<LSEDataType> d_sink_host(sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
: std::array<ck_tile::index_t, 1>{0});
if(sink_grad)
{
d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; });
}
if(init_method == "ui" || init_method == "0")
{
@@ -360,11 +377,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sink_buf(sink_grad ? sink_host.get_element_space_size_in_bytes() : 0);
ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_sink_buf(sink_grad ? d_sink_host.get_element_space_size_in_bytes() : 0);
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
@@ -396,6 +415,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
if(sink_grad)
{
sink_buf.ToDevice(sink_host.data());
d_sink_buf.ToDevice(d_sink_host.data());
}
// clang-format off
auto layout_str = [&](bool permute){
@@ -415,7 +439,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
<< ", mask:" << mask << std::flush;
@@ -474,7 +499,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
@@ -490,6 +514,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(),
ws_ptr,
sink_buf.GetDeviceBuffer(),
d_sink_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
seqlen_q_ptr_dev,
@@ -580,6 +606,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
std::vector<ck_tile::HostTensor<AccDataType>> p_sink_host_refs;
randval_buf.FromDevice(randval_host.data());
@@ -756,6 +783,46 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
// Incorporate sink token into the softmax distribution (reference computation).
// The sink acts as an extra key whose score is sink_host(wb, i_h) (in log-space),
// which is a per-head random value in [30, 60].
// lse_new = log(exp(lse_old) + exp(sink))
// P_new = P_old * exp(lse_old - lse_new) (rescaled token attention)
// P_sink = exp(sink - lse_new) (sink attention weight)
ck_tile::HostTensor<AccDataType> p_sink_host_ref(
sink_grad ? std::array<ck_tile::index_t, 2>{nhead, real_seqlen_q}
: std::array<ck_tile::index_t, 2>{0, 0});
if(sink_grad)
{
for(int i_h = 0; i_h < nhead; ++i_h)
{
AccDataType sink_val = sink_host(wb, i_h);
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
{
// Use numerically stable log-sum-exp: lse_new = log(exp(lse_old)+exp(sink))
// = max(lse_old, sink) + log(1 + exp(min - max))
// This handles lse_old = -inf (fully-masked rows) without producing NaN:
// if lse_old=-inf: max=sink, min=-inf, exp(-inf-sink)=0, lse_new=sink
// It also avoids exp(lse_old) overflow when lse_old is large.
// p_scale = exp(lse_old - lse_new) [fraction kept by regular tokens]
// p_sink = exp(sink - lse_new) [sink attention weight]
AccDataType lse_old = lse_host_ref(i_h, i_q);
AccDataType hi = lse_old > sink_val ? lse_old : sink_val;
AccDataType lo = lse_old > sink_val ? sink_val : lse_old;
AccDataType lse_new =
hi + ck_tile::log(AccDataType(1) + ck_tile::exp(lo - hi));
AccDataType p_scale = ck_tile::exp(lse_old - lse_new);
lse_host_ref(i_h, i_q) = lse_new;
for(int i_k = 0; i_k < real_seqlen_k; ++i_k)
p_hp_host_ref(i_h, i_q, i_k) *= p_scale;
p_sink_host_ref(i_h, i_q) = ck_tile::exp(sink_val - lse_new);
}
}
}
if(p_drop > 0)
{
p_dropped_hp_host_ref = p_hp_host_ref;
@@ -814,6 +881,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
o_host_refs.push_back(o_host_ref);
p_hp_host_refs.push_back(p_hp_host_ref);
p_lp_host_refs.push_back(p_lp_host_ref);
p_sink_host_refs.push_back(p_sink_host_ref);
if(p_drop > 0)
{
randval_host_refs.push_back(randval_host_ref);
@@ -833,6 +901,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
dbias_buf.SetZero();
if(sink_grad)
d_sink_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
launcher(fmha_args, stream_config_v);
@@ -841,10 +911,19 @@ bwd_result fmha_bwd_run(mode_enum mode,
dk_buf.FromDevice(dk_host.data());
dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data());
if(sink_grad)
d_sink_buf.FromDevice(d_sink_host.data());
// Track the index into reference vectors (may differ from wb if batches were skipped)
ck_tile::index_t ref_idx = 0;
// validation sink accumulator: global over batch, shape [nhead]
ck_tile::HostTensor<AccDataType> d_sink_host_ref(
sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
: std::array<ck_tile::index_t, 1>{0});
if(sink_grad)
d_sink_host_ref.ForEach([&](auto& self, auto i) { self(i) = 0; });
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
// When padding is enabled, use logical lengths instead of computing from padded
@@ -920,6 +999,30 @@ bwd_result fmha_bwd_run(mode_enum mode,
ds_hp_host_ref.mDesc.get_lengths()[1],
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
if(sink_grad)
{
// Reference: dSink[h] = -sum_q( P_sink[h,q] * D[h,q] )
// where D[h,q] = sum_j(dO[h,q,j] * O[h,q,j]) * p_undrop
for(int i_h = 0; i_h < nhead; ++i_h)
{
AccDataType d_sink_head_acc = 0;
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
{
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
do_dot_o +=
ck_tile::type_convert<AccDataType>(do_host_ref(i_h, i_q, o)) *
ck_tile::type_convert<AccDataType>(
o_host_refs[ref_idx](i_h, i_q, o)) *
p_undrop;
}
d_sink_head_acc += -p_sink_host_refs[ref_idx](i_h, i_q) * do_dot_o;
}
d_sink_host_ref(i_h) += d_sink_head_acc;
}
}
if(use_dbias)
{
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
@@ -1032,6 +1135,17 @@ bwd_result fmha_bwd_run(mode_enum mode,
ref_idx++;
}
if(pass && sink_grad)
{
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
bool dsink_pass = ck_tile::check_err(d_sink_host,
d_sink_host_ref,
std::string("Error: SinkGrad Incorrect results!"),
rtol,
atol);
pass &= dsink_pass;
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}

View File

@@ -39,7 +39,6 @@ function print_log_header(){
#run verification tests
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
#run performance benchmarks
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"

View File

@@ -69,6 +69,28 @@ test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
done
# sink gradient tests: same coverage as main tests but with -sink_grad=1
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 64 128 256 ; do
for mode in 0 1 ; do
for bias in "n" "a" ; do
for p_drop in 0.0 0.2 ; do
test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=0 -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=0 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -sink_grad=1
done
done
done
done
done
done
# sink gradient additional cases: non-standard hdim
for hdim in 40 48 72 96 ; do
test_h_s_mask -prec=fp16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS -sink_grad=1
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
done
set +x
new_fails_count=0

View File

@@ -235,6 +235,64 @@ run_padding_basic_boundary_tests() {
done
}
# Sink-specific mask pattern tests (sliding window + sink token).
run_sink_mask_tests() {
# window_size[2,0], sink_size=2 (top-left causal + sink)
# before: after:
# 1 * * * * * * * 1 * * * * * * *
# 1 1 * * * * * * 1 1 * * * * * *
# 1 1 1 * * * * * 1 1 1 * * * * *
# * 1 1 1 * * * * 1 1 1 1 * * * *
# * * 1 1 1 * * * 1 1 1 1 1 * * *
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
# * * * * * 1 1 1 1 1 * * * 1 1 1
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
run_exe -prec=bf16 -mode=0 -b=2 -h=2 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
# window_size[0,3], sink_size=2 (top-left + sink)
# before: after:
# 1 1 1 1 * * * * 1 1 1 1 * * * *
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
# * * 1 1 1 1 * * 1 1 1 1 1 1 * *
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
# window_size[1,0], sink_size=2 (bottom-right + sink)
# before: after:
# * * 1 1 * * * * 1 1 1 1 * * * *
# * * * 1 1 * * * 1 1 * 1 1 * * *
# * * * * 1 1 * * 1 1 * * 1 1 * *
# * * * * * 1 1 * 1 1 * * * 1 1 *
# * * * * * * 1 1 1 1 * * * * 1 1
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
run_exe -prec=bf16 -mode=0 -b=2 -h=4 -d=128 -d_v=128 -s=2048 -s_k=2048 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
# window_size[2,0], sink_size=2 (bottom-right, group mode + sink)
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
# window_size[-1,1], sink_size=2 (bottom-right, large seqlen + sink)
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
run_exe -prec=bf16 -mode=1 -b=1 -h=2 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
}
# init_sink tests: validate sink token initialization across prec/hdim/mode.
run_sink_init_tests() {
for prec in "fp16" "bf16" ; do
for hdim in 64 128 256 ; do
for mode in 0 1 ; do
for mask in 0 1 ; do
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -d=$hdim -d_v=$hdim -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
run_exe -prec=$prec -mode=$mode -b=2 -h=4 -d=$hdim -d_v=$hdim -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
done
done
done
done
}
set -x
run_fp16_bf16_tests
@@ -242,6 +300,8 @@ run_padding_smoke_tests
run_padding_basic_boundary_tests
run_fp8bf16_tests
run_fp8fp32_tests
run_sink_mask_tests
run_sink_init_tests
if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests

View File

@@ -1,93 +0,0 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# TODO: run this script from CK root or build directory
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
set -euo pipefail
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
EXE_NAME=tile_example_fmha_fwd
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
KNAME=1
GPU_arch=$GPU_arch
if [ -z "$GPU_arch" ] ; then
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
fi
set -x
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
# window_size[2,0], sink_size = 2
# x=1/y=3
# 1 * * * * * * * 1 * * * * * * *
# 1 1 * * * * * * 1 1 * * * * * *
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
# * 1 1 1 * * * * 1 1 1 1 * * * *
# * * 1 1 1 * * * 1 1 1 1 1 * * *
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
# * * * * * 1 1 1 1 1 * * * 1 1 1
# l=2/r=0(tl) l=2/r=0/s=2(tl)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
# x=4/y=1
# 1 1 1 1 * * * * 1 1 1 1 * * * *
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
# l=0/r=3(tl) l=0/r=3/s=2(tl)
# l=3/r=0(br) l=3/r=0/s=2(br)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
# x=4/y=-1
# * * 1 1 * * * * 1 1 1 1 * * * *
# * * * 1 1 * * * 1 1 * 1 1 * * *
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
# * * * * * 1 1 * 1 1 * * * 1 1 *
# * * * * * * 1 1 1 1 * * * * 1 1
# l=1/r=0(br) l=1/r=0/s=2(br)
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
# x=-1/y=5
# * * * * * * * * * * * *
# * * * * * * * * * * * *
# 1 * * * * * 1 * * * * *
# 1 1 * * * * 1 1 * * * *
# 1 1 1 * * * ----> 1 1 1 * * *
# * 1 1 1 * * 1 1 1 1 * *
# * * 1 1 1 * 1 1 1 1 1 *
# * * * 1 1 1 1 1 * 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
# x=-1/y=8
# * * * * * * * * * *
# * * * * * * * * * *
# 1 * * * * ----> 1 * * * *
# 1 1 * * * 1 1 * * *
# 1 1 1 * * 1 1 1 * *
# 1 1 1 1 * 1 1 1 1 *
# 1 1 1 1 1 1 1 1 1 1
# 1 1 1 1 1 1 1 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1

View File

@@ -63,8 +63,8 @@
#define __gfx101__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
defined(__gfx10_3_generic__)
defined(__gfx1033__) || defined(__gfx1034__) || defined(__gfx1035__) || \
defined(__gfx1036__) || defined(__gfx10_3_generic__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \

View File

@@ -125,8 +125,9 @@ inline bool is_gfx101_supported()
inline bool is_gfx103_supported()
{
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1033" ||
ck::get_device_name() == "gfx1034" || ck::get_device_name() == "gfx1035" ||
ck::get_device_name() == "gfx1036";
}
inline bool is_wmma_supported()

View File

@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/numeric.hpp"
namespace ck {
namespace tensor_operation {
@@ -492,6 +493,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
c_grid_desc_m_n_container_.push_back(descs[I2]);
}
}
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
input_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()) *
Conv_N_ * Conv_C_ * sizeof(CDataType);
}
const ADataType* p_a_grid_;
@@ -512,6 +517,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
long_index_t c_space_size_bytes;
};
// Invoker
@@ -571,18 +578,47 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
DeviceOp::BGridDesc_K0_N_K1,
DeviceOp::CGridDesc_M_N,
true>;
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
if(stream_config.flush_cache)
{
// Clear input only for perf measurement.
// For non-grouped solver user has to clear input on his own.
const auto clear_input = [&]() {
if(i == 0)
{
hip_check_error(hipMemsetAsync(arg.p_c_grid_,
0,
arg.c_space_size_bytes,
stream_config.stream_id_));
}
};
ave_time += launch_and_time_kernel_with_preprocess_flush_cache(
stream_config,
clear_input,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
else
{
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
}
else
{
@@ -594,18 +630,47 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
DeviceOp::BGridDesc_K0_N_K1,
DeviceOp::CGridDesc_M_N,
false>;
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
if(stream_config.flush_cache)
{
// Clear input only for perf measurement.
// For non-grouped solver user has to clear input on his own.
const auto clear_input = [&]() {
if(i == 0)
{
hip_check_error(hipMemsetAsync(arg.p_c_grid_,
0,
arg.c_space_size_bytes,
stream_config.stream_id_));
}
};
ave_time += launch_and_time_kernel_with_preprocess_flush_cache(
stream_config,
clear_input,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
else
{
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
}
}
return ave_time;

View File

@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/numeric.hpp"
namespace ck {
namespace tensor_operation {
@@ -1050,6 +1051,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_right_pads_{input_right_pads}
{
CreateABCDesc<NDimSpatial>();
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
input_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()) *
Conv_N_ * Conv_C_ * sizeof(CDataType);
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
@@ -1216,6 +1221,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
long_index_t c_space_size_bytes;
};
// Invoker
@@ -1273,18 +1280,47 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
DeviceOp::BGridDesc_K0_N_K1,
DeviceOp::CGridDesc_M_N,
true>;
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
if(stream_config.flush_cache)
{
// Clear input only for perf measurement.
// For non-grouped solver user has to clear input on his own.
const auto clear_input = [&]() {
if(i == 0)
{
hip_check_error(hipMemsetAsync(arg.p_c_grid_,
0,
arg.c_space_size_bytes,
stream_config.stream_id_));
}
};
ave_time += launch_and_time_kernel_with_preprocess_flush_cache(
stream_config,
clear_input,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
else
{
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
}
else
{
@@ -1296,18 +1332,47 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
DeviceOp::BGridDesc_K0_N_K1,
DeviceOp::CGridDesc_M_N,
false>;
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
if(stream_config.flush_cache)
{
// Clear input only for perf measurement.
// For non-grouped solver user has to clear input on his own.
const auto clear_input = [&]() {
if(i == 0)
{
hip_check_error(hipMemsetAsync(arg.p_c_grid_,
0,
arg.c_space_size_bytes,
stream_config.stream_id_));
}
};
ave_time += launch_and_time_kernel_with_preprocess_flush_cache(
stream_config,
clear_input,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
else
{
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
}
}
return ave_time;

View File

@@ -1225,26 +1225,50 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
has_main_loop,
no_main_loop,
CTranspose>;
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_b_grid,
p_a_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.b_element_op_,
arg.a_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
if(stream_config.flush_cache)
{
return launch_and_time_kernel_with_preprocess_flush_cache(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_b_grid,
p_a_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.b_element_op_,
arg.a_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
}
else
{
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_b_grid,
p_a_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.b_element_op_,
arg.a_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
}
}
else
{
@@ -1264,26 +1288,50 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
has_main_loop,
no_main_loop,
CTranspose>;
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
if(stream_config.flush_cache)
{
return launch_and_time_kernel_with_preprocess_flush_cache(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
}
else
{
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
}
}
};
if(has_loop_in_all_gemm)

View File

@@ -88,6 +88,7 @@ enum struct amdgcn_target_id
GFX1030 = 0x1030,
GFX1031 = 0x1031,
GFX1032 = 0x1032,
GFX1033 = 0x1033,
GFX1034 = 0x1034,
GFX1035 = 0x1035,
GFX1036 = 0x1036,
@@ -284,6 +285,7 @@ constexpr auto get_compiler_target()
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1033, GFX1033);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
@@ -351,6 +353,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target_id(char const*
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1030", GFX1030);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1031", GFX1031);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1032", GFX1032);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1033", GFX1033);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1034", GFX1034);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1035", GFX1035);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1036", GFX1036);
@@ -607,6 +610,7 @@ CK_TILE_HOST_DEVICE constexpr auto get_compiler_target()
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1033, GFX1033);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
@@ -688,6 +692,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target(char const* tes
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1030", GFX1030);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1031", GFX1031);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1032", GFX1032);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1033", GFX1033);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1034", GFX1034);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1035", GFX1035);
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1036", GFX1036);

View File

@@ -15,8 +15,8 @@
#define __gfx101__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
defined(__gfx10_3_generic__)
defined(__gfx1033__) || defined(__gfx1034__) || defined(__gfx1035__) || \
defined(__gfx1036__) || defined(__gfx10_3_generic__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
@@ -405,6 +405,12 @@ struct amdgcn_compiler_target_state
static constexpr bool CK_TILE_ARCH_GFX1032 = false;
#endif // __gfx1032__
#if defined(__gfx1033__)
static constexpr bool CK_TILE_ARCH_GFX1033 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1033 = false;
#endif // __gfx1033__
#if defined(__gfx1034__)
static constexpr bool CK_TILE_ARCH_GFX1034 = true;
#else
@@ -537,6 +543,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1033, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1034, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1035, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1036, \

View File

@@ -39,7 +39,7 @@ CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
{
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
}
@@ -89,7 +89,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<T
{
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
return make_tuple(old_tuple[number<IRs>{}]...);
}
@@ -109,7 +109,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is..
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
}
@@ -120,7 +120,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is..
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};

View File

@@ -144,9 +144,11 @@ struct sequence
static_assert(MapOld2New::size() == size(),
"wrong! reorder map should have the same size as sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<remove_cvref_t<MapOld2New>>::value,
"wrong! invalid reorder map");
return reorder_new_to_old(typename sequence_map_inverse<MapOld2New>::type{});
return reorder_new_to_old(
typename sequence_map_inverse<remove_cvref_t<MapOld2New>>::type{});
}
CK_TILE_HOST_DEVICE static constexpr auto reverse()
@@ -548,163 +550,59 @@ struct sequence_reduce<Reduce, Seq>
};
#endif
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
// Sorts a sequence using constexpr insertion sort. O(1) template instantiation
// depth, replacing the recursive merge sort that created O(N log N) intermediate types.
namespace detail {
template <typename Values, typename Compare, typename IndexSeq>
struct sequence_sort_helper;
template <index_t... Vs, typename Compare, index_t... Idx>
struct sequence_sort_helper<sequence<Vs...>, Compare, sequence<Idx...>>
{
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl
struct sort_result
{
static constexpr bool choose_left = LeftValues::front() < RightValues::front();
static constexpr index_t chosen_value =
choose_left ? LeftValues::front() : RightValues::front();
static constexpr index_t chosen_id = choose_left ? LeftIds::front() : RightIds::front();
using new_merged_values = decltype(MergedValues::push_back(number<chosen_value>{}));
using new_merged_ids = decltype(MergedIds::push_back(number<chosen_id>{}));
using new_left_values = typename std::
conditional<choose_left, decltype(LeftValues::pop_front()), LeftValues>::type;
using new_left_ids =
typename std::conditional<choose_left, decltype(LeftIds::pop_front()), LeftIds>::type;
using new_right_values = typename std::
conditional<choose_left, RightValues, decltype(RightValues::pop_front())>::type;
using new_right_ids =
typename std::conditional<choose_left, RightIds, decltype(RightIds::pop_front())>::type;
using merge = sorted_sequence_merge_impl<new_left_values,
new_left_ids,
new_right_values,
new_right_ids,
new_merged_values,
new_merged_ids,
Comp>;
// this is output
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
static_array<index_t, sizeof...(Vs)> values;
static_array<index_t, sizeof...(Vs)> ids;
};
template <typename LeftValues,
typename LeftIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<LeftValues,
LeftIds,
sequence<>,
sequence<>,
MergedValues,
MergedIds,
Comp>
static constexpr sort_result compute()
{
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
};
constexpr index_t n = sizeof...(Vs);
sort_result r{{{Vs...}}, {{Idx...}}};
// insertion sort — O(N^2) constexpr steps, O(1) template depth
for(index_t i = 1; i < n; ++i)
{
for(index_t j = i; j > 0 && Compare{}(r.values[j], r.values[j - 1]); --j)
{
auto tv = r.values[j];
r.values[j] = r.values[j - 1];
r.values[j - 1] = tv;
auto ti = r.ids[j];
r.ids[j] = r.ids[j - 1];
r.ids[j - 1] = ti;
}
}
return r;
}
template <typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<sequence<>,
sequence<>,
RightValues,
RightIds,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
};
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename Comp>
struct sorted_sequence_merge
{
using merge = sorted_sequence_merge_impl<LeftValues,
LeftIds,
RightValues,
RightIds,
sequence<>,
sequence<>,
Comp>;
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
static constexpr index_t nsize = Values::size();
using split_unsorted_values = sequence_split<Values, nsize / 2>;
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
using left_unsorted_values = typename split_unsorted_values::left_type;
using left_unsorted_ids = typename split_unsorted_ids::left_type;
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
using left_sorted_values = typename left_sort::sorted_values;
using left_sorted_ids = typename left_sort::sorted_ids;
using right_unsorted_values = typename split_unsorted_values::right_type;
using right_unsorted_ids = typename split_unsorted_ids::right_type;
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
using right_sorted_values = typename right_sort::sorted_values;
using right_sorted_ids = typename right_sort::sorted_ids;
using merged_sorted = sorted_sequence_merge<left_sorted_values,
left_sorted_ids,
right_sorted_values,
right_sorted_ids,
Compare>;
using sorted_values = typename merged_sorted::merged_values;
using sorted_ids = typename merged_sorted::merged_ids;
static constexpr sort_result sorted = compute();
using sorted_values = sequence<sorted.values[Idx]...>;
using sorted_ids = sequence<sorted.ids[Idx]...>;
};
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<sequence<ValueX, ValueY>, sequence<IdX, IdY>, Compare>
{
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
using sorted_values = typename std::
conditional<choose_x, sequence<ValueX, ValueY>, sequence<ValueY, ValueX>>::type;
using sorted_ids =
typename std::conditional<choose_x, sequence<IdX, IdY>, sequence<IdY, IdX>>::type;
};
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<sequence<Value>, sequence<Id>, Compare>
{
using sorted_values = sequence<Value>;
using sorted_ids = sequence<Id>;
};
template <typename Compare>
struct sequence_sort_impl<sequence<>, sequence<>, Compare>
{
using sorted_values = sequence<>;
using sorted_ids = sequence<>;
};
} // namespace detail
template <typename Values, typename Compare>
struct sequence_sort
{
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::size(), 1>::type;
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
static constexpr index_t n = Values::size();
using idx_seq = make_index_sequence<n>;
// this is output
using type = typename sort::sorted_values;
using sorted2unsorted_map = typename sort::sorted_ids;
using helper = detail::sequence_sort_helper<remove_cvref_t<Values>, Compare, idx_seq>;
using type = typename helper::sorted_values;
using sorted2unsorted_map = typename helper::sorted_ids;
};
template <typename Values, typename Less, typename Equal>
@@ -782,10 +680,42 @@ struct sequence_unique_sort
using sorted2unsorted_map = typename uniquify::uniquified_ids;
};
// Validates that a sequence is a permutation of {0, 1, ..., N-1}.
// Uses a constexpr loop instead of instantiating sequence_sort.
namespace detail {
template <index_t... Is>
constexpr bool check_valid_sequence_map()
{
constexpr index_t n = sizeof...(Is);
if constexpr(n == 0)
{
return true;
}
else
{
constexpr index_t vals[] = {Is...};
static_array<bool, n> seen{};
for(index_t i = 0; i < n; ++i)
{
if(vals[i] < 0 || vals[i] >= n || seen[vals[i]])
return false;
seen[vals[i]] = true;
}
return true;
}
}
} // namespace detail
template <typename SeqMap>
struct is_valid_sequence_map
: std::is_same<typename arithmetic_sequence_gen<0, SeqMap::size(), 1>::type,
typename sequence_sort<SeqMap, less<index_t>>::type>
struct is_valid_sequence_map : std::false_type
{
};
template <index_t... Is>
struct is_valid_sequence_map<sequence<Is...>>
: std::integral_constant<bool, detail::check_valid_sequence_map<Is...>()>
{
};

View File

@@ -376,9 +376,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transf
constexpr auto all_up_dim_new_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
static_assert(
is_valid_sequence_map<remove_cvref_t<decltype(all_low_dim_old_top_ids)>>::value &&
is_valid_sequence_map<remove_cvref_t<decltype(all_up_dim_new_top_ids)>>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
@@ -443,8 +444,8 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
static_assert(is_valid_sequence_map<remove_cvref_t<decltype(all_old_top_ids)>>::value &&
is_valid_sequence_map<remove_cvref_t<decltype(all_new_top_ids)>>::value,
"wrong!");
}

View File

@@ -135,65 +135,147 @@ struct idx_identity
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class Orders>
struct static_ford_impl
// Computes the inverse of a permutation as a constexpr array.
// Avoids the sequence_map_inverse -> is_valid_sequence_map -> sequence_sort chain.
template <class Perm>
struct inverse_perm;
template <index_t... Ps>
struct inverse_perm<sequence<Ps...>>
{
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
static constexpr auto compute()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
constexpr index_t n = sizeof...(Ps);
static_array<index_t, n> result{};
constexpr index_t input[] = {Ps...};
for(index_t i = 0; i < n; ++i)
{
result[input[i]] = i;
}
return result;
}
static constexpr auto value = compute();
};
// Decomposes a linear index into multi-dimensional indices using pre-computed
// strides. Uses a single flat static_for instead of recursive nesting, which
// eliminates intermediate lambda closure instantiations.
template <class OrderedLengths, class IndexSeq>
struct index_decomposer;
template <index_t... Ls, index_t... Is>
struct index_decomposer<sequence<Ls...>, sequence<Is...>>
{
static constexpr index_t n_dim = sizeof...(Ls);
static constexpr static_array<index_t, n_dim> lengths = {{Ls...}};
static constexpr static_array<index_t, n_dim> compute_all_strides()
{
static_array<index_t, n_dim> result{};
if constexpr(n_dim > 0)
{
result[n_dim - 1] = 1;
for(index_t i = n_dim - 1; i > 0; --i)
{
result[i - 1] = result[i] * lengths[i];
}
}
return result;
}
// F signature: F(sequence<...>)
// CurrentOrderedId: sequence<...>
template <class F, class CurrentOrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
static constexpr static_array<index_t, n_dim> strides = compute_all_strides();
// Compile-time decomposition: linear index -> sequence of per-dimension indices
template <index_t LinearIdx>
using decompose = sequence<((LinearIdx / strides[Is]) % lengths[Is])...>;
// Decompose AND reorder in one step using a pre-computed inverse permutation.
// Produces the unordered multi-index directly, avoiding per-iteration
// reorder_old_to_new member function instantiations on each unique sequence type.
template <index_t LinearIdx, class New2Old>
using decompose_reordered = sequence<((LinearIdx / strides[inverse_perm<New2Old>::value[Is]]) %
lengths[inverse_perm<New2Old>::value[Is]])...>;
};
// Calls f(decompose<I>{}) for each linear index I in the pack, using a single
// fold expression. Bypasses the static_for lambda entirely, eliminating M*N
// intermediate lambda closure instantiations that the lambda-based approach creates.
template <class Decomposer, class LinearIdxSeq>
struct ford_applier;
template <class Decomposer, index_t... LinearIds>
struct ford_applier<Decomposer, sequence<LinearIds...>>
{
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
f, CurrentOrderedId::push_back(I));
});
if constexpr(sizeof...(LinearIds) > 0)
{
(f(typename Decomposer::template decompose<LinearIds>{}), ...);
}
}
};
template <class Orders>
struct static_ford_impl<sequence<>, Orders>
// Same as ford_applier but applies reordering during decomposition.
template <class Decomposer, class New2Old, class LinearIdxSeq>
struct ford_applier_reordered;
template <class Decomposer, class New2Old, index_t... LinearIds>
struct ford_applier_reordered<Decomposer, New2Old, sequence<LinearIds...>>
{
// F signature: F(sequence<...>)
// OrderedId: sequence<...>
template <class F, class OrderedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
// retrive unordered Id
f(OrderedId::reorder_old_to_new(Orders{}));
if constexpr(sizeof...(LinearIds) > 0)
{
(f(typename Decomposer::template decompose_reordered<LinearIds, New2Old>{}), ...);
}
}
};
} // namespace detail
// Lengths is sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is sequence<...>, it is the order of dimension in which static_ford
// will loop over each
// dimension
// Compile-time N-dimensional loop with static multi-indices.
// Uses direct fold expansion with index decomposition, producing zero
// intermediate lambda closures. Each iteration calls f with a compile-time
// sequence<i0, i1, ...> containing the multi-dimensional index.
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_ford
{
static constexpr index_t n_dim = Lengths::size();
static constexpr index_t total_size =
reduce_on_sequence(Lengths{}, multiplies<>{}, number<1>{});
static constexpr bool is_identity_order = std::is_same_v<Orders, make_index_sequence<n_dim>>;
// For identity order, OrderedLengths == Lengths (no reorder needed).
// For non-identity, reorder lengths according to iteration order.
// Both branches must be valid types, but only the active one is used.
using OrderedLengths =
std::conditional_t<is_identity_order,
Lengths,
remove_cvref_t<decltype(Lengths::reorder_new_to_old(Orders{}))>>;
using Decomposer = detail::index_decomposer<OrderedLengths, make_index_sequence<n_dim>>;
CK_TILE_HOST_DEVICE constexpr static_ford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
}
// F signature: F(sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
if constexpr(is_identity_order)
{
detail::ford_applier<Decomposer, make_index_sequence<total_size>>{}(f);
}
else
{
detail::ford_applier_reordered<Decomposer, Orders, make_index_sequence<total_size>>{}(
f);
}
}
};

View File

@@ -103,6 +103,42 @@ struct Max
}
};
struct Min
{
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::max();
};
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return min(y, x);
}
// Overload with changed flag for index tracking
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
{
T new_min = min(y, x);
if(x < y)
{
changed = true;
}
return new_min;
}
};
struct AbsMax
{
template <

View File

@@ -129,7 +129,13 @@ struct CShuffleEpilogue
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr bool EightWave = (MWave * NWave == 8);
#if defined(__gfx9__)
static constexpr bool EightWave = (MWave * NWave == 8);
#else
static constexpr bool EightWave = false;
#endif
static constexpr index_t BlockedXDLN_PerWarp =
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;

View File

@@ -1516,6 +1516,7 @@ struct FmhaBwdOGradDotOKernel
using DDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::DDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::ODataType>;
using OGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::OGradDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::LSEDataType>;
static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
@@ -1557,25 +1558,31 @@ struct FmhaBwdOGradDotOKernel
const void* o_ptr;
const void* do_ptr;
void* d_ptr;
const void* lse_ptr; // log-sum-exp from forward pass, shape [batch, nhead, seqlen_q]
const LSEDataType* sink_ptr; // sink scores, shape [batch, nhead]; nullptr disables sink
LSEDataType* d_sink_ptr; // sink gradient output, shape [nhead]; nullptr disables sink grad
float p_undrop;
ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead; // used to index sink_ptr / d_sink_ptr
ck_tile::index_t stride_do;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_d;
// LSE and D always share the same layout; this stride covers both.
ck_tile::index_t nhead_stride_lsed;
};
struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs
{
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_d;
// LSE and D always share the same layout; this stride covers both.
ck_tile::index_t batch_stride_lsed;
};
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
@@ -1593,32 +1600,40 @@ struct FmhaBwdOGradDotOKernel
MakeKargs(const void* o_ptr,
const void* do_ptr,
void* d_ptr,
const void* lse_ptr,
const void* sink_ptr,
void* d_sink_ptr,
float p_undrop,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t nhead,
ck_tile::index_t stride_do,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_d)
ck_tile::index_t batch_stride_lsed)
{
Kargs kargs{{o_ptr,
do_ptr,
d_ptr,
lse_ptr,
reinterpret_cast<const LSEDataType*>(sink_ptr),
reinterpret_cast<LSEDataType*>(d_sink_ptr),
p_undrop,
seqlen_q,
hdim_v,
nhead,
stride_do,
stride_o,
nhead_stride_do,
nhead_stride_o,
nhead_stride_d},
nhead_stride_lsed},
batch_stride_do,
batch_stride_o,
batch_stride_d};
batch_stride_lsed};
return kargs;
}
@@ -1628,28 +1643,36 @@ struct FmhaBwdOGradDotOKernel
MakeKargs(const void* o_ptr,
const void* do_ptr,
void* d_ptr,
const void* lse_ptr,
const void* sink_ptr,
void* d_sink_ptr,
float p_undrop,
const void* seqstart_q_ptr,
const void* seqlen_q_ptr,
const void* cu_seqlen_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t nhead,
ck_tile::index_t stride_do,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d)
ck_tile::index_t nhead_stride_lsed)
{
Kargs kargs{{o_ptr,
do_ptr,
d_ptr,
lse_ptr,
reinterpret_cast<const LSEDataType*>(sink_ptr),
reinterpret_cast<LSEDataType*>(d_sink_ptr),
p_undrop,
-1, // seqlen will be updated by another pointer
hdim_v,
nhead,
stride_do,
stride_o,
nhead_stride_do,
nhead_stride_o,
nhead_stride_d},
nhead_stride_lsed},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
@@ -1683,18 +1706,18 @@ struct FmhaBwdOGradDotOKernel
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
long_index_t batch_offset_o = 0;
long_index_t batch_offset_do = 0;
long_index_t batch_offset_d = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_do = 0;
long_index_t batch_offset_lsed = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.stride_o;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_d = query_start;
batch_offset_o = query_start * kargs.stride_o;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = query_start;
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
if(kargs.cu_seqlen_q_ptr != nullptr)
@@ -1722,11 +1745,20 @@ struct FmhaBwdOGradDotOKernel
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
}
// Read per-head sink score and convert to log2 domain so the pipeline can use exp2.
// Pre-multiply by log2e so that exp2(sink_value - log2e*lse) == exp(raw_sink - lse).
// -inf is left unchanged (log2e * -inf == -inf) to keep P_sink -> 0 when sink is disabled.
const LSEDataType sink_value =
kargs.sink_ptr != nullptr
? log2e_v<LSEDataType> *
kargs.sink_ptr[static_cast<long_index_t>(i_batch) * kargs.nhead + i_nhead]
: -numeric<LSEDataType>::infinity();
// for simplicity, batch stride we just modify the pointer
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
@@ -1734,9 +1766,13 @@ struct FmhaBwdOGradDotOKernel
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
batch_offset_do;
const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
batch_offset_lsed;
DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
batch_offset_d;
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
batch_offset_lsed;
// O/dO/D DRAM and DRAM window
const auto o_dram = [&]() {
@@ -1770,13 +1806,31 @@ struct FmhaBwdOGradDotOKernel
auto o_dram_window =
make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
auto do_dram_window =
make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
// nullptr when sink grad is disabled; the pipeline checks this to skip the sink path
LSEDataType* atomic_sink_grad_ptr =
kargs.d_sink_ptr == nullptr ? nullptr : kargs.d_sink_ptr + i_nhead;
// lse_ptr is always valid (also needed by the main bwd kernel).
// The actual load happens inside the pipeline only when atomic_sink_grad_ptr != nullptr.
auto lse_dram = [&]() {
const auto lse_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
lse_ptr, make_tuple(kargs.seqlen_q), number<1>{});
return pad_tensor_view(
lse_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
}();
auto lse_dram_window = make_tile_window(lse_dram, make_tuple(number<kM0>{}), {i_m0});
FmhaBwdOGradDotO{}(o_dram_window,
do_dram_window,
lse_dram_window,
d_dram_window,
sink_value,
kargs.p_undrop,
atomic_sink_grad_ptr);
}
};

View File

@@ -14,6 +14,7 @@ struct BlockFmhaBwdOGradDotO
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; // needed for sink gradient
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
@@ -32,11 +33,18 @@ struct BlockFmhaBwdOGradDotO
template <typename ODramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp>
// Computes D = diag(dO * O) and optionally accumulates the sink token gradient.
// sink_value: log-space sink score; pass -inf and atomic_sink_grad_ptr=nullptr to skip sink.
// atomic_sink_grad_ptr: per-head accumulator in global memory; nullptr disables sink path.
CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
DDramBlockWindowTmp& d_dram_block_window_tmp,
float p_undrop) const
const LSEDataType sink_value,
float p_undrop,
LSEDataType* atomic_sink_grad_ptr = nullptr) const
{
static_assert(
std::is_same_v<ODataType, remove_cvref_t<typename ODramBlockWindowTmp::DataType>> &&
@@ -44,6 +52,10 @@ struct BlockFmhaBwdOGradDotO
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
"wrong!");
// atomic_sink_grad_ptr is reinterpret_cast to float* in the sink path;
// ensure LSEDataType is float so the cast is well-defined.
static_assert(std::is_same_v<LSEDataType, float>,
"sink gradient atomicAdd requires LSEDataType == float");
static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kBlockSize ==
@@ -67,14 +79,13 @@ struct BlockFmhaBwdOGradDotO
auto do_ = load_tile(do_dram_window);
// declare d
// D[q] = sum_j(O[q,j] * dO[q,j]), used in softmax backward
constexpr auto d_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
auto d = make_static_distributed_tensor<DDataType>(d_dstr);
clear_tile(d); // Initialize D
clear_tile(d);
constexpr auto o_spans = decltype(o)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
@@ -86,9 +97,67 @@ struct BlockFmhaBwdOGradDotO
});
});
// Scale by p_undrop (=1 when dropout is disabled)
tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d);
store_tile(d_dram_block_window_tmp, d);
// Sink gradient path: skipped entirely when atomic_sink_grad_ptr is nullptr
if(atomic_sink_grad_ptr != nullptr)
{
// Load LSE only on the sink path to avoid unnecessary global memory reads
constexpr auto lse_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
o.get_tile_distribution().get_static_tile_distribution_encoding(),
sequence<1>{}));
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
lse_dram_block_window_tmp.get_window_origin(),
lse_dstr);
auto lse_ = load_tile(lse_dram_window);
// Compute per-query contribution: -P_sink[q] * D[q]
// where P_sink[q] = exp2(sink_value - log2e*lse[q])
// sink_value has already been pre-multiplied by log2e at the kernel call site,
// so exp2(sink_value - log2e*lse) == exp(raw_sink - lse).
// exp2 maps directly to the v_exp_f32 hardware instruction on AMD GPUs.
// Always accumulate in float regardless of DDataType to avoid precision loss
// and to ensure atomicAdd works correctly on all architectures.
auto sink_val_tensor = make_static_distributed_tensor<float>(d_dstr);
tile_elementwise_inout(
[&](auto& s_out, const auto& l_in, const auto& d_in) {
float p_sink = exp2(type_convert<float>(sink_value) -
log2e_v<float> * type_convert<float>(l_in));
s_out = -p_sink * type_convert<float>(d_in);
},
sink_val_tensor,
lse_,
d);
// Reduce contributions held by this thread
float thread_sum = 0.f;
constexpr auto s_spans = decltype(sink_val_tensor)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
thread_sum += sink_val_tensor(i_idx);
});
// Warp-level reduction: fold thread_sum across lanes so only one
// atomicAdd per warp is issued instead of one per thread.
#if defined(__HIP_DEVICE_COMPILE__) || defined(__CUDA_ARCH__)
const index_t warp_sz = get_warp_size();
for(index_t offset = warp_sz >> 1; offset > 0; offset >>= 1)
thread_sum += warp_shuffle_down(thread_sum, offset);
// Only lane 0 of each warp writes to global memory.
// Note: this atomicAdd is non-deterministic across runs regardless of the
// -deterministic flag, because d_sink is a single scalar per head accumulated
// across all thread-blocks. The practical impact is negligible for this value.
if(get_lane_id() == 0)
atomicAdd(reinterpret_cast<float*>(atomic_sink_grad_ptr), thread_sum);
#endif
}
}
};

View File

@@ -163,7 +163,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
amd_wave_read_first_lane(integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0));
// check early exit if no work to do.
if(num_total_loop <= 0)
// __builtin_expect is load-bearing: omitting it causes incorrect AGPR allocation in
// the dK/dV accumulation loop on some compiler versions, leading to wrong results.
if(__builtin_expect(num_total_loop <= 0, 0))
{
// Note: here dk_acc&dv_acc are all cleared, return it
return make_tuple(dk_acc, dv_acc);

View File

@@ -67,6 +67,7 @@ struct BlockFmhaBwdPipelineProblem
template <typename ODataType_,
typename OGradDataType_,
typename DDataType_,
typename LSEDataType_,
index_t kBlockSize_,
index_t kVHeaddim_,
bool kIsGroupMode_,
@@ -76,6 +77,7 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
using ODataType = remove_cvref_t<ODataType_>;
using OGradDataType = remove_cvref_t<OGradDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using Traits = remove_cvref_t<Traits_>;
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,

View File

@@ -1226,23 +1226,37 @@ struct UniversalGemmKernel
s_waitcnt_barrier();
const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
// Apply pivot to M tile index first, then use the same pivoted index
// for both data-tile selection and chunk-signal wait.
auto iM_eff = amd_wave_read_first_lane(iM);
if(kargs.async_input_scheduler.chunk_signals != nullptr)
{
const auto tile_idx_pivot =
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
const auto tiles_m = amd_wave_read_first_lane(
integer_divide_ceil(kargs.M, TilePartitioner::MPerBlock));
if(tiles_m > 0)
{
iM_eff = amd_wave_read_first_lane((iM_eff + tile_idx_pivot) % tiles_m);
}
}
const index_t i_m = amd_wave_read_first_lane(iM_eff * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
// Synchronize with producer to ensure input data is ready before processing tile
if(kargs.async_input_scheduler.chunk_signals != nullptr)
{
const auto tiles_per_chunk =
amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m);
const auto tile_idx_pivot =
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
const auto num_chunks =
amd_wave_read_first_lane(kargs.async_input_scheduler.num_chunks);
if(tiles_per_chunk > 0 && num_chunks > 0)
{
// Pivot allows rotating chunk assignments for load balancing
const auto chunk_idx = amd_wave_read_first_lane(
((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks);
const auto chunk_idx =
amd_wave_read_first_lane((iM_eff / tiles_per_chunk) % num_chunks);
workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals);
chunk_barrier.wait_eq_wave(/*value=*/1, /*offset=*/chunk_idx);
}

View File

@@ -903,13 +903,11 @@ struct GroupedConvolutionBackwardDataKernel
const auto& d_block_window =
MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
a_block_window, b_block_window, num_loop, smem_ptr_0);
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);

View File

@@ -65,7 +65,9 @@ run_grouped_conv_backward_data_tile_algs(const ckt::Args<SIGNATURE>& args,
const ckt::Outputs<SIGNATURE>& outputs,
const ck_tile::stream_config& s_conf)
{
float best_avg_time = std::numeric_limits<float>::max();
// Run first instance as dummy to get proper time from the first instance
bool dummy_run_executed = false;
float best_avg_time = std::numeric_limits<float>::max();
std::string best_op_name, op_name;
int best_split_k = 0;
ck::index_t best_instance_index = -1;
@@ -121,6 +123,13 @@ run_grouped_conv_backward_data_tile_algs(const ckt::Args<SIGNATURE>& args,
run_alg_func(args_k_batch, inputs, outputs, s_conf);
if(is_supported)
{
if((s_conf.time_kernel_ || s_conf.flush_cache_) && !dummy_run_executed)
{
// Run first instance twice
std::tie(is_supported, avg_time, op_name) =
run_alg_func(args_k_batch, inputs, outputs, s_conf);
dummy_run_executed = true;
}
ckt::ValidationReport report;
auto&& [rtol, atol] =
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);

View File

@@ -106,17 +106,17 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
{
ckt::Args<SIGNATURE> args_k_batch = args;
args_k_batch.k_batch = k_batch;
if((s_conf.time_kernel_ || s_conf.flush_cache_) && !dummy_run_executed)
{
// Run first instance twice when profiling to stabilize timing
std::tie(is_supported, avg_time, op_name) =
run_alg_func(args_k_batch, inputs, outputs, s_conf);
dummy_run_executed = true;
}
std::tie(is_supported, avg_time, op_name) =
run_alg_func(args_k_batch, inputs, outputs, s_conf);
if(is_supported)
{
if((s_conf.time_kernel_ || s_conf.flush_cache_) && !dummy_run_executed)
{
// Run first instance twice when profiling to stabilize timing
std::tie(is_supported, avg_time, op_name) =
run_alg_func(args_k_batch, inputs, outputs, s_conf);
dummy_run_executed = true;
}
ckt::ValidationReport report;
auto&& [rtol, atol] =
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);

View File

@@ -86,15 +86,16 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
auto ref_conv = ReferenceInstance{};
auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
auto run_alg = [&](auto&& run_alg_func) {
if(!dummy_run_executed)
{
// Run first instance twice
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
dummy_run_executed = true;
}
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
if(is_supported)
{
if((s_conf.time_kernel_ || s_conf.flush_cache_) && !dummy_run_executed)
{
// Run first instance twice
std::tie(is_supported, avg_time, op_name) =
run_alg_func(args, inputs, outputs, s_conf);
dummy_run_executed = true;
}
best_avg_time = std::min(best_avg_time, avg_time);
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name

View File

@@ -197,10 +197,11 @@ bool profile_conv_bwd_data_impl(int do_verification,
}
std::string best_op_name;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
bool dummy_run_executed = false;
for(auto& op_ptr : op_ptrs)
{
@@ -230,16 +231,38 @@ bool profile_conv_bwd_data_impl(int do_verification,
// skip test if instance_index is specified
continue;
}
// for conv bwd data, some input tensor element are zero, but not written by kernel,
// need to set zero
in_device_buf.SetZero();
if(!time_kernel)
{
// Don't clear for perf measurement.
// For non-grouped solver user has to clear input on his own.
// for conv bwd data, some input tensor element are zero, but not written by kernel,
// need to set zero
in_device_buf.SetZero();
}
std::string op_name = op_ptr->GetTypeString();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// Run first instance twice to get proper time
if(time_kernel && !dummy_run_executed)
{
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0 /*log_level*/,
5 /*cold_iters*/,
50 /*nrepeat_*/,
time_kernel /*flush_cache*/});
dummy_run_executed = true;
}
float avg_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0 /*log_level*/,
5 /*cold_iters*/,
50 /*nrepeat_*/,
time_kernel /*flush_cache*/});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();

View File

@@ -287,6 +287,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
bool pass = true;
index_t num_kernel = 0;
index_t valid_instances = 0;
bool dummy_run_executed = false;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr, const index_t& split_k_for_run) {
// workspace_sz will be equal to 0 for other layout than NGCHW
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
@@ -317,8 +319,25 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
auto invoker_ptr = op_ptr->MakeInvokerPointer();
float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// Run first instance twice to get proper time
if(time_kernel && !dummy_run_executed)
{
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0 /*log_level*/,
5 /*cold_iters*/,
50 /*nrepeat_*/,
time_kernel /*flush_cache*/});
dummy_run_executed = true;
}
float avg_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0 /*log_level*/,
5 /*cold_iters*/,
50 /*nrepeat_*/,
time_kernel /*flush_cache*/});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
@@ -495,7 +514,6 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
{
std::cout << "\nValid instances for this problem:" << std::endl;
}
for(auto& op_ptr : op_ptrs)
{
for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)

View File

@@ -296,7 +296,8 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
index_t best_instance_index = 0;
// profile device op instances
bool pass = true;
bool pass = true;
bool dummy_run_executed = false;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
// workspace_sz will be equal to 0 for other layout than NGCHW
@@ -331,6 +332,19 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
auto invoker_ptr = op_ptr->MakeInvokerPointer();
// Run first instance twice to get proper time
if(time_kernel && !dummy_run_executed)
{
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0 /*log_level*/,
5 /*cold_iters*/,
50 /*nrepeat_*/,
time_kernel /*flush_cache*/});
dummy_run_executed = true;
}
float avg_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
@@ -437,30 +451,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
std::cout << "\nValid instances for this problem:" << std::endl;
}
// Run first instance twice to get proper time
{
auto argument_ptr = op_ptrs[0]->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
{},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
{},
{},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
run_impl(op_ptrs[0], argument_ptr);
}
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),

View File

@@ -89,7 +89,8 @@ int call_profiler(const ckt::Args<SIGNATURE>& args,
0 /*log_level*/,
5 /*cold_iters*/,
50 /*nrepeat_*/,
true /*is_gpu_timer_*/});
true /*is_gpu_timer_*/,
time_kernel /*flush_cache*/});
if(time_kernel)
{
std::cout << "\nBest configuration parameters:" << "\n\tname: " << op_name << " (instance "

View File

@@ -69,4 +69,4 @@ add_subdirectory(fmha)
add_subdirectory(gemm_tile_engine)
add_subdirectory(pooling)
add_subdirectory(grouped_conv)
add_subdirectory(gemm_streamk_tile_engine)
add_subdirectory(pooling_tile_engine)

View File

@@ -355,6 +355,102 @@ TEST(SequenceSort, SortSingleElement)
EXPECT_TRUE((std::is_same<Result, Expected>::value));
}
// Test sequence_sort sorted2unsorted_map (index tracking)
TEST(SequenceSort, SortedMapUnsorted)
{
using Seq = sequence<5, 2, 8, 1, 9>;
using Sort = sequence_sort<Seq, less<index_t>>;
using Map = typename Sort::sorted2unsorted_map;
// sorted = <1,2,5,8,9>, original indices = <3,1,0,2,4>
using Expected = sequence<3, 1, 0, 2, 4>;
EXPECT_TRUE((std::is_same<Map, Expected>::value));
}
TEST(SequenceSort, SortedMapAlreadySorted)
{
using Seq = sequence<1, 2, 3, 4, 5>;
using Sort = sequence_sort<Seq, less<index_t>>;
using Map = typename Sort::sorted2unsorted_map;
// Already sorted: map should be identity
using Expected = sequence<0, 1, 2, 3, 4>;
EXPECT_TRUE((std::is_same<Map, Expected>::value));
}
TEST(SequenceSort, SortedMapRoundTrip)
{
// Verify: sorted_values[i] == original[sorted2unsorted_map[i]]
using Seq = sequence<5, 2, 8, 1, 9>;
using Sort = sequence_sort<Seq, less<index_t>>;
// sorted = <1,2,5,8,9>, map = <3,1,0,2,4>
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(0)), Sort::type::at(0));
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(1)), Sort::type::at(1));
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(2)), Sort::type::at(2));
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(3)), Sort::type::at(3));
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(4)), Sort::type::at(4));
}
TEST(SequenceSort, SortedMapWithDuplicates)
{
using Seq = sequence<3, 1, 3, 1>;
using Sort = sequence_sort<Seq, less<index_t>>;
using Sorted = typename Sort::type;
using Map = typename Sort::sorted2unsorted_map;
// sorted = <1,1,3,3>
using ExpectedSorted = sequence<1, 1, 3, 3>;
EXPECT_TRUE((std::is_same<Sorted, ExpectedSorted>::value));
// Verify round-trip: original[map[i]] == sorted[i] for all i
// (don't assert specific index order for duplicates — sort stability may vary)
EXPECT_EQ(Seq::at(Map::at(0)), Sorted::at(0));
EXPECT_EQ(Seq::at(Map::at(1)), Sorted::at(1));
EXPECT_EQ(Seq::at(Map::at(2)), Sorted::at(2));
EXPECT_EQ(Seq::at(Map::at(3)), Sorted::at(3));
}
TEST(SequenceSort, SortedMapReverseSorted)
{
using Seq = sequence<5, 4, 3, 2, 1>;
using Sort = sequence_sort<Seq, less<index_t>>;
using Sorted = typename Sort::type;
using Map = typename Sort::sorted2unsorted_map;
using ExpSorted = sequence<1, 2, 3, 4, 5>;
using ExpMap = sequence<4, 3, 2, 1, 0>;
EXPECT_TRUE((std::is_same<Sorted, ExpSorted>::value));
EXPECT_TRUE((std::is_same<Map, ExpMap>::value));
}
TEST(SequenceSort, SortedMapEmpty)
{
using Sort = sequence_sort<sequence<>, less<index_t>>;
using Map = typename Sort::sorted2unsorted_map;
EXPECT_TRUE((std::is_same<Map, sequence<>>::value));
}
TEST(SequenceSort, SortedMapSingleElement)
{
using Sort = sequence_sort<sequence<42>, less<index_t>>;
using Map = typename Sort::sorted2unsorted_map;
EXPECT_TRUE((std::is_same<Map, sequence<0>>::value));
}
// Test sequence_unique_sort sorted2unsorted_map
TEST(SequenceUniqueSort, UniqueSortMap)
{
using Seq = sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>;
using Result = sequence_unique_sort<Seq, less<index_t>, equal<index_t>>;
using Map = typename Result::sorted2unsorted_map;
// sorted unique = <1,2,3,4,5,6,9>
// The map should reference the first occurrence of each unique value in the original
// Verify round-trip: for each i, original[map[i]] == sorted_unique[i]
using Values = typename Result::type;
EXPECT_EQ(Seq::at(Map::at(0)), Values::at(0)); // 1
EXPECT_EQ(Seq::at(Map::at(1)), Values::at(1)); // 2
EXPECT_EQ(Seq::at(Map::at(2)), Values::at(2)); // 3
EXPECT_EQ(Seq::at(Map::at(3)), Values::at(3)); // 4
EXPECT_EQ(Seq::at(Map::at(4)), Values::at(4)); // 5
EXPECT_EQ(Seq::at(Map::at(5)), Values::at(5)); // 6
EXPECT_EQ(Seq::at(Map::at(6)), Values::at(6)); // 9
}
// Test sequence_unique_sort
TEST(SequenceUniqueSort, UniqueSort)
{
@@ -405,6 +501,24 @@ TEST(SequenceMap, InvalidMapMissing)
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
}
TEST(SequenceMap, InvalidMapNegative)
{
using Map = sequence<0, -1, 2>;
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
}
TEST(SequenceMap, ValidMapSingleElement)
{
EXPECT_TRUE((is_valid_sequence_map<sequence<0>>::value));
}
TEST(SequenceMap, InvalidMapSingleElement)
{
EXPECT_FALSE((is_valid_sequence_map<sequence<1>>::value));
}
TEST(SequenceMap, ValidMapEmpty) { EXPECT_TRUE((is_valid_sequence_map<sequence<>>::value)); }
// Test sequence_map_inverse
// Note: sequence_map_inverse inverts a mapping where Map[i] = j means old position i maps to new
// position j The inverse gives us new position i came from old position inverse[i]

View File

@@ -91,7 +91,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param)
drop_offset,
drop_prefs,
mask_str,
det, // deterministic
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -333,7 +334,8 @@ TEST_P(BasicQPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -419,7 +421,8 @@ TEST_P(BasicKVPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -513,7 +516,8 @@ TEST_P(QKVPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -620,7 +624,8 @@ TEST_P(ZeroLengthPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -741,7 +746,8 @@ TEST_P(VariedPaddingRatios, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -843,7 +849,8 @@ TEST_P(PaddingWithMask, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -977,7 +984,8 @@ TEST_P(MultiBatchPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,

View File

@@ -19,55 +19,93 @@ set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
set(STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
test_gemm_streamk_util.cpp)
# We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a
if(GPU_TARGETS MATCHES "gfx942|gfx950")
list(APPEND STREAMK_EXTENDED_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
endif()
# ---- Code-generate test .cpp files from types header ----
set(STREAMK_TYPES_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_types.hpp)
set(STREAMK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/generate_test_files.py)
add_gtest_executable(test_ck_tile_streamk_extended ${STREAMK_EXTENDED_SOURCES})
target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# Re-run configure automatically if the types header changes (e.g. types added/removed)
# or if the generation script changes
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT})
# Define the targets and their corresponding executable names
set(STREAMK_GEN_TARGETS extended atomic_smoke linear_smoke tree_smoke pipelines_smoke)
set(STREAMK_GEN_EXEC_EXTENDED test_ck_tile_streamk_extended)
set(STREAMK_GEN_EXEC_ATOMIC_SMOKE test_ck_tile_streamk_atomic_smoke)
set(STREAMK_GEN_EXEC_LINEAR_SMOKE test_ck_tile_streamk_linear_smoke)
set(STREAMK_GEN_EXEC_TREE_SMOKE test_ck_tile_streamk_tree_smoke)
set(STREAMK_GEN_EXEC_PIPELINES_SMOKE test_ck_tile_streamk_pipelines_smoke)
# Collect all test targets for umbrella label
set(CK_TILE_GEMM_STREAMK_TEST_TARGETS
test_ck_tile_streamk_tile_partitioner
test_ck_tile_streamk_extended
test_ck_tile_streamk_tile_partitioner)
foreach(target IN LISTS STREAMK_GEN_TARGETS)
string(TOUPPER ${target} TARGET_UPPER)
set(GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/${target})
set(EXEC_NAME ${STREAMK_GEN_EXEC_${TARGET_UPPER}})
set(LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/${target}_files.txt)
# Phase 1 (configure time): discover the list of files that will be generated
execute_process(
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
--types_header ${STREAMK_TYPES_HEADER}
--output_dir ${GEN_DIR}
--target ${target}
--list_files ${LIST_FILE}
RESULT_VARIABLE ret
ERROR_VARIABLE list_files_stderr)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR
"Failed to list ${target} test files via Python: ${ret}\n"
"stderr: ${list_files_stderr}"
)
endif()
file(STRINGS ${LIST_FILE} ALL_SOURCES_${target})
# Phase 2 (build time): generate the .cpp files when the types header changes
add_custom_command(
OUTPUT ${ALL_SOURCES_${target}}
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
--types_header ${STREAMK_TYPES_HEADER}
--output_dir ${GEN_DIR}
--target ${target}
--gen_files
DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT}
COMMENT "Generating StreamK ${target} test sources from types header")
# Filter out fp8/bf8 sources on gfx90a since those types are not natively supported
set(FILTERED_SOURCES)
foreach(src IN LISTS ALL_SOURCES_${target})
if(NOT src MATCHES "_(fp8|bf8)_" OR GPU_TARGETS MATCHES "gfx942|gfx950")
list(APPEND FILTERED_SOURCES ${src})
endif()
endforeach()
list(APPEND FILTERED_SOURCES test_gemm_streamk_util.cpp)
add_gtest_executable(${EXEC_NAME} ${FILTERED_SOURCES})
target_compile_options(${EXEC_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
list(APPEND CK_TILE_GEMM_STREAMK_TEST_TARGETS ${EXEC_NAME})
endforeach()
# Add python unit tests to validate the code gen logic in generate_test_files.py
add_test(
NAME test_ck_tile_streamk_generate_test_files
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_generate_test_files.py -v
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..
)
# Label all ck_tile gemm_streamk tests with CK_TILE_GEMM_STREAMK_TESTS for selective execution
foreach(test_target ${CK_TILE_GEMM_STREAMK_TEST_TARGETS})
set_tests_properties(${test_target} PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
endforeach()
# Also label the Python test
set_tests_properties(test_ck_tile_streamk_generate_test_files PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
# Umbrella target to build and run all ck_tile gemm_streamk tests
# Usage: ninja ck_tile_gemm_streamk_tests

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV3,
KernelTypesStreamKBf16NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,18 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4,
KernelTypesStreamKBf16NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentMem, KernelTypesStreamKBf16NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV3, KernelTypesStreamKBf16PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentMem, KernelTypesStreamKBf16PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV3, KernelTypesStreamKBf8NonPersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentMem, KernelTypesStreamKBf8NonPersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8PersistentCompV3 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV3
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV3, KernelTypesStreamKBf8PersistentCompV3);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8PersistentMem : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentMem
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentMem, KernelTypesStreamKBf8PersistentMem);
#include "test_gemm_streamk_extended_cases.inc"
#undef TEST_SUITE_NAME

Some files were not shown because too many files have changed in this diff Show More