mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Merge origin/develop into users/yiding12/fmha-bwd-workspace
This commit is contained in:
25
.gitignore
vendored
25
.gitignore
vendored
@@ -114,3 +114,28 @@ experimental/grouped_convolution_tile_instances/instances/*
|
||||
!experimental/grouped_convolution_tile_instances/instances/*.inc
|
||||
!experimental/grouped_convolution_tile_instances/instances/*.hpp
|
||||
experimental/grouped_convolution_tile_instances/*.inc
|
||||
# Heuristics: benchmark data (never in git)
|
||||
dispatcher/heuristics/data/
|
||||
|
||||
# Heuristics: experimental/training artifacts (exclude from git)
|
||||
dispatcher/heuristics/models/**/oof_predictions.parquet
|
||||
dispatcher/heuristics/models/**/cv_metrics_*.json
|
||||
dispatcher/heuristics/models/**/eval_report.json
|
||||
dispatcher/heuristics/models/**/feature_importances_*.json
|
||||
dispatcher/heuristics/models/**/model_tflops_ihem.lgbm
|
||||
dispatcher/heuristics/models/**/model_tflops_log.lgbm
|
||||
dispatcher/heuristics/models/**/model_tflops_log_big.lgbm
|
||||
|
||||
# Heuristics: keep in git (production model files):
|
||||
# models/{op}_{dtype}_{arch}/model_tflops.lgbm
|
||||
# models/{op}_{dtype}_{arch}/model_latency.lgbm
|
||||
# models/{op}_{dtype}_{arch}/model_bandwidth.lgbm
|
||||
# models/{op}_{dtype}_{arch}/feature_spec.json
|
||||
# models/{op}_{dtype}_{arch}/train_manifest.json
|
||||
|
||||
# Heuristics: logs and caches
|
||||
dispatcher/heuristics/*.log
|
||||
dispatcher/heuristics/__pycache__/
|
||||
dispatcher/heuristics/tests/__pycache__/
|
||||
dispatcher/heuristics/.pytest_cache/
|
||||
|
||||
|
||||
51
Dockerfile
51
Dockerfile
@@ -2,22 +2,33 @@
|
||||
FROM ubuntu:24.04
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG ROCMVERSION=7.1.1
|
||||
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
|
||||
ARG TARBALL_URL=https://rocm.nightlies.amd.com/tarball/therock-dist-linux-gfx90X-dcgpu-7.12.0a20260218.tar.gz
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
ARG CK_SCCACHE=""
|
||||
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
|
||||
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=$PATH:/opt/rocm/bin
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib
|
||||
ENV HIP_PLATFORM=amd
|
||||
|
||||
# Add rocm repository
|
||||
RUN set -xe && \
|
||||
apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl
|
||||
|
||||
RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
|
||||
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
|
||||
apt update && \
|
||||
apt install python3-setuptools python3-wheel -y && \
|
||||
apt install rocm-dev -y
|
||||
RUN if [ "$compiler_version" = "therock" ]; then \
|
||||
rm -rf /opt/rocm && mkdir /opt/rocm && \
|
||||
echo "Downloading ROCm tarball from $TARBALL_URL..." && \
|
||||
wget -q -O /tmp/rocm.tar.gz "$TARBALL_URL" && \
|
||||
echo "Extracting tarball to /opt/rocm..." && \
|
||||
tar -xzf /tmp/rocm.tar.gz -C /opt/rocm --strip-components=1 ; \
|
||||
else echo "using the release compiler" && \
|
||||
wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
|
||||
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
|
||||
apt update && \
|
||||
apt install python3-setuptools python3-wheel -y && \
|
||||
apt install rocm-dev -y; \
|
||||
fi
|
||||
|
||||
# Install SCCACHE
|
||||
ENV SCCACHE_VERSION="0.14.0"
|
||||
@@ -34,7 +45,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
build-essential \
|
||||
cmake \
|
||||
git \
|
||||
hip-rocclr \
|
||||
iputils-ping \
|
||||
jq \
|
||||
libelf-dev \
|
||||
@@ -44,8 +54,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
net-tools \
|
||||
pkg-config \
|
||||
python3-full \
|
||||
python3-pip \
|
||||
redis \
|
||||
rocm-llvm-dev \
|
||||
sshpass \
|
||||
stunnel \
|
||||
software-properties-common \
|
||||
@@ -88,26 +98,3 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
git clone -b master https://github.com/ROCm/rocm-cmake.git && \
|
||||
cd rocm-cmake && mkdir build && cd build && \
|
||||
cmake .. && cmake --build . && cmake --build . --target install
|
||||
|
||||
WORKDIR /
|
||||
# Add alternative compilers, if necessary
|
||||
ENV compiler_version=$compiler_version
|
||||
ENV compiler_commit=$compiler_commit
|
||||
RUN sh -c "echo compiler version = '$compiler_version'" && \
|
||||
sh -c "echo compiler commit = '$compiler_commit'"
|
||||
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
|
||||
make -j 8 ; \
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
|
||||
make -j 8 ; \
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
@@ -9,7 +9,7 @@ ENV compiler_commit=$compiler_commit
|
||||
RUN sh -c "echo compiler version = '$compiler_version'" && \
|
||||
sh -c "echo compiler commit = '$compiler_commit'"
|
||||
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" = "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && git log -1 && mkdir build && cd build && \
|
||||
cmake -G Ninja \
|
||||
@@ -43,7 +43,7 @@ RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-sta
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" != "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
|
||||
cmake -G Ninja \
|
||||
|
||||
@@ -3,7 +3,6 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG ROCMVERSION=7.2
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
ARG CK_SCCACHE=""
|
||||
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
|
||||
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
@@ -19,16 +18,15 @@ RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2
|
||||
dnf install python3-setuptools python3-wheel -y && \
|
||||
dnf install rocm-dev -y
|
||||
|
||||
## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined
|
||||
ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache
|
||||
# Install SCCACHE
|
||||
ENV SCCACHE_VERSION="0.14.0"
|
||||
ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin
|
||||
ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION}
|
||||
ENV CK_SCCACHE=$CK_SCCACHE
|
||||
RUN if [ "$CK_SCCACHE" != "" ]; then \
|
||||
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
|
||||
curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \
|
||||
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \
|
||||
fi
|
||||
RUN set -x && \
|
||||
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
|
||||
wget -qO sccache.tar.gz https://github.com/mozilla/sccache/releases/latest/download/sccache-v$SCCACHE_VERSION-x86_64-unknown-linux-musl.tar.gz && \
|
||||
tar -xzf sccache.tar.gz --strip-components=1 -C ${SCCACHE_INSTALL_LOCATION} && \
|
||||
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache
|
||||
|
||||
# Install dependencies
|
||||
RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \
|
||||
@@ -83,19 +81,71 @@ ENV compiler_commit=$compiler_commit
|
||||
RUN sh -c "echo compiler version = '$compiler_version'" && \
|
||||
sh -c "echo compiler commit = '$compiler_commit'"
|
||||
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" = "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
|
||||
make -j 8 ; \
|
||||
cmake -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DLLVM_ENABLE_PROJECTS="clang;lld;clang-tools-extra;flang;mlir" \
|
||||
-DLLVM_LIT_ARGS="-vv --show-unsupported --show-xfail -j 32" \
|
||||
-DPACKAGE_VENDOR="AMD" \
|
||||
-DCMAKE_INSTALL_PREFIX=/home/$USER/rocm/pure_llvm_1.0 \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DLLVM_BUILD_DOCS=ON \
|
||||
-DLLVM_TARGETS_TO_BUILD=all \
|
||||
-DLIBOMPTARGET_ENABLE_DEBUG=ON \
|
||||
-DOFFLOAD_ENABLE_EMISSARY_APIS=OFF \
|
||||
-DCLANG_DEFAULT_LINKER=lld \
|
||||
-DCLANG_DEFAULT_PIE_ON_LINUX=0 \
|
||||
-DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi;openmp;compiler-rt;libunwind;flang-rt" \
|
||||
-DLIBCXX_ENABLE_SHARED=OFF \
|
||||
-DLIBCXX_ENABLE_STATIC=ON \
|
||||
-DLIBCXX_INSTALL_LIBRARY=OFF \
|
||||
-DLIBCXX_INSTALL_HEADERS=OFF \
|
||||
-DLIBCXXABI_ENABLE_SHARED=OFF \
|
||||
-DLIBCXXABI_ENABLE_STATIC=ON \
|
||||
-DLIBCXXABI_INSTALL_STATIC_LIBRARY=OFF \
|
||||
-DLLVM_ENABLE_ASSERTIONS=1 \
|
||||
-DLLVM_ENABLE_Z3_SOLVER=OFF \
|
||||
-DLLVM_ENABLE_ZLIB=ON \
|
||||
-DLLVM_LINK_LLVM_DYLIB=OFF \
|
||||
-DCLANG_LINK_CLANG_DYLIB=OFF \
|
||||
../llvm && \
|
||||
ninja -j16 ; \
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
|
||||
RUN if ( [ "$compiler_version" = "develop" ] || [ "$compiler_version" = "amd-staging" ] ) && [ "$compiler_commit" != "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
|
||||
make -j 8 ; \
|
||||
cmake -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DLLVM_ENABLE_PROJECTS="clang;lld;clang-tools-extra;flang;mlir" \
|
||||
-DLLVM_LIT_ARGS="-vv --show-unsupported --show-xfail -j 32" \
|
||||
-DPACKAGE_VENDOR="AMD" \
|
||||
-DCMAKE_INSTALL_PREFIX=/home/$USER/rocm/pure_llvm_1.0 \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DLLVM_BUILD_DOCS=ON \
|
||||
-DLLVM_TARGETS_TO_BUILD=all \
|
||||
-DLIBOMPTARGET_ENABLE_DEBUG=ON \
|
||||
-DOFFLOAD_ENABLE_EMISSARY_APIS=OFF \
|
||||
-DCLANG_DEFAULT_LINKER=lld \
|
||||
-DCLANG_DEFAULT_PIE_ON_LINUX=0 \
|
||||
-DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi;openmp;compiler-rt;libunwind;flang-rt" \
|
||||
-DLIBCXX_ENABLE_SHARED=OFF \
|
||||
-DLIBCXX_ENABLE_STATIC=ON \
|
||||
-DLIBCXX_INSTALL_LIBRARY=OFF \
|
||||
-DLIBCXX_INSTALL_HEADERS=OFF \
|
||||
-DLIBCXXABI_ENABLE_SHARED=OFF \
|
||||
-DLIBCXXABI_ENABLE_STATIC=ON \
|
||||
-DLIBCXXABI_INSTALL_STATIC_LIBRARY=OFF \
|
||||
-DLLVM_ENABLE_ASSERTIONS=1 \
|
||||
-DLLVM_ENABLE_Z3_SOLVER=OFF \
|
||||
-DLLVM_ENABLE_ZLIB=ON \
|
||||
-DLLVM_LINK_LLVM_DYLIB=OFF \
|
||||
-DCLANG_LINK_CLANG_DYLIB=OFF \
|
||||
../llvm && \
|
||||
ninja -j16 ; \
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
|
||||
22
Jenkinsfile
vendored
22
Jenkinsfile
vendored
@@ -421,16 +421,19 @@ def buildDocker(install_prefix){
|
||||
def image_name = getDockerImageName()
|
||||
def base_image_name = getBaseDockerImageName()
|
||||
echo "Building Docker for ${image_name}"
|
||||
def dockerArgs = "--build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if(params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
|
||||
def dockerArgs = "--build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if(params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_COMMIT != ""){
|
||||
dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f projects/composablekernel/Dockerfile.compiler . "
|
||||
}
|
||||
else if(params.COMPILER_VERSION == "therock"){
|
||||
dockerArgs = dockerArgs + " --no-cache -f projects/composablekernel/Dockerfile . "
|
||||
}
|
||||
else if(params.RUN_AITER_TESTS){
|
||||
image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter"
|
||||
dockerArgs = dockerArgs + " --no-cache -f projects/composablekernel/Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . "
|
||||
}
|
||||
else if(params.RUN_PYTORCH_TESTS){
|
||||
image_name = "${env.CK_DOCKERHUB}:ck_pytorch"
|
||||
image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_pytorch"
|
||||
dockerArgs = dockerArgs + " --no-cache -f projects/composablekernel/Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . "
|
||||
}
|
||||
else{
|
||||
@@ -470,10 +473,10 @@ def get_docker_options(){
|
||||
else{ //only add kfd and dri paths if you actually going to run somthing on GPUs
|
||||
dockerOpts = "--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
}
|
||||
if (params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
|
||||
if (params.COMPILER_VERSION == "develop" || params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "therock" || params.COMPILER_COMMIT != ""){
|
||||
// the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with
|
||||
// newer clang22 compilers and running with older hip runtima libraries
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 "
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 --env HIP_PLATFORM=amd "
|
||||
}
|
||||
// on some machines the group ids for video and render groups may not be the same as in the docker image!
|
||||
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
|
||||
@@ -1140,7 +1143,7 @@ def run_pytorch_tests(Map conf=[:]){
|
||||
show_node_info()
|
||||
checkoutComposableKernel()
|
||||
//use the latest pytorch-nightly image
|
||||
def image = "${env.CK_DOCKERHUB}:ck_pytorch"
|
||||
def image = "${env.CK_DOCKERHUB_PRIVATE}:ck_pytorch"
|
||||
def dockerOpts=get_docker_options() + ' --group-add irc '
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
|
||||
@@ -1183,16 +1186,19 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_
|
||||
0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_BASIC_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=false;BUILD_GFX908=false;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=develop;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=therock;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 15 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 13 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
|
||||
0 11 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true
|
||||
0 9 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : ""
|
||||
|
||||
POLL_SPEC = BRANCH_NAME == "develop" ? 'H H/6 * * *' : ''
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
triggers {
|
||||
parameterizedCron(CRON_SETTINGS)
|
||||
pollSCM(POLL_SPEC)
|
||||
}
|
||||
options {
|
||||
skipDefaultCheckout()
|
||||
@@ -1214,7 +1220,7 @@ pipeline {
|
||||
string(
|
||||
name: 'COMPILER_VERSION',
|
||||
defaultValue: '',
|
||||
description: 'Specify which version of compiler to use: release, develop, amd-staging, amd-mainline, or leave blank (default).')
|
||||
description: 'Specify which version of compiler to use: develop, amd-staging, therock, or leave blank (default).')
|
||||
string(
|
||||
name: 'COMPILER_COMMIT',
|
||||
defaultValue: '',
|
||||
|
||||
@@ -154,6 +154,8 @@ rocminfo | grep -i "gfx"
|
||||
|
||||
### Install Python Dependencies
|
||||
|
||||
#### Core Dependencies (Required)
|
||||
|
||||
NumPy is required for Python examples and kernel generation. We recommend using a virtual environment:
|
||||
|
||||
**Option 1: Using standard venv**
|
||||
@@ -165,8 +167,8 @@ python3 -m venv .venv
|
||||
source .venv/bin/activate # Linux/macOS
|
||||
# .venv\Scripts\activate # Windows
|
||||
|
||||
# Install NumPy
|
||||
pip install numpy
|
||||
# Install core dependencies
|
||||
pip install -r python/requirements.txt
|
||||
```
|
||||
|
||||
**Option 2: Using uv (faster alternative)**
|
||||
@@ -179,17 +181,38 @@ uv venv .venv
|
||||
source .venv/bin/activate # Linux/macOS
|
||||
# .venv\Scripts\activate # Windows
|
||||
|
||||
# Install NumPy
|
||||
uv pip install numpy
|
||||
# Install core dependencies
|
||||
uv pip install -r python/requirements.txt
|
||||
```
|
||||
|
||||
**Option 3: System-wide install (not recommended)**
|
||||
```bash
|
||||
pip install numpy
|
||||
pip install -r python/requirements.txt
|
||||
```
|
||||
|
||||
> **Note:** Always activate your virtual environment before running CMake or Python examples.
|
||||
|
||||
#### ML Heuristics Dependencies (Optional)
|
||||
|
||||
For ML-based kernel selection (examples 09-11), install additional dependencies:
|
||||
|
||||
```bash
|
||||
# Activate your virtual environment first
|
||||
source .venv/bin/activate
|
||||
|
||||
# Install ML dependencies (LightGBM, pandas, pyarrow, scikit-learn)
|
||||
pip install -r requirements-ml.txt
|
||||
```
|
||||
|
||||
**Why separate?** ML dependencies are large (especially pyarrow) and not needed for basic dispatcher usage. Install only if you need:
|
||||
- ML-based kernel selection (`examples/gemm/python/09_ml_heuristic.py`)
|
||||
- Model training (`heuristics/train.py`)
|
||||
- Model evaluation (`heuristics/evaluate.py`)
|
||||
- Automated benchmark analysis
|
||||
|
||||
**Core dependencies:** ~50 MB (NumPy only)
|
||||
**With ML dependencies:** ~500 MB (includes LightGBM, pandas, pyarrow, scikit-learn)
|
||||
|
||||
### Supported Data Types
|
||||
|
||||
CK Tile supports a wide range of data types for GEMM operations:
|
||||
@@ -470,6 +493,42 @@ python3 examples/gemm/python/10_advanced_benchmark.py \
|
||||
|
||||
---
|
||||
|
||||
## ML-Based Kernel Selection (Optional)
|
||||
|
||||
The dispatcher includes ML heuristics for automated kernel selection using trained LightGBM models.
|
||||
|
||||
**Prerequisites:** Install ML dependencies first:
|
||||
|
||||
```bash
|
||||
pip install -r requirements-ml.txt # ~500 MB (LightGBM, pandas, pyarrow, scikit-learn)
|
||||
```
|
||||
|
||||
**Documentation:** See [heuristics/README.md](heuristics/README.md) for:
|
||||
- Training and evaluating models
|
||||
- Feature engineering (72 features)
|
||||
- Using pre-trained models
|
||||
- Python API reference
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
python3 examples/gemm/python/09_ml_heuristic.py # ML-based kernel selection
|
||||
python3 examples/gemm/python/10_rank_kernels.py # Kernel ranking
|
||||
```
|
||||
|
||||
**Model Compression:** Trained models are stored in compressed `.lgbm.gz` format to save space (~67% size reduction). Python tools automatically decompress models on first use. For C++ examples, decompress manually:
|
||||
|
||||
```bash
|
||||
# If you have compressed models
|
||||
cd heuristics/models/gemm_universal_fp16_gfx950
|
||||
gunzip model_tflops.lgbm.gz
|
||||
|
||||
# Then use in C++ example
|
||||
cd ../../../build
|
||||
./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## External Integration
|
||||
|
||||
### Using Dispatcher in Your Own Project
|
||||
|
||||
@@ -346,6 +346,55 @@ add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.
|
||||
add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp)
|
||||
add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp)
|
||||
|
||||
# ML Heuristic example -- requires LightGBM shared library
|
||||
# Derive site-packages from active Python interpreter (respects virtualenvs)
|
||||
find_package(Python3 COMPONENTS Interpreter)
|
||||
|
||||
set(LIGHTGBM_SEARCH_PATHS)
|
||||
if(Python3_FOUND AND Python3_EXECUTABLE)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_path('purelib'))"
|
||||
OUTPUT_VARIABLE PYTHON_SITE_PACKAGES
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
ERROR_QUIET
|
||||
)
|
||||
if(PYTHON_SITE_PACKAGES)
|
||||
list(APPEND LIGHTGBM_SEARCH_PATHS "${PYTHON_SITE_PACKAGES}/lightgbm/lib")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Fallback to common Python 3.x site-packages if auto-detection failed
|
||||
if(NOT PYTHON_SITE_PACKAGES)
|
||||
list(APPEND LIGHTGBM_SEARCH_PATHS
|
||||
"$ENV{HOME}/.local/lib/python3.12/site-packages/lightgbm/lib"
|
||||
)
|
||||
endif()
|
||||
|
||||
find_library(LIGHTGBM_LIB NAMES LightGBM lib_lightgbm _lightgbm
|
||||
HINTS ${CMAKE_PREFIX_PATH}
|
||||
PATHS ${LIGHTGBM_SEARCH_PATHS}
|
||||
NO_DEFAULT_PATH
|
||||
DOC "LightGBM shared library for ML heuristics"
|
||||
)
|
||||
|
||||
# Fallback: search default paths (respects LightGBM_DIR if set by user)
|
||||
if(NOT LIGHTGBM_LIB)
|
||||
find_library(LIGHTGBM_LIB NAMES LightGBM lib_lightgbm)
|
||||
endif()
|
||||
|
||||
if(LIGHTGBM_LIB)
|
||||
add_declarative_gpu_example(gemm_09_ml_heuristic gemm/cpp/09_ml_heuristic.cpp)
|
||||
target_link_libraries(gemm_09_ml_heuristic PRIVATE ${LIGHTGBM_LIB})
|
||||
message(STATUS "LightGBM found: ${LIGHTGBM_LIB} -- building gemm_09_ml_heuristic")
|
||||
else()
|
||||
message(STATUS "LightGBM not found -- skipping gemm_09_ml_heuristic")
|
||||
message(STATUS " To enable ML heuristic example:")
|
||||
message(STATUS " 1. Activate virtualenv: source .venv/bin/activate")
|
||||
message(STATUS " 2. Install: pip install -r ../requirements-ml.txt")
|
||||
message(STATUS " 3. Reconfigure: cmake ..")
|
||||
message(STATUS " Or set CMAKE_PREFIX_PATH or LightGBM_DIR to LightGBM location")
|
||||
endif()
|
||||
|
||||
# =============================================================================
|
||||
# GEMM Python Library - Single Fallback Kernel
|
||||
# =============================================================================
|
||||
|
||||
211
dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp
Normal file
211
dispatcher/examples/gemm/cpp/09_ml_heuristic.cpp
Normal file
@@ -0,0 +1,211 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Example 09: ML-Based Kernel Selection (Native C++)
|
||||
*
|
||||
* Uses a trained LightGBM model loaded via the C API to predict TFLOPS
|
||||
* for each kernel in the registry and select the best one. The kernels
|
||||
* are JIT-compiled at build time via DECL_KERNEL_SET (same as other examples).
|
||||
*
|
||||
* Build: cd dispatcher/build && cmake .. && make gemm_09_ml_heuristic
|
||||
* Run: ./gemm_09_ml_heuristic --model <path_to_model.lgbm>
|
||||
*/
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_decl.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
#include "ck_tile/dispatcher/ml_heuristic.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
using Signature = decl::Signature;
|
||||
using Algorithm = decl::Algorithm;
|
||||
|
||||
// Multiple kernel configs for ML to choose from
|
||||
DECL_KERNEL_SET(ml_kernels,
|
||||
// Small tiles
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(16, 16, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 64)
|
||||
.wave(2, 2, 1)
|
||||
.warp(16, 16, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
// Medium tiles
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 64)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 64)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv4")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
// Large tiles
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(256, 256, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(256, 128, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 256, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 09: ML-Based Kernel Selection",
|
||||
"Uses trained LightGBM model for kernel selection");
|
||||
args.add_option("--arch", "gfx942", "GPU architecture");
|
||||
args.add_option("--model", "", "Path to LightGBM model file (.lgbm)");
|
||||
args.add_option("--log_transform", "false", "Model uses log1p transform");
|
||||
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
print_header("Example 09: ML-Based Kernel Selection");
|
||||
|
||||
std::string gfx_arch = args.get("--arch", "gfx942");
|
||||
std::string model_path = args.get("--model", "");
|
||||
bool log_transform = (args.get("--log_transform", "false") == "true");
|
||||
|
||||
if(model_path.empty())
|
||||
{
|
||||
std::cerr << "Error: --model <path> is required" << std::endl;
|
||||
std::cerr << "Usage: ./gemm_09_ml_heuristic --model path/to/model_tflops.lgbm" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Setup Registry (kernels are JIT compiled from DECL_KERNEL_SET above)
|
||||
Registry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << "Registry: " << registry.size() << " kernel(s)" << std::endl;
|
||||
|
||||
// Load ML model and create heuristic
|
||||
HardwareProfile hw;
|
||||
MLHeuristic ml_heuristic(model_path, ®istry, hw, log_transform);
|
||||
if(!ml_heuristic.is_loaded())
|
||||
{
|
||||
std::cerr << "Failed to load model. Exiting." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Wire ML heuristic into dispatcher
|
||||
Dispatcher dispatcher(®istry);
|
||||
dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic);
|
||||
dispatcher.set_heuristic([&ml_heuristic](const Problem& p) { return ml_heuristic(p); });
|
||||
|
||||
std::cout << "Strategy: ML Heuristic (LightGBM)" << std::endl;
|
||||
|
||||
// Test with different problem sizes
|
||||
using DataType = ck_tile::fp16_t;
|
||||
std::vector<std::tuple<int, int, int>> sizes = {
|
||||
{128, 128, 64},
|
||||
{512, 512, 256},
|
||||
{1024, 1024, 512},
|
||||
{2048, 2048, 1024},
|
||||
};
|
||||
|
||||
std::cout << std::endl
|
||||
<< std::setw(20) << "Shape" << std::setw(30) << "Selected Kernel" << std::setw(15)
|
||||
<< "Pred TFLOPS" << std::setw(12) << "Select ms" << std::setw(10) << "Status"
|
||||
<< std::endl;
|
||||
std::cout << std::string(87, '-') << std::endl;
|
||||
|
||||
bool all_passed = true;
|
||||
|
||||
for(const auto& [M, N, K] : sizes)
|
||||
{
|
||||
Problem problem;
|
||||
problem.M = M;
|
||||
problem.N = N;
|
||||
problem.K = K;
|
||||
problem.k_batch = 1;
|
||||
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
auto kernel = dispatcher.select_kernel(problem);
|
||||
auto t1 = std::chrono::high_resolution_clock::now();
|
||||
double select_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
|
||||
|
||||
std::string size_str =
|
||||
std::to_string(M) + "x" + std::to_string(N) + "x" + std::to_string(K);
|
||||
|
||||
if(!kernel)
|
||||
{
|
||||
std::cout << std::setw(20) << size_str << std::setw(30) << "NONE" << std::setw(15)
|
||||
<< "N/A" << std::setw(12) << std::fixed << std::setprecision(2) << select_ms
|
||||
<< std::setw(10) << "FAIL" << std::endl;
|
||||
all_passed = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
double pred = ml_heuristic.predict_tflops(problem, kernel->get_key());
|
||||
std::string name = kernel->get_key().encode_identifier();
|
||||
if(name.length() > 27)
|
||||
name = name.substr(0, 27) + "..";
|
||||
|
||||
std::cout << std::setw(20) << size_str << std::setw(30) << name << std::setw(15)
|
||||
<< std::fixed << std::setprecision(2) << pred << std::setw(12)
|
||||
<< std::setprecision(2) << select_ms << std::setw(10) << "OK" << std::endl;
|
||||
}
|
||||
|
||||
std::cout << std::endl
|
||||
<< (all_passed ? "*** ALL TESTS PASSED ***" : "*** SOME TESTS FAILED ***")
|
||||
<< std::endl;
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
305
dispatcher/examples/gemm/python/09_ml_heuristic.py
Normal file
305
dispatcher/examples/gemm/python/09_ml_heuristic.py
Normal file
@@ -0,0 +1,305 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 09: ML-Based Kernel Selection
|
||||
|
||||
Uses a trained LightGBM model to select the optimal kernel for each problem
|
||||
size. The model predicts TFLOPS for every candidate in the kernel pool and
|
||||
picks the highest-scoring one, which is then JIT-compiled and run.
|
||||
|
||||
This replaces the hand-crafted rules in 08_heuristics.py with a data-driven
|
||||
approach achieving 97-98% of oracle-best TFLOPS efficiency.
|
||||
|
||||
Complexity: *****
|
||||
|
||||
Prerequisites:
|
||||
- Trained model in dispatcher/heuristics/models/gemm_universal_fp8_gfx950/
|
||||
- lightgbm, pandas, numpy, pyarrow installed
|
||||
|
||||
Usage:
|
||||
python3 09_ml_heuristic.py
|
||||
python3 09_ml_heuristic.py --dtype fp16 --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "heuristics"))
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
)
|
||||
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Kernel specification -- same structure as 08_heuristics.py"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
wave_m: int = 2
|
||||
wave_n: int = 2
|
||||
wave_k: int = 1
|
||||
warp_m: int = 32
|
||||
warp_n: int = 32
|
||||
warp_k: int = 16
|
||||
|
||||
|
||||
# Kernel pool: representative configs spanning small to large tiles,
|
||||
# compv3/compv4/mem pipelines, and intrawave/interwave schedulers.
|
||||
KERNEL_POOL = [
|
||||
# Small tiles
|
||||
KernelSpec("s_64x64_k32_v3", 64, 64, 32, "compv3", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k64_v3", 64, 64, 64, "compv3", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k128_v3", 64, 64, 128, "compv3", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k32_v4", 64, 64, 32, "compv4", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", warp_m=16, warp_n=16),
|
||||
KernelSpec("s_64x64_k128_mem", 64, 64, 128, "mem", warp_m=16, warp_n=16),
|
||||
# Medium tiles
|
||||
KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3"),
|
||||
KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3"),
|
||||
KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3"),
|
||||
KernelSpec("m_128x128_k32_v4", 128, 128, 32, "compv4"),
|
||||
KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4"),
|
||||
KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem"),
|
||||
KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem"),
|
||||
# Rectangular medium
|
||||
KernelSpec("r_64x128_k32", 64, 128, 32, "compv3", warp_m=16),
|
||||
KernelSpec("r_128x64_k32", 128, 64, 32, "compv3", warp_n=16),
|
||||
KernelSpec("r_64x128_k64", 64, 128, 64, "compv3", warp_m=16),
|
||||
KernelSpec("r_128x64_k64", 128, 64, 64, "compv3", warp_n=16),
|
||||
# Large tiles
|
||||
KernelSpec("l_256x128_k32", 256, 128, 32, "compv3"),
|
||||
KernelSpec("l_128x256_k32", 128, 256, 32, "compv3"),
|
||||
KernelSpec("l_256x256_k32", 256, 256, 32, "compv3"),
|
||||
KernelSpec("l_256x256_k64", 256, 256, 64, "compv3"),
|
||||
# Interwave variants
|
||||
KernelSpec("m_128x128_k64_iw", 128, 128, 64, "compv3", "interwave"),
|
||||
KernelSpec("m_128x128_k64_mem_iw", 128, 128, 64, "mem", "interwave"),
|
||||
]
|
||||
|
||||
|
||||
def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict:
|
||||
"""Convert a KernelSpec to the dict format the feature engine expects.
|
||||
|
||||
Note: pad_m/n/k default to True to match KernelConfig defaults and actual
|
||||
compiled kernels. This ensures the ML model receives the correct padding
|
||||
flags that will be used during JIT compilation.
|
||||
"""
|
||||
return {
|
||||
"kernel_name": spec.name,
|
||||
"tile_m": spec.tile_m,
|
||||
"tile_n": spec.tile_n,
|
||||
"tile_k": spec.tile_k,
|
||||
"warp_m": spec.wave_m,
|
||||
"warp_n": spec.wave_n,
|
||||
"warp_k": spec.wave_k,
|
||||
"warp_tile_m": spec.warp_m,
|
||||
"warp_tile_n": spec.warp_n,
|
||||
"warp_tile_k": spec.warp_k,
|
||||
"pipeline": spec.pipeline,
|
||||
"scheduler": spec.scheduler,
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": True, # Match KernelConfig default
|
||||
"pad_n": True, # Match KernelConfig default
|
||||
"pad_k": True, # Match KernelConfig default
|
||||
"persistent": False,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
}
|
||||
|
||||
|
||||
def spec_to_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Convert a KernelSpec to the dispatcher's KernelConfig for JIT compilation."""
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a="row",
|
||||
layout_b="col",
|
||||
layout_c="row",
|
||||
tile_m=spec.tile_m,
|
||||
tile_n=spec.tile_n,
|
||||
tile_k=spec.tile_k,
|
||||
wave_m=spec.wave_m,
|
||||
wave_n=spec.wave_n,
|
||||
wave_k=spec.wave_k,
|
||||
warp_m=spec.warp_m,
|
||||
warp_n=spec.warp_n,
|
||||
warp_k=spec.warp_k,
|
||||
pipeline=spec.pipeline,
|
||||
scheduler=spec.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def ml_select_kernel(
|
||||
predictor: Predictor,
|
||||
pool: List[KernelSpec],
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
dtype: str,
|
||||
layout: str,
|
||||
) -> tuple:
|
||||
"""Score all kernels in the pool and return (best_spec, predicted_tflops)."""
|
||||
problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1}
|
||||
kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool]
|
||||
|
||||
ranked = predictor.rank_kernels(problem, kernel_dicts)
|
||||
if not ranked:
|
||||
return pool[0], 0.0
|
||||
|
||||
best_name, best_tflops = ranked[0]
|
||||
best_spec = next((s for s in pool if s.name == best_name), pool[0])
|
||||
return best_spec, best_tflops
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ML-based kernel selection for GEMM")
|
||||
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp8"])
|
||||
parser.add_argument("--arch", default="gfx942")
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
default=str(
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "heuristics"
|
||||
/ "models"
|
||||
/ "gemm_universal_fp8_gfx950"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_run", action="store_true", help="Only predict, don't run GEMMs"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 75)
|
||||
print(" Example 09: ML-Based Kernel Selection")
|
||||
print("=" * 75)
|
||||
print(f"\n Model: {args.model_dir}")
|
||||
print(f" Dtype: {args.dtype}")
|
||||
print(f" Arch: {args.arch}")
|
||||
print(f" Pool: {len(KERNEL_POOL)} kernels")
|
||||
|
||||
predictor = Predictor(args.model_dir)
|
||||
print(" Model loaded successfully")
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float16
|
||||
|
||||
test_sizes = [
|
||||
(128, 128, 64),
|
||||
(256, 256, 128),
|
||||
(512, 512, 256),
|
||||
(1024, 1024, 512),
|
||||
(2048, 2048, 1024),
|
||||
]
|
||||
|
||||
header = f"{'Shape':<20} {'Selected Kernel':<25} {'Pred TFLOPS':>12}"
|
||||
if not args.no_run:
|
||||
header += f" {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
|
||||
print(f"\n {header}")
|
||||
print(" " + "-" * len(header))
|
||||
|
||||
results = []
|
||||
|
||||
for M, N, K in test_sizes:
|
||||
t0 = time.time()
|
||||
best_spec, pred_tflops = ml_select_kernel(
|
||||
predictor, KERNEL_POOL, M, N, K, args.dtype, "rcr"
|
||||
)
|
||||
_ = (time.time() - t0) * 1000 # ML selection time (unused)
|
||||
|
||||
size_str = f"{M}x{N}x{K}"
|
||||
line = f" {size_str:<20} {best_spec.name:<25} {pred_tflops:>12.2f}"
|
||||
|
||||
if args.no_run:
|
||||
print(line)
|
||||
results.append((size_str, best_spec.name, True, 0, pred_tflops))
|
||||
continue
|
||||
|
||||
config = spec_to_kernel_config(best_spec, args.dtype, args.arch)
|
||||
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"ml_{best_spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'BUILD':>8}"
|
||||
print(line)
|
||||
results.append((size_str, best_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'UNSUP':>8}"
|
||||
print(line)
|
||||
results.append((size_str, best_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
np.random.seed(42)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if result.success:
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(
|
||||
np_dtype
|
||||
)
|
||||
max_err = np.max(np.abs(result.output - C_ref))
|
||||
passed = max_err < 1e-2
|
||||
status = "PASS" if passed else "FAIL"
|
||||
line += f" {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
|
||||
results.append(
|
||||
(size_str, best_spec.name, passed, result.time_ms, result.tflops)
|
||||
)
|
||||
else:
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
results.append((size_str, best_spec.name, False, 0, 0))
|
||||
|
||||
print(line)
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 75)
|
||||
print(" SUMMARY")
|
||||
print("=" * 75)
|
||||
passed = sum(1 for r in results if r[2])
|
||||
print(f"\n Results: {passed}/{len(results)} tests passed")
|
||||
valid = [r for r in results if r[2] and r[4] > 0]
|
||||
if valid:
|
||||
avg = sum(r[4] for r in valid) / len(valid)
|
||||
print(f" Average TFLOPS: {avg:.2f}")
|
||||
if passed == len(results):
|
||||
print("\n *** ALL TESTS PASSED ***")
|
||||
print("=" * 75)
|
||||
return 0 if passed == len(results) else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
60
dispatcher/heuristics/.gitignore
vendored
Normal file
60
dispatcher/heuristics/.gitignore
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
# Python bytecode and caches
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
|
||||
# Jupyter notebooks
|
||||
*.ipynb
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE and editor files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Test output and logs
|
||||
*.log
|
||||
test_output.log
|
||||
custom_shapes_gpu_test.log
|
||||
|
||||
# Benchmark and analysis output files
|
||||
*.csv
|
||||
*.json
|
||||
!models/*/feature_spec.json
|
||||
!models/*/train_manifest.json
|
||||
|
||||
# Data files (parquet, arrow)
|
||||
*.parquet
|
||||
*.arrow
|
||||
|
||||
# Temporary and NFS files
|
||||
.nfs*
|
||||
*.tmp
|
||||
*.bak
|
||||
|
||||
# Decompressed model files (compressed .lgbm.gz versions are tracked)
|
||||
models/**/*.lgbm
|
||||
|
||||
# User-specific test and analysis scripts
|
||||
test_*.py
|
||||
!tests/test_*.py
|
||||
find_*.py
|
||||
oracle_*.json
|
||||
validation_results_*.csv
|
||||
custom_shapes_*.csv
|
||||
fp16_bf16_*.csv
|
||||
|
||||
# Ignore all markdown files except tracked documentation
|
||||
*.md
|
||||
!DATA_GENERATION.md
|
||||
!LEARNINGS.md
|
||||
!README.md
|
||||
412
dispatcher/heuristics/DATA_GENERATION.md
Normal file
412
dispatcher/heuristics/DATA_GENERATION.md
Normal file
@@ -0,0 +1,412 @@
|
||||
# Data Generation Guide
|
||||
|
||||
This document explains how to build benchmark binaries from the CK Tile engine,
|
||||
generate benchmark datasets, and manage them for the ML kernel performance
|
||||
prediction system.
|
||||
|
||||
## Overview
|
||||
|
||||
The ML heuristic needs benchmark data: measured TFLOPS, latency, and bandwidth
|
||||
for every (problem shape, kernel config) pair. The tile engine builds one
|
||||
executable per kernel configuration. Each executable benchmarks a single kernel
|
||||
on a given problem size and outputs JSON with performance metrics.
|
||||
|
||||
```
|
||||
CK source --> CMake configure --> ninja build --> benchmark binaries
|
||||
(4608 per op/dtype/layout)
|
||||
|
||||
benchmark binaries --> run on GPU --> streaming log --> parquet dataset
|
||||
(per shape) (JSON blocks) (canonical schema)
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **ROCm**: HIP >= 6.0.3 (for gfx950: HIP >= 6.0.4)
|
||||
- **Build tools**: CMake >= 3.21, Ninja, HIP-aware clang compiler
|
||||
- **Python**: 3.10+ with `pandas`, `pyarrow`
|
||||
- **GPU**: ROCm-capable AMD GPU (MI250X, MI300X, MI355X, etc.)
|
||||
|
||||
---
|
||||
|
||||
## Part 1: Building Benchmark Binaries from the Tile Engine
|
||||
|
||||
If you already have pre-built binaries (e.g., in `/workspace/ck_tile/bin/`),
|
||||
skip to Part 2. This section explains how to build them from source.
|
||||
|
||||
### Step 1: CMake Configure
|
||||
|
||||
From the CK repository root:
|
||||
|
||||
```bash
|
||||
cmake -S /workspace/rocm-libraries/projects/composablekernel \
|
||||
-B build \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DGEMM_UNIVERSAL_DATATYPE="fp8" \
|
||||
-DGEMM_UNIVERSAL_LAYOUT="rcr" \
|
||||
-G Ninja
|
||||
```
|
||||
|
||||
**Key CMake variables:**
|
||||
|
||||
| Variable | Default | Description |
|
||||
|---|---|---|
|
||||
| `GPU_TARGETS` | (required) | Target GPU architectures. Supported: `gfx90a`, `gfx942`, `gfx950`, `gfx1201`. Semicolon-separated for multiple. |
|
||||
| `GEMM_UNIVERSAL_DATATYPE` | `"fp8;fp16"` | Data types to build. Options: `fp8`, `fp16`, `bf16`, `bf8`. Semicolon-separated. |
|
||||
| `GEMM_UNIVERSAL_LAYOUT` | `"rcr;rrr;crr;ccr"` | Layouts to build. Semicolon-separated. |
|
||||
| `GEMM_UNIVERSAL_CONFIG_FILE` | `"default_config.json"` | Kernel config file (in the `configs/` directory). Controls which tile sizes, warp configs, pipelines, etc. are enumerated. |
|
||||
| `ENABLE_CCACHE_GEMM_UNIVERSAL` | `OFF` | Enable ccache for faster rebuilds. |
|
||||
|
||||
**Example: build only fp8 RCR for gfx950 (fastest, ~4608 kernels):**
|
||||
```bash
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DGEMM_UNIVERSAL_DATATYPE="fp8" \
|
||||
-DGEMM_UNIVERSAL_LAYOUT="rcr" \
|
||||
-G Ninja
|
||||
```
|
||||
|
||||
**Example: build all dtypes and layouts (slow, ~4608 * 4 * 4 = ~73K kernels):**
|
||||
```bash
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DGEMM_UNIVERSAL_DATATYPE="fp8;fp16;bf16;bf8" \
|
||||
-DGEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-G Ninja
|
||||
```
|
||||
|
||||
### What happens during configure
|
||||
|
||||
1. CMake calls `gemm_universal_instance_builder.py --list_kernels` to enumerate
|
||||
all valid kernel configurations from the config JSON.
|
||||
2. It writes `gemm_universal_kernel_list.txt` (one kernel per line) and
|
||||
`gemm_universal_kernel_count.txt` to the build directory.
|
||||
3. For each kernel, it creates a ninja build target.
|
||||
|
||||
### Step 2: Build
|
||||
|
||||
```bash
|
||||
# Build all benchmarks for the configured dtypes/layouts
|
||||
ninja -C build benchmark_gemm_universal_all
|
||||
|
||||
# Or build a specific dtype/layout combo
|
||||
ninja -C build benchmark_gemm_universal_fp8_rcr
|
||||
|
||||
# Or build by pipeline type
|
||||
ninja -C build benchmark_gemm_universal_compv4_pipeline
|
||||
ninja -C build benchmark_gemm_universal_mem_pipeline
|
||||
|
||||
# Or build a single specific kernel
|
||||
ninja -C build benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128
|
||||
```
|
||||
|
||||
**Build time estimates:**
|
||||
- ~4608 kernels (one dtype, one layout): 1-4 hours depending on CPU cores
|
||||
- Use `-j <N>` to control parallelism: `ninja -C build -j 32 benchmark_gemm_universal_fp8_rcr`
|
||||
|
||||
### Step 3: Verify binaries
|
||||
|
||||
Binaries are placed in `build/bin/`:
|
||||
|
||||
```bash
|
||||
ls build/bin/benchmark_gemm_universal_fp8_rcr_* | wc -l
|
||||
# Expected: 4608 (for default config)
|
||||
|
||||
# Test one binary
|
||||
./build/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \
|
||||
-m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0
|
||||
```
|
||||
|
||||
### Kernel config files
|
||||
|
||||
The config files live in:
|
||||
```
|
||||
tile_engine/ops/gemm/gemm_universal/configs/
|
||||
default_config.json # Default: full enumeration
|
||||
default_ci_config.json # CI: reduced set for fast testing
|
||||
user_provided_config.json # Custom: your own subset
|
||||
```
|
||||
|
||||
To use a custom config:
|
||||
```bash
|
||||
cmake ... -DGEMM_UNIVERSAL_CONFIG_FILE="user_provided_config.json"
|
||||
```
|
||||
|
||||
The config controls which tile sizes (e.g., 128x128x64, 256x256x32), warp
|
||||
configurations (e.g., 2x2x1, 1x4x1), pipelines (compv3, compv4, mem),
|
||||
schedulers, and other parameters are included in the kernel enumeration.
|
||||
|
||||
### Building StreamK / other ops
|
||||
|
||||
The same pattern applies to other tile engine ops:
|
||||
|
||||
```bash
|
||||
# StreamK
|
||||
ninja -C build benchmark_gemm_streamk_fp8_rcr
|
||||
|
||||
# Grouped convolution
|
||||
ninja -C build benchmark_grouped_conv_fwd_fp16_nhwgc
|
||||
```
|
||||
|
||||
Each op has its own instance builder and config directory.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Running Benchmarks and Generating Data
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Run benchmarks for a set of shapes
|
||||
|
||||
Each binary accepts `-m=`, `-n=`, `-k=`, `-warmup=`, `-repeat=`, `-verify=` flags
|
||||
and outputs JSON to stdout:
|
||||
|
||||
```bash
|
||||
/workspace/ck_tile/bin/benchmark_gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128 \
|
||||
-m=1024 -n=1024 -k=1024 -warmup=3 -repeat=10 -verify=0
|
||||
```
|
||||
|
||||
Output:
|
||||
```json
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_...",
|
||||
"problem": {
|
||||
"split_k": 1, "m": 1024, "n": 1024, "k": 1024,
|
||||
"dtype_a": "fp8", "dtype_b": "fp8", ...
|
||||
},
|
||||
"perf_result": {
|
||||
"latency(ms)": 0.04,
|
||||
"tflops(TFlops)": 204.60,
|
||||
"bandwidth(GB/s)": 624.39
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Batch generation using provided scripts
|
||||
|
||||
**Wide coverage (diverse shapes across all regimes):**
|
||||
```bash
|
||||
python3 generate_wide_coverage.py \
|
||||
--bin_dir /workspace/ck_tile/bin \
|
||||
--out_dir data/wide_coverage \
|
||||
--batch_size 25 \
|
||||
--warmup 3 --repeat 10
|
||||
```
|
||||
|
||||
**Edge-case dimensions (N=1, K=1, small N/K):**
|
||||
```bash
|
||||
python3 generate_edge_dims.py
|
||||
```
|
||||
|
||||
Both scripts write streaming log files that `data_pipeline.py` can parse.
|
||||
|
||||
### 3. Parse logs into parquet
|
||||
|
||||
```bash
|
||||
python3 data_pipeline.py <log_file> \
|
||||
-o data/my_dataset.parquet \
|
||||
--arch gfx950 \
|
||||
--capture_hw
|
||||
```
|
||||
|
||||
The `--capture_hw` flag runs `rocminfo` once and injects the GPU hardware
|
||||
profile (CU count, clock speed, cache sizes, etc.) into every row.
|
||||
|
||||
## Canonical Data Schema
|
||||
|
||||
Every parquet file follows this schema:
|
||||
|
||||
| Column | Type | Description |
|
||||
|---|---|---|
|
||||
| `op_type` | str | `gemm_universal`, `gemm_streamk`, etc. |
|
||||
| `dtype` | str | `fp8`, `fp16`, `bf16`, `bf8` |
|
||||
| `layout` | str | `rcr`, `rrr`, `crr`, `ccr` |
|
||||
| `arch` | str | `gfx942`, `gfx950`, etc. |
|
||||
| `kernel_name` | str | Full kernel identifier |
|
||||
| `m`, `n`, `k` | int | Problem dimensions |
|
||||
| `split_k` | int | Split-K factor (1 = standard) |
|
||||
| `measured_tflops` | float | Ground-truth TFLOPS |
|
||||
| `latency_ms` | float | Measured latency |
|
||||
| `bandwidth_gb_s` | float | Measured bandwidth |
|
||||
| `is_valid` | bool | True if tflops > 0 and latency > 0 |
|
||||
| `tile_m`, `tile_n`, `tile_k` | int | Tile dimensions |
|
||||
| `warp_m`, `warp_n`, `warp_k` | int | Warp config |
|
||||
| `warp_tile_m/n/k` | int | Warp tile dimensions |
|
||||
| `pipeline` | str | `compv3`, `compv4`, `mem`, etc. |
|
||||
| `scheduler` | str | `intrawave`, `interwave` |
|
||||
| `epilogue` | str | `cshuffle`, `default` |
|
||||
| `pad_m`, `pad_n`, `pad_k` | bool | Padding flags |
|
||||
| `persistent` | bool | Persistent kernel flag |
|
||||
| `run_id` | str | Unique collection run identifier |
|
||||
|
||||
## Shape Selection Guidelines
|
||||
|
||||
Good training data requires diverse shapes. Cover all of these regimes:
|
||||
|
||||
### By M dimension (batch size / output rows)
|
||||
- **M=1**: single-token inference (hardest case for tiling)
|
||||
- **Tiny M (2-16)**: small batch inference
|
||||
- **Small M (32-128)**: medium batch
|
||||
- **Medium M (256-2048)**: large batch / training
|
||||
- **Large M (4096-20480)**: very large batch
|
||||
|
||||
### By N and K dimension
|
||||
- **N=1**: vector-matrix multiply (degenerate)
|
||||
- **K=1**: rank-1 update / outer product (degenerate)
|
||||
- **Small N or K (2-16)**: stress tile efficiency
|
||||
- **Deep K (K > 4096)**: compute-bound regime
|
||||
- **Shallow K (K < 256)**: memory-bound regime
|
||||
|
||||
### By shape family
|
||||
- **Square**: M ~ N ~ K (powers of 2)
|
||||
- **Tall**: M >> N (tall output matrix)
|
||||
- **Wide**: N >> M (wide output matrix)
|
||||
- **Deep-K**: K >> M and K >> N
|
||||
|
||||
### Special cases
|
||||
- **Prime dimensions**: 17, 31, 127, 251, 509, 1021, 2039, 4093
|
||||
(worst-case for tile alignment, tests padding logic)
|
||||
- **Non-power-of-2**: 48, 96, 192, 384, 576, 768, 1536, 3072, 4608
|
||||
(common in LLM architectures)
|
||||
- **LLM inference shapes**: DeepSeek, LLaMA-7B, LLaMA-70B MLP/attention dims
|
||||
|
||||
### Minimum recommended coverage
|
||||
|
||||
For a production-quality model, aim for:
|
||||
- At least 200 unique (M, N, K) shapes
|
||||
- At least 10 shapes per shape family
|
||||
- All kernel configs (4608 for fp8 RCR) run against every shape
|
||||
- Multiple layouts if training a cross-layout model
|
||||
|
||||
## Benchmark Quality Guidelines
|
||||
|
||||
### Warmup and repeat
|
||||
- Minimum `warmup=3`, `repeat=10` for fast iteration
|
||||
- Production quality: `warmup=5`, `repeat=20` for stable measurements
|
||||
- The `perf_result` values are averaged over `repeat` iterations
|
||||
|
||||
### Noise handling
|
||||
- Use **median** latency when aggregating multiple runs of the same benchmark
|
||||
- Flag measurements where coefficient of variation exceeds 10%
|
||||
- Avoid benchmarking under thermal throttling (check GPU temperature)
|
||||
- Lock GPU clocks if possible for reproducibility
|
||||
|
||||
### Environment metadata
|
||||
Store with every dataset:
|
||||
- GPU model and architecture (from `rocminfo`)
|
||||
- ROCm driver version
|
||||
- Clock mode (default / locked)
|
||||
- Git hash of the CK tile engine build (if available)
|
||||
- Timestamp
|
||||
|
||||
## Adding Data for a New Op
|
||||
|
||||
To generate benchmark data for a new operation (e.g., `gemm_streamk`):
|
||||
|
||||
1. **Build the binaries** using the tile engine:
|
||||
```bash
|
||||
ninja -C build benchmark_gemm_streamk_fp8_rcr
|
||||
```
|
||||
|
||||
2. **Write a generation script** (or modify `generate_wide_coverage.py`):
|
||||
- Change the executable glob pattern to match the new op
|
||||
- Add any op-specific CLI flags the binaries need
|
||||
|
||||
3. **Run and parse**:
|
||||
```bash
|
||||
python3 data_pipeline.py my_streamk_run.log \
|
||||
-o data/gemm_streamk_fp8_gfx950.parquet --arch gfx950
|
||||
```
|
||||
|
||||
4. **Train**:
|
||||
```bash
|
||||
python3 train.py --op gemm_streamk --dtype fp8 --arch gfx950 \
|
||||
--data_dir data/ --out_dir models/gemm_streamk_fp8_gfx950
|
||||
```
|
||||
|
||||
## Adding Data for a New Layout
|
||||
|
||||
Same binaries, same shapes -- just change the layout filter:
|
||||
|
||||
```bash
|
||||
# Build rrr binaries
|
||||
ninja -C build benchmark_gemm_universal_fp8_rrr
|
||||
|
||||
# Generate and parse
|
||||
# ... (same flow, different bin_dir or executable glob)
|
||||
|
||||
# Train a cross-layout model by putting all layouts in the same data_dir
|
||||
python3 train.py --data_dir data/ --out_dir models/gemm_universal_fp8_gfx950_all_layouts
|
||||
```
|
||||
|
||||
The feature engine includes `layout` as a categorical feature, so one model
|
||||
can handle all layouts.
|
||||
|
||||
## Incremental Data Collection
|
||||
|
||||
When you have a trained model and want to add more data:
|
||||
|
||||
1. Generate new data (new shapes, new layouts, etc.)
|
||||
2. Parse into parquet alongside existing data
|
||||
3. Warm-start from the previous model:
|
||||
```bash
|
||||
python3 train.py --data_dir data/ --out_dir models/v2 \
|
||||
--warm_start models/v1 \
|
||||
--warm_start_n_estimators 200
|
||||
```
|
||||
|
||||
This adds 200 new trees on top of the existing model. The feature schema
|
||||
must match exactly (enforced automatically).
|
||||
|
||||
## File Organization
|
||||
|
||||
Recommended directory structure:
|
||||
|
||||
```
|
||||
heuristics/
|
||||
data/
|
||||
gemm_universal_fp8_rcr_gfx950.parquet # original 108 shapes
|
||||
wide_coverage/ # batch log files
|
||||
wide_coverage_batch_001.log
|
||||
wide_coverage_batch_002.log
|
||||
...
|
||||
edge_dims/ # N=1, K=1 edge cases
|
||||
edge_dims_batch_001.log
|
||||
...
|
||||
models/
|
||||
gemm_universal_fp8_gfx950/ # trained model artifacts
|
||||
model_tflops.lgbm
|
||||
model_latency.lgbm
|
||||
model_bandwidth.lgbm
|
||||
feature_spec.json
|
||||
train_manifest.json
|
||||
cv_metrics_tflops.json
|
||||
eval_report.json
|
||||
...
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Benchmark binary exits with non-zero code
|
||||
Some kernel configs are invalid for certain problem sizes (e.g., tile_m=256
|
||||
with M=16). The data pipeline marks these as `is_valid=False` and they are
|
||||
filtered out during training. This is expected.
|
||||
|
||||
### Edge dims produce very few results
|
||||
N=1 and K=1 shapes are degenerate -- most kernel configurations have minimum
|
||||
dimension requirements and will fail or produce zero TFLOPS. The small number
|
||||
of valid results is still useful (it tells the model which configs work for
|
||||
these shapes).
|
||||
|
||||
### Benchmarks are slow
|
||||
Each shape requires running all 4608 kernel executables sequentially. At
|
||||
~0.01s per kernel, that is ~46 seconds per shape. For 700 shapes, expect
|
||||
~9 hours. Tips:
|
||||
- Run on a dedicated GPU (no other workloads)
|
||||
- Use `--batch_size 25` to get incremental output
|
||||
- Parse and train on partial data while generation continues
|
||||
|
||||
### Data from different GPUs / driver versions
|
||||
Store `run_id` and hardware metadata with each dataset. Training on mixed
|
||||
data is allowed but not recommended for production models. Filter to a
|
||||
single `run_id` or `arch` for clean experiments.
|
||||
151
dispatcher/heuristics/LEARNINGS.md
Normal file
151
dispatcher/heuristics/LEARNINGS.md
Normal file
@@ -0,0 +1,151 @@
|
||||
# Learnings and Design Decisions
|
||||
|
||||
Empirical findings from building the CK Tile kernel performance prediction system.
|
||||
These inform the current defaults and explain why certain approaches were chosen.
|
||||
|
||||
## 1. Log-Transform is Essential for Cross-Scale Accuracy
|
||||
|
||||
**Problem**: GEMM TFLOPS spans 5 orders of magnitude across different problem
|
||||
sizes. When training on raw TFLOPS, the regression loss (RMSE) is dominated by
|
||||
large shapes where absolute errors are biggest. The model learns to predict
|
||||
large shapes accurately but ignores tiny shapes where the TFLOPS values are
|
||||
much lower.
|
||||
|
||||
**Evidence** (168 shapes, 626K rows, 5-fold GroupKFold CV):
|
||||
|
||||
|
||||
| Model | Mean Eff | P10 Eff | tiny_m Eff | Min Eff |
|
||||
| ----------------------------- | ---------- | ---------- | ---------- | ---------- |
|
||||
| Raw TFLOPS (500 trees) | 92.73% | 80.24% | 84.55% | 4.26% |
|
||||
| **log1p(TFLOPS)** (500 trees) | **96.92%** | **94.34%** | **94.89%** | **60.27%** |
|
||||
| log1p(TFLOPS) (2000 trees) | 97.51% | 93.89% | 96.04% | 63.56% |
|
||||
|
||||
|
||||
**Solution**: Train on `log1p(measured_tflops)` and apply `expm1()` to
|
||||
predictions. This is now the default in `train.py`. Pass `--no_log_transform`
|
||||
to revert to raw regression (not recommended).
|
||||
|
||||
**Why log1p, not log**: `log1p(x) = log(1 + x)` handles zero and near-zero
|
||||
TFLOPS gracefully, whereas `log(x)` produces -inf for x=0.
|
||||
|
||||
## 2. Tiny-M Shapes are the Hardest Case
|
||||
|
||||
M=1 (single-token inference) shapes are fundamentally different from batch shapes:
|
||||
|
||||
- Most kernel configurations produce very low TFLOPS
|
||||
- The "best" kernel is often only marginally better than the rest
|
||||
- The oracle performance itself is very low, so any prediction error tanks efficiency
|
||||
- Many kernels fail outright (tile_m=128 with M=1 wastes 127/128 of the tile)
|
||||
|
||||
The bottom shapes in our evaluation are all M=1, with efficiencies in the
|
||||
63-70% range. These shapes have such low absolute performance that the model's
|
||||
noise floor exceeds the performance difference between kernels.
|
||||
|
||||
**Mitigation**: Log-transform helps significantly (tiny_m improved from 84% to
|
||||
96%). For production use with M=1, consider a dedicated fallback (e.g.,
|
||||
hardcoded kernel selection for M < 4 based on known-good configs).
|
||||
|
||||
## 3. IHEM (Hard Example Mining) Hurts When Scale is the Issue
|
||||
|
||||
We tried Iterative Hard Example Mining with sample reweighting (2x-5x weight
|
||||
on hard shapes). Result: it made things **worse**, degrading mean efficiency
|
||||
from 94.31% to 92.90% over 3 iterations.
|
||||
|
||||
**Why**: The hard shapes are hard because of scale mismatch, not because the
|
||||
model lacks capacity. Reweighting amplifies the small-TFLOPS rows, which
|
||||
distorts the learned relationship between features and performance for the
|
||||
majority of shapes. The log-transform was the correct fix -- it addresses the
|
||||
root cause (scale) rather than the symptom (bad predictions on tiny shapes).
|
||||
|
||||
**Lesson**: IHEM is useful when the model has capacity gaps (e.g., certain
|
||||
pipeline types are underrepresented). It is counterproductive when the issue
|
||||
is target-variable scale. Always try target transforms before reweighting.
|
||||
|
||||
## 4. GroupKFold Key = (M, N, K) Forces Generalization
|
||||
|
||||
The validation uses `GroupKFold` where the group key is `(M, N, K)` -- all
|
||||
kernels for the same shape go to the same fold. This means:
|
||||
|
||||
- The model is always evaluated on shapes it has **never seen** during training
|
||||
- Layout is excluded from the key, forcing the model to generalize across layouts
|
||||
- Since models are per-arch, `arch` is implicit (constant within one training run)
|
||||
|
||||
This is much stricter than random row splitting, where the model would see some
|
||||
kernels for each shape during training. Our efficiency numbers are conservative
|
||||
estimates of real-world performance on unseen shapes.
|
||||
|
||||
## 5. Model Size vs Accuracy Tradeoff
|
||||
|
||||
|
||||
| Config | Trees | Leaves | LR | Mean Eff | P10 Eff | Train Time |
|
||||
| ------------------ | -------- | ------- | -------- | ---------- | ---------- | ------------- |
|
||||
| Small (default v1) | 500 | 127 | 0.05 | 96.92% | 94.34% | ~20s |
|
||||
| **Big (current)** | **2000** | **255** | **0.02** | **97.51%** | **93.89%** | **~25s/fold** |
|
||||
|
||||
|
||||
The bigger model improved mean efficiency by 0.6% but P10 didn't improve
|
||||
(actually slightly worse). The extra capacity helps on medium shapes but
|
||||
doesn't crack the tiny-M floor. This suggests the feature set, not model
|
||||
capacity, is the limiting factor for the hardest shapes.
|
||||
|
||||
For C++ deployment, the bigger model (2000 trees, 255 leaves) is still fast
|
||||
enough -- LightGBM inference is O(trees * log(leaves)) per sample, which is
|
||||
~microseconds even at 2000 trees.
|
||||
|
||||
## 6. N=1 and K=1 Shapes are Degenerate
|
||||
|
||||
We generated benchmark data for 546 edge-case shapes (N=1, K=1, small N/K).
|
||||
Result: **zero valid kernel results** across 94 shapes. All 4608 kernels either
|
||||
fail or produce 0 TFLOPS for these degenerate dimensions.
|
||||
|
||||
This means:
|
||||
|
||||
- The tile engine kernels have hard minimum dimension requirements
|
||||
- N=1 / K=1 shapes cannot be handled by the current kernel set
|
||||
- These shapes need dedicated kernels (e.g., BLAS-1/BLAS-2 fallbacks)
|
||||
- The ML model should not be expected to handle them -- they should be filtered
|
||||
out before reaching the heuristic
|
||||
|
||||
## 7. Feature Engineering Insights
|
||||
|
||||
From LightGBM feature importances on the log-target model:
|
||||
|
||||
**Top features** (by split count):
|
||||
|
||||
- `M, N, K` -- raw dimensions are always the most important
|
||||
- `tile_m, tile_n, tile_k` -- the tile shape is the primary kernel differentiator
|
||||
- `overall_tile_efficiency` -- how well the shape fits the tile (the interaction)
|
||||
- `num_tiles_m, total_output_tiles` -- work decomposition
|
||||
- `arithmetic_intensity` -- compute vs memory bound regime
|
||||
- `pipeline` -- pipeline type (compv3 vs compv4 vs mem) significantly affects perf
|
||||
|
||||
**Low-importance features**:
|
||||
|
||||
- Hardware constants (CUs, clock, caches) -- they're constant within one arch
|
||||
model, so they provide no discriminative signal. They'll become important when
|
||||
training cross-arch models.
|
||||
- `split_k` -- always 1 in current data
|
||||
- `persistent` -- rarely True in current kernel set
|
||||
|
||||
## 8. Warm-Start Works for Incremental Updates
|
||||
|
||||
LightGBM's `init_model` parameter successfully continues training from an
|
||||
existing model. New trees are added on top of existing ones. Key considerations:
|
||||
|
||||
- Feature schema must match exactly (enforced by `check_feature_compatibility`)
|
||||
- Use fewer new trees (200-500) since we're refining, not starting fresh
|
||||
- The `train_manifest.json` tracks the full lineage (total trees, data sizes)
|
||||
- Quality should be at least as good as the base model (tested)
|
||||
|
||||
## 9. Data Volume Matters More Than Model Complexity
|
||||
|
||||
|
||||
| Dataset | Shapes | Rows | Mean Eff (log, 500 trees) |
|
||||
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------ | ---- | ----------------------------- |
|
||||
| Original (DeepSeek only) | 108 | 418K | 98.28% (on seen distribution) |
|
||||
| + Wide coverage M=1 distribution. Adding 60 diverse shapes (many M=1) exposed the model's weakness on tiny shapes. More diverse training data is always better than a bigger model on narrow data.Summary of DefaultsBased on these findings, the current defaults in `train.py` are:- **Target transform**: `log1p` for TFLOPS and bandwidth (scale normalization)- **Model**: 2000 trees, 255 leaves, max depth 15, LR 0.02- **Validation**: 5-fold GroupKFold, key = (M, N, K)- **Early stopping**: patience 100 (let trees fully converge)- **Warm start**: 500 new trees (was 200, increased for bigger base model) | 168 | 626K | 96.92% (harder distribution) |
|
||||
|
||||
|
||||
The original 108-shape model looked great (98.28%) but was overfitting to the
|
||||
DeepSeek LLM inference
|
||||
|
||||
271
dispatcher/heuristics/README.md
Normal file
271
dispatcher/heuristics/README.md
Normal file
@@ -0,0 +1,271 @@
|
||||
# CK Tile Heuristics: ML-Based Kernel Selection
|
||||
|
||||
Fast, accurate kernel selection for CK Tile operations using LightGBM regression
|
||||
with Origami-augmented feature engineering.
|
||||
|
||||
## What This Does
|
||||
|
||||
Instead of running all 4608+ kernel configurations on the GPU to find the best
|
||||
one (exhaustive search taking ~46 seconds per shape), this system trains an ML
|
||||
model that predicts TFLOPS for any (problem, kernel) pair in microseconds. It
|
||||
scores all candidates instantly and picks the best kernel -- achieving 98.28%
|
||||
of oracle-best TFLOPS efficiency across 108 tested shapes.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Generate and convert benchmark data
|
||||
|
||||
**Step 1: Generate benchmark data**
|
||||
|
||||
```bash
|
||||
python3 generate_benchmark_data.py \
|
||||
--build_dir /path/to/build \
|
||||
--output_dir data/fp16_original \
|
||||
--dtype fp16 \
|
||||
--layout rcr \
|
||||
--num_build_jobs 4 \
|
||||
--warmup 10 \
|
||||
--repeat 50
|
||||
```
|
||||
|
||||
This outputs JSON with all benchmark results.
|
||||
|
||||
**Step 2: Convert JSON to parquet training format**
|
||||
|
||||
```bash
|
||||
python3 convert_json_to_parquet.py \
|
||||
--input data/fp16_original/benchmark_results_fp16_rcr.json \
|
||||
--output data/fp16_original/fp16_training_data.parquet \
|
||||
--arch gfx950
|
||||
```
|
||||
|
||||
The converter automatically fixes pad flags for `_mem` kernels and validates data.
|
||||
|
||||
**Alternative: Parse existing logs**
|
||||
|
||||
If you have raw benchmark logs from CK Tile:
|
||||
|
||||
```bash
|
||||
python3 data_pipeline.py ck_tile_testrun_2.log \
|
||||
-o data/gemm_universal_fp8_rcr_gfx950.parquet \
|
||||
--arch gfx950 --capture_hw
|
||||
```
|
||||
|
||||
### 2. Train a model
|
||||
|
||||
```bash
|
||||
python3 train.py \
|
||||
--data_dir data/ \
|
||||
--out_dir models/gemm_universal_fp8_gfx950 \
|
||||
--op gemm_universal --dtype fp8 --arch gfx950
|
||||
```
|
||||
|
||||
**Note**: Trained models are automatically compressed to `.lgbm.gz` format to save space (~67% reduction). The Python tools automatically decompress them on first use and cache the decompressed version. For warm-start training, decompression happens automatically.
|
||||
|
||||
### 3. Evaluate
|
||||
|
||||
```bash
|
||||
python3 evaluate.py \
|
||||
--model_dir models/gemm_universal_fp8_gfx950 \
|
||||
--data_dir data/ --op gemm_universal --dtype fp8
|
||||
```
|
||||
|
||||
### 4. Predict the best kernel for a problem
|
||||
|
||||
```bash
|
||||
python3 predict.py \
|
||||
--model_dir models/gemm_universal_fp8_gfx950 \
|
||||
--m 128 --n 1536 --k 7168 --layout rcr
|
||||
```
|
||||
|
||||
### 5. Search for optimal configs (optional)
|
||||
|
||||
```bash
|
||||
python3 search.py \
|
||||
--model_dir models/gemm_universal_fp8_gfx950 \
|
||||
--m 128 --n 1536 --k 7168 \
|
||||
--strategy random --budget 500 --top_k 10
|
||||
```
|
||||
|
||||
### 6. Using models in C++ (requires decompression)
|
||||
|
||||
C++ code uses the LightGBM C API which requires uncompressed `.lgbm` files. If you have compressed models (`.lgbm.gz`), decompress them first:
|
||||
|
||||
```bash
|
||||
cd models/gemm_universal_fp16_gfx950
|
||||
gunzip model_tflops.lgbm.gz
|
||||
```
|
||||
|
||||
Then use in C++ examples:
|
||||
|
||||
```bash
|
||||
cd dispatcher/build
|
||||
./gemm_09_ml_heuristic --model ../heuristics/models/gemm_universal_fp16_gfx950/model_tflops.lgbm
|
||||
```
|
||||
|
||||
**Note**: Python tools automatically decompress `.lgbm.gz` files on first use, so you can run Python scripts first to trigger decompression, then use the same models in C++.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Problem (M, N, K, dtype, layout)
|
||||
|
|
||||
v
|
||||
FeatureEngine.extract_batch() <-- 55 features: problem, kernel, interaction, hardware
|
||||
|
|
||||
v
|
||||
LGBMRegressor.predict() <-- predicts TFLOPS for each candidate kernel
|
||||
|
|
||||
v
|
||||
Sort by predicted TFLOPS <-- rank all candidates
|
||||
|
|
||||
v
|
||||
Select Top-1 kernel <-- 98.28% mean efficiency, <1ms inference
|
||||
```
|
||||
|
||||
Three models are trained per (op, dtype, arch):
|
||||
- **TFLOPS model** (primary): used for kernel ranking
|
||||
- **Latency model** (auxiliary): for latency-sensitive workloads
|
||||
- **Bandwidth model** (auxiliary): for memory-bound analysis
|
||||
|
||||
## File Inventory
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `generate_benchmark_data.py` | Build and run benchmarks across ~25 diverse problem sizes, output JSON |
|
||||
| `convert_json_to_parquet.py` | Convert benchmark JSON to parquet training format, fix `_mem` pad flags |
|
||||
| `data_pipeline.py` | Parse raw benchmark logs into canonical parquet datasets |
|
||||
| `feature_engine.py` | 55-feature extraction: problem, kernel, interaction, hardware profile |
|
||||
| `train.py` | Multi-target LGBMRegressor training with GroupKFold CV, IHEM, warm-start |
|
||||
| `predict.py` | Predictor class: predict TFLOPS/latency/bandwidth, rank kernels |
|
||||
| `evaluate.py` | Full evaluation: global metrics, per-shape/layout/pipeline slices |
|
||||
| `search.py` | Surrogate search: discrete DE, random top-K |
|
||||
| `generate_wide_coverage.py` | Generate benchmark data across 706 diverse shapes |
|
||||
| `generate_edge_dims.py` | Generate N=1, K=1, and other edge-case shapes |
|
||||
| `DATA_GENERATION.md` | Detailed guide for building binaries and generating data |
|
||||
| `plan.md` | Full design plan with architecture, milestones, and rationale |
|
||||
|
||||
## Features Used (55 total)
|
||||
|
||||
### Problem features (13)
|
||||
`M, N, K, split_k, log2(M), log2(N), log2(K), log2(MNK),
|
||||
arithmetic_intensity, aspect_ratio_mn, aspect_ratio_mk, aspect_ratio_nk, layout`
|
||||
|
||||
### Kernel features (17)
|
||||
`tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n,
|
||||
warp_tile_k, pipeline, scheduler, epilogue, pad_m, pad_n, pad_k, persistent,
|
||||
num_warps, tile_volume, tile_mn, lds_usage_estimate, lds_usage_ratio`
|
||||
|
||||
### Interaction features (9)
|
||||
`num_tiles_m, num_tiles_n, num_tiles_k, total_output_tiles,
|
||||
tile_eff_m, tile_eff_n, tile_eff_k, overall_tile_efficiency, cu_utilization`
|
||||
|
||||
### Hardware profile features (12)
|
||||
`hw_num_cus, hw_simds_per_cu, hw_total_simds, hw_shader_engines,
|
||||
hw_max_clock_mhz, hw_max_waves_per_cu, hw_wavefront_size, hw_lds_capacity,
|
||||
hw_l1_cache_kb, hw_l2_cache_kb, hw_l3_cache_kb, hw_num_xcd`
|
||||
|
||||
## Model Performance
|
||||
|
||||
### fp8 RCR, gfx950
|
||||
|
||||
| Metric | 108 shapes (original) | 168 shapes (wide coverage) |
|
||||
|---|---|---|
|
||||
| Mean TFLOPS Efficiency | 98.28% | 97.51% |
|
||||
| P10 TFLOPS Efficiency | 94.64% | 93.89% |
|
||||
| tiny_m (M=1) Efficiency | 95.57% | 96.04% |
|
||||
| R2 (TFLOPS) | 0.997 | 0.993 |
|
||||
|
||||
### fp16 RCR, gfx950
|
||||
|
||||
Trained on 25 shapes, 1,024 kernels, 21,920 valid benchmarks.
|
||||
|
||||
| Metric | Value |
|
||||
|---|---|
|
||||
| Mean TFLOPS Efficiency | 99.36% |
|
||||
| P10 TFLOPS Efficiency | 98.05% |
|
||||
| P50 TFLOPS Efficiency | 100.00% |
|
||||
| Min Efficiency | 95.45% |
|
||||
| NDCG@1 | 64.00% |
|
||||
| Top-5 Hit Rate | 88.00% |
|
||||
|
||||
**Shape Family Breakdown:**
|
||||
|
||||
| Shape Family | Mean Eff | P10 Eff | Shapes |
|
||||
|---|---|---|---|
|
||||
| Large M (M≥1024) | 99.54% | 99.07% | 4 |
|
||||
| Medium M (128≤M<1024) | 99.62% | 98.74% | 7 |
|
||||
| Small M (8≤M<128) | 98.82% | 96.22% | 8 |
|
||||
| Tiny M (M<8) | 99.65% | 98.96% | 6 |
|
||||
|
||||
**Pipeline Breakdown:**
|
||||
|
||||
| Pipeline | Mean Eff | P10 Eff |
|
||||
|---|---|---|
|
||||
| compv3 | 99.75% | 99.09% |
|
||||
| compv4 | 99.40% | 98.54% |
|
||||
| mem | 99.08% | 96.59% |
|
||||
|
||||
Training uses `log1p(TFLOPS)` as the target by default, which normalizes the
|
||||
scale across shapes spanning 0.02 to 2230 TFLOPS. This was the key finding
|
||||
that improved tiny-M shapes from 84% to 96% efficiency. See
|
||||
[LEARNINGS.md](LEARNINGS.md) for details.
|
||||
|
||||
## Validation
|
||||
|
||||
Training uses `GroupKFold(n_splits=5)` with group key `(M, N, K)` to ensure
|
||||
the model is evaluated on shapes it has never seen during training. Layout is
|
||||
excluded from the group key to force cross-layout generalization.
|
||||
|
||||
## Incremental Training (Warm Start)
|
||||
|
||||
When new benchmark data arrives, update the model without retraining from scratch:
|
||||
|
||||
```bash
|
||||
python3 train.py \
|
||||
--data_dir data/ \
|
||||
--out_dir models/v2 \
|
||||
--warm_start models/gemm_universal_fp8_gfx950 \
|
||||
--warm_start_n_estimators 200
|
||||
```
|
||||
|
||||
This adds 200 new trees on top of the existing model. Feature schemas must
|
||||
match exactly (automatically enforced).
|
||||
|
||||
## Extending to New Ops
|
||||
|
||||
Adding support for a new operation (e.g., `gemm_streamk`, `grouped_conv`):
|
||||
|
||||
1. **Build binaries**: `ninja -C build benchmark_gemm_streamk_fp8_rcr`
|
||||
2. **Subclass `FeatureEngine`**: add op-specific features (e.g., StreamK split factor)
|
||||
3. **Generate data**: run benchmarks across diverse shapes
|
||||
4. **Train**: `python3 train.py --op gemm_streamk --dtype fp8 --data_dir data/ --out_dir models/`
|
||||
|
||||
The training, evaluation, prediction, and search infrastructure is fully
|
||||
op-agnostic -- only the feature engine needs a new subclass.
|
||||
|
||||
## Tests
|
||||
|
||||
102 tests covering all modules:
|
||||
|
||||
```bash
|
||||
python3 -m pytest tests/ -v
|
||||
```
|
||||
|
||||
Test coverage includes:
|
||||
- Log parsing with malformed JSON, empty logs, single-kernel shapes
|
||||
- Feature formula correctness (tile efficiency, LDS usage, arithmetic intensity)
|
||||
- Corner-case shapes: M=1, N=1, K=1, prime dimensions, 20480x7168x256
|
||||
- Batch vs single extraction parity
|
||||
- Parameter space validation and projection
|
||||
- Predictor: single/batch prediction, ranking, missing models, empty inputs
|
||||
- Training: group keys, efficiency computation, warm-start, feature compatibility
|
||||
- Search: random, DE, config validity, determinism
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[README.md](README.md)**: This file -- quick start, architecture, performance
|
||||
- **[DATA_GENERATION.md](DATA_GENERATION.md)**: Complete guide for building tile engine
|
||||
binaries, running benchmarks, managing datasets, and troubleshooting
|
||||
- **[LEARNINGS.md](LEARNINGS.md)**: Empirical findings and design decisions (log-transform,
|
||||
IHEM results, tiny-M analysis, feature importance, N=1/K=1 edge cases)
|
||||
4
dispatcher/heuristics/__init__.py
Normal file
4
dispatcher/heuristics/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Tile Heuristics: ML-based kernel selection
|
||||
67
dispatcher/heuristics/collect_additional.sh
Executable file
67
dispatcher/heuristics/collect_additional.sh
Executable file
@@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Generate additional benchmark data for shapes NOT in the original log.
|
||||
# Runs in background; outputs streaming JSON that can be parsed by data_pipeline.py.
|
||||
|
||||
BIN_DIR="/workspace/ck_tile/bin"
|
||||
OUT_LOG="data/additional_shapes.log"
|
||||
WARMUP=3
|
||||
REPEAT=10
|
||||
|
||||
mkdir -p data
|
||||
|
||||
# Additional shapes: square powers-of-2 and common ML sizes not in original DeepSeek set
|
||||
SHAPES=(
|
||||
"64,64,64"
|
||||
"128,128,128"
|
||||
"256,256,256"
|
||||
"512,512,512"
|
||||
"1024,1024,1024"
|
||||
"2048,2048,2048"
|
||||
"4096,4096,4096"
|
||||
"1,4096,4096"
|
||||
"8,4096,4096"
|
||||
"32,4096,4096"
|
||||
"128,4096,4096"
|
||||
"1,4096,11008"
|
||||
"32,4096,11008"
|
||||
"1,8192,8192"
|
||||
"32,8192,8192"
|
||||
"1,8192,28672"
|
||||
"32,8192,28672"
|
||||
"256,256,8192"
|
||||
"8192,8192,256"
|
||||
"1024,4096,1024"
|
||||
"4096,1024,4096"
|
||||
"2048,8192,2048"
|
||||
)
|
||||
|
||||
echo "CK Tile Additional Shapes Benchmark" > "$OUT_LOG"
|
||||
echo "GPU ID: 0" >> "$OUT_LOG"
|
||||
echo "Implementation: gemm_universal" >> "$OUT_LOG"
|
||||
echo "" >> "$OUT_LOG"
|
||||
|
||||
SHAPE_IDX=0
|
||||
for SHAPE in "${SHAPES[@]}"; do
|
||||
IFS=',' read -r M N K <<< "$SHAPE"
|
||||
SHAPE_IDX=$((SHAPE_IDX + 1))
|
||||
|
||||
echo "========================================" >> "$OUT_LOG"
|
||||
echo "Shape $SHAPE_IDX: M=$M N=$N K=$K dtype=fp8 layout=rcr" >> "$OUT_LOG"
|
||||
echo "========================================" >> "$OUT_LOG"
|
||||
|
||||
KERNEL_COUNT=0
|
||||
for EXE in "$BIN_DIR"/benchmark_gemm_universal_fp8_rcr_*; do
|
||||
KERNEL_COUNT=$((KERNEL_COUNT + 1))
|
||||
OUTPUT=$("$EXE" -m="$M" -n="$N" -k="$K" -warmup=$WARMUP -repeat=$REPEAT -verify=0 2>/dev/null)
|
||||
# Extract just the JSON block
|
||||
echo "$OUTPUT" | sed -n '/{/,/^}/p' >> "$OUT_LOG"
|
||||
done
|
||||
|
||||
echo "Found $KERNEL_COUNT kernels" >> "$OUT_LOG"
|
||||
echo "Completed shape $SHAPE_IDX: M=$M N=$N K=$K ($KERNEL_COUNT kernels)" >&2
|
||||
done
|
||||
|
||||
echo "Done generating additional data" >&2
|
||||
233
dispatcher/heuristics/convert_json_to_parquet.py
Normal file
233
dispatcher/heuristics/convert_json_to_parquet.py
Normal file
@@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Convert benchmark JSON results to parquet format for training.
|
||||
|
||||
Usage:
|
||||
python convert_json_to_parquet.py \
|
||||
--input benchmark_results_fp16_rcr.json \
|
||||
--output fp16_training_data.parquet
|
||||
|
||||
Features:
|
||||
- Converts JSON benchmark results to flat row format
|
||||
- Automatically fixes pad flags for _mem kernels
|
||||
- Captures both successes and failures
|
||||
- Compatible with existing training data format
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def convert_json_to_parquet(json_file: Path, output_file: Path, arch: str = "gfx950"):
|
||||
"""Convert benchmark JSON to parquet training data format."""
|
||||
|
||||
print(f"Loading {json_file}...")
|
||||
with open(json_file) as f:
|
||||
data = json.load(f)
|
||||
|
||||
metadata = data.get("metadata", {})
|
||||
dtype = metadata.get("dtype", "fp16")
|
||||
layout = metadata.get("layout", "rcr")
|
||||
|
||||
print(f" Data type: {dtype}")
|
||||
print(f" Layout: {layout}")
|
||||
print(f" Kernels: {metadata.get('num_kernels', 0)}")
|
||||
print(f" Problem sizes: {metadata.get('num_problems', 0)}")
|
||||
print()
|
||||
|
||||
rows = []
|
||||
for kernel_result in data["results"]:
|
||||
kernel_config = kernel_result["kernel_config"]
|
||||
|
||||
for benchmark in kernel_result["benchmarks"]:
|
||||
# Common fields for both valid and invalid runs
|
||||
row = {
|
||||
"op_type": "gemm_universal",
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"arch": arch,
|
||||
"kernel_name": kernel_config["name"],
|
||||
"m": benchmark["m"],
|
||||
"n": benchmark["n"],
|
||||
"k": benchmark["k"],
|
||||
"split_k": 1,
|
||||
"is_valid": benchmark["is_valid"],
|
||||
"run_id": 0,
|
||||
"pipeline": kernel_config["pipeline"],
|
||||
"epilogue": kernel_config["epilogue"],
|
||||
"scheduler": kernel_config["scheduler"],
|
||||
"pad_m": kernel_config["pad_m"],
|
||||
"pad_n": kernel_config["pad_n"],
|
||||
"pad_k": kernel_config["pad_k"],
|
||||
"persistent": kernel_config["persistent"],
|
||||
"tile_m": kernel_config["tile_m"],
|
||||
"tile_n": kernel_config["tile_n"],
|
||||
"tile_k": kernel_config["tile_k"],
|
||||
"warp_m": kernel_config["warp_m"],
|
||||
"warp_n": kernel_config["warp_n"],
|
||||
"warp_k": kernel_config["warp_k"],
|
||||
"warp_tile_m": kernel_config["warp_tile_m"],
|
||||
"warp_tile_n": kernel_config["warp_tile_n"],
|
||||
"warp_tile_k": kernel_config["warp_tile_k"],
|
||||
}
|
||||
|
||||
if benchmark["is_valid"]:
|
||||
# Valid run - include performance metrics
|
||||
row["measured_tflops"] = benchmark["tflops"]
|
||||
row["latency_ms"] = benchmark["avg_time_ms"]
|
||||
# Calculate bandwidth if needed
|
||||
m, n, k = benchmark["m"], benchmark["n"], benchmark["k"]
|
||||
bytes_transferred = (m * k + k * n + m * n) * 2 # FP16 = 2 bytes
|
||||
if benchmark["avg_time_ms"] > 0:
|
||||
row["bandwidth_gb_s"] = (bytes_transferred / 1e9) / (
|
||||
benchmark["avg_time_ms"] / 1000
|
||||
)
|
||||
else:
|
||||
row["bandwidth_gb_s"] = 0.0
|
||||
else:
|
||||
# Failed run - zero metrics
|
||||
row["measured_tflops"] = 0.0
|
||||
row["latency_ms"] = 0.0
|
||||
row["bandwidth_gb_s"] = 0.0
|
||||
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
|
||||
print(f"Converted {len(df):,} benchmark results")
|
||||
print(f" Valid: {df['is_valid'].sum():,}")
|
||||
print(f" Failed: {(~df['is_valid']).sum():,}")
|
||||
print()
|
||||
|
||||
# Fix pad flags for _mem kernels (critical for P1 features!)
|
||||
print("Fixing pad flags for _mem kernels...")
|
||||
mem_mask = df["pipeline"] == "mem"
|
||||
mem_count = mem_mask.sum()
|
||||
|
||||
if mem_count > 0:
|
||||
df.loc[mem_mask, "pad_m"] = True
|
||||
df.loc[mem_mask, "pad_n"] = True
|
||||
df.loc[mem_mask, "pad_k"] = True
|
||||
print(f" ✓ Fixed {mem_count:,} _mem kernel rows")
|
||||
print()
|
||||
|
||||
# Save to parquet
|
||||
df.to_parquet(output_file, index=False)
|
||||
print(f"✓ Saved to {output_file}")
|
||||
print()
|
||||
|
||||
# Show statistics
|
||||
print("=" * 80)
|
||||
print("STATISTICS")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
print("Dimension ranges:")
|
||||
print(f" M: {df['m'].min():,} - {df['m'].max():,}")
|
||||
print(f" N: {df['n'].min():,} - {df['n'].max():,}")
|
||||
print(f" K: {df['k'].min():,} - {df['k'].max():,}")
|
||||
print()
|
||||
|
||||
print("Pipeline distribution:")
|
||||
print(df["pipeline"].value_counts())
|
||||
print()
|
||||
|
||||
print("Pad flag distribution:")
|
||||
pad_combos = df[["pad_m", "pad_n", "pad_k"]].value_counts()
|
||||
print(pad_combos)
|
||||
print()
|
||||
|
||||
if (~df["is_valid"]).sum() > 0:
|
||||
print("Failure analysis:")
|
||||
failed = df[~df["is_valid"]]
|
||||
print(f" Total failures: {len(failed):,}")
|
||||
|
||||
# Group by pipeline
|
||||
print("\n By pipeline:")
|
||||
for pipeline, count in failed["pipeline"].value_counts().items():
|
||||
print(f" {pipeline}: {count:,}")
|
||||
|
||||
# Show sample failures
|
||||
print("\n Sample failures:")
|
||||
for _, row in failed.head(5).iterrows():
|
||||
print(
|
||||
f" {row['kernel_name'][:60]:60s} M={row['m']:4d} N={row['n']:4d} K={row['k']:4d}"
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def merge_datasets(parquet_files: list[Path], output_file: Path):
|
||||
"""Merge multiple parquet files into one."""
|
||||
|
||||
print("=" * 80)
|
||||
print("MERGING DATASETS")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
dfs = []
|
||||
for pq_file in parquet_files:
|
||||
if pq_file.exists():
|
||||
df = pd.read_parquet(pq_file)
|
||||
print(f" {pq_file.name}: {len(df):,} rows")
|
||||
dfs.append(df)
|
||||
else:
|
||||
print(f" ✗ {pq_file} not found, skipping")
|
||||
|
||||
if not dfs:
|
||||
print("No files to merge!")
|
||||
return
|
||||
|
||||
combined = pd.concat(dfs, ignore_index=True)
|
||||
combined.to_parquet(output_file, index=False)
|
||||
|
||||
print()
|
||||
print(f"✓ Merged {len(combined):,} total rows to {output_file}")
|
||||
print()
|
||||
|
||||
# Show dtype distribution
|
||||
print("Data type distribution:")
|
||||
print(combined["dtype"].value_counts())
|
||||
print()
|
||||
|
||||
print("Layout distribution:")
|
||||
print(combined["layout"].value_counts())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert benchmark JSON to parquet training data",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input", type=str, required=True, help="Input JSON file from benchmark"
|
||||
)
|
||||
parser.add_argument("--output", type=str, required=True, help="Output parquet file")
|
||||
parser.add_argument("--arch", type=str, default="gfx950", help="GPU architecture")
|
||||
parser.add_argument(
|
||||
"--merge_with", type=str, nargs="*", help="Additional parquet files to merge"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
input_file = Path(args.input)
|
||||
output_file = Path(args.output)
|
||||
|
||||
# Convert JSON to parquet
|
||||
df = convert_json_to_parquet(input_file, output_file, args.arch)
|
||||
|
||||
# Merge if requested
|
||||
if args.merge_with:
|
||||
merge_files = [output_file] + [Path(f) for f in args.merge_with]
|
||||
merged_output = output_file.parent / f"{output_file.stem}_merged.parquet"
|
||||
merge_datasets(merge_files, merged_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
394
dispatcher/heuristics/data_pipeline.py
Normal file
394
dispatcher/heuristics/data_pipeline.py
Normal file
@@ -0,0 +1,394 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Data pipeline for CK Tile heuristics.
|
||||
|
||||
Parses benchmark logs and structured JSON into a canonical parquet dataset.
|
||||
Supports:
|
||||
- Streaming log format (Shape N: headers + inline JSON) from ck_tile profiling runs
|
||||
- Structured JSON from generate_benchmark_data.py
|
||||
- Direct parquet passthrough
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
CANONICAL_COLUMNS = [
|
||||
"op_type",
|
||||
"dtype",
|
||||
"layout",
|
||||
"arch",
|
||||
"kernel_name",
|
||||
"m",
|
||||
"n",
|
||||
"k",
|
||||
"split_k",
|
||||
"measured_tflops",
|
||||
"latency_ms",
|
||||
"bandwidth_gb_s",
|
||||
"is_valid",
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
"run_id",
|
||||
]
|
||||
|
||||
|
||||
def parse_kernel_name(name: str) -> dict:
|
||||
"""Extract kernel config fields from a gemm_universal kernel name.
|
||||
|
||||
Name format:
|
||||
gemm_universal_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}
|
||||
_{padM}_{padN}_{padK}_{persistent}_{tileM}x{tileN}x{tileK}
|
||||
_{warpM}x{warpN}x{warpK}_{warpTileM}x{warpTileN}x{warpTileK}
|
||||
"""
|
||||
result = {}
|
||||
try:
|
||||
prefix_match = re.match(
|
||||
r"gemm_universal_(\w+?)_((?:rcr|rrr|crr|ccr))_(.*)", name
|
||||
)
|
||||
if not prefix_match:
|
||||
return result
|
||||
result["dtype"] = prefix_match.group(1)
|
||||
result["layout"] = prefix_match.group(2)
|
||||
remainder = prefix_match.group(3)
|
||||
|
||||
parts = remainder.split("_")
|
||||
if len(parts) < 10:
|
||||
return result
|
||||
|
||||
result["pipeline"] = parts[0]
|
||||
result["epilogue"] = parts[1]
|
||||
result["scheduler"] = parts[2]
|
||||
result["pad_m"] = parts[3] == "True"
|
||||
result["pad_n"] = parts[4] == "True"
|
||||
result["pad_k"] = parts[5] == "True"
|
||||
result["persistent"] = parts[6] == "True"
|
||||
|
||||
tile_dims = parts[7].split("x")
|
||||
warp_dims = parts[8].split("x")
|
||||
warp_tile_dims = parts[9].split("x")
|
||||
|
||||
result["tile_m"] = int(tile_dims[0])
|
||||
result["tile_n"] = int(tile_dims[1])
|
||||
result["tile_k"] = int(tile_dims[2])
|
||||
result["warp_m"] = int(warp_dims[0])
|
||||
result["warp_n"] = int(warp_dims[1])
|
||||
result["warp_k"] = int(warp_dims[2])
|
||||
result["warp_tile_m"] = int(warp_tile_dims[0])
|
||||
result["warp_tile_n"] = int(warp_tile_dims[1])
|
||||
result["warp_tile_k"] = int(warp_tile_dims[2])
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _layout_from_problem(problem: dict) -> str:
|
||||
"""Derive layout shorthand (rcr/rrr/etc.) from problem JSON fields."""
|
||||
la = problem.get("layout_a", "")
|
||||
lb = problem.get("layout_b", "")
|
||||
lc = problem.get("layout_c", "")
|
||||
|
||||
def _tag(s):
|
||||
s = s.lower()
|
||||
if "row" in s:
|
||||
return "r"
|
||||
if "col" in s:
|
||||
return "c"
|
||||
return "?"
|
||||
|
||||
return _tag(la) + _tag(lb) + _tag(lc)
|
||||
|
||||
|
||||
def parse_streaming_log(
|
||||
path: str | Path,
|
||||
arch: str = "unknown",
|
||||
run_id: Optional[str] = None,
|
||||
op_type: str = "gemm_universal",
|
||||
) -> pd.DataFrame:
|
||||
"""Parse a CK Tile streaming benchmark log into a canonical DataFrame.
|
||||
|
||||
The log alternates between shape headers and JSON result blocks:
|
||||
Shape N: M=16 N=1536 K=7168 dtype=fp8 layout=rcr
|
||||
{
|
||||
"name": "gemm_universal_...",
|
||||
"problem": { ... },
|
||||
"perf_result": { "latency(ms)": ..., "tflops(TFlops)": ..., "bandwidth(GB/s)": ... }
|
||||
}
|
||||
"""
|
||||
path = Path(path)
|
||||
if run_id is None:
|
||||
run_id = hashlib.md5(path.name.encode()).hexdigest()[:12]
|
||||
|
||||
shape_re = re.compile(
|
||||
r"Shape\s+\d+:\s+M=(\d+)\s+N=(\d+)\s+K=(\d+)\s+dtype=(\w+)\s+layout=(\w+)"
|
||||
)
|
||||
|
||||
rows = []
|
||||
current_m, current_n, current_k = 0, 0, 0
|
||||
current_dtype, current_layout = "", ""
|
||||
json_buf = []
|
||||
brace_depth = 0
|
||||
|
||||
with open(path, "r") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
|
||||
shape_match = shape_re.search(stripped)
|
||||
if shape_match:
|
||||
current_m = int(shape_match.group(1))
|
||||
current_n = int(shape_match.group(2))
|
||||
current_k = int(shape_match.group(3))
|
||||
current_dtype = shape_match.group(4)
|
||||
current_layout = shape_match.group(5)
|
||||
continue
|
||||
|
||||
if brace_depth == 0 and stripped.startswith("{"):
|
||||
json_buf = [stripped]
|
||||
brace_depth = stripped.count("{") - stripped.count("}")
|
||||
if brace_depth == 0:
|
||||
raw = "\n".join(json_buf)
|
||||
try:
|
||||
obj = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
elif brace_depth > 0:
|
||||
json_buf.append(stripped)
|
||||
brace_depth += stripped.count("{") - stripped.count("}")
|
||||
if brace_depth <= 0:
|
||||
brace_depth = 0
|
||||
raw = "\n".join(json_buf)
|
||||
try:
|
||||
obj = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# If we get here, obj was successfully parsed
|
||||
kernel_name = obj.get("name", "")
|
||||
problem = obj.get("problem", {})
|
||||
perf = obj.get("perf_result", {})
|
||||
|
||||
m = problem.get("m", current_m)
|
||||
n = problem.get("n", current_n)
|
||||
k = problem.get("k", current_k)
|
||||
split_k = problem.get("split_k", 1)
|
||||
dtype = problem.get("dtype_a", current_dtype)
|
||||
layout = (
|
||||
_layout_from_problem(problem)
|
||||
if problem.get("layout_a")
|
||||
else current_layout
|
||||
)
|
||||
|
||||
tflops = perf.get("tflops(TFlops)", 0.0)
|
||||
latency = perf.get("latency(ms)", 0.0)
|
||||
bandwidth = perf.get("bandwidth(GB/s)", 0.0)
|
||||
|
||||
kp = parse_kernel_name(kernel_name)
|
||||
|
||||
row = {
|
||||
"op_type": op_type,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"arch": arch,
|
||||
"kernel_name": kernel_name,
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"split_k": split_k,
|
||||
"measured_tflops": tflops,
|
||||
"latency_ms": latency,
|
||||
"bandwidth_gb_s": bandwidth,
|
||||
"is_valid": tflops > 0 and latency > 0,
|
||||
"run_id": run_id,
|
||||
}
|
||||
row.update(kp)
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
for col in CANONICAL_COLUMNS:
|
||||
if col not in df.columns:
|
||||
df[col] = None
|
||||
return df
|
||||
|
||||
|
||||
def get_hardware_profile() -> dict:
|
||||
"""Capture GPU hardware profile from rocminfo."""
|
||||
profile = {}
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["rocminfo"], capture_output=True, text=True, timeout=30
|
||||
)
|
||||
output = result.stdout
|
||||
|
||||
gpu_section = False
|
||||
for line in output.split("\n"):
|
||||
line = line.strip()
|
||||
if "Device Type:" in line and "GPU" in line:
|
||||
gpu_section = True
|
||||
continue
|
||||
if gpu_section and "Device Type:" in line and "GPU" not in line:
|
||||
break
|
||||
if not gpu_section:
|
||||
continue
|
||||
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, _, val = line.partition(":")
|
||||
key = key.strip()
|
||||
val = val.strip()
|
||||
|
||||
mapping = {
|
||||
"Name": "gfx_name",
|
||||
"Marketing Name": "marketing_name",
|
||||
"Compute Unit": "num_cus",
|
||||
"SIMDs per CU": "simds_per_cu",
|
||||
"Shader Engines": "shader_engines",
|
||||
"Shader Arrs. per Eng.": "shader_arrays_per_engine",
|
||||
"Max Clock Freq. (MHz)": "max_clock_mhz",
|
||||
"Wavefront Size": "wavefront_size",
|
||||
"Max Waves Per CU": "max_waves_per_cu",
|
||||
"Chip ID": "chip_id",
|
||||
}
|
||||
|
||||
if key in mapping:
|
||||
raw = val.split("(")[0].strip()
|
||||
try:
|
||||
profile[mapping[key]] = int(raw)
|
||||
except ValueError:
|
||||
profile[mapping[key]] = raw
|
||||
|
||||
for line in output.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("L1:") and "num_cus" in profile:
|
||||
raw = line.split(":")[1].strip().split("(")[0].strip()
|
||||
try:
|
||||
profile["l1_cache_kb"] = int(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
elif line.startswith("L2:"):
|
||||
raw = line.split(":")[1].strip().split("(")[0].strip()
|
||||
try:
|
||||
profile["l2_cache_kb"] = int(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
elif line.startswith("L3:"):
|
||||
raw = line.split(":")[1].strip().split("(")[0].strip()
|
||||
try:
|
||||
profile["l3_cache_kb"] = int(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def load_parquet(path: str | Path) -> pd.DataFrame:
|
||||
"""Load a canonical parquet dataset."""
|
||||
return pd.read_parquet(path)
|
||||
|
||||
|
||||
def save_parquet(df: pd.DataFrame, path: str | Path):
|
||||
"""Save a DataFrame in canonical parquet format."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(path, index=False, engine="pyarrow")
|
||||
|
||||
|
||||
def build_training_dataset(
|
||||
data_dir: str | Path,
|
||||
op_type: str = "gemm_universal",
|
||||
dtype: str = "fp8",
|
||||
) -> pd.DataFrame:
|
||||
"""Load and merge all parquet files matching the given op/dtype from a directory."""
|
||||
data_dir = Path(data_dir)
|
||||
frames = []
|
||||
for f in sorted(data_dir.glob("*.parquet")):
|
||||
df = pd.read_parquet(f)
|
||||
if "op_type" in df.columns:
|
||||
df = df[df["op_type"] == op_type]
|
||||
if "dtype" in df.columns:
|
||||
df = df[df["dtype"] == dtype]
|
||||
if len(df) > 0:
|
||||
frames.append(df)
|
||||
if not frames:
|
||||
raise FileNotFoundError(
|
||||
f"No parquet files with op_type={op_type}, dtype={dtype} in {data_dir}"
|
||||
)
|
||||
return pd.concat(frames, ignore_index=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser(description="Parse CK Tile benchmark data")
|
||||
parser.add_argument("input", help="Input file (log or parquet)")
|
||||
parser.add_argument("--output", "-o", required=True, help="Output parquet path")
|
||||
parser.add_argument("--arch", default="gfx950", help="GPU architecture")
|
||||
parser.add_argument("--op_type", default="gemm_universal", help="Operation type")
|
||||
parser.add_argument(
|
||||
"--capture_hw",
|
||||
action="store_true",
|
||||
help="Capture hardware profile from rocminfo",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = Path(args.input)
|
||||
|
||||
print(f"Parsing {input_path}...")
|
||||
t0 = time.time()
|
||||
|
||||
if input_path.suffix == ".parquet":
|
||||
df = load_parquet(input_path)
|
||||
else:
|
||||
df = parse_streaming_log(input_path, arch=args.arch, op_type=args.op_type)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
print(f"Parsed {len(df)} rows in {elapsed:.1f}s")
|
||||
print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}")
|
||||
print(f" Unique kernels: {df['kernel_name'].nunique()}")
|
||||
print(f" Valid rows: {df['is_valid'].sum()} / {len(df)}")
|
||||
|
||||
if df["measured_tflops"].max() > 0:
|
||||
print(
|
||||
f" TFLOPS range: {df['measured_tflops'].min():.2f} - {df['measured_tflops'].max():.2f}"
|
||||
)
|
||||
|
||||
if args.capture_hw:
|
||||
hw = get_hardware_profile()
|
||||
print(f" Hardware profile: {hw}")
|
||||
for k, v in hw.items():
|
||||
df[f"hw_{k}"] = v
|
||||
|
||||
save_parquet(df, args.output)
|
||||
print(f"Saved to {args.output}")
|
||||
324
dispatcher/heuristics/dispatcher_integration.py
Normal file
324
dispatcher/heuristics/dispatcher_integration.py
Normal file
@@ -0,0 +1,324 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Dispatcher integration for ML-based kernel selection.
|
||||
|
||||
Bridges the trained LightGBM Predictor with the CK Tile dispatcher's
|
||||
kernel selection flow. Provides heuristic functions compatible with
|
||||
both the Python pre-selection pattern (08_heuristics.py style) and
|
||||
the C++ HeuristicFunction signature.
|
||||
|
||||
Name mapping between feature engine and dispatcher KernelConfig:
|
||||
Feature engine Dispatcher KernelConfig
|
||||
--------------------- ----------------------
|
||||
warp_m (warps/block) wave_m
|
||||
warp_n wave_n
|
||||
warp_k wave_k
|
||||
warp_tile_m warp_m
|
||||
warp_tile_n warp_n
|
||||
warp_tile_k warp_k
|
||||
|
||||
Usage:
|
||||
from dispatcher_integration import create_ml_heuristic
|
||||
|
||||
heuristic = create_ml_heuristic("models/gemm_universal_fp8_gfx950")
|
||||
best_spec = heuristic(M=1024, N=1024, K=1024, kernel_pool=KERNEL_POOL)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from data_pipeline import parse_kernel_name
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
LAYOUT_TO_DISPATCHER = {
|
||||
"rcr": ("row", "col", "row"),
|
||||
"rrr": ("row", "row", "row"),
|
||||
"crr": ("col", "row", "row"),
|
||||
"ccr": ("col", "col", "row"),
|
||||
}
|
||||
|
||||
DTYPE_TO_C_DTYPE = {
|
||||
"fp8": "fp16",
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
"fp32": "fp32",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLKernelSpec:
|
||||
"""Kernel spec returned by the ML heuristic, compatible with the dispatcher
|
||||
example pattern. Carries both the feature-engine-space config and the
|
||||
dispatcher-space KernelConfig fields."""
|
||||
|
||||
kernel_name: str
|
||||
predicted_tflops: float
|
||||
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
wave_m: int
|
||||
wave_n: int
|
||||
wave_k: int
|
||||
warp_m: int
|
||||
warp_n: int
|
||||
warp_k: int
|
||||
pipeline: str
|
||||
scheduler: str
|
||||
epilogue: str
|
||||
pad_m: bool
|
||||
pad_n: bool
|
||||
pad_k: bool
|
||||
persistent: bool
|
||||
|
||||
|
||||
def kernel_config_to_feature_dict(kernel_name: str) -> dict:
|
||||
"""Parse a tile-engine kernel name into a feature-engine-compatible dict.
|
||||
|
||||
Returns a dict with fields matching what GemmUniversalFeatureEngine.extract()
|
||||
expects for the kernel parameter: tile_m/n/k, warp_m/n/k (warps per block),
|
||||
warp_tile_m/n/k, pipeline, scheduler, epilogue, pad_m/n/k, persistent.
|
||||
"""
|
||||
parsed = parse_kernel_name(kernel_name)
|
||||
if not parsed:
|
||||
return {}
|
||||
parsed["kernel_name"] = kernel_name
|
||||
return parsed
|
||||
|
||||
|
||||
def feature_dict_to_dispatcher_config(
|
||||
feat: dict, dtype: str = "fp8", arch: str = "gfx950"
|
||||
) -> dict:
|
||||
"""Convert a feature-engine kernel dict to dispatcher KernelConfig fields.
|
||||
|
||||
Handles the naming inversion:
|
||||
feature engine warp_m -> KernelConfig wave_m (warps per block)
|
||||
feature engine warp_tile_m -> KernelConfig warp_m (elements per warp)
|
||||
"""
|
||||
layout = feat.get("layout", "rcr")
|
||||
la, lb, lc = LAYOUT_TO_DISPATCHER.get(layout, ("row", "col", "row"))
|
||||
c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype)
|
||||
|
||||
return {
|
||||
"dtype_a": dtype,
|
||||
"dtype_b": dtype,
|
||||
"dtype_c": c_dtype,
|
||||
"dtype_acc": "fp32",
|
||||
"layout_a": la,
|
||||
"layout_b": lb,
|
||||
"layout_c": lc,
|
||||
"tile_m": feat.get("tile_m", 128),
|
||||
"tile_n": feat.get("tile_n", 128),
|
||||
"tile_k": feat.get("tile_k", 64),
|
||||
"wave_m": feat.get("warp_m", 2),
|
||||
"wave_n": feat.get("warp_n", 2),
|
||||
"wave_k": feat.get("warp_k", 1),
|
||||
"warp_m": feat.get("warp_tile_m", 32),
|
||||
"warp_n": feat.get("warp_tile_n", 32),
|
||||
"warp_k": feat.get("warp_tile_k", 16),
|
||||
"pipeline": feat.get("pipeline", "compv3"),
|
||||
"scheduler": feat.get("scheduler", "intrawave"),
|
||||
"epilogue": feat.get("epilogue", "cshuffle"),
|
||||
"pad_m": feat.get("pad_m", True),
|
||||
"pad_n": feat.get("pad_n", True),
|
||||
"pad_k": feat.get("pad_k", True),
|
||||
"gfx_arch": arch,
|
||||
}
|
||||
|
||||
|
||||
def feature_dict_to_ml_spec(feat: dict, predicted_tflops: float = 0.0) -> MLKernelSpec:
|
||||
"""Convert a feature-engine kernel dict + prediction to an MLKernelSpec."""
|
||||
return MLKernelSpec(
|
||||
kernel_name=feat.get("kernel_name", "unknown"),
|
||||
predicted_tflops=predicted_tflops,
|
||||
tile_m=feat.get("tile_m", 128),
|
||||
tile_n=feat.get("tile_n", 128),
|
||||
tile_k=feat.get("tile_k", 64),
|
||||
wave_m=feat.get("warp_m", 2),
|
||||
wave_n=feat.get("warp_n", 2),
|
||||
wave_k=feat.get("warp_k", 1),
|
||||
warp_m=feat.get("warp_tile_m", 32),
|
||||
warp_n=feat.get("warp_tile_n", 32),
|
||||
warp_k=feat.get("warp_tile_k", 16),
|
||||
pipeline=feat.get("pipeline", "compv3"),
|
||||
scheduler=feat.get("scheduler", "intrawave"),
|
||||
epilogue=feat.get("epilogue", "cshuffle"),
|
||||
pad_m=feat.get("pad_m", False),
|
||||
pad_n=feat.get("pad_n", False),
|
||||
pad_k=feat.get("pad_k", False),
|
||||
persistent=feat.get("persistent", False),
|
||||
)
|
||||
|
||||
|
||||
def load_kernel_pool_from_binaries(bin_dir: str | Path) -> list[dict]:
|
||||
"""Discover benchmark executables and parse their names into feature dicts.
|
||||
|
||||
Each executable name encodes the full kernel config. This creates the
|
||||
candidate pool for the ML heuristic without needing a registry JSON export.
|
||||
"""
|
||||
bin_dir = Path(bin_dir)
|
||||
configs = []
|
||||
for exe in sorted(bin_dir.glob("benchmark_gemm_universal_*")):
|
||||
name = exe.stem.replace("benchmark_", "")
|
||||
feat = kernel_config_to_feature_dict(name)
|
||||
if feat and "tile_m" in feat:
|
||||
configs.append(feat)
|
||||
return configs
|
||||
|
||||
|
||||
def create_ml_heuristic(
|
||||
model_dir: str | Path,
|
||||
dtype: str = "fp8",
|
||||
arch: str = "gfx950",
|
||||
layout: str = "rcr",
|
||||
kernel_pool: Optional[list[dict]] = None,
|
||||
bin_dir: Optional[str | Path] = None,
|
||||
):
|
||||
"""Create an ML heuristic function for kernel selection.
|
||||
|
||||
Returns a callable with signature:
|
||||
(M: int, N: int, K: int) -> MLKernelSpec
|
||||
|
||||
The returned function scores all candidate kernels using the trained
|
||||
LightGBM regressor and returns the best one as an MLKernelSpec.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_dir : str or Path
|
||||
Path to trained model directory (must contain model_tflops.lgbm or
|
||||
model_tflops_log_big.lgbm and feature_spec.json).
|
||||
dtype : str
|
||||
Data type for the problem (fp8, fp16, bf16).
|
||||
arch : str
|
||||
GPU architecture (gfx942, gfx950).
|
||||
layout : str
|
||||
Matrix layout (rcr, rrr, crr, ccr).
|
||||
kernel_pool : list of dict, optional
|
||||
Pre-parsed kernel configs. If None, loads from bin_dir.
|
||||
bin_dir : str or Path, optional
|
||||
Directory with benchmark executables. Used to build kernel_pool if
|
||||
kernel_pool is not provided. Defaults to /workspace/ck_tile/bin.
|
||||
"""
|
||||
model_dir = Path(model_dir)
|
||||
predictor = Predictor(model_dir)
|
||||
|
||||
if kernel_pool is None:
|
||||
if bin_dir is None:
|
||||
bin_dir = Path("/workspace/ck_tile/bin")
|
||||
kernel_pool = load_kernel_pool_from_binaries(bin_dir)
|
||||
|
||||
if not kernel_pool:
|
||||
raise ValueError(
|
||||
"No kernel configs found. Check bin_dir or provide kernel_pool."
|
||||
)
|
||||
|
||||
def heuristic(M: int, N: int, K: int) -> MLKernelSpec:
|
||||
problem = {
|
||||
"m": M,
|
||||
"n": N,
|
||||
"k": K,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
ranked = predictor.rank_kernels(problem, kernel_pool)
|
||||
|
||||
if not ranked:
|
||||
feat = kernel_pool[0]
|
||||
return feature_dict_to_ml_spec(feat, 0.0)
|
||||
|
||||
best_name, best_tflops = ranked[0]
|
||||
best_feat = next(
|
||||
(kp for kp in kernel_pool if kp.get("kernel_name") == best_name),
|
||||
kernel_pool[0],
|
||||
)
|
||||
return feature_dict_to_ml_spec(best_feat, best_tflops)
|
||||
|
||||
return heuristic
|
||||
|
||||
|
||||
def create_ranked_heuristic(
|
||||
model_dir: str | Path,
|
||||
dtype: str = "fp8",
|
||||
arch: str = "gfx950",
|
||||
layout: str = "rcr",
|
||||
kernel_pool: Optional[list[dict]] = None,
|
||||
bin_dir: Optional[str | Path] = None,
|
||||
top_k: int = 5,
|
||||
):
|
||||
"""Create an ML heuristic that returns the top-K ranked kernel specs.
|
||||
|
||||
Returns a callable with signature:
|
||||
(M: int, N: int, K: int) -> list[MLKernelSpec]
|
||||
|
||||
Useful when you want fallback options if the top-1 kernel fails to build.
|
||||
"""
|
||||
model_dir = Path(model_dir)
|
||||
predictor = Predictor(model_dir)
|
||||
|
||||
if kernel_pool is None:
|
||||
if bin_dir is None:
|
||||
bin_dir = Path("/workspace/ck_tile/bin")
|
||||
kernel_pool = load_kernel_pool_from_binaries(bin_dir)
|
||||
|
||||
name_to_feat = {kp.get("kernel_name", ""): kp for kp in kernel_pool}
|
||||
|
||||
def heuristic(M: int, N: int, K: int) -> list[MLKernelSpec]:
|
||||
problem = {
|
||||
"m": M,
|
||||
"n": N,
|
||||
"k": K,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
ranked = predictor.rank_kernels(problem, kernel_pool)
|
||||
results = []
|
||||
for name, tflops in ranked[:top_k]:
|
||||
feat = name_to_feat.get(name, kernel_pool[0])
|
||||
results.append(feature_dict_to_ml_spec(feat, tflops))
|
||||
return results
|
||||
|
||||
return heuristic
|
||||
|
||||
|
||||
def ml_spec_to_dispatcher_config(
|
||||
spec: MLKernelSpec, dtype: str = "fp8", arch: str = "gfx950"
|
||||
) -> dict:
|
||||
"""Convert an MLKernelSpec to a dict compatible with ctypes_utils.KernelConfig."""
|
||||
layout_a, layout_b, layout_c = "row", "col", "row"
|
||||
c_dtype = DTYPE_TO_C_DTYPE.get(dtype, dtype)
|
||||
|
||||
return {
|
||||
"dtype_a": dtype,
|
||||
"dtype_b": dtype,
|
||||
"dtype_c": c_dtype,
|
||||
"dtype_acc": "fp32",
|
||||
"layout_a": layout_a,
|
||||
"layout_b": layout_b,
|
||||
"layout_c": layout_c,
|
||||
"tile_m": spec.tile_m,
|
||||
"tile_n": spec.tile_n,
|
||||
"tile_k": spec.tile_k,
|
||||
"wave_m": spec.wave_m,
|
||||
"wave_n": spec.wave_n,
|
||||
"wave_k": spec.wave_k,
|
||||
"warp_m": spec.warp_m,
|
||||
"warp_n": spec.warp_n,
|
||||
"warp_k": spec.warp_k,
|
||||
"pipeline": spec.pipeline,
|
||||
"scheduler": spec.scheduler,
|
||||
"epilogue": spec.epilogue,
|
||||
"pad_m": spec.pad_m,
|
||||
"pad_n": spec.pad_n,
|
||||
"pad_k": spec.pad_k,
|
||||
"gfx_arch": arch,
|
||||
}
|
||||
254
dispatcher/heuristics/evaluate.py
Normal file
254
dispatcher/heuristics/evaluate.py
Normal file
@@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Evaluation and reporting for CK Tile kernel performance models.
|
||||
|
||||
Computes:
|
||||
- Global metrics: TFLOPS efficiency (mean, p10, p50, min), R2, NDCG@1, Top-K hit rate
|
||||
- Per-slice breakdowns: by layout, shape family, K-depth regime, pipeline
|
||||
- Cross-target consistency checks
|
||||
- Feature importance analysis
|
||||
|
||||
Usage:
|
||||
python evaluate.py --model_dir models/gemm_universal_fp8_gfx950 --data_dir data/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from data_pipeline import build_training_dataset
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from predict import Predictor
|
||||
from train import compute_tflops_efficiency
|
||||
|
||||
|
||||
def classify_shape_family(m: int, n: int, k: int) -> str:
|
||||
"""Classify a GEMM shape into a family for sliced evaluation.
|
||||
|
||||
Families:
|
||||
- tiny_m: M < 32 (single-token / very small batch inference)
|
||||
- small_m: 32 <= M < 256
|
||||
- medium_m: 256 <= M < 4096
|
||||
- large_m: M >= 4096
|
||||
- square: 0.5 <= M/N <= 2.0 and 0.5 <= M/K <= 2.0
|
||||
- tall: M/N > 2.0
|
||||
- wide: M/N < 0.5
|
||||
"""
|
||||
if m < 32:
|
||||
return "tiny_m"
|
||||
elif m < 256:
|
||||
return "small_m"
|
||||
elif m < 4096:
|
||||
return "medium_m"
|
||||
elif m >= 4096:
|
||||
return "large_m"
|
||||
return "other"
|
||||
|
||||
|
||||
def classify_k_regime(k: int) -> str:
|
||||
"""Classify K dimension into depth regime."""
|
||||
if k < 512:
|
||||
return "shallow_k"
|
||||
elif k < 4096:
|
||||
return "medium_k"
|
||||
else:
|
||||
return "deep_k"
|
||||
|
||||
|
||||
def evaluate_model(
|
||||
predictor: Predictor,
|
||||
df: pd.DataFrame,
|
||||
feature_engine: GemmUniversalFeatureEngine,
|
||||
) -> dict:
|
||||
"""Run full evaluation on a dataset. Returns a metrics dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predictor : Predictor
|
||||
Trained predictor with at least a TFLOPS model loaded.
|
||||
df : pd.DataFrame
|
||||
Benchmark data in canonical schema.
|
||||
feature_engine : GemmUniversalFeatureEngine
|
||||
Feature engine matching the trained model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with keys: global_metrics, shape_family_metrics, k_regime_metrics,
|
||||
pipeline_metrics, per_shape_efficiency.
|
||||
"""
|
||||
valid = df[df["is_valid"].fillna(False) & (df["measured_tflops"] > 0)].copy()
|
||||
valid = valid.reset_index(drop=True)
|
||||
|
||||
X = feature_engine.extract_batch(valid)
|
||||
model = predictor._load_model("tflops")
|
||||
if model is None:
|
||||
raise FileNotFoundError("No TFLOPS model found")
|
||||
|
||||
# Predict and apply inverse log transform if model was trained in log-space
|
||||
raw_pred = model.predict(X)
|
||||
if "tflops" in predictor._log_targets:
|
||||
valid["pred_tflops"] = np.expm1(raw_pred)
|
||||
else:
|
||||
# Clamp to non-negative even for non-log models
|
||||
valid["pred_tflops"] = np.maximum(0.0, raw_pred)
|
||||
|
||||
y_true = valid["measured_tflops"].values
|
||||
y_pred = valid["pred_tflops"].values
|
||||
|
||||
ss_res = np.sum((y_true - y_pred) ** 2)
|
||||
ss_tot = np.sum((y_true - y_true.mean()) ** 2)
|
||||
r2 = 1 - ss_res / max(ss_tot, 1e-10)
|
||||
rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
|
||||
mae = np.mean(np.abs(y_true - y_pred))
|
||||
|
||||
eff_df = compute_tflops_efficiency(valid, "pred_tflops")
|
||||
|
||||
ndcg1_count = 0
|
||||
total_shapes = 0
|
||||
topk_hits = {3: 0, 5: 0, 10: 0}
|
||||
|
||||
for (m, n, k), group in valid.groupby(["m", "n", "k"]):
|
||||
if group["measured_tflops"].max() <= 0:
|
||||
continue
|
||||
total_shapes += 1
|
||||
oracle_idx = group["measured_tflops"].idxmax()
|
||||
pred_ranking = group.sort_values("pred_tflops", ascending=False).index.tolist()
|
||||
|
||||
if pred_ranking[0] == oracle_idx:
|
||||
ndcg1_count += 1
|
||||
|
||||
oracle_rank = pred_ranking.index(oracle_idx)
|
||||
for topk in topk_hits:
|
||||
if oracle_rank < topk:
|
||||
topk_hits[topk] += 1
|
||||
|
||||
global_metrics = {
|
||||
"r2": r2,
|
||||
"rmse": rmse,
|
||||
"mae": mae,
|
||||
"num_valid_rows": len(valid),
|
||||
"num_shapes": total_shapes,
|
||||
"efficiency_mean": float(eff_df["efficiency"].mean()) if len(eff_df) > 0 else 0,
|
||||
"efficiency_p10": float(eff_df["efficiency"].quantile(0.1))
|
||||
if len(eff_df) > 0
|
||||
else 0,
|
||||
"efficiency_p50": float(eff_df["efficiency"].quantile(0.5))
|
||||
if len(eff_df) > 0
|
||||
else 0,
|
||||
"efficiency_min": float(eff_df["efficiency"].min()) if len(eff_df) > 0 else 0,
|
||||
"ndcg_at_1": ndcg1_count / max(total_shapes, 1),
|
||||
"top3_hit_rate": topk_hits[3] / max(total_shapes, 1),
|
||||
"top5_hit_rate": topk_hits[5] / max(total_shapes, 1),
|
||||
"top10_hit_rate": topk_hits[10] / max(total_shapes, 1),
|
||||
}
|
||||
|
||||
def _slice_efficiency(slice_df):
|
||||
if len(slice_df) == 0:
|
||||
return {"count": 0}
|
||||
eff = compute_tflops_efficiency(slice_df, "pred_tflops")
|
||||
if len(eff) == 0:
|
||||
return {"count": 0}
|
||||
return {
|
||||
"count": len(eff),
|
||||
"mean": float(eff["efficiency"].mean()),
|
||||
"p10": float(eff["efficiency"].quantile(0.1)),
|
||||
"min": float(eff["efficiency"].min()),
|
||||
}
|
||||
|
||||
valid["shape_family"] = valid.apply(
|
||||
lambda r: classify_shape_family(r["m"], r["n"], r["k"]), axis=1
|
||||
)
|
||||
valid["k_regime"] = valid["k"].apply(classify_k_regime)
|
||||
|
||||
shape_family_metrics = {}
|
||||
for family, group in valid.groupby("shape_family"):
|
||||
shape_family_metrics[family] = _slice_efficiency(group)
|
||||
|
||||
k_regime_metrics = {}
|
||||
for regime, group in valid.groupby("k_regime"):
|
||||
k_regime_metrics[regime] = _slice_efficiency(group)
|
||||
|
||||
pipeline_metrics = {}
|
||||
if "pipeline" in valid.columns:
|
||||
for pipeline, group in valid.groupby("pipeline"):
|
||||
pipeline_metrics[str(pipeline)] = _slice_efficiency(group)
|
||||
|
||||
return {
|
||||
"global_metrics": global_metrics,
|
||||
"shape_family_metrics": shape_family_metrics,
|
||||
"k_regime_metrics": k_regime_metrics,
|
||||
"pipeline_metrics": pipeline_metrics,
|
||||
"per_shape_efficiency": eff_df.to_dict(orient="records")
|
||||
if len(eff_df) > 0
|
||||
else [],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate CK Tile performance model")
|
||||
parser.add_argument(
|
||||
"--model_dir", required=True, help="Directory with trained models"
|
||||
)
|
||||
parser.add_argument("--data_dir", required=True, help="Directory with parquet data")
|
||||
parser.add_argument("--op", default="gemm_universal")
|
||||
parser.add_argument("--dtype", default="fp8")
|
||||
parser.add_argument("--output", "-o", help="Output JSON path for metrics")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading data from {args.data_dir}...")
|
||||
df = build_training_dataset(args.data_dir, op_type=args.op, dtype=args.dtype)
|
||||
print(f" {len(df)} rows, {df.groupby(['m', 'n', 'k']).ngroups} shapes")
|
||||
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
predictor = Predictor(args.model_dir, feature_engine=fe)
|
||||
|
||||
print("Evaluating...")
|
||||
results = evaluate_model(predictor, df, fe)
|
||||
|
||||
gm = results["global_metrics"]
|
||||
print("\nGlobal Metrics:")
|
||||
print(f" R2: {gm['r2']:.4f}")
|
||||
print(f" RMSE: {gm['rmse']:.2f}")
|
||||
print(f" Efficiency Mean: {gm['efficiency_mean']:.4f}")
|
||||
print(f" Efficiency P10: {gm['efficiency_p10']:.4f}")
|
||||
print(f" Efficiency P50: {gm['efficiency_p50']:.4f}")
|
||||
print(f" Efficiency Min: {gm['efficiency_min']:.4f}")
|
||||
print(f" NDCG@1: {gm['ndcg_at_1']:.4f}")
|
||||
print(f" Top-3 Hit Rate: {gm['top3_hit_rate']:.4f}")
|
||||
print(f" Top-5 Hit Rate: {gm['top5_hit_rate']:.4f}")
|
||||
print(f" Top-10 Hit Rate: {gm['top10_hit_rate']:.4f}")
|
||||
|
||||
print("\nShape Family Breakdown:")
|
||||
for family, metrics in sorted(results["shape_family_metrics"].items()):
|
||||
if metrics.get("count", 0) > 0:
|
||||
print(
|
||||
f" {family:12s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} min={metrics['min']:.4f} (n={metrics['count']})"
|
||||
)
|
||||
|
||||
print("\nK-Depth Regime Breakdown:")
|
||||
for regime, metrics in sorted(results["k_regime_metrics"].items()):
|
||||
if metrics.get("count", 0) > 0:
|
||||
print(
|
||||
f" {regime:12s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} min={metrics['min']:.4f} (n={metrics['count']})"
|
||||
)
|
||||
|
||||
print("\nPipeline Breakdown:")
|
||||
for pipeline, metrics in sorted(results["pipeline_metrics"].items()):
|
||||
if metrics.get("count", 0) > 0:
|
||||
print(
|
||||
f" {pipeline:15s}: mean={metrics['mean']:.4f} p10={metrics['p10']:.4f} (n={metrics['count']})"
|
||||
)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
print(f"\nFull results saved to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
577
dispatcher/heuristics/feature_engine.py
Normal file
577
dispatcher/heuristics/feature_engine.py
Normal file
@@ -0,0 +1,577 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Feature engineering for CK Tile kernel performance prediction.
|
||||
|
||||
Provides a strict FeatureEngine interface with per-op subclasses.
|
||||
All feature engines produce a consistent numpy array for LightGBM.
|
||||
"""
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
DTYPE_BYTES = {
|
||||
"fp32": 4.0,
|
||||
"fp16": 2.0,
|
||||
"bf16": 2.0,
|
||||
"fp8": 1.0,
|
||||
"bf8": 1.0,
|
||||
"int8": 1.0,
|
||||
"int4": 0.5,
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3}
|
||||
PIPELINE_MAP = {"compv3": 0, "compv4": 1, "compv5": 2, "mem": 3, "preshufflev2": 4}
|
||||
SCHEDULER_MAP = {"intrawave": 0, "interwave": 1}
|
||||
EPILOGUE_MAP = {"default": 0, "cshuffle": 1}
|
||||
|
||||
|
||||
class FeatureEngine(ABC):
|
||||
"""Abstract base for per-op feature extraction."""
|
||||
|
||||
@abstractmethod
|
||||
def get_feature_names(self) -> list[str]:
|
||||
"""Ordered list of feature names matching the output array columns."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_categorical_features(self) -> list[str]:
|
||||
"""Feature names that should be treated as categorical by LightGBM."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, problem: dict, kernel: dict) -> np.ndarray:
|
||||
"""Extract a single feature vector from a (problem, kernel) pair."""
|
||||
...
|
||||
|
||||
def extract_batch(self, df: pd.DataFrame) -> np.ndarray:
|
||||
"""Vectorized batch extraction from a DataFrame. Override for speed."""
|
||||
names = self.get_feature_names()
|
||||
result = np.zeros((len(df), len(names)), dtype=np.float64)
|
||||
for i in range(len(df)):
|
||||
row = df.iloc[i]
|
||||
prob = row.to_dict()
|
||||
kern = row.to_dict()
|
||||
result[i] = self.extract(prob, kern)
|
||||
return result
|
||||
|
||||
def get_parameter_space(self) -> dict[str, list]:
|
||||
"""Valid discrete values for each kernel parameter (for surrogate search)."""
|
||||
return {}
|
||||
|
||||
def get_constraints(self) -> list:
|
||||
"""Multi-param constraint functions returning True if config is valid."""
|
||||
return []
|
||||
|
||||
def validate_config(self, config: dict) -> bool:
|
||||
"""Check all constraints. Returns True if the config is valid."""
|
||||
ps = self.get_parameter_space()
|
||||
for k, valid_vals in ps.items():
|
||||
if k in config and config[k] not in valid_vals:
|
||||
return False
|
||||
for constraint in self.get_constraints():
|
||||
if not constraint(config):
|
||||
return False
|
||||
return True
|
||||
|
||||
def project_to_valid(self, config: dict) -> dict:
|
||||
"""Snap a config to the nearest valid discrete point."""
|
||||
ps = self.get_parameter_space()
|
||||
result = dict(config)
|
||||
for k, valid_vals in ps.items():
|
||||
if k not in result:
|
||||
continue
|
||||
v = result[k]
|
||||
if isinstance(valid_vals[0], (int, float)):
|
||||
result[k] = min(valid_vals, key=lambda x: abs(x - v))
|
||||
elif v not in valid_vals:
|
||||
result[k] = valid_vals[0]
|
||||
return result
|
||||
|
||||
|
||||
class GemmUniversalFeatureEngine(FeatureEngine):
|
||||
"""Feature engine for gemm_universal kernels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_cus: int = 256,
|
||||
lds_capacity: int = 65536,
|
||||
max_clock_mhz: int = 2400,
|
||||
simds_per_cu: int = 4,
|
||||
shader_engines: int = 32,
|
||||
max_waves_per_cu: int = 32,
|
||||
wavefront_size: int = 64,
|
||||
l1_cache_kb: int = 32,
|
||||
l2_cache_kb: int = 4096,
|
||||
l3_cache_kb: int = 262144,
|
||||
num_xcd: int = 8,
|
||||
):
|
||||
self._hw = {
|
||||
"num_cus": num_cus,
|
||||
"lds_capacity": lds_capacity,
|
||||
"max_clock_mhz": max_clock_mhz,
|
||||
"simds_per_cu": simds_per_cu,
|
||||
"shader_engines": shader_engines,
|
||||
"max_waves_per_cu": max_waves_per_cu,
|
||||
"wavefront_size": wavefront_size,
|
||||
"l1_cache_kb": l1_cache_kb,
|
||||
"l2_cache_kb": l2_cache_kb,
|
||||
"l3_cache_kb": l3_cache_kb,
|
||||
"num_xcd": num_xcd,
|
||||
"total_simds": num_cus * simds_per_cu,
|
||||
}
|
||||
|
||||
def get_feature_names(self) -> list[str]:
|
||||
return [
|
||||
# Problem features
|
||||
"M",
|
||||
"N",
|
||||
"K",
|
||||
"split_k",
|
||||
"log2_M",
|
||||
"log2_N",
|
||||
"log2_K",
|
||||
"log2_MNK",
|
||||
"arithmetic_intensity",
|
||||
"aspect_ratio_mn",
|
||||
"aspect_ratio_mk",
|
||||
"aspect_ratio_nk",
|
||||
"layout",
|
||||
# Kernel features
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
"num_warps",
|
||||
"tile_volume",
|
||||
"tile_mn",
|
||||
"lds_usage_estimate",
|
||||
"lds_usage_ratio",
|
||||
# Interaction features
|
||||
"num_tiles_m",
|
||||
"num_tiles_n",
|
||||
"num_tiles_k",
|
||||
"total_output_tiles",
|
||||
"tile_eff_m",
|
||||
"tile_eff_n",
|
||||
"tile_eff_k",
|
||||
"overall_tile_efficiency",
|
||||
"cu_utilization",
|
||||
# P0 FIX: Problem-to-tile ratio features
|
||||
"ratio_M_to_tile_m",
|
||||
"ratio_N_to_tile_n",
|
||||
"ratio_K_to_tile_k",
|
||||
"problem_smaller_than_tile_m",
|
||||
"problem_smaller_than_tile_n",
|
||||
"problem_smaller_than_tile_k",
|
||||
"any_dim_too_small",
|
||||
# P1 FIX: Padding requirement interaction features
|
||||
"needs_padding_m",
|
||||
"needs_padding_n",
|
||||
"needs_padding_k",
|
||||
"has_padding_when_needed_m",
|
||||
"has_padding_when_needed_n",
|
||||
"has_padding_when_needed_k",
|
||||
"missing_required_padding_m",
|
||||
"missing_required_padding_n",
|
||||
"missing_required_padding_k",
|
||||
"missing_any_required_padding",
|
||||
# Hardware features
|
||||
"hw_num_cus",
|
||||
"hw_simds_per_cu",
|
||||
"hw_total_simds",
|
||||
"hw_shader_engines",
|
||||
"hw_max_clock_mhz",
|
||||
"hw_max_waves_per_cu",
|
||||
"hw_wavefront_size",
|
||||
"hw_lds_capacity",
|
||||
"hw_l1_cache_kb",
|
||||
"hw_l2_cache_kb",
|
||||
"hw_l3_cache_kb",
|
||||
"hw_num_xcd",
|
||||
]
|
||||
|
||||
def get_categorical_features(self) -> list[str]:
|
||||
return ["layout", "pipeline", "scheduler", "epilogue"]
|
||||
|
||||
def extract(self, problem: dict, kernel: dict) -> np.ndarray:
|
||||
M = int(problem.get("m", problem.get("M", 0)))
|
||||
N = int(problem.get("n", problem.get("N", 0)))
|
||||
K = int(problem.get("k", problem.get("K", 0)))
|
||||
split_k = int(problem.get("split_k", 1))
|
||||
dtype = str(problem.get("dtype", "fp8"))
|
||||
bpe = DTYPE_BYTES.get(dtype, 1.0)
|
||||
|
||||
log2_M = math.log2(max(M, 1))
|
||||
log2_N = math.log2(max(N, 1))
|
||||
log2_K = math.log2(max(K, 1))
|
||||
log2_MNK = math.log2(max(M * N * K, 1))
|
||||
|
||||
mem_bytes = (M * K + K * N + M * N) * bpe
|
||||
ai = (2.0 * M * N * K) / max(mem_bytes, 1)
|
||||
|
||||
ar_mn = M / max(N, 1)
|
||||
ar_mk = M / max(K, 1)
|
||||
ar_nk = N / max(K, 1)
|
||||
|
||||
layout_code = LAYOUT_MAP.get(str(problem.get("layout", "rcr")), 0)
|
||||
|
||||
tile_m = int(kernel.get("tile_m", 128))
|
||||
tile_n = int(kernel.get("tile_n", 128))
|
||||
tile_k = int(kernel.get("tile_k", 64))
|
||||
warp_m = int(kernel.get("warp_m", 2))
|
||||
warp_n = int(kernel.get("warp_n", 2))
|
||||
warp_k = int(kernel.get("warp_k", 1))
|
||||
warp_tile_m = int(kernel.get("warp_tile_m", 32))
|
||||
warp_tile_n = int(kernel.get("warp_tile_n", 32))
|
||||
warp_tile_k = int(kernel.get("warp_tile_k", 16))
|
||||
|
||||
pipeline_code = PIPELINE_MAP.get(str(kernel.get("pipeline", "compv4")), 0)
|
||||
scheduler_code = SCHEDULER_MAP.get(str(kernel.get("scheduler", "intrawave")), 0)
|
||||
epilogue_code = EPILOGUE_MAP.get(str(kernel.get("epilogue", "cshuffle")), 0)
|
||||
|
||||
pad_m = float(kernel.get("pad_m", False))
|
||||
pad_n = float(kernel.get("pad_n", False))
|
||||
pad_k = float(kernel.get("pad_k", False))
|
||||
persistent = float(kernel.get("persistent", False))
|
||||
|
||||
num_warps = warp_m * warp_n * warp_k
|
||||
tile_volume = tile_m * tile_n * tile_k
|
||||
tile_mn = tile_m * tile_n
|
||||
|
||||
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
|
||||
lds_cap = self._hw["lds_capacity"]
|
||||
if str(kernel.get("pipeline", "")).startswith("compv4"):
|
||||
lds_cap = 32768
|
||||
lds_ratio = lds_est / max(lds_cap, 1)
|
||||
|
||||
num_tiles_m = math.ceil(M / max(tile_m, 1))
|
||||
num_tiles_n = math.ceil(N / max(tile_n, 1))
|
||||
num_tiles_k = math.ceil(K / max(tile_k, 1))
|
||||
total_output_tiles = num_tiles_m * num_tiles_n
|
||||
|
||||
rem_m = M % tile_m if tile_m > 0 else 0
|
||||
tile_eff_m = rem_m / tile_m if rem_m > 0 else 1.0
|
||||
rem_n = N % tile_n if tile_n > 0 else 0
|
||||
tile_eff_n = rem_n / tile_n if rem_n > 0 else 1.0
|
||||
rem_k = K % tile_k if tile_k > 0 else 0
|
||||
tile_eff_k = rem_k / tile_k if rem_k > 0 else 1.0
|
||||
overall_eff = tile_eff_m * tile_eff_n * tile_eff_k
|
||||
|
||||
cu_util = total_output_tiles / max(self._hw["num_cus"], 1)
|
||||
|
||||
# P0 FIX: Problem-to-tile ratio features (avoid oversized tiles for tiny problems)
|
||||
ratio_M_to_tile_m = M / max(tile_m, 1)
|
||||
ratio_N_to_tile_n = N / max(tile_n, 1)
|
||||
ratio_K_to_tile_k = K / max(tile_k, 1)
|
||||
|
||||
# Binary features: is problem dimension smaller than tile?
|
||||
problem_smaller_than_tile_m = float(M < tile_m)
|
||||
problem_smaller_than_tile_n = float(N < tile_n)
|
||||
problem_smaller_than_tile_k = float(K < tile_k)
|
||||
any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k))
|
||||
|
||||
# P1 FIX: Padding requirement features (does this kernel have padding when needed?)
|
||||
needs_padding_m = float(M % tile_m != 0) if tile_m > 0 else 0.0
|
||||
needs_padding_n = float(N % tile_n != 0) if tile_n > 0 else 0.0
|
||||
needs_padding_k = float(K % tile_k != 0) if tile_k > 0 else 0.0
|
||||
|
||||
# Interaction features: kernel has padding capability when problem needs it
|
||||
has_padding_when_needed_m = float(needs_padding_m and pad_m)
|
||||
has_padding_when_needed_n = float(needs_padding_n and pad_n)
|
||||
has_padding_when_needed_k = float(needs_padding_k and pad_k)
|
||||
|
||||
# Critical feature: missing required padding (kernel will likely fail)
|
||||
missing_required_padding_m = float(needs_padding_m and not pad_m)
|
||||
missing_required_padding_n = float(needs_padding_n and not pad_n)
|
||||
missing_required_padding_k = float(needs_padding_k and not pad_k)
|
||||
missing_any_required_padding = float(
|
||||
missing_required_padding_m
|
||||
or missing_required_padding_n
|
||||
or missing_required_padding_k
|
||||
)
|
||||
|
||||
hw = self._hw
|
||||
return np.array(
|
||||
[
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
split_k,
|
||||
log2_M,
|
||||
log2_N,
|
||||
log2_K,
|
||||
log2_MNK,
|
||||
ai,
|
||||
ar_mn,
|
||||
ar_mk,
|
||||
ar_nk,
|
||||
layout_code,
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline_code,
|
||||
scheduler_code,
|
||||
epilogue_code,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
num_warps,
|
||||
tile_volume,
|
||||
tile_mn,
|
||||
lds_est,
|
||||
lds_ratio,
|
||||
num_tiles_m,
|
||||
num_tiles_n,
|
||||
num_tiles_k,
|
||||
total_output_tiles,
|
||||
tile_eff_m,
|
||||
tile_eff_n,
|
||||
tile_eff_k,
|
||||
overall_eff,
|
||||
cu_util,
|
||||
# P0 FIX: New ratio and binary features
|
||||
ratio_M_to_tile_m,
|
||||
ratio_N_to_tile_n,
|
||||
ratio_K_to_tile_k,
|
||||
problem_smaller_than_tile_m,
|
||||
problem_smaller_than_tile_n,
|
||||
problem_smaller_than_tile_k,
|
||||
any_dim_too_small,
|
||||
# P1 FIX: Padding requirement interaction features
|
||||
needs_padding_m,
|
||||
needs_padding_n,
|
||||
needs_padding_k,
|
||||
has_padding_when_needed_m,
|
||||
has_padding_when_needed_n,
|
||||
has_padding_when_needed_k,
|
||||
missing_required_padding_m,
|
||||
missing_required_padding_n,
|
||||
missing_required_padding_k,
|
||||
missing_any_required_padding,
|
||||
hw["num_cus"],
|
||||
hw["simds_per_cu"],
|
||||
hw["total_simds"],
|
||||
hw["shader_engines"],
|
||||
hw["max_clock_mhz"],
|
||||
hw["max_waves_per_cu"],
|
||||
hw["wavefront_size"],
|
||||
hw["lds_capacity"],
|
||||
hw["l1_cache_kb"],
|
||||
hw["l2_cache_kb"],
|
||||
hw["l3_cache_kb"],
|
||||
hw["num_xcd"],
|
||||
],
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
def extract_batch(self, df: pd.DataFrame) -> np.ndarray:
|
||||
"""Vectorized batch extraction -- much faster than row-by-row."""
|
||||
n = len(df)
|
||||
names = self.get_feature_names()
|
||||
result = np.zeros((n, len(names)), dtype=np.float64)
|
||||
|
||||
M = df["m"].values.astype(np.float64)
|
||||
N = df["n"].values.astype(np.float64)
|
||||
K = df["k"].values.astype(np.float64)
|
||||
split_k = df["split_k"].fillna(1).values.astype(np.float64)
|
||||
|
||||
dtype_col = df["dtype"].fillna("fp8")
|
||||
bpe = dtype_col.map(DTYPE_BYTES).fillna(1.0).values
|
||||
|
||||
result[:, 0] = M
|
||||
result[:, 1] = N
|
||||
result[:, 2] = K
|
||||
result[:, 3] = split_k
|
||||
result[:, 4] = np.log2(np.maximum(M, 1))
|
||||
result[:, 5] = np.log2(np.maximum(N, 1))
|
||||
result[:, 6] = np.log2(np.maximum(K, 1))
|
||||
result[:, 7] = np.log2(np.maximum(M * N * K, 1))
|
||||
|
||||
mem = (M * K + K * N + M * N) * bpe
|
||||
result[:, 8] = (2.0 * M * N * K) / np.maximum(mem, 1)
|
||||
result[:, 9] = M / np.maximum(N, 1)
|
||||
result[:, 10] = M / np.maximum(K, 1)
|
||||
result[:, 11] = N / np.maximum(K, 1)
|
||||
|
||||
result[:, 12] = df["layout"].map(LAYOUT_MAP).fillna(0).values
|
||||
|
||||
tile_m = df["tile_m"].fillna(128).values.astype(np.float64)
|
||||
tile_n = df["tile_n"].fillna(128).values.astype(np.float64)
|
||||
tile_k = df["tile_k"].fillna(64).values.astype(np.float64)
|
||||
warp_m = df["warp_m"].fillna(2).values.astype(np.float64)
|
||||
warp_n = df["warp_n"].fillna(2).values.astype(np.float64)
|
||||
warp_k = df["warp_k"].fillna(1).values.astype(np.float64)
|
||||
warp_tile_m = df["warp_tile_m"].fillna(32).values.astype(np.float64)
|
||||
warp_tile_n = df["warp_tile_n"].fillna(32).values.astype(np.float64)
|
||||
warp_tile_k = df["warp_tile_k"].fillna(16).values.astype(np.float64)
|
||||
|
||||
result[:, 13] = tile_m
|
||||
result[:, 14] = tile_n
|
||||
result[:, 15] = tile_k
|
||||
result[:, 16] = warp_m
|
||||
result[:, 17] = warp_n
|
||||
result[:, 18] = warp_k
|
||||
result[:, 19] = warp_tile_m
|
||||
result[:, 20] = warp_tile_n
|
||||
result[:, 21] = warp_tile_k
|
||||
|
||||
result[:, 22] = df["pipeline"].map(PIPELINE_MAP).fillna(0).values
|
||||
result[:, 23] = df["scheduler"].map(SCHEDULER_MAP).fillna(0).values
|
||||
result[:, 24] = df["epilogue"].map(EPILOGUE_MAP).fillna(0).values
|
||||
|
||||
result[:, 25] = df["pad_m"].fillna(False).astype(float).values
|
||||
result[:, 26] = df["pad_n"].fillna(False).astype(float).values
|
||||
result[:, 27] = df["pad_k"].fillna(False).astype(float).values
|
||||
result[:, 28] = df["persistent"].fillna(False).astype(float).values
|
||||
|
||||
num_warps = warp_m * warp_n * warp_k
|
||||
result[:, 29] = num_warps
|
||||
result[:, 30] = tile_m * tile_n * tile_k
|
||||
result[:, 31] = tile_m * tile_n
|
||||
|
||||
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
|
||||
result[:, 32] = lds_est
|
||||
lds_cap = np.full(n, self._hw["lds_capacity"], dtype=np.float64)
|
||||
is_compv4 = df["pipeline"].fillna("").str.startswith("compv4")
|
||||
lds_cap[is_compv4] = 32768
|
||||
result[:, 33] = lds_est / np.maximum(lds_cap, 1)
|
||||
|
||||
ntm = np.ceil(M / np.maximum(tile_m, 1))
|
||||
ntn = np.ceil(N / np.maximum(tile_n, 1))
|
||||
ntk = np.ceil(K / np.maximum(tile_k, 1))
|
||||
result[:, 34] = ntm
|
||||
result[:, 35] = ntn
|
||||
result[:, 36] = ntk
|
||||
result[:, 37] = ntm * ntn
|
||||
|
||||
rem_m = np.mod(M, np.maximum(tile_m, 1))
|
||||
result[:, 38] = np.where(rem_m > 0, rem_m / tile_m, 1.0)
|
||||
rem_n = np.mod(N, np.maximum(tile_n, 1))
|
||||
result[:, 39] = np.where(rem_n > 0, rem_n / tile_n, 1.0)
|
||||
rem_k = np.mod(K, np.maximum(tile_k, 1))
|
||||
result[:, 40] = np.where(rem_k > 0, rem_k / tile_k, 1.0)
|
||||
result[:, 41] = result[:, 38] * result[:, 39] * result[:, 40]
|
||||
|
||||
result[:, 42] = (ntm * ntn) / max(self._hw["num_cus"], 1)
|
||||
|
||||
# P0 FIX: Problem-to-tile ratio features
|
||||
result[:, 43] = M / np.maximum(tile_m, 1) # ratio_M_to_tile_m
|
||||
result[:, 44] = N / np.maximum(tile_n, 1) # ratio_N_to_tile_n
|
||||
result[:, 45] = K / np.maximum(tile_k, 1) # ratio_K_to_tile_k
|
||||
|
||||
# Binary features: is problem smaller than tile?
|
||||
result[:, 46] = (M < tile_m).astype(float) # problem_smaller_than_tile_m
|
||||
result[:, 47] = (N < tile_n).astype(float) # problem_smaller_than_tile_n
|
||||
result[:, 48] = (K < tile_k).astype(float) # problem_smaller_than_tile_k
|
||||
result[:, 49] = ((M < tile_m) | (N < tile_n) | (K < tile_k)).astype(
|
||||
float
|
||||
) # any_dim_too_small
|
||||
|
||||
# P1 FIX: Padding requirement features
|
||||
pad_m_bool = df["pad_m"].fillna(False).astype(bool).values
|
||||
pad_n_bool = df["pad_n"].fillna(False).astype(bool).values
|
||||
pad_k_bool = df["pad_k"].fillna(False).astype(bool).values
|
||||
|
||||
needs_padding_m = (np.mod(M, np.maximum(tile_m, 1)) != 0)
|
||||
needs_padding_n = (np.mod(N, np.maximum(tile_n, 1)) != 0)
|
||||
needs_padding_k = (np.mod(K, np.maximum(tile_k, 1)) != 0)
|
||||
|
||||
result[:, 50] = needs_padding_m.astype(float)
|
||||
result[:, 51] = needs_padding_n.astype(float)
|
||||
result[:, 52] = needs_padding_k.astype(float)
|
||||
|
||||
# Interaction features: kernel has padding when problem needs it
|
||||
result[:, 53] = (needs_padding_m & pad_m_bool).astype(float) # has_padding_when_needed_m
|
||||
result[:, 54] = (needs_padding_n & pad_n_bool).astype(float) # has_padding_when_needed_n
|
||||
result[:, 55] = (needs_padding_k & pad_k_bool).astype(float) # has_padding_when_needed_k
|
||||
|
||||
# Critical feature: missing required padding
|
||||
result[:, 56] = (needs_padding_m & ~pad_m_bool).astype(float) # missing_required_padding_m
|
||||
result[:, 57] = (needs_padding_n & ~pad_n_bool).astype(float) # missing_required_padding_n
|
||||
result[:, 58] = (needs_padding_k & ~pad_k_bool).astype(float) # missing_required_padding_k
|
||||
result[:, 59] = ((needs_padding_m & ~pad_m_bool) | (needs_padding_n & ~pad_n_bool) | (needs_padding_k & ~pad_k_bool)).astype(float) # missing_any_required_padding
|
||||
|
||||
# Hardware profile features
|
||||
hw = self._hw
|
||||
result[:, 60] = hw["num_cus"]
|
||||
result[:, 61] = hw["simds_per_cu"]
|
||||
result[:, 62] = hw["total_simds"]
|
||||
result[:, 63] = hw["shader_engines"]
|
||||
result[:, 64] = hw["max_clock_mhz"]
|
||||
result[:, 65] = hw["max_waves_per_cu"]
|
||||
result[:, 66] = hw["wavefront_size"]
|
||||
result[:, 67] = hw["lds_capacity"]
|
||||
result[:, 68] = hw["l1_cache_kb"]
|
||||
result[:, 69] = hw["l2_cache_kb"]
|
||||
result[:, 70] = hw["l3_cache_kb"]
|
||||
result[:, 71] = hw["num_xcd"]
|
||||
|
||||
return result
|
||||
|
||||
def get_parameter_space(self) -> dict[str, list]:
|
||||
return {
|
||||
"tile_m": [32, 64, 128, 192, 256],
|
||||
"tile_n": [32, 64, 128, 192, 256],
|
||||
"tile_k": [32, 64, 128, 256],
|
||||
"warp_m": [1, 2, 4],
|
||||
"warp_n": [1, 2, 4],
|
||||
"warp_k": [1],
|
||||
"warp_tile_m": [4, 16, 32, 64],
|
||||
"warp_tile_n": [4, 16, 32, 64],
|
||||
"warp_tile_k": [8, 16, 32, 64, 128],
|
||||
"pipeline": list(PIPELINE_MAP.keys()),
|
||||
"scheduler": list(SCHEDULER_MAP.keys()),
|
||||
"epilogue": list(EPILOGUE_MAP.keys()),
|
||||
"pad_m": [True, False],
|
||||
"pad_n": [True, False],
|
||||
"pad_k": [True, False],
|
||||
"persistent": [True, False],
|
||||
}
|
||||
|
||||
def get_constraints(self) -> list:
|
||||
lds_cap = self._hw["lds_capacity"]
|
||||
|
||||
def _lds_constraint(cfg):
|
||||
tm = cfg.get("tile_m", 128)
|
||||
tn = cfg.get("tile_n", 128)
|
||||
tk = cfg.get("tile_k", 64)
|
||||
bpe = 1.0 # fp8 default
|
||||
est = (tm * tk + tn * tk) * bpe
|
||||
cap = (
|
||||
32768 if str(cfg.get("pipeline", "")).startswith("compv4") else lds_cap
|
||||
)
|
||||
return est <= cap
|
||||
|
||||
def _warp_constraint(cfg):
|
||||
wm = cfg.get("warp_m", 2)
|
||||
wn = cfg.get("warp_n", 2)
|
||||
wk = cfg.get("warp_k", 1)
|
||||
return (wm * wn * wk) in [2, 4, 8]
|
||||
|
||||
return [_lds_constraint, _warp_constraint]
|
||||
553
dispatcher/heuristics/generate_benchmark_data.py
Normal file
553
dispatcher/heuristics/generate_benchmark_data.py
Normal file
@@ -0,0 +1,553 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
GEMM Universal Benchmark Data Generation Script
|
||||
|
||||
This script generates training data for ML-based kernel selection heuristics by:
|
||||
1. Reading kernel configurations from the tile engine
|
||||
2. Building benchmark executables (in parallel)
|
||||
3. Running benchmarks across multiple problem sizes
|
||||
4. Outputting performance data in JSON format
|
||||
|
||||
Usage:
|
||||
python generate_benchmark_data.py \
|
||||
--build_dir /tmp/build \
|
||||
--output_dir /tmp/benchmark_data \
|
||||
--dtype fp16 \
|
||||
--layout rcr \
|
||||
--num_build_jobs 4 \
|
||||
--num_benchmark_jobs 1
|
||||
|
||||
Requirements:
|
||||
- ROCm-capable GPU
|
||||
- CK tile engine built with CMake
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
import re
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelConfig:
|
||||
"""Represents a single kernel configuration."""
|
||||
|
||||
name: str
|
||||
dtype: str
|
||||
layout: str
|
||||
pipeline: str
|
||||
epilogue: str
|
||||
scheduler: str
|
||||
pad_m: bool
|
||||
pad_n: bool
|
||||
pad_k: bool
|
||||
persistent: bool
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
warp_m: int
|
||||
warp_n: int
|
||||
warp_k: int
|
||||
warp_tile_m: int
|
||||
warp_tile_n: int
|
||||
warp_tile_k: int
|
||||
|
||||
@classmethod
|
||||
def from_kernel_name(cls, name: str, dtype: str, layout: str) -> "KernelConfig":
|
||||
"""Parse kernel name to extract configuration."""
|
||||
# Format: gemm_universal_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{padM}_{padN}_{padK}_{persistent}_{tile_config}
|
||||
# tile_config: {tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}
|
||||
|
||||
parts = name.split("_")
|
||||
prefix = f"gemm_universal_{dtype}_{layout}_"
|
||||
trait_and_tile = name[len(prefix) :]
|
||||
trait_parts = trait_and_tile.split("_")
|
||||
|
||||
pipeline = trait_parts[0]
|
||||
epilogue = trait_parts[1]
|
||||
scheduler = trait_parts[2]
|
||||
pad_m = trait_parts[3] == "True"
|
||||
pad_n = trait_parts[4] == "True"
|
||||
pad_k = trait_parts[5] == "True"
|
||||
persistent = trait_parts[6] == "True"
|
||||
|
||||
# Parse tile config
|
||||
tile_dims = trait_parts[7].split("x")
|
||||
warp_dims = trait_parts[8].split("x")
|
||||
warp_tile_dims = trait_parts[9].split("x")
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
pipeline=pipeline,
|
||||
epilogue=epilogue,
|
||||
scheduler=scheduler,
|
||||
pad_m=pad_m,
|
||||
pad_n=pad_n,
|
||||
pad_k=pad_k,
|
||||
persistent=persistent,
|
||||
tile_m=int(tile_dims[0]),
|
||||
tile_n=int(tile_dims[1]),
|
||||
tile_k=int(tile_dims[2]),
|
||||
warp_m=int(warp_dims[0]),
|
||||
warp_n=int(warp_dims[1]),
|
||||
warp_k=int(warp_dims[2]),
|
||||
warp_tile_m=int(warp_tile_dims[0]),
|
||||
warp_tile_n=int(warp_tile_dims[1]),
|
||||
warp_tile_k=int(warp_tile_dims[2]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
"""Result of a single benchmark run."""
|
||||
|
||||
kernel_name: str
|
||||
m: int
|
||||
n: int
|
||||
k: int
|
||||
avg_time_ms: float
|
||||
tflops: float
|
||||
is_valid: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProblemSize:
|
||||
"""GEMM problem dimensions."""
|
||||
|
||||
m: int
|
||||
n: int
|
||||
k: int
|
||||
|
||||
|
||||
def get_problem_sizes() -> List[ProblemSize]:
|
||||
"""
|
||||
Generate diverse problem sizes for benchmarking.
|
||||
|
||||
Includes:
|
||||
- Square matrices (powers of 2)
|
||||
- Rectangular matrices (common in ML)
|
||||
- LLM-specific sizes (attention, MLP)
|
||||
- Edge cases (small, very large)
|
||||
"""
|
||||
sizes = []
|
||||
|
||||
# Powers of 2 (square)
|
||||
for p in [6, 7, 8, 9, 10, 11, 12, 13]: # 64 to 8192
|
||||
dim = 2**p
|
||||
sizes.append(ProblemSize(dim, dim, dim))
|
||||
|
||||
# Common ML sizes (batch x hidden)
|
||||
ml_sizes = [
|
||||
(1, 4096, 4096), # Single token inference
|
||||
(8, 4096, 4096), # Small batch
|
||||
(32, 4096, 4096), # Medium batch
|
||||
(128, 4096, 4096), # Large batch
|
||||
(1, 4096, 11008), # LLaMA MLP up-projection
|
||||
(1, 11008, 4096), # LLaMA MLP down-projection
|
||||
(32, 4096, 11008),
|
||||
(32, 11008, 4096),
|
||||
(1, 8192, 8192), # Large model
|
||||
(32, 8192, 8192),
|
||||
(1, 8192, 28672), # LLaMA-70B MLP
|
||||
(32, 8192, 28672),
|
||||
]
|
||||
for m, n, k in ml_sizes:
|
||||
sizes.append(ProblemSize(m, n, k))
|
||||
|
||||
# Rectangular matrices
|
||||
rect_sizes = [
|
||||
(1024, 4096, 1024),
|
||||
(4096, 1024, 4096),
|
||||
(2048, 8192, 2048),
|
||||
(256, 256, 8192), # Tall K
|
||||
(8192, 8192, 256), # Short K
|
||||
]
|
||||
for m, n, k in rect_sizes:
|
||||
sizes.append(ProblemSize(m, n, k))
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_sizes = []
|
||||
for s in sizes:
|
||||
key = (s.m, s.n, s.k)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_sizes.append(s)
|
||||
|
||||
return unique_sizes
|
||||
|
||||
|
||||
def load_kernel_list(build_dir: Path, dtype: str, layout: str) -> List[KernelConfig]:
|
||||
"""Load kernel configurations from the tile engine build."""
|
||||
kernel_list_path = (
|
||||
build_dir
|
||||
/ "tile_engine"
|
||||
/ "ops"
|
||||
/ "gemm"
|
||||
/ "gemm_universal"
|
||||
/ dtype
|
||||
/ layout
|
||||
/ "gemm_universal_kernel_list.txt"
|
||||
)
|
||||
|
||||
if not kernel_list_path.exists():
|
||||
raise FileNotFoundError(f"Kernel list not found: {kernel_list_path}")
|
||||
|
||||
kernels = []
|
||||
with open(kernel_list_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Format: kernel_name|tile_config|trait_combo
|
||||
parts = line.split("|")
|
||||
kernel_name = parts[0]
|
||||
kernels.append(KernelConfig.from_kernel_name(kernel_name, dtype, layout))
|
||||
|
||||
return kernels
|
||||
|
||||
|
||||
def build_kernel(build_dir: Path, kernel: KernelConfig) -> Tuple[str, bool, str]:
|
||||
"""
|
||||
Build a single kernel benchmark executable.
|
||||
|
||||
Returns: (kernel_name, success, error_message)
|
||||
"""
|
||||
target_name = f"benchmark_{kernel.name}"
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ninja", "-j1", target_name],
|
||||
cwd=build_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minute timeout
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return (kernel.name, False, result.stderr[:500])
|
||||
|
||||
return (kernel.name, True, "")
|
||||
except subprocess.TimeoutExpired:
|
||||
return (kernel.name, False, "Build timeout")
|
||||
except Exception as e:
|
||||
return (kernel.name, False, str(e))
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
build_dir: Path,
|
||||
kernel: KernelConfig,
|
||||
problem: ProblemSize,
|
||||
warmup: int = 10,
|
||||
repeat: int = 50,
|
||||
) -> BenchmarkResult:
|
||||
"""
|
||||
Run benchmark for a single kernel and problem size.
|
||||
"""
|
||||
exe_path = build_dir / "bin" / f"benchmark_{kernel.name}"
|
||||
|
||||
if not exe_path.exists():
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=0,
|
||||
tflops=0,
|
||||
is_valid=False,
|
||||
error="Executable not found",
|
||||
)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
str(exe_path),
|
||||
f"-m={problem.m}",
|
||||
f"-n={problem.n}",
|
||||
f"-k={problem.k}",
|
||||
f"-warmup={warmup}",
|
||||
f"-repeat={repeat}",
|
||||
"-verify=0",
|
||||
"-json_output=true",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
# Try to parse error
|
||||
error = result.stderr[:200] if result.stderr else result.stdout[:200]
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=0,
|
||||
tflops=0,
|
||||
is_valid=False,
|
||||
error=error,
|
||||
)
|
||||
|
||||
# Parse JSON output
|
||||
output = result.stdout.strip()
|
||||
|
||||
# Try to find JSON in output
|
||||
json_match = re.search(r"\{.*\}", output, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
# Extract from nested perf_result object
|
||||
perf = data.get("perf_result", {})
|
||||
avg_time_ms = perf.get("latency(ms)", 0)
|
||||
tflops = perf.get("tflops(TFlops)", 0)
|
||||
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=avg_time_ms,
|
||||
tflops=tflops,
|
||||
is_valid=True,
|
||||
)
|
||||
else:
|
||||
# Parse from text output
|
||||
# Look for patterns like "avg_time: X ms" or "tflops: Y"
|
||||
avg_time = 0.0
|
||||
tflops = 0.0
|
||||
|
||||
time_match = re.search(
|
||||
r"(?:avg[_\s]?time|latency)[:\s]+(\d+\.?\d*)\s*(?:ms)?", output, re.I
|
||||
)
|
||||
if time_match:
|
||||
avg_time = float(time_match.group(1))
|
||||
|
||||
tflops_match = re.search(r"tflops[:\s]+(\d+\.?\d*)", output, re.I)
|
||||
if tflops_match:
|
||||
tflops = float(tflops_match.group(1))
|
||||
|
||||
# Calculate TFLOPs if not provided
|
||||
if tflops == 0 and avg_time > 0:
|
||||
flops = 2.0 * problem.m * problem.n * problem.k
|
||||
tflops = flops / (avg_time * 1e-3) / 1e12
|
||||
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=avg_time,
|
||||
tflops=tflops,
|
||||
is_valid=avg_time > 0,
|
||||
error=None if avg_time > 0 else "Could not parse output",
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=0,
|
||||
tflops=0,
|
||||
is_valid=False,
|
||||
error="Benchmark timeout",
|
||||
)
|
||||
except Exception as e:
|
||||
return BenchmarkResult(
|
||||
kernel_name=kernel.name,
|
||||
m=problem.m,
|
||||
n=problem.n,
|
||||
k=problem.k,
|
||||
avg_time_ms=0,
|
||||
tflops=0,
|
||||
is_valid=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate GEMM benchmark data for ML training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--build_dir", type=str, default="/tmp/build", help="CK build directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="/tmp/benchmark_data",
|
||||
help="Output directory for benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="fp16",
|
||||
choices=["fp16", "fp8", "bf16", "bf8"],
|
||||
help="Data type to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
type=str,
|
||||
default="rcr",
|
||||
choices=["rcr", "rrr", "crr", "ccr"],
|
||||
help="Matrix layout to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_build_jobs", type=int, default=4, help="Number of parallel build jobs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_benchmark_jobs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of parallel benchmark jobs (use 1 for accurate timing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_kernels",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of kernels to benchmark (for testing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_build",
|
||||
action="store_true",
|
||||
help="Skip building and only run benchmarks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=10, help="Number of warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat", type=int, default=50, help="Number of benchmark iterations"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
build_dir = Path(args.build_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load kernel configurations
|
||||
print(f"Loading kernel list for {args.dtype}/{args.layout}...")
|
||||
kernels = load_kernel_list(build_dir, args.dtype, args.layout)
|
||||
print(f"Found {len(kernels)} kernel configurations")
|
||||
|
||||
if args.max_kernels:
|
||||
kernels = kernels[: args.max_kernels]
|
||||
print(f"Limiting to {len(kernels)} kernels")
|
||||
|
||||
# Build kernels
|
||||
if not args.skip_build:
|
||||
print(
|
||||
f"\nBuilding {len(kernels)} kernels with {args.num_build_jobs} parallel jobs..."
|
||||
)
|
||||
build_results = {"success": 0, "failed": 0, "failed_kernels": []}
|
||||
|
||||
with ProcessPoolExecutor(max_workers=args.num_build_jobs) as executor:
|
||||
futures = {executor.submit(build_kernel, build_dir, k): k for k in kernels}
|
||||
|
||||
for i, future in enumerate(as_completed(futures)):
|
||||
kernel_name, success, error = future.result()
|
||||
if success:
|
||||
build_results["success"] += 1
|
||||
else:
|
||||
build_results["failed"] += 1
|
||||
build_results["failed_kernels"].append(
|
||||
{"name": kernel_name, "error": error}
|
||||
)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
print(
|
||||
f" Built {i + 1}/{len(kernels)} ({build_results['success']} success, {build_results['failed']} failed)"
|
||||
)
|
||||
|
||||
print(
|
||||
f"\nBuild complete: {build_results['success']} success, {build_results['failed']} failed"
|
||||
)
|
||||
|
||||
# Save build results
|
||||
with open(output_dir / "build_results.json", "w") as f:
|
||||
json.dump(build_results, f, indent=2)
|
||||
|
||||
# Get problem sizes
|
||||
problem_sizes = get_problem_sizes()
|
||||
print(f"\nBenchmarking {len(problem_sizes)} problem sizes...")
|
||||
|
||||
# Run benchmarks
|
||||
all_results = []
|
||||
total_benchmarks = len(kernels) * len(problem_sizes)
|
||||
completed = 0
|
||||
|
||||
print(f"Total benchmarks to run: {total_benchmarks}")
|
||||
|
||||
for kernel in kernels:
|
||||
kernel_results = {
|
||||
"kernel_config": asdict(kernel),
|
||||
"benchmarks": [],
|
||||
}
|
||||
|
||||
for problem in problem_sizes:
|
||||
result = run_benchmark(
|
||||
build_dir,
|
||||
kernel,
|
||||
problem,
|
||||
warmup=args.warmup,
|
||||
repeat=args.repeat,
|
||||
)
|
||||
kernel_results["benchmarks"].append(asdict(result))
|
||||
completed += 1
|
||||
|
||||
if completed % 100 == 0:
|
||||
print(f" Progress: {completed}/{total_benchmarks} benchmarks complete")
|
||||
|
||||
all_results.append(kernel_results)
|
||||
|
||||
# Save intermediate results
|
||||
intermediate_file = (
|
||||
output_dir / f"benchmark_results_{args.dtype}_{args.layout}_partial.json"
|
||||
)
|
||||
with open(intermediate_file, "w") as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
# Save final results
|
||||
final_file = output_dir / f"benchmark_results_{args.dtype}_{args.layout}.json"
|
||||
with open(final_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"metadata": {
|
||||
"dtype": args.dtype,
|
||||
"layout": args.layout,
|
||||
"num_kernels": len(kernels),
|
||||
"num_problems": len(problem_sizes),
|
||||
"warmup": args.warmup,
|
||||
"repeat": args.repeat,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
},
|
||||
"problem_sizes": [asdict(p) for p in problem_sizes],
|
||||
"results": all_results,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
print(f"\nResults saved to {final_file}")
|
||||
|
||||
# Print summary
|
||||
valid_count = sum(
|
||||
1 for kr in all_results for br in kr["benchmarks"] if br["is_valid"]
|
||||
)
|
||||
print(f"Valid benchmarks: {valid_count}/{total_benchmarks}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
dispatcher/heuristics/generate_edge_dims.py
Normal file
166
dispatcher/heuristics/generate_edge_dims.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Supplementary edge-case benchmark generator for N=1 and K=1 dimensions.
|
||||
|
||||
These shapes represent vector-matrix multiply (N=1), rank-1 updates (K=1),
|
||||
and other degenerate GEMM cases that stress tile efficiency and padding logic.
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def generate_edge_shapes():
|
||||
"""Generate shapes with N=1, K=1, and other single-dimension edge cases."""
|
||||
shapes = set()
|
||||
|
||||
# --- N=1: vector-matrix multiply / single output column ---
|
||||
for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||
for k in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]:
|
||||
shapes.add((m, 1, k))
|
||||
|
||||
# --- K=1: rank-1 update / outer product ---
|
||||
for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||
for n in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 7168, 8192]:
|
||||
shapes.add((m, n, 1))
|
||||
|
||||
# --- M=1, N=1: dot product ---
|
||||
for k in [1, 16, 64, 256, 1024, 4096, 8192]:
|
||||
shapes.add((1, 1, k))
|
||||
|
||||
# --- M=1, K=1: scalar-vector ---
|
||||
for n in [1, 16, 64, 256, 1024, 4096, 8192]:
|
||||
shapes.add((1, n, 1))
|
||||
|
||||
# --- N=1, K=1: scalar-vector ---
|
||||
for m in [1, 16, 64, 256, 1024, 4096, 8192]:
|
||||
shapes.add((m, 1, 1))
|
||||
|
||||
# --- All ones: 1x1x1 ---
|
||||
shapes.add((1, 1, 1))
|
||||
|
||||
# --- Small N (2-16) ---
|
||||
for m in [64, 256, 1024, 4096]:
|
||||
for n in [2, 3, 4, 7, 8, 15, 16]:
|
||||
for k in [64, 256, 1024, 4096]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- Small K (2-16) ---
|
||||
for m in [64, 256, 1024, 4096]:
|
||||
for n in [64, 256, 1024, 4096]:
|
||||
for k in [2, 3, 4, 7, 8, 15, 16]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
return sorted(shapes)
|
||||
|
||||
|
||||
def run_shapes(bin_dir, shapes, out_file, warmup=3, repeat=10):
|
||||
"""Run all kernels against shapes, writing streaming log."""
|
||||
executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*"))
|
||||
if not executables:
|
||||
print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
total = 0
|
||||
for idx, (m, n, k) in enumerate(shapes):
|
||||
out_file.write("\n========================================\n")
|
||||
out_file.write(f"Shape {idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n")
|
||||
out_file.write("========================================\n")
|
||||
out_file.write(f"Found {len(executables)} kernels\n")
|
||||
out_file.flush()
|
||||
|
||||
for exe in executables:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
str(exe),
|
||||
f"-m={m}",
|
||||
f"-n={n}",
|
||||
f"-k={k}",
|
||||
f"-warmup={warmup}",
|
||||
f"-repeat={repeat}",
|
||||
"-verify=0",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
output = result.stdout
|
||||
json_start = output.find("{")
|
||||
json_end = output.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_block = output[json_start:json_end]
|
||||
try:
|
||||
json.loads(json_block)
|
||||
out_file.write(json_block + "\n")
|
||||
total += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except (subprocess.TimeoutExpired, Exception):
|
||||
pass
|
||||
|
||||
out_file.flush()
|
||||
print(
|
||||
f" Shape {idx + 1}/{len(shapes)}: M={m} N={n} K={k}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bin_dir = "/workspace/ck_tile/bin"
|
||||
out_dir = Path("data/edge_dims")
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shapes = generate_edge_shapes()
|
||||
print(f"Generated {len(shapes)} edge-case shapes", file=sys.stderr, flush=True)
|
||||
|
||||
n1_count = sum(1 for m, n, k in shapes if n == 1)
|
||||
k1_count = sum(1 for m, n, k in shapes if k == 1)
|
||||
both1 = sum(1 for m, n, k in shapes if n == 1 and k == 1)
|
||||
small_n = sum(1 for m, n, k in shapes if 2 <= n <= 16)
|
||||
small_k = sum(1 for m, n, k in shapes if 2 <= k <= 16)
|
||||
print(
|
||||
f" N=1: {n1_count}, K=1: {k1_count}, both=1: {both1}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" Small N(2-16): {small_n}, Small K(2-16): {small_k}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
batch_size = 25
|
||||
total = 0
|
||||
batch_idx = 0
|
||||
for i in range(0, len(shapes), batch_size):
|
||||
batch = shapes[i : i + batch_size]
|
||||
batch_idx += 1
|
||||
out_path = out_dir / f"edge_dims_batch_{batch_idx:03d}.log"
|
||||
print(
|
||||
f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
with open(out_path, "w") as f:
|
||||
f.write(f"CK Tile Edge Dims Benchmark Batch {batch_idx}\n")
|
||||
f.write("GPU ID: 0\nImplementation: gemm_universal\n\n")
|
||||
count = run_shapes(bin_dir, batch, f, warmup=3, repeat=10)
|
||||
total += count
|
||||
|
||||
print(f" Batch {batch_idx} done: {count} results", file=sys.stderr, flush=True)
|
||||
|
||||
print(
|
||||
f"\nTotal: {total} benchmarks across {len(shapes)} shapes",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
289
dispatcher/heuristics/generate_wide_coverage.py
Normal file
289
dispatcher/heuristics/generate_wide_coverage.py
Normal file
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Wide-coverage benchmark data generator.
|
||||
|
||||
Generates benchmark results for hundreds of diverse GEMM shapes across all
|
||||
corner cases: skinny M, tall N, deep K, M=1, prime dimensions, power-of-2,
|
||||
LLM inference shapes, training shapes, and edge cases.
|
||||
|
||||
Runs all 4608 kernels in /workspace/ck_tile/bin/ against each shape and
|
||||
writes streaming log output parseable by data_pipeline.py.
|
||||
|
||||
Usage:
|
||||
python3 generate_wide_coverage.py --bin_dir /workspace/ck_tile/bin \
|
||||
--out_dir data/ --batch_size 20 --warmup 3 --repeat 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def generate_shape_list():
|
||||
"""Generate a comprehensive list of (M, N, K) shapes covering all corner cases.
|
||||
|
||||
Categories:
|
||||
1. M=1 (single token inference) -- the hardest case
|
||||
2. Tiny M (2-16) -- small batch inference
|
||||
3. Small M (32-128) -- medium batch
|
||||
4. Medium M (256-2048) -- large batch / training
|
||||
5. Large M (4096-20480) -- very large batch
|
||||
6. Square shapes (powers of 2)
|
||||
7. Skinny M, tall N (M << N)
|
||||
8. Tall M, skinny N (M >> N)
|
||||
9. Deep K (K >> M, N) -- compute-bound
|
||||
10. Shallow K (K << M, N) -- memory-bound
|
||||
11. Prime dimensions -- worst-case for tiling
|
||||
12. LLM-specific shapes (DeepSeek, LLaMA, etc.)
|
||||
13. Non-power-of-2 common sizes
|
||||
"""
|
||||
shapes = set()
|
||||
|
||||
# --- 1. M=1 (single token) across various N, K ---
|
||||
for n in [512, 1024, 1536, 2048, 3072, 4096, 4608, 7168, 8192, 11008, 14336, 28672]:
|
||||
for k in [256, 512, 1024, 1536, 2048, 2304, 4096, 7168, 8192]:
|
||||
shapes.add((1, n, k))
|
||||
|
||||
# --- 2. Tiny M (2-16) ---
|
||||
for m in [2, 4, 8, 16]:
|
||||
for n in [512, 1536, 4096, 7168]:
|
||||
for k in [256, 1024, 4096, 7168]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 3. Small M (32-128) ---
|
||||
for m in [32, 48, 64, 96, 128]:
|
||||
for n in [512, 1536, 4096, 7168, 8192]:
|
||||
for k in [256, 512, 2048, 4096, 7168]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 4. Medium M (256-2048) ---
|
||||
for m in [256, 384, 512, 768, 1024, 1536, 2048]:
|
||||
for n in [512, 1536, 4096, 7168]:
|
||||
for k in [256, 1024, 2048, 4096, 7168]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 5. Large M (4096-20480) ---
|
||||
for m in [4096, 8192, 12288, 16384, 20480]:
|
||||
for n in [512, 1536, 4096, 7168]:
|
||||
for k in [256, 1024, 2048, 7168]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 6. Square shapes (powers of 2) ---
|
||||
for p in range(5, 14): # 32 to 8192
|
||||
d = 2**p
|
||||
shapes.add((d, d, d))
|
||||
|
||||
# --- 7. Skinny M, tall N ---
|
||||
for m in [1, 4, 16, 64]:
|
||||
for n in [8192, 16384, 28672]:
|
||||
for k in [1024, 4096, 8192]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 8. Tall M, skinny N ---
|
||||
for m in [4096, 8192, 16384]:
|
||||
for n in [32, 64, 128, 256]:
|
||||
for k in [1024, 4096]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 9. Deep K (K >> M, N) ---
|
||||
for m in [16, 64, 256]:
|
||||
for n in [16, 64, 256]:
|
||||
for k in [4096, 8192, 16384, 32768]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 10. Shallow K (K << M, N) ---
|
||||
for m in [1024, 4096, 8192]:
|
||||
for n in [1024, 4096, 8192]:
|
||||
for k in [16, 32, 64, 128]:
|
||||
shapes.add((m, n, k))
|
||||
|
||||
# --- 11. Prime dimensions ---
|
||||
primes = [17, 31, 37, 127, 251, 509, 1021, 2039, 4093]
|
||||
for p in primes:
|
||||
shapes.add((p, p, p))
|
||||
for p in primes[:5]:
|
||||
shapes.add((p, 4096, 4096))
|
||||
shapes.add((4096, p, 4096))
|
||||
shapes.add((4096, 4096, p))
|
||||
|
||||
# --- 12. LLM-specific shapes ---
|
||||
llm_shapes = [
|
||||
# DeepSeek MoE
|
||||
(1, 1536, 7168),
|
||||
(1, 4608, 7168),
|
||||
(1, 7168, 2048),
|
||||
(1, 7168, 2304),
|
||||
(1, 7168, 256),
|
||||
(1, 576, 7168),
|
||||
(1, 512, 7168),
|
||||
(1, 3072, 1536),
|
||||
# LLaMA-7B
|
||||
(1, 4096, 4096),
|
||||
(32, 4096, 4096),
|
||||
(128, 4096, 4096),
|
||||
(1, 4096, 11008),
|
||||
(32, 4096, 11008),
|
||||
(1, 11008, 4096),
|
||||
(32, 11008, 4096),
|
||||
# LLaMA-70B
|
||||
(1, 8192, 8192),
|
||||
(32, 8192, 8192),
|
||||
(128, 8192, 8192),
|
||||
(1, 8192, 28672),
|
||||
(32, 8192, 28672),
|
||||
(1, 28672, 8192),
|
||||
# GPT-style attention
|
||||
(128, 128, 64),
|
||||
(128, 128, 128),
|
||||
(256, 256, 64),
|
||||
(512, 512, 64),
|
||||
(1024, 1024, 64),
|
||||
(2048, 2048, 64),
|
||||
]
|
||||
for s in llm_shapes:
|
||||
shapes.add(s)
|
||||
|
||||
# --- 13. Non-power-of-2 common sizes ---
|
||||
for m in [48, 96, 192, 384, 576, 768, 1152, 1536, 2304, 3072, 4608, 6144]:
|
||||
shapes.add((m, m, m))
|
||||
shapes.add((m, 4096, 4096))
|
||||
|
||||
sorted_shapes = sorted(shapes)
|
||||
return sorted_shapes
|
||||
|
||||
|
||||
def run_shape_batch(bin_dir, shapes, out_file, warmup=3, repeat=10):
|
||||
"""Run all kernels against a batch of shapes, writing streaming log output."""
|
||||
executables = sorted(Path(bin_dir).glob("benchmark_gemm_universal_fp8_rcr_*"))
|
||||
if not executables:
|
||||
print(f"ERROR: No executables found in {bin_dir}", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
total_benchmarks = 0
|
||||
|
||||
for shape_idx, (m, n, k) in enumerate(shapes):
|
||||
out_file.write("\n========================================\n")
|
||||
out_file.write(
|
||||
f"Shape {shape_idx + 1}: M={m} N={n} K={k} dtype=fp8 layout=rcr\n"
|
||||
)
|
||||
out_file.write("========================================\n")
|
||||
out_file.write(f"Found {len(executables)} kernels\n")
|
||||
out_file.flush()
|
||||
|
||||
for exe in executables:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
str(exe),
|
||||
f"-m={m}",
|
||||
f"-n={n}",
|
||||
f"-k={k}",
|
||||
f"-warmup={warmup}",
|
||||
f"-repeat={repeat}",
|
||||
"-verify=0",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
output = result.stdout
|
||||
# Extract JSON block from output
|
||||
json_start = output.find("{")
|
||||
json_end = output.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_block = output[json_start:json_end]
|
||||
try:
|
||||
json.loads(json_block)
|
||||
out_file.write(json_block + "\n")
|
||||
total_benchmarks += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except (subprocess.TimeoutExpired, Exception):
|
||||
pass
|
||||
|
||||
out_file.flush()
|
||||
elapsed_kernels = len(executables)
|
||||
print(
|
||||
f" Shape {shape_idx + 1}/{len(shapes)}: M={m} N={n} K={k} "
|
||||
f"({elapsed_kernels} kernels)",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return total_benchmarks
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate wide-coverage benchmark data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bin_dir",
|
||||
default="/workspace/ck_tile/bin",
|
||||
help="Directory with benchmark executables",
|
||||
)
|
||||
parser.add_argument("--out_dir", default="data", help="Output directory")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=25, help="Shapes per output file"
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=3)
|
||||
parser.add_argument("--repeat", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--max_shapes", type=int, default=None, help="Limit total shapes (for testing)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shapes = generate_shape_list()
|
||||
if args.max_shapes:
|
||||
shapes = shapes[: args.max_shapes]
|
||||
|
||||
print(f"Generated {len(shapes)} unique shapes", file=sys.stderr, flush=True)
|
||||
print(f"Bin dir: {args.bin_dir}", file=sys.stderr, flush=True)
|
||||
print(f"Output dir: {args.out_dir}", file=sys.stderr, flush=True)
|
||||
print(f"Batch size: {args.batch_size}", file=sys.stderr, flush=True)
|
||||
|
||||
total = 0
|
||||
batch_idx = 0
|
||||
for i in range(0, len(shapes), args.batch_size):
|
||||
batch = shapes[i : i + args.batch_size]
|
||||
batch_idx += 1
|
||||
out_path = out_dir / f"wide_coverage_batch_{batch_idx:03d}.log"
|
||||
|
||||
print(
|
||||
f"\nBatch {batch_idx}: shapes {i + 1}-{i + len(batch)} -> {out_path}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
with open(out_path, "w") as f:
|
||||
f.write(f"CK Tile Wide Coverage Benchmark Batch {batch_idx}\n")
|
||||
f.write("GPU ID: 0\n")
|
||||
f.write("Implementation: gemm_universal\n\n")
|
||||
count = run_shape_batch(
|
||||
args.bin_dir, batch, f, warmup=args.warmup, repeat=args.repeat
|
||||
)
|
||||
total += count
|
||||
|
||||
print(
|
||||
f" Batch {batch_idx} complete: {count} benchmarks",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\nTotal: {total} benchmarks across {len(shapes)} shapes",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
867
dispatcher/heuristics/ml_heuristic_sweep.py
Normal file
867
dispatcher/heuristics/ml_heuristic_sweep.py
Normal file
@@ -0,0 +1,867 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
ML Heuristic Sweep: Comprehensive GEMM Performance Evaluation
|
||||
|
||||
Sweeps across diverse problem shapes with ML-based kernel selection to measure
|
||||
TFLOPS performance. Supports multiple dtypes (fp16, bf16, fp8) and validates
|
||||
ML model predictions by executing kernels on GPU.
|
||||
|
||||
Shape Constraints (fp16/bf16 on gfx950):
|
||||
- M >= 1 (any M is valid)
|
||||
- N % 8 == 0 AND N >= 64
|
||||
- K % 2 == 0 AND K >= 32
|
||||
|
||||
Usage:
|
||||
python ml_heuristic_sweep.py --dtype fp16 --num_shapes 256
|
||||
python ml_heuristic_sweep.py --dtypes fp16 bf16 --output sweep_results.csv
|
||||
python ml_heuristic_sweep.py --dtype fp16 --dry_run # Prediction only, no GPU execution
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
# Add parent directories to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "python"))
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
)
|
||||
|
||||
try:
|
||||
from predict import Predictor
|
||||
# from feature_engine import GemmUniversalFeatureEngine
|
||||
|
||||
HAS_ML = True
|
||||
except ImportError:
|
||||
HAS_ML = False
|
||||
print("WARNING: ML heuristic modules not available. Will use first-fit selection.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Kernel specification for ML heuristic"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
wave_m: int = 2
|
||||
wave_n: int = 2
|
||||
wave_k: int = 1
|
||||
warp_m: int = 32
|
||||
warp_n: int = 32
|
||||
warp_k: int = 16
|
||||
|
||||
|
||||
# Comprehensive kernel pool covering diverse tile sizes and configurations
|
||||
KERNEL_POOL = [
|
||||
# Small tiles (64x64)
|
||||
KernelSpec(
|
||||
"s_64x64_k32_v3", 64, 64, 32, "compv3", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"s_64x64_k64_v3", 64, 64, 64, "compv3", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"s_64x64_k128_v3", 64, 64, 128, "compv3", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"s_64x64_k64_v4", 64, 64, 64, "compv4", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
KernelSpec("s_64x64_k64_mem", 64, 64, 64, "mem", "intrawave", 2, 2, 1, 16, 16, 16),
|
||||
KernelSpec(
|
||||
"s_64x64_k128_mem", 64, 64, 128, "mem", "intrawave", 2, 2, 1, 16, 16, 16
|
||||
),
|
||||
# Medium tiles (128x128)
|
||||
KernelSpec("m_128x128_k32_v3", 128, 128, 32, "compv3", "intrawave"),
|
||||
KernelSpec("m_128x128_k64_v3", 128, 128, 64, "compv3", "intrawave"),
|
||||
KernelSpec("m_128x128_k128_v3", 128, 128, 128, "compv3", "intrawave"),
|
||||
KernelSpec("m_128x128_k64_v4", 128, 128, 64, "compv4", "intrawave"),
|
||||
KernelSpec("m_128x128_k128_v4", 128, 128, 128, "compv4", "intrawave"),
|
||||
KernelSpec("m_128x128_k64_mem", 128, 128, 64, "mem", "intrawave"),
|
||||
KernelSpec("m_128x128_k128_mem", 128, 128, 128, "mem", "intrawave"),
|
||||
# Rectangular medium (M != N)
|
||||
KernelSpec(
|
||||
"r_64x128_k32_v3", 64, 128, 32, "compv3", "intrawave", 2, 2, 1, 16, 32, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_128x64_k32_v3", 128, 64, 32, "compv3", "intrawave", 2, 2, 1, 32, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_64x128_k64_v3", 64, 128, 64, "compv3", "intrawave", 2, 2, 1, 16, 32, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_128x64_k64_v3", 128, 64, 64, "compv3", "intrawave", 2, 2, 1, 32, 16, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_64x256_k32_v3", 64, 256, 32, "compv3", "intrawave", 2, 2, 1, 16, 32, 16
|
||||
),
|
||||
KernelSpec(
|
||||
"r_256x64_k32_v3", 256, 64, 32, "compv3", "intrawave", 2, 2, 1, 32, 16, 16
|
||||
),
|
||||
# Large tiles (256x256)
|
||||
KernelSpec("l_256x128_k32_v3", 256, 128, 32, "compv3", "intrawave"),
|
||||
KernelSpec("l_128x256_k32_v3", 128, 256, 32, "compv3", "intrawave"),
|
||||
KernelSpec("l_256x256_k32_v3", 256, 256, 32, "compv3", "intrawave"),
|
||||
KernelSpec("l_256x256_k64_v3", 256, 256, 64, "compv3", "intrawave"),
|
||||
KernelSpec("l_256x256_k64_v4", 256, 256, 64, "compv4", "intrawave"),
|
||||
# Interwave variants
|
||||
KernelSpec("m_128x128_k64_iw_v3", 128, 128, 64, "compv3", "interwave"),
|
||||
KernelSpec("m_128x128_k128_iw_v3", 128, 128, 128, "compv3", "interwave"),
|
||||
KernelSpec("l_256x256_k32_iw_v3", 256, 256, 32, "compv3", "interwave"),
|
||||
]
|
||||
|
||||
|
||||
def generate_problem_shapes(num_shapes: int = 1024) -> List[Tuple[int, int, int]]:
|
||||
"""
|
||||
Generate diverse problem shapes with hardware constraints:
|
||||
- M >= 1 (any M is valid, including tiny M for inference)
|
||||
- N % 8 == 0 AND N >= 64 (hardware alignment requirement)
|
||||
- K % 2 == 0 AND K >= 32 (fp16 requirement)
|
||||
|
||||
Covers:
|
||||
- Powers of 2 (square and rectangular)
|
||||
- ML workloads (LLM attention, MLP, batch inference)
|
||||
- Non-power-of-2 dimensions (aligned to constraints)
|
||||
- Edge cases (tiny M, very large matrices, extreme aspect ratios)
|
||||
"""
|
||||
shapes = []
|
||||
|
||||
# 1. Powers of 2 - Square (64 to 8192) with K variations
|
||||
for p in range(6, 14): # 2^6=64 to 2^13=8192
|
||||
dim = 2**p
|
||||
shapes.append((dim, dim, dim))
|
||||
if dim >= 128:
|
||||
# K variations (must be even and >= 32)
|
||||
shapes.append((dim, dim, dim // 2))
|
||||
shapes.append((dim, dim, dim * 2))
|
||||
shapes.append((dim, dim, max(32, dim // 4)))
|
||||
|
||||
# 2. Small batch inference (1-256 batch, common hidden dims)
|
||||
# N must be multiple of 8 and >= 64
|
||||
hidden_dims = [768, 1024, 2048, 3072, 4096, 5120, 8192, 11008, 12288, 16384]
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
|
||||
|
||||
for hidden in hidden_dims:
|
||||
for batch in batch_sizes[:8]:
|
||||
shapes.append((batch, hidden, hidden))
|
||||
if hidden >= 4096:
|
||||
# LLM MLP projections (ensure K is even)
|
||||
k_mlp = hidden * 3 // 4
|
||||
if k_mlp % 2 == 1:
|
||||
k_mlp += 1 # Make even
|
||||
if k_mlp >= 32:
|
||||
shapes.append((batch, hidden, k_mlp))
|
||||
shapes.append((batch, k_mlp, hidden))
|
||||
|
||||
# 3. Attention patterns (seq_len x head_dim)
|
||||
# seq_len can be any value >= 1, total_dim must be multiple of 8
|
||||
seq_lens = [128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
head_dims = [64, 80, 96, 128, 256]
|
||||
num_heads = [8, 12, 16, 32, 40, 64]
|
||||
|
||||
for seq in seq_lens:
|
||||
for head_dim in head_dims:
|
||||
for nh in num_heads[:4]:
|
||||
total_dim = nh * head_dim
|
||||
# total_dim should be multiple of 8 (naturally satisfied for most cases)
|
||||
if total_dim % 8 == 0 and total_dim >= 64:
|
||||
# head_dim must be even for K
|
||||
if head_dim % 2 == 0 and head_dim >= 32:
|
||||
shapes.append((seq, total_dim, head_dim))
|
||||
shapes.append((seq, head_dim, total_dim))
|
||||
|
||||
# 4. Rectangular matrices (extreme aspect ratios)
|
||||
# All dims must satisfy constraints
|
||||
dims_m = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
dims_n = [64, 128, 256, 512, 1024, 2048, 4096, 8192] # N >= 64, N % 8 == 0
|
||||
dims_k = [
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
] # K >= 32, K % 2 == 0
|
||||
|
||||
# Sample to avoid explosion
|
||||
for i, m in enumerate(dims_m):
|
||||
for j, n in enumerate(dims_n):
|
||||
for _l, k in enumerate(dims_k):
|
||||
if (i + j + _l) % 3 == 0: # Stratified sampling
|
||||
shapes.append((m, n, k))
|
||||
|
||||
# 5. Non-power-of-2 dimensions (aligned to constraints)
|
||||
# N values: multiples of 8, >= 64
|
||||
non_pow2_n = [
|
||||
72,
|
||||
80,
|
||||
88,
|
||||
96,
|
||||
104,
|
||||
112,
|
||||
120,
|
||||
136,
|
||||
144,
|
||||
152,
|
||||
160,
|
||||
176,
|
||||
184,
|
||||
192,
|
||||
200,
|
||||
224,
|
||||
240,
|
||||
272,
|
||||
288,
|
||||
304,
|
||||
320,
|
||||
336,
|
||||
352,
|
||||
368,
|
||||
384,
|
||||
400,
|
||||
416,
|
||||
448,
|
||||
480,
|
||||
544,
|
||||
576,
|
||||
640,
|
||||
672,
|
||||
704,
|
||||
736,
|
||||
768,
|
||||
800,
|
||||
832,
|
||||
896,
|
||||
960,
|
||||
1088,
|
||||
1152,
|
||||
1216,
|
||||
1280,
|
||||
1344,
|
||||
1408,
|
||||
1472,
|
||||
1536,
|
||||
1600,
|
||||
1664,
|
||||
1728,
|
||||
1792,
|
||||
1856,
|
||||
1920,
|
||||
2176,
|
||||
2304,
|
||||
2432,
|
||||
2560,
|
||||
2688,
|
||||
2816,
|
||||
2944,
|
||||
3072,
|
||||
3200,
|
||||
3328,
|
||||
3456,
|
||||
3584,
|
||||
3712,
|
||||
3840,
|
||||
3968,
|
||||
4224,
|
||||
4352,
|
||||
4480,
|
||||
4608,
|
||||
4736,
|
||||
4864,
|
||||
4992,
|
||||
]
|
||||
|
||||
# K values: even numbers >= 32
|
||||
non_pow2_k = [
|
||||
34,
|
||||
36,
|
||||
38,
|
||||
40,
|
||||
42,
|
||||
44,
|
||||
48,
|
||||
50,
|
||||
52,
|
||||
56,
|
||||
60,
|
||||
66,
|
||||
68,
|
||||
72,
|
||||
76,
|
||||
80,
|
||||
88,
|
||||
96,
|
||||
100,
|
||||
112,
|
||||
120,
|
||||
136,
|
||||
144,
|
||||
160,
|
||||
176,
|
||||
192,
|
||||
224,
|
||||
240,
|
||||
272,
|
||||
288,
|
||||
320,
|
||||
352,
|
||||
384,
|
||||
416,
|
||||
448,
|
||||
480,
|
||||
544,
|
||||
576,
|
||||
640,
|
||||
672,
|
||||
704,
|
||||
768,
|
||||
800,
|
||||
832,
|
||||
896,
|
||||
960,
|
||||
1088,
|
||||
1152,
|
||||
1280,
|
||||
1344,
|
||||
1408,
|
||||
1536,
|
||||
1600,
|
||||
1664,
|
||||
1792,
|
||||
1920,
|
||||
]
|
||||
|
||||
# M values: any value >= 1
|
||||
non_pow2_m = [
|
||||
1,
|
||||
3,
|
||||
5,
|
||||
7,
|
||||
9,
|
||||
11,
|
||||
13,
|
||||
15,
|
||||
17,
|
||||
19,
|
||||
23,
|
||||
27,
|
||||
31,
|
||||
33,
|
||||
37,
|
||||
41,
|
||||
47,
|
||||
51,
|
||||
57,
|
||||
63,
|
||||
65,
|
||||
71,
|
||||
79,
|
||||
87,
|
||||
95,
|
||||
97,
|
||||
111,
|
||||
119,
|
||||
127,
|
||||
129,
|
||||
143,
|
||||
159,
|
||||
175,
|
||||
191,
|
||||
193,
|
||||
223,
|
||||
239,
|
||||
255,
|
||||
257,
|
||||
287,
|
||||
319,
|
||||
351,
|
||||
383,
|
||||
385,
|
||||
447,
|
||||
479,
|
||||
511,
|
||||
513,
|
||||
575,
|
||||
639,
|
||||
703,
|
||||
767,
|
||||
769,
|
||||
895,
|
||||
959,
|
||||
1023,
|
||||
1025,
|
||||
]
|
||||
|
||||
# Sample non-power-of-2 shapes
|
||||
for i, m in enumerate(non_pow2_m[:30]):
|
||||
for j, n in enumerate(non_pow2_n[:20]):
|
||||
for _l, k in enumerate(non_pow2_k[:15]):
|
||||
if (i + j + _l) % 4 == 0: # Stratified sampling
|
||||
shapes.append((m, n, k))
|
||||
|
||||
# 6. Very tall K (memory-bound) - ensure N % 8 == 0, K % 2 == 0
|
||||
for mn in [64, 128, 256, 512, 1024]:
|
||||
for k in [4096, 8192, 16384]:
|
||||
shapes.append((mn, mn, k))
|
||||
|
||||
# 7. Very short K (compute-bound) - ensure K >= 32, K % 2 == 0
|
||||
for mn in [512, 1024, 2048, 4096]:
|
||||
for k in [32, 64, 128]:
|
||||
shapes.append((mn, mn, k))
|
||||
|
||||
# 8. Tiny M (edge cases for batch-1 inference)
|
||||
for m in [1, 2, 4, 8, 16, 32]:
|
||||
for n in [64, 128, 256, 512, 1024, 2048]: # N >= 64, N % 8 == 0
|
||||
for k in [32, 64, 128, 256, 512]: # K >= 32, K % 2 == 0
|
||||
shapes.append((m, n, k))
|
||||
|
||||
# 9. Stress test sizes (aligned to constraints)
|
||||
stress_sizes = [
|
||||
(10000, 10000, 10000),
|
||||
(1000, 10000, 1000),
|
||||
(1000, 1000, 10000),
|
||||
(5000, 5000, 5000),
|
||||
(7168, 7168, 7168), # Common LLM hidden dim
|
||||
(8192, 11008, 8192), # LLaMA MLP dimensions
|
||||
]
|
||||
shapes.extend(stress_sizes)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_shapes = []
|
||||
for s in shapes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
unique_shapes.append(s)
|
||||
|
||||
# Filter to ensure all shapes meet constraints
|
||||
valid_shapes = []
|
||||
for m, n, k in unique_shapes:
|
||||
if m >= 1 and n >= 64 and n % 8 == 0 and k >= 32 and k % 2 == 0:
|
||||
valid_shapes.append((m, n, k))
|
||||
|
||||
# Sample down to target number if we have too many
|
||||
if len(valid_shapes) > num_shapes:
|
||||
# Stratified sampling to preserve diversity
|
||||
step = len(valid_shapes) / num_shapes
|
||||
valid_shapes = [valid_shapes[int(i * step)] for i in range(num_shapes)]
|
||||
|
||||
return valid_shapes
|
||||
|
||||
|
||||
def spec_to_feature_dict(spec: KernelSpec, dtype: str, layout: str) -> dict:
|
||||
"""Convert KernelSpec to feature dict for ML predictor"""
|
||||
return {
|
||||
"kernel_name": spec.name,
|
||||
"tile_m": spec.tile_m,
|
||||
"tile_n": spec.tile_n,
|
||||
"tile_k": spec.tile_k,
|
||||
"warp_m": spec.wave_m,
|
||||
"warp_n": spec.wave_n,
|
||||
"warp_k": spec.wave_k,
|
||||
"warp_tile_m": spec.warp_m,
|
||||
"warp_tile_n": spec.warp_n,
|
||||
"warp_tile_k": spec.warp_k,
|
||||
"pipeline": spec.pipeline,
|
||||
"scheduler": spec.scheduler,
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": True, # Enable padding to support arbitrary M dimensions
|
||||
"pad_n": True, # Enable padding to support arbitrary N dimensions
|
||||
"pad_k": True, # Enable padding to support arbitrary K dimensions
|
||||
"persistent": False,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
}
|
||||
|
||||
|
||||
def spec_to_kernel_config(
|
||||
spec: KernelSpec, dtype: str, arch: str, dtype_acc: str = "fp32"
|
||||
) -> KernelConfig:
|
||||
"""Convert KernelSpec to KernelConfig for dispatcher"""
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc=dtype_acc,
|
||||
layout_a="row",
|
||||
layout_b="col",
|
||||
layout_c="row",
|
||||
tile_m=spec.tile_m,
|
||||
tile_n=spec.tile_n,
|
||||
tile_k=spec.tile_k,
|
||||
wave_m=spec.wave_m,
|
||||
wave_n=spec.wave_n,
|
||||
wave_k=spec.wave_k,
|
||||
warp_m=spec.warp_m,
|
||||
warp_n=spec.warp_n,
|
||||
warp_k=spec.warp_k,
|
||||
pipeline=spec.pipeline,
|
||||
scheduler=spec.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def ml_select_kernel(
|
||||
predictor, pool: List[KernelSpec], M: int, N: int, K: int, dtype: str, layout: str
|
||||
) -> Tuple[KernelSpec, float]:
|
||||
"""Use ML model to select best kernel"""
|
||||
if not HAS_ML or predictor is None:
|
||||
# Fallback: select first kernel
|
||||
return pool[0], 0.0
|
||||
|
||||
problem = {"m": M, "n": N, "k": K, "dtype": dtype, "layout": layout, "split_k": 1}
|
||||
kernel_dicts = [spec_to_feature_dict(s, dtype, layout) for s in pool]
|
||||
|
||||
ranked = predictor.rank_kernels(problem, kernel_dicts)
|
||||
if not ranked:
|
||||
return pool[0], 0.0
|
||||
|
||||
best_name, best_tflops = ranked[0]
|
||||
best_spec = next((s for s in pool if s.name == best_name), pool[0])
|
||||
return best_spec, best_tflops
|
||||
|
||||
|
||||
def run_single_gemm(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
dtype: str,
|
||||
arch: str,
|
||||
predictor,
|
||||
dry_run: bool = False,
|
||||
dtype_acc: str = "fp32",
|
||||
) -> dict:
|
||||
"""Run a single GEMM with ML heuristic selection"""
|
||||
|
||||
# Select kernel via ML heuristic
|
||||
t0 = time.time()
|
||||
best_spec, pred_tflops = ml_select_kernel(
|
||||
predictor, KERNEL_POOL, M, N, K, dtype, "rcr"
|
||||
)
|
||||
select_time_ms = (time.time() - t0) * 1000
|
||||
|
||||
result = {
|
||||
"M": M,
|
||||
"N": N,
|
||||
"K": K,
|
||||
"dtype": dtype,
|
||||
"selected_kernel": best_spec.name,
|
||||
"predicted_tflops": pred_tflops,
|
||||
"selection_time_ms": select_time_ms,
|
||||
"actual_time_ms": 0,
|
||||
"actual_tflops": 0,
|
||||
"status": "SKIP" if dry_run else "PENDING",
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if dry_run:
|
||||
return result
|
||||
|
||||
# Build and run kernel
|
||||
config = spec_to_kernel_config(best_spec, dtype, arch, dtype_acc)
|
||||
|
||||
try:
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"sweep_{dtype}_{best_spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
result["status"] = "BUILD_FAIL"
|
||||
result["error"] = "Failed to build kernel"
|
||||
cleanup_gemm()
|
||||
return result
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
result["status"] = "UNSUPPORTED"
|
||||
result["error"] = "Problem size not supported by kernel"
|
||||
cleanup_gemm()
|
||||
return result
|
||||
|
||||
# Create input data
|
||||
np_dtype = {"fp16": np.float16, "bf16": np.float16, "fp8": np.float16}[dtype]
|
||||
np.random.seed(42)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Run GEMM
|
||||
exec_result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if exec_result.success:
|
||||
result["actual_time_ms"] = exec_result.time_ms
|
||||
result["actual_tflops"] = exec_result.tflops
|
||||
result["status"] = "SUCCESS"
|
||||
else:
|
||||
# Decode status code for better error message
|
||||
status_messages = {
|
||||
0: "Success",
|
||||
-1: "GPU/HIP error (check permissions, memory, or kernel validity)",
|
||||
-2: "No suitable kernel found for this problem size",
|
||||
}
|
||||
error_msg = status_messages.get(exec_result.status, f"Unknown error (status={exec_result.status})")
|
||||
result["status"] = "RUN_FAIL"
|
||||
result["error"] = f"{error_msg} (status_code={exec_result.status})"
|
||||
|
||||
# Print detailed error for debugging
|
||||
print(f" ERROR: {error_msg}")
|
||||
print(f" Status code: {exec_result.status}")
|
||||
print(f" Time returned: {exec_result.time_ms}")
|
||||
print(f" Kernel: {exec_result.kernel_name}")
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
except Exception as e:
|
||||
result["status"] = "ERROR"
|
||||
result["error"] = str(e)[:200]
|
||||
cleanup_gemm()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ML Heuristic Sweep: Test GEMM across many shapes and dtypes"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtypes",
|
||||
nargs="+",
|
||||
default=["fp16"],
|
||||
choices=["fp16", "bf16", "fp8"],
|
||||
help="Data types to test (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx950", help="GPU architecture (default: gfx950)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype_acc",
|
||||
default="fp32",
|
||||
choices=["fp16", "fp32"],
|
||||
help="Accumulator data type (default: fp32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
default=None,
|
||||
help="Path to model directory (auto-detect if not specified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_shapes",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of problem shapes to test (default: 256)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="ml_heuristic_sweep_results.csv",
|
||||
help="Output CSV file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry_run",
|
||||
action="store_true",
|
||||
help="Only predict, do not run kernels (fast validation)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup ML predictor
|
||||
predictor = None
|
||||
if HAS_ML:
|
||||
if args.model_dir is None:
|
||||
# Auto-detect model directory based on first dtype
|
||||
first_dtype = args.dtypes[0]
|
||||
heuristics_dir = Path(__file__).parent
|
||||
model_candidates = [
|
||||
heuristics_dir / "models" / f"gemm_universal_{first_dtype}_{args.arch}",
|
||||
]
|
||||
for model_dir in model_candidates:
|
||||
if model_dir.exists():
|
||||
args.model_dir = str(model_dir)
|
||||
break
|
||||
|
||||
if args.model_dir and Path(args.model_dir).exists():
|
||||
try:
|
||||
predictor = Predictor(args.model_dir)
|
||||
print(f"✓ Loaded ML model from: {args.model_dir}")
|
||||
except Exception as e:
|
||||
print(f"⚠ Failed to load ML model: {e}")
|
||||
print(" Will use first-fit selection instead")
|
||||
else:
|
||||
print(f"⚠ Model directory not found: {args.model_dir}")
|
||||
print(" Will use first-fit selection instead")
|
||||
|
||||
# Generate problem shapes
|
||||
print(f"\nGenerating {args.num_shapes} problem shapes...")
|
||||
shapes = generate_problem_shapes(args.num_shapes)
|
||||
print(
|
||||
f"✓ Generated {len(shapes)} valid shapes (M>=1, N%8==0, N>=64, K%2==0, K>=32)"
|
||||
)
|
||||
|
||||
# Validate all shapes meet constraints
|
||||
invalid = [
|
||||
(m, n, k)
|
||||
for m, n, k in shapes
|
||||
if not (m >= 1 and n >= 64 and n % 8 == 0 and k >= 32 and k % 2 == 0)
|
||||
]
|
||||
if invalid:
|
||||
print(f"⚠ WARNING: {len(invalid)} shapes violate constraints!")
|
||||
print(f" First few: {invalid[:5]}")
|
||||
|
||||
# Print configuration
|
||||
print("\n" + "=" * 80)
|
||||
print(" ML Heuristic Sweep Configuration")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f" Model: {args.model_dir if args.model_dir else 'first-fit (no ML)'}"
|
||||
)
|
||||
print(f" Data types: {', '.join(args.dtypes)}")
|
||||
print(f" Accumulator: {args.dtype_acc}")
|
||||
print(f" Architecture: {args.arch}")
|
||||
print(f" Kernel pool: {len(KERNEL_POOL)} kernels")
|
||||
print(f" Problem shapes: {len(shapes)}")
|
||||
print(f" Total tests: {len(shapes) * len(args.dtypes)}")
|
||||
print(
|
||||
f" Mode: {'DRY RUN (prediction only)' if args.dry_run else 'FULL RUN (execute kernels)'}"
|
||||
)
|
||||
print(f" Output: {args.output}")
|
||||
print("=" * 80)
|
||||
|
||||
# Open output CSV
|
||||
csv_file = open(args.output, "w", newline="")
|
||||
csv_writer = csv.DictWriter(
|
||||
csv_file,
|
||||
fieldnames=[
|
||||
"dtype",
|
||||
"M",
|
||||
"N",
|
||||
"K",
|
||||
"selected_kernel",
|
||||
"predicted_tflops",
|
||||
"selection_time_ms",
|
||||
"actual_time_ms",
|
||||
"actual_tflops",
|
||||
"status",
|
||||
"error",
|
||||
],
|
||||
)
|
||||
csv_writer.writeheader()
|
||||
|
||||
# Run sweep
|
||||
total_tests = len(shapes) * len(args.dtypes)
|
||||
completed = 0
|
||||
start_time = time.time()
|
||||
|
||||
print("\nStarting sweep... (Ctrl+C to stop and save partial results)\n")
|
||||
|
||||
try:
|
||||
for dtype in args.dtypes:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f" Testing dtype: {dtype.upper()}")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
for i, (M, N, K) in enumerate(shapes):
|
||||
result = run_single_gemm(
|
||||
M, N, K, dtype, args.arch, predictor, args.dry_run, args.dtype_acc
|
||||
)
|
||||
|
||||
# Write to CSV
|
||||
csv_writer.writerow(result)
|
||||
csv_file.flush()
|
||||
|
||||
completed += 1
|
||||
|
||||
# Progress update
|
||||
if completed % 10 == 0 or result["status"] != "SUCCESS":
|
||||
elapsed = time.time() - start_time
|
||||
rate = completed / elapsed if elapsed > 0 else 0
|
||||
eta = (total_tests - completed) / rate if rate > 0 else 0
|
||||
|
||||
status_emoji = {
|
||||
"SUCCESS": "✓",
|
||||
"SKIP": "→",
|
||||
"BUILD_FAIL": "✗",
|
||||
"UNSUPPORTED": "○",
|
||||
"RUN_FAIL": "✗",
|
||||
"ERROR": "✗",
|
||||
}.get(result["status"], "?")
|
||||
|
||||
print(
|
||||
f" [{completed:4d}/{total_tests}] {status_emoji} "
|
||||
f"{dtype:4s} {M:5d}x{N:5d}x{K:5d} → "
|
||||
f"{result['selected_kernel']:20s} "
|
||||
f"pred={result['predicted_tflops']:6.1f} "
|
||||
f"actual={result['actual_tflops']:6.1f} TFLOPS "
|
||||
f"[{rate:.1f} tests/s, ETA {eta / 60:.1f}m]"
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n\n⚠ Interrupted! Saving partial results to {args.output}...")
|
||||
|
||||
finally:
|
||||
csv_file.close()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 80)
|
||||
print(" SWEEP COMPLETE")
|
||||
print("=" * 80)
|
||||
|
||||
# Read back results and compute statistics
|
||||
results = []
|
||||
with open(args.output, "r") as f:
|
||||
reader = csv.DictReader(f)
|
||||
results = list(reader)
|
||||
|
||||
print(f"\n Total tests: {len(results)}")
|
||||
print(f" Output file: {args.output}")
|
||||
|
||||
if not args.dry_run:
|
||||
success = [r for r in results if r["status"] == "SUCCESS"]
|
||||
print(
|
||||
f" Successful: {len(success)} ({100 * len(success) / len(results):.1f}%)"
|
||||
)
|
||||
|
||||
if success:
|
||||
avg_tflops = np.mean([float(r["actual_tflops"]) for r in success])
|
||||
max_tflops = max([float(r["actual_tflops"]) for r in success])
|
||||
print(f" Avg TFLOPS: {avg_tflops:.2f}")
|
||||
print(f" Max TFLOPS: {max_tflops:.2f}")
|
||||
|
||||
# Per-dtype breakdown
|
||||
for dtype in args.dtypes:
|
||||
dtype_results = [r for r in success if r["dtype"] == dtype]
|
||||
if dtype_results:
|
||||
avg = np.mean([float(r["actual_tflops"]) for r in dtype_results])
|
||||
print(
|
||||
f" {dtype:4s}: {avg:.2f} TFLOPS (n={len(dtype_results)})"
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,113 @@
|
||||
{
|
||||
"op_type": "gemm_universal",
|
||||
"dtype": "fp16",
|
||||
"arch": "gfx950",
|
||||
"feature_names": [
|
||||
"M",
|
||||
"N",
|
||||
"K",
|
||||
"split_k",
|
||||
"log2_M",
|
||||
"log2_N",
|
||||
"log2_K",
|
||||
"log2_MNK",
|
||||
"arithmetic_intensity",
|
||||
"aspect_ratio_mn",
|
||||
"aspect_ratio_mk",
|
||||
"aspect_ratio_nk",
|
||||
"layout",
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
"num_warps",
|
||||
"tile_volume",
|
||||
"tile_mn",
|
||||
"lds_usage_estimate",
|
||||
"lds_usage_ratio",
|
||||
"num_tiles_m",
|
||||
"num_tiles_n",
|
||||
"num_tiles_k",
|
||||
"total_output_tiles",
|
||||
"tile_eff_m",
|
||||
"tile_eff_n",
|
||||
"tile_eff_k",
|
||||
"overall_tile_efficiency",
|
||||
"cu_utilization",
|
||||
"ratio_M_to_tile_m",
|
||||
"ratio_N_to_tile_n",
|
||||
"ratio_K_to_tile_k",
|
||||
"problem_smaller_than_tile_m",
|
||||
"problem_smaller_than_tile_n",
|
||||
"problem_smaller_than_tile_k",
|
||||
"any_dim_too_small",
|
||||
"needs_padding_m",
|
||||
"needs_padding_n",
|
||||
"needs_padding_k",
|
||||
"has_padding_when_needed_m",
|
||||
"has_padding_when_needed_n",
|
||||
"has_padding_when_needed_k",
|
||||
"missing_required_padding_m",
|
||||
"missing_required_padding_n",
|
||||
"missing_required_padding_k",
|
||||
"missing_any_required_padding",
|
||||
"hw_num_cus",
|
||||
"hw_simds_per_cu",
|
||||
"hw_total_simds",
|
||||
"hw_shader_engines",
|
||||
"hw_max_clock_mhz",
|
||||
"hw_max_waves_per_cu",
|
||||
"hw_wavefront_size",
|
||||
"hw_lds_capacity",
|
||||
"hw_l1_cache_kb",
|
||||
"hw_l2_cache_kb",
|
||||
"hw_l3_cache_kb",
|
||||
"hw_num_xcd"
|
||||
],
|
||||
"categorical_features": [
|
||||
"layout",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue"
|
||||
],
|
||||
"targets": [
|
||||
"tflops",
|
||||
"latency",
|
||||
"bandwidth"
|
||||
],
|
||||
"log_targets": [
|
||||
"bandwidth",
|
||||
"tflops"
|
||||
],
|
||||
"params": {
|
||||
"objective": "regression",
|
||||
"metric": [
|
||||
"rmse",
|
||||
"mae"
|
||||
],
|
||||
"num_leaves": 255,
|
||||
"max_depth": 15,
|
||||
"n_estimators": 2000,
|
||||
"learning_rate": 0.02,
|
||||
"min_child_samples": 10,
|
||||
"subsample": 0.85,
|
||||
"colsample_bytree": 0.85,
|
||||
"reg_alpha": 0.05,
|
||||
"reg_lambda": 0.5,
|
||||
"verbose": -1,
|
||||
"n_jobs": 8,
|
||||
"seed": 42
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"warm_start_from": null,
|
||||
"prev_n_estimators": 0,
|
||||
"new_n_estimators": 2000,
|
||||
"total_n_estimators": 2000,
|
||||
"data_rows": 25600,
|
||||
"valid_rows": 21920,
|
||||
"unique_shapes": 25,
|
||||
"timestamp": "2026-03-20T05:00:55"
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
{
|
||||
"op_type": "gemm_universal",
|
||||
"dtype": "fp8",
|
||||
"arch": "gfx950",
|
||||
"feature_names": [
|
||||
"M",
|
||||
"N",
|
||||
"K",
|
||||
"split_k",
|
||||
"log2_M",
|
||||
"log2_N",
|
||||
"log2_K",
|
||||
"log2_MNK",
|
||||
"arithmetic_intensity",
|
||||
"aspect_ratio_mn",
|
||||
"aspect_ratio_mk",
|
||||
"aspect_ratio_nk",
|
||||
"layout",
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"warp_tile_m",
|
||||
"warp_tile_n",
|
||||
"warp_tile_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
"pad_m",
|
||||
"pad_n",
|
||||
"pad_k",
|
||||
"persistent",
|
||||
"num_warps",
|
||||
"tile_volume",
|
||||
"tile_mn",
|
||||
"lds_usage_estimate",
|
||||
"lds_usage_ratio",
|
||||
"num_tiles_m",
|
||||
"num_tiles_n",
|
||||
"num_tiles_k",
|
||||
"total_output_tiles",
|
||||
"tile_eff_m",
|
||||
"tile_eff_n",
|
||||
"tile_eff_k",
|
||||
"overall_tile_efficiency",
|
||||
"cu_utilization",
|
||||
"ratio_M_to_tile_m",
|
||||
"ratio_N_to_tile_n",
|
||||
"ratio_K_to_tile_k",
|
||||
"problem_smaller_than_tile_m",
|
||||
"problem_smaller_than_tile_n",
|
||||
"problem_smaller_than_tile_k",
|
||||
"any_dim_too_small",
|
||||
"needs_padding_m",
|
||||
"needs_padding_n",
|
||||
"needs_padding_k",
|
||||
"has_padding_when_needed_m",
|
||||
"has_padding_when_needed_n",
|
||||
"has_padding_when_needed_k",
|
||||
"missing_required_padding_m",
|
||||
"missing_required_padding_n",
|
||||
"missing_required_padding_k",
|
||||
"missing_any_required_padding",
|
||||
"hw_num_cus",
|
||||
"hw_simds_per_cu",
|
||||
"hw_total_simds",
|
||||
"hw_shader_engines",
|
||||
"hw_max_clock_mhz",
|
||||
"hw_max_waves_per_cu",
|
||||
"hw_wavefront_size",
|
||||
"hw_lds_capacity",
|
||||
"hw_l1_cache_kb",
|
||||
"hw_l2_cache_kb",
|
||||
"hw_l3_cache_kb",
|
||||
"hw_num_xcd"
|
||||
],
|
||||
"categorical_features": [
|
||||
"layout",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue"
|
||||
],
|
||||
"targets": [
|
||||
"tflops",
|
||||
"latency",
|
||||
"bandwidth"
|
||||
],
|
||||
"log_targets": [
|
||||
"bandwidth",
|
||||
"tflops"
|
||||
],
|
||||
"params": {
|
||||
"objective": "regression",
|
||||
"metric": [
|
||||
"rmse",
|
||||
"mae"
|
||||
],
|
||||
"num_leaves": 255,
|
||||
"max_depth": 15,
|
||||
"n_estimators": 2000,
|
||||
"learning_rate": 0.02,
|
||||
"min_child_samples": 10,
|
||||
"subsample": 0.85,
|
||||
"colsample_bytree": 0.85,
|
||||
"reg_alpha": 0.05,
|
||||
"reg_lambda": 0.5,
|
||||
"verbose": -1,
|
||||
"n_jobs": 8,
|
||||
"seed": 42
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"warm_start_from": null,
|
||||
"prev_n_estimators": 0,
|
||||
"new_n_estimators": 2000,
|
||||
"total_n_estimators": 2000,
|
||||
"data_rows": 1296528,
|
||||
"valid_rows": 1253076,
|
||||
"unique_shapes": 168,
|
||||
"timestamp": "2026-03-19T06:10:29"
|
||||
}
|
||||
243
dispatcher/heuristics/predict.py
Normal file
243
dispatcher/heuristics/predict.py
Normal file
@@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Predictor for CK Tile kernel performance.
|
||||
|
||||
Loads trained LightGBM models and provides:
|
||||
- predict_tflops(): predicted TFLOPS for a single (problem, kernel) pair
|
||||
- predict_latency(): predicted latency in ms
|
||||
- predict_bandwidth(): predicted bandwidth in GB/s
|
||||
- predict_all(): all three predictions at once
|
||||
- rank_kernels(): rank all candidate kernels by predicted TFLOPS
|
||||
- select_best(): return the best kernel ID
|
||||
|
||||
Usage:
|
||||
predictor = Predictor("models/gemm_universal_fp8_gfx950")
|
||||
best_kernel = predictor.select_best(
|
||||
problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"},
|
||||
kernel_configs=[...],
|
||||
)
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
|
||||
|
||||
class Predictor:
|
||||
"""Loads trained models and feature spec for kernel performance prediction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_dir : str or Path
|
||||
Directory containing model artifacts:
|
||||
- model_tflops.lgbm (required)
|
||||
- model_latency.lgbm (optional)
|
||||
- model_bandwidth.lgbm (optional)
|
||||
- feature_spec.json (required)
|
||||
|
||||
feature_engine : FeatureEngine, optional
|
||||
Override the feature engine. If None, constructs one from feature_spec.json.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir: str | Path, feature_engine=None):
|
||||
self._model_dir = Path(model_dir)
|
||||
self._models: dict[str, lgb.Booster] = {}
|
||||
|
||||
spec_path = self._model_dir / "feature_spec.json"
|
||||
if spec_path.exists():
|
||||
with open(spec_path) as f:
|
||||
self._spec = json.load(f)
|
||||
else:
|
||||
self._spec = {}
|
||||
|
||||
self._log_targets = set(self._spec.get("log_targets", []))
|
||||
|
||||
if feature_engine is not None:
|
||||
self._feature_engine = feature_engine
|
||||
else:
|
||||
self._feature_engine = GemmUniversalFeatureEngine()
|
||||
|
||||
def _load_model(self, target: str) -> Optional[lgb.Booster]:
|
||||
"""Lazy-load a model for the given target.
|
||||
|
||||
Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist.
|
||||
The decompressed file is cached to disk for subsequent loads.
|
||||
"""
|
||||
if target in self._models:
|
||||
return self._models[target]
|
||||
|
||||
path = self._model_dir / f"model_{target}.lgbm"
|
||||
gz_path = self._model_dir / f"model_{target}.lgbm.gz"
|
||||
|
||||
# Auto-decompress if needed
|
||||
if not path.exists() and gz_path.exists():
|
||||
with gzip.open(gz_path, 'rb') as f_in:
|
||||
with open(path, 'wb') as f_out:
|
||||
f_out.write(f_in.read())
|
||||
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
model = lgb.Booster(model_file=str(path))
|
||||
self._models[target] = model
|
||||
return model
|
||||
|
||||
def _predict_single(self, target: str, problem: dict, kernel_config: dict) -> float:
|
||||
"""Predict a single target value, applying inverse log transform if needed."""
|
||||
model = self._load_model(target)
|
||||
if model is None:
|
||||
raise FileNotFoundError(f"No model_{target}.lgbm in {self._model_dir}")
|
||||
features = self._feature_engine.extract(problem, kernel_config)
|
||||
raw = float(model.predict(features.reshape(1, -1))[0])
|
||||
if target in self._log_targets:
|
||||
return float(np.expm1(raw))
|
||||
# Clamp to non-negative even for non-log models
|
||||
return float(max(0.0, raw))
|
||||
|
||||
def predict_tflops(self, problem: dict, kernel_config: dict) -> float:
|
||||
"""Predict TFLOPS for a single (problem, kernel) pair.
|
||||
|
||||
Returns a real TFLOPS estimate (interpretable, usable as DE surrogate).
|
||||
If the model was trained in log-space, the inverse transform is applied
|
||||
automatically.
|
||||
"""
|
||||
return self._predict_single("tflops", problem, kernel_config)
|
||||
|
||||
def predict_latency(self, problem: dict, kernel_config: dict) -> float:
|
||||
"""Predict latency in milliseconds for a single (problem, kernel) pair."""
|
||||
return self._predict_single("latency", problem, kernel_config)
|
||||
|
||||
def predict_bandwidth(self, problem: dict, kernel_config: dict) -> float:
|
||||
"""Predict bandwidth in GB/s for a single (problem, kernel) pair."""
|
||||
return self._predict_single("bandwidth", problem, kernel_config)
|
||||
|
||||
def predict_all(self, problem: dict, kernel_config: dict) -> dict[str, float]:
|
||||
"""Predict all available targets for a single (problem, kernel) pair.
|
||||
|
||||
Returns dict with keys 'tflops', 'latency_ms', 'bandwidth_gb_s' (if models exist).
|
||||
|
||||
Note: Applies inverse log transform for targets in log_targets and clamps
|
||||
negatives to 0.0, consistent with _predict_single().
|
||||
"""
|
||||
features = self._feature_engine.extract(problem, kernel_config).reshape(1, -1)
|
||||
result = {}
|
||||
for target, key in [
|
||||
("tflops", "tflops"),
|
||||
("latency", "latency_ms"),
|
||||
("bandwidth", "bandwidth_gb_s"),
|
||||
]:
|
||||
model = self._load_model(target)
|
||||
if model is not None:
|
||||
raw = float(model.predict(features)[0])
|
||||
# Apply inverse log transform if model was trained in log-space
|
||||
if target in self._log_targets:
|
||||
result[key] = float(np.expm1(raw))
|
||||
else:
|
||||
# Clamp to non-negative even for non-log models
|
||||
result[key] = float(max(0.0, raw))
|
||||
return result
|
||||
|
||||
def rank_kernels(
|
||||
self, problem: dict, kernel_configs: list[dict]
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Rank candidate kernels by predicted TFLOPS (descending).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
problem : dict
|
||||
Problem specification with keys: m, n, k, dtype, layout, split_k.
|
||||
kernel_configs : list of dict
|
||||
Each dict must have a 'kernel_name' key plus kernel parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of (kernel_name, predicted_tflops) tuples, sorted descending.
|
||||
"""
|
||||
if not kernel_configs:
|
||||
return []
|
||||
|
||||
model = self._load_model("tflops")
|
||||
if model is None:
|
||||
raise FileNotFoundError(f"No model_tflops.lgbm in {self._model_dir}")
|
||||
|
||||
rows = []
|
||||
for kc in kernel_configs:
|
||||
merged = {**problem, **kc}
|
||||
rows.append(merged)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
X = self._feature_engine.extract_batch(df)
|
||||
preds = model.predict(X)
|
||||
if "tflops" in self._log_targets:
|
||||
preds = np.expm1(preds)
|
||||
|
||||
results = []
|
||||
for i, kc in enumerate(kernel_configs):
|
||||
name = kc.get("kernel_name", f"kernel_{i}")
|
||||
results.append((name, float(preds[i])))
|
||||
|
||||
results.sort(key=lambda x: -x[1])
|
||||
return results
|
||||
|
||||
def select_best(self, problem: dict, kernel_configs: list[dict]) -> str:
|
||||
"""Return the kernel_name of the best predicted kernel."""
|
||||
ranked = self.rank_kernels(problem, kernel_configs)
|
||||
if not ranked:
|
||||
raise ValueError("No kernel configs provided")
|
||||
return ranked[0][0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Predict kernel performance")
|
||||
parser.add_argument(
|
||||
"--model_dir", required=True, help="Directory with trained models"
|
||||
)
|
||||
parser.add_argument("--m", type=int, required=True)
|
||||
parser.add_argument("--n", type=int, required=True)
|
||||
parser.add_argument("--k", type=int, required=True)
|
||||
parser.add_argument("--layout", default="rcr")
|
||||
parser.add_argument("--dtype", default="fp8")
|
||||
args = parser.parse_args()
|
||||
|
||||
predictor = Predictor(args.model_dir)
|
||||
problem = {
|
||||
"m": args.m,
|
||||
"n": args.n,
|
||||
"k": args.k,
|
||||
"dtype": args.dtype,
|
||||
"layout": args.layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
print(f"Loading models from {args.model_dir}...")
|
||||
print(
|
||||
f"Problem: M={args.m} N={args.n} K={args.k} dtype={args.dtype} layout={args.layout}"
|
||||
)
|
||||
|
||||
data_dir = Path(args.model_dir).parent.parent / "data"
|
||||
if data_dir.exists():
|
||||
for pq in data_dir.glob("*.parquet"):
|
||||
df = pd.read_parquet(pq)
|
||||
kernel_names = df["kernel_name"].unique()
|
||||
configs = []
|
||||
for kn in kernel_names[:10]:
|
||||
row = df[df["kernel_name"] == kn].iloc[0]
|
||||
configs.append(row.to_dict())
|
||||
if configs:
|
||||
ranked = predictor.rank_kernels(problem, configs)
|
||||
print(f"\nTop 5 kernels (from {len(configs)} candidates):")
|
||||
for name, tflops in ranked[:5]:
|
||||
print(f" {tflops:8.2f} TFLOPS {name}")
|
||||
break
|
||||
272
dispatcher/heuristics/search.py
Normal file
272
dispatcher/heuristics/search.py
Normal file
@@ -0,0 +1,272 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Surrogate search for CK Tile kernel configuration optimization.
|
||||
|
||||
Uses a trained LGBMRegressor as a cheap surrogate function to search the
|
||||
discrete kernel parameter space (tile sizes, warp config, pipeline, etc.)
|
||||
without running actual GPU benchmarks.
|
||||
|
||||
Strategies:
|
||||
- 'random': Sample N random valid configs, score all, return top-K.
|
||||
- 'de': Discrete Differential Evolution with mutation over valid parameter choices.
|
||||
|
||||
Usage:
|
||||
from search import SurrogateSearch
|
||||
from predict import Predictor
|
||||
|
||||
predictor = Predictor("models/gemm_universal_fp8_gfx950")
|
||||
searcher = SurrogateSearch(predictor, strategy='random')
|
||||
results = searcher.search(
|
||||
problem={"m": 128, "n": 1536, "k": 7168, "dtype": "fp8", "layout": "rcr"},
|
||||
budget=500,
|
||||
)
|
||||
# results: [(config_dict, predicted_tflops), ...] sorted descending
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
class SurrogateSearch:
|
||||
"""Search kernel parameter space using ML regressor as surrogate objective.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predictor : Predictor
|
||||
Trained predictor with a TFLOPS model.
|
||||
feature_engine : GemmUniversalFeatureEngine, optional
|
||||
Feature engine for parameter space and validation. If None, uses default.
|
||||
strategy : str
|
||||
Search strategy: 'random' or 'de' (Discrete Differential Evolution).
|
||||
seed : int
|
||||
Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
predictor: Predictor,
|
||||
feature_engine: Optional[GemmUniversalFeatureEngine] = None,
|
||||
strategy: str = "random",
|
||||
seed: int = 42,
|
||||
):
|
||||
self._predictor = predictor
|
||||
self._fe = feature_engine or GemmUniversalFeatureEngine()
|
||||
self._strategy = strategy
|
||||
self._rng = random.Random(seed)
|
||||
self._np_rng = np.random.RandomState(seed)
|
||||
self._param_space = self._fe.get_parameter_space()
|
||||
|
||||
def _sample_random_config(self) -> dict:
|
||||
"""Sample a single random config from the parameter space."""
|
||||
config = {}
|
||||
for param, values in self._param_space.items():
|
||||
config[param] = self._rng.choice(values)
|
||||
return config
|
||||
|
||||
def _sample_valid_config(self, max_attempts: int = 50) -> Optional[dict]:
|
||||
"""Sample a random config that passes all validation constraints."""
|
||||
for _ in range(max_attempts):
|
||||
config = self._sample_random_config()
|
||||
if self._fe.validate_config(config):
|
||||
return config
|
||||
return None
|
||||
|
||||
def _score_config(self, problem: dict, config: dict) -> float:
|
||||
"""Score a config using the predictor."""
|
||||
return self._predictor.predict_tflops(problem, config)
|
||||
|
||||
def _search_random(
|
||||
self, problem: dict, budget: int, top_k: int
|
||||
) -> list[tuple[dict, float]]:
|
||||
"""Random search: sample valid configs, score all, return top-K."""
|
||||
configs = []
|
||||
for _ in range(budget):
|
||||
cfg = self._sample_valid_config()
|
||||
if cfg is not None:
|
||||
configs.append(cfg)
|
||||
|
||||
if not configs:
|
||||
return []
|
||||
|
||||
scored = []
|
||||
for cfg in configs:
|
||||
try:
|
||||
score = self._score_config(problem, cfg)
|
||||
scored.append((cfg, score))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
scored.sort(key=lambda x: -x[1])
|
||||
return scored[:top_k]
|
||||
|
||||
def _search_de(
|
||||
self,
|
||||
problem: dict,
|
||||
budget: int,
|
||||
top_k: int,
|
||||
pop_size: int = 20,
|
||||
mutation_rate: float = 0.3,
|
||||
crossover_rate: float = 0.7,
|
||||
) -> list[tuple[dict, float]]:
|
||||
"""Discrete Differential Evolution.
|
||||
|
||||
Uses discrete mutation: randomly swap parameters to other valid values
|
||||
from the parameter space (no continuous relaxation + snap).
|
||||
|
||||
Each generation:
|
||||
1. For each member of the population, create a trial vector by:
|
||||
- Selecting 3 random other members (a, b, c)
|
||||
- For each parameter, with probability mutation_rate, take the value
|
||||
from a, b, or c (uniform choice among the three donors)
|
||||
- With probability crossover_rate, take the trial value; otherwise keep original
|
||||
2. Validate the trial; if invalid, resample that parameter from the space
|
||||
3. Score the trial; if better than parent, replace
|
||||
"""
|
||||
param_names = list(self._param_space.keys())
|
||||
|
||||
population = []
|
||||
for _ in range(pop_size):
|
||||
cfg = self._sample_valid_config()
|
||||
if cfg is not None:
|
||||
score = self._score_config(problem, cfg)
|
||||
population.append((cfg, score))
|
||||
|
||||
if len(population) < 4:
|
||||
return self._search_random(problem, budget, top_k)
|
||||
|
||||
evals_used = len(population)
|
||||
max_gens = (budget - evals_used) // pop_size
|
||||
|
||||
for gen in range(max_gens):
|
||||
new_pop = []
|
||||
for i, (parent, parent_score) in enumerate(population):
|
||||
candidates = [j for j in range(len(population)) if j != i]
|
||||
if len(candidates) < 3:
|
||||
new_pop.append((parent, parent_score))
|
||||
continue
|
||||
|
||||
a_idx, b_idx, c_idx = self._rng.sample(candidates, 3)
|
||||
a, b, c = (
|
||||
population[a_idx][0],
|
||||
population[b_idx][0],
|
||||
population[c_idx][0],
|
||||
)
|
||||
|
||||
trial = dict(parent)
|
||||
for param in param_names:
|
||||
if self._rng.random() < mutation_rate:
|
||||
donor = self._rng.choice([a, b, c])
|
||||
trial[param] = donor.get(param, parent.get(param))
|
||||
|
||||
if self._rng.random() > crossover_rate:
|
||||
trial[param] = parent.get(param)
|
||||
|
||||
if not self._fe.validate_config(trial):
|
||||
for param in param_names:
|
||||
if param in trial and trial[param] not in self._param_space.get(
|
||||
param, [trial[param]]
|
||||
):
|
||||
trial[param] = self._rng.choice(self._param_space[param])
|
||||
if not self._fe.validate_config(trial):
|
||||
new_pop.append((parent, parent_score))
|
||||
continue
|
||||
|
||||
try:
|
||||
trial_score = self._score_config(problem, trial)
|
||||
evals_used += 1
|
||||
except Exception:
|
||||
new_pop.append((parent, parent_score))
|
||||
continue
|
||||
|
||||
if trial_score > parent_score:
|
||||
new_pop.append((trial, trial_score))
|
||||
else:
|
||||
new_pop.append((parent, parent_score))
|
||||
|
||||
population = new_pop
|
||||
|
||||
population.sort(key=lambda x: -x[1])
|
||||
return population[:top_k]
|
||||
|
||||
def search(
|
||||
self,
|
||||
problem: dict,
|
||||
budget: int = 500,
|
||||
top_k: int = 10,
|
||||
**kwargs,
|
||||
) -> list[tuple[dict, float]]:
|
||||
"""Search the kernel parameter space for the best configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
problem : dict
|
||||
Problem specification: m, n, k, dtype, layout, split_k.
|
||||
budget : int
|
||||
Maximum number of surrogate evaluations.
|
||||
top_k : int
|
||||
Number of top configurations to return.
|
||||
**kwargs
|
||||
Strategy-specific parameters (pop_size, mutation_rate, etc.).
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of (config_dict, predicted_tflops), sorted descending by TFLOPS.
|
||||
"""
|
||||
if self._strategy == "random":
|
||||
return self._search_random(problem, budget, top_k)
|
||||
elif self._strategy == "de":
|
||||
return self._search_de(problem, budget, top_k, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {self._strategy}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Surrogate search for optimal kernel config"
|
||||
)
|
||||
parser.add_argument("--model_dir", required=True)
|
||||
parser.add_argument("--m", type=int, required=True)
|
||||
parser.add_argument("--n", type=int, required=True)
|
||||
parser.add_argument("--k", type=int, required=True)
|
||||
parser.add_argument("--dtype", default="fp8")
|
||||
parser.add_argument("--layout", default="rcr")
|
||||
parser.add_argument("--strategy", default="random", choices=["random", "de"])
|
||||
parser.add_argument("--budget", type=int, default=500)
|
||||
parser.add_argument("--top_k", type=int, default=10)
|
||||
args = parser.parse_args()
|
||||
|
||||
predictor = Predictor(args.model_dir)
|
||||
searcher = SurrogateSearch(predictor, strategy=args.strategy)
|
||||
problem = {
|
||||
"m": args.m,
|
||||
"n": args.n,
|
||||
"k": args.k,
|
||||
"dtype": args.dtype,
|
||||
"layout": args.layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
print(f"Searching with strategy={args.strategy}, budget={args.budget}...")
|
||||
t0 = time.time()
|
||||
results = searcher.search(problem, budget=args.budget, top_k=args.top_k)
|
||||
elapsed = time.time() - t0
|
||||
|
||||
print(f"\nTop {len(results)} configs found in {elapsed * 1000:.1f}ms:")
|
||||
for i, (cfg, tflops) in enumerate(results):
|
||||
tile_str = f"{cfg.get('tile_m', '?')}x{cfg.get('tile_n', '?')}x{cfg.get('tile_k', '?')}"
|
||||
warp_str = f"{cfg.get('warp_m', '?')}x{cfg.get('warp_n', '?')}x{cfg.get('warp_k', '?')}"
|
||||
print(
|
||||
f" #{i + 1}: {tflops:8.2f} TFLOPS tile={tile_str} warp={warp_str} "
|
||||
f"pipeline={cfg.get('pipeline', '?')} scheduler={cfg.get('scheduler', '?')}"
|
||||
)
|
||||
2
dispatcher/heuristics/tests/__init__.py
Normal file
2
dispatcher/heuristics/tests/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
368
dispatcher/heuristics/tests/test_data_pipeline.py
Normal file
368
dispatcher/heuristics/tests/test_data_pipeline.py
Normal file
@@ -0,0 +1,368 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for data_pipeline.py.
|
||||
|
||||
Covers: kernel name parsing, layout derivation, streaming log parsing,
|
||||
schema validation, and corner cases (empty logs, malformed JSON, single-shape).
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from data_pipeline import (
|
||||
parse_kernel_name,
|
||||
_layout_from_problem,
|
||||
parse_streaming_log,
|
||||
save_parquet,
|
||||
load_parquet,
|
||||
CANONICAL_COLUMNS,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_kernel_name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseKernelName:
|
||||
def test_standard_name(self):
|
||||
name = "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128"
|
||||
result = parse_kernel_name(name)
|
||||
assert result["dtype"] == "fp8"
|
||||
assert result["layout"] == "rcr"
|
||||
assert result["pipeline"] == "compv3"
|
||||
assert result["epilogue"] == "cshuffle"
|
||||
assert result["scheduler"] == "intrawave"
|
||||
assert result["pad_m"] is False
|
||||
assert result["pad_n"] is False
|
||||
assert result["pad_k"] is False
|
||||
assert result["persistent"] is False
|
||||
assert result["tile_m"] == 128
|
||||
assert result["tile_n"] == 128
|
||||
assert result["tile_k"] == 128
|
||||
assert result["warp_m"] == 1
|
||||
assert result["warp_n"] == 4
|
||||
assert result["warp_k"] == 1
|
||||
assert result["warp_tile_m"] == 16
|
||||
assert result["warp_tile_n"] == 16
|
||||
assert result["warp_tile_k"] == 128
|
||||
|
||||
def test_with_padding_and_persistent(self):
|
||||
name = "gemm_universal_fp16_rrr_compv4_default_interwave_True_True_True_True_256x256x64_2x2x1_32x32x16"
|
||||
result = parse_kernel_name(name)
|
||||
assert result["dtype"] == "fp16"
|
||||
assert result["layout"] == "rrr"
|
||||
assert result["pad_m"] is True
|
||||
assert result["pad_n"] is True
|
||||
assert result["pad_k"] is True
|
||||
assert result["persistent"] is True
|
||||
assert result["tile_m"] == 256
|
||||
|
||||
def test_empty_name(self):
|
||||
assert parse_kernel_name("") == {}
|
||||
|
||||
def test_malformed_name(self):
|
||||
assert parse_kernel_name("not_a_kernel_name") == {}
|
||||
|
||||
def test_partial_name(self):
|
||||
result = parse_kernel_name("gemm_universal_fp8_rcr_compv3")
|
||||
assert result.get("dtype") == "fp8"
|
||||
assert result.get("layout") == "rcr"
|
||||
assert "tile_m" not in result # not enough parts
|
||||
|
||||
def test_all_layouts(self):
|
||||
for layout in ["rcr", "rrr", "crr", "ccr"]:
|
||||
name = f"gemm_universal_fp8_{layout}_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128"
|
||||
result = parse_kernel_name(name)
|
||||
assert result["layout"] == layout
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _layout_from_problem
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLayoutFromProblem:
|
||||
def test_rcr(self):
|
||||
assert (
|
||||
_layout_from_problem(
|
||||
{
|
||||
"layout_a": "RowMajor",
|
||||
"layout_b": "ColumnMajor",
|
||||
"layout_c": "RowMajor",
|
||||
}
|
||||
)
|
||||
== "rcr"
|
||||
)
|
||||
|
||||
def test_rrr(self):
|
||||
assert (
|
||||
_layout_from_problem(
|
||||
{"layout_a": "RowMajor", "layout_b": "RowMajor", "layout_c": "RowMajor"}
|
||||
)
|
||||
== "rrr"
|
||||
)
|
||||
|
||||
def test_empty(self):
|
||||
assert _layout_from_problem({}) == "???"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert (
|
||||
_layout_from_problem(
|
||||
{
|
||||
"layout_a": "rowmajor",
|
||||
"layout_b": "COLUMNMAJOR",
|
||||
"layout_c": "RowMajor",
|
||||
}
|
||||
)
|
||||
== "rcr"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_streaming_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_LOG = """\
|
||||
================================================================================
|
||||
LOG FILE: test.log
|
||||
================================================================================
|
||||
CK Tile Profiling Run
|
||||
GPU ID: 0
|
||||
|
||||
--- Running CK Tile benchmarks on GPU 0 ---
|
||||
|
||||
========================================
|
||||
Shape 1: M=16 N=1536 K=7168 dtype=fp8 layout=rcr
|
||||
========================================
|
||||
Found 2 kernels
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
|
||||
"problem": {
|
||||
"split_k":1,
|
||||
"m":16,
|
||||
"n":1536,
|
||||
"k":7168,
|
||||
"stride_a":7168,
|
||||
"stride_b":7168,
|
||||
"stride_c":1536,
|
||||
"dtype_a":"fp8",
|
||||
"dtype_b":"fp8",
|
||||
"dtype_acc":"fp32",
|
||||
"dtype_c":"fp16",
|
||||
"layout_a":"RowMajor",
|
||||
"layout_b":"ColumnMajor",
|
||||
"layout_c":"RowMajor",
|
||||
"structured_sparsity":false
|
||||
},
|
||||
"perf_result": {
|
||||
"latency(ms)": 0.04,
|
||||
"tflops(TFlops)": 8.81,
|
||||
"bandwidth(GB/s)": 279.51
|
||||
}
|
||||
}
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_2x2x1_32x32x16",
|
||||
"problem": {
|
||||
"split_k":1,
|
||||
"m":16,
|
||||
"n":1536,
|
||||
"k":7168,
|
||||
"stride_a":7168,
|
||||
"stride_b":7168,
|
||||
"stride_c":1536,
|
||||
"dtype_a":"fp8",
|
||||
"dtype_b":"fp8",
|
||||
"dtype_acc":"fp32",
|
||||
"dtype_c":"fp16",
|
||||
"layout_a":"RowMajor",
|
||||
"layout_b":"ColumnMajor",
|
||||
"layout_c":"RowMajor",
|
||||
"structured_sparsity":false
|
||||
},
|
||||
"perf_result": {
|
||||
"latency(ms)": 0.05,
|
||||
"tflops(TFlops)": 7.22,
|
||||
"bandwidth(GB/s)": 228.85
|
||||
}
|
||||
}
|
||||
|
||||
========================================
|
||||
Shape 2: M=20480 N=7168 K=256 dtype=fp8 layout=rcr
|
||||
========================================
|
||||
Found 1 kernels
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_mem_cshuffle_intrawave_False_False_False_True_64x64x128_1x4x1_16x16x32",
|
||||
"problem": {
|
||||
"split_k":1,
|
||||
"m":20480,
|
||||
"n":7168,
|
||||
"k":256,
|
||||
"stride_a":256,
|
||||
"stride_b":256,
|
||||
"stride_c":7168,
|
||||
"dtype_a":"fp8",
|
||||
"dtype_b":"fp8",
|
||||
"dtype_acc":"fp32",
|
||||
"dtype_c":"fp16",
|
||||
"layout_a":"RowMajor",
|
||||
"layout_b":"ColumnMajor",
|
||||
"layout_c":"RowMajor",
|
||||
"structured_sparsity":false
|
||||
},
|
||||
"perf_result": {
|
||||
"latency(ms)": 0.15,
|
||||
"tflops(TFlops)": 505.00,
|
||||
"bandwidth(GB/s)": 1200.50
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class TestParseStreamingLog:
|
||||
def _write_log(self, content: str) -> Path:
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False)
|
||||
f.write(content)
|
||||
f.close()
|
||||
return Path(f.name)
|
||||
|
||||
def test_basic_parse(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path, arch="gfx950")
|
||||
assert len(df) == 3
|
||||
assert df["arch"].iloc[0] == "gfx950"
|
||||
assert df["m"].tolist() == [16, 16, 20480]
|
||||
assert df["n"].tolist() == [1536, 1536, 7168]
|
||||
assert df["k"].tolist() == [7168, 7168, 256]
|
||||
|
||||
def test_tflops_values(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path)
|
||||
assert df["measured_tflops"].tolist() == pytest.approx([8.81, 7.22, 505.0])
|
||||
|
||||
def test_kernel_config_parsed(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path)
|
||||
assert df["tile_m"].iloc[0] == 128
|
||||
assert df["pipeline"].iloc[0] == "compv3"
|
||||
assert df["pipeline"].iloc[1] == "compv4"
|
||||
|
||||
def test_layout_derived_from_json(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path)
|
||||
assert all(df["layout"] == "rcr")
|
||||
|
||||
def test_empty_log(self):
|
||||
path = self._write_log("No shapes here\nJust noise\n")
|
||||
df = parse_streaming_log(path)
|
||||
assert len(df) == 0
|
||||
for col in CANONICAL_COLUMNS:
|
||||
assert col in df.columns
|
||||
|
||||
def test_single_kernel(self):
|
||||
log = """\
|
||||
Shape 1: M=1 N=1 K=1 dtype=fp8 layout=rcr
|
||||
{
|
||||
"name": "gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
|
||||
"problem": {"split_k":1, "m":1, "n":1, "k":1, "dtype_a":"fp8", "dtype_b":"fp8", "layout_a":"RowMajor", "layout_b":"ColumnMajor", "layout_c":"RowMajor"},
|
||||
"perf_result": {"latency(ms)": 0.001, "tflops(TFlops)": 0.002, "bandwidth(GB/s)": 0.01}
|
||||
}
|
||||
"""
|
||||
path = self._write_log(log)
|
||||
df = parse_streaming_log(path)
|
||||
assert len(df) == 1
|
||||
assert df["m"].iloc[0] == 1
|
||||
assert bool(df["is_valid"].iloc[0]) is True
|
||||
|
||||
def test_zero_tflops_marked_invalid(self):
|
||||
log = """\
|
||||
Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr
|
||||
{
|
||||
"name": "test_kernel",
|
||||
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
|
||||
"perf_result": {"latency(ms)": 0.0, "tflops(TFlops)": 0.0, "bandwidth(GB/s)": 0.0}
|
||||
}
|
||||
"""
|
||||
path = self._write_log(log)
|
||||
df = parse_streaming_log(path)
|
||||
assert len(df) == 1
|
||||
assert bool(df["is_valid"].iloc[0]) is False
|
||||
|
||||
def test_malformed_json_skipped(self):
|
||||
log = """\
|
||||
Shape 1: M=16 N=16 K=16 dtype=fp8 layout=rcr
|
||||
{
|
||||
"name": "good_kernel",
|
||||
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
|
||||
"perf_result": {"latency(ms)": 0.01, "tflops(TFlops)": 1.0, "bandwidth(GB/s)": 10.0}
|
||||
}
|
||||
{ this is not valid json }
|
||||
{
|
||||
"name": "another_good",
|
||||
"problem": {"split_k":1, "m":16, "n":16, "k":16, "dtype_a":"fp8"},
|
||||
"perf_result": {"latency(ms)": 0.02, "tflops(TFlops)": 2.0, "bandwidth(GB/s)": 20.0}
|
||||
}
|
||||
"""
|
||||
path = self._write_log(log)
|
||||
df = parse_streaming_log(path)
|
||||
assert len(df) == 2
|
||||
|
||||
def test_extreme_shapes(self):
|
||||
"""Tiny M=1 (single token) and very large M=20480."""
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path)
|
||||
assert 1 not in df["m"].values # sample has M=16, M=20480
|
||||
assert 16 in df["m"].values
|
||||
assert 20480 in df["m"].values
|
||||
|
||||
def test_run_id_assigned(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path, run_id="test_run_123")
|
||||
assert all(df["run_id"] == "test_run_123")
|
||||
|
||||
def test_op_type_assigned(self):
|
||||
path = self._write_log(SAMPLE_LOG)
|
||||
df = parse_streaming_log(path, op_type="gemm_streamk")
|
||||
assert all(df["op_type"] == "gemm_streamk")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parquet round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParquetIO:
|
||||
def test_round_trip(self, tmp_path):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [16, 32],
|
||||
"n": [1536, 1536],
|
||||
"k": [7168, 7168],
|
||||
"measured_tflops": [8.81, 15.0],
|
||||
}
|
||||
)
|
||||
path = tmp_path / "test.parquet"
|
||||
save_parquet(df, path)
|
||||
loaded = load_parquet(path)
|
||||
assert len(loaded) == 2
|
||||
assert loaded["m"].tolist() == [16, 32]
|
||||
|
||||
def test_creates_parent_dirs(self, tmp_path):
|
||||
path = tmp_path / "sub" / "dir" / "test.parquet"
|
||||
df = pd.DataFrame({"x": [1]})
|
||||
save_parquet(df, path)
|
||||
assert path.exists()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
264
dispatcher/heuristics/tests/test_dispatcher_integration.py
Normal file
264
dispatcher/heuristics/tests/test_dispatcher_integration.py
Normal file
@@ -0,0 +1,264 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for dispatcher_integration.py.
|
||||
|
||||
Covers: kernel name parsing to feature dict, feature dict to dispatcher config
|
||||
(name mapping inversion), MLKernelSpec creation, binary pool loading, and
|
||||
the ML heuristic function.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from dispatcher_integration import (
|
||||
kernel_config_to_feature_dict,
|
||||
feature_dict_to_dispatcher_config,
|
||||
feature_dict_to_ml_spec,
|
||||
ml_spec_to_dispatcher_config,
|
||||
create_ml_heuristic,
|
||||
load_kernel_pool_from_binaries,
|
||||
MLKernelSpec,
|
||||
LAYOUT_TO_DISPATCHER,
|
||||
)
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
|
||||
|
||||
SAMPLE_KERNEL_NAME = (
|
||||
"gemm_universal_fp8_rcr_compv3_cshuffle_intrawave"
|
||||
"_False_False_False_False_128x128x128_1x4x1_16x16x128"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# kernel_config_to_feature_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKernelConfigToFeatureDict:
|
||||
def test_parses_standard_name(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
assert feat["tile_m"] == 128
|
||||
assert feat["tile_n"] == 128
|
||||
assert feat["tile_k"] == 128
|
||||
assert feat["warp_m"] == 1 # warps per block
|
||||
assert feat["warp_n"] == 4
|
||||
assert feat["warp_k"] == 1
|
||||
assert feat["warp_tile_m"] == 16
|
||||
assert feat["warp_tile_n"] == 16
|
||||
assert feat["warp_tile_k"] == 128
|
||||
assert feat["pipeline"] == "compv3"
|
||||
assert feat["scheduler"] == "intrawave"
|
||||
assert feat["epilogue"] == "cshuffle"
|
||||
assert feat["kernel_name"] == SAMPLE_KERNEL_NAME
|
||||
|
||||
def test_empty_name_returns_empty(self):
|
||||
assert kernel_config_to_feature_dict("") == {}
|
||||
|
||||
def test_invalid_name_returns_empty(self):
|
||||
assert kernel_config_to_feature_dict("not_a_kernel") == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Name mapping: feature dict <-> dispatcher config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNameMapping:
|
||||
"""The critical inversion: feature engine warp_m/n/k (warps per block)
|
||||
maps to dispatcher wave_m/n/k, and feature engine warp_tile_m/n/k
|
||||
maps to dispatcher warp_m/n/k."""
|
||||
|
||||
def test_warp_to_wave_mapping(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["wave_m"] == feat["warp_m"] # 1
|
||||
assert disp["wave_n"] == feat["warp_n"] # 4
|
||||
assert disp["wave_k"] == feat["warp_k"] # 1
|
||||
|
||||
def test_warp_tile_to_warp_mapping(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["warp_m"] == feat["warp_tile_m"] # 16
|
||||
assert disp["warp_n"] == feat["warp_tile_n"] # 16
|
||||
assert disp["warp_k"] == feat["warp_tile_k"] # 128
|
||||
|
||||
def test_tile_dims_pass_through(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["tile_m"] == 128
|
||||
assert disp["tile_n"] == 128
|
||||
assert disp["tile_k"] == 128
|
||||
|
||||
def test_pipeline_passes_through(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["pipeline"] == "compv3"
|
||||
assert disp["scheduler"] == "intrawave"
|
||||
assert disp["epilogue"] == "cshuffle"
|
||||
|
||||
def test_rcr_layout_mapping(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
disp = feature_dict_to_dispatcher_config(feat, dtype="fp8")
|
||||
assert disp["layout_a"] == "row"
|
||||
assert disp["layout_b"] == "col"
|
||||
assert disp["layout_c"] == "row"
|
||||
|
||||
def test_all_layouts(self):
|
||||
for layout, (la, lb, lc) in LAYOUT_TO_DISPATCHER.items():
|
||||
feat = {"layout": layout, "tile_m": 128}
|
||||
disp = feature_dict_to_dispatcher_config(feat)
|
||||
assert disp["layout_a"] == la
|
||||
assert disp["layout_b"] == lb
|
||||
assert disp["layout_c"] == lc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MLKernelSpec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMLKernelSpec:
|
||||
def test_from_feature_dict(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
spec = feature_dict_to_ml_spec(feat, predicted_tflops=123.4)
|
||||
assert spec.kernel_name == SAMPLE_KERNEL_NAME
|
||||
assert spec.predicted_tflops == 123.4
|
||||
assert spec.tile_m == 128
|
||||
assert spec.wave_m == 1 # was warp_m in feature space
|
||||
assert spec.warp_m == 16 # was warp_tile_m in feature space
|
||||
|
||||
def test_spec_to_dispatcher_config(self):
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
spec = feature_dict_to_ml_spec(feat, 100.0)
|
||||
disp = ml_spec_to_dispatcher_config(spec, dtype="fp8", arch="gfx950")
|
||||
assert disp["tile_m"] == 128
|
||||
assert disp["wave_m"] == 1
|
||||
assert disp["warp_m"] == 16
|
||||
assert disp["gfx_arch"] == "gfx950"
|
||||
assert disp["dtype_a"] == "fp8"
|
||||
|
||||
def test_roundtrip_preserves_values(self):
|
||||
"""feature_dict -> MLKernelSpec -> dispatcher_config should be consistent."""
|
||||
feat = kernel_config_to_feature_dict(SAMPLE_KERNEL_NAME)
|
||||
spec = feature_dict_to_ml_spec(feat, 0.0)
|
||||
disp_from_spec = ml_spec_to_dispatcher_config(spec)
|
||||
disp_from_feat = feature_dict_to_dispatcher_config(feat)
|
||||
for key in [
|
||||
"tile_m",
|
||||
"tile_n",
|
||||
"tile_k",
|
||||
"wave_m",
|
||||
"wave_n",
|
||||
"wave_k",
|
||||
"warp_m",
|
||||
"warp_n",
|
||||
"warp_k",
|
||||
"pipeline",
|
||||
"scheduler",
|
||||
"epilogue",
|
||||
]:
|
||||
assert disp_from_spec[key] == disp_from_feat[key], f"Mismatch on {key}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary pool loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadKernelPool:
|
||||
def test_loads_from_real_bin_dir(self):
|
||||
bin_dir = Path("/workspace/ck_tile/bin")
|
||||
if not bin_dir.exists():
|
||||
pytest.skip("No /workspace/ck_tile/bin")
|
||||
pool = load_kernel_pool_from_binaries(bin_dir)
|
||||
assert len(pool) > 0
|
||||
assert "tile_m" in pool[0]
|
||||
assert "kernel_name" in pool[0]
|
||||
|
||||
def test_empty_dir_returns_empty(self, tmp_path):
|
||||
pool = load_kernel_pool_from_binaries(tmp_path)
|
||||
assert pool == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ML heuristic function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateMLHeuristic:
|
||||
@pytest.fixture
|
||||
def mock_model_dir(self, tmp_path):
|
||||
"""Create a minimal model for testing the heuristic flow."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
n_features = len(fe.get_feature_names())
|
||||
np.random.seed(42)
|
||||
X = np.random.rand(100, n_features)
|
||||
y = np.random.rand(100) * 500
|
||||
model = lgb.LGBMRegressor(n_estimators=5, verbose=-1)
|
||||
model.fit(X, y)
|
||||
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
return tmp_path
|
||||
|
||||
def _make_pool(self):
|
||||
"""Create a small synthetic kernel pool."""
|
||||
names = [
|
||||
"gemm_universal_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_128x128x128_1x4x1_16x16x128",
|
||||
"gemm_universal_fp8_rcr_compv4_default_intrawave_False_False_False_False_128x128x64_2x2x1_32x32x16",
|
||||
"gemm_universal_fp8_rcr_mem_cshuffle_interwave_False_False_False_False_64x64x128_1x4x1_16x16x32",
|
||||
]
|
||||
return [kernel_config_to_feature_dict(n) for n in names]
|
||||
|
||||
def test_returns_ml_kernel_spec(self, mock_model_dir):
|
||||
pool = self._make_pool()
|
||||
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
|
||||
result = heuristic(1024, 1024, 1024)
|
||||
assert isinstance(result, MLKernelSpec)
|
||||
assert result.tile_m > 0
|
||||
assert isinstance(result.predicted_tflops, float)
|
||||
|
||||
def test_returns_valid_kernel_from_pool(self, mock_model_dir):
|
||||
pool = self._make_pool()
|
||||
pool_names = {p["kernel_name"] for p in pool}
|
||||
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
|
||||
result = heuristic(1024, 1024, 1024)
|
||||
assert result.kernel_name in pool_names
|
||||
|
||||
def test_different_shapes_may_select_different_kernels(self, mock_model_dir):
|
||||
pool = self._make_pool()
|
||||
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
|
||||
r1 = heuristic(16, 1536, 7168)
|
||||
r2 = heuristic(8192, 8192, 256)
|
||||
# At minimum both should return valid specs
|
||||
assert r1.tile_m > 0
|
||||
assert r2.tile_m > 0
|
||||
|
||||
def test_m1_corner_case(self, mock_model_dir):
|
||||
pool = self._make_pool()
|
||||
heuristic = create_ml_heuristic(mock_model_dir, kernel_pool=pool)
|
||||
result = heuristic(1, 4096, 4096)
|
||||
assert isinstance(result, MLKernelSpec)
|
||||
assert np.isfinite(result.predicted_tflops)
|
||||
|
||||
def test_empty_pool_raises(self, mock_model_dir):
|
||||
with pytest.raises(ValueError, match="No kernel configs"):
|
||||
create_ml_heuristic(mock_model_dir, kernel_pool=[])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
55
dispatcher/heuristics/tests/test_evaluate.py
Normal file
55
dispatcher/heuristics/tests/test_evaluate.py
Normal file
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for evaluate.py.
|
||||
|
||||
Covers: shape family classification, K-depth regime classification,
|
||||
and basic evaluation metric checks.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from evaluate import classify_shape_family, classify_k_regime
|
||||
|
||||
|
||||
class TestClassifyShapeFamily:
|
||||
def test_tiny_m(self):
|
||||
assert classify_shape_family(1, 4096, 4096) == "tiny_m"
|
||||
assert classify_shape_family(16, 1536, 7168) == "tiny_m"
|
||||
|
||||
def test_small_m(self):
|
||||
assert classify_shape_family(32, 1536, 7168) == "small_m"
|
||||
assert classify_shape_family(128, 4096, 4096) == "small_m"
|
||||
|
||||
def test_medium_m(self):
|
||||
assert classify_shape_family(256, 1024, 1024) == "medium_m"
|
||||
assert classify_shape_family(2048, 2048, 2048) == "medium_m"
|
||||
|
||||
def test_large_m(self):
|
||||
assert classify_shape_family(4096, 4096, 4096) == "large_m"
|
||||
assert classify_shape_family(20480, 7168, 256) == "large_m"
|
||||
|
||||
|
||||
class TestClassifyKRegime:
|
||||
def test_shallow(self):
|
||||
assert classify_k_regime(256) == "shallow_k"
|
||||
assert classify_k_regime(32) == "shallow_k"
|
||||
|
||||
def test_medium(self):
|
||||
assert classify_k_regime(1024) == "medium_k"
|
||||
assert classify_k_regime(2048) == "medium_k"
|
||||
|
||||
def test_deep(self):
|
||||
assert classify_k_regime(4096) == "deep_k"
|
||||
assert classify_k_regime(7168) == "deep_k"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
409
dispatcher/heuristics/tests/test_feature_engine.py
Normal file
409
dispatcher/heuristics/tests/test_feature_engine.py
Normal file
@@ -0,0 +1,409 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for feature_engine.py.
|
||||
|
||||
Covers: feature count consistency, formula correctness (tile efficiency, LDS,
|
||||
arithmetic intensity), corner-case shapes (M=1, huge M, square, skinny-K),
|
||||
parameter space validity, config validation, and batch vs single extraction parity.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from feature_engine import (
|
||||
GemmUniversalFeatureEngine,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fe():
|
||||
"""Default feature engine with MI355X-like hardware."""
|
||||
return GemmUniversalFeatureEngine(
|
||||
num_cus=256,
|
||||
lds_capacity=65536,
|
||||
max_clock_mhz=2400,
|
||||
simds_per_cu=4,
|
||||
shader_engines=32,
|
||||
max_waves_per_cu=32,
|
||||
wavefront_size=64,
|
||||
l1_cache_kb=32,
|
||||
l2_cache_kb=4096,
|
||||
l3_cache_kb=262144,
|
||||
num_xcd=8,
|
||||
)
|
||||
|
||||
|
||||
def _make_problem(m=1024, n=1024, k=1024, dtype="fp8", layout="rcr", split_k=1):
|
||||
return {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"split_k": split_k,
|
||||
}
|
||||
|
||||
|
||||
def _make_kernel(
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=64,
|
||||
warp_m=2,
|
||||
warp_n=2,
|
||||
warp_k=1,
|
||||
warp_tile_m=32,
|
||||
warp_tile_n=32,
|
||||
warp_tile_k=16,
|
||||
pipeline="compv3",
|
||||
scheduler="intrawave",
|
||||
epilogue="cshuffle",
|
||||
pad_m=False,
|
||||
pad_n=False,
|
||||
pad_k=False,
|
||||
persistent=False,
|
||||
):
|
||||
return {
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"warp_tile_k": warp_tile_k,
|
||||
"pipeline": pipeline,
|
||||
"scheduler": scheduler,
|
||||
"epilogue": epilogue,
|
||||
"pad_m": pad_m,
|
||||
"pad_n": pad_n,
|
||||
"pad_k": pad_k,
|
||||
"persistent": persistent,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic consistency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFeatureConsistency:
|
||||
def test_feature_count_matches_names(self, fe):
|
||||
prob = _make_problem()
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
assert len(vec) == len(fe.get_feature_names())
|
||||
|
||||
def test_feature_count_is_72(self, fe):
|
||||
assert len(fe.get_feature_names()) == 72
|
||||
|
||||
def test_no_nan_in_output(self, fe):
|
||||
prob = _make_problem()
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_no_inf_in_output(self, fe):
|
||||
prob = _make_problem()
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
assert not np.any(np.isinf(vec))
|
||||
|
||||
def test_categorical_features_in_names(self, fe):
|
||||
names = fe.get_feature_names()
|
||||
for cat in fe.get_categorical_features():
|
||||
assert cat in names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Formula correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTileEfficiency:
|
||||
"""Tile efficiency: fraction of the last tile that is useful work."""
|
||||
|
||||
def test_perfectly_divisible(self, fe):
|
||||
prob = _make_problem(m=256, n=256, k=128)
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("tile_eff_m")] == 1.0
|
||||
assert vec[names.index("tile_eff_n")] == 1.0
|
||||
assert vec[names.index("tile_eff_k")] == 1.0
|
||||
assert vec[names.index("overall_tile_efficiency")] == 1.0
|
||||
|
||||
def test_not_divisible(self, fe):
|
||||
prob = _make_problem(m=100, n=100, k=100)
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("tile_eff_m")] == pytest.approx(100 / 128)
|
||||
assert vec[names.index("tile_eff_n")] == pytest.approx(100 / 128)
|
||||
assert vec[names.index("tile_eff_k")] == pytest.approx(36 / 64)
|
||||
|
||||
def test_m_equals_1(self, fe):
|
||||
"""Single-token inference: M=1, tile_m=128 => eff = 1/128."""
|
||||
prob = _make_problem(m=1)
|
||||
kern = _make_kernel(tile_m=128)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("tile_eff_m")] == pytest.approx(1.0 / 128.0)
|
||||
|
||||
|
||||
class TestLDSUsage:
|
||||
def test_lds_formula(self, fe):
|
||||
prob = _make_problem(dtype="fp8")
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
expected = (128 * 64 + 128 * 64) * 1.0 # fp8 = 1 byte
|
||||
assert vec[names.index("lds_usage_estimate")] == expected
|
||||
|
||||
def test_lds_ratio_compv4(self, fe):
|
||||
"""compv4 has 32KB LDS limit, not 64KB."""
|
||||
prob = _make_problem(dtype="fp8")
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64, pipeline="compv4")
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
lds_est = (128 * 64 + 128 * 64) * 1.0
|
||||
assert vec[names.index("lds_usage_ratio")] == pytest.approx(lds_est / 32768)
|
||||
|
||||
def test_lds_fp16_doubles(self, fe):
|
||||
prob = _make_problem(dtype="fp16")
|
||||
kern = _make_kernel(tile_m=128, tile_n=128, tile_k=64)
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
expected = (128 * 64 + 128 * 64) * 2.0 # fp16 = 2 bytes
|
||||
assert vec[names.index("lds_usage_estimate")] == expected
|
||||
|
||||
|
||||
class TestArithmeticIntensity:
|
||||
def test_square_shape(self, fe):
|
||||
M, N, K = 1024, 1024, 1024
|
||||
prob = _make_problem(m=M, n=N, k=K, dtype="fp8")
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
mem = (M * K + K * N + M * N) * 1.0
|
||||
expected = (2.0 * M * N * K) / mem
|
||||
assert vec[names.index("arithmetic_intensity")] == pytest.approx(expected)
|
||||
|
||||
def test_skinny_k(self, fe):
|
||||
"""Small K => low arithmetic intensity (memory-bound)."""
|
||||
prob = _make_problem(m=8192, n=8192, k=32, dtype="fp8")
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("arithmetic_intensity")] < 100
|
||||
|
||||
def test_deep_k(self, fe):
|
||||
"""Large K => high arithmetic intensity (compute-bound)."""
|
||||
prob = _make_problem(m=256, n=256, k=8192, dtype="fp8")
|
||||
kern = _make_kernel()
|
||||
vec = fe.extract(prob, kern)
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("arithmetic_intensity")] > 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Corner-case shapes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCornerCaseShapes:
|
||||
def test_m1_single_token(self, fe):
|
||||
vec = fe.extract(_make_problem(m=1, n=4096, k=4096), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_m1_n1_k1_minimum(self, fe):
|
||||
vec = fe.extract(_make_problem(m=1, n=1, k=1), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
assert not np.any(np.isinf(vec))
|
||||
|
||||
def test_very_large_m(self, fe):
|
||||
vec = fe.extract(_make_problem(m=20480, n=7168, k=7168), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_non_power_of_2(self, fe):
|
||||
vec = fe.extract(_make_problem(m=1536, n=7168, k=2304), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_prime_dimensions(self, fe):
|
||||
vec = fe.extract(_make_problem(m=17, n=31, k=127), _make_kernel())
|
||||
assert not np.any(np.isnan(vec))
|
||||
|
||||
def test_tall_matrix(self, fe):
|
||||
"""M >> N (tall matrix)."""
|
||||
prob = _make_problem(m=16384, n=64, k=1024)
|
||||
vec = fe.extract(prob, _make_kernel())
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("aspect_ratio_mn")] > 100
|
||||
|
||||
def test_wide_matrix(self, fe):
|
||||
"""N >> M (wide matrix)."""
|
||||
prob = _make_problem(m=64, n=16384, k=1024)
|
||||
vec = fe.extract(prob, _make_kernel())
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("aspect_ratio_mn")] < 0.01
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch vs single extraction parity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchParity:
|
||||
def test_batch_matches_single(self, fe):
|
||||
"""Vectorized batch should produce identical results to row-by-row."""
|
||||
rows = [
|
||||
{
|
||||
"m": 16,
|
||||
"n": 1536,
|
||||
"k": 7168,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"tile_m": 128,
|
||||
"tile_n": 128,
|
||||
"tile_k": 128,
|
||||
"warp_m": 1,
|
||||
"warp_n": 4,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 16,
|
||||
"warp_tile_n": 16,
|
||||
"warp_tile_k": 128,
|
||||
"pipeline": "compv3",
|
||||
"scheduler": "intrawave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
},
|
||||
{
|
||||
"m": 20480,
|
||||
"n": 7168,
|
||||
"k": 256,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"tile_m": 64,
|
||||
"tile_n": 64,
|
||||
"tile_k": 128,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": "mem",
|
||||
"scheduler": "interwave",
|
||||
"epilogue": "default",
|
||||
"pad_m": True,
|
||||
"pad_n": True,
|
||||
"pad_k": True,
|
||||
"persistent": True,
|
||||
},
|
||||
]
|
||||
df = pd.DataFrame(rows)
|
||||
batch_result = fe.extract_batch(df)
|
||||
|
||||
for i, row_dict in enumerate(rows):
|
||||
single_result = fe.extract(row_dict, row_dict)
|
||||
np.testing.assert_allclose(
|
||||
batch_result[i],
|
||||
single_result,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
err_msg=f"Mismatch at row {i}",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parameter space and validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParameterSpace:
|
||||
def test_parameter_space_non_empty(self, fe):
|
||||
ps = fe.get_parameter_space()
|
||||
assert len(ps) > 0
|
||||
assert "tile_m" in ps
|
||||
assert "pipeline" in ps
|
||||
|
||||
def test_valid_config_passes(self, fe):
|
||||
config = {
|
||||
"tile_m": 128,
|
||||
"tile_n": 128,
|
||||
"tile_k": 64,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"pipeline": "compv3",
|
||||
"scheduler": "intrawave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
}
|
||||
assert fe.validate_config(config) is True
|
||||
|
||||
def test_invalid_tile_rejected(self, fe):
|
||||
config = {"tile_m": 999}
|
||||
assert fe.validate_config(config) is False
|
||||
|
||||
def test_lds_constraint_rejects_huge_tile(self, fe):
|
||||
config = {
|
||||
"tile_m": 256,
|
||||
"tile_n": 256,
|
||||
"tile_k": 256,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"pipeline": "compv4",
|
||||
}
|
||||
assert fe.validate_config(config) is False
|
||||
|
||||
def test_project_to_valid_snaps(self, fe):
|
||||
config = {"tile_m": 100, "tile_n": 200, "pipeline": "compv3"}
|
||||
projected = fe.project_to_valid(config)
|
||||
assert projected["tile_m"] == 128
|
||||
assert projected["tile_n"] == 192
|
||||
assert projected["pipeline"] == "compv3"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hardware features
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHardwareFeatures:
|
||||
def test_hardware_values_propagated(self, fe):
|
||||
vec = fe.extract(_make_problem(), _make_kernel())
|
||||
names = fe.get_feature_names()
|
||||
assert vec[names.index("hw_num_cus")] == 256
|
||||
assert vec[names.index("hw_max_clock_mhz")] == 2400
|
||||
assert vec[names.index("hw_total_simds")] == 256 * 4
|
||||
assert vec[names.index("hw_num_xcd")] == 8
|
||||
|
||||
def test_different_hardware(self):
|
||||
fe_small = GemmUniversalFeatureEngine(num_cus=120, max_clock_mhz=1800)
|
||||
vec = fe_small.extract(_make_problem(), _make_kernel())
|
||||
names = fe_small.get_feature_names()
|
||||
assert vec[names.index("hw_num_cus")] == 120
|
||||
assert vec[names.index("hw_max_clock_mhz")] == 1800
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
357
dispatcher/heuristics/tests/test_feature_parity.py
Normal file
357
dispatcher/heuristics/tests/test_feature_parity.py
Normal file
@@ -0,0 +1,357 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Test that the C++ extract_features() in ml_heuristic.hpp produces identical
|
||||
values to the Python GemmUniversalFeatureEngine.extract().
|
||||
|
||||
This test uses ctypes to call the C++ feature extraction compiled into a
|
||||
small shared library, then compares against Python output. If compilation
|
||||
fails (no HIP/ROCm), it falls back to verifying the Python feature engine
|
||||
against manually computed expected values for specific test cases.
|
||||
"""
|
||||
|
||||
import math
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
from feature_engine import (
|
||||
GemmUniversalFeatureEngine,
|
||||
PIPELINE_MAP,
|
||||
SCHEDULER_MAP,
|
||||
EPILOGUE_MAP,
|
||||
LAYOUT_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _compute_features_manually(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
split_k,
|
||||
dtype,
|
||||
layout,
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline,
|
||||
scheduler,
|
||||
epilogue,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
hw,
|
||||
):
|
||||
"""Recompute features independently to verify the Python engine."""
|
||||
bpe_map = {"fp8": 1.0, "fp16": 2.0, "bf16": 2.0, "fp32": 4.0}
|
||||
bpe = bpe_map.get(dtype, 1.0)
|
||||
|
||||
log2_M = math.log2(max(M, 1))
|
||||
log2_N = math.log2(max(N, 1))
|
||||
log2_K = math.log2(max(K, 1))
|
||||
log2_MNK = math.log2(max(M * N * K, 1))
|
||||
mem = (M * K + K * N + M * N) * bpe
|
||||
ai = (2.0 * M * N * K) / max(mem, 1)
|
||||
|
||||
lds_est = (tile_m * tile_k + tile_n * tile_k) * bpe
|
||||
lds_cap = 32768 if pipeline == "compv4" else hw["lds_capacity"]
|
||||
|
||||
ntm = math.ceil(M / max(tile_m, 1))
|
||||
ntn = math.ceil(N / max(tile_n, 1))
|
||||
ntk = math.ceil(K / max(tile_k, 1))
|
||||
|
||||
def eff(d, t):
|
||||
if t <= 0:
|
||||
return 1.0
|
||||
r = d % t
|
||||
return r / t if r > 0 else 1.0
|
||||
|
||||
# Problem-to-tile ratios
|
||||
ratio_M_to_tile_m = M / max(tile_m, 1)
|
||||
ratio_N_to_tile_n = N / max(tile_n, 1)
|
||||
ratio_K_to_tile_k = K / max(tile_k, 1)
|
||||
|
||||
# Binary features: problem smaller than tile
|
||||
problem_smaller_than_tile_m = float(M < tile_m)
|
||||
problem_smaller_than_tile_n = float(N < tile_n)
|
||||
problem_smaller_than_tile_k = float(K < tile_k)
|
||||
any_dim_too_small = float((M < tile_m) or (N < tile_n) or (K < tile_k))
|
||||
|
||||
# Padding requirement features
|
||||
needs_padding_m = float(tile_m > 0 and M % tile_m != 0)
|
||||
needs_padding_n = float(tile_n > 0 and N % tile_n != 0)
|
||||
needs_padding_k = float(tile_k > 0 and K % tile_k != 0)
|
||||
|
||||
# Interaction features
|
||||
has_padding_when_needed_m = float(needs_padding_m and pad_m)
|
||||
has_padding_when_needed_n = float(needs_padding_n and pad_n)
|
||||
has_padding_when_needed_k = float(needs_padding_k and pad_k)
|
||||
|
||||
# Missing padding features
|
||||
missing_required_padding_m = float(needs_padding_m and not pad_m)
|
||||
missing_required_padding_n = float(needs_padding_n and not pad_n)
|
||||
missing_required_padding_k = float(needs_padding_k and not pad_k)
|
||||
missing_any_required_padding = float(
|
||||
missing_required_padding_m or missing_required_padding_n or missing_required_padding_k
|
||||
)
|
||||
|
||||
return [
|
||||
M, # 0
|
||||
N, # 1
|
||||
K, # 2
|
||||
split_k, # 3
|
||||
log2_M, # 4
|
||||
log2_N, # 5
|
||||
log2_K, # 6
|
||||
log2_MNK, # 7
|
||||
ai, # 8
|
||||
M / max(N, 1), # 9 (aspect_ratio_mn)
|
||||
M / max(K, 1), # 10 (aspect_ratio_mk)
|
||||
N / max(K, 1), # 11 (aspect_ratio_nk)
|
||||
LAYOUT_MAP.get(layout, 0), # 12
|
||||
tile_m, # 13
|
||||
tile_n, # 14
|
||||
tile_k, # 15
|
||||
warp_m, # 16
|
||||
warp_n, # 17
|
||||
warp_k, # 18
|
||||
warp_tile_m, # 19
|
||||
warp_tile_n, # 20
|
||||
warp_tile_k, # 21
|
||||
PIPELINE_MAP.get(pipeline, 0), # 22
|
||||
SCHEDULER_MAP.get(scheduler, 0), # 23
|
||||
EPILOGUE_MAP.get(epilogue, 0), # 24
|
||||
float(pad_m), # 25
|
||||
float(pad_n), # 26
|
||||
float(pad_k), # 27
|
||||
float(persistent), # 28
|
||||
warp_m * warp_n * warp_k, # 29 (num_warps)
|
||||
tile_m * tile_n * tile_k, # 30 (tile_volume)
|
||||
tile_m * tile_n, # 31 (tile_mn)
|
||||
lds_est, # 32 (lds_usage_estimate)
|
||||
lds_est / max(lds_cap, 1), # 33 (lds_usage_ratio)
|
||||
ntm, # 34 (num_tiles_m)
|
||||
ntn, # 35 (num_tiles_n)
|
||||
ntk, # 36 (num_tiles_k)
|
||||
ntm * ntn, # 37 (total_output_tiles)
|
||||
eff(M, tile_m), # 38 (tile_eff_m)
|
||||
eff(N, tile_n), # 39 (tile_eff_n)
|
||||
eff(K, tile_k), # 40 (tile_eff_k)
|
||||
eff(M, tile_m) * eff(N, tile_n) * eff(K, tile_k), # 41 (overall_tile_efficiency)
|
||||
ntm * ntn / max(hw["num_cus"], 1), # 42 (cu_utilization)
|
||||
ratio_M_to_tile_m, # 43
|
||||
ratio_N_to_tile_n, # 44
|
||||
ratio_K_to_tile_k, # 45
|
||||
problem_smaller_than_tile_m, # 46
|
||||
problem_smaller_than_tile_n, # 47
|
||||
problem_smaller_than_tile_k, # 48
|
||||
any_dim_too_small, # 49
|
||||
needs_padding_m, # 50
|
||||
needs_padding_n, # 51
|
||||
needs_padding_k, # 52
|
||||
has_padding_when_needed_m, # 53
|
||||
has_padding_when_needed_n, # 54
|
||||
has_padding_when_needed_k, # 55
|
||||
missing_required_padding_m, # 56
|
||||
missing_required_padding_n, # 57
|
||||
missing_required_padding_k, # 58
|
||||
missing_any_required_padding, # 59
|
||||
hw["num_cus"], # 60
|
||||
hw["simds_per_cu"], # 61
|
||||
hw["num_cus"] * hw["simds_per_cu"], # 62 (total_simds)
|
||||
hw["shader_engines"], # 63
|
||||
hw["max_clock_mhz"], # 64
|
||||
hw["max_waves_per_cu"], # 65
|
||||
hw["wavefront_size"], # 66
|
||||
hw["lds_capacity"], # 67
|
||||
hw["l1_cache_kb"], # 68
|
||||
hw["l2_cache_kb"], # 69
|
||||
hw["l3_cache_kb"], # 70
|
||||
hw["num_xcd"], # 71
|
||||
]
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
{
|
||||
"problem": {
|
||||
"m": 1024,
|
||||
"n": 1024,
|
||||
"k": 1024,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
},
|
||||
"kernel": {
|
||||
"tile_m": 128,
|
||||
"tile_n": 128,
|
||||
"tile_k": 64,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": "compv3",
|
||||
"scheduler": "intrawave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"problem": {
|
||||
"m": 1,
|
||||
"n": 4096,
|
||||
"k": 4096,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
},
|
||||
"kernel": {
|
||||
"tile_m": 64,
|
||||
"tile_n": 64,
|
||||
"tile_k": 128,
|
||||
"warp_m": 1,
|
||||
"warp_n": 4,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 16,
|
||||
"warp_tile_n": 16,
|
||||
"warp_tile_k": 128,
|
||||
"pipeline": "compv4",
|
||||
"scheduler": "interwave",
|
||||
"epilogue": "default",
|
||||
"pad_m": True,
|
||||
"pad_n": True,
|
||||
"pad_k": True,
|
||||
"persistent": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"problem": {
|
||||
"m": 20480,
|
||||
"n": 7168,
|
||||
"k": 256,
|
||||
"split_k": 1,
|
||||
"dtype": "fp16",
|
||||
"layout": "rrr",
|
||||
},
|
||||
"kernel": {
|
||||
"tile_m": 256,
|
||||
"tile_n": 256,
|
||||
"tile_k": 32,
|
||||
"warp_m": 4,
|
||||
"warp_n": 1,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": "mem",
|
||||
"scheduler": "interwave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
HW = {
|
||||
"num_cus": 256,
|
||||
"simds_per_cu": 4,
|
||||
"shader_engines": 32,
|
||||
"max_clock_mhz": 2400,
|
||||
"max_waves_per_cu": 32,
|
||||
"wavefront_size": 64,
|
||||
"lds_capacity": 65536,
|
||||
"l1_cache_kb": 32,
|
||||
"l2_cache_kb": 4096,
|
||||
"l3_cache_kb": 262144,
|
||||
"num_xcd": 8,
|
||||
}
|
||||
|
||||
|
||||
class TestFeatureParity:
|
||||
"""Verify Python feature engine matches manual computation (C++ uses same logic)."""
|
||||
|
||||
@pytest.fixture
|
||||
def fe(self):
|
||||
return GemmUniversalFeatureEngine(**HW)
|
||||
|
||||
@pytest.mark.parametrize("case_idx", range(len(TEST_CASES)))
|
||||
def test_python_matches_manual(self, fe, case_idx):
|
||||
case = TEST_CASES[case_idx]
|
||||
prob = case["problem"]
|
||||
kern = case["kernel"]
|
||||
|
||||
py_features = fe.extract(prob, kern)
|
||||
|
||||
manual = _compute_features_manually(
|
||||
prob["m"],
|
||||
prob["n"],
|
||||
prob["k"],
|
||||
prob["split_k"],
|
||||
prob["dtype"],
|
||||
prob["layout"],
|
||||
kern["tile_m"],
|
||||
kern["tile_n"],
|
||||
kern["tile_k"],
|
||||
kern["warp_m"],
|
||||
kern["warp_n"],
|
||||
kern["warp_k"],
|
||||
kern["warp_tile_m"],
|
||||
kern["warp_tile_n"],
|
||||
kern["warp_tile_k"],
|
||||
kern["pipeline"],
|
||||
kern["scheduler"],
|
||||
kern["epilogue"],
|
||||
kern["pad_m"],
|
||||
kern["pad_n"],
|
||||
kern["pad_k"],
|
||||
kern["persistent"],
|
||||
HW,
|
||||
)
|
||||
|
||||
manual_arr = np.array(manual, dtype=np.float64)
|
||||
assert len(py_features) == len(manual_arr) == 72
|
||||
|
||||
for i in range(72):
|
||||
assert py_features[i] == pytest.approx(
|
||||
manual_arr[i], rel=1e-10, abs=1e-15
|
||||
), (
|
||||
f"Feature {i} ({fe.get_feature_names()[i]}): Python={py_features[i]}, Manual={manual_arr[i]}"
|
||||
)
|
||||
|
||||
def test_feature_count(self, fe):
|
||||
assert len(fe.get_feature_names()) == 72
|
||||
|
||||
def test_encoding_maps_match_cpp(self):
|
||||
"""The C++ encode_* functions must use the same mapping as Python."""
|
||||
assert PIPELINE_MAP == {
|
||||
"compv3": 0,
|
||||
"compv4": 1,
|
||||
"compv5": 2,
|
||||
"mem": 3,
|
||||
"preshufflev2": 4,
|
||||
}
|
||||
assert SCHEDULER_MAP == {"intrawave": 0, "interwave": 1}
|
||||
assert EPILOGUE_MAP == {"default": 0, "cshuffle": 1}
|
||||
assert LAYOUT_MAP == {"rcr": 0, "rrr": 1, "crr": 2, "ccr": 3}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
90
dispatcher/heuristics/tests/test_model_compression.py
Normal file
90
dispatcher/heuristics/tests/test_model_compression.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test that compressed models can be loaded and used."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
def test_fp16_model_decompression():
|
||||
"""Test that fp16 model is auto-decompressed and usable."""
|
||||
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp16_gfx950"
|
||||
|
||||
# Ensure .lgbm.gz exists
|
||||
gz_file = model_dir / "model_tflops.lgbm.gz"
|
||||
|
||||
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
|
||||
|
||||
# Load predictor - should auto-decompress
|
||||
predictor = Predictor(model_dir)
|
||||
|
||||
# Test prediction
|
||||
problem = {"m": 128, "n": 1536, "k": 7168, "dtype": "fp16", "layout": "rcr"}
|
||||
kernel_config = {
|
||||
"tile_shape": {"m0": 128, "n0": 128, "k0": 16},
|
||||
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
|
||||
"warp_tile": {"m2": 32, "n2": 32, "k2": 8},
|
||||
}
|
||||
|
||||
tflops = predictor.predict_tflops(problem, kernel_config)
|
||||
|
||||
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
|
||||
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
|
||||
|
||||
# Verify decompressed file was created
|
||||
lgbm_file = model_dir / "model_tflops.lgbm"
|
||||
assert lgbm_file.exists(), "Model should have been decompressed"
|
||||
|
||||
print(f"✅ FP16 model decompression test passed")
|
||||
print(f" Predicted TFLOPS: {tflops:.2f}")
|
||||
print(f" Decompressed to: {lgbm_file}")
|
||||
return True
|
||||
|
||||
|
||||
def test_fp8_model_decompression():
|
||||
"""Test that fp8 model is auto-decompressed and usable."""
|
||||
model_dir = Path(__file__).parent.parent / "models" / "gemm_universal_fp8_gfx950"
|
||||
|
||||
# Ensure .lgbm.gz exists
|
||||
gz_file = model_dir / "model_tflops.lgbm.gz"
|
||||
|
||||
assert gz_file.exists(), f"Compressed model not found: {gz_file}"
|
||||
|
||||
# Load predictor - should auto-decompress
|
||||
predictor = Predictor(model_dir)
|
||||
|
||||
# Test prediction
|
||||
problem = {"m": 2048, "n": 2048, "k": 2048, "dtype": "fp8", "layout": "rcr"}
|
||||
kernel_config = {
|
||||
"tile_shape": {"m0": 256, "n0": 256, "k0": 64},
|
||||
"wave_shape": {"m1": 2, "n1": 2, "k1": 1},
|
||||
"warp_tile": {"m2": 32, "n2": 32, "k2": 16},
|
||||
}
|
||||
|
||||
tflops = predictor.predict_tflops(problem, kernel_config)
|
||||
|
||||
assert isinstance(tflops, float), f"Expected float, got {type(tflops)}"
|
||||
assert tflops > 0, f"Expected positive TFLOPS, got {tflops}"
|
||||
|
||||
# Verify decompressed file was created
|
||||
lgbm_file = model_dir / "model_tflops.lgbm"
|
||||
assert lgbm_file.exists(), "Model should have been decompressed"
|
||||
|
||||
print(f"✅ FP8 model decompression test passed")
|
||||
print(f" Predicted TFLOPS: {tflops:.2f}")
|
||||
print(f" Decompressed to: {lgbm_file}")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing compressed model auto-decompression...")
|
||||
print()
|
||||
|
||||
test_fp16_model_decompression()
|
||||
print()
|
||||
test_fp8_model_decompression()
|
||||
print()
|
||||
print("✅ All model compression tests passed!")
|
||||
181
dispatcher/heuristics/tests/test_predict.py
Normal file
181
dispatcher/heuristics/tests/test_predict.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for predict.py.
|
||||
|
||||
Covers: Predictor initialization, single prediction, ranking, select_best,
|
||||
missing model handling, and edge cases (single kernel, empty list).
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_dir(tmp_path):
|
||||
"""Create a minimal trained model for testing."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
n_features = len(fe.get_feature_names())
|
||||
|
||||
np.random.seed(42)
|
||||
X = np.random.rand(200, n_features)
|
||||
y = np.random.rand(200) * 100
|
||||
|
||||
model = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
|
||||
model.fit(X, y)
|
||||
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
|
||||
|
||||
y_lat = np.random.rand(200) * 0.1
|
||||
model_lat = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
|
||||
model_lat.fit(X, y_lat)
|
||||
model_lat.booster_.save_model(str(tmp_path / "model_latency.lgbm"))
|
||||
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictor(model_dir):
|
||||
return Predictor(model_dir)
|
||||
|
||||
|
||||
def _problem():
|
||||
return {
|
||||
"m": 1024,
|
||||
"n": 1024,
|
||||
"k": 1024,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
|
||||
def _kernel(tile_m=128, pipeline="compv3"):
|
||||
return {
|
||||
"kernel_name": f"test_kernel_{tile_m}_{pipeline}",
|
||||
"tile_m": tile_m,
|
||||
"tile_n": 128,
|
||||
"tile_k": 64,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": pipeline,
|
||||
"scheduler": "intrawave",
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
}
|
||||
|
||||
|
||||
class TestPredictor:
|
||||
def test_predict_tflops_returns_float(self, predictor):
|
||||
result = predictor.predict_tflops(_problem(), _kernel())
|
||||
assert isinstance(result, float)
|
||||
|
||||
def test_predict_latency_returns_float(self, predictor):
|
||||
result = predictor.predict_latency(_problem(), _kernel())
|
||||
assert isinstance(result, float)
|
||||
|
||||
def test_predict_all_returns_dict(self, predictor):
|
||||
result = predictor.predict_all(_problem(), _kernel())
|
||||
assert "tflops" in result
|
||||
assert "latency_ms" in result
|
||||
|
||||
def test_rank_kernels_sorted_descending(self, predictor):
|
||||
kernels = [_kernel(64, "compv3"), _kernel(128, "compv4"), _kernel(256, "mem")]
|
||||
ranked = predictor.rank_kernels(_problem(), kernels)
|
||||
assert len(ranked) == 3
|
||||
scores = [s for _, s in ranked]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_select_best_returns_name(self, predictor):
|
||||
kernels = [_kernel(64), _kernel(128)]
|
||||
best = predictor.select_best(_problem(), kernels)
|
||||
assert isinstance(best, str)
|
||||
assert best in [k["kernel_name"] for k in kernels]
|
||||
|
||||
def test_single_kernel(self, predictor):
|
||||
kernels = [_kernel(128)]
|
||||
ranked = predictor.rank_kernels(_problem(), kernels)
|
||||
assert len(ranked) == 1
|
||||
|
||||
def test_missing_bandwidth_model(self, model_dir):
|
||||
pred = Predictor(model_dir)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
pred.predict_bandwidth(_problem(), _kernel())
|
||||
|
||||
def test_empty_kernel_list(self, predictor):
|
||||
with pytest.raises(ValueError):
|
||||
predictor.select_best(_problem(), [])
|
||||
|
||||
def test_corner_case_m1(self, predictor):
|
||||
prob = {
|
||||
"m": 1,
|
||||
"n": 4096,
|
||||
"k": 4096,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
}
|
||||
result = predictor.predict_tflops(prob, _kernel())
|
||||
assert np.isfinite(result)
|
||||
|
||||
def test_different_shapes_give_different_results(self, predictor):
|
||||
k = _kernel()
|
||||
r1 = predictor.predict_tflops(
|
||||
{
|
||||
"m": 16,
|
||||
"n": 1536,
|
||||
"k": 7168,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
k,
|
||||
)
|
||||
r2 = predictor.predict_tflops(
|
||||
{
|
||||
"m": 20480,
|
||||
"n": 7168,
|
||||
"k": 256,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
k,
|
||||
)
|
||||
assert r1 != r2
|
||||
|
||||
|
||||
class TestPredictorEdgeCases:
|
||||
def test_nonexistent_model_dir(self):
|
||||
with pytest.raises(Exception):
|
||||
pred = Predictor("/nonexistent/path")
|
||||
pred.predict_tflops(_problem(), _kernel())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
192
dispatcher/heuristics/tests/test_search.py
Normal file
192
dispatcher/heuristics/tests/test_search.py
Normal file
@@ -0,0 +1,192 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for search.py.
|
||||
|
||||
Covers: random search, DE search, config validity, result ordering,
|
||||
budget compliance, and edge cases.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from predict import Predictor
|
||||
from search import SurrogateSearch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_dir(tmp_path):
|
||||
"""Create a minimal trained model."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
n_features = len(fe.get_feature_names())
|
||||
np.random.seed(42)
|
||||
X = np.random.rand(200, n_features)
|
||||
y = np.random.rand(200) * 500
|
||||
model = lgb.LGBMRegressor(n_estimators=10, verbose=-1)
|
||||
model.fit(X, y)
|
||||
model.booster_.save_model(str(tmp_path / "model_tflops.lgbm"))
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictor(model_dir):
|
||||
return Predictor(model_dir)
|
||||
|
||||
|
||||
def _problem():
|
||||
return {
|
||||
"m": 1024,
|
||||
"n": 1024,
|
||||
"k": 1024,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
|
||||
class TestRandomSearch:
|
||||
def test_returns_results(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(_problem(), budget=50, top_k=5)
|
||||
assert len(results) > 0
|
||||
assert len(results) <= 5
|
||||
|
||||
def test_results_sorted_descending(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(_problem(), budget=100, top_k=10)
|
||||
scores = [s for _, s in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_configs_are_valid(self, predictor):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
searcher = SurrogateSearch(predictor, feature_engine=fe, strategy="random")
|
||||
results = searcher.search(_problem(), budget=50, top_k=5)
|
||||
for cfg, _ in results:
|
||||
ps = fe.get_parameter_space()
|
||||
for k, v in cfg.items():
|
||||
if k in ps:
|
||||
assert v in ps[k], f"{k}={v} not in {ps[k]}"
|
||||
|
||||
def test_respects_top_k(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(_problem(), budget=100, top_k=3)
|
||||
assert len(results) <= 3
|
||||
|
||||
def test_different_problems_produce_results(self, predictor):
|
||||
"""Both problem sizes should produce valid search results."""
|
||||
searcher = SurrogateSearch(predictor, strategy="random", seed=42)
|
||||
r1 = searcher.search(
|
||||
{
|
||||
"m": 16,
|
||||
"n": 1536,
|
||||
"k": 7168,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
budget=50,
|
||||
top_k=3,
|
||||
)
|
||||
searcher2 = SurrogateSearch(predictor, strategy="random", seed=42)
|
||||
r2 = searcher2.search(
|
||||
{
|
||||
"m": 20480,
|
||||
"n": 7168,
|
||||
"k": 256,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
budget=50,
|
||||
top_k=3,
|
||||
)
|
||||
assert len(r1) > 0
|
||||
assert len(r2) > 0
|
||||
for _, score in r1 + r2:
|
||||
assert np.isfinite(score)
|
||||
|
||||
def test_m1_corner_case(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(
|
||||
{
|
||||
"m": 1,
|
||||
"n": 4096,
|
||||
"k": 4096,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"split_k": 1,
|
||||
},
|
||||
budget=50,
|
||||
top_k=5,
|
||||
)
|
||||
assert len(results) > 0
|
||||
for _, score in results:
|
||||
assert np.isfinite(score)
|
||||
|
||||
|
||||
class TestDESearch:
|
||||
def test_returns_results(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="de")
|
||||
results = searcher.search(_problem(), budget=100, top_k=5)
|
||||
assert len(results) > 0
|
||||
|
||||
def test_results_sorted_descending(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="de")
|
||||
results = searcher.search(_problem(), budget=100, top_k=5)
|
||||
scores = [s for _, s in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_de_improves_over_initial(self, predictor):
|
||||
"""DE should generally find at least as good as random initialization."""
|
||||
searcher_r = SurrogateSearch(predictor, strategy="random", seed=42)
|
||||
r_results = searcher_r.search(_problem(), budget=100, top_k=1)
|
||||
searcher_d = SurrogateSearch(predictor, strategy="de", seed=42)
|
||||
d_results = searcher_d.search(_problem(), budget=100, top_k=1)
|
||||
if r_results and d_results:
|
||||
assert d_results[0][1] >= r_results[0][1] * 0.9
|
||||
|
||||
def test_small_budget(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="de")
|
||||
results = searcher.search(_problem(), budget=30, top_k=5)
|
||||
assert len(results) > 0
|
||||
|
||||
|
||||
class TestSearchEdgeCases:
|
||||
def test_unknown_strategy_raises(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="unknown")
|
||||
with pytest.raises(ValueError):
|
||||
searcher.search(_problem(), budget=10)
|
||||
|
||||
def test_zero_budget(self, predictor):
|
||||
searcher = SurrogateSearch(predictor, strategy="random")
|
||||
results = searcher.search(_problem(), budget=0, top_k=5)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_deterministic_with_same_seed(self, predictor):
|
||||
s1 = SurrogateSearch(predictor, strategy="random", seed=123)
|
||||
s2 = SurrogateSearch(predictor, strategy="random", seed=123)
|
||||
r1 = s1.search(_problem(), budget=50, top_k=5)
|
||||
r2 = s2.search(_problem(), budget=50, top_k=5)
|
||||
assert len(r1) == len(r2)
|
||||
for (c1, s1_), (c2, s2_) in zip(r1, r2):
|
||||
assert s1_ == pytest.approx(s2_)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
329
dispatcher/heuristics/tests/test_train.py
Normal file
329
dispatcher/heuristics/tests/test_train.py
Normal file
@@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for train.py.
|
||||
|
||||
Covers: group key computation, TFLOPS efficiency calculation, edge cases
|
||||
(single group, all-invalid data, tied predictions), and warm-start
|
||||
incremental training (feature compat, lineage, quality).
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
from train import (
|
||||
compute_group_keys,
|
||||
compute_tflops_efficiency,
|
||||
check_feature_compatibility,
|
||||
load_warm_start_model,
|
||||
train_final_model,
|
||||
DEFAULT_PARAMS,
|
||||
)
|
||||
|
||||
|
||||
class TestComputeGroupKeys:
|
||||
def test_basic(self):
|
||||
df = pd.DataFrame(
|
||||
{"m": [16, 16, 32], "n": [1536, 1536, 1536], "k": [7168, 7168, 7168]}
|
||||
)
|
||||
keys = compute_group_keys(df)
|
||||
assert keys[0] == keys[1]
|
||||
assert keys[0] != keys[2]
|
||||
|
||||
def test_unique_shapes(self):
|
||||
df = pd.DataFrame({"m": [1, 2, 3], "n": [4, 5, 6], "k": [7, 8, 9]})
|
||||
keys = compute_group_keys(df)
|
||||
assert len(set(keys)) == 3
|
||||
|
||||
|
||||
class TestComputeTflopsEfficiency:
|
||||
def test_perfect_prediction(self):
|
||||
"""Model predicts highest TFLOPS kernel => efficiency = 1.0."""
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [1024, 1024, 1024],
|
||||
"n": [1024, 1024, 1024],
|
||||
"k": [1024, 1024, 1024],
|
||||
"measured_tflops": [100, 200, 150],
|
||||
"pred_tflops": [50, 300, 100], # correctly ranks kernel 1 highest
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 1
|
||||
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
|
||||
|
||||
def test_worst_prediction(self):
|
||||
"""Model picks the worst kernel."""
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [1024, 1024, 1024],
|
||||
"n": [1024, 1024, 1024],
|
||||
"k": [1024, 1024, 1024],
|
||||
"measured_tflops": [100, 200, 150],
|
||||
"pred_tflops": [999, 1, 1], # incorrectly ranks kernel 0 highest
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert eff["efficiency"].iloc[0] == pytest.approx(100 / 200)
|
||||
|
||||
def test_multiple_shapes(self):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [16, 16, 32, 32],
|
||||
"n": [1536, 1536, 1536, 1536],
|
||||
"k": [7168, 7168, 7168, 7168],
|
||||
"measured_tflops": [10, 20, 100, 200],
|
||||
"pred_tflops": [5, 25, 150, 190],
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 2
|
||||
assert eff.iloc[0]["efficiency"] == pytest.approx(1.0)
|
||||
assert eff.iloc[1]["efficiency"] == pytest.approx(1.0)
|
||||
|
||||
def test_zero_tflops_shape_skipped(self):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [16, 16],
|
||||
"n": [16, 16],
|
||||
"k": [16, 16],
|
||||
"measured_tflops": [0, 0],
|
||||
"pred_tflops": [1, 2],
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 0
|
||||
|
||||
def test_single_kernel_per_shape(self):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [1024],
|
||||
"n": [1024],
|
||||
"k": [1024],
|
||||
"measured_tflops": [150],
|
||||
"pred_tflops": [100],
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 1
|
||||
assert eff["efficiency"].iloc[0] == pytest.approx(1.0)
|
||||
|
||||
def test_tied_predictions(self):
|
||||
"""When multiple kernels have the same predicted TFLOPS, pandas idxmax picks the first."""
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"m": [1024, 1024, 1024],
|
||||
"n": [1024, 1024, 1024],
|
||||
"k": [1024, 1024, 1024],
|
||||
"measured_tflops": [100, 200, 200],
|
||||
"pred_tflops": [50, 50, 50],
|
||||
}
|
||||
)
|
||||
eff = compute_tflops_efficiency(df, "pred_tflops")
|
||||
assert len(eff) == 1
|
||||
assert eff["efficiency"].iloc[0] >= 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for warm-start tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dummy_data(n_rows=200, n_shapes=5):
|
||||
"""Create a small synthetic benchmark DataFrame for testing training."""
|
||||
rng = np.random.RandomState(42)
|
||||
rows = []
|
||||
for _ in range(n_rows):
|
||||
m = rng.choice([64, 128, 256, 512, 1024])
|
||||
n = rng.choice([64, 128, 256, 512, 1024])
|
||||
k = rng.choice([64, 128, 256, 512, 1024])
|
||||
rows.append(
|
||||
{
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"split_k": 1,
|
||||
"dtype": "fp8",
|
||||
"layout": "rcr",
|
||||
"op_type": "gemm_universal",
|
||||
"tile_m": rng.choice([64, 128, 256]),
|
||||
"tile_n": rng.choice([64, 128, 256]),
|
||||
"tile_k": rng.choice([32, 64, 128]),
|
||||
"warp_m": rng.choice([1, 2, 4]),
|
||||
"warp_n": rng.choice([1, 2, 4]),
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": rng.choice(["compv3", "compv4", "mem"]),
|
||||
"scheduler": rng.choice(["intrawave", "interwave"]),
|
||||
"epilogue": "cshuffle",
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
"measured_tflops": float(rng.uniform(10, 500)),
|
||||
"latency_ms": float(rng.uniform(0.01, 1.0)),
|
||||
"bandwidth_gb_s": float(rng.uniform(50, 1500)),
|
||||
"is_valid": True,
|
||||
"kernel_name": f"test_kernel_{rng.randint(0, 100)}",
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def _save_feature_spec(model_dir, fe):
|
||||
"""Save a feature_spec.json matching the given feature engine."""
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(model_dir / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
|
||||
|
||||
def _train_and_save_base_model(model_dir, df, fe, target="tflops"):
|
||||
"""Train a small base model and save it to model_dir."""
|
||||
params = dict(DEFAULT_PARAMS)
|
||||
params["n_estimators"] = 20
|
||||
params["n_jobs"] = 1
|
||||
model = train_final_model(df, fe, target, params)
|
||||
model.booster_.save_model(str(model_dir / f"model_{target}.lgbm"))
|
||||
_save_feature_spec(model_dir, fe)
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Warm-start tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckFeatureCompatibility:
|
||||
def test_compatible_passes(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
_save_feature_spec(tmp_path, fe)
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
def test_missing_spec_raises(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
with pytest.raises(FileNotFoundError, match="feature_spec.json"):
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
def test_added_feature_raises(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names()[:-1],
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
with pytest.raises(ValueError, match="Feature schema mismatch"):
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
def test_removed_feature_raises(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names() + ["extra_feature"],
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
with pytest.raises(ValueError, match="Feature schema mismatch"):
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
def test_categorical_mismatch_raises(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
spec = {
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": ["layout", "pipeline"],
|
||||
}
|
||||
with open(tmp_path / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f)
|
||||
with pytest.raises(ValueError, match="Categorical feature mismatch"):
|
||||
check_feature_compatibility(tmp_path, fe)
|
||||
|
||||
|
||||
class TestLoadWarmStartModel:
|
||||
def test_loads_existing_model(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
df = _make_dummy_data()
|
||||
_train_and_save_base_model(tmp_path, df, fe)
|
||||
path = load_warm_start_model(tmp_path, "tflops")
|
||||
assert path is not None
|
||||
assert Path(path).exists()
|
||||
|
||||
def test_returns_none_for_missing_target(self, tmp_path):
|
||||
assert load_warm_start_model(tmp_path, "tflops") is None
|
||||
|
||||
def test_returns_none_for_wrong_target(self, tmp_path):
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
df = _make_dummy_data()
|
||||
_train_and_save_base_model(tmp_path, df, fe, target="tflops")
|
||||
assert load_warm_start_model(tmp_path, "bandwidth") is None
|
||||
|
||||
|
||||
class TestWarmStartTraining:
|
||||
def test_warm_start_produces_more_trees(self, tmp_path):
|
||||
"""A warm-started model should have more trees than the base."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
df = _make_dummy_data(n_rows=300)
|
||||
|
||||
base_dir = tmp_path / "base"
|
||||
base_dir.mkdir()
|
||||
base_model = _train_and_save_base_model(base_dir, df, fe)
|
||||
base_n_trees = base_model.booster_.num_trees()
|
||||
|
||||
init_model_path = load_warm_start_model(base_dir, "tflops")
|
||||
params = dict(DEFAULT_PARAMS)
|
||||
params["n_estimators"] = 15
|
||||
params["n_jobs"] = 1
|
||||
warm_model = train_final_model(
|
||||
df, fe, "tflops", params, init_model=init_model_path
|
||||
)
|
||||
warm_n_trees = warm_model.booster_.num_trees()
|
||||
|
||||
assert warm_n_trees > base_n_trees
|
||||
|
||||
def test_warm_start_does_not_degrade(self, tmp_path):
|
||||
"""Warm-started model on the same data should not be significantly worse."""
|
||||
fe = GemmUniversalFeatureEngine()
|
||||
df = _make_dummy_data(n_rows=300)
|
||||
|
||||
base_dir = tmp_path / "base"
|
||||
base_dir.mkdir()
|
||||
base_model = _train_and_save_base_model(base_dir, df, fe)
|
||||
|
||||
X = fe.extract_batch(df[df["is_valid"]].reset_index(drop=True))
|
||||
y = df[df["is_valid"]]["measured_tflops"].values
|
||||
base_rmse = np.sqrt(np.mean((base_model.predict(X) - y) ** 2))
|
||||
|
||||
init_model_path = load_warm_start_model(base_dir, "tflops")
|
||||
params = dict(DEFAULT_PARAMS)
|
||||
params["n_estimators"] = 15
|
||||
params["n_jobs"] = 1
|
||||
warm_model = train_final_model(
|
||||
df, fe, "tflops", params, init_model=init_model_path
|
||||
)
|
||||
warm_rmse = np.sqrt(np.mean((warm_model.predict(X) - y) ** 2))
|
||||
|
||||
assert warm_rmse <= base_rmse * 1.1
|
||||
|
||||
def test_warm_start_from_nonexistent_dir(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
check_feature_compatibility(
|
||||
Path("/nonexistent/model/dir"), GemmUniversalFeatureEngine()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
555
dispatcher/heuristics/train.py
Normal file
555
dispatcher/heuristics/train.py
Normal file
@@ -0,0 +1,555 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Training script for CK Tile kernel performance prediction.
|
||||
|
||||
Trains LGBMRegressor models (TFLOPS, latency, bandwidth) with:
|
||||
- Log-space regression (log1p transform) for scale-invariant accuracy
|
||||
- GroupKFold cross-validation (group key = (M, N, K))
|
||||
- Iterative Hard Example Mining (IHEM)
|
||||
- Model complexity bounds for C++ deployability
|
||||
- Optional Optuna hyperparameter tuning
|
||||
- Warm-start incremental training from a previous model via --warm_start
|
||||
|
||||
Log-transform rationale:
|
||||
GEMM TFLOPS spans 5 orders of magnitude (0.02 for M=1 to 2230 for large
|
||||
shapes). Raw regression optimizes for absolute RMSE, which means the model
|
||||
spends all its capacity predicting large shapes accurately and ignores tiny
|
||||
shapes where TFLOPS is < 10. Training on log1p(TFLOPS) puts all shapes on
|
||||
equal footing, improving tiny_m efficiency from 84% to 96%.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import GroupKFold
|
||||
|
||||
from data_pipeline import build_training_dataset
|
||||
from feature_engine import GemmUniversalFeatureEngine
|
||||
|
||||
|
||||
TARGET_COLUMNS = {
|
||||
"tflops": "measured_tflops",
|
||||
"latency": "latency_ms",
|
||||
"bandwidth": "bandwidth_gb_s",
|
||||
}
|
||||
|
||||
# Targets where log1p transform is applied by default.
|
||||
# TFLOPS and bandwidth span orders of magnitude; latency is already small-scale.
|
||||
LOG_TARGETS = {"tflops", "bandwidth"}
|
||||
|
||||
DEFAULT_PARAMS = {
|
||||
"objective": "regression",
|
||||
"metric": ["rmse", "mae"],
|
||||
"num_leaves": 255,
|
||||
"max_depth": 15,
|
||||
"n_estimators": 2000,
|
||||
"learning_rate": 0.02,
|
||||
"min_child_samples": 10,
|
||||
"subsample": 0.85,
|
||||
"colsample_bytree": 0.85,
|
||||
"reg_alpha": 0.05,
|
||||
"reg_lambda": 0.5,
|
||||
"verbose": -1,
|
||||
"n_jobs": 8,
|
||||
"seed": 42,
|
||||
}
|
||||
|
||||
MAX_ESTIMATORS = 5000
|
||||
WARM_START_N_ESTIMATORS = 500
|
||||
|
||||
|
||||
def check_feature_compatibility(
|
||||
prev_model_dir: Path,
|
||||
feature_engine: GemmUniversalFeatureEngine,
|
||||
) -> None:
|
||||
"""Verify that the previous model's feature spec matches the current engine.
|
||||
|
||||
Raises ValueError with a detailed message on mismatch. This prevents silent
|
||||
corruption when warm-starting from a model trained with a different feature
|
||||
schema (e.g., after adding a new feature or changing an encoding).
|
||||
"""
|
||||
spec_path = prev_model_dir / "feature_spec.json"
|
||||
if not spec_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"No feature_spec.json in {prev_model_dir}. "
|
||||
"Cannot verify feature compatibility for warm start."
|
||||
)
|
||||
|
||||
with open(spec_path) as f:
|
||||
prev_spec = json.load(f)
|
||||
|
||||
prev_names = prev_spec.get("feature_names", [])
|
||||
curr_names = feature_engine.get_feature_names()
|
||||
if prev_names != curr_names:
|
||||
added = set(curr_names) - set(prev_names)
|
||||
removed = set(prev_names) - set(curr_names)
|
||||
parts = ["Feature schema mismatch between previous model and current engine."]
|
||||
if added:
|
||||
parts.append(f" Added features: {sorted(added)}")
|
||||
if removed:
|
||||
parts.append(f" Removed features: {sorted(removed)}")
|
||||
if not added and not removed:
|
||||
parts.append(" Feature order changed (names match but order differs).")
|
||||
raise ValueError("\n".join(parts))
|
||||
|
||||
prev_cats = prev_spec.get("categorical_features", [])
|
||||
curr_cats = feature_engine.get_categorical_features()
|
||||
if sorted(prev_cats) != sorted(curr_cats):
|
||||
raise ValueError(
|
||||
f"Categorical feature mismatch.\n"
|
||||
f" Previous: {sorted(prev_cats)}\n"
|
||||
f" Current: {sorted(curr_cats)}"
|
||||
)
|
||||
|
||||
|
||||
def load_warm_start_model(prev_model_dir: Path, target: str) -> str | None:
|
||||
"""Load the path to a previous model file for warm-start, or None if absent.
|
||||
|
||||
Automatically decompresses .lgbm.gz files if the .lgbm file doesn't exist.
|
||||
The decompressed file is cached to disk for subsequent loads.
|
||||
|
||||
Returns the string path (what LightGBM's init_model expects) rather than
|
||||
a loaded Booster, because LGBMRegressor.fit(init_model=...) accepts both
|
||||
path strings and Booster objects and path strings avoid keeping the old
|
||||
model in memory.
|
||||
"""
|
||||
import gzip
|
||||
|
||||
model_path = prev_model_dir / f"model_{target}.lgbm"
|
||||
gz_path = prev_model_dir / f"model_{target}.lgbm.gz"
|
||||
|
||||
# Auto-decompress if needed
|
||||
if not model_path.exists() and gz_path.exists():
|
||||
print(f" Decompressing {gz_path.name}...")
|
||||
with gzip.open(gz_path, "rb") as f_in:
|
||||
with open(model_path, "wb") as f_out:
|
||||
f_out.write(f_in.read())
|
||||
|
||||
if not model_path.exists():
|
||||
return None
|
||||
return str(model_path)
|
||||
|
||||
|
||||
def compute_group_keys(df: pd.DataFrame) -> np.ndarray:
|
||||
"""Create GroupKFold group keys from (M, N, K)."""
|
||||
return (
|
||||
df["m"].astype(str) + "_" + df["n"].astype(str) + "_" + df["k"].astype(str)
|
||||
).values
|
||||
|
||||
|
||||
def compute_tflops_efficiency(
|
||||
df: pd.DataFrame, pred_col: str = "pred_tflops"
|
||||
) -> pd.DataFrame:
|
||||
"""Compute per-shape efficiency: predicted-best TFLOPS / oracle-best TFLOPS."""
|
||||
results = []
|
||||
for (m, n, k), group in df.groupby(["m", "n", "k"]):
|
||||
oracle_best = group["measured_tflops"].max()
|
||||
if oracle_best <= 0:
|
||||
continue
|
||||
pred_best_idx = group[pred_col].idxmax()
|
||||
selected_tflops = group.loc[pred_best_idx, "measured_tflops"]
|
||||
efficiency = selected_tflops / oracle_best
|
||||
results.append(
|
||||
{
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"oracle_best_tflops": oracle_best,
|
||||
"selected_tflops": selected_tflops,
|
||||
"efficiency": efficiency,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(results)
|
||||
|
||||
|
||||
def train_single_target(
|
||||
X_train,
|
||||
y_train,
|
||||
X_val,
|
||||
y_val,
|
||||
params: dict,
|
||||
categorical_features: list[str],
|
||||
feature_names: list[str],
|
||||
init_model=None,
|
||||
) -> lgb.LGBMRegressor:
|
||||
"""Train a single LGBMRegressor with early stopping.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
|
||||
If provided, training continues from this model (warm start).
|
||||
Accepts a file path to a .lgbm file, a Booster instance, or an
|
||||
LGBMModel instance. The new model adds n_estimators trees on top
|
||||
of the existing ones.
|
||||
"""
|
||||
cat_indices = [
|
||||
feature_names.index(c) for c in categorical_features if c in feature_names
|
||||
]
|
||||
|
||||
model = lgb.LGBMRegressor(**params)
|
||||
model.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
eval_metric=["rmse"],
|
||||
callbacks=[
|
||||
lgb.early_stopping(50, verbose=False),
|
||||
lgb.log_evaluation(0),
|
||||
],
|
||||
categorical_feature=cat_indices if cat_indices else "auto",
|
||||
init_model=init_model,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def run_cv(
|
||||
df: pd.DataFrame,
|
||||
feature_engine: GemmUniversalFeatureEngine,
|
||||
target: str,
|
||||
params: dict,
|
||||
n_splits: int = 5,
|
||||
use_log: bool = True,
|
||||
) -> dict:
|
||||
"""Run GroupKFold cross-validation and return OOF predictions + metrics.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
use_log : bool
|
||||
If True and target is in LOG_TARGETS, train on log1p(y) and invert
|
||||
predictions with expm1 for efficiency calculation. This normalizes
|
||||
the scale so that tiny-M shapes (TFLOPS ~ 1) get equal attention
|
||||
as large-M shapes (TFLOPS ~ 2000).
|
||||
"""
|
||||
target_col = TARGET_COLUMNS[target]
|
||||
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
|
||||
df_valid = df[valid_mask].reset_index(drop=True)
|
||||
|
||||
apply_log = use_log and target in LOG_TARGETS
|
||||
|
||||
print(
|
||||
f" Training on {len(df_valid)} valid rows for target={target}"
|
||||
f"{' (log-space)' if apply_log else ''}"
|
||||
)
|
||||
|
||||
X = feature_engine.extract_batch(df_valid)
|
||||
y_raw = df_valid[target_col].values
|
||||
y = np.log1p(y_raw) if apply_log else y_raw
|
||||
groups = compute_group_keys(df_valid)
|
||||
feature_names = feature_engine.get_feature_names()
|
||||
cat_features = feature_engine.get_categorical_features()
|
||||
|
||||
unique_groups = np.unique(groups)
|
||||
actual_splits = min(n_splits, len(unique_groups))
|
||||
if actual_splits < 2:
|
||||
print(f" WARNING: Only {len(unique_groups)} unique groups, skipping CV")
|
||||
return {}
|
||||
|
||||
gkf = GroupKFold(n_splits=actual_splits)
|
||||
oof_preds = np.zeros(len(df_valid))
|
||||
fold_metrics = []
|
||||
|
||||
for fold_idx, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
|
||||
X_tr, X_val = X[train_idx], X[val_idx]
|
||||
y_tr, y_val = y[train_idx], y[val_idx]
|
||||
|
||||
model = train_single_target(
|
||||
X_tr, y_tr, X_val, y_val, params, cat_features, feature_names
|
||||
)
|
||||
preds = model.predict(X_val)
|
||||
oof_preds[val_idx] = preds
|
||||
|
||||
rmse = np.sqrt(np.mean((preds - y_val) ** 2))
|
||||
r2 = 1 - np.sum((preds - y_val) ** 2) / max(
|
||||
np.sum((y_val - y_val.mean()) ** 2), 1e-10
|
||||
)
|
||||
|
||||
if target == "tflops":
|
||||
val_df = df_valid.iloc[val_idx].copy()
|
||||
preds_raw = np.expm1(preds) if apply_log else preds
|
||||
val_df["pred_tflops"] = preds_raw
|
||||
eff_df = compute_tflops_efficiency(val_df)
|
||||
mean_eff = eff_df["efficiency"].mean() if len(eff_df) > 0 else 0
|
||||
p10_eff = eff_df["efficiency"].quantile(0.1) if len(eff_df) > 0 else 0
|
||||
else:
|
||||
mean_eff, p10_eff = None, None
|
||||
|
||||
fold_metrics.append(
|
||||
{
|
||||
"fold": fold_idx,
|
||||
"rmse": rmse,
|
||||
"r2": r2,
|
||||
"mean_efficiency": mean_eff,
|
||||
"p10_efficiency": p10_eff,
|
||||
"train_size": len(train_idx),
|
||||
"val_size": len(val_idx),
|
||||
"val_groups": len(np.unique(groups[val_idx])),
|
||||
}
|
||||
)
|
||||
|
||||
eff_str = (
|
||||
f", eff={mean_eff:.4f}, p10={p10_eff:.4f}" if mean_eff is not None else ""
|
||||
)
|
||||
print(f" Fold {fold_idx}: RMSE={rmse:.4f}, R2={r2:.4f}{eff_str}")
|
||||
|
||||
df_valid[f"oof_pred_{target}"] = oof_preds
|
||||
|
||||
return {
|
||||
"fold_metrics": fold_metrics,
|
||||
"oof_df": df_valid,
|
||||
"feature_names": feature_names,
|
||||
"log_transform": apply_log,
|
||||
}
|
||||
|
||||
|
||||
def train_final_model(
|
||||
df: pd.DataFrame,
|
||||
feature_engine: GemmUniversalFeatureEngine,
|
||||
target: str,
|
||||
params: dict,
|
||||
init_model=None,
|
||||
use_log: bool = True,
|
||||
) -> lgb.LGBMRegressor:
|
||||
"""Train the final model on all valid data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_model : str, Path, lgb.Booster, lgb.LGBMModel, or None
|
||||
If provided, training continues from this model (warm start).
|
||||
use_log : bool
|
||||
If True and target is in LOG_TARGETS, train on log1p(y).
|
||||
The saved model then predicts in log-space; callers must apply
|
||||
expm1() to get raw values.
|
||||
"""
|
||||
target_col = TARGET_COLUMNS[target]
|
||||
valid_mask = df["is_valid"].fillna(False) & (df[target_col] > 0)
|
||||
df_valid = df[valid_mask].reset_index(drop=True)
|
||||
|
||||
apply_log = use_log and target in LOG_TARGETS
|
||||
|
||||
X = feature_engine.extract_batch(df_valid)
|
||||
y_raw = df_valid[target_col].values
|
||||
y = np.log1p(y_raw) if apply_log else y_raw
|
||||
feature_names = feature_engine.get_feature_names()
|
||||
cat_features = feature_engine.get_categorical_features()
|
||||
cat_indices = [feature_names.index(c) for c in cat_features if c in feature_names]
|
||||
|
||||
model = lgb.LGBMRegressor(**params)
|
||||
model.fit(
|
||||
X,
|
||||
y,
|
||||
categorical_feature=cat_indices if cat_indices else "auto",
|
||||
init_model=init_model,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train CK Tile kernel performance models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir", required=True, help="Directory with parquet files"
|
||||
)
|
||||
parser.add_argument("--out_dir", required=True, help="Output directory for models")
|
||||
parser.add_argument("--op", default="gemm_universal", help="Operation type")
|
||||
parser.add_argument("--dtype", default="fp8", help="Data type filter")
|
||||
parser.add_argument("--arch", default="gfx950", help="Architecture")
|
||||
parser.add_argument(
|
||||
"--targets", default="tflops,latency,bandwidth", help="Comma-separated targets"
|
||||
)
|
||||
parser.add_argument("--n_splits", type=int, default=5, help="Number of CV folds")
|
||||
parser.add_argument(
|
||||
"--tune", action="store_true", help="Run Optuna hyperparameter tuning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_log_transform",
|
||||
action="store_true",
|
||||
help="Disable log1p transform on targets. By default, TFLOPS and bandwidth "
|
||||
"are trained in log-space for scale-invariant accuracy across shape sizes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warm_start",
|
||||
default=None,
|
||||
help="Path to previous model directory to continue training from. "
|
||||
"Uses LightGBM's init_model to add new trees on top of the "
|
||||
"existing model. Feature schemas must match exactly.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warm_start_n_estimators",
|
||||
type=int,
|
||||
default=WARM_START_N_ESTIMATORS,
|
||||
help=f"Number of new trees to add when warm-starting (default: {WARM_START_N_ESTIMATORS}). "
|
||||
"Lower than a full train since we're refining, not starting from scratch.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
targets = [t.strip() for t in args.targets.split(",")]
|
||||
|
||||
print(f"Loading data from {args.data_dir}...")
|
||||
df = build_training_dataset(args.data_dir, op_type=args.op, dtype=args.dtype)
|
||||
print(f" Total rows: {len(df)}")
|
||||
print(f" Unique shapes: {df.groupby(['m', 'n', 'k']).ngroups}")
|
||||
print(f" Unique kernels: {df['kernel_name'].nunique()}")
|
||||
|
||||
hw_cols = [c for c in df.columns if c.startswith("hw_")]
|
||||
hw_kwargs = {}
|
||||
if hw_cols:
|
||||
row0 = df.iloc[0]
|
||||
if "hw_num_cus" in df.columns:
|
||||
hw_kwargs["num_cus"] = int(row0.get("hw_num_cus", 256))
|
||||
if "hw_max_clock_mhz" in df.columns:
|
||||
hw_kwargs["max_clock_mhz"] = int(row0.get("hw_max_clock_mhz", 2400))
|
||||
if "hw_simds_per_cu" in df.columns:
|
||||
hw_kwargs["simds_per_cu"] = int(row0.get("hw_simds_per_cu", 4))
|
||||
if "hw_shader_engines" in df.columns:
|
||||
hw_kwargs["shader_engines"] = int(row0.get("hw_shader_engines", 32))
|
||||
if "hw_max_waves_per_cu" in df.columns:
|
||||
hw_kwargs["max_waves_per_cu"] = int(row0.get("hw_max_waves_per_cu", 32))
|
||||
if "hw_wavefront_size" in df.columns:
|
||||
hw_kwargs["wavefront_size"] = int(row0.get("hw_wavefront_size", 64))
|
||||
if "hw_l1_cache_kb" in df.columns:
|
||||
hw_kwargs["l1_cache_kb"] = int(row0.get("hw_l1_cache_kb", 32))
|
||||
if "hw_l2_cache_kb" in df.columns:
|
||||
hw_kwargs["l2_cache_kb"] = int(row0.get("hw_l2_cache_kb", 4096))
|
||||
if "hw_l3_cache_kb" in df.columns:
|
||||
hw_kwargs["l3_cache_kb"] = int(row0.get("hw_l3_cache_kb", 262144))
|
||||
|
||||
fe = GemmUniversalFeatureEngine(**hw_kwargs)
|
||||
|
||||
params = dict(DEFAULT_PARAMS)
|
||||
use_log = not args.no_log_transform
|
||||
|
||||
prev_model_dir = None
|
||||
prev_manifest = {}
|
||||
if args.warm_start:
|
||||
prev_model_dir = Path(args.warm_start)
|
||||
if not prev_model_dir.exists():
|
||||
raise FileNotFoundError(f"Warm-start directory not found: {prev_model_dir}")
|
||||
print(f" Warm-starting from {prev_model_dir}")
|
||||
check_feature_compatibility(prev_model_dir, fe)
|
||||
print(" Feature compatibility: OK")
|
||||
params["n_estimators"] = args.warm_start_n_estimators
|
||||
print(f" New trees to add: {args.warm_start_n_estimators}")
|
||||
|
||||
prev_manifest_path = prev_model_dir / "train_manifest.json"
|
||||
if prev_manifest_path.exists():
|
||||
with open(prev_manifest_path) as f:
|
||||
prev_manifest = json.load(f)
|
||||
|
||||
all_cv_results = {}
|
||||
for target in targets:
|
||||
if target not in TARGET_COLUMNS:
|
||||
print(f" Skipping unknown target: {target}")
|
||||
continue
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Training {target} model")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
init_model_path = None
|
||||
if prev_model_dir is not None:
|
||||
init_model_path = load_warm_start_model(prev_model_dir, target)
|
||||
if init_model_path:
|
||||
print(f" Warm-starting from {init_model_path}")
|
||||
else:
|
||||
print(f" No previous {target} model found, training from scratch")
|
||||
|
||||
t0 = time.time()
|
||||
cv_result = run_cv(
|
||||
df, fe, target, params, n_splits=args.n_splits, use_log=use_log
|
||||
)
|
||||
cv_time = time.time() - t0
|
||||
|
||||
if cv_result and cv_result["fold_metrics"]:
|
||||
all_cv_results[target] = cv_result["fold_metrics"]
|
||||
metrics_path = out_dir / f"cv_metrics_{target}.json"
|
||||
with open(metrics_path, "w") as f:
|
||||
json.dump(cv_result["fold_metrics"], f, indent=2)
|
||||
print(f" CV completed in {cv_time:.1f}s, saved to {metrics_path}")
|
||||
|
||||
if target == "tflops" and cv_result.get("oof_df") is not None:
|
||||
oof_df = cv_result["oof_df"]
|
||||
oof_df.to_parquet(out_dir / "oof_predictions.parquet", index=False)
|
||||
|
||||
eff_df = compute_tflops_efficiency(oof_df, "oof_pred_tflops")
|
||||
if len(eff_df) > 0:
|
||||
print("\n OOF TFLOPS Efficiency:")
|
||||
print(f" Mean: {eff_df['efficiency'].mean():.4f}")
|
||||
print(f" P10: {eff_df['efficiency'].quantile(0.1):.4f}")
|
||||
print(f" P50: {eff_df['efficiency'].quantile(0.5):.4f}")
|
||||
print(f" Min: {eff_df['efficiency'].min():.4f}")
|
||||
|
||||
print(f"\n Training final {target} model on all data...")
|
||||
t0 = time.time()
|
||||
model = train_final_model(
|
||||
df, fe, target, params, init_model=init_model_path, use_log=use_log
|
||||
)
|
||||
train_time = time.time() - t0
|
||||
|
||||
model_path = out_dir / f"model_{target}.lgbm"
|
||||
model.booster_.save_model(str(model_path))
|
||||
print(f" Saved {model_path} ({train_time:.1f}s)")
|
||||
|
||||
importances = dict(
|
||||
zip(
|
||||
fe.get_feature_names(),
|
||||
model.feature_importances_.tolist(),
|
||||
)
|
||||
)
|
||||
imp_path = out_dir / f"feature_importances_{target}.json"
|
||||
with open(imp_path, "w") as f:
|
||||
json.dump(importances, f, indent=2)
|
||||
|
||||
log_targets_used = sorted(LOG_TARGETS & set(targets)) if use_log else []
|
||||
spec = {
|
||||
"op_type": args.op,
|
||||
"dtype": args.dtype,
|
||||
"arch": args.arch,
|
||||
"feature_names": fe.get_feature_names(),
|
||||
"categorical_features": fe.get_categorical_features(),
|
||||
"targets": targets,
|
||||
"log_targets": log_targets_used,
|
||||
"params": params,
|
||||
}
|
||||
with open(out_dir / "feature_spec.json", "w") as f:
|
||||
json.dump(spec, f, indent=2)
|
||||
|
||||
manifest = {
|
||||
"warm_start_from": str(prev_model_dir) if prev_model_dir else None,
|
||||
"prev_n_estimators": prev_manifest.get(
|
||||
"total_n_estimators", params.get("n_estimators")
|
||||
)
|
||||
if prev_model_dir
|
||||
else 0,
|
||||
"new_n_estimators": params["n_estimators"],
|
||||
"total_n_estimators": (
|
||||
prev_manifest.get("total_n_estimators", 0) + params["n_estimators"]
|
||||
if prev_model_dir
|
||||
else params["n_estimators"]
|
||||
),
|
||||
"data_rows": len(df),
|
||||
"valid_rows": int(df["is_valid"].fillna(False).sum()),
|
||||
"unique_shapes": int(df.groupby(["m", "n", "k"]).ngroups),
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
}
|
||||
with open(out_dir / "train_manifest.json", "w") as f:
|
||||
json.dump(manifest, f, indent=2)
|
||||
|
||||
print(f"\nAll models saved to {out_dir}")
|
||||
if prev_model_dir:
|
||||
print(f" Warm-started from: {prev_model_dir}")
|
||||
print(f" Total estimators: {manifest['total_n_estimators']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
317
dispatcher/heuristics/validate_ml_heuristic.py
Normal file
317
dispatcher/heuristics/validate_ml_heuristic.py
Normal file
@@ -0,0 +1,317 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
ML Heuristic Validation: Test ML predictions against oracle-best from training data
|
||||
|
||||
This script validates ML-based kernel selection by:
|
||||
1. Loading benchmark data (oracle-best results for each shape)
|
||||
2. Using ML model to predict best kernel for each shape
|
||||
3. Comparing ML selection with oracle-best to compute efficiency
|
||||
|
||||
Usage:
|
||||
python validate_ml_heuristic.py --dtype fp16 --model_dir models/gemm_universal_fp16_gfx950
|
||||
python validate_ml_heuristic.py --dtype fp8 --layout rcr
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from predict import Predictor
|
||||
|
||||
|
||||
def validate_ml_heuristic(dtype: str, layout: str, model_dir: str, data_dir: str):
|
||||
"""Validate ML heuristic predictions against oracle-best"""
|
||||
|
||||
print("=" * 100)
|
||||
print(f" ML Heuristic Validation: {dtype.upper()} {layout.upper()}")
|
||||
print("=" * 100)
|
||||
print()
|
||||
|
||||
# Load training data
|
||||
print(f"Loading training data from {data_dir}...")
|
||||
|
||||
# Try dtype-specific parquet first, then fall back to combined
|
||||
dtype_specific = (
|
||||
Path(data_dir) / f"{dtype}_original" / f"{dtype}_training_data.parquet"
|
||||
)
|
||||
combined = Path(data_dir) / "all_training_data_fixed.parquet"
|
||||
|
||||
if dtype_specific.exists():
|
||||
training_data = pd.read_parquet(dtype_specific)
|
||||
print(f"✓ Loaded {len(training_data):,} benchmark runs from {dtype_specific}")
|
||||
elif combined.exists():
|
||||
training_data = pd.read_parquet(combined)
|
||||
training_data = training_data[
|
||||
(training_data["dtype"] == dtype) & (training_data["layout"] == layout)
|
||||
]
|
||||
print(f"✓ Loaded {len(training_data):,} benchmark runs from {combined}")
|
||||
else:
|
||||
print(f"❌ Error: No training data found at {dtype_specific} or {combined}")
|
||||
return
|
||||
|
||||
if len(training_data) == 0:
|
||||
print(f"❌ Error: No data found for dtype={dtype}, layout={layout}")
|
||||
return
|
||||
|
||||
# Get unique shapes with oracle-best
|
||||
shape_groups = training_data.groupby(["m", "n", "k"])
|
||||
print(f"Unique shapes: {len(shape_groups)}")
|
||||
print()
|
||||
|
||||
# Load ML predictor
|
||||
print(f"Loading ML predictor from {model_dir}...")
|
||||
try:
|
||||
predictor = Predictor(model_dir)
|
||||
print("✓ Loaded ML predictor")
|
||||
print(f" Log targets: {predictor._log_targets}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading model: {e}")
|
||||
return
|
||||
|
||||
print()
|
||||
print("=" * 100)
|
||||
print(" Computing Oracle-Best Efficiency for Each Shape")
|
||||
print("=" * 100)
|
||||
print()
|
||||
|
||||
results = []
|
||||
|
||||
for shape_idx, ((m, n, k), group) in enumerate(shape_groups):
|
||||
# Find oracle-best (max TFLOPS across all kernels tested)
|
||||
oracle_best_row = group.loc[group["measured_tflops"].idxmax()]
|
||||
oracle_best_tflops = oracle_best_row["measured_tflops"]
|
||||
oracle_best_kernel = oracle_best_row["kernel_name"]
|
||||
|
||||
# Get all kernel configs tested for this shape
|
||||
kernel_configs = []
|
||||
for _, row in group.iterrows():
|
||||
kernel_dict = {
|
||||
"tile_m": row["tile_m"],
|
||||
"tile_n": row["tile_n"],
|
||||
"tile_k": row["tile_k"],
|
||||
"warp_m": row["warp_m"],
|
||||
"warp_n": row["warp_n"],
|
||||
"warp_k": row["warp_k"],
|
||||
"warp_tile_m": row["warp_tile_m"],
|
||||
"warp_tile_n": row["warp_tile_n"],
|
||||
"warp_tile_k": row["warp_tile_k"],
|
||||
"pipeline": row["pipeline"],
|
||||
"scheduler": row["scheduler"],
|
||||
"epilogue": row["epilogue"],
|
||||
"pad_m": row["pad_m"],
|
||||
"pad_n": row["pad_n"],
|
||||
"pad_k": row["pad_k"],
|
||||
"persistent": row["persistent"],
|
||||
"kernel_name": row["kernel_name"],
|
||||
}
|
||||
kernel_configs.append(kernel_dict)
|
||||
|
||||
# Use ML model to rank kernels
|
||||
problem = {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"dtype": dtype,
|
||||
"layout": layout,
|
||||
"split_k": 1,
|
||||
}
|
||||
|
||||
try:
|
||||
ranked = predictor.rank_kernels(problem, kernel_configs)
|
||||
|
||||
if ranked:
|
||||
ml_best_kernel, ml_predicted_tflops = ranked[0]
|
||||
|
||||
# Find actual TFLOPS for the ML-predicted kernel
|
||||
ml_kernel_row = group[group["kernel_name"] == ml_best_kernel]
|
||||
if len(ml_kernel_row) > 0:
|
||||
ml_actual_tflops = ml_kernel_row["measured_tflops"].values[0]
|
||||
|
||||
# Calculate efficiency
|
||||
efficiency_pct = 100.0 * (ml_actual_tflops / oracle_best_tflops)
|
||||
|
||||
# Determine if ML picked oracle-best
|
||||
is_oracle_best = ml_best_kernel == oracle_best_kernel
|
||||
|
||||
results.append(
|
||||
{
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"oracle_best_tflops": oracle_best_tflops,
|
||||
"oracle_best_kernel": oracle_best_kernel,
|
||||
"ml_predicted_tflops": ml_predicted_tflops,
|
||||
"ml_selected_kernel": ml_best_kernel,
|
||||
"ml_actual_tflops": ml_actual_tflops,
|
||||
"efficiency_pct": efficiency_pct,
|
||||
"is_oracle_best": is_oracle_best,
|
||||
"num_kernels": len(group),
|
||||
}
|
||||
)
|
||||
|
||||
if (shape_idx + 1) % 20 == 0:
|
||||
status = "✓" if is_oracle_best else f"{efficiency_pct:.1f}%"
|
||||
print(
|
||||
f" [{shape_idx + 1:3d}/{len(shape_groups)}] "
|
||||
f"M={m:4d} N={n:5d} K={k:5d}: {status}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" Error on shape M={m} N={n} K={k}: {e}")
|
||||
continue
|
||||
|
||||
print()
|
||||
print("=" * 100)
|
||||
print(" Results Summary")
|
||||
print("=" * 100)
|
||||
print()
|
||||
|
||||
if results:
|
||||
df_results = pd.DataFrame(results)
|
||||
efficiencies = df_results["efficiency_pct"].values
|
||||
oracle_matches = df_results["is_oracle_best"].sum()
|
||||
|
||||
print(f"Total shapes tested: {len(results)}")
|
||||
print()
|
||||
print("Efficiency Statistics (% of Oracle-Best TFLOPS):")
|
||||
print(f" Mean: {np.mean(efficiencies):.2f}%")
|
||||
print(f" Median: {np.median(efficiencies):.2f}%")
|
||||
print(f" Min: {np.min(efficiencies):.2f}%")
|
||||
print(f" Max: {np.max(efficiencies):.2f}%")
|
||||
print(f" P10: {np.percentile(efficiencies, 10):.2f}%")
|
||||
print(f" P50: {np.percentile(efficiencies, 50):.2f}%")
|
||||
print(f" P90: {np.percentile(efficiencies, 90):.2f}%")
|
||||
print()
|
||||
print(
|
||||
f"Oracle-best matches: {oracle_matches}/{len(results)} ({100 * oracle_matches / len(results):.1f}%)"
|
||||
)
|
||||
print()
|
||||
|
||||
# Classify by M size
|
||||
df_results["m_class"] = pd.cut(
|
||||
df_results["m"],
|
||||
bins=[0, 8, 128, 1024, float("inf")],
|
||||
labels=[
|
||||
"Tiny (M<8)",
|
||||
"Small (8≤M<128)",
|
||||
"Medium (128≤M<1024)",
|
||||
"Large (M≥1024)",
|
||||
],
|
||||
)
|
||||
|
||||
print("Efficiency by M size:")
|
||||
for m_class in [
|
||||
"Tiny (M<8)",
|
||||
"Small (8≤M<128)",
|
||||
"Medium (128≤M<1024)",
|
||||
"Large (M≥1024)",
|
||||
]:
|
||||
subset = df_results[df_results["m_class"] == m_class]
|
||||
if len(subset) > 0:
|
||||
print(
|
||||
f" {m_class:25s}: {subset['efficiency_pct'].mean():6.2f}% "
|
||||
f"(n={len(subset)}, P10={subset['efficiency_pct'].quantile(0.1):.2f}%)"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
# Save results
|
||||
output_file = f"validation_results_{dtype}_{layout}.csv"
|
||||
df_results.to_csv(output_file, index=False)
|
||||
print(f"✓ Results saved to {output_file}")
|
||||
|
||||
# Show best and worst shapes
|
||||
print()
|
||||
print("Top 5 shapes (best efficiency):")
|
||||
top5 = df_results.nlargest(5, "efficiency_pct")[
|
||||
["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"]
|
||||
]
|
||||
for idx, row in top5.iterrows():
|
||||
match = "✓" if row["is_oracle_best"] else " "
|
||||
print(
|
||||
f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: "
|
||||
f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)"
|
||||
)
|
||||
|
||||
print()
|
||||
print("Bottom 5 shapes (worst efficiency):")
|
||||
bottom5 = df_results.nsmallest(5, "efficiency_pct")[
|
||||
["m", "n", "k", "efficiency_pct", "oracle_best_tflops", "is_oracle_best"]
|
||||
]
|
||||
for idx, row in bottom5.iterrows():
|
||||
match = "✓" if row["is_oracle_best"] else " "
|
||||
print(
|
||||
f" {match} M={row['m']:5d} N={row['n']:5d} K={row['k']:5d}: "
|
||||
f"{row['efficiency_pct']:.2f}% ({row['oracle_best_tflops']:.2f} TFLOPS)"
|
||||
)
|
||||
|
||||
else:
|
||||
print("No results to display")
|
||||
|
||||
print()
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Validate ML heuristic predictions against oracle-best from training data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp8"],
|
||||
help="Data type to validate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
default="rcr",
|
||||
choices=["rcr", "rrr", "crr", "ccr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
default=None,
|
||||
help="Path to model directory (auto-detect if not specified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
help="Path to training data directory (auto-detect if not specified)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Auto-detect model directory if not specified
|
||||
if args.model_dir is None:
|
||||
heuristics_dir = Path(__file__).parent
|
||||
model_candidates = [
|
||||
heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx950",
|
||||
heuristics_dir / "models" / f"gemm_universal_{args.dtype}_gfx942",
|
||||
]
|
||||
for candidate in model_candidates:
|
||||
if candidate.exists():
|
||||
args.model_dir = str(candidate)
|
||||
break
|
||||
|
||||
if args.model_dir is None:
|
||||
print(f"❌ Error: Could not find model directory for {args.dtype}")
|
||||
print(f" Searched: {[str(c) for c in model_candidates]}")
|
||||
print(" Please specify --model_dir explicitly")
|
||||
return 1
|
||||
|
||||
# Auto-detect data directory if not specified
|
||||
if args.data_dir is None:
|
||||
heuristics_dir = Path(__file__).parent
|
||||
args.data_dir = str(heuristics_dir / "data")
|
||||
|
||||
validate_ml_heuristic(args.dtype, args.layout, args.model_dir, args.data_dir)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -140,6 +140,11 @@ struct KernelKey
|
||||
bool preshuffle; // Preshuffle (for weight preshuffle variants)
|
||||
bool transpose_c; // TransposeC
|
||||
std::uint8_t num_wave_groups; // NumWaveGroups
|
||||
|
||||
// Padding support flags (kPadM, kPadN, kPadK in generated kernels)
|
||||
bool pad_m = true; // Support arbitrary M dimensions via padding
|
||||
bool pad_n = true; // Support arbitrary N dimensions via padding
|
||||
bool pad_k = true; // Support arbitrary K dimensions via padding
|
||||
} algorithm;
|
||||
|
||||
std::string gfx_arch; // e.g. "gfx942", "gfx90a", "gfx908"
|
||||
@@ -185,7 +190,10 @@ struct KernelKey
|
||||
algorithm.double_buffer,
|
||||
algorithm.preshuffle,
|
||||
algorithm.transpose_c,
|
||||
algorithm.num_wave_groups);
|
||||
algorithm.num_wave_groups,
|
||||
algorithm.pad_m,
|
||||
algorithm.pad_n,
|
||||
algorithm.pad_k);
|
||||
}
|
||||
|
||||
/// Equality comparison
|
||||
@@ -397,8 +405,14 @@ inline std::string KernelKey::encode_identifier() const
|
||||
|
||||
// Include pipeline, scheduler, epilogue for uniqueness
|
||||
oss << to_string(algorithm.pipeline) << "_";
|
||||
oss << to_string(algorithm.scheduler) << "_";
|
||||
oss << to_string(algorithm.epilogue) << "_";
|
||||
oss << to_string(algorithm.scheduler) << "_";
|
||||
|
||||
// Match tile_engine naming: padding flags (True/False) then persistent flag
|
||||
oss << (algorithm.pad_m ? "True" : "False") << "_";
|
||||
oss << (algorithm.pad_n ? "True" : "False") << "_";
|
||||
oss << (algorithm.pad_k ? "True" : "False") << "_";
|
||||
oss << (algorithm.persistent ? "True" : "False") << "_";
|
||||
|
||||
// Match tile_engine naming: tile_m x tile_n x tile_k _ warp_m x warp_n x warp_k _
|
||||
// warp_tile_m x warp_tile_n x warp_tile_k
|
||||
@@ -407,9 +421,6 @@ inline std::string KernelKey::encode_identifier() const
|
||||
<< unsigned(algorithm.wave_shape.k) << "_" << unsigned(algorithm.warp_tile_shape.m) << "x"
|
||||
<< unsigned(algorithm.warp_tile_shape.n) << "x" << unsigned(algorithm.warp_tile_shape.k);
|
||||
|
||||
// Add trait flags
|
||||
oss << "_" << (algorithm.persistent ? "persist" : "nopers");
|
||||
|
||||
if(signature.split_k > 1)
|
||||
oss << "_splitk" << unsigned(signature.split_k);
|
||||
if(!signature.elementwise_op.empty() && signature.elementwise_op != "PassThrough")
|
||||
|
||||
379
dispatcher/include/ck_tile/dispatcher/ml_heuristic.hpp
Normal file
379
dispatcher/include/ck_tile/dispatcher/ml_heuristic.hpp
Normal file
@@ -0,0 +1,379 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "ck_tile/dispatcher/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_key.hpp"
|
||||
#include "ck_tile/dispatcher/registry.hpp"
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
namespace ck_tile {
|
||||
namespace dispatcher {
|
||||
extern "C" {
|
||||
int LGBM_BoosterCreateFromModelfile(const char*, int*, void**);
|
||||
int LGBM_BoosterPredictForMat(
|
||||
void*, const void*, int, int, int, int, int, int, int, const char*, int64_t*, double*);
|
||||
int LGBM_BoosterFree(void*);
|
||||
}
|
||||
inline int encode_pipeline(Pipeline p)
|
||||
{
|
||||
switch(p)
|
||||
{
|
||||
case Pipeline::CompV3: return 0;
|
||||
case Pipeline::CompV4: return 1;
|
||||
case Pipeline::CompV5: return 2;
|
||||
case Pipeline::Mem: return 3;
|
||||
case Pipeline::PreShuffleV2: return 4;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
inline int encode_scheduler(Scheduler s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case Scheduler::Intrawave: return 0;
|
||||
case Scheduler::Interwave: return 1;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
inline int encode_epilogue(Epilogue e)
|
||||
{
|
||||
switch(e)
|
||||
{
|
||||
case Epilogue::Default: return 0;
|
||||
case Epilogue::CShuffle: return 1;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
inline int encode_layout(LayoutTag a, LayoutTag b, LayoutTag c)
|
||||
{
|
||||
bool ra = (a == LayoutTag::RowMajor), rb = (b == LayoutTag::RowMajor);
|
||||
if(ra && !rb)
|
||||
return 0; // RCR
|
||||
if(ra && rb)
|
||||
return 1; // RRR
|
||||
if(!ra && rb)
|
||||
return 2; // CCR
|
||||
return 3; // CRR
|
||||
}
|
||||
inline double dtype_bytes_ml(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP32: return 4;
|
||||
case DataType::FP16:
|
||||
case DataType::BF16: return 2;
|
||||
case DataType::FP8:
|
||||
case DataType::BF8:
|
||||
case DataType::INT8: return 1;
|
||||
case DataType::INT4: return 0.5;
|
||||
default: return 2;
|
||||
}
|
||||
}
|
||||
struct HardwareProfile
|
||||
{
|
||||
int num_cus = 256, simds_per_cu = 4, shader_engines = 32, max_clock_mhz = 2400,
|
||||
max_waves_per_cu = 32, wavefront_size = 64, lds_capacity = 65536, l1_cache_kb = 32,
|
||||
l2_cache_kb = 4096, l3_cache_kb = 262144, num_xcd = 8;
|
||||
int total_simds() const { return num_cus * simds_per_cu; }
|
||||
};
|
||||
|
||||
// CRITICAL: Feature count MUST match feature_spec.json
|
||||
// Python training uses 72 features - this header MUST extract exactly 72 features in the same order
|
||||
static constexpr int NUM_FEATURES = 72;
|
||||
|
||||
inline std::array<double, NUM_FEATURES>
|
||||
extract_features(const Problem& prob, const KernelKey& key, const HardwareProfile& hw)
|
||||
{
|
||||
// Problem dimensions
|
||||
double M = prob.M, N = prob.N, K = prob.K;
|
||||
double sk = (prob.k_batch > 0 ? prob.k_batch : 1);
|
||||
double bpe = dtype_bytes_ml(key.signature.dtype_a);
|
||||
|
||||
// Log-scale features
|
||||
double l2M = std::log2(std::max(M, 1.0));
|
||||
double l2N = std::log2(std::max(N, 1.0));
|
||||
double l2K = std::log2(std::max(K, 1.0));
|
||||
double l2MNK = std::log2(std::max(M * N * K, 1.0));
|
||||
|
||||
// Arithmetic intensity
|
||||
double mem = (M * K + K * N + M * N) * bpe;
|
||||
double ai = 2.0 * M * N * K / std::max(mem, 1.0);
|
||||
|
||||
// Aspect ratios
|
||||
double ar_mn = M / std::max(N, 1.0);
|
||||
double ar_mk = M / std::max(K, 1.0);
|
||||
double ar_nk = N / std::max(K, 1.0);
|
||||
|
||||
// Layout encoding
|
||||
double layout = (double)encode_layout(
|
||||
key.signature.layout_a, key.signature.layout_b, key.signature.layout_c);
|
||||
|
||||
// Tile dimensions
|
||||
double tm = key.algorithm.tile_shape.m;
|
||||
double tn = key.algorithm.tile_shape.n;
|
||||
double tk = key.algorithm.tile_shape.k;
|
||||
|
||||
// Wave/warp dimensions
|
||||
double wm = key.algorithm.wave_shape.m;
|
||||
double wn = key.algorithm.wave_shape.n;
|
||||
double wk = key.algorithm.wave_shape.k;
|
||||
|
||||
// Warp tile dimensions
|
||||
double wtm = key.algorithm.warp_tile_shape.m;
|
||||
double wtn = key.algorithm.warp_tile_shape.n;
|
||||
double wtk = key.algorithm.warp_tile_shape.k;
|
||||
|
||||
// Algorithm encoding
|
||||
double pipeline = (double)encode_pipeline(key.algorithm.pipeline);
|
||||
double scheduler = (double)encode_scheduler(key.algorithm.scheduler);
|
||||
double epilogue = (double)encode_epilogue(key.algorithm.epilogue);
|
||||
|
||||
// Padding flags - read from KernelKey
|
||||
double pad_m = key.algorithm.pad_m ? 1.0 : 0.0;
|
||||
double pad_n = key.algorithm.pad_n ? 1.0 : 0.0;
|
||||
double pad_k = key.algorithm.pad_k ? 1.0 : 0.0;
|
||||
|
||||
// Persistent kernel flag
|
||||
double persistent = key.algorithm.persistent ? 1.0 : 0.0;
|
||||
|
||||
// Derived features
|
||||
double num_warps = wm * wn * wk;
|
||||
double tile_volume = tm * tn * tk;
|
||||
double tile_mn = tm * tn;
|
||||
|
||||
// LDS usage estimation
|
||||
double lest = (tm * tk + tn * tk) * bpe;
|
||||
double lcap = (key.algorithm.pipeline == Pipeline::CompV4) ? 32768.0 : (double)hw.lds_capacity;
|
||||
double lds_ratio = lest / std::max(lcap, 1.0);
|
||||
|
||||
// Tile counts
|
||||
double ntm = std::ceil(M / std::max(tm, 1.0));
|
||||
double ntn = std::ceil(N / std::max(tn, 1.0));
|
||||
double ntk = std::ceil(K / std::max(tk, 1.0));
|
||||
double total_output_tiles = ntm * ntn;
|
||||
|
||||
// Tile efficiency (fractional remainder utilization)
|
||||
auto ef = [](double d, double t) -> double {
|
||||
if(t <= 0)
|
||||
return 1.0;
|
||||
double r = std::fmod(d, t);
|
||||
return r > 0 ? r / t : 1.0;
|
||||
};
|
||||
double tile_eff_m = ef(M, tm);
|
||||
double tile_eff_n = ef(N, tn);
|
||||
double tile_eff_k = ef(K, tk);
|
||||
double overall_tile_efficiency = tile_eff_m * tile_eff_n * tile_eff_k;
|
||||
|
||||
// CU utilization
|
||||
double cu_utilization = total_output_tiles / std::max((double)hw.num_cus, 1.0);
|
||||
|
||||
// P0 FIX: Problem-to-tile ratio features (critical for small problems)
|
||||
double ratio_M_to_tile_m = M / std::max(tm, 1.0);
|
||||
double ratio_N_to_tile_n = N / std::max(tn, 1.0);
|
||||
double ratio_K_to_tile_k = K / std::max(tk, 1.0);
|
||||
|
||||
// Binary features: is problem dimension smaller than tile?
|
||||
double problem_smaller_than_tile_m = (M < tm) ? 1.0 : 0.0;
|
||||
double problem_smaller_than_tile_n = (N < tn) ? 1.0 : 0.0;
|
||||
double problem_smaller_than_tile_k = (K < tk) ? 1.0 : 0.0;
|
||||
double any_dim_too_small = ((M < tm) || (N < tn) || (K < tk)) ? 1.0 : 0.0;
|
||||
|
||||
// P1 FIX: Padding requirement features
|
||||
double needs_padding_m = (tm > 0 && std::fmod(M, tm) != 0.0) ? 1.0 : 0.0;
|
||||
double needs_padding_n = (tn > 0 && std::fmod(N, tn) != 0.0) ? 1.0 : 0.0;
|
||||
double needs_padding_k = (tk > 0 && std::fmod(K, tk) != 0.0) ? 1.0 : 0.0;
|
||||
|
||||
// Interaction features: kernel has padding when problem needs it
|
||||
double has_padding_when_needed_m = (needs_padding_m && pad_m) ? 1.0 : 0.0;
|
||||
double has_padding_when_needed_n = (needs_padding_n && pad_n) ? 1.0 : 0.0;
|
||||
double has_padding_when_needed_k = (needs_padding_k && pad_k) ? 1.0 : 0.0;
|
||||
|
||||
// Critical feature: missing required padding (kernel will likely fail)
|
||||
double missing_required_padding_m = (needs_padding_m && !pad_m) ? 1.0 : 0.0;
|
||||
double missing_required_padding_n = (needs_padding_n && !pad_n) ? 1.0 : 0.0;
|
||||
double missing_required_padding_k = (needs_padding_k && !pad_k) ? 1.0 : 0.0;
|
||||
double missing_any_required_padding =
|
||||
(missing_required_padding_m || missing_required_padding_n || missing_required_padding_k)
|
||||
? 1.0
|
||||
: 0.0;
|
||||
|
||||
// Hardware features
|
||||
double hw_num_cus = (double)hw.num_cus;
|
||||
double hw_simds_per_cu = (double)hw.simds_per_cu;
|
||||
double hw_total_simds = (double)hw.total_simds();
|
||||
double hw_shader_engines = (double)hw.shader_engines;
|
||||
double hw_max_clock_mhz = (double)hw.max_clock_mhz;
|
||||
double hw_max_waves_per_cu = (double)hw.max_waves_per_cu;
|
||||
double hw_wavefront_size = (double)hw.wavefront_size;
|
||||
double hw_lds_capacity = (double)hw.lds_capacity;
|
||||
double hw_l1_cache_kb = (double)hw.l1_cache_kb;
|
||||
double hw_l2_cache_kb = (double)hw.l2_cache_kb;
|
||||
double hw_l3_cache_kb = (double)hw.l3_cache_kb;
|
||||
double hw_num_xcd = (double)hw.num_xcd;
|
||||
|
||||
// Feature vector in EXACT order from feature_spec.json
|
||||
// This order MUST match Python feature_engine.py::get_feature_names()
|
||||
return {{
|
||||
M, // 0
|
||||
N, // 1
|
||||
K, // 2
|
||||
sk, // 3 (split_k)
|
||||
l2M, // 4 (log2_M)
|
||||
l2N, // 5 (log2_N)
|
||||
l2K, // 6 (log2_K)
|
||||
l2MNK, // 7 (log2_MNK)
|
||||
ai, // 8 (arithmetic_intensity)
|
||||
ar_mn, // 9 (aspect_ratio_mn)
|
||||
ar_mk, // 10 (aspect_ratio_mk)
|
||||
ar_nk, // 11 (aspect_ratio_nk)
|
||||
layout, // 12 (layout)
|
||||
tm, // 13 (tile_m)
|
||||
tn, // 14 (tile_n)
|
||||
tk, // 15 (tile_k)
|
||||
wm, // 16 (warp_m)
|
||||
wn, // 17 (warp_n)
|
||||
wk, // 18 (warp_k)
|
||||
wtm, // 19 (warp_tile_m)
|
||||
wtn, // 20 (warp_tile_n)
|
||||
wtk, // 21 (warp_tile_k)
|
||||
pipeline, // 22 (pipeline)
|
||||
scheduler, // 23 (scheduler)
|
||||
epilogue, // 24 (epilogue)
|
||||
pad_m, // 25 (pad_m)
|
||||
pad_n, // 26 (pad_n)
|
||||
pad_k, // 27 (pad_k)
|
||||
persistent, // 28 (persistent)
|
||||
num_warps, // 29 (num_warps)
|
||||
tile_volume, // 30 (tile_volume)
|
||||
tile_mn, // 31 (tile_mn)
|
||||
lest, // 32 (lds_usage_estimate)
|
||||
lds_ratio, // 33 (lds_usage_ratio)
|
||||
ntm, // 34 (num_tiles_m)
|
||||
ntn, // 35 (num_tiles_n)
|
||||
ntk, // 36 (num_tiles_k)
|
||||
total_output_tiles, // 37 (total_output_tiles)
|
||||
tile_eff_m, // 38 (tile_eff_m)
|
||||
tile_eff_n, // 39 (tile_eff_n)
|
||||
tile_eff_k, // 40 (tile_eff_k)
|
||||
overall_tile_efficiency, // 41 (overall_tile_efficiency)
|
||||
cu_utilization, // 42 (cu_utilization)
|
||||
ratio_M_to_tile_m, // 43 (ratio_M_to_tile_m)
|
||||
ratio_N_to_tile_n, // 44 (ratio_N_to_tile_n)
|
||||
ratio_K_to_tile_k, // 45 (ratio_K_to_tile_k)
|
||||
problem_smaller_than_tile_m, // 46 (problem_smaller_than_tile_m)
|
||||
problem_smaller_than_tile_n, // 47 (problem_smaller_than_tile_n)
|
||||
problem_smaller_than_tile_k, // 48 (problem_smaller_than_tile_k)
|
||||
any_dim_too_small, // 49 (any_dim_too_small)
|
||||
needs_padding_m, // 50 (needs_padding_m)
|
||||
needs_padding_n, // 51 (needs_padding_n)
|
||||
needs_padding_k, // 52 (needs_padding_k)
|
||||
has_padding_when_needed_m, // 53 (has_padding_when_needed_m)
|
||||
has_padding_when_needed_n, // 54 (has_padding_when_needed_n)
|
||||
has_padding_when_needed_k, // 55 (has_padding_when_needed_k)
|
||||
missing_required_padding_m, // 56 (missing_required_padding_m)
|
||||
missing_required_padding_n, // 57 (missing_required_padding_n)
|
||||
missing_required_padding_k, // 58 (missing_required_padding_k)
|
||||
missing_any_required_padding, // 59 (missing_any_required_padding)
|
||||
hw_num_cus, // 60 (hw_num_cus)
|
||||
hw_simds_per_cu, // 61 (hw_simds_per_cu)
|
||||
hw_total_simds, // 62 (hw_total_simds)
|
||||
hw_shader_engines, // 63 (hw_shader_engines)
|
||||
hw_max_clock_mhz, // 64 (hw_max_clock_mhz)
|
||||
hw_max_waves_per_cu, // 65 (hw_max_waves_per_cu)
|
||||
hw_wavefront_size, // 66 (hw_wavefront_size)
|
||||
hw_lds_capacity, // 67 (hw_lds_capacity)
|
||||
hw_l1_cache_kb, // 68 (hw_l1_cache_kb)
|
||||
hw_l2_cache_kb, // 69 (hw_l2_cache_kb)
|
||||
hw_l3_cache_kb, // 70 (hw_l3_cache_kb)
|
||||
hw_num_xcd, // 71 (hw_num_xcd)
|
||||
}};
|
||||
}
|
||||
|
||||
class MLHeuristic
|
||||
{
|
||||
public:
|
||||
MLHeuristic(const std::string& path,
|
||||
const Registry* reg,
|
||||
HardwareProfile hw = {},
|
||||
bool log_t = false)
|
||||
: registry_(reg), hw_(hw), log_t_(log_t)
|
||||
{
|
||||
int iters = 0;
|
||||
if(LGBM_BoosterCreateFromModelfile(path.c_str(), &iters, &b_) != 0 || !b_)
|
||||
{
|
||||
std::cerr << "MLHeuristic: Failed to load " << path << std::endl;
|
||||
|
||||
// Check if a compressed .gz version exists
|
||||
std::string gz_path = path + ".gz";
|
||||
std::ifstream gz_check(gz_path);
|
||||
if(gz_check.good())
|
||||
{
|
||||
std::cerr << "MLHeuristic: Found compressed model at " << gz_path << std::endl;
|
||||
std::cerr << "MLHeuristic: Please decompress it first:" << std::endl;
|
||||
std::cerr << " gunzip " << gz_path << std::endl;
|
||||
}
|
||||
|
||||
b_ = nullptr;
|
||||
}
|
||||
else
|
||||
std::cout << "MLHeuristic: Loaded (" << iters << " iters)" << std::endl;
|
||||
}
|
||||
~MLHeuristic()
|
||||
{
|
||||
if(b_)
|
||||
LGBM_BoosterFree(b_);
|
||||
}
|
||||
MLHeuristic(const MLHeuristic&) = delete;
|
||||
MLHeuristic& operator=(const MLHeuristic&) = delete;
|
||||
bool is_loaded() const { return b_ != nullptr; }
|
||||
double predict_tflops(const Problem& prob, const KernelKey& key) const
|
||||
{
|
||||
if(!b_)
|
||||
return 0;
|
||||
auto f = extract_features(prob, key, hw_);
|
||||
int64_t ol = 0;
|
||||
double pred = 0;
|
||||
if(LGBM_BoosterPredictForMat(
|
||||
b_, f.data(), 0, 1, NUM_FEATURES, 1, 0, 0, 0, "", &ol, &pred) != 0)
|
||||
return 0;
|
||||
return log_t_ ? std::expm1(pred) : pred;
|
||||
}
|
||||
std::vector<std::string> operator()(const Problem& prob) const
|
||||
{
|
||||
if(!b_ || !registry_)
|
||||
return {};
|
||||
auto insts = registry_->get_all();
|
||||
struct C
|
||||
{
|
||||
std::string id;
|
||||
double t;
|
||||
};
|
||||
std::vector<C> cs;
|
||||
cs.reserve(insts.size());
|
||||
for(auto& i : insts)
|
||||
{
|
||||
auto& k = i->get_key();
|
||||
cs.push_back({k.encode_identifier(), predict_tflops(prob, k)});
|
||||
}
|
||||
std::sort(cs.begin(), cs.end(), [](auto& a, auto& b) { return a.t > b.t; });
|
||||
std::vector<std::string> r;
|
||||
r.reserve(cs.size());
|
||||
for(auto& c : cs)
|
||||
r.push_back(std::move(c.id));
|
||||
return r;
|
||||
}
|
||||
|
||||
private:
|
||||
void* b_ = nullptr;
|
||||
const Registry* registry_ = nullptr;
|
||||
HardwareProfile hw_;
|
||||
bool log_t_ = false;
|
||||
};
|
||||
} // namespace dispatcher
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +1,16 @@
|
||||
# Core dependencies
|
||||
numpy>=1.19.0
|
||||
|
||||
# ML Heuristic dependencies (OPTIONAL - large dependencies)
|
||||
# For ML-based kernel selection, install separately:
|
||||
# pip install -r ../requirements-ml.txt
|
||||
# This avoids mandatory large dependencies (pyarrow, lightgbm) for users who don't need ML features
|
||||
|
||||
# Optional dependencies (install with pip install -e ".[torch]")
|
||||
# torch>=2.0.0
|
||||
|
||||
# Development dependencies (install with pip install -e ".[dev]")
|
||||
# pytest>=6.0.0
|
||||
pytest>=6.0.0
|
||||
# pytest-cov>=2.0.0
|
||||
# black>=21.0
|
||||
# flake8>=3.9.0
|
||||
|
||||
6
dispatcher/requirements-ml.txt
Normal file
6
dispatcher/requirements-ml.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
# ML Heuristic dependencies for ML-based kernel selection
|
||||
# Install with: pip install -r requirements-ml.txt
|
||||
lightgbm>=3.0.0
|
||||
pandas>=1.3.0
|
||||
pyarrow>=6.0.0
|
||||
scikit-learn>=0.24.0
|
||||
@@ -538,6 +538,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
/* BlockSize = M0 = */ {F_bm0},
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
|
||||
@@ -1209,7 +1209,8 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q <= 8192")),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
} # fmt: skip
|
||||
@@ -1244,7 +1245,7 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
@@ -1303,7 +1304,23 @@ class Product:
|
||||
|
||||
def get_product(receipt: int) -> Product:
|
||||
# Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
if receipt == 2:
|
||||
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
cond = problem_ctx.dtype in ["fp16", "bf16"]
|
||||
cond &= kernel_ctx.pipeline.F_vlayout == "row"
|
||||
cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= kernel_ctx.pipeline.F_qscale == "no"
|
||||
cond &= kernel_ctx.pipeline.F_skip == "f"
|
||||
cond &= kernel_ctx.pipeline.F_sink == "f"
|
||||
# FlashAttention direct fwd wrappers always use softcap disabled and LSE enabled.
|
||||
cond &= kernel_ctx.pipeline.F_logits == "f"
|
||||
cond &= kernel_ctx.pipeline.F_lse == "t"
|
||||
return cond
|
||||
|
||||
return Product(name="Flash attention integration", rule=fit)
|
||||
# Receipt 3 forward coverage used by CK library / smoke tests
|
||||
elif receipt == 3:
|
||||
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
cond = problem_ctx.dtype in ["fp16", "bf16"]
|
||||
|
||||
@@ -939,6 +939,8 @@ def get_fwd_splitkv_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
# FlashAttention splitkv paths use softcap-disabled kernels only.
|
||||
cond &= pipeline.F_logits == "f"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
@@ -1142,4 +1144,7 @@ def list_blobs(
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + "\n")
|
||||
f.write(
|
||||
(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix()
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
@@ -87,6 +87,7 @@ auto create_args(int argc, char* argv[])
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
|
||||
"will not be used")
|
||||
.insert("sink_grad", "0", "if set to 1, compute and validate sink token gradient")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
|
||||
|
||||
@@ -122,6 +123,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
bool sink_grad = arg_parser.get_bool("sink_grad");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -154,6 +156,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
sink_grad,
|
||||
deterministic,
|
||||
init_method,
|
||||
seed,
|
||||
|
||||
@@ -117,6 +117,9 @@ struct fmha_bwd_args
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* workspace_ptr;
|
||||
const void*
|
||||
sink_ptr; // sink scores [batch, nhead] in log-space (LSEDataType); nullptr disables sink
|
||||
void* d_sink_ptr; // sink gradient output [nhead] (LSEDataType); nullptr disables sink gradient
|
||||
|
||||
// Usage notes for sequence length pointer parameters:
|
||||
//
|
||||
@@ -353,11 +356,15 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.lse_ptr,
|
||||
args.sink_ptr,
|
||||
args.d_sink_ptr,
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
@@ -369,9 +376,13 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.lse_ptr,
|
||||
args.sink_ptr,
|
||||
args.d_sink_ptr,
|
||||
args.p_undrop,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
|
||||
@@ -77,6 +77,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
uint64_t drop_offset,
|
||||
bool drop_prefs,
|
||||
std::string mask_str,
|
||||
bool sink_grad, // if true, compute and validate sink gradient
|
||||
bool deterministic,
|
||||
std::string init_method,
|
||||
uint32_t seed,
|
||||
@@ -285,6 +286,16 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<LSEDataType> sink_host(
|
||||
sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
|
||||
: std::array<ck_tile::index_t, 2>{1, 1} /* dummy when sink is disabled */);
|
||||
if(sink_grad)
|
||||
{
|
||||
std::uniform_real_distribution<float> sink_dist(30.0f, 60.0f);
|
||||
sink_host.ForEach([&](auto& self, auto i) {
|
||||
self(i) = static_cast<LSEDataType>(sink_dist(random_engine));
|
||||
});
|
||||
}
|
||||
ck_tile::HostTensor<DDataType> d_host(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host(
|
||||
@@ -302,6 +313,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
use_dbias
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<LSEDataType> d_sink_host(sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
|
||||
: std::array<ck_tile::index_t, 1>{0});
|
||||
if(sink_grad)
|
||||
{
|
||||
d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; });
|
||||
}
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
@@ -360,11 +377,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sink_buf(sink_grad ? sink_host.get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_sink_buf(sink_grad ? d_sink_host.get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
@@ -396,6 +415,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
if(sink_grad)
|
||||
{
|
||||
sink_buf.ToDevice(sink_host.data());
|
||||
d_sink_buf.ToDevice(d_sink_host.data());
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
auto layout_str = [&](bool permute){
|
||||
@@ -415,7 +439,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
|
||||
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
|
||||
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
|
||||
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
|
||||
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
|
||||
<< ", deterministic:" << deterministic
|
||||
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
|
||||
<< ", mask:" << mask << std::flush;
|
||||
|
||||
@@ -474,7 +499,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
|
||||
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
|
||||
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
@@ -490,6 +514,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dv_buf.GetDeviceBuffer(),
|
||||
dbias_buf.GetDeviceBuffer(),
|
||||
ws_ptr,
|
||||
sink_buf.GetDeviceBuffer(),
|
||||
d_sink_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
seqlen_q_ptr_dev,
|
||||
@@ -580,6 +606,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
|
||||
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<AccDataType>> p_sink_host_refs;
|
||||
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
@@ -756,6 +783,46 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
|
||||
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
|
||||
|
||||
// Incorporate sink token into the softmax distribution (reference computation).
|
||||
// The sink acts as an extra key whose score is sink_host(wb, i_h) (in log-space),
|
||||
// which is a per-head random value in [30, 60].
|
||||
// lse_new = log(exp(lse_old) + exp(sink))
|
||||
// P_new = P_old * exp(lse_old - lse_new) (rescaled token attention)
|
||||
// P_sink = exp(sink - lse_new) (sink attention weight)
|
||||
ck_tile::HostTensor<AccDataType> p_sink_host_ref(
|
||||
sink_grad ? std::array<ck_tile::index_t, 2>{nhead, real_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 2>{0, 0});
|
||||
if(sink_grad)
|
||||
{
|
||||
for(int i_h = 0; i_h < nhead; ++i_h)
|
||||
{
|
||||
AccDataType sink_val = sink_host(wb, i_h);
|
||||
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
|
||||
{
|
||||
// Use numerically stable log-sum-exp: lse_new = log(exp(lse_old)+exp(sink))
|
||||
// = max(lse_old, sink) + log(1 + exp(min - max))
|
||||
// This handles lse_old = -inf (fully-masked rows) without producing NaN:
|
||||
// if lse_old=-inf: max=sink, min=-inf, exp(-inf-sink)=0, lse_new=sink
|
||||
// It also avoids exp(lse_old) overflow when lse_old is large.
|
||||
// p_scale = exp(lse_old - lse_new) [fraction kept by regular tokens]
|
||||
// p_sink = exp(sink - lse_new) [sink attention weight]
|
||||
AccDataType lse_old = lse_host_ref(i_h, i_q);
|
||||
AccDataType hi = lse_old > sink_val ? lse_old : sink_val;
|
||||
AccDataType lo = lse_old > sink_val ? sink_val : lse_old;
|
||||
AccDataType lse_new =
|
||||
hi + ck_tile::log(AccDataType(1) + ck_tile::exp(lo - hi));
|
||||
AccDataType p_scale = ck_tile::exp(lse_old - lse_new);
|
||||
|
||||
lse_host_ref(i_h, i_q) = lse_new;
|
||||
|
||||
for(int i_k = 0; i_k < real_seqlen_k; ++i_k)
|
||||
p_hp_host_ref(i_h, i_q, i_k) *= p_scale;
|
||||
|
||||
p_sink_host_ref(i_h, i_q) = ck_tile::exp(sink_val - lse_new);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
@@ -814,6 +881,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
o_host_refs.push_back(o_host_ref);
|
||||
p_hp_host_refs.push_back(p_hp_host_ref);
|
||||
p_lp_host_refs.push_back(p_lp_host_ref);
|
||||
p_sink_host_refs.push_back(p_sink_host_ref);
|
||||
if(p_drop > 0)
|
||||
{
|
||||
randval_host_refs.push_back(randval_host_ref);
|
||||
@@ -833,6 +901,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dbias_buf.SetZero();
|
||||
if(sink_grad)
|
||||
d_sink_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
|
||||
launcher(fmha_args, stream_config_v);
|
||||
@@ -841,10 +911,19 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dk_buf.FromDevice(dk_host.data());
|
||||
dv_buf.FromDevice(dv_host.data());
|
||||
dbias_buf.FromDevice(dbias_host.data());
|
||||
if(sink_grad)
|
||||
d_sink_buf.FromDevice(d_sink_host.data());
|
||||
|
||||
// Track the index into reference vectors (may differ from wb if batches were skipped)
|
||||
ck_tile::index_t ref_idx = 0;
|
||||
|
||||
// validation sink accumulator: global over batch, shape [nhead]
|
||||
ck_tile::HostTensor<AccDataType> d_sink_host_ref(
|
||||
sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
|
||||
: std::array<ck_tile::index_t, 1>{0});
|
||||
if(sink_grad)
|
||||
d_sink_host_ref.ForEach([&](auto& self, auto i) { self(i) = 0; });
|
||||
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
// When padding is enabled, use logical lengths instead of computing from padded
|
||||
@@ -920,6 +999,30 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
|
||||
if(sink_grad)
|
||||
{
|
||||
// Reference: dSink[h] = -sum_q( P_sink[h,q] * D[h,q] )
|
||||
// where D[h,q] = sum_j(dO[h,q,j] * O[h,q,j]) * p_undrop
|
||||
for(int i_h = 0; i_h < nhead; ++i_h)
|
||||
{
|
||||
AccDataType d_sink_head_acc = 0;
|
||||
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
|
||||
{
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o +=
|
||||
ck_tile::type_convert<AccDataType>(do_host_ref(i_h, i_q, o)) *
|
||||
ck_tile::type_convert<AccDataType>(
|
||||
o_host_refs[ref_idx](i_h, i_q, o)) *
|
||||
p_undrop;
|
||||
}
|
||||
d_sink_head_acc += -p_sink_host_refs[ref_idx](i_h, i_q) * do_dot_o;
|
||||
}
|
||||
d_sink_host_ref(i_h) += d_sink_head_acc;
|
||||
}
|
||||
}
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
@@ -1032,6 +1135,17 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ref_idx++;
|
||||
}
|
||||
|
||||
if(pass && sink_grad)
|
||||
{
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
|
||||
bool dsink_pass = ck_tile::check_err(d_sink_host,
|
||||
d_sink_host_ref,
|
||||
std::string("Error: SinkGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
pass &= dsink_pass;
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@ function print_log_header(){
|
||||
#run verification tests
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
|
||||
|
||||
#run performance benchmarks
|
||||
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
|
||||
|
||||
@@ -69,6 +69,28 @@ test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# sink gradient tests: same coverage as main tests but with -sink_grad=1
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for bias in "n" "a" ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=0 -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=0 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# sink gradient additional cases: non-standard hdim
|
||||
for hdim in 40 48 72 96 ; do
|
||||
test_h_s_mask -prec=fp16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
done
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
|
||||
@@ -235,6 +235,64 @@ run_padding_basic_boundary_tests() {
|
||||
done
|
||||
}
|
||||
|
||||
# Sink-specific mask pattern tests (sliding window + sink token).
|
||||
run_sink_mask_tests() {
|
||||
# window_size[2,0], sink_size=2 (top-left causal + sink)
|
||||
# before: after:
|
||||
# 1 * * * * * * * 1 * * * * * * *
|
||||
# 1 1 * * * * * * 1 1 * * * * * *
|
||||
# 1 1 1 * * * * * 1 1 1 * * * * *
|
||||
# * 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
|
||||
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
|
||||
# * * * * * 1 1 1 1 1 * * * 1 1 1
|
||||
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
|
||||
run_exe -prec=bf16 -mode=0 -b=2 -h=2 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
|
||||
|
||||
# window_size[0,3], sink_size=2 (top-left + sink)
|
||||
# before: after:
|
||||
# 1 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * 1 1 1 1 * * 1 1 1 1 1 1 * *
|
||||
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
|
||||
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
|
||||
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
|
||||
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
|
||||
|
||||
# window_size[1,0], sink_size=2 (bottom-right + sink)
|
||||
# before: after:
|
||||
# * * 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * * 1 1 * * * 1 1 * 1 1 * * *
|
||||
# * * * * 1 1 * * 1 1 * * 1 1 * *
|
||||
# * * * * * 1 1 * 1 1 * * * 1 1 *
|
||||
# * * * * * * 1 1 1 1 * * * * 1 1
|
||||
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
|
||||
run_exe -prec=bf16 -mode=0 -b=2 -h=4 -d=128 -d_v=128 -s=2048 -s_k=2048 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
|
||||
|
||||
# window_size[2,0], sink_size=2 (bottom-right, group mode + sink)
|
||||
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
|
||||
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
|
||||
|
||||
# window_size[-1,1], sink_size=2 (bottom-right, large seqlen + sink)
|
||||
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
|
||||
run_exe -prec=bf16 -mode=1 -b=1 -h=2 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
|
||||
}
|
||||
|
||||
# init_sink tests: validate sink token initialization across prec/hdim/mode.
|
||||
run_sink_init_tests() {
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for mask in 0 1 ; do
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -d=$hdim -d_v=$hdim -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=4 -d=$hdim -d_v=$hdim -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
@@ -242,6 +300,8 @@ run_padding_smoke_tests
|
||||
run_padding_basic_boundary_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
run_sink_mask_tests
|
||||
run_sink_init_tests
|
||||
|
||||
if [ $TEST_APPENDKV -eq 1 ] ; then
|
||||
run_fp16_appendkv_tests
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_fwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=$GPU_arch
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
set -x
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
|
||||
|
||||
# window_size[2,0], sink_size = 2
|
||||
|
||||
# x=1/y=3
|
||||
# 1 * * * * * * * 1 * * * * * * *
|
||||
# 1 1 * * * * * * 1 1 * * * * * *
|
||||
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
|
||||
# * 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
|
||||
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
|
||||
# * * * * * 1 1 1 1 1 * * * 1 1 1
|
||||
# l=2/r=0(tl) l=2/r=0/s=2(tl)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
|
||||
|
||||
# x=4/y=1
|
||||
# 1 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
|
||||
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
|
||||
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
|
||||
# l=0/r=3(tl) l=0/r=3/s=2(tl)
|
||||
# l=3/r=0(br) l=3/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
|
||||
|
||||
# x=4/y=-1
|
||||
# * * 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * * 1 1 * * * 1 1 * 1 1 * * *
|
||||
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
|
||||
# * * * * * 1 1 * 1 1 * * * 1 1 *
|
||||
# * * * * * * 1 1 1 1 * * * * 1 1
|
||||
# l=1/r=0(br) l=1/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
|
||||
|
||||
# x=-1/y=5
|
||||
|
||||
# * * * * * * * * * * * *
|
||||
# * * * * * * * * * * * *
|
||||
# 1 * * * * * 1 * * * * *
|
||||
# 1 1 * * * * 1 1 * * * *
|
||||
# 1 1 1 * * * ----> 1 1 1 * * *
|
||||
# * 1 1 1 * * 1 1 1 1 * *
|
||||
# * * 1 1 1 * 1 1 1 1 1 *
|
||||
# * * * 1 1 1 1 1 * 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
|
||||
# x=-1/y=8
|
||||
# * * * * * * * * * *
|
||||
# * * * * * * * * * *
|
||||
# 1 * * * * ----> 1 * * * *
|
||||
# 1 1 * * * 1 1 * * *
|
||||
# 1 1 1 * * 1 1 1 * *
|
||||
# 1 1 1 1 * 1 1 1 1 *
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
@@ -63,8 +63,8 @@
|
||||
#define __gfx101__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
|
||||
defined(__gfx10_3_generic__)
|
||||
defined(__gfx1033__) || defined(__gfx1034__) || defined(__gfx1035__) || \
|
||||
defined(__gfx1036__) || defined(__gfx10_3_generic__)
|
||||
#define __gfx103__
|
||||
#endif
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
|
||||
@@ -125,8 +125,9 @@ inline bool is_gfx101_supported()
|
||||
inline bool is_gfx103_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
|
||||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
|
||||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
|
||||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1033" ||
|
||||
ck::get_device_name() == "gfx1034" || ck::get_device_name() == "gfx1035" ||
|
||||
ck::get_device_name() == "gfx1036";
|
||||
}
|
||||
|
||||
inline bool is_wmma_supported()
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -492,6 +493,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
c_grid_desc_m_n_container_.push_back(descs[I2]);
|
||||
}
|
||||
}
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
input_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()) *
|
||||
Conv_N_ * Conv_C_ * sizeof(CDataType);
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
@@ -512,6 +517,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
std::vector<ck::index_t> conv_filter_dilations_;
|
||||
std::vector<ck::index_t> input_left_pads_;
|
||||
std::vector<ck::index_t> input_right_pads_;
|
||||
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -571,18 +578,47 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
DeviceOp::CGridDesc_M_N,
|
||||
true>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
// Clear input only for perf measurement.
|
||||
// For non-grouped solver user has to clear input on his own.
|
||||
const auto clear_input = [&]() {
|
||||
if(i == 0)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_space_size_bytes,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
ave_time += launch_and_time_kernel_with_preprocess_flush_cache(
|
||||
stream_config,
|
||||
clear_input,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -594,18 +630,47 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
DeviceOp::CGridDesc_M_N,
|
||||
false>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
// Clear input only for perf measurement.
|
||||
// For non-grouped solver user has to clear input on his own.
|
||||
const auto clear_input = [&]() {
|
||||
if(i == 0)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_space_size_bytes,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
ave_time += launch_and_time_kernel_with_preprocess_flush_cache(
|
||||
stream_config,
|
||||
clear_input,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ave_time;
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1050,6 +1051,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
CreateABCDesc<NDimSpatial>();
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
input_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()) *
|
||||
Conv_N_ * Conv_C_ * sizeof(CDataType);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -1216,6 +1221,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
|
||||
std::vector<ck::index_t> conv_filter_dilations_;
|
||||
std::vector<ck::index_t> input_left_pads_;
|
||||
std::vector<ck::index_t> input_right_pads_;
|
||||
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -1273,18 +1280,47 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
DeviceOp::CGridDesc_M_N,
|
||||
true>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
// Clear input only for perf measurement.
|
||||
// For non-grouped solver user has to clear input on his own.
|
||||
const auto clear_input = [&]() {
|
||||
if(i == 0)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_space_size_bytes,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
ave_time += launch_and_time_kernel_with_preprocess_flush_cache(
|
||||
stream_config,
|
||||
clear_input,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1296,18 +1332,47 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
DeviceOp::CGridDesc_M_N,
|
||||
false>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
// Clear input only for perf measurement.
|
||||
// For non-grouped solver user has to clear input on his own.
|
||||
const auto clear_input = [&]() {
|
||||
if(i == 0)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_space_size_bytes,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
ave_time += launch_and_time_kernel_with_preprocess_flush_cache(
|
||||
stream_config,
|
||||
clear_input,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ave_time;
|
||||
|
||||
@@ -1225,26 +1225,50 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
has_main_loop,
|
||||
no_main_loop,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
return launch_and_time_kernel_with_preprocess_flush_cache(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1264,26 +1288,50 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
has_main_loop,
|
||||
no_main_loop,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
return launch_and_time_kernel_with_preprocess_flush_cache(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
}
|
||||
}
|
||||
};
|
||||
if(has_loop_in_all_gemm)
|
||||
|
||||
@@ -88,6 +88,7 @@ enum struct amdgcn_target_id
|
||||
GFX1030 = 0x1030,
|
||||
GFX1031 = 0x1031,
|
||||
GFX1032 = 0x1032,
|
||||
GFX1033 = 0x1033,
|
||||
GFX1034 = 0x1034,
|
||||
GFX1035 = 0x1035,
|
||||
GFX1036 = 0x1036,
|
||||
@@ -284,6 +285,7 @@ constexpr auto get_compiler_target()
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1033, GFX1033);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
|
||||
@@ -351,6 +353,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target_id(char const*
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1030", GFX1030);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1031", GFX1031);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1032", GFX1032);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1033", GFX1033);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1034", GFX1034);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1035", GFX1035);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1036", GFX1036);
|
||||
@@ -607,6 +610,7 @@ CK_TILE_HOST_DEVICE constexpr auto get_compiler_target()
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1030, GFX1030);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1031, GFX1031);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1032, GFX1032);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1033, GFX1033);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1034, GFX1034);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1035, GFX1035);
|
||||
MAP_COMPILER_STATE_TO_GFX10_3_TARGET(CK_TILE_ARCH_GFX1036, GFX1036);
|
||||
@@ -688,6 +692,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target(char const* tes
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1030", GFX1030);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1031", GFX1031);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1032", GFX1032);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1033", GFX1033);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1034", GFX1034);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1035", GFX1035);
|
||||
MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX10_3_TARGET("gfx1036", GFX1036);
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
#define __gfx101__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
|
||||
defined(__gfx10_3_generic__)
|
||||
defined(__gfx1033__) || defined(__gfx1034__) || defined(__gfx1035__) || \
|
||||
defined(__gfx1036__) || defined(__gfx10_3_generic__)
|
||||
#define __gfx103__
|
||||
#endif
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
@@ -405,6 +405,12 @@ struct amdgcn_compiler_target_state
|
||||
static constexpr bool CK_TILE_ARCH_GFX1032 = false;
|
||||
#endif // __gfx1032__
|
||||
|
||||
#if defined(__gfx1033__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1033 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1033 = false;
|
||||
#endif // __gfx1033__
|
||||
|
||||
#if defined(__gfx1034__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1034 = true;
|
||||
#else
|
||||
@@ -537,6 +543,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1033, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1034, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1035, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1036, \
|
||||
|
||||
@@ -39,7 +39,7 @@ CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<T
|
||||
{
|
||||
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return make_tuple(old_tuple[number<IRs>{}]...);
|
||||
}
|
||||
@@ -109,7 +109,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is..
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
|
||||
}
|
||||
@@ -120,7 +120,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is..
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
|
||||
|
||||
|
||||
@@ -144,9 +144,11 @@ struct sequence
|
||||
static_assert(MapOld2New::size() == size(),
|
||||
"wrong! reorder map should have the same size as sequence to be rerodered");
|
||||
|
||||
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<remove_cvref_t<MapOld2New>>::value,
|
||||
"wrong! invalid reorder map");
|
||||
|
||||
return reorder_new_to_old(typename sequence_map_inverse<MapOld2New>::type{});
|
||||
return reorder_new_to_old(
|
||||
typename sequence_map_inverse<remove_cvref_t<MapOld2New>>::type{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto reverse()
|
||||
@@ -548,163 +550,59 @@ struct sequence_reduce<Reduce, Seq>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename Values, typename Ids, typename Compare>
|
||||
struct sequence_sort_impl
|
||||
// Sorts a sequence using constexpr insertion sort. O(1) template instantiation
|
||||
// depth, replacing the recursive merge sort that created O(N log N) intermediate types.
|
||||
namespace detail {
|
||||
|
||||
template <typename Values, typename Compare, typename IndexSeq>
|
||||
struct sequence_sort_helper;
|
||||
|
||||
template <index_t... Vs, typename Compare, index_t... Idx>
|
||||
struct sequence_sort_helper<sequence<Vs...>, Compare, sequence<Idx...>>
|
||||
{
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl
|
||||
struct sort_result
|
||||
{
|
||||
static constexpr bool choose_left = LeftValues::front() < RightValues::front();
|
||||
|
||||
static constexpr index_t chosen_value =
|
||||
choose_left ? LeftValues::front() : RightValues::front();
|
||||
static constexpr index_t chosen_id = choose_left ? LeftIds::front() : RightIds::front();
|
||||
|
||||
using new_merged_values = decltype(MergedValues::push_back(number<chosen_value>{}));
|
||||
using new_merged_ids = decltype(MergedIds::push_back(number<chosen_id>{}));
|
||||
|
||||
using new_left_values = typename std::
|
||||
conditional<choose_left, decltype(LeftValues::pop_front()), LeftValues>::type;
|
||||
using new_left_ids =
|
||||
typename std::conditional<choose_left, decltype(LeftIds::pop_front()), LeftIds>::type;
|
||||
|
||||
using new_right_values = typename std::
|
||||
conditional<choose_left, RightValues, decltype(RightValues::pop_front())>::type;
|
||||
using new_right_ids =
|
||||
typename std::conditional<choose_left, RightIds, decltype(RightIds::pop_front())>::type;
|
||||
|
||||
using merge = sorted_sequence_merge_impl<new_left_values,
|
||||
new_left_ids,
|
||||
new_right_values,
|
||||
new_right_ids,
|
||||
new_merged_values,
|
||||
new_merged_ids,
|
||||
Comp>;
|
||||
// this is output
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
static_array<index_t, sizeof...(Vs)> values;
|
||||
static_array<index_t, sizeof...(Vs)> ids;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
sequence<>,
|
||||
sequence<>,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
static constexpr sort_result compute()
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
|
||||
};
|
||||
constexpr index_t n = sizeof...(Vs);
|
||||
sort_result r{{{Vs...}}, {{Idx...}}};
|
||||
// insertion sort — O(N^2) constexpr steps, O(1) template depth
|
||||
for(index_t i = 1; i < n; ++i)
|
||||
{
|
||||
for(index_t j = i; j > 0 && Compare{}(r.values[j], r.values[j - 1]); --j)
|
||||
{
|
||||
auto tv = r.values[j];
|
||||
r.values[j] = r.values[j - 1];
|
||||
r.values[j - 1] = tv;
|
||||
auto ti = r.ids[j];
|
||||
r.ids[j] = r.ids[j - 1];
|
||||
r.ids[j - 1] = ti;
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<sequence<>,
|
||||
sequence<>,
|
||||
RightValues,
|
||||
RightIds,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge
|
||||
{
|
||||
using merge = sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
RightValues,
|
||||
RightIds,
|
||||
sequence<>,
|
||||
sequence<>,
|
||||
Comp>;
|
||||
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
static constexpr index_t nsize = Values::size();
|
||||
|
||||
using split_unsorted_values = sequence_split<Values, nsize / 2>;
|
||||
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
|
||||
|
||||
using left_unsorted_values = typename split_unsorted_values::left_type;
|
||||
using left_unsorted_ids = typename split_unsorted_ids::left_type;
|
||||
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
|
||||
using left_sorted_values = typename left_sort::sorted_values;
|
||||
using left_sorted_ids = typename left_sort::sorted_ids;
|
||||
|
||||
using right_unsorted_values = typename split_unsorted_values::right_type;
|
||||
using right_unsorted_ids = typename split_unsorted_ids::right_type;
|
||||
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
|
||||
using right_sorted_values = typename right_sort::sorted_values;
|
||||
using right_sorted_ids = typename right_sort::sorted_ids;
|
||||
|
||||
using merged_sorted = sorted_sequence_merge<left_sorted_values,
|
||||
left_sorted_ids,
|
||||
right_sorted_values,
|
||||
right_sorted_ids,
|
||||
Compare>;
|
||||
|
||||
using sorted_values = typename merged_sorted::merged_values;
|
||||
using sorted_ids = typename merged_sorted::merged_ids;
|
||||
static constexpr sort_result sorted = compute();
|
||||
using sorted_values = sequence<sorted.values[Idx]...>;
|
||||
using sorted_ids = sequence<sorted.ids[Idx]...>;
|
||||
};
|
||||
|
||||
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
|
||||
struct sequence_sort_impl<sequence<ValueX, ValueY>, sequence<IdX, IdY>, Compare>
|
||||
{
|
||||
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
|
||||
|
||||
using sorted_values = typename std::
|
||||
conditional<choose_x, sequence<ValueX, ValueY>, sequence<ValueY, ValueX>>::type;
|
||||
using sorted_ids =
|
||||
typename std::conditional<choose_x, sequence<IdX, IdY>, sequence<IdY, IdX>>::type;
|
||||
};
|
||||
|
||||
template <index_t Value, index_t Id, typename Compare>
|
||||
struct sequence_sort_impl<sequence<Value>, sequence<Id>, Compare>
|
||||
{
|
||||
using sorted_values = sequence<Value>;
|
||||
using sorted_ids = sequence<Id>;
|
||||
};
|
||||
|
||||
template <typename Compare>
|
||||
struct sequence_sort_impl<sequence<>, sequence<>, Compare>
|
||||
{
|
||||
using sorted_values = sequence<>;
|
||||
using sorted_ids = sequence<>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename Values, typename Compare>
|
||||
struct sequence_sort
|
||||
{
|
||||
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::size(), 1>::type;
|
||||
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
|
||||
static constexpr index_t n = Values::size();
|
||||
using idx_seq = make_index_sequence<n>;
|
||||
|
||||
// this is output
|
||||
using type = typename sort::sorted_values;
|
||||
using sorted2unsorted_map = typename sort::sorted_ids;
|
||||
using helper = detail::sequence_sort_helper<remove_cvref_t<Values>, Compare, idx_seq>;
|
||||
|
||||
using type = typename helper::sorted_values;
|
||||
using sorted2unsorted_map = typename helper::sorted_ids;
|
||||
};
|
||||
|
||||
template <typename Values, typename Less, typename Equal>
|
||||
@@ -782,10 +680,42 @@ struct sequence_unique_sort
|
||||
using sorted2unsorted_map = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
// Validates that a sequence is a permutation of {0, 1, ..., N-1}.
|
||||
// Uses a constexpr loop instead of instantiating sequence_sort.
|
||||
namespace detail {
|
||||
|
||||
template <index_t... Is>
|
||||
constexpr bool check_valid_sequence_map()
|
||||
{
|
||||
constexpr index_t n = sizeof...(Is);
|
||||
if constexpr(n == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t vals[] = {Is...};
|
||||
static_array<bool, n> seen{};
|
||||
for(index_t i = 0; i < n; ++i)
|
||||
{
|
||||
if(vals[i] < 0 || vals[i] >= n || seen[vals[i]])
|
||||
return false;
|
||||
seen[vals[i]] = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename SeqMap>
|
||||
struct is_valid_sequence_map
|
||||
: std::is_same<typename arithmetic_sequence_gen<0, SeqMap::size(), 1>::type,
|
||||
typename sequence_sort<SeqMap, less<index_t>>::type>
|
||||
struct is_valid_sequence_map : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
struct is_valid_sequence_map<sequence<Is...>>
|
||||
: std::integral_constant<bool, detail::check_valid_sequence_map<Is...>()>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
@@ -376,9 +376,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transf
|
||||
constexpr auto all_up_dim_new_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
static_assert(
|
||||
is_valid_sequence_map<remove_cvref_t<decltype(all_low_dim_old_top_ids)>>::value &&
|
||||
is_valid_sequence_map<remove_cvref_t<decltype(all_up_dim_new_top_ids)>>::value,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
|
||||
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
|
||||
@@ -443,8 +444,8 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
static_assert(is_valid_sequence_map<remove_cvref_t<decltype(all_old_top_ids)>>::value &&
|
||||
is_valid_sequence_map<remove_cvref_t<decltype(all_new_top_ids)>>::value,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
|
||||
@@ -135,65 +135,147 @@ struct idx_identity
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
struct static_ford_impl
|
||||
// Computes the inverse of a permutation as a constexpr array.
|
||||
// Avoids the sequence_map_inverse -> is_valid_sequence_map -> sequence_sort chain.
|
||||
template <class Perm>
|
||||
struct inverse_perm;
|
||||
|
||||
template <index_t... Ps>
|
||||
struct inverse_perm<sequence<Ps...>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
|
||||
static constexpr auto compute()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
constexpr index_t n = sizeof...(Ps);
|
||||
static_array<index_t, n> result{};
|
||||
constexpr index_t input[] = {Ps...};
|
||||
for(index_t i = 0; i < n; ++i)
|
||||
{
|
||||
result[input[i]] = i;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
static constexpr auto value = compute();
|
||||
};
|
||||
|
||||
// Decomposes a linear index into multi-dimensional indices using pre-computed
|
||||
// strides. Uses a single flat static_for instead of recursive nesting, which
|
||||
// eliminates intermediate lambda closure instantiations.
|
||||
template <class OrderedLengths, class IndexSeq>
|
||||
struct index_decomposer;
|
||||
|
||||
template <index_t... Ls, index_t... Is>
|
||||
struct index_decomposer<sequence<Ls...>, sequence<Is...>>
|
||||
{
|
||||
static constexpr index_t n_dim = sizeof...(Ls);
|
||||
static constexpr static_array<index_t, n_dim> lengths = {{Ls...}};
|
||||
|
||||
static constexpr static_array<index_t, n_dim> compute_all_strides()
|
||||
{
|
||||
static_array<index_t, n_dim> result{};
|
||||
if constexpr(n_dim > 0)
|
||||
{
|
||||
result[n_dim - 1] = 1;
|
||||
for(index_t i = n_dim - 1; i > 0; --i)
|
||||
{
|
||||
result[i - 1] = result[i] * lengths[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...>)
|
||||
// CurrentOrderedId: sequence<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
|
||||
static constexpr static_array<index_t, n_dim> strides = compute_all_strides();
|
||||
|
||||
// Compile-time decomposition: linear index -> sequence of per-dimension indices
|
||||
template <index_t LinearIdx>
|
||||
using decompose = sequence<((LinearIdx / strides[Is]) % lengths[Is])...>;
|
||||
|
||||
// Decompose AND reorder in one step using a pre-computed inverse permutation.
|
||||
// Produces the unordered multi-index directly, avoiding per-iteration
|
||||
// reorder_old_to_new member function instantiations on each unique sequence type.
|
||||
template <index_t LinearIdx, class New2Old>
|
||||
using decompose_reordered = sequence<((LinearIdx / strides[inverse_perm<New2Old>::value[Is]]) %
|
||||
lengths[inverse_perm<New2Old>::value[Is]])...>;
|
||||
};
|
||||
|
||||
// Calls f(decompose<I>{}) for each linear index I in the pack, using a single
|
||||
// fold expression. Bypasses the static_for lambda entirely, eliminating M*N
|
||||
// intermediate lambda closure instantiations that the lambda-based approach creates.
|
||||
template <class Decomposer, class LinearIdxSeq>
|
||||
struct ford_applier;
|
||||
|
||||
template <class Decomposer, index_t... LinearIds>
|
||||
struct ford_applier<Decomposer, sequence<LinearIds...>>
|
||||
{
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
|
||||
f, CurrentOrderedId::push_back(I));
|
||||
});
|
||||
if constexpr(sizeof...(LinearIds) > 0)
|
||||
{
|
||||
(f(typename Decomposer::template decompose<LinearIds>{}), ...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_ford_impl<sequence<>, Orders>
|
||||
// Same as ford_applier but applies reordering during decomposition.
|
||||
template <class Decomposer, class New2Old, class LinearIdxSeq>
|
||||
struct ford_applier_reordered;
|
||||
|
||||
template <class Decomposer, class New2Old, index_t... LinearIds>
|
||||
struct ford_applier_reordered<Decomposer, New2Old, sequence<LinearIds...>>
|
||||
{
|
||||
// F signature: F(sequence<...>)
|
||||
// OrderedId: sequence<...>
|
||||
template <class F, class OrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
// retrive unordered Id
|
||||
f(OrderedId::reorder_old_to_new(Orders{}));
|
||||
if constexpr(sizeof...(LinearIds) > 0)
|
||||
{
|
||||
(f(typename Decomposer::template decompose_reordered<LinearIds, New2Old>{}), ...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is sequence<...>, it is the length of each dimension for
|
||||
// N-dimensional loop
|
||||
// Orders is sequence<...>, it is the order of dimension in which static_ford
|
||||
// will loop over each
|
||||
// dimension
|
||||
// Compile-time N-dimensional loop with static multi-indices.
|
||||
// Uses direct fold expansion with index decomposition, producing zero
|
||||
// intermediate lambda closures. Each iteration calls f with a compile-time
|
||||
// sequence<i0, i1, ...> containing the multi-dimensional index.
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
static constexpr index_t n_dim = Lengths::size();
|
||||
static constexpr index_t total_size =
|
||||
reduce_on_sequence(Lengths{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static constexpr bool is_identity_order = std::is_same_v<Orders, make_index_sequence<n_dim>>;
|
||||
|
||||
// For identity order, OrderedLengths == Lengths (no reorder needed).
|
||||
// For non-identity, reorder lengths according to iteration order.
|
||||
// Both branches must be valid types, but only the active one is used.
|
||||
using OrderedLengths =
|
||||
std::conditional_t<is_identity_order,
|
||||
Lengths,
|
||||
remove_cvref_t<decltype(Lengths::reorder_new_to_old(Orders{}))>>;
|
||||
using Decomposer = detail::index_decomposer<OrderedLengths, make_index_sequence<n_dim>>;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
|
||||
if constexpr(is_identity_order)
|
||||
{
|
||||
detail::ford_applier<Decomposer, make_index_sequence<total_size>>{}(f);
|
||||
}
|
||||
else
|
||||
{
|
||||
detail::ford_applier_reordered<Decomposer, Orders, make_index_sequence<total_size>>{}(
|
||||
f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -103,6 +103,42 @@ struct Max
|
||||
}
|
||||
};
|
||||
|
||||
struct Min
|
||||
{
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::max();
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return min(y, x);
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_min = min(y, x);
|
||||
if(x < y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_min;
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsMax
|
||||
{
|
||||
template <
|
||||
|
||||
@@ -129,7 +129,13 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr bool EightWave = (MWave * NWave == 8);
|
||||
|
||||
#if defined(__gfx9__)
|
||||
static constexpr bool EightWave = (MWave * NWave == 8);
|
||||
#else
|
||||
static constexpr bool EightWave = false;
|
||||
#endif
|
||||
|
||||
static constexpr index_t BlockedXDLN_PerWarp =
|
||||
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
@@ -1516,6 +1516,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
using DDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::DDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::ODataType>;
|
||||
using OGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::OGradDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::LSEDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
|
||||
@@ -1557,25 +1558,31 @@ struct FmhaBwdOGradDotOKernel
|
||||
const void* o_ptr;
|
||||
const void* do_ptr;
|
||||
void* d_ptr;
|
||||
const void* lse_ptr; // log-sum-exp from forward pass, shape [batch, nhead, seqlen_q]
|
||||
const LSEDataType* sink_ptr; // sink scores, shape [batch, nhead]; nullptr disables sink
|
||||
LSEDataType* d_sink_ptr; // sink gradient output, shape [nhead]; nullptr disables sink grad
|
||||
|
||||
float p_undrop;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead; // used to index sink_ptr / d_sink_ptr
|
||||
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_o;
|
||||
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_d;
|
||||
// LSE and D always share the same layout; this stride covers both.
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
};
|
||||
|
||||
struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_d;
|
||||
// LSE and D always share the same layout; this stride covers both.
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
};
|
||||
|
||||
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
|
||||
@@ -1593,32 +1600,40 @@ struct FmhaBwdOGradDotOKernel
|
||||
MakeKargs(const void* o_ptr,
|
||||
const void* do_ptr,
|
||||
void* d_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* sink_ptr,
|
||||
void* d_sink_ptr,
|
||||
float p_undrop,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_d,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_d)
|
||||
ck_tile::index_t batch_stride_lsed)
|
||||
{
|
||||
Kargs kargs{{o_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
lse_ptr,
|
||||
reinterpret_cast<const LSEDataType*>(sink_ptr),
|
||||
reinterpret_cast<LSEDataType*>(d_sink_ptr),
|
||||
p_undrop,
|
||||
seqlen_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
stride_do,
|
||||
stride_o,
|
||||
nhead_stride_do,
|
||||
nhead_stride_o,
|
||||
nhead_stride_d},
|
||||
nhead_stride_lsed},
|
||||
batch_stride_do,
|
||||
batch_stride_o,
|
||||
batch_stride_d};
|
||||
batch_stride_lsed};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -1628,28 +1643,36 @@ struct FmhaBwdOGradDotOKernel
|
||||
MakeKargs(const void* o_ptr,
|
||||
const void* do_ptr,
|
||||
void* d_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* sink_ptr,
|
||||
void* d_sink_ptr,
|
||||
float p_undrop,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_d)
|
||||
ck_tile::index_t nhead_stride_lsed)
|
||||
{
|
||||
Kargs kargs{{o_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
lse_ptr,
|
||||
reinterpret_cast<const LSEDataType*>(sink_ptr),
|
||||
reinterpret_cast<LSEDataType*>(d_sink_ptr),
|
||||
p_undrop,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_v,
|
||||
nhead,
|
||||
stride_do,
|
||||
stride_o,
|
||||
nhead_stride_do,
|
||||
nhead_stride_o,
|
||||
nhead_stride_d},
|
||||
nhead_stride_lsed},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
|
||||
@@ -1683,18 +1706,18 @@ struct FmhaBwdOGradDotOKernel
|
||||
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
|
||||
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_do = 0;
|
||||
long_index_t batch_offset_d = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_do = 0;
|
||||
long_index_t batch_offset_lsed = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
batch_offset_do = query_start * kargs.stride_do;
|
||||
batch_offset_d = query_start;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
batch_offset_do = query_start * kargs.stride_do;
|
||||
batch_offset_lsed = query_start;
|
||||
|
||||
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
@@ -1722,11 +1745,20 @@ struct FmhaBwdOGradDotOKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
|
||||
batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
|
||||
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
|
||||
}
|
||||
|
||||
// Read per-head sink score and convert to log2 domain so the pipeline can use exp2.
|
||||
// Pre-multiply by log2e so that exp2(sink_value - log2e*lse) == exp(raw_sink - lse).
|
||||
// -inf is left unchanged (log2e * -inf == -inf) to keep P_sink -> 0 when sink is disabled.
|
||||
const LSEDataType sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? log2e_v<LSEDataType> *
|
||||
kargs.sink_ptr[static_cast<long_index_t>(i_batch) * kargs.nhead + i_nhead]
|
||||
: -numeric<LSEDataType>::infinity();
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
@@ -1734,9 +1766,13 @@ struct FmhaBwdOGradDotOKernel
|
||||
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
|
||||
batch_offset_do;
|
||||
const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
|
||||
batch_offset_lsed;
|
||||
|
||||
DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
|
||||
batch_offset_d;
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
|
||||
batch_offset_lsed;
|
||||
|
||||
// O/dO/D DRAM and DRAM window
|
||||
const auto o_dram = [&]() {
|
||||
@@ -1770,13 +1806,31 @@ struct FmhaBwdOGradDotOKernel
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
|
||||
|
||||
auto do_dram_window =
|
||||
make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
|
||||
|
||||
auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
|
||||
|
||||
FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
|
||||
// nullptr when sink grad is disabled; the pipeline checks this to skip the sink path
|
||||
LSEDataType* atomic_sink_grad_ptr =
|
||||
kargs.d_sink_ptr == nullptr ? nullptr : kargs.d_sink_ptr + i_nhead;
|
||||
|
||||
// lse_ptr is always valid (also needed by the main bwd kernel).
|
||||
// The actual load happens inside the pipeline only when atomic_sink_grad_ptr != nullptr.
|
||||
auto lse_dram = [&]() {
|
||||
const auto lse_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
lse_ptr, make_tuple(kargs.seqlen_q), number<1>{});
|
||||
return pad_tensor_view(
|
||||
lse_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
auto lse_dram_window = make_tile_window(lse_dram, make_tuple(number<kM0>{}), {i_m0});
|
||||
|
||||
FmhaBwdOGradDotO{}(o_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
sink_value,
|
||||
kargs.p_undrop,
|
||||
atomic_sink_grad_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ struct BlockFmhaBwdOGradDotO
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; // needed for sink gradient
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
@@ -32,11 +33,18 @@ struct BlockFmhaBwdOGradDotO
|
||||
|
||||
template <typename ODramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp>
|
||||
// Computes D = diag(dO * O) and optionally accumulates the sink token gradient.
|
||||
// sink_value: log-space sink score; pass -inf and atomic_sink_grad_ptr=nullptr to skip sink.
|
||||
// atomic_sink_grad_ptr: per-head accumulator in global memory; nullptr disables sink path.
|
||||
CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
float p_undrop) const
|
||||
const LSEDataType sink_value,
|
||||
float p_undrop,
|
||||
LSEDataType* atomic_sink_grad_ptr = nullptr) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ODataType, remove_cvref_t<typename ODramBlockWindowTmp::DataType>> &&
|
||||
@@ -44,6 +52,10 @@ struct BlockFmhaBwdOGradDotO
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
// atomic_sink_grad_ptr is reinterpret_cast to float* in the sink path;
|
||||
// ensure LSEDataType is float so the cast is well-defined.
|
||||
static_assert(std::is_same_v<LSEDataType, float>,
|
||||
"sink gradient atomicAdd requires LSEDataType == float");
|
||||
|
||||
static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kBlockSize ==
|
||||
@@ -67,14 +79,13 @@ struct BlockFmhaBwdOGradDotO
|
||||
|
||||
auto do_ = load_tile(do_dram_window);
|
||||
|
||||
// declare d
|
||||
// D[q] = sum_j(O[q,j] * dO[q,j]), used in softmax backward
|
||||
constexpr auto d_dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
|
||||
|
||||
auto d = make_static_distributed_tensor<DDataType>(d_dstr);
|
||||
|
||||
clear_tile(d); // Initialize D
|
||||
clear_tile(d);
|
||||
|
||||
constexpr auto o_spans = decltype(o)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -86,9 +97,67 @@ struct BlockFmhaBwdOGradDotO
|
||||
});
|
||||
});
|
||||
|
||||
// Scale by p_undrop (=1 when dropout is disabled)
|
||||
tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d);
|
||||
|
||||
store_tile(d_dram_block_window_tmp, d);
|
||||
|
||||
// Sink gradient path: skipped entirely when atomic_sink_grad_ptr is nullptr
|
||||
if(atomic_sink_grad_ptr != nullptr)
|
||||
{
|
||||
// Load LSE only on the sink path to avoid unnecessary global memory reads
|
||||
constexpr auto lse_dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
o.get_tile_distribution().get_static_tile_distribution_encoding(),
|
||||
sequence<1>{}));
|
||||
auto lse_dram_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
lse_dram_block_window_tmp.get_window_origin(),
|
||||
lse_dstr);
|
||||
auto lse_ = load_tile(lse_dram_window);
|
||||
|
||||
// Compute per-query contribution: -P_sink[q] * D[q]
|
||||
// where P_sink[q] = exp2(sink_value - log2e*lse[q])
|
||||
// sink_value has already been pre-multiplied by log2e at the kernel call site,
|
||||
// so exp2(sink_value - log2e*lse) == exp(raw_sink - lse).
|
||||
// exp2 maps directly to the v_exp_f32 hardware instruction on AMD GPUs.
|
||||
// Always accumulate in float regardless of DDataType to avoid precision loss
|
||||
// and to ensure atomicAdd works correctly on all architectures.
|
||||
auto sink_val_tensor = make_static_distributed_tensor<float>(d_dstr);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& s_out, const auto& l_in, const auto& d_in) {
|
||||
float p_sink = exp2(type_convert<float>(sink_value) -
|
||||
log2e_v<float> * type_convert<float>(l_in));
|
||||
s_out = -p_sink * type_convert<float>(d_in);
|
||||
},
|
||||
sink_val_tensor,
|
||||
lse_,
|
||||
d);
|
||||
|
||||
// Reduce contributions held by this thread
|
||||
float thread_sum = 0.f;
|
||||
constexpr auto s_spans = decltype(sink_val_tensor)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
thread_sum += sink_val_tensor(i_idx);
|
||||
});
|
||||
|
||||
// Warp-level reduction: fold thread_sum across lanes so only one
|
||||
// atomicAdd per warp is issued instead of one per thread.
|
||||
#if defined(__HIP_DEVICE_COMPILE__) || defined(__CUDA_ARCH__)
|
||||
const index_t warp_sz = get_warp_size();
|
||||
for(index_t offset = warp_sz >> 1; offset > 0; offset >>= 1)
|
||||
thread_sum += warp_shuffle_down(thread_sum, offset);
|
||||
|
||||
// Only lane 0 of each warp writes to global memory.
|
||||
// Note: this atomicAdd is non-deterministic across runs regardless of the
|
||||
// -deterministic flag, because d_sink is a single scalar per head accumulated
|
||||
// across all thread-blocks. The practical impact is negligible for this value.
|
||||
if(get_lane_id() == 0)
|
||||
atomicAdd(reinterpret_cast<float*>(atomic_sink_grad_ptr), thread_sum);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -163,7 +163,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
amd_wave_read_first_lane(integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0));
|
||||
|
||||
// check early exit if no work to do.
|
||||
if(num_total_loop <= 0)
|
||||
// __builtin_expect is load-bearing: omitting it causes incorrect AGPR allocation in
|
||||
// the dK/dV accumulation loop on some compiler versions, leading to wrong results.
|
||||
if(__builtin_expect(num_total_loop <= 0, 0))
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleared, return it
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
|
||||
@@ -67,6 +67,7 @@ struct BlockFmhaBwdPipelineProblem
|
||||
template <typename ODataType_,
|
||||
typename OGradDataType_,
|
||||
typename DDataType_,
|
||||
typename LSEDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kVHeaddim_,
|
||||
bool kIsGroupMode_,
|
||||
@@ -76,6 +77,7 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using OGradDataType = remove_cvref_t<OGradDataType_>;
|
||||
using DDataType = remove_cvref_t<DDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
|
||||
|
||||
@@ -1226,23 +1226,37 @@ struct UniversalGemmKernel
|
||||
s_waitcnt_barrier();
|
||||
const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
// Apply pivot to M tile index first, then use the same pivoted index
|
||||
// for both data-tile selection and chunk-signal wait.
|
||||
auto iM_eff = amd_wave_read_first_lane(iM);
|
||||
|
||||
if(kargs.async_input_scheduler.chunk_signals != nullptr)
|
||||
{
|
||||
const auto tile_idx_pivot =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
|
||||
const auto tiles_m = amd_wave_read_first_lane(
|
||||
integer_divide_ceil(kargs.M, TilePartitioner::MPerBlock));
|
||||
if(tiles_m > 0)
|
||||
{
|
||||
iM_eff = amd_wave_read_first_lane((iM_eff + tile_idx_pivot) % tiles_m);
|
||||
}
|
||||
}
|
||||
|
||||
const index_t i_m = amd_wave_read_first_lane(iM_eff * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
// Synchronize with producer to ensure input data is ready before processing tile
|
||||
if(kargs.async_input_scheduler.chunk_signals != nullptr)
|
||||
{
|
||||
const auto tiles_per_chunk =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m);
|
||||
const auto tile_idx_pivot =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
|
||||
const auto num_chunks =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.num_chunks);
|
||||
if(tiles_per_chunk > 0 && num_chunks > 0)
|
||||
{
|
||||
// Pivot allows rotating chunk assignments for load balancing
|
||||
const auto chunk_idx = amd_wave_read_first_lane(
|
||||
((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks);
|
||||
const auto chunk_idx =
|
||||
amd_wave_read_first_lane((iM_eff / tiles_per_chunk) % num_chunks);
|
||||
workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals);
|
||||
chunk_barrier.wait_eq_wave(/*value=*/1, /*offset=*/chunk_idx);
|
||||
}
|
||||
|
||||
@@ -903,13 +903,11 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const auto& d_block_window =
|
||||
MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
|
||||
@@ -65,7 +65,9 @@ run_grouped_conv_backward_data_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
const ckt::Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config& s_conf)
|
||||
{
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
// Run first instance as dummy to get proper time from the first instance
|
||||
bool dummy_run_executed = false;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
std::string best_op_name, op_name;
|
||||
int best_split_k = 0;
|
||||
ck::index_t best_instance_index = -1;
|
||||
@@ -121,6 +123,13 @@ run_grouped_conv_backward_data_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
run_alg_func(args_k_batch, inputs, outputs, s_conf);
|
||||
if(is_supported)
|
||||
{
|
||||
if((s_conf.time_kernel_ || s_conf.flush_cache_) && !dummy_run_executed)
|
||||
{
|
||||
// Run first instance twice
|
||||
std::tie(is_supported, avg_time, op_name) =
|
||||
run_alg_func(args_k_batch, inputs, outputs, s_conf);
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
ckt::ValidationReport report;
|
||||
auto&& [rtol, atol] =
|
||||
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);
|
||||
|
||||
@@ -106,17 +106,17 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
{
|
||||
ckt::Args<SIGNATURE> args_k_batch = args;
|
||||
args_k_batch.k_batch = k_batch;
|
||||
if((s_conf.time_kernel_ || s_conf.flush_cache_) && !dummy_run_executed)
|
||||
{
|
||||
// Run first instance twice when profiling to stabilize timing
|
||||
std::tie(is_supported, avg_time, op_name) =
|
||||
run_alg_func(args_k_batch, inputs, outputs, s_conf);
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
std::tie(is_supported, avg_time, op_name) =
|
||||
run_alg_func(args_k_batch, inputs, outputs, s_conf);
|
||||
if(is_supported)
|
||||
{
|
||||
if((s_conf.time_kernel_ || s_conf.flush_cache_) && !dummy_run_executed)
|
||||
{
|
||||
// Run first instance twice when profiling to stabilize timing
|
||||
std::tie(is_supported, avg_time, op_name) =
|
||||
run_alg_func(args_k_batch, inputs, outputs, s_conf);
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
ckt::ValidationReport report;
|
||||
auto&& [rtol, atol] =
|
||||
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);
|
||||
|
||||
@@ -86,15 +86,16 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
auto run_alg = [&](auto&& run_alg_func) {
|
||||
if(!dummy_run_executed)
|
||||
{
|
||||
// Run first instance twice
|
||||
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
|
||||
if(is_supported)
|
||||
{
|
||||
if((s_conf.time_kernel_ || s_conf.flush_cache_) && !dummy_run_executed)
|
||||
{
|
||||
// Run first instance twice
|
||||
std::tie(is_supported, avg_time, op_name) =
|
||||
run_alg_func(args, inputs, outputs, s_conf);
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
best_avg_time = std::min(best_avg_time, avg_time);
|
||||
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name
|
||||
|
||||
@@ -197,10 +197,11 @@ bool profile_conv_bwd_data_impl(int do_verification,
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
int num_kernel = 0;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
int num_kernel = 0;
|
||||
bool dummy_run_executed = false;
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
@@ -230,16 +231,38 @@ bool profile_conv_bwd_data_impl(int do_verification,
|
||||
// skip test if instance_index is specified
|
||||
continue;
|
||||
}
|
||||
// for conv bwd data, some input tensor element are zero, but not written by kernel,
|
||||
// need to set zero
|
||||
in_device_buf.SetZero();
|
||||
if(!time_kernel)
|
||||
{
|
||||
// Don't clear for perf measurement.
|
||||
// For non-grouped solver user has to clear input on his own.
|
||||
// for conv bwd data, some input tensor element are zero, but not written by kernel,
|
||||
// need to set zero
|
||||
in_device_buf.SetZero();
|
||||
}
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
float avg_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
// Run first instance twice to get proper time
|
||||
if(time_kernel && !dummy_run_executed)
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr,
|
||||
time_kernel,
|
||||
0 /*log_level*/,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
time_kernel /*flush_cache*/});
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr,
|
||||
time_kernel,
|
||||
0 /*log_level*/,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
time_kernel /*flush_cache*/});
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
|
||||
@@ -287,6 +287,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
bool pass = true;
|
||||
index_t num_kernel = 0;
|
||||
index_t valid_instances = 0;
|
||||
bool dummy_run_executed = false;
|
||||
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr, const index_t& split_k_for_run) {
|
||||
// workspace_sz will be equal to 0 for other layout than NGCHW
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
@@ -317,8 +319,25 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
float avg_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
// Run first instance twice to get proper time
|
||||
if(time_kernel && !dummy_run_executed)
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr,
|
||||
time_kernel,
|
||||
0 /*log_level*/,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
time_kernel /*flush_cache*/});
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr,
|
||||
time_kernel,
|
||||
0 /*log_level*/,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
time_kernel /*flush_cache*/});
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
@@ -495,7 +514,6 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
{
|
||||
std::cout << "\nValid instances for this problem:" << std::endl;
|
||||
}
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)
|
||||
|
||||
@@ -296,7 +296,8 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
index_t best_instance_index = 0;
|
||||
|
||||
// profile device op instances
|
||||
bool pass = true;
|
||||
bool pass = true;
|
||||
bool dummy_run_executed = false;
|
||||
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
|
||||
// workspace_sz will be equal to 0 for other layout than NGCHW
|
||||
@@ -331,6 +332,19 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
// Run first instance twice to get proper time
|
||||
if(time_kernel && !dummy_run_executed)
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr,
|
||||
time_kernel,
|
||||
0 /*log_level*/,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
time_kernel /*flush_cache*/});
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr,
|
||||
time_kernel,
|
||||
@@ -437,30 +451,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
std::cout << "\nValid instances for this problem:" << std::endl;
|
||||
}
|
||||
|
||||
// Run first instance twice to get proper time
|
||||
{
|
||||
auto argument_ptr = op_ptrs[0]->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
{},
|
||||
{},
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
run_impl(op_ptrs[0], argument_ptr);
|
||||
}
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
|
||||
|
||||
@@ -89,7 +89,8 @@ int call_profiler(const ckt::Args<SIGNATURE>& args,
|
||||
0 /*log_level*/,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
true /*is_gpu_timer_*/});
|
||||
true /*is_gpu_timer_*/,
|
||||
time_kernel /*flush_cache*/});
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "\nBest configuration parameters:" << "\n\tname: " << op_name << " (instance "
|
||||
|
||||
@@ -69,4 +69,4 @@ add_subdirectory(fmha)
|
||||
add_subdirectory(gemm_tile_engine)
|
||||
add_subdirectory(pooling)
|
||||
add_subdirectory(grouped_conv)
|
||||
add_subdirectory(gemm_streamk_tile_engine)
|
||||
add_subdirectory(pooling_tile_engine)
|
||||
|
||||
@@ -355,6 +355,102 @@ TEST(SequenceSort, SortSingleElement)
|
||||
EXPECT_TRUE((std::is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_sort sorted2unsorted_map (index tracking)
|
||||
TEST(SequenceSort, SortedMapUnsorted)
|
||||
{
|
||||
using Seq = sequence<5, 2, 8, 1, 9>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
// sorted = <1,2,5,8,9>, original indices = <3,1,0,2,4>
|
||||
using Expected = sequence<3, 1, 0, 2, 4>;
|
||||
EXPECT_TRUE((std::is_same<Map, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapAlreadySorted)
|
||||
{
|
||||
using Seq = sequence<1, 2, 3, 4, 5>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
// Already sorted: map should be identity
|
||||
using Expected = sequence<0, 1, 2, 3, 4>;
|
||||
EXPECT_TRUE((std::is_same<Map, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapRoundTrip)
|
||||
{
|
||||
// Verify: sorted_values[i] == original[sorted2unsorted_map[i]]
|
||||
using Seq = sequence<5, 2, 8, 1, 9>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
// sorted = <1,2,5,8,9>, map = <3,1,0,2,4>
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(0)), Sort::type::at(0));
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(1)), Sort::type::at(1));
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(2)), Sort::type::at(2));
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(3)), Sort::type::at(3));
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(4)), Sort::type::at(4));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapWithDuplicates)
|
||||
{
|
||||
using Seq = sequence<3, 1, 3, 1>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
using Sorted = typename Sort::type;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
// sorted = <1,1,3,3>
|
||||
using ExpectedSorted = sequence<1, 1, 3, 3>;
|
||||
EXPECT_TRUE((std::is_same<Sorted, ExpectedSorted>::value));
|
||||
// Verify round-trip: original[map[i]] == sorted[i] for all i
|
||||
// (don't assert specific index order for duplicates — sort stability may vary)
|
||||
EXPECT_EQ(Seq::at(Map::at(0)), Sorted::at(0));
|
||||
EXPECT_EQ(Seq::at(Map::at(1)), Sorted::at(1));
|
||||
EXPECT_EQ(Seq::at(Map::at(2)), Sorted::at(2));
|
||||
EXPECT_EQ(Seq::at(Map::at(3)), Sorted::at(3));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapReverseSorted)
|
||||
{
|
||||
using Seq = sequence<5, 4, 3, 2, 1>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
using Sorted = typename Sort::type;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
using ExpSorted = sequence<1, 2, 3, 4, 5>;
|
||||
using ExpMap = sequence<4, 3, 2, 1, 0>;
|
||||
EXPECT_TRUE((std::is_same<Sorted, ExpSorted>::value));
|
||||
EXPECT_TRUE((std::is_same<Map, ExpMap>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapEmpty)
|
||||
{
|
||||
using Sort = sequence_sort<sequence<>, less<index_t>>;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
EXPECT_TRUE((std::is_same<Map, sequence<>>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapSingleElement)
|
||||
{
|
||||
using Sort = sequence_sort<sequence<42>, less<index_t>>;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
EXPECT_TRUE((std::is_same<Map, sequence<0>>::value));
|
||||
}
|
||||
|
||||
// Test sequence_unique_sort sorted2unsorted_map
|
||||
TEST(SequenceUniqueSort, UniqueSortMap)
|
||||
{
|
||||
using Seq = sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>;
|
||||
using Result = sequence_unique_sort<Seq, less<index_t>, equal<index_t>>;
|
||||
using Map = typename Result::sorted2unsorted_map;
|
||||
// sorted unique = <1,2,3,4,5,6,9>
|
||||
// The map should reference the first occurrence of each unique value in the original
|
||||
// Verify round-trip: for each i, original[map[i]] == sorted_unique[i]
|
||||
using Values = typename Result::type;
|
||||
EXPECT_EQ(Seq::at(Map::at(0)), Values::at(0)); // 1
|
||||
EXPECT_EQ(Seq::at(Map::at(1)), Values::at(1)); // 2
|
||||
EXPECT_EQ(Seq::at(Map::at(2)), Values::at(2)); // 3
|
||||
EXPECT_EQ(Seq::at(Map::at(3)), Values::at(3)); // 4
|
||||
EXPECT_EQ(Seq::at(Map::at(4)), Values::at(4)); // 5
|
||||
EXPECT_EQ(Seq::at(Map::at(5)), Values::at(5)); // 6
|
||||
EXPECT_EQ(Seq::at(Map::at(6)), Values::at(6)); // 9
|
||||
}
|
||||
|
||||
// Test sequence_unique_sort
|
||||
TEST(SequenceUniqueSort, UniqueSort)
|
||||
{
|
||||
@@ -405,6 +501,24 @@ TEST(SequenceMap, InvalidMapMissing)
|
||||
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, InvalidMapNegative)
|
||||
{
|
||||
using Map = sequence<0, -1, 2>;
|
||||
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, ValidMapSingleElement)
|
||||
{
|
||||
EXPECT_TRUE((is_valid_sequence_map<sequence<0>>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, InvalidMapSingleElement)
|
||||
{
|
||||
EXPECT_FALSE((is_valid_sequence_map<sequence<1>>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, ValidMapEmpty) { EXPECT_TRUE((is_valid_sequence_map<sequence<>>::value)); }
|
||||
|
||||
// Test sequence_map_inverse
|
||||
// Note: sequence_map_inverse inverts a mapping where Map[i] = j means old position i maps to new
|
||||
// position j The inverse gives us new position i came from old position inverse[i]
|
||||
|
||||
@@ -91,7 +91,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det, // deterministic
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -333,7 +334,8 @@ TEST_P(BasicQPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -419,7 +421,8 @@ TEST_P(BasicKVPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -513,7 +516,8 @@ TEST_P(QKVPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -620,7 +624,8 @@ TEST_P(ZeroLengthPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -741,7 +746,8 @@ TEST_P(VariedPaddingRatios, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -843,7 +849,8 @@ TEST_P(PaddingWithMask, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -977,7 +984,8 @@ TEST_P(MultiBatchPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
|
||||
@@ -19,55 +19,93 @@ set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4
|
||||
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
|
||||
|
||||
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
set(STREAMK_EXTENDED_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent_mem.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
|
||||
# We only test fp8 and bf8 on gfx942 and gfx950 since these types are not natively supported on gfx90a
|
||||
if(GPU_TARGETS MATCHES "gfx942|gfx950")
|
||||
list(APPEND STREAMK_EXTENDED_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent_mem.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv3.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_compv4.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent_mem.cpp)
|
||||
endif()
|
||||
# ---- Code-generate test .cpp files from types header ----
|
||||
set(STREAMK_TYPES_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_types.hpp)
|
||||
set(STREAMK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/generate_test_files.py)
|
||||
|
||||
add_gtest_executable(test_ck_tile_streamk_extended ${STREAMK_EXTENDED_SOURCES})
|
||||
target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# Re-run configure automatically if the types header changes (e.g. types added/removed)
|
||||
# or if the generation script changes
|
||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT})
|
||||
|
||||
# Define the targets and their corresponding executable names
|
||||
set(STREAMK_GEN_TARGETS extended atomic_smoke linear_smoke tree_smoke pipelines_smoke)
|
||||
set(STREAMK_GEN_EXEC_EXTENDED test_ck_tile_streamk_extended)
|
||||
set(STREAMK_GEN_EXEC_ATOMIC_SMOKE test_ck_tile_streamk_atomic_smoke)
|
||||
set(STREAMK_GEN_EXEC_LINEAR_SMOKE test_ck_tile_streamk_linear_smoke)
|
||||
set(STREAMK_GEN_EXEC_TREE_SMOKE test_ck_tile_streamk_tree_smoke)
|
||||
set(STREAMK_GEN_EXEC_PIPELINES_SMOKE test_ck_tile_streamk_pipelines_smoke)
|
||||
|
||||
# Collect all test targets for umbrella label
|
||||
set(CK_TILE_GEMM_STREAMK_TEST_TARGETS
|
||||
test_ck_tile_streamk_tile_partitioner
|
||||
test_ck_tile_streamk_extended
|
||||
test_ck_tile_streamk_tile_partitioner)
|
||||
|
||||
foreach(target IN LISTS STREAMK_GEN_TARGETS)
|
||||
string(TOUPPER ${target} TARGET_UPPER)
|
||||
set(GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/${target})
|
||||
set(EXEC_NAME ${STREAMK_GEN_EXEC_${TARGET_UPPER}})
|
||||
set(LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/${target}_files.txt)
|
||||
|
||||
# Phase 1 (configure time): discover the list of files that will be generated
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
|
||||
--types_header ${STREAMK_TYPES_HEADER}
|
||||
--output_dir ${GEN_DIR}
|
||||
--target ${target}
|
||||
--list_files ${LIST_FILE}
|
||||
RESULT_VARIABLE ret
|
||||
ERROR_VARIABLE list_files_stderr)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR
|
||||
"Failed to list ${target} test files via Python: ${ret}\n"
|
||||
"stderr: ${list_files_stderr}"
|
||||
)
|
||||
endif()
|
||||
file(STRINGS ${LIST_FILE} ALL_SOURCES_${target})
|
||||
|
||||
# Phase 2 (build time): generate the .cpp files when the types header changes
|
||||
add_custom_command(
|
||||
OUTPUT ${ALL_SOURCES_${target}}
|
||||
COMMAND ${Python3_EXECUTABLE} ${STREAMK_GEN_SCRIPT}
|
||||
--types_header ${STREAMK_TYPES_HEADER}
|
||||
--output_dir ${GEN_DIR}
|
||||
--target ${target}
|
||||
--gen_files
|
||||
DEPENDS ${STREAMK_TYPES_HEADER} ${STREAMK_GEN_SCRIPT}
|
||||
COMMENT "Generating StreamK ${target} test sources from types header")
|
||||
|
||||
# Filter out fp8/bf8 sources on gfx90a since those types are not natively supported
|
||||
set(FILTERED_SOURCES)
|
||||
foreach(src IN LISTS ALL_SOURCES_${target})
|
||||
if(NOT src MATCHES "_(fp8|bf8)_" OR GPU_TARGETS MATCHES "gfx942|gfx950")
|
||||
list(APPEND FILTERED_SOURCES ${src})
|
||||
endif()
|
||||
endforeach()
|
||||
list(APPEND FILTERED_SOURCES test_gemm_streamk_util.cpp)
|
||||
|
||||
add_gtest_executable(${EXEC_NAME} ${FILTERED_SOURCES})
|
||||
target_compile_options(${EXEC_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
list(APPEND CK_TILE_GEMM_STREAMK_TEST_TARGETS ${EXEC_NAME})
|
||||
endforeach()
|
||||
|
||||
# Add python unit tests to validate the code gen logic in generate_test_files.py
|
||||
add_test(
|
||||
NAME test_ck_tile_streamk_generate_test_files
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_generate_test_files.py -v
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
)
|
||||
|
||||
# Label all ck_tile gemm_streamk tests with CK_TILE_GEMM_STREAMK_TESTS for selective execution
|
||||
foreach(test_target ${CK_TILE_GEMM_STREAMK_TEST_TARGETS})
|
||||
set_tests_properties(${test_target} PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
|
||||
endforeach()
|
||||
# Also label the Python test
|
||||
set_tests_properties(test_ck_tile_streamk_generate_test_files PROPERTIES LABELS "CK_TILE_GEMM_STREAMK_TESTS")
|
||||
|
||||
# Umbrella target to build and run all ck_tile gemm_streamk tests
|
||||
# Usage: ninja ck_tile_gemm_streamk_tests
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV3,
|
||||
KernelTypesStreamKBf16NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentCompV4,
|
||||
KernelTypesStreamKBf16NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistentMem, KernelTypesStreamKBf16NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV3, KernelTypesStreamKBf16PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentCompV4, KernelTypesStreamKBf16PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16PersistentMem, KernelTypesStreamKBf16PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV3, KernelTypesStreamKBf8NonPersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentCompV4, KernelTypesStreamKBf8NonPersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistentMem, KernelTypesStreamKBf8NonPersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentCompV3 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV3, KernelTypesStreamKBf8PersistentCompV3);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentCompV4 : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentCompV4, KernelTypesStreamKBf8PersistentCompV4);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8PersistentMem : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8PersistentMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8PersistentMem, KernelTypesStreamKBf8PersistentMem);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user