Merge remote-tracking branch 'origin/develop' into ginolu/sparge_attention

This commit is contained in:
Gino Lu
2026-05-19 21:34:32 -04:00
480 changed files with 77709 additions and 7606 deletions

View File

@@ -52,6 +52,10 @@ option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
option(BUILD_CK_TILE_ENGINE "Build the tile_engine subdirectory" OFF)
option(BUILD_CK_EXAMPLES "Build the example subdirectory" ON)
option(BUILD_CK_TUTORIALS "Build the tutorial subdirectory" ON)
option(CK_ENABLE_ROCM_CK "Build rocm_ck API" OFF)
if(CK_EXPERIMENTAL_BUILDER)
add_definitions(-DCK_EXPERIMENTAL_BUILDER)
@@ -207,6 +211,21 @@ else()
set(USER_GPU_TARGETS 0)
endif()
#Unsupported GPU targets to be filtered from the list:
set(CK_UNSUPPORTED_GPU_TARGETS "gfx900;gfx906;gfx90c")
#If only one of the unsupported targets is requested, generate dummy target and exit here.
if("${GPU_TARGETS}" IN_LIST CK_UNSUPPORTED_GPU_TARGETS)
add_custom_target(ck_dummy_target)
message("CK is not supported for target ${GPU_TARGETS}")
return()
endif()
#If multiple targets are requested, filter out any targets currently on the unsupported list:
message(STATUS "Filtering out unsupported targets: ${CK_UNSUPPORTED_GPU_TARGETS}")
list(REMOVE_ITEM GPU_TARGETS ${CK_UNSUPPORTED_GPU_TARGETS})
list(REMOVE_ITEM GPU_ARCHS ${CK_UNSUPPORTED_GPU_TARGETS})
find_package(hip REQUIRED)
enable_language(HIP)
@@ -229,8 +248,10 @@ if(NOT ENABLE_ASAN_PACKAGING)
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1200;gfx1201")
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600400000 AND ${hip_VERSION_FLAT} LESS 600443483)
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1200;gfx1201;gfx950")
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600443483)
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600443483 AND ${hip_VERSION_FLAT} LESS 700200000)
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx950;gfx10-3-generic;gfx11-generic;gfx12-generic")
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 700200000)
set(CK_GPU_TARGETS "")
endif()
else()
#build CK only for xnack-supported targets when using ASAN
@@ -668,59 +689,64 @@ if(NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
endif()
# Optimization: Search only in library/src where all instance files actually live
# (was searching entire source tree, taking ~40s instead of <1s)
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/library/src/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list})
set(target_dir)
IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}")
set(cmake_instance)
file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES")
set(add_inst 1)
endif()
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES)
list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance)
endif()
ENDIF()
ENDFOREACH()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF)
option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
option(BUILD_CK_DEVICE_INSTANCES "Build device operation instances in library/" ON)
option(BUILD_CK_PROFILER "Build the CK profiler in profiler/" ON)
option(BUILD_CK_TILE_ENGINE_TESTS "Build tile engine tests" ON)
option(BUILD_CK_TILE_FMHA_TESTS "Build FMHA tests" ON)
option(BUILD_CK_TILE_CSHUFFLE_LDS_BENCHMARKS "Build CShuffleLds microbenchmarks (requires BUILD_CK_EXAMPLES=ON)" OFF)
add_subdirectory(library)
if(BUILD_CK_DEVICE_INSTANCES)
# Optimization: Search only in library/src where all instance files actually live
# (was searching entire source tree, taking ~40s instead of <1s)
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/library/src/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list})
set(target_dir)
IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}")
set(cmake_instance)
file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES")
set(add_inst 1)
endif()
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES)
list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance)
endif()
ENDIF()
ENDFOREACH()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
add_subdirectory(library)
endif()
if (CK_EXPERIMENTAL_BUILDER)
add_subdirectory(experimental/builder)
@@ -728,34 +754,47 @@ if (CK_EXPERIMENTAL_BUILDER)
endif()
if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
rocm_package_setup_component(tests
LIBRARY_NAME composablekernel
PACKAGE_NAME tests # Prevent -static suffix on package name
)
if(BUILD_CK_EXAMPLES)
rocm_package_setup_component(examples
LIBRARY_NAME composablekernel
PACKAGE_NAME examples
)
add_subdirectory(example)
endif()
rocm_package_setup_component(examples
LIBRARY_NAME composablekernel
PACKAGE_NAME examples
)
add_subdirectory(example)
add_subdirectory(tutorial)
rocm_package_setup_component(tutorials
LIBRARY_NAME composablekernel
PACKAGE_NAME tutorials
)
add_subdirectory(tile_engine)
if(BUILD_CK_TUTORIALS)
add_subdirectory(tutorial)
rocm_package_setup_component(tutorials
LIBRARY_NAME composablekernel
PACKAGE_NAME tutorials
)
endif()
if(BUILD_CK_TILE_ENGINE)
add_subdirectory(tile_engine)
endif()
if(CK_ENABLE_ROCM_CK)
add_subdirectory(rocm_ck)
if(TARGET check)
add_dependencies(check build-smoke-rocm-ck)
endif()
endif()
if(BUILD_TESTING)
rocm_package_setup_component(tests
LIBRARY_NAME composablekernel
PACKAGE_NAME tests # Prevent -static suffix on package name
)
add_subdirectory(test)
endif()
endif()
if (NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory(profiler)
if(BUILD_CK_PROFILER)
if (NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory(profiler)
endif()
endif()
if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))

View File

@@ -51,6 +51,22 @@
"GPU_TARGETS": "gfx908;gfx90a;gfx942"
}
},
{
"name": "dev-minimal",
"binaryDir": "${sourceDir}/build",
"displayName": "CK Dev - Minimal Build",
"description": "Fast iteration build with minimal components (configure ~5s vs ~150s)",
"inherits": ["dev"],
"cacheVariables": {
"BUILD_CK_DEVICE_INSTANCES": "OFF",
"BUILD_CK_PROFILER": "OFF",
"BUILD_CK_EXAMPLES": "OFF",
"BUILD_CK_TUTORIALS": "OFF",
"BUILD_CK_TILE_ENGINE": "OFF",
"BUILD_CK_TILE_ENGINE_TESTS": "OFF",
"BUILD_CK_TILE_FMHA_TESTS": "OFF"
}
},
{
"name": "dev-gfx908",
"displayName": "CK Dev - gfx908",

View File

@@ -3,7 +3,19 @@ FROM ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=7.1.1
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
ARG TARBALL_URL=https://rocm.nightlies.amd.com/tarball/therock-dist-linux-gfx90X-dcgpu-7.12.0a20260218.tar.gz
# TheRock nightly tarball configuration.
# By default, discovers the latest tarball from the nightlies index.
# Manual overrides:
# Pin a specific tarball:
# --build-arg TARBALL_URL=https://rocm.nightlies.amd.com/tarball-multi-arch/therock-dist-linux-multiarch-7.13.0a20260430.tar.gz
# Change the arch variant (default: multiarch):
# --build-arg TARBALL_PATTERN=therock-dist-linux-gfx90a
# --build-arg TARBALL_PATTERN=therock-dist-linux-gfx94X-dcgpu
ARG TARBALL_URL=""
ARG TARBALL_BASE=https://rocm.nightlies.amd.com/tarball-multi-arch
ARG TARBALL_PATTERN=therock-dist-linux-multiarch
ARG compiler_version=""
ARG compiler_commit=""
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
@@ -18,10 +30,18 @@ RUN set -xe && \
RUN if [ "$compiler_version" = "therock" ]; then \
rm -rf /opt/rocm && mkdir /opt/rocm && \
echo "Downloading ROCm tarball from $TARBALL_URL..." && \
if [ -n "$TARBALL_URL" ]; then \
echo "Using provided TARBALL_URL: $TARBALL_URL" ; \
else \
echo "Discovering latest tarball from $TARBALL_BASE..." && \
TARBALL_URL="${TARBALL_BASE}/$(curl -sL "${TARBALL_BASE}/" \
| grep -oP '"name":\s*"\K'"${TARBALL_PATTERN}"'-[^"]+\.tar\.gz' \
| sort -V | tail -1)" && \
echo "Found: $TARBALL_URL" ; \
fi && \
wget -q -O /tmp/rocm.tar.gz "$TARBALL_URL" && \
echo "Extracting tarball to /opt/rocm..." && \
tar -xzf /tmp/rocm.tar.gz -C /opt/rocm --strip-components=1 ; \
tar -xzf /tmp/rocm.tar.gz -C /opt/rocm --strip-components=1 && \
rm /tmp/rocm.tar.gz ; \
else echo "using the release compiler" && \
wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
@@ -36,7 +56,7 @@ ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin
ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION}
RUN set -x && \
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
wget -qO sccache.tar.gz https://github.com/mozilla/sccache/releases/latest/download/sccache-v$SCCACHE_VERSION-x86_64-unknown-linux-musl.tar.gz && \
wget -qO sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v$SCCACHE_VERSION/sccache-v$SCCACHE_VERSION-x86_64-unknown-linux-musl.tar.gz && \
tar -xzf sccache.tar.gz --strip-components=1 -C ${SCCACHE_INSTALL_LOCATION} && \
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache

View File

@@ -10,30 +10,36 @@ RUN pip install pandas zmq einops ninja tabulate vcs_versioning && \
sudo mkdir /home/jenkins/workspace && \
cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \
if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \
git clone --depth 1 -b "$CK_AITER_BRANCH" --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \
cd rocm-libraries && \
mkdir rocm-libraries && cd rocm-libraries && \
git init -q && \
git remote add origin https://github.com/ROCm/rocm-libraries.git && \
git fetch --depth 1 --filter=blob:none origin "$CK_AITER_BRANCH" && \
git sparse-checkout init --cone && \
git sparse-checkout set projects/composablekernel && \
git checkout "$CK_AITER_BRANCH" && \
git checkout FETCH_HEAD && \
ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \
LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \
mv projects/composablekernel ../ck && \
cd ../ck && rm -rf ../rocm-libraries && \
git init && \
git init -b "$LOCAL_BRANCH" && \
git config user.name "assistant-librarian[bot]" && \
git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \
git branch -m "$CK_AITER_BRANCH" && git add -A && \
git add -A && \
git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" ; \
else \
git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck ; \
git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck && \
LOCAL_BRANCH="$CK_AITER_BRANCH" ; \
fi && \
cd /home/jenkins/workspace && rm -rf aiter && \
git clone --depth 1 -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \
cd aiter && \
rm -rf 3rdparty/composable_kernel/ && \
git clone -b "$CK_AITER_BRANCH" ../ck 3rdparty/composable_kernel/ && \
git clone -b "$LOCAL_BRANCH" ../ck 3rdparty/composable_kernel/ && \
python3 setup.py develop && \
groupadd -g 1001 jenkins && \
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \
groupadd -f video && \
groupadd -f render && \
chown -R jenkins:jenkins /home/jenkins && \
chmod -R a+rwx /home/jenkins && \
chown -R jenkins:jenkins /tmp && \

View File

@@ -12,30 +12,36 @@ RUN set -x ; \
sudo mkdir /home/jenkins/workspace && \
cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \
if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \
git clone --depth 1 -b "$CK_FA_BRANCH" --no-checkout --filter=blob:none https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \
cd rocm-libraries && \
mkdir rocm-libraries && cd rocm-libraries && \
git init -q && \
git remote add origin https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \
git fetch --depth 1 --filter=blob:none origin "$CK_FA_BRANCH" && \
git sparse-checkout init --cone && \
git sparse-checkout set projects/composablekernel && \
git checkout "$CK_FA_BRANCH" && \
git checkout FETCH_HEAD && \
ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \
LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \
mv projects/composablekernel ../ck && \
cd ../ck && rm -rf ../rocm-libraries && \
git init && \
git init -b "$LOCAL_BRANCH" && \
git config user.name "assistant-librarian[bot]" && \
git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \
git branch -m "$CK_FA_BRANCH" && git add -A && \
git add -A && \
git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" > /dev/null ; \
else \
git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck ; \
git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck && \
LOCAL_BRANCH="$CK_FA_BRANCH" ; \
fi && \
cd /home/jenkins/workspace && rm -rf flash-attention && \
git clone --depth 1 -b "$FA_BRANCH" --recursive "https://github.com/$FA_ORIGIN/flash-attention.git" && \
cd flash-attention && \
rm -rf csrc/composable_kernel/ && \
git clone -b "$CK_FA_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \
git clone -b "$LOCAL_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \
MAX_JOBS=$(nproc) GPU_ARCHS="$GPU_ARCHS" /opt/venv/bin/python3 -u -m pip install --no-build-isolation -v . && \
groupadd -g 1001 jenkins && \
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \
groupadd -f video && \
groupadd -f render && \
chown -R jenkins:jenkins /home/jenkins && \
chmod -R a+rwx /home/jenkins && \
chown -R jenkins:jenkins /tmp && \

View File

@@ -4,6 +4,7 @@ ARG CK_PYTORCH_BRANCH="develop"
RUN groupadd -g 109 render && \
usermod -u 1001 jenkins && \
groupmod -g 1001 jenkins && \
pip install --upgrade pandas && \
cd /tmp/pytorch && \
rm -rf build && \
cd /tmp/pytorch/third_party && \
@@ -18,15 +19,8 @@ RUN groupadd -g 109 render && \
cd /tmp/pytorch/third_party/flash-attention/csrc && \
rm -rf composable_kernel && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
chown -R jenkins:jenkins /tmp/pytorch && \
chmod -R a+rwx /tmp/pytorch && \
sudo usermod -aG irc jenkins && \
#install hipblaslt
git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \
cd rocm-libraries && \
git checkout develop && \
git sparse-checkout init --cone && \
git sparse-checkout set projects/hipblaslt shared/origami && \
cd projects/hipblaslt && \
git show --oneline -s && \
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller
mkdir -p /var/jenkins/workspace/pytorch && \
cp -r /tmp/pytorch/* /var/jenkins/workspace/pytorch/ && \
chown -R jenkins:jenkins /var/jenkins/workspace/pytorch && \
chmod -R a+rwx /var/jenkins/workspace/pytorch && \
sudo usermod -aG irc jenkins

200
Jenkinsfile vendored
View File

@@ -24,6 +24,24 @@
// Benefits: PR builds 5h → 30min (typical), nightly builds unchanged
// See: script/dependency-parser/README.md for details
//
@NonCPS
String getGitHubCommitHash(def build)
{
def scmAction = build?.actions.find { action ->
action instanceof jenkins.scm.api.SCMRevisionAction
}
if (scmAction?.revision instanceof org.jenkinsci.plugins.github_branch_source.PullRequestSCMRevision)
{
return scmAction.revision.pullHash
}
else if (scmAction?.revision instanceof jenkins.plugins.git.AbstractGitSCMSource$SCMRevisionImpl)
{
return scmAction.revision.hash
}
return null
}
def rocmnode(name) {
return '(rocmtest || miopen) && (' + name + ')'
}
@@ -31,6 +49,7 @@ def rocmnode(name) {
def show_node_info() {
sh """
echo "NODE_NAME = \$NODE_NAME"
hostname
lsb_release -sd
uname -r
cat /sys/module/amdgpu/version
@@ -38,6 +57,30 @@ def show_node_info() {
"""
}
def setGithubStatus(String context, String state, String description) {
def sha = env.GIT_COMMIT
def targetUrl = env.RUN_DISPLAY_URL ?: env.BUILD_URL
def statusUrl = "https://api.github.com/repos/ROCm/rocm-libraries/statuses/${sha}"
withCredentials([usernamePassword(credentialsId: 'github-app-miopen', usernameVariable: 'GITHUB_APP', passwordVariable: 'GITHUB_TOKEN')]) {
def code = '0'
try {
retry(3) {
code = sh(returnStdout: true, script: """
curl -s -w "%{http_code}" -o /dev/null -X POST '${statusUrl}' \\
-H "Authorization: token \$GITHUB_TOKEN" \\
-H 'Content-Type: application/json' \\
-d '{"state":"${state}","context":"${context}","description":"${description}","target_url":"${targetUrl}"}'
""").trim()
if (!code.startsWith('2')) {
error("GitHub status POST returned ${code}")
}
}
} catch (Exception e) {
echo "WARNING: GitHub status POST failed after retries (context=${context}, state=${state}, code=${code})"
}
}
}
def cloneUpdateRefRepo() {
def refRepoPath = "/var/jenkins/ref-repo/rocm-libraries"
def lockLabel = "git ref repo lock - ${env.NODE_NAME}"
@@ -78,7 +121,14 @@ def checkoutComposableKernel()
//update ref repo
cloneUpdateRefRepo()
// checkout project
checkout scm
def scmVars = checkout scm
// getGitHubCommitHash reads SCMRevisionAction recorded before any local merge,
// giving the true PR branch tip (pullHash) or branch HEAD (hash).
// Falls back to ORIG_HEAD (pre-merge HEAD set by git merge) when SCMRevisionAction
// is unavailable, then to HEAD for branch builds where no merge occurred.
env.GIT_COMMIT = getGitHubCommitHash(currentBuild.rawBuild) ?: sh(returnStdout: true, script: '''
git rev-parse ORIG_HEAD 2>/dev/null || git rev-parse HEAD
''').trim()
}
def generateAndArchiveBuildTraceVisualization(String buildTraceFileName) {
@@ -672,6 +722,9 @@ def cmake_build(Map conf=[:]){
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
setup_args = " -D CK_EXPERIMENTAL_BUILDER=ON " + setup_args
}
if (params.RUN_ROCM_CK_TESTS) {
setup_args = " -D CK_ENABLE_ROCM_CK=ON " + setup_args
}
setup_cmd = conf.get(
"setup_cmd",
"""${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_FLAGS=" -O3 " .. """
@@ -788,6 +841,9 @@ def cmake_build(Map conf=[:]){
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
sh 'ninja check-builder'
}
if (params.RUN_ROCM_CK_TESTS) {
sh 'ninja check-rocm-ck'
}
if(params.BUILD_PACKAGES){
echo "Build ckProfiler packages"
sh 'ninja -j64 package'
@@ -827,6 +883,9 @@ def cmake_build(Map conf=[:]){
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
sh 'ninja check-builder'
}
if (params.RUN_ROCM_CK_TESTS) {
sh 'ninja check-rocm-ck'
}
if(params.BUILD_PACKAGES){
echo "Build ckProfiler packages"
sh 'ninja -j64 package'
@@ -840,8 +899,10 @@ def cmake_build(Map conf=[:]){
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
archiveArtifacts "perf_fmha_*.log"
stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}"
dir("projects/composablekernel"){
archiveArtifacts "perf_fmha_*.log"
stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}"
}
}
catch(Exception err){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
@@ -858,13 +919,19 @@ def buildHipClangJob(Map conf=[:]){
def retimage
(retimage, image) = getDockerImage(conf)
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
setGithubStatus("${env.STAGE_NAME}", 'pending', "Starting ${env.STAGE_NAME}")
try {
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 20, unit: 'HOURS')
{
cmake_build(conf)
}
}
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
}
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
throw e
}
return retimage
}
@@ -888,7 +955,8 @@ def Build_CK(Map conf=[:]){
def image
def retimage
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
setGithubStatus("${env.STAGE_NAME}", 'pending', "Starting ${env.STAGE_NAME}")
try {
try {
(retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) {
@@ -905,6 +973,7 @@ def Build_CK(Map conf=[:]){
}
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
echo "The job was cancelled or aborted"
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
throw e
}
withDockerContainer(image: image, args: dockerOpts) {
@@ -915,16 +984,10 @@ def Build_CK(Map conf=[:]){
cmake_build(conf)
if ( params.RUN_INDUCTOR_TESTS && arch == "gfx90a" ){
echo "Run inductor codegen tests"
sh """
python3 -m venv ${env.WORKSPACE}/projects/composablekernel
. ${env.WORKSPACE}/projects/composablekernel/bin/activate
python3 -m pip install pytest build setuptools setuptools_scm
python3 -m pip install .
python3 -m pytest python/test/test_gen_instances.py
"""
sh "projects/composablekernel/script/run_inductor_tests.sh"
}
// run performance tests, stash the logs, results will be processed on the master node
dir("projects/composablekernel/script"){
dir("projects/composablekernel/script"){
if (params.RUN_PERFORMANCE_TESTS){
if (params.RUN_FULL_QA && (arch == "gfx90a" || arch == "gfx942")){
// run full tests on gfx90a or gfx942
@@ -971,6 +1034,11 @@ def Build_CK(Map conf=[:]){
}
}
}
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
}
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
throw e
}
return retimage
}
@@ -991,7 +1059,8 @@ def process_results(Map conf=[:]){
//use older image that has user jenkins
def image = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm6.3"
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
setGithubStatus("${env.STAGE_NAME}", 'pending', 'Processing results...')
try {
try
{
echo "Pulling image: ${image}"
@@ -1005,6 +1074,10 @@ def process_results(Map conf=[:]){
error "Unable to locate image: ${image}"
}
}
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
throw e
}
withDockerContainer(image: image, args: '--cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v=/var/jenkins/:/var/jenkins') {
timeout(time: 15, unit: 'MINUTES'){
@@ -1023,6 +1096,13 @@ def process_results(Map conf=[:]){
catch(Exception err){
echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}."
}
try{
unstash "perf_fmha_log_gfx950"
}
catch(Exception err){
echo "could not locate the FMHA performance logs for gfx950: ${err.getMessage()}."
}
}
if (params.BUILD_INSTANCES_ONLY){
// unstash deb packages
@@ -1105,10 +1185,10 @@ def process_results(Map conf=[:]){
// process the logs
sh "./process_perf_data.sh"
}
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
}
catch(e){
echo "Throwing error exception while processing performance test results"
echo 'Exception occurred: ' + e.toString()
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
throw e
}
finally{
@@ -1123,7 +1203,8 @@ def run_downstream_tests(Map conf=[:]){
checkoutComposableKernel()
def dockerOpts = get_docker_options() + ' --group-add irc '
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
setGithubStatus("${env.STAGE_NAME}", 'pending', "Starting ${env.STAGE_NAME}")
try {
try
{
echo "Pulling image: ${conf.image}"
@@ -1137,6 +1218,10 @@ def run_downstream_tests(Map conf=[:]){
error "Unable to locate image: ${conf.image}"
}
}
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
throw e
}
withDockerContainer(image: conf.image, args: dockerOpts) {
timeout(time: conf.get("timeoutHours", 2), unit: 'HOURS'){
@@ -1146,10 +1231,12 @@ def run_downstream_tests(Map conf=[:]){
for (cmd in conf.execute_cmds) {
sh "${cmd}"
}
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
}
catch(e){
echo "Throwing error exception while running ${env.STAGE_NAME}"
echo 'Exception occurred: ' + e.toString()
setGithubStatus("${env.STAGE_NAME}", 'error', "Stage ${env.STAGE_NAME} failed")
throw e
}
finally{
@@ -1161,8 +1248,11 @@ def run_downstream_tests(Map conf=[:]){
def getPytorchTestsCmds() {
return [
"python3 /tmp/pytorch/tools/amd_build/build_amd.py",
"USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python /tmp/pytorch/setup.py develop"
"mkdir pytorch",
"cp -r /var/jenkins/workspace/pytorch/* pytorch/",
"ls -ltr pytorch",
"python3 pytorch/tools/amd_build/build_amd.py",
"cd pytorch && USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python3 setup.py develop"
]
}
def getAiterTestsCmds() {
@@ -1197,7 +1287,7 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_
0 13 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
0 11 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;RUN_FA_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true
0 9 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : ""
CURRENT_BRANCH_NAME = env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME
CURRENT_BRANCH_NAME = env.CHANGE_ID ? "refs/pull/${env.CHANGE_ID}/head" : (env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME)
POLL_SPEC = BRANCH_NAME == "develop" ? 'H H/6 * * *' : ''
@@ -1318,8 +1408,8 @@ pipeline {
description: "Build CK and run tests on gfx101 (default: OFF)")
booleanParam(
name: "BUILD_GFX103",
defaultValue: true,
description: "Build CK and run tests on gfx103 (default: ON)")
defaultValue: false,
description: "Build CK and run tests on gfx103 (default: OFF)")
booleanParam(
name: "BUILD_GFX11",
defaultValue: true,
@@ -1338,8 +1428,8 @@ pipeline {
description: "Generate a detailed time trace (default: OFF)")
booleanParam(
name: "RUN_INDUCTOR_TESTS",
defaultValue: false,
description: "Run inductor codegen tests (default: OFF)")
defaultValue: true,
description: "Run inductor codegen tests (default: ON)")
booleanParam(
name: "RUN_CODEGEN_TESTS",
defaultValue: true,
@@ -1348,6 +1438,10 @@ pipeline {
name: "RUN_BUILDER_TESTS",
defaultValue: false,
description: "Run CK_BUILDER tests (default: OFF)")
booleanParam(
name: "RUN_ROCM_CK_TESTS",
defaultValue: true,
description: "Run rocm_ck tests (default: ON)")
booleanParam(
name: "RUN_ALL_UNIT_TESTS",
defaultValue: false,
@@ -1404,9 +1498,9 @@ pipeline {
dbsshport = "${dbsshport}"
dbsshuser = "${dbsshuser}"
dbsshpassword = "${dbsshpassword}"
ck_git_creds = "${ck_git_creds}"
gerrit_cred="${gerrit_cred}"
DOCKER_BUILDKIT = "1"
BUILD_GFX103 = "${env.BRANCH_NAME == 'develop' ? true : false}"
}
stages{
stage("Determine CI Execution") {
@@ -1743,6 +1837,7 @@ pipeline {
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
-D BUILD_CK_TILE_ENGINE="ON" \
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx942" \
@@ -1756,7 +1851,7 @@ pipeline {
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
-D GEMM_PRESHUFFLE_CONFIG_FILE="default_ci_config.json" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}
@@ -1785,6 +1880,7 @@ pipeline {
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
-D BUILD_CK_TILE_ENGINE="ON" \
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx942" \
@@ -1799,7 +1895,7 @@ pipeline {
-D GROUPED_GEMM_DATATYPE="fp8;fp16" \
-D GROUPED_GEMM_LAYOUT="rcr;rrr;crr;ccr" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all benchmark_grouped_gemm_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/grouped_gemm/grouped_gemm_benchmark.py . --problem-sizes "1024,1024,1024" --group-counts 8 --warmup 5 --repeat 5 --verbose --json grouped_gemm_results.json """
@@ -1819,6 +1915,7 @@ pipeline {
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
-D BUILD_CK_TILE_ENGINE="ON" \
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx950" \
@@ -1829,7 +1926,7 @@ pipeline {
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}
@@ -1848,13 +1945,14 @@ pipeline {
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
-D BUILD_CK_TILE_ENGINE="ON" \
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx1201" \
-D GEMM_UNIVERSAL_DATATYPE="fp16" \
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
@@ -1904,6 +2002,7 @@ pipeline {
cleanWs()
}
}
/*
stage("Build CK and run Tests on gfx908")
{
when {
@@ -1920,6 +2019,7 @@ pipeline {
cleanWs()
}
}
*/
stage("Build CK and run Tests on gfx90a")
{
when {
@@ -2036,8 +2136,9 @@ pipeline {
}
success {
script {
// Report the parent stage build ck and run tests status
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
node(rocmnode("nogpu")) {
// Report the parent stage build ck and run tests status
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
echo "Reporting success status for build ck and run tests"
}
}
@@ -2063,12 +2164,9 @@ pipeline {
post {
success {
script {
// Report the skipped parent's stage status
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Process Performance Test Results", account: 'ROCm', repo: 'rocm-libraries') {
echo "Process Performance Test Results stage skipped."
}
// Report the skipped stage's status
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Process results", account: 'ROCm', repo: 'rocm-libraries') {
node(rocmnode("nogpu")) {
// Report the skipped parent's stage status
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
echo "Process Performance Test Results stage skipped."
}
}
@@ -2078,20 +2176,22 @@ pipeline {
}
post {
success {
githubNotify context: 'Math CI Summary',
status: 'SUCCESS',
description: 'All checks have passed'
script {
node(rocmnode("nogpu")) {
setGithubStatus('Math CI Summary', 'success', "Math CI passed")
}
}
}
failure {
githubNotify context: 'Math CI Summary',
status: 'FAILURE',
description: 'Some checks have failed'
node(rocmnode("nogpu")) {
script {
checkoutComposableKernel()
}
withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) {
sh 'bash projects/composablekernel/script/infra_helper/send_failure_notifications.sh'
script {
node(rocmnode("nogpu")) {
setGithubStatus('Math CI Summary', 'failure', "Math CI failed")
script {
checkoutComposableKernel()
}
withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) {
sh 'bash projects/composablekernel/script/infra_helper/send_failure_notifications.sh'
}
}
}
}

View File

@@ -124,6 +124,21 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
../script/cmake-ck-dev.sh .. gfx90a -DCMAKE_BUILD_TYPE=Release
```
**Fast iteration builds:**
For faster CMake configuration during development (~5s vs ~150s), use the `--minimal` flag to disable
building device instances, profiler, examples, tutorials, and tests:
```bash
../script/cmake-ck-dev.sh --minimal .. gfx90a
```
You can also specify a custom preset:
```bash
../script/cmake-ck-dev.sh --preset=dev-minimal .. gfx90a
```
5. Build the entire CK library:
```bash

View File

@@ -71,6 +71,7 @@ set(GTEST_CXX_FLAGS
-Wno-lifetime-safety-intra-tu-suggestions
-Wno-lifetime-safety-cross-tu-suggestions
-Wno-character-conversion
-Wno-lifetime-safety-invalidation
)
if(WIN32)
@@ -78,7 +79,8 @@ if(WIN32)
-Wno-suggest-destructor-override
-Wno-suggest-override
-Wno-nonportable-system-include-path
-Wno-language-extension-token)
-Wno-language-extension-token
-Wno-lifetime-safety-invalidation)
endif()
target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS})

View File

@@ -21,6 +21,8 @@ endif()
add_library(ck_tile_dispatcher
src/registry.cpp
src/dispatcher.cpp
src/fmha_registry.cpp
src/fmha_dispatcher.cpp
)
# Enable PIC for Python bindings
@@ -34,13 +36,21 @@ target_include_directories(ck_tile_dispatcher
$<INSTALL_INTERFACE:include>
)
# Link against CK Tile headers (header-only)
# CK Tile core headers (ck_tile/core, ck_tile/ops, etc.)
target_include_directories(ck_tile_dispatcher
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include>
$<INSTALL_INTERFACE:include>
)
# CK project root -- needed only for FMHA generated wrappers that include
# "example/ck_tile/01_fmha/fmha_fwd.hpp". PRIVATE to avoid exposing the
# entire project tree to downstream consumers.
target_include_directories(ck_tile_dispatcher
PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/..>
)
# Link against HIP headers if available
if(hip_FOUND)
target_link_libraries(ck_tile_dispatcher PUBLIC hip::host)

View File

@@ -394,6 +394,12 @@ python3 examples/grouped_conv/python/03_bwd_data.py # Backward data +
python3 examples/grouped_conv/python/04_bwd_weight.py # Backward weight + CPU ref
python3 examples/grouped_conv/python/05_benchmark.py # Multi-problem benchmark
python3 examples/grouped_conv/python/06_registry_json.py # Heuristic selection + JSON
# FMHA Examples (JIT-compiled on the fly)
python3 examples/fmha/python/01_basic_fmha.py # Basic forward attention
python3 examples/fmha/python/12_masks_fmha.py # Causal masks
python3 examples/fmha/python/18_backward_fmha.py # Backward pass
python3 examples/fmha/python/16_splitkv_fmha.py # Split-KV for long sequences
```
### Example Output
@@ -716,7 +722,7 @@ This matrix shows all CK Tile operations with per-data-type, per-layout, and per
| GEMM | streamk_gemm<br>example: `40_streamk_gemm/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| Reduce | multi_reduce2d<br>example: `05_reduce/` | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
| Reduce | reduce2d<br>example: `05_reduce/` | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
| Attention | fmha<br>example: `01_fmha/` | | | | | ❌ | | | | | | | | | | ❌ |
| Attention | fmha<br>example: `01_fmha/` | | | | | ❌ | | | | | | | | | | ❌ |
| Attention | sparse_attn<br>example: `50_sparse_attn/` | ❌ | | ❌ | | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ |
| Activation | softmax | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
| Activation | topk_softmax<br>example: `09_topk_softmax/` | ❌ | ❌ | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
@@ -871,7 +877,14 @@ dispatcher/
| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder)
| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations
| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe)
| +---- grouped_conv_utils.hpp # Grouped conv utilities
| |---- grouped_conv_utils.hpp # Grouped conv utilities
| |---- fmha_types.hpp # FMHA fwd/bwd args and traits structs
| |---- fmha_problem.hpp # FmhaProblem, FmhaProblemBuilder
| |---- fmha_kernel_key.hpp # FmhaKernelKey (Signature + Algorithm)
| |---- fmha_kernel_instance.hpp # FmhaKernelInstance virtual interface
| |---- fmha_kernel_decl.hpp # Declarative FmhaSignature/FmhaAlgorithm
| |---- fmha_registry.hpp # FmhaRegistry (thread-safe)
| +---- fmha_dispatcher.hpp # FmhaDispatcher (plan, select, run)
|
|---- src/ # C++ implementation
|
@@ -879,12 +892,17 @@ dispatcher/
| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings
| |---- unified_gemm_codegen.py # GEMM kernel generator
| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator
| |---- unified_fmha_codegen.py # FMHA kernel generator
| |---- fmha_arch_specs.json # FMHA per-arch tile/pipeline specs
| |---- fmha_rules.py # FMHA validation rules
| |---- fmha_profiles.py # FMHA named profiles/receipts
| +---- arch_specs.json # GPU specifications
|
|---- python/ # Python utilities
| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output
| |---- ctypes_utils.py # GEMM ctypes utilities
| +---- grouped_conv_utils.py # Grouped conv utilities
| |---- grouped_conv_utils.py # Grouped conv utilities
| +---- fmha_utils.py # FMHA: JIT compile, FmhaRunner, FmhaKernelConfig
|
|---- scripts/ # Build scripts
| |---- compile_gemm_examples.py # GEMM build script
@@ -892,15 +910,19 @@ dispatcher/
|
|---- bindings/ctypes/ # Python ctypes interface
| |---- gemm_ctypes_lib.cpp # GEMM Python library
| +---- conv_ctypes_lib.cpp # Grouped conv Python library
| |---- conv_ctypes_lib.cpp # Grouped conv Python library
| +---- fmha_ctypes_lib.cpp # FMHA Python library
|
|---- examples/ # Examples
| |---- gemm/
| | |---- cpp/ # C++ GEMM examples (01-07)
| | +---- python/ # Python GEMM examples (01-11)
| +---- grouped_conv/
| |---- cpp/ # C++ Grouped Conv examples (01-07)
| +---- python/ # Python Grouped Conv examples (01-06)
| |---- grouped_conv/
| | |---- cpp/ # C++ Grouped Conv examples (01-07)
| | +---- python/ # Python Grouped Conv examples (01-06)
| +---- fmha/
| |---- cpp/ # C++ FMHA examples (01-35)
| +---- python/ # Python FMHA examples (01-38)
|
+---- tests/ # Unit tests (C++ and Python)
```
@@ -913,6 +935,8 @@ dispatcher/
|-----------|--------|
| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) |
| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) |
| FMHA C++ | examples/fmha/cpp/ (35 examples covering all FMHA variants) |
| FMHA Python | examples/fmha/python/ (38 examples with JIT compilation) |
| Codegen | [codegen/README.md](codegen/README.md) |
| Python Utils | [python/README.md](python/README.md) |
| C++ Headers | [include/ck_tile/dispatcher/README.md](include/ck_tile/dispatcher/README.md) |

View File

@@ -10,6 +10,7 @@ bindings/
| |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API
| |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data)
| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library)
| |---- fmha_ctypes_lib.cpp # FMHA dispatcher C API (fwd + bwd)
| |---- gpu_helper.cpp # CLI helper for Python
| +---- CMakeLists.txt
+---- README.md

View File

@@ -129,7 +129,22 @@ float conv_bwdw_run(const void* input_ptr,
return -1.0f;
if(!input_ptr || !grad_output_ptr || !grad_weight_ptr)
return -1.0f; // Null data pointer would cause kernel crash
return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream);
try
{
return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream);
}
catch(const std::exception&)
{
// Kernel rejected args (e.g. unsupported tile/channel combo)
// -3.0f matches conv_ctypes_lib.cpp:316 convention
// -2.0f is reserved for "no kernel / not compiled for this direction"
return -3.0f;
}
catch(...)
{
return -3.0f;
}
#else
return -1.0f;
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -81,7 +81,9 @@
"warp_configs": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1]
[4, 1, 1],
[8, 2, 1],
[4, 4, 1]
],
"warp_tile_combos": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
@@ -256,8 +258,8 @@
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]]
},
"gfx950": {
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16], [32, 32, 32], [16, 16, 64]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16], [32, 32, 32], [16, 16, 64]],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
}

View File

@@ -1,11 +1,10 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
Generated from: arch_specs.json
Generated at: 2026-01-05T19:34:01.224422
Generated at: 2026-04-10T20:07:11.665064
To update this file:
1. Edit arch_specs.json
@@ -50,7 +49,7 @@ WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {
"gfx908": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1], [8, 2, 1], [4, 4, 1]],
"gfx1100": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
"gfx1200": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
"gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
@@ -226,6 +225,8 @@ PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
[32, 32, 32],
[16, 16, 64],
],
"bf16_bf16_fp32": [
[32, 32, 8],
@@ -233,6 +234,8 @@ PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
[32, 32, 32],
[16, 16, 64],
],
"fp8_fp8_fp32": [
[32, 32, 16],

View File

@@ -4,7 +4,7 @@
# SPDX-License-Identifier: MIT
"""
Shared codegen infrastructure for GEMM and grouped convolution code generators.
Shared codegen infrastructure for GEMM, grouped convolution, and FMHA code generators.
Extracted from unified_gemm_codegen.py + arch-aware expansion helpers from conv.
Both unified_gemm_codegen.py and unified_grouped_conv_codegen.py import from here

View File

@@ -0,0 +1,4 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""FMHA codegen subpackage — tile specs, instance generation, symbol mapping, and C++ codegen."""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,175 @@
{
"_comment": "FMHA-specific architecture specs. Edit this file to add new GPU/dtype/pipeline support for FMHA.",
"_note": "Common GPU hardware data (element_sizes, warp_size, warp_configs, lds_capacity_kb) lives in ../arch_specs.json. This file holds FMHA-specific capabilities, tile tables, and validation rules.",
"architectures": {
"gfx90a": {
"family": "cdna2",
"arch_tag": "ck_tile::gfx9_t",
"supported_dtypes": ["fp16", "bf16", "fp32"],
"supported_pipelines": ["qr", "qr_async", "qs", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
"supports_trload": false,
"supports_v3": false
},
"gfx942": {
"family": "cdna3",
"arch_tag": "ck_tile::gfx9_t",
"supported_dtypes": ["fp16", "bf16", "fp32", "fp8", "fp8fp16", "fp8bf16", "fp8fp32", "bf8"],
"supported_pipelines": ["qr", "qr_async", "qs", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
"supports_trload": false,
"supports_v3": false
},
"gfx950": {
"family": "cdna4",
"arch_tag": "ck_tile::gfx9_t",
"supported_dtypes": ["fp16", "bf16", "fp32", "fp8", "fp8fp16", "fp8bf16", "fp8fp32", "bf8"],
"supported_pipelines": ["qr", "qr_async", "qs", "qr_async_trload", "qr_async_trload_v3", "v3", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
"supports_trload": true,
"supports_v3": true
},
"gfx1100": {
"family": "rdna3",
"arch_tag": "ck_tile::gfx11_t",
"supported_dtypes": ["fp16", "bf16"],
"supported_pipelines": ["qr", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
"supports_trload": false,
"supports_v3": false
},
"gfx1201": {
"family": "rdna4",
"arch_tag": "ck_tile::gfx12_t",
"supported_dtypes": ["fp16", "bf16", "fp8", "fp8bf16"],
"supported_pipelines": ["qr", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
"supports_trload": false,
"supports_v3": false
}
},
"supported_hdims": {
"_comment": "hdim_q must satisfy ceil_to_qualified_tile_length() in tile_fmha_shape.hpp. Each entry is [hdim_q, hdim_v].",
"fp16": [[32,32], [64,64], [80,96], [96,128], [128,128], [160,160], [192,128], [192,192], [256,256]],
"bf16": [[32,32], [64,64], [80,96], [96,128], [128,128], [160,160], [192,128], [192,192], [256,256]],
"fp32": [[32,32], [48,48], [64,64], [96,128], [128,128], [192,192], [256,256]],
"fp8": [[64,64], [128,128], [192,128], [256,256]],
"fp8bf16": [[64,64], [128,128], [192,128], [256,256]],
"fp8fp32": [[128,128]],
"bf8": [[64,64], [128,128], [192,128], [256,256]],
"mxfp8": [[128,128], [256,256]],
"mxfp4": [[128,128], [256,256]]
},
"fmha_warp_tiles": {
"_comment": "FMHA warp tile sizes [wm0, wn0, wk0] per FMHA dtype. Subset of MFMA/WMMA instructions relevant to attention.",
"fp16": [[32,32,16], [16,16,32]],
"bf16": [[32,32,16], [16,16,32]],
"fp32": [[16,16,16]],
"fp8": [[32,32,32]],
"fp8bf16": [[32,32,32]],
"fp8fp32": [[32,32,32]],
"bf8": [[32,32,32]],
"mxfp8": [[32,32,64]],
"mxfp4": [[16,16,128]]
},
"fmha_element_sizes": {
"_comment": "FMHA-specific element sizes for composite dtypes not in parent arch_specs.json. Common dtypes (fp16, bf16, fp32, fp8, bf8) use ../arch_specs.json element_sizes.",
"fp8bf16": 1,
"fp8fp32": 1,
"mxfp8": 1,
"mxfp4": 1
},
"tile_sweep_ranges": {
"_comment": "Block tile dimensions to sweep. Must be multiples of warp tile sizes.",
"valid_bm0": [16, 32, 64, 128, 192, 256],
"valid_bn0": [16, 32, 64, 96, 128, 192, 256, 384],
"valid_bk0": [16, 32, 64, 96, 128, 256]
},
"k0max_map": {
"_comment": "Maps hdim_q -> padded K-tile length. Source: tile_fmha_shape.hpp ceil_to_qualified_tile_length().",
"32": 32, "48": 48, "64": 64, "80": 96, "96": 128,
"128": 128, "160": 256, "192": 192, "256": 256
},
"lds_limits": {
"_comment": "LDS budget in bytes per non-async FMHA pipeline. Async pipelines compute LDS dynamically.",
"qr": 65536,
"qs": 65536
},
"global_rules": {
"hdim_192_128_no_bias_dropout": true,
"logits_requires_no_bias": true,
"group_mode_requires_padding": true,
"hdim_divisible_by": 8
},
"defaults": {
"tile": [128, 64, 32, 128, 32, 128],
"wave": [2, 2, 1, 2, 2, 1, 1, 1, 1],
"warp": [32, 32, 16, 32, 32, 16, 16, 16, 16],
"padding": [true, true, true, true],
"block_per_cu": 1,
"num_wave_groups": 1,
"selection_rank": 0
},
"splitkv_combine": {
"combine_bn1": 32,
"hdims_fp16": [32, 64, 96, 128, 256],
"hdims_fp8": [64, 128, 256]
},
"batch_prefill": {
"supported_page_sizes": [1, 16, 1024],
"supported_kv_memory_layouts": ["vectorized", "linear"],
"supported_kv_lookup_tables": ["vllm", "sglang"]
},
"bwd_tiles": {
"_comment": "BWD dq_dk_dv tile tables. Format: [bm0, bn0, bk0, bn1, bk1, bk0max, tile6, tile7, tile8].",
"dq_dk_dv_fp16": {
"32_32": [32, 128, 32, 32, 32, 32, 64, 32, 32],
"64_64": [32, 128, 64, 32, 64, 32, 32, 64, 64],
"96_128": [32, 128, 96, 32, 96, 32, 32, 96, 96],
"128_128": [16, 128, 128, 16, 128, 16, 32, 128, 128],
"256_256": [16, 64, 256, 16, 256, 16, 32, 256, 256]
},
"dq_dk_dv_extra": {
"64_64": [
{"tile": [32, 128, 64, 32, 64, 32, 32, 64, 64], "tag": "trload", "batch_only": false},
{"tile": [32, 16, 64, 32, 64, 32, 16, 64, 64], "tag": "small", "batch_only": true}
],
"128_128": [
{"tile": [16, 16, 128, 16, 128, 16, 16, 128, 128], "tag": "small", "batch_only": true},
{"tile": [16, 192, 128, 16, 128, 16, 32, 128, 128], "tag": "bn192", "batch_only": false},
{"tile": [32, 128, 128, 32, 128, 32, 32, 128, 128], "tag": "trload", "batch_only": false}
]
},
"dot_do_o_hdims": [32, 64, 96, 128, 256],
"convert_dq_hdims": [32, 64, 96, 128, 256],
"convert_dq_tile_groups": {"32": 1, "64": 1, "96": 1, "128": 1, "256": 1},
"pad_combos": [["f","f"], ["f","t"], ["f","8"], ["t","f"], ["t","t"], ["t","8"], ["8","8"]],
"extra_pad_combos": [["f","f"], ["8","8"]],
"dropouts": ["no", "dropout_wg16", "dropout_wg16_storerandval"],
"small_dropouts": ["no"]
},
"bwd_wave_warp": {
"_comment": "BWD wave/warp lookup. Key: 'bm0_bn0_bk0_trload'. Value: {wave: 9-tuple, warp_k1: int}.",
"16_16_128_t": {"wave": [1,1,1,1,1,1,1,1,1], "warp_k1": 16},
"16_64_256_f": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16},
"16_128_128_f": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16},
"16_192_128_t": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16},
"32_16_64_t": {"wave": [1,1,1,1,1,1,1,1,1], "warp_k1": 16},
"32_128_32_f": {"wave": [1,4,1,4,1,1,2,2,1], "warp_k1": 16},
"32_128_64_f": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16},
"32_128_64_t": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 32},
"32_128_96_f": {"wave": [1,4,1,4,1,1,2,2,1], "warp_k1": 16},
"32_128_128_t": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 32},
"64_128_32_f": {"wave": [2,4,1,4,1,1,2,4,1], "warp_k1": 16},
"64_128_64_f": {"wave": [2,4,1,4,1,1,2,4,1], "warp_k1": 16},
"64_128_128_f": {"wave": [2,4,1,4,1,1,2,4,1], "warp_k1": 16}
}
}

View File

@@ -0,0 +1,261 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Generate FMHA fallback kernel + dispatch header for the Python ctypes library.
Mirrors generate_conv_dispatch_header.py: generates a single FMHA forward
kernel and creates a dispatch header that can be force-included into
fmha_ctypes_lib.cpp.
Usage:
python3 generate_fmha_fallback.py --output-dir /path/to/output --gpu-target gfx950
"""
import argparse
import json
import subprocess
import sys
from pathlib import Path
# Default kernel config for fallback — a single fwd fp16 kernel with
# known-good tile (128x128x32, qr_async) for basic smoke-test capability.
# Source: tile dims from fmha_fwd.py FmhaFwdTileSize for hdim=128 fp16.
DEFAULT_CONFIG = {
"arch": "gfx950",
"signature": {
"family": "fwd",
"data_type": "fp16",
"mode": "batch",
"vlayout": "r",
"hdim_q": 128,
"hdim_v": 128,
"mask": "no",
"bias": "no",
"lse": False,
"dropout": False,
"qscale": "no",
"rope": "none",
"logits": False,
"paged_kv": False,
"fp8_static_quant": False,
"skip_min_seqlen_q": False,
"sink": False,
"dbias": False,
"store_randval": False,
"deterministic": False,
"kv_memory_layout": "vectorized",
"kv_lookup_table": "sglang",
"page_size": 1,
},
"algorithm": {
"pipeline": "qr_async",
"tile": [128, 128, 32, 128, 32, 128],
"wave": [4, 1, 1, 4, 1, 1, 1, 1, 1],
"warp": [32, 32, 16, 32, 32, 16, 16, 16, 16],
"padding": [True, True, True, True],
"block_per_cu": 1,
"num_wave_groups": 1,
"max_splits_log2": 0,
"max_seq_len_q": 0,
},
}
def generate_dispatch_header(output_dir: Path, wrapper_dir: Path) -> Path:
"""Generate fmha_python_dispatch.hpp from the wrapper headers."""
wrappers = sorted(wrapper_dir.glob("dispatcher_wrapper_fmha_*.hpp"))
if not wrappers:
raise RuntimeError(f"No FMHA dispatcher wrappers found in {wrapper_dir}")
kernel_names = []
make_calls = []
for w in wrappers:
stem = w.stem.replace("dispatcher_wrapper_", "")
kernel_names.append(stem)
make_calls.append(
f" registry.register_kernel("
f"ck_tile::dispatcher::generated::make_{stem}(arch));"
)
lines = [
"// Auto-generated FMHA dispatch header for Python ctypes library",
"#pragma once",
"",
]
for w in wrappers:
lines.append(f'#include "dispatcher_wrappers/{w.name}"')
lines += [
"",
'#include "ck_tile/dispatcher/fmha_registry.hpp"',
'#include "ck_tile/dispatcher/fmha_dispatcher.hpp"',
"",
"namespace generated {",
"",
"inline void register_fmha_python_kernels("
"ck_tile::dispatcher::FmhaRegistry& registry, const std::string& arch) {",
" (void)arch;",
]
lines += make_calls
lines += [
"}",
"",
"} // namespace generated",
"",
"#ifndef REGISTER_GENERATED_KERNELS",
"#define REGISTER_GENERATED_KERNELS(registry, arch) "
"::generated::register_fmha_python_kernels(registry, arch)",
"#endif",
"",
"// Stable C ABI for dlopen/dlsym-based kernel registration.",
'// Plugins call dlsym(handle, "ck_fmha_register_kernels") to get this.',
'extern "C" __attribute__((visibility("default")))',
"int ck_fmha_register_kernels(",
" ck_tile::dispatcher::FmhaRegistry& registry, const char* arch) {",
" ::generated::register_fmha_python_kernels(registry, arch ? std::string(arch) : std::string());",
f" return {len(kernel_names)};",
"}",
"",
"// Kernel inventory for Python introspection",
f"static const int FMHA_KERNEL_COUNT = {len(kernel_names)};",
"static const char* FMHA_KERNEL_NAMES[] = {"
+ ", ".join(f'"{n}"' for n in kernel_names)
+ "};",
"",
]
header_path = output_dir / "fmha_python_dispatch.hpp"
header_path.write_text("\n".join(lines) + "\n")
return header_path
def compile_kernels(output_dir: Path, gpu_target: str, include_dirs: str) -> Path:
"""Compile kernel .cpp files into a static library."""
import shutil
hipcc = shutil.which("hipcc") or "/opt/rocm/bin/hipcc"
kernel_cpps = sorted(output_dir.glob("fmha_*.cpp"))
if not kernel_cpps:
raise RuntimeError("No kernel .cpp files to compile")
import re
# Use the shared compile flags from fmha_utils
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "python"))
from fmha_utils import fmha_compile_flags # noqa: E402
base_flags = fmha_compile_flags(gpu_target, hipcc, family="bwd")
inc_flags = []
for d in re.split(r"[;:]", include_dirs):
d = d.strip()
if d:
inc_flags.extend(["-I", d])
objs = []
for cpp in kernel_cpps:
obj = cpp.with_suffix(".o")
cmd = base_flags + inc_flags + [str(cpp), "-o", str(obj)]
print(f" Compiling {cpp.name}...")
r = subprocess.run(cmd, capture_output=True, text=True)
if r.returncode != 0:
print(f" FAILED: {r.stderr}", file=sys.stderr)
raise RuntimeError(f"Failed to compile {cpp.name}")
objs.append(str(obj))
lib_path = output_dir / "libfmha_python_fallback.a"
ar_cmd = ["ar", "rcs", str(lib_path)] + objs
subprocess.check_call(ar_cmd)
print(f" Static lib: {lib_path}")
return lib_path
def main():
parser = argparse.ArgumentParser(
description="Generate FMHA fallback kernel for Python library"
)
parser.add_argument("--output-dir", required=True, type=Path)
parser.add_argument("--gpu-target", default="gfx950")
parser.add_argument(
"--config-json",
default=None,
help="Override default kernel config (JSON string)",
)
parser.add_argument(
"--compile", action="store_true", help="Also compile the kernel .cpp into a .a"
)
parser.add_argument(
"--include-dirs",
default="",
help="Semicolon-separated include directories for compilation",
)
args = parser.parse_args()
output_dir = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
codegen_dir = Path(__file__).parent
codegen_script = codegen_dir / "codegen.py"
# Accept either a single config dict or a list of configs
if args.config_json:
parsed = json.loads(args.config_json)
if isinstance(parsed, list):
# Multi-config: pass list directly to unified_fmha_codegen
codegen_input = parsed
else:
# Single config: merge with defaults
config = dict(DEFAULT_CONFIG)
config["arch"] = args.gpu_target
config["signature"] = dict(DEFAULT_CONFIG["signature"])
config["algorithm"] = dict(DEFAULT_CONFIG["algorithm"])
config.update(parsed)
codegen_input = config
else:
config = dict(DEFAULT_CONFIG)
config["arch"] = args.gpu_target
config["signature"] = dict(DEFAULT_CONFIG["signature"])
config["algorithm"] = dict(DEFAULT_CONFIG["algorithm"])
codegen_input = config
print(f"Generating FMHA fallback kernel for {args.gpu_target}...")
print(f" Output: {output_dir}")
cmd = [
sys.executable,
str(codegen_script),
"--output-dir",
str(output_dir),
"--gpu-target",
args.gpu_target,
"--config-json",
json.dumps(codegen_input),
]
result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(codegen_dir))
if result.returncode != 0:
print(f" Codegen FAILED:\n{result.stderr}", file=sys.stderr)
return 1
wrapper_dir = output_dir / "dispatcher_wrappers"
if not wrapper_dir.exists():
print(" ERROR: No dispatcher_wrappers dir created", file=sys.stderr)
return 1
header_path = generate_dispatch_header(output_dir, wrapper_dir)
print(f" Dispatch header: {header_path}")
kernel_cpps = list(output_dir.glob("fmha_*.cpp"))
print(f" Kernel TUs: {len(kernel_cpps)}")
if args.compile and kernel_cpps:
compile_kernels(output_dir, args.gpu_target, args.include_dirs)
return 0
if __name__ == "__main__":
sys.exit(main())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,333 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import json
import hashlib
# Architecture tag → C++ arch trait type.
# Source: CK's include/ck_tile/core/arch/amd_gpu_traits.hpp
# gfx9* → gfx9_t, gfx11* → gfx11_t, gfx12* → gfx12_t.
ARCH_TAG_MAP = {
"gfx90a": "ck_tile::gfx9_t",
"gfx942": "ck_tile::gfx9_t",
"gfx950": "ck_tile::gfx9_t",
"gfx1100": "ck_tile::gfx11_t",
"gfx1201": "ck_tile::gfx12_t",
}
# Architecture → preprocessor guard for conditional compilation.
# Source: HIP compiler predefined macros (__gfx90a__, __gfx942__, etc.).
ARCH_PREPROC_MAP = {
"gfx90a": "defined(__gfx90a__)",
"gfx942": "defined(__gfx942__)",
"gfx950": "defined(__gfx950__)",
"gfx1100": "defined(__gfx1100__)",
"gfx1201": "defined(__gfx1201__)",
}
# Forward dtype → C++ type config struct.
# Source: example/ck_tile/01_fmha/fmha_fwd.hpp FmhaFwdTypeConfig<> specializations
# and codegen/cpp_symbol_map.py FWD_DTYPE_MAP.
FWD_DTYPE_MAP = {
"fp32": "FmhaFwdFp32",
"fp16": "FmhaFwdFp16",
"bf16": "FmhaFwdBf16",
"fp8": "FmhaFwdFp8",
"bf8": "FmhaFwdBf8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16",
"fp8fp32": "FmhaFwdFp8Fp32",
}
# Backward dtype → C++ type config struct.
# Source: example/ck_tile/01_fmha/fmha_bwd.hpp FmhaBwdTypeConfig<> specializations.
# BWD currently only supports fp16/bf16/fp32.
BWD_DTYPE_MAP = {
"fp32": "FmhaBwdFp32",
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16",
}
# Kernel family → C++ enum.
# Source: include/ck_tile/dispatcher/fmha_types.hpp FmhaKernelFamily enum.
KERNEL_FAMILY_TO_ENUM = {
"fwd": "FmhaKernelFamily::Fwd",
"fwd_pagedkv": "FmhaKernelFamily::FwdPagedKv",
"fwd_splitkv": "FmhaKernelFamily::FwdSplitKv",
"fwd_splitkv_combine": "FmhaKernelFamily::FwdSplitKvCombine",
"fwd_appendkv": "FmhaKernelFamily::FwdAppendKv",
"batch_prefill": "FmhaKernelFamily::BatchPrefill",
"bwd_dot_do_o": "FmhaKernelFamily::BwdDotDoO",
"bwd_dq_dk_dv": "FmhaKernelFamily::BwdDqDkDv",
"bwd_convert_dq": "FmhaKernelFamily::BwdConvertDq",
}
# API family → C++ enum.
# Source: include/ck_tile/dispatcher/fmha_types.hpp FmhaApiFamily enum.
API_FAMILY_TO_ENUM = {
"fwd": "FmhaApiFamily::Fwd",
"fwd_pagedkv": "FmhaApiFamily::FwdPagedKv",
"fwd_splitkv": "FmhaApiFamily::FwdSplitKv",
"fwd_appendkv": "FmhaApiFamily::FwdAppendKv",
"batch_prefill": "FmhaApiFamily::BatchPrefill",
"bwd": "FmhaApiFamily::Bwd",
}
# Mask type → canonical form and C++ types.
# Source: include/ck_tile/ops/fmha/block/block_attention_mask.hpp
# SimplifiedGenericAttentionMask<is_causal> and GenericAttentionMask<has_mask, has_local>.
MASK_CANONICAL = {
"no": "no",
"no_mask": "no",
"causal": "top_left",
"top_left": "top_left",
"t": "top_left",
"bottom_right": "bottom_right",
"b": "bottom_right",
"generic": "generic",
"window_generic": "generic",
"g": "generic",
}
MASK_TO_CPP = {
"no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"top_left": "ck_tile::SimplifiedGenericAttentionMask<true>",
"bottom_right": "ck_tile::SimplifiedGenericAttentionMask<true>",
"generic": "ck_tile::GenericAttentionMask<true, true>",
}
MASK_TO_CPP_GENERIC = {
"no": "FmhaMasks::NoMask",
"top_left": "FmhaMasks::CausalMask",
"bottom_right": "FmhaMasks::CausalMask",
"generic": "FmhaMasks::GenericMask",
}
MASK_TO_INT = {
"no": 0,
"top_left": 1,
"bottom_right": 2,
"generic": 3,
}
# Bias type → canonical form and C++ enum.
# Source: include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp.
BIAS_CANONICAL = {
"no": "no",
"no_bias": "no",
"bias": "bias",
"elementwise": "bias",
"elementwise_bias": "bias",
"alibi": "alibi",
}
BIAS_TO_CPP = {
"no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI",
}
BIAS_TO_INT = {
"no": 0,
"bias": 1,
"alibi": 2,
}
# Quantization scale type → canonical form and C++ enum.
# Source: include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp.
QSCALE_CANONICAL = {
"no": "no",
"no_scale": "no",
"pertensor": "pertensor",
"blockscale": "blockscale",
"kv_blockscale": "kv_blockscale",
}
QSCALE_TO_CPP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
}
QSCALE_TO_INT = {
"no": 0,
"pertensor": 1,
"blockscale": 2,
"kv_blockscale": 3,
}
# Rotary embedding type → canonical form and C++ enum.
# Source: include/ck_tile/ops/fmha/block/rotary_embedding_enum.hpp.
ROPE_CANONICAL = {
"none": "none",
"no": "none",
"inter": "inter",
"interleaved": "inter",
"half": "half",
"half_rotated": "half",
}
ROPE_TO_CPP = {
"none": "ck_tile::RotaryEmbeddingEnum::NONE",
"inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED",
}
ROPE_TO_INT = {
"none": 0,
"inter": 1,
"half": 2,
}
# V layout → C++ bool (true = row-major, false = column-major).
# Source: TileFmhaShape<..., IsVLayoutRowMajor> template parameter.
LAYOUT_TO_BOOL = {
"r": "true",
"row": "true",
"row_major": "true",
"c": "false",
"col": "false",
"col_major": "false",
}
# KV cache memory layout → canonical form and C++ enum.
# Source: include/ck_tile/ops/fmha/block/block_attention_kv_cache.hpp.
KV_MEMORY_LAYOUT_CANONICAL = {
"vectorized": "vectorized",
"linear": "linear",
}
KV_MEMORY_LAYOUT_TO_CPP = {
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
}
KV_MEMORY_LAYOUT_TO_INT = {
"vectorized": 0,
"linear": 1,
}
# KV lookup table type → canonical form and C++ enum.
# Source: include/ck_tile/ops/fmha/block/block_attention_kv_cache.hpp.
KV_LOOKUP_CANONICAL = {
"sglang": "sglang",
"vllm": "vllm",
}
KV_LOOKUP_TO_CPP = {
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
}
KV_LOOKUP_TO_INT = {
"vllm": 0,
"sglang": 1,
}
# Pipeline tag → C++ pipeline class.
# Source: include/ck_tile/ops/fmha/pipeline/ — one header per pipeline variant.
PIPELINE_TO_CPP = {
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
"v3": "ck_tile::BlockFmhaFwdV3Pipeline",
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
"qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS",
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS",
"appendkv": "ck_tile::BlockFmhaFwdAppendKVPipeline",
"batch_prefill_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}
# Pipeline tag → C++ pipeline enum value.
# Source: include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp.
PIPELINE_ENUM_TO_CPP = {
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
"v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"batch_prefill_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
True: "true",
False: "false",
"t": "true",
"f": "false",
}
def canonical_mask(value: str) -> str:
return MASK_CANONICAL.get(value, value)
def canonical_bias(value: str) -> str:
return BIAS_CANONICAL.get(value, value)
def canonical_qscale(value: str) -> str:
return QSCALE_CANONICAL.get(value, value)
def canonical_rope(value: str) -> str:
return ROPE_CANONICAL.get(value, value)
def canonical_kv_memory_layout(value: str) -> str:
return KV_MEMORY_LAYOUT_CANONICAL.get(value, value)
def canonical_kv_lookup(value: str) -> str:
return KV_LOOKUP_CANONICAL.get(value, value)
def sanitize_token(value) -> str:
return str(value).replace("::", "_").replace("/", "_").replace(" ", "_")
def kernel_name_from_config(config: dict) -> str:
sig = config["signature"]
alg = config["algorithm"]
family = sanitize_token(sig["family"])
dtype = sanitize_token(sig["data_type"])
mode = sanitize_token(sig["mode"])
vlayout = sanitize_token(sig["vlayout"])
mask = sanitize_token(canonical_mask(sig["mask"]))
bias = sanitize_token(canonical_bias(sig["bias"]))
qscale = sanitize_token(canonical_qscale(sig["qscale"]))
rope = sanitize_token(canonical_rope(sig["rope"]))
kv_memory = sanitize_token(canonical_kv_memory_layout(sig["kv_memory_layout"]))
kv_lookup = sanitize_token(canonical_kv_lookup(sig["kv_lookup_table"]))
pipeline = sanitize_token(alg["pipeline"])
canonical_blob = json.dumps(
{
"family": family,
"dtype": dtype,
"mode": mode,
"vlayout": vlayout,
"mask": mask,
"bias": bias,
"qscale": qscale,
"rope": rope,
"kv_memory": kv_memory,
"kv_lookup": kv_lookup,
"sig": sig,
"alg": alg,
},
sort_keys=True,
).encode("utf-8")
digest = hashlib.sha1(canonical_blob).hexdigest()[:12]
return (
f"fmha_{family}_{dtype}_{mode}_h{sig['hdim_q']}x{sig['hdim_v']}"
f"_{pipeline}_{digest}"
)

View File

@@ -0,0 +1,921 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
FMHA validation and kernel specifications.
Architecture-specific data (dtypes, pipelines, hdims, tile tables) is stored in
``fmha_arch_specs.json`` so that it can be edited without touching Python code.
Common GPU hardware data (element sizes, warp size, LDS capacity) is imported
from the parent ``arch_specs_generated`` module (generated from ``arch_specs.json``).
This file provides:
- JSON loading helpers
- Tile constraints (per-arch rules that reject invalid tiles)
- Feature compatibility rules (pipeline × feature flag interactions)
- Receipt filters and profiles (deployment-specific kernel subsets)
- Config validation for the AOT codegen path
"""
import json
import sys
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple
# Ensure this directory and parent codegen/ are on sys.path for sibling imports
_THIS_DIR = Path(__file__).resolve().parent
_CODEGEN_DIR = _THIS_DIR.parent
sys.path.insert(0, str(_THIS_DIR))
sys.path.insert(0, str(_CODEGEN_DIR))
from symbol_map import ( # noqa: E402
BWD_DTYPE_MAP,
FWD_DTYPE_MAP,
canonical_bias,
canonical_mask,
canonical_qscale,
)
# Import shared hardware data from parent arch_specs_generated (generated from
# arch_specs.json by generate_arch_specs.py). Falls back to inline defaults if
# the generated module is unavailable (e.g. in standalone testing).
try:
from arch_specs_generated import ELEMENT_SIZE_MAP as _PARENT_ELEMENT_SIZES # noqa: E402
except ImportError:
_PARENT_ELEMENT_SIZES = {
"fp16": 2,
"bf16": 2,
"fp32": 4,
"fp64": 8,
"fp8": 1,
"bf8": 1,
"int8": 1,
"int4": 0.5,
"pk_fp4": 0.5,
"int32": 4,
}
# =============================================================================
# JSON data loading
# =============================================================================
_FMHA_SPECS_PATH = _THIS_DIR / "fmha_arch_specs.json"
def _load_fmha_specs() -> dict:
"""Load fmha_arch_specs.json (cached after first call)."""
if not hasattr(_load_fmha_specs, "_cache"):
with open(_FMHA_SPECS_PATH) as f:
_load_fmha_specs._cache = json.load(f)
return _load_fmha_specs._cache
def _build_element_sizes() -> Dict[str, int]:
"""Merge parent element sizes with FMHA-specific composite dtypes."""
base = {k: int(v) for k, v in _PARENT_ELEMENT_SIZES.items()}
base.update(_load_fmha_specs().get("fmha_element_sizes", {}))
return base
# =============================================================================
# 1. Architecture capabilities (loaded from fmha_arch_specs.json)
# =============================================================================
def _build_arch_dtypes() -> Dict[str, List[str]]:
"""Build ARCH_DTYPES from JSON architectures."""
return {
arch: info["supported_dtypes"]
for arch, info in _load_fmha_specs()["architectures"].items()
}
def _build_supported_hdims() -> Dict[str, List[Tuple[int, int]]]:
"""Build SUPPORTED_HDIMS from JSON, converting [q,v] lists to tuples."""
return {
dtype: [tuple(pair) for pair in pairs]
for dtype, pairs in _load_fmha_specs()["supported_hdims"].items()
if dtype != "_comment"
}
def _build_arch_metadata() -> Dict[str, dict]:
"""Build ARCH_METADATA from JSON architectures."""
return dict(_load_fmha_specs()["architectures"])
ARCH_DTYPES: Dict[str, List[str]] = _build_arch_dtypes()
SUPPORTED_HDIMS: Dict[str, List[Tuple[int, int]]] = _build_supported_hdims()
ARCH_METADATA: Dict[str, dict] = _build_arch_metadata()
# =============================================================================
# 2. Tile hardware parameters (loaded from fmha_arch_specs.json + parent arch_specs)
# =============================================================================
def _build_warp_classes() -> Dict[str, List[Tuple[int, int, int]]]:
"""Build WARP_CLASSES from JSON fmha_warp_tiles."""
return {
dtype: [tuple(w) for w in warps]
for dtype, warps in _load_fmha_specs()["fmha_warp_tiles"].items()
if dtype != "_comment"
}
def _build_lds_limits() -> Dict[str, int]:
"""Build LDS_LIMITS from JSON."""
return dict(_load_fmha_specs()["lds_limits"])
def _build_k0max_map() -> Dict[int, int]:
"""Build K0_MAX_SUBMAX_MAP from JSON (string keys → int keys)."""
return {
int(k): v for k, v in _load_fmha_specs()["k0max_map"].items() if k != "_comment"
}
_specs = _load_fmha_specs()
_tile_ranges = _specs["tile_sweep_ranges"]
LDS_LIMITS: Dict[str, int] = _build_lds_limits()
WARP_CLASSES: Dict[str, List[Tuple[int, int, int]]] = _build_warp_classes()
ELEMENT_SIZES: Dict[str, int] = _build_element_sizes()
VALID_BM0: List[int] = _tile_ranges["valid_bm0"]
VALID_BN0: List[int] = _tile_ranges["valid_bn0"]
VALID_BK0: List[int] = _tile_ranges["valid_bk0"]
K0_MAX_SUBMAX_MAP: Dict[int, int] = _build_k0max_map()
# =============================================================================
# 3. Tile constraints
# =============================================================================
def check_gfx9_tile_constraints(
dtype: str,
hdim_q: int,
hdim_v: int,
pipeline: str,
bm0: int,
bn0: int,
bk0: int,
) -> bool:
"""Gfx9 compatibility rules.
Source: fmha_fwd.py CompatibilityRuleFactoryGfx9.check_hdim_tile().
Applies to gfx90a, gfx942, gfx950 for pipelines in {qr, qr_async, qs}.
Note: CK factory is stricter (bm0==128 only for non-128 hdims); we allow
{64, 128, 192, 256} to let the tile engine explore more configurations.
"""
if dtype == "fp32":
return True
if pipeline not in ("qr", "qr_async", "qs"):
return True
if (hdim_q, hdim_v) == (128, 128) and bn0 != 128:
return False
if (hdim_q, hdim_v) == (128, 128) and pipeline == "qr_async" and bm0 != 128:
return False
if (hdim_q, hdim_v) != (128, 128) and bm0 not in (64, 128, 192, 256):
return False
if (hdim_q, hdim_v) == (128, 128) and pipeline != "qr_async" and bk0 == 64:
return False
return True
def check_gfx950_tile_constraints(
hdim_q: int,
hdim_v: int,
pipeline: str,
bm0: int,
bn0: int,
) -> bool:
"""Gfx950 trload/v3 constraints.
Source: fmha_fwd.py CompatibilityRuleFactoryGfx950.check_tile_pipeline().
Note: CK enforces biconditional (v3_tile ↔ v3_pipeline); we only enforce
v3_pipeline → v3_tile since non-v3 pipelines may still use bm0=256.
"""
if pipeline == "qr_async_trload":
if (hdim_q, hdim_v) == (128, 128) and bn0 == 128:
return False
if (hdim_q, hdim_v) not in [(64, 64), (128, 128)]:
return False
is_v3_tile = bm0 == 256
is_v3_pipeline = pipeline == "qr_async_trload_v3"
# v3 pipeline requires bm0=256; other pipelines also allow bm0=256
if is_v3_pipeline and not is_v3_tile:
return False
return True
def check_qr_mfma_insts(
arch: str,
hdim_q: int,
pipeline: str,
bn0: int,
bk0: int,
wn0: int,
wk0: int,
) -> bool:
"""NumMfmaInsts % 8 == 0 check.
Source: block_fmha_pipeline_qr_ks_vs.hpp static_assert at line ~490.
Full C++ formula: (kM0/WarpM)*(kN0/WarpN)*(kK0/WarpK) / (MWarp*NWarp).
We simplify to (bn0/wn0)*(bk0/wk0), omitting (bm0/wm0)/(rm0*rn0) which
equals 1 for all current fp16/bf16/fp32/fp8 tiles, or a power-of-2 factor
for mxfp8/mxfp4 that doesn't change the mod-8 result. This is conservative:
it can only reject tiles the full formula would also reject, never the reverse.
Only applies to qr pipeline + hdim_q==256 + CDNA (gfx9*).
"""
if pipeline != "qr" or hdim_q != 256:
return True
if not arch.startswith("gfx9"):
return True
num_mfma = (bn0 // wn0) * (bk0 // wk0)
if num_mfma % 8 != 0:
return False
return True
def tile_passes_all_constraints(
arch: str,
dtype: str,
hdim_q: int,
hdim_v: int,
pipeline: str,
bm0: int,
bn0: int,
bk0: int,
wm0: int,
wn0: int,
wk0: int,
) -> bool:
"""Master constraint check — returns True if the tile is valid."""
elem_size = ELEMENT_SIZES.get(dtype, 2)
lds_limit = LDS_LIMITS.get(pipeline, 65536)
# LDS capacity check (pipeline-dependent formula)
if pipeline in ("qr_async", "qr_async_trload", "qr_async_trload_v3"):
# Async pipeline: Q is in registers. LDS holds NumKVLdsBuffers (=3) copies of
# max(SingleKSize, SingleVSize). Derived from GetSmemSizeKV() in
# block_fmha_pipeline_qx_ks_vs_custom_policy.hpp.
#
# SingleVSize formula (MakeVLdsBlockDescriptor):
# Banks=32, PixelsPerRow = Banks*4/sizeof(dtype) = 32*4/elem_size
# kKPack = 16/elem_size (GetSmemKPackV)
# NPerRow = PixelsPerRow/kKPack
# SingleVSize = (bk1/kKPack) * (hdim_v/NPerRow) * (PixelsPerRow + kKPack)
# For bf16: PixelsPerRow=64, kKPack=8, NPerRow=8
# SingleVSize = (32/8)*(hdim_v/8)*(64+8) = 4*(hdim_v/8)*72 = 36*hdim_v
#
# SingleKSize formula (GetSingleSmemElementSpaceSize, async branch):
# KPack = 16/elem_size, KVector = alignment (gfx950: 16/elem_size = 8 for bf16)
# LanesPerK = bk0/KVector, LaneGroups = 64/LanesPerK
# NumIssues = bn0/(LaneGroups*NumWarps)
# SingleKSize = NumIssues*NumWarps*(64*KVector + KPack)
#
bk1 = 32 # kK1 in TileFmhaShape — design choice from fmha_fwd.py tile defs
num_warps = bm0 // wm0
# Banks: arch.hpp get_n_lds_banks() — 64 for gfx950, 32 for older
banks = 64 if arch == "gfx950" else 32
pixels_per_row = banks * 4 // elem_size # Banks * 4bytes / sizeof(dtype)
k_pack = 16 // elem_size # GetSmemKPackV: 16 / sizeof(dtype)
n_per_row = pixels_per_row // k_pack
single_v = (bk1 // k_pack) * (hdim_v // n_per_row) * (pixels_per_row + k_pack)
# KVector: GetAlignmentK in custom_policy.hpp — MaxLoadSizeInBytes / sizeof(dtype)
# gfx950 uses dwordx4 (16 bytes), older uses dword (4 bytes)
k_vector = 16 // elem_size if arch == "gfx950" else 4 // elem_size
lanes_per_k = bk0 // k_vector if k_vector > 0 else 1
lane_groups = 64 // lanes_per_k if lanes_per_k > 0 else 1 # WarpSize=64
num_issues = (
bn0 // (lane_groups * num_warps) if (lane_groups * num_warps) > 0 else 0
)
single_k = num_issues * num_warps * (64 * k_vector + k_pack)
single_buf_bytes = max(single_k, single_v) * elem_size
# NumPrefetchK = NumPrefetchV = 3 (async_default_policy.hpp)
num_kv_buffers = 3
# Q uses registers (QLoadOnce=true), so GetSmemSizeQ() = 0.
total_lds = single_buf_bytes * num_kv_buffers
# gfx950 HW LDS limit: arch.hpp get_smem_capacity() = 163840 (160 KiB)
if total_lds > 163840:
return False
else:
# Non-async (qr/qs): Q and K tiles share LDS simultaneously
if (bm0 * bk0 + bn0 * bk0) * elem_size > lds_limit:
return False
# bk0 range
if bk0 > hdim_q:
return False
# hdim_q divisibility (tile_fmha_shape.hpp:60)
if hdim_q % bk0 != 0:
return False
# Warp alignment
if bm0 % wm0 != 0 or bk0 % wk0 != 0 or bn0 % wn0 != 0:
return False
# MFMA inst count
if not check_qr_mfma_insts(arch, hdim_q, pipeline, bn0, bk0, wn0, wk0):
return False
# Async DMA distribution constraint (MakeKLdsStoreBlockDescriptor, custom_policy.hpp).
# NumIssues = kNPerBlock / (LaneGroups * NumWarps) must be a positive integer, where
# LaneGroups = WarpSize / LanesPerK = 64 / (bk0 / KVector).
# Equivalently: (bn0 * bk0) % (kBlockSize * KVector) == 0.
# KVector = MaxLoadSizeInBytes / sizeof(dtype): gfx950=16/2=8, older=4/2=2 for bf16.
if pipeline == "qr_async" and arch.startswith("gfx9"):
kvector = 16 // elem_size if arch == "gfx950" else 4 // elem_size
num_warps = bm0 // wm0
block_size = num_warps * 64 # WarpSize = 64
if (bn0 * bk0) % (block_size * kvector) != 0:
return False
# Arch constraints
if arch in ("gfx90a", "gfx942", "gfx950"):
if not check_gfx9_tile_constraints(
dtype, hdim_q, hdim_v, pipeline, bm0, bn0, bk0
):
return False
if arch == "gfx950":
if not check_gfx950_tile_constraints(hdim_q, hdim_v, pipeline, bm0, bn0):
return False
return True
# =============================================================================
# 4. Feature compatibility rules
# =============================================================================
# Supported mask, bias, and boolean values for feature products.
# These are the template enum values in CK's FMHA traits structs.
MASKS = ["no", "causal", "generic"]
BIASES = ["no", "bias", "alibi"]
BOOLS = ["t", "f"]
# Dtype groups matching CK's _DT_* classification in fmha_fwd.py factory classes.
DT_FP16_BF16 = {"fp16", "bf16"}
DT_FP8 = {"fp8bf16", "fp8", "bf8"}
DT_FP8FP32 = {"fp8fp32"}
DT_FP32 = {"fp32"}
def check_logits_bias(logits: str, bias: str) -> bool:
"""logits_soft_cap requires no bias.
Source: fmha_fwd.py CompatibilityRuleFactory.check_feature().
"""
return not (logits == "t" and bias != "no")
def check_group_mode_padding(mode: str, spad: str, skpad: str) -> bool:
"""Group mode requires spad=t and skpad=t.
Source: fmha_fwd.py CompatibilityRuleFactory.check_feature() +
block_fmha_pipeline static_asserts for padding.
"""
if mode == "group":
return spad == "t" and skpad == "t"
return True
# =============================================================================
# 5. Variant-specific tile tables (loaded from fmha_arch_specs.json)
# =============================================================================
def _build_bwd_tiles() -> Tuple[
Dict[Tuple[int, int], Tuple[int, ...]],
Dict[Tuple[int, int], List[Tuple[Tuple[int, ...], str, bool]]],
Dict[Tuple[int, int, int, str], dict],
]:
"""Build BWD tile tables from JSON."""
bwd = _load_fmha_specs()["bwd_tiles"]
# Main tiles: "hdimq_hdimv" -> 9-tuple
main = {}
for k, v in bwd["dq_dk_dv_fp16"].items():
hq, hv = map(int, k.split("_"))
main[(hq, hv)] = tuple(v)
# Extra tiles: "hdimq_hdimv" -> [(tile, tag, batch_only), ...]
extra = {}
for k, entries in bwd.get("dq_dk_dv_extra", {}).items():
hq, hv = map(int, k.split("_"))
extra[(hq, hv)] = [
(tuple(e["tile"]), e["tag"], e["batch_only"]) for e in entries
]
# Wave/warp lookup: "bm0_bn0_bk0_trload" -> {wave, warp_k1}
ww = {}
for k, v in _load_fmha_specs()["bwd_wave_warp"].items():
if k.startswith("_"):
continue
parts = k.split("_")
key = (int(parts[0]), int(parts[1]), int(parts[2]), parts[3])
ww[key] = {"wave": tuple(v["wave"]), "warp_k1": v["warp_k1"]}
return main, extra, ww
def _build_splitkv_hdims() -> Tuple[List[int], List[int]]:
"""Build SplitKV combine hdim lists from JSON."""
skv = _load_fmha_specs()["splitkv_combine"]
return skv["hdims_fp16"], skv["hdims_fp8"]
_bwd_main, _bwd_extra, _bwd_ww = _build_bwd_tiles()
_skv_fp16, _skv_fp8 = _build_splitkv_hdims()
SPLITKV_COMBINE_HDIMS_FP16: List[int] = _skv_fp16
SPLITKV_COMBINE_HDIMS_FP8: List[int] = _skv_fp8
BWD_DQ_DK_DV_TILES_FP16: Dict[Tuple[int, int], Tuple[int, ...]] = _bwd_main
BWD_DQ_DK_DV_EXTRA_TILES: Dict[
Tuple[int, int], List[Tuple[Tuple[int, ...], str, bool]]
] = _bwd_extra
BWD_DQ_WAVE_WARP: Dict[Tuple[int, int, int, str], dict] = _bwd_ww
_bwd_json = _load_fmha_specs()["bwd_tiles"]
BWD_EXTRA_PAD_COMBOS: List[Tuple[str, str]] = [
tuple(p) for p in _bwd_json["extra_pad_combos"]
]
BWD_SMALL_DROPOUTS: List[str] = _bwd_json["small_dropouts"]
BWD_DOT_DO_O_HDIMS: List[int] = _bwd_json["dot_do_o_hdims"]
BWD_CONVERT_DQ_HDIMS: List[int] = _bwd_json["convert_dq_hdims"]
BWD_CONVERT_DQ_TILE_GROUPS: Dict[int, int] = {
int(k): v for k, v in _bwd_json["convert_dq_tile_groups"].items()
}
BWD_DROPOUTS: List[str] = _bwd_json["dropouts"]
BWD_PAD_COMBOS: List[Tuple[str, str]] = [tuple(p) for p in _bwd_json["pad_combos"]]
# =============================================================================
# 6. Receipt filters
# =============================================================================
class Receipt(IntEnum):
"""Named receipt levels for deployment profiles.
These are deployment-specific filters, not derived from C++ constraints.
They control which kernel subsets are emitted for different integration
targets (PyTorch, AITER, Flash-Attention, etc.).
"""
CK_DEFAULT = 0
CK_EXTENDED = 1
FLASH_FWD = 2
FLASH_BWD = 3
PYTORCH = 4
AITER_BATCH = 100
AITER_GROUP = 200
AITER_BWD_BATCH = 300
AITER_BWD_GROUP = 400
AITER_CPP = 600
FP32_ALL = 800
FP32_MIN = 801
FP8_TEST = 888
RECEIPT_FILTERS: Dict[int, Callable[[str, object], bool]] = {
0: lambda dtype, spec: dtype != "fp32",
2: lambda dtype, spec: (
dtype in ("fp16", "bf16")
and getattr(spec, "bias", "no") in ("no", "alibi")
and getattr(spec, "qscale", "no") == "no"
and getattr(spec, "skip", "f") == "f"
and getattr(spec, "sink", "f") == "f"
),
4: lambda dtype, spec: (
dtype in ("fp16", "bf16")
and getattr(spec, "bias", "no") in ("no", "bias")
and getattr(spec, "qscale", "no") == "no"
and getattr(spec, "skip", "f") == "f"
and getattr(spec, "logits", "f") == "f"
),
100: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"),
200: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"),
600: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"),
888: lambda dtype, spec: dtype in ("fp8bf16", "fp8fp32"),
800: lambda dtype, spec: (
dtype == "fp32"
and getattr(spec, "skip", "f") == "f"
and getattr(spec, "logits", "f") == "f"
),
}
def receipt_filter(receipt: int, dtype: str, spec) -> bool:
"""Apply receipt-level filter. Returns True if the kernel should be kept."""
fn = RECEIPT_FILTERS.get(receipt)
if fn is None:
return dtype != "fp32"
return fn(dtype, spec)
# =============================================================================
# 7. Profiles
# =============================================================================
PROFILE_ALIASES: Dict[str, str] = {str(r.value): r.name.lower() for r in Receipt}
@dataclass(frozen=True)
class FmhaProfile:
name: str
predicate: Callable[[dict], bool]
def allows(self, config: dict) -> bool:
return self.predicate(config)
def _dtype_is(config: dict, allowed: Iterable[str]) -> bool:
return config["signature"]["data_type"] in set(allowed)
def _mode_is(config: dict, allowed: Iterable[str]) -> bool:
return config["signature"]["mode"] in set(allowed)
def _family_is(config: dict, allowed: Iterable[str]) -> bool:
return config["signature"]["family"] in set(allowed)
def _common_row_major_filter(config: dict) -> bool:
return config["signature"]["vlayout"] == "r"
def _bias_is(config: dict, allowed: Iterable[str]) -> bool:
return canonical_bias(config["signature"]["bias"]) in set(allowed)
def _qscale_is(config: dict, allowed: Iterable[str]) -> bool:
return canonical_qscale(config["signature"]["qscale"]) in set(allowed)
def _no_skip_or_logits(config: dict) -> bool:
return (not config["signature"]["skip_min_seqlen_q"]) and (
not config["signature"]["logits"]
)
PROFILES: Dict[str, FmhaProfile] = {
"ck_default": FmhaProfile(
"ck_default", lambda c: c["signature"]["data_type"] != "fp32"
),
"ck_extended": FmhaProfile(
"ck_extended", lambda c: c["signature"]["data_type"] != "fp32"
),
"flash_fwd": FmhaProfile(
"flash_fwd",
lambda c: (
_family_is(c, {"fwd", "fwd_splitkv", "fwd_appendkv", "fwd_pagedkv"})
and _dtype_is(c, {"fp16", "bf16"})
and _common_row_major_filter(c)
and _bias_is(c, {"no", "alibi"})
and _qscale_is(c, {"no"})
and not c["signature"]["skip_min_seqlen_q"]
),
),
"flash_bwd": FmhaProfile(
"flash_bwd",
lambda c: (
_family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"})
and _dtype_is(c, {"fp16", "bf16"})
),
),
"pytorch": FmhaProfile(
"pytorch",
lambda c: (
_dtype_is(c, {"fp16", "bf16"})
and _common_row_major_filter(c)
and _bias_is(c, {"no", "bias"})
and _qscale_is(c, {"no"})
and _no_skip_or_logits(c)
and not c["signature"].get("sink", False)
),
),
"aiter_batch": FmhaProfile(
"aiter_batch",
lambda c: (
_dtype_is(c, {"fp16", "bf16", "fp8bf16"})
and _mode_is(c, {"batch"})
and _common_row_major_filter(c)
and (
c["signature"]["data_type"] != "fp8bf16"
or c["signature"]["hdim_q"] in {128, 192}
)
),
),
"aiter_group": FmhaProfile(
"aiter_group",
lambda c: (
_dtype_is(c, {"fp16", "bf16", "fp8bf16"})
and _mode_is(c, {"group"})
and _common_row_major_filter(c)
),
),
"aiter_bwd_batch": FmhaProfile(
"aiter_bwd_batch",
lambda c: (
_family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"})
and _dtype_is(c, {"fp16", "bf16"})
and _mode_is(c, {"batch"})
),
),
"aiter_bwd_group": FmhaProfile(
"aiter_bwd_group",
lambda c: (
_family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"})
and _dtype_is(c, {"fp16", "bf16"})
and _mode_is(c, {"group"})
),
),
"aiter_cpp": FmhaProfile(
"aiter_cpp",
lambda c: (
_dtype_is(c, {"fp16", "bf16", "fp8bf16"})
and _common_row_major_filter(c)
and (
c["signature"]["data_type"] != "fp8bf16"
or c["signature"]["hdim_q"] in {128, 192}
)
),
),
"fp32_all": FmhaProfile(
"fp32_all", lambda c: _dtype_is(c, {"fp32"}) and _no_skip_or_logits(c)
),
"fp32_min": FmhaProfile(
"fp32_min",
lambda c: (
_dtype_is(c, {"fp32"})
and _mode_is(c, {"batch"})
and c["signature"]["hdim_q"] in {48, 128}
and c["signature"]["hdim_v"] in {48, 128}
and canonical_bias(c["signature"]["bias"]) == "no"
and not c["signature"]["lse"]
and not c["signature"]["dropout"]
and canonical_qscale(c["signature"]["qscale"]) == "no"
),
),
"fp8_test": FmhaProfile(
"fp8_test",
lambda c: (
_dtype_is(c, {"fp8bf16", "fp8fp32"})
and c["signature"]["hdim_q"] in {128, 192}
and _common_row_major_filter(c)
),
),
"all": FmhaProfile("all", lambda _: True),
}
def normalize_profile(
profile: Optional[str] = None, receipt: Optional[str] = None
) -> str:
if profile:
return PROFILE_ALIASES.get(str(profile), str(profile))
if receipt is not None:
return PROFILE_ALIASES.get(str(receipt), str(receipt))
return "ck_default"
def get_profile(
profile: Optional[str] = None, receipt: Optional[str] = None
) -> FmhaProfile:
normalized = normalize_profile(profile=profile, receipt=receipt)
if normalized not in PROFILES:
raise KeyError(f"Unknown FMHA profile: {normalized}")
return PROFILES[normalized]
def profile_allows(
config: dict, profile: Optional[str] = None, receipt: Optional[str] = None
) -> bool:
return get_profile(profile=profile, receipt=receipt).allows(config)
# =============================================================================
# 8. Validation helpers (for unified_fmha_codegen)
# =============================================================================
_DEFAULTS: dict = _load_fmha_specs()["defaults"]
_GLOBAL_RULES: dict = _load_fmha_specs()["global_rules"]
def load_arch_specs() -> dict:
"""Return arch_specs dict compatible with unified_fmha_codegen.
Combines FMHA-specific architecture data from fmha_arch_specs.json with
defaults, global rules, and splitkv combine params.
"""
specs = _load_fmha_specs()
return {
"architectures": ARCH_METADATA,
"defaults": _DEFAULTS,
"global_rules": _GLOBAL_RULES,
"splitkv_combine": specs["splitkv_combine"],
}
# =============================================================================
# 9. Config validation (for unified_fmha_codegen)
# =============================================================================
@dataclass
class ValidationResult:
valid: bool = True
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
def add_error(self, msg: str):
self.valid = False
self.errors.append(msg)
def add_warning(self, msg: str):
self.warnings.append(msg)
def validate_config(
config: dict, arch_specs: Optional[dict] = None
) -> "ValidationResult":
"""Validate an FMHA kernel config against all rules."""
arch_specs = arch_specs or load_arch_specs()
result = ValidationResult()
sig = config["signature"]
alg = config["algorithm"]
arch = config["arch"]
architectures = arch_specs.get("architectures", ARCH_METADATA)
if arch not in architectures:
result.add_error(f"Unsupported FMHA target architecture: {arch}")
return result
arch_info = architectures[arch]
global_rules = arch_specs.get("global_rules", _GLOBAL_RULES)
dtype = sig["data_type"]
family = sig["family"]
pipeline = alg["pipeline"]
canonical_mask(sig["mask"])
bias = canonical_bias(sig["bias"])
# Family validation
supported_families = {
"fwd",
"fwd_pagedkv",
"fwd_splitkv",
"fwd_splitkv_combine",
"fwd_appendkv",
"batch_prefill",
"bwd_dot_do_o",
"bwd_dq_dk_dv",
"bwd_convert_dq",
}
if family not in supported_families:
result.add_error(f"Unsupported FMHA family: {family}")
# Dtype validation
supported_dtypes = set(arch_info["supported_dtypes"])
if dtype not in supported_dtypes:
result.add_error(f"dtype {dtype} is not supported on {arch}")
if family.startswith("bwd") and dtype not in BWD_DTYPE_MAP:
result.add_error(
f"Backward family {family} only supports {sorted(BWD_DTYPE_MAP)}"
)
if (
family.startswith("fwd")
and not family.startswith("fwd_append")
and dtype not in FWD_DTYPE_MAP
):
result.add_error(f"Forward family {family} does not recognize dtype {dtype}")
# Pipeline validation
if (
family != "fwd_splitkv_combine"
and pipeline not in arch_info["supported_pipelines"]
):
result.add_error(f"pipeline {pipeline} is not supported on {arch}")
if pipeline in {"v3", "qr_async_trload_v3"} and not arch_info.get(
"supports_v3", False
):
result.add_warning(f"v3 pipeline on {arch} requires supports_v3 in arch specs")
if pipeline == "qr_async_trload" and not arch_info.get("supports_trload", False):
result.add_error("qr_async_trload requires a trload-capable architecture")
# Global rules
hdim_q = sig["hdim_q"]
hdim_v = sig["hdim_v"]
divisor = global_rules.get("hdim_divisible_by", 8)
if hdim_q % divisor != 0 or hdim_v % divisor != 0:
result.add_error(f"Head dimensions must be multiples of {divisor}")
if global_rules.get("hdim_192_128_no_bias_dropout"):
if (
hdim_q == 192
and hdim_v == 128
and (bias != "no" or sig.get("dropout", False))
):
result.add_warning(
"hdim (192,128) with bias/dropout has limited tile support"
)
if global_rules.get("logits_requires_no_bias"):
if bias != "no" and sig.get("logits", False):
result.add_error("logits_soft_cap cannot be combined with bias")
if pipeline in {"qr_async_trload", "v3", "qr_async_trload_v3"} and (
hdim_q != hdim_v or hdim_q not in {64, 128}
):
result.add_error(f"{pipeline} only supports symmetric head dims 64 or 128")
# Tile validation
tile = alg["tile"]
expected_tile_len = 9 if family == "bwd_dq_dk_dv" else 6
if len(tile) != expected_tile_len or len(alg["wave"]) != 9 or len(alg["warp"]) != 9:
result.add_error(
f"tile/wave/warp must have {expected_tile_len}/9/9 elements for {family}"
)
# MFMA instruction count check for qr/h256/CDNA
_1d_families = {"bwd_dot_do_o", "bwd_convert_dq"}
if (
pipeline == "qr"
and hdim_q == 256
and family not in _1d_families
and arch_info.get("family", "").startswith("cdna")
and len(tile) >= 3
and len(alg["wave"]) >= 2
and len(alg["warp"]) >= 3
):
wm, wn, wk = alg["warp"][0], alg["warp"][1], alg["warp"][2]
gm, gn = alg["wave"][0], alg["wave"][1]
if wm > 0 and wn > 0 and wk > 0 and gm > 0 and gn > 0:
num_mfma = (tile[0] // wm) * (tile[1] // wn) * (tile[2] // wk) // (gm * gn)
if num_mfma % 8 != 0:
result.add_error(
f"NumMfmaInsts={num_mfma} must be divisible by 8 for qr/h256/CDNA"
)
if alg["block_per_cu"] <= 0 and alg["block_per_cu"] != -1:
result.add_error("block_per_cu must be positive or -1 (auto)")
if alg["num_wave_groups"] <= 0:
result.add_error("num_wave_groups must be positive")
# --- Family-specific rules ---
if family == "batch_prefill":
if sig.get("vlayout", "r") != "r":
result.add_error("batch_prefill only supports row-major V layout")
if not sig.get("paged_kv", False):
result.add_error("batch_prefill requires paged_kv=true")
ps = sig.get("page_size", 0)
if ps <= 0 or (ps & (ps - 1)) != 0:
result.add_error("batch_prefill page_size must be a positive power of two")
if sig.get("mode", "batch") != "group":
result.add_error("batch_prefill requires group mode")
if pipeline != "qr_async":
result.add_error("batch_prefill currently uses qr_async pipeline")
if family == "fwd_appendkv":
if sig.get("mode", "batch") != "batch":
result.add_error("fwd_appendkv uses batch-mode public API surface")
if pipeline != "appendkv":
result.add_error("fwd_appendkv must use appendkv pipeline")
if sig.get("vlayout", "r") != "r":
result.add_error("fwd_appendkv currently only supports row-major V")
if family == "fwd_splitkv_combine":
if sig.get("mode", "batch") not in {"batch", "group"}:
result.add_error("fwd_splitkv_combine requires batch or group mode")
combine_bn1 = arch_specs.get("splitkv_combine", {}).get("combine_bn1", 32)
if len(tile) > 3 and tile[3] != combine_bn1:
result.add_error(f"fwd_splitkv_combine requires bn1={combine_bn1}")
if len(tile) > 3 and (hdim_v < tile[3] or hdim_v % tile[3] != 0):
result.add_error("fwd_splitkv_combine requires hdim_v divisible by bn1")
if family == "fwd_pagedkv":
if pipeline != "qr_pagedkv":
result.add_error("fwd_pagedkv must use qr_pagedkv pipeline")
if not sig.get("paged_kv", False):
result.add_error("fwd_pagedkv requires paged_kv=true")
if sig.get("vlayout", "r") != "r":
result.add_error("fwd_pagedkv currently only supports row-major V")
if family == "fwd_splitkv":
if pipeline not in {"qr", "qr_nwarp_sshuffle"}:
result.add_error("fwd_splitkv must use qr or qr_nwarp_sshuffle pipeline")
if sig.get("vlayout", "r") != "r":
result.add_error("fwd_splitkv currently only supports row-major V")
if family == "fwd" and sig.get("vlayout", "r") != "r":
result.add_warning("dispatcher forward examples currently assume row-major V")
return result

View File

@@ -230,7 +230,7 @@ def generate_cpp_header(specs: Dict[str, Any], output_path: Path):
for arch, data in archs.items():
enum_name = arch.upper().replace("GFX", "GFX_")
arch_enums.append(f" {enum_name}, // {data['description']}")
arch_enums.append(f" {enum_name},")
arch_to_string_cases.append(
f' case GpuArch::{enum_name}: return "{arch}";'
)
@@ -288,12 +288,12 @@ def generate_cpp_header(specs: Dict[str, Any], output_path: Path):
f" if (pipeline == Pipeline::{pipeline_enum_map[pipeline]}) return {limit};"
)
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
content = f"""// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
*
*
* Generated from: arch_specs.json
* Generated at: {timestamp}
*

View File

@@ -0,0 +1,310 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Single Source of Truth for Grouped Convolution Tile Configurations
This module defines all valid tile configurations for grouped convolution kernels.
Both codegen and instance_builder import from here to ensure consistency.
Architecture:
grouped_conv_tile_configs.py (SOURCE OF TRUTH)
├── Used by unified_grouped_conv_codegen.py
└── Used by grouped_conv_instance_builder.py
"""
from typing import Dict, List, Tuple
# =============================================================================
# Tile Configurations (Single Source of Truth)
# =============================================================================
# Common tile configurations used across variants
# Format: (tile_m, tile_n, tile_k)
# CRITICAL: tile_m MUST equal wave_m × warp_tile_m (TileGemmShape constraint)
# Only tiles that successfully compile are included
COMMON_TILES: List[Tuple[int, int, int]] = [
# Using warp_tile [16,16,16]: tile_m = wave_m × 16
(16, 64, 64), # 1 × 16 = 16, wave=(1,4,1)
(32, 64, 64), # 2 × 16 = 32, wave=(2,2,1)
(64, 64, 64), # 4 × 16 = 64, wave=(4,1,1)
# (128, 64, 64), # 8 × 16 = 128, wave=(8,2,1) - EXCLUDED: Compile error
# Using warp_tile [32,32,16]: tile_m = wave_m × 32
(32, 128, 64), # 1 × 32 = 32, wave=(1,4,1)
(64, 128, 64), # 2 × 32 = 64, wave=(2,2,1)
(128, 128, 64), # 4 × 32 = 128, wave=(4,4,1) - NEW!
# Note: 256x64x64 excluded - compilation issues
# Using warp_tile [16,16,32]: tile_m = wave_m × 16
(16, 64, 128), # 1 × 16 = 16, wave=(1,4,1)
(32, 64, 128), # 2 × 16 = 32, wave=(2,2,1)
(64, 64, 128), # 4 × 16 = 64, wave=(4,1,1)
(128, 64, 128), # 8 × 16 = 128, wave=(8,2,1) - NEW!
# Note: Excluded tiles:
# - 128x64x64: wave=8x2x1, warp=16x16x16 - compile error
# - 32x128x128, 64x128x128, 128x128x128, 256x128x128 (warp_tile 32x32x32) - compv4 issues
# - 256x64x64, 256x128x128 - arch filter rejection
]
# Wave configurations per tile
# Key: (tile_m, tile_n, tile_k) -> (wave_m, wave_n, wave_k)
# Constraint: tile_m == wave_m × warp_tile_m
# Only use approved wave configs from arch_specs.json: [1,4,1], [2,2,1], [4,1,1], [8,2,1], [4,4,1]
TILE_TO_WAVE: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
# warp_tile [16,16,16]
(16, 64, 64): (1, 4, 1),
(32, 64, 64): (2, 2, 1),
(64, 64, 64): (4, 1, 1),
# warp_tile [32,32,16]
(32, 128, 64): (1, 4, 1),
(64, 128, 64): (2, 2, 1),
(128, 128, 64): (4, 4, 1), # NEW - balanced 4x4 wave
# warp_tile [16,16,32]
(16, 64, 128): (1, 4, 1),
(32, 64, 128): (2, 2, 1),
(64, 64, 128): (4, 1, 1),
(128, 64, 128): (8, 2, 1), # NEW
}
# Warp tile configurations (must match arch_specs.json gfx950 bf16 approved list)
# Key: (tile_m, tile_n, tile_k) -> (warp_m, warp_n, warp_k)
TILE_TO_WARP: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
# warp_tile [16,16,16]
(16, 64, 64): (16, 16, 16),
(32, 64, 64): (16, 16, 16),
(64, 64, 64): (16, 16, 16),
# warp_tile [32,32,16]
(32, 128, 64): (32, 32, 16),
(64, 128, 64): (32, 32, 16),
(128, 128, 64): (32, 32, 16), # NEW
# warp_tile [16,16,32]
(16, 64, 128): (16, 16, 32),
(32, 64, 128): (16, 16, 32),
(64, 64, 128): (16, 16, 32),
(128, 64, 128): (16, 16, 32), # NEW
}
# Vector sizes per tile (for memory operations)
# Key: (tile_m, tile_n, tile_k) -> (vec_a, vec_b, vec_c)
TILE_TO_VECTOR: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
(16, 64, 64): (4, 8, 8),
(32, 64, 64): (4, 8, 8),
(64, 64, 64): (4, 8, 8),
(32, 128, 64): (4, 8, 8),
(64, 128, 64): (4, 8, 8),
(128, 128, 64): (4, 8, 8),
(16, 64, 128): (4, 8, 8),
(32, 64, 128): (4, 8, 8),
(64, 64, 128): (4, 8, 8),
(128, 64, 128): (4, 8, 8),
}
# =============================================================================
# Pipeline Variant Suffixes (single source of truth)
# =============================================================================
# Empirically verified valid (pipeline, wave_mode, has_dsb, has_si) combinations
# observed in the 2D and 3D bf16 gfx950 benchmark CSVs. 30 entries total per ndim.
# Each tuple: (pipeline, wave_mode, has_dsb, has_si)
# wave_mode: "intrawave" | "interwave"
# has_dsb: 1 if "_dsb" suffix present (double smem buffer), else 0
# has_si: 1 if "_si" suffix present (store immediate), else 0
PIPELINE_VARIANTS: List[Tuple[str, str, int, int]] = [
# basic_v1: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
("basic_v1", "intrawave", 0, 0),
("basic_v1", "intrawave", 1, 0),
("basic_v1", "intrawave", 0, 1),
("basic_v1", "intrawave", 1, 1),
("basic_v1", "interwave", 0, 0),
("basic_v1", "interwave", 1, 0),
("basic_v1", "interwave", 0, 1),
("basic_v1", "interwave", 1, 1),
# compv3: intrawave × {∅, dsb, si, dsb_si} = 4 combos
("compv3", "intrawave", 0, 0),
("compv3", "intrawave", 1, 0),
("compv3", "intrawave", 0, 1),
("compv3", "intrawave", 1, 1),
# compv4: intrawave × {dsb, dsb_si} only = 2 combos
("compv4", "intrawave", 1, 0),
("compv4", "intrawave", 1, 1),
# compv5: intrawave × {∅, dsb, si, dsb_si} = 4 combos
("compv5", "intrawave", 0, 0),
("compv5", "intrawave", 1, 0),
("compv5", "intrawave", 0, 1),
("compv5", "intrawave", 1, 1),
# compv6: intrawave × {∅, dsb, si, dsb_si} = 4 combos
("compv6", "intrawave", 0, 0),
("compv6", "intrawave", 1, 0),
("compv6", "intrawave", 0, 1),
("compv6", "intrawave", 1, 1),
# mem: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
("mem", "intrawave", 0, 0),
("mem", "intrawave", 1, 0),
("mem", "intrawave", 0, 1),
("mem", "intrawave", 1, 1),
("mem", "interwave", 0, 0),
("mem", "interwave", 1, 0),
("mem", "interwave", 0, 1),
("mem", "interwave", 1, 1),
]
def iter_pipeline_variants(pipelines: List[str] = None):
"""Iterate (pipeline, wave_mode, has_dsb, has_si) tuples, optionally filtered.
Args:
pipelines: optional list of pipeline names to keep. If None, yield all.
"""
if pipelines is None:
for entry in PIPELINE_VARIANTS:
yield entry
return
keep = set(pipelines)
for entry in PIPELINE_VARIANTS:
if entry[0] in keep:
yield entry
# Valid pipelines per variant
# All 8 pipelines (basic_v1, mem, compv3-6, comp_async, basic_async_v1) successfully
# build and run for all variants in both 2D and 3D (verified via 10_test_all_pipelines.py)
VARIANT_PIPELINES: Dict[str, List[str]] = {
"forward": [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
],
"bwd_data": [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
],
"bwd_weight": [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
],
}
# Tiles that support compv4 pipeline
# compv4 has stricter requirements due to double buffering and LDS constraints
# Pattern: only warp_tile [16,16,16] or [16,16,32] work with compv4
# Large warp_tile [32,32,16] and wave [8,2,1] fail arch validation for compv4
COMPV4_COMPATIBLE_TILES: List[Tuple[int, int, int]] = [
# warp_tile [16,16,16] - all work with compv4
(16, 64, 64),
(32, 64, 64),
(64, 64, 64),
# (128, 64, 64), # Excluded: wave=8x2x1 fails for compv4
# warp_tile [16,16,32] - all work with compv4
(16, 64, 128),
(32, 64, 128),
(64, 64, 128),
# (128, 64, 128), # Excluded: wave=8x2x1 fails for compv4
]
# Backward weight tiles (very restricted due to transpose_tile2d constraints)
# Testing all tiles to verify which ones actually work
BWD_WEIGHT_TILES: List[Tuple[int, int, int]] = [
# warp_tile [16,16,16]
(16, 64, 64), # Known working config
(32, 64, 64), # Test
(64, 64, 64), # Test
# warp_tile [32,32,16]
(32, 128, 64), # Test
(64, 128, 64), # Test
(128, 128, 64), # Test
# warp_tile [16,16,32]
(16, 64, 128), # Test
(32, 64, 128), # Test
(64, 64, 128), # Test
(128, 64, 128), # Test
]
# =============================================================================
# Validation
# =============================================================================
def validate_tile_config(tile_m: int, tile_n: int, tile_k: int) -> bool:
"""Check if a tile configuration is valid and registered."""
tile_key = (tile_m, tile_n, tile_k)
return (
tile_key in TILE_TO_WAVE
and tile_key in TILE_TO_WARP
and tile_key in TILE_TO_VECTOR
)
def get_tile_full_config(tile_m: int, tile_n: int, tile_k: int) -> dict:
"""Get complete configuration for a tile size.
Returns:
dict with keys: wave_m, wave_n, wave_k, warp_m, warp_n, warp_k, vec_a, vec_b, vec_c
or None if tile not found
"""
tile_key = (tile_m, tile_n, tile_k)
if not validate_tile_config(tile_m, tile_n, tile_k):
return None
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
vec_a, vec_b, vec_c = TILE_TO_VECTOR[tile_key]
return {
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"wave_m": wave_m,
"wave_n": wave_n,
"wave_k": wave_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"vec_a": vec_a,
"vec_b": vec_b,
"vec_c": vec_c,
}
# =============================================================================
# Summary Statistics
# =============================================================================
def print_summary():
"""Print summary of available tile configurations."""
print("=" * 80)
print("Grouped Convolution Tile Configurations (Single Source of Truth)")
print("=" * 80)
print(f"Total tiles: {len(COMMON_TILES)}")
print(f"Backward weight tiles: {len(BWD_WEIGHT_TILES)}")
print()
print("Tile sizes (M×N×K):")
for tile in COMMON_TILES:
m, n, k = tile
wave = TILE_TO_WAVE[tile]
warp = TILE_TO_WARP[tile]
print(
f" {m:3}×{n:3}×{k:3} wave={wave[0]}×{wave[1]}×{wave[2]} warp={warp[0]}×{warp[1]}×{warp[2]}"
)
print("=" * 80)
if __name__ == "__main__":
print_summary()

View File

@@ -41,6 +41,26 @@ except ImportError:
ArchFilter = None
OperatorType = None
# Import tile configurations from grouped_config_rules (single source of truth)
try:
from grouped_config_rules import (
COMMON_TILES,
TILE_TO_WAVE,
TILE_TO_WARP,
VARIANT_PIPELINES,
BWD_WEIGHT_TILES,
COMPV4_COMPATIBLE_TILES,
)
HAS_TILE_CONFIGS = True
except ImportError:
HAS_TILE_CONFIGS = False
COMMON_TILES = []
TILE_TO_WAVE = {}
TILE_TO_WARP = {}
VARIANT_PIPELINES = {}
BWD_WEIGHT_TILES = []
COMPV4_COMPATIBLE_TILES = []
# ============================================================================
# Configuration and Data Structures
@@ -494,6 +514,21 @@ struct {kernel_name}_Config {{
# Create valid C++ namespace name
ns_name = "ns_" + kernel_name.replace("-", "_")
# basic_v1 / basic_async_v1 inherit BaseGemmPipelineAGmemBGmemCRegV1
# whose TailHandler takes (run_func, has_hot_loop) and invokes
# run_func(bool_constant<...>) -- 1 lambda arg. Other pipelines pass
# (run_func, has_hot_loop, tail_number) and invoke 2-arg run_func.
if tr.pipeline in ("basic_v1", "basic_async_v1"):
tail_handler_call = "BaseGemmPipeline::TailHandler(Run, has_hot_loop);"
run_lambda_signature = "[&](const auto has_hot_loop_)"
else:
tail_handler_call = (
"BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);"
)
run_lambda_signature = (
"[&](const auto has_hot_loop_, const auto tail_number_)"
)
return f"""
// Unique namespace for this kernel to avoid conflicts when including multiple kernels
namespace {ns_name} {{
@@ -605,7 +640,7 @@ struct {kernel_name}_Launcher {{
using Kernel = {kernel_type}<
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
const auto Run = {run_lambda_signature} {{
auto kargs = Kernel::MakeKernelArgs(args);
if (!Kernel::IsSupportedArgument(kargs)) {{
@@ -621,7 +656,7 @@ struct {kernel_name}_Launcher {{
return ave_time;
}};
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
{tail_handler_call}
return ave_time;
}}
}};
@@ -1021,7 +1056,10 @@ def get_default_configs(
variants: Optional[List[GroupedConvVariant]] = None,
ndims: Optional[List[int]] = None,
) -> List[GroupedConvKernelConfig]:
"""Get default grouped convolution configurations for target architecture"""
"""Get default grouped convolution configurations for target architecture.
Uses tile configurations from grouped_conv_instance_builder.py as single source of truth.
"""
configs = []
if variants is None:
@@ -1029,39 +1067,53 @@ def get_default_configs(
if ndims is None:
ndims = [2]
# Valid configurations per variant (based on CK Tile example configs)
# Forward and Backward Data: standard GEMM-like tiles
fwd_bwd_data_tiles = [
# (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
(128, 128, 32, 2, 2, 32, 32, 16), # Standard 128x128
(256, 256, 32, 2, 2, 32, 32, 16), # Large 256x256
(64, 64, 32, 1, 4, 16, 16, 16), # Small 64x64
(128, 64, 32, 2, 2, 32, 32, 16), # Rectangular
(16, 64, 64, 1, 4, 16, 16, 32), # Tall and narrow
]
# Import tile configs from instance builder (single source of truth)
if not HAS_TILE_CONFIGS or not COMMON_TILES:
log.warning("grouped_config_rules not available, using fallback tile configs")
# Fallback to minimal set if grouped_config_rules unavailable
fwd_bwd_data_tiles = [
(128, 128, 32, 2, 2, 32, 32, 16),
(64, 64, 32, 1, 4, 16, 16, 16),
(16, 64, 64, 1, 4, 16, 16, 32),
]
bwd_weight_tiles = [(16, 64, 64, 1, 4, 16, 16, 32)]
else:
# Build tile list from COMMON_TILES with wave/warp mappings
fwd_bwd_data_tiles = []
for tile_m, tile_n, tile_k in COMMON_TILES:
tile_key = (tile_m, tile_n, tile_k)
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
fwd_bwd_data_tiles.append(
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
)
# Backward Weight: VERY specific tile configs that work with CK Tile's bwd_weight kernel
# Based on ConvConfigComputeV3 from CK Tile examples (example/ck_tile/20_grouped_convolution/)
# Note: Backward weight has strict constraints on warp configurations due to transpose_tile2d
# Only specific warp configs work: (1, 4, 1) and (4, 1, 1) are known to work
bwd_weight_tiles = [
# (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
# ConvConfigComputeV3: The primary working config for backward weight
(16, 64, 64, 1, 4, 16, 16, 32),
]
# Backward weight: use BWD_WEIGHT_TILES from config rules
bwd_weight_tiles = []
for tile_m, tile_n, tile_k in BWD_WEIGHT_TILES:
tile_key = (tile_m, tile_n, tile_k)
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
bwd_weight_tiles.append(
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
)
for variant in variants:
# Select tile configs based on variant
if variant == GroupedConvVariant.BACKWARD_WEIGHT:
tile_configs = bwd_weight_tiles
# Backward weight ONLY supports compv3 (compv4/compv5 have transpose_tile2d issues)
pipelines = [("compv3", "cshuffle")]
# Backward weight supports compv3 and mem pipelines
# (compv4/compv5 have transpose_tile2d issues)
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
# Also generate two-stage variants (fp32 workspace + elementwise convert)
two_stage_flags = [False, True]
elif variant == GroupedConvVariant.BACKWARD_DATA:
tile_configs = fwd_bwd_data_tiles
# Backward data ONLY supports compv3 (compv4 has get_length issues in bwd_data kernel)
pipelines = [("compv3", "cshuffle")]
# Backward data supports compv3 and mem pipelines
# (compv4/compv5 have get_length issues in bwd_data kernel)
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
two_stage_flags = [False]
else:
tile_configs = fwd_bwd_data_tiles
@@ -1080,6 +1132,12 @@ def get_default_configs(
warp_tile_n,
warp_tile_k,
) in tile_configs:
# Skip tiles incompatible with compv4
if pipeline == "compv4" and HAS_TILE_CONFIGS:
tile_key = (tile_m, tile_n, tile_k)
if tile_key not in COMPV4_COMPATIBLE_TILES:
continue # Skip this tile for compv4
for two_stage in two_stage_flags:
adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k
@@ -1609,7 +1667,16 @@ def main():
parser.add_argument(
"--pipeline",
type=str,
choices=["mem", "compv3", "compv4", "compv5"],
choices=[
"basic_v1",
"basic_async_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
],
help="Pipeline type",
)
parser.add_argument(
@@ -1642,6 +1709,16 @@ def main():
default=None,
help="Double SMEM buffer (true/false)",
)
parser.add_argument(
"--split-image",
action="store_true",
help="Enable split-image (EnableSplitImage) for large spatial tensors",
)
parser.add_argument(
"--two-stage",
action="store_true",
help="Enable two-stage bwd_weight (fp32 workspace + elementwise convert)",
)
args = parser.parse_args()
@@ -1679,7 +1756,13 @@ def main():
if args.double_smem_buffer is not None:
dsb = args.double_smem_buffer.lower() == "true"
else:
dsb = pipeline == "compv4" # compv4 requires double buffer
# Historical default: only compv4 auto-defaults to dsb=true.
# Other pipelines that also require DoubleSmemBuffer (e.g. comp_async)
# must be told explicitly via --double-smem-buffer true; otherwise
# they will fail loudly at the pipeline header static_assert. This
# is intentional -- silent fallback to a different config would
# mask the user's input.
dsb = pipeline == "compv4"
trait = GroupedConvTraitConfig(
pipeline=pipeline,
@@ -1690,6 +1773,8 @@ def main():
pad_k=args.pad_k,
double_smem_buffer=dsb,
num_groups_to_merge=args.num_groups_to_merge,
split_image=args.split_image,
two_stage=args.two_stage,
)
config = GroupedConvKernelConfig(
tile=tile,
@@ -1719,18 +1804,20 @@ def main():
print(f" Spatial dims: {args.ndim}")
print(f"\nConfigurations ({len(filtered_configs)}):")
for cfg in filtered_configs:
print(f" - {cfg.name('fp16')}")
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
print(
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
)
print(
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
)
print(
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
)
# List configs for each requested datatype (fixes bf16 -> fp16 bug)
for dt in args.datatype:
print(f" - {cfg.name(dt)}")
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
print(
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
)
print(
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
)
print(
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
)
return
# Generate

View File

@@ -290,7 +290,7 @@ function(add_declarative_gpu_example NAME SOURCE)
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/example_kernel_builder.py
${EXAMPLE_SOURCE}
--output-dir ${EXAMPLE_KERNEL_DIR}
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include"
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include,${CMAKE_CURRENT_SOURCE_DIR}/../.."
--gpu-target ${GPU_TARGET}
--jobs ${NPROC}
--target-name ${NAME}
@@ -456,7 +456,47 @@ add_declarative_gpu_example(grouped_conv_06_bwd_weight grouped_conv/cpp/06_bw
add_declarative_gpu_example(grouped_conv_07_benchmark grouped_conv/cpp/07_multi_tile_benchmark.cpp)
# =============================================================================
# Grouped Convolution Python Library - Multi-Kernel (fwd/bwd_data/bwd_weight x 2D/3D)
# FMHA C++ Examples
# =============================================================================
add_declarative_gpu_example(fmha_01_basic fmha/cpp/01_basic_fmha.cpp)
add_declarative_gpu_example(fmha_02_splitkv fmha/cpp/02_splitkv_fmha.cpp)
add_declarative_gpu_example(fmha_03_kvcache fmha/cpp/03_kvcache_fmha.cpp)
add_declarative_gpu_example(fmha_04_bwd fmha/cpp/04_bwd_fmha.cpp)
add_declarative_gpu_example(fmha_05_appendkv fmha/cpp/05_appendkv_fmha.cpp)
add_declarative_gpu_example(fmha_06_batch_prefill fmha/cpp/06_batch_prefill_fmha.cpp)
add_declarative_gpu_example(fmha_07_profile_pytorch fmha/cpp/07_profile_pytorch_fmha.cpp)
add_declarative_gpu_example(fmha_08_profile_flash fmha/cpp/08_profile_flash_fmha.cpp)
add_declarative_gpu_example(fmha_09_profile_aiter fmha/cpp/09_profile_aiter_fmha.cpp)
add_declarative_gpu_example(fmha_10_profile_fp32_fp8 fmha/cpp/10_profile_fp32_fp8_fmha.cpp)
add_declarative_gpu_example(fmha_11_receipt_aliases fmha/cpp/11_receipt_aliases_fmha.cpp)
add_declarative_gpu_example(fmha_12_registry_json fmha/cpp/12_registry_json_fmha.cpp)
add_declarative_gpu_example(fmha_13_feature_coverage fmha/cpp/13_feature_coverage_fmha.cpp)
add_declarative_gpu_example(fmha_14_benchmark_validation fmha/cpp/14_benchmark_validation_fmha.cpp)
add_declarative_gpu_example(fmha_15_multi_shape fmha/cpp/15_multi_shape_fmha.cpp)
add_declarative_gpu_example(fmha_16_heuristics fmha/cpp/16_heuristics_fmha.cpp)
add_declarative_gpu_example(fmha_17_autofill_autocorrect fmha/cpp/17_autofill_autocorrect_fmha.cpp)
add_declarative_gpu_example(fmha_18_gpu_splitkv fmha/cpp/18_gpu_splitkv_fmha.cpp)
add_declarative_gpu_example(fmha_19_gpu_masks fmha/cpp/19_gpu_masks_fmha.cpp)
add_declarative_gpu_example(fmha_20_gpu_bias fmha/cpp/20_gpu_bias_fmha.cpp)
add_declarative_gpu_example(fmha_21_gpu_features fmha/cpp/21_gpu_features_fmha.cpp)
add_declarative_gpu_example(fmha_22_gpu_bwd fmha/cpp/22_gpu_bwd_fmha.cpp)
add_declarative_gpu_example(fmha_23_multi_registry fmha/cpp/23_multi_registry_fmha.cpp)
add_declarative_gpu_example(fmha_24_per_receipt_registries fmha/cpp/24_per_receipt_registries_fmha.cpp)
add_declarative_gpu_example(fmha_25_gpu_appendkv_prefill fmha/cpp/25_gpu_appendkv_batchprefill_fmha.cpp)
add_declarative_gpu_example(fmha_26_dtypes_hdims fmha/cpp/26_dtypes_hdims_fmha.cpp)
add_declarative_gpu_example(fmha_27_padding_permutation fmha/cpp/27_padding_permutation_fmha.cpp)
add_declarative_gpu_example(fmha_28_bwd_masks fmha/cpp/28_bwd_masks_fmha.cpp)
add_declarative_gpu_example(fmha_29_bwd_bias_dropout fmha/cpp/29_bwd_bias_dropout_fmha.cpp)
add_declarative_gpu_example(fmha_30_bwd_benchmark fmha/cpp/30_bwd_benchmark_fmha.cpp)
add_declarative_gpu_example(fmha_31_logits_soft_cap fmha/cpp/31_logits_soft_cap_fmha.cpp)
add_declarative_gpu_example(fmha_32_sink_tokens fmha/cpp/32_sink_tokens_fmha.cpp)
add_declarative_gpu_example(fmha_33_bwd_deterministic fmha/cpp/33_bwd_deterministic_fmha.cpp)
add_declarative_gpu_example(fmha_34_bwd_gqa fmha/cpp/34_bwd_gqa_fmha.cpp)
add_declarative_gpu_example(fmha_35_generic_mask fmha/cpp/35_generic_mask_fmha.cpp)
# =============================================================================
# Grouped Convolution Python Library - Multi-Kernel (fwd/bwdd/bwdw x 2D/3D)
# =============================================================================
# Kernel output directory for the Python conv library
@@ -502,13 +542,67 @@ if(hip_FOUND)
endif()
add_dependencies(dispatcher_conv_lib generate_conv_fallback_kernels)
# =============================================================================
# FMHA Python Library - Single Fallback Kernel
# =============================================================================
set(FMHA_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/fmha_python_fallback")
set(FMHA_DISPATCH_HEADER "${FMHA_FALLBACK_KERNEL_DIR}/fmha_python_dispatch.hpp")
set(FMHA_FALLBACK_LIB "${FMHA_FALLBACK_KERNEL_DIR}/libfmha_python_fallback.a")
set(FMHA_FALLBACK_SENTINEL "${FMHA_FALLBACK_KERNEL_DIR}/.generated")
# Generate the FMHA fallback kernel, compile it, and produce both
# the dispatch header and a static library with the kernel object.
# Uses example_kernel_builder.py with a synthetic source that declares
# a single FMHA kernel set, just like the C++ examples do.
set(FMHA_FALLBACK_SOURCE "${FMHA_FALLBACK_KERNEL_DIR}/fmha_python_fallback.cpp")
add_custom_command(
OUTPUT ${FMHA_DISPATCH_HEADER} ${FMHA_FALLBACK_LIB} ${FMHA_FALLBACK_SENTINEL}
COMMAND ${CMAKE_COMMAND} -E make_directory ${FMHA_FALLBACK_KERNEL_DIR}
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/fmha/generate_fallback.py
--output-dir ${FMHA_FALLBACK_KERNEL_DIR}
--gpu-target ${GPU_TARGET}
--compile
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include:${CMAKE_CURRENT_SOURCE_DIR}/../include:${CMAKE_CURRENT_SOURCE_DIR}/../.."
COMMAND ${CMAKE_COMMAND} -E touch ${FMHA_FALLBACK_SENTINEL}
COMMENT "Generating and compiling FMHA fallback kernel for Python library..."
VERBATIM
)
add_custom_target(generate_fmha_fallback_kernels
DEPENDS ${FMHA_DISPATCH_HEADER} ${FMHA_FALLBACK_LIB})
# FMHA dynamic library for Python
add_library(dispatcher_fmha_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/fmha_ctypes_lib.cpp)
target_link_libraries(dispatcher_fmha_lib PRIVATE ck_tile_dispatcher ${FMHA_FALLBACK_LIB})
target_include_directories(dispatcher_fmha_lib PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${CMAKE_CURRENT_SOURCE_DIR}/../include
${CMAKE_CURRENT_SOURCE_DIR}/../..
${FMHA_FALLBACK_KERNEL_DIR}
${FMHA_FALLBACK_KERNEL_DIR}/dispatcher_wrappers
)
target_compile_options(dispatcher_fmha_lib PRIVATE
-include ${FMHA_DISPATCH_HEADER}
-DGFX_ARCH="${GPU_TARGET}"
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
if(hip_FOUND)
target_link_libraries(dispatcher_fmha_lib PRIVATE hip::device hip::host)
endif()
add_dependencies(dispatcher_fmha_lib generate_fmha_fallback_kernels)
message(STATUS "GEMM examples configured - kernels will be generated during 'make'")
message(STATUS "Grouped Conv examples configured - kernels will be generated during 'make'")
message(STATUS "FMHA examples configured - kernels will be generated during 'make'")
# Convenience target to build all Python ctypes libraries
add_custom_target(python_libs
DEPENDS dispatcher_gemm_lib dispatcher_conv_lib
COMMENT "Building Python ctypes libraries (GEMM + Conv)"
DEPENDS dispatcher_gemm_lib dispatcher_conv_lib dispatcher_fmha_lib
COMMENT "Building Python ctypes libraries (GEMM + Conv + FMHA)"
)
# =============================================================================

View File

@@ -59,9 +59,17 @@ python3 examples/gemm/python/08_heuristics.py
```
examples/
|---- gemm/
| |---- cpp/ # 6 C++ GEMM examples
| |---- cpp/ # 7 C++ GEMM examples
| +---- python/ # 11 Python GEMM examples
|
|---- grouped_conv/
| |---- cpp/ # 7 C++ Grouped Conv examples
| +---- python/ # 6 Python Grouped Conv examples
|
|---- fmha/
| |---- cpp/ # 35 C++ FMHA examples (all variants)
| +---- python/ # 38 Python FMHA examples (JIT-compiled)
|
+---- README.md
```

View File

@@ -0,0 +1,371 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 01: Basic FMHA Forward with GPU Execution
//
// Demonstrates the full flow:
// 1. Declare kernels via DECL_FMHA_KERNEL_SET
// 2. Register and plan
// 3. Allocate Q, K, V, O GPU buffers
// 4. Run the FMHA forward kernel on GPU
// 5. Copy output to host and validate against CPU reference
//
// Mirrors 01_basic_gemm.cpp for FMHA.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
// FMHA tile/wave/warp dimensions correspond to TWO GEMM stages:
// Stage 0 (Q * K^T): tile_m0 x tile_n0 x tile_k0 (seqlen_q x seqlen_k x hdim_q)
// Stage 1 (Attn * V): tile_m0 x tile_n1 x tile_k1 (seqlen_q x hdim_v x seqlen_k)
// Wave/warp follow the same stage pattern: *_m0/n0/k0 for stage 0, *_m1/n1/k1 for stage 1.
DECL_FMHA_KERNEL_SET(basic_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r") // V row-major
.hdim(128) // hdim_q = hdim_v = 128
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 tile: seqlen_q=128, seqlen_k=128, hdim_q=32
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 tile: hdim_v=128, seqlen_k=32, alignment=128
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
// Wave: 4 warps on m, 1 on n, 1 on k (both stages)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
// Warp tile: 32x32x16 (both stages)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true) // pad_s, pad_sk, pad_d, pad_dv
.alignments(128, 128) // hdim_q_alignment, hdim_v_alignment
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 01: FMHA Forward (GPU Execution)", "FMHA with real GPU data");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
args.add_option("--hdim", "128", "Head dimension");
args.add_flag("--validate", "Validate against CPU reference");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 01: FMHA Forward (GPU Execution)");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("basic_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// Step 2: Plan
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t k_elems = q_elems;
const int64_t v_elems = q_elems;
const int64_t o_elems = q_elems;
// Step 3: Allocate GPU buffers
std::cout << "\nStep 2: Allocate GPU Buffers\n";
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
<< "]\n";
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(k_elems);
GpuBuffer<FmhaDataType> v_dev(v_elems);
GpuBuffer<FmhaDataType> o_dev(o_elems);
// Fill Q, K, V with random data
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems);
std::vector<FmhaDataType> k_host(k_elems);
std::vector<FmhaDataType> v_host(v_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
// Step 4: Set up args with device pointers and strides
fmha_fwd_args fmha_args{};
fmha_args.q_ptr = q_dev.get();
fmha_args.k_ptr = k_dev.get();
fmha_args.v_ptr = v_dev.get();
fmha_args.o_ptr = o_dev.get();
fmha_args.bias_ptr = nullptr;
fmha_args.q_descale_ptr = nullptr;
fmha_args.k_descale_ptr = nullptr;
fmha_args.v_descale_ptr = nullptr;
fmha_args.rand_val_ptr = nullptr;
fmha_args.lse_ptr = nullptr;
fmha_args.sink_ptr = nullptr;
fmha_args.block_scale_seqstart_q_ptr = nullptr;
fmha_args.block_scale_seqstart_k_ptr = nullptr;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.batch = batch;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.scale_s = scale;
fmha_args.logits_soft_cap = 0.0f;
// bhsd layout strides
fmha_args.stride_q = hdim;
fmha_args.stride_k = hdim;
fmha_args.stride_v = hdim;
fmha_args.stride_bias = 0;
fmha_args.stride_randval = 0;
fmha_args.stride_o = hdim;
fmha_args.nhead_stride_q = seqlen * hdim;
fmha_args.nhead_stride_k = seqlen * hdim;
fmha_args.nhead_stride_v = seqlen * hdim;
fmha_args.nhead_stride_bias = 0;
fmha_args.nhead_stride_randval = 0;
fmha_args.nhead_stride_lse = 0;
fmha_args.nhead_stride_o = seqlen * hdim;
fmha_args.nhead_stride_q_descale = 0;
fmha_args.nhead_stride_k_descale = 0;
fmha_args.nhead_stride_v_descale = 0;
fmha_args.batch_stride_q = nhead * seqlen * hdim;
fmha_args.batch_stride_k = nhead * seqlen * hdim;
fmha_args.batch_stride_v = nhead * seqlen * hdim;
fmha_args.batch_stride_bias = 0;
fmha_args.batch_stride_randval = 0;
fmha_args.batch_stride_lse = 0;
fmha_args.batch_stride_o = nhead * seqlen * hdim;
fmha_args.batch_stride_q_descale = 0;
fmha_args.batch_stride_k_descale = 0;
fmha_args.batch_stride_v_descale = 0;
fmha_args.window_size_left = -1;
fmha_args.window_size_right = -1;
fmha_args.sink_size = 0;
fmha_args.mask_type = 0;
fmha_args.min_seqlen_q = 0;
fmha_args.p_drop = 0.0f;
fmha_args.s_randval = false;
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fmha_args.block_scale_size_q = 0;
fmha_args.block_scale_size_kv = 0;
// Step 5: Run on GPU
std::cout << "\nStep 3: Run FMHA Forward on GPU\n";
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
return 1;
}
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Step 6: Copy output and validate
std::cout << "\nStep 4: Validate\n";
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
// Quick sanity check: output should be non-zero
int nonzero = 0;
for(int64_t i = 0; i < o_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
bool passed = (nonzero > 0);
if(args.has("--validate"))
{
// CPU reference
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < k_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < v_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
double max_abs_err = 0.0;
double max_rel_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < o_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
max_abs_err = std::max(max_abs_err, abs_err);
max_rel_err = std::max(max_rel_err, rel_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Max rel error: " << max_rel_err << "\n";
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
passed = (errors == 0);
}
print_separator();
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,162 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(splitkv_fmha_kernels,
.add(FmhaSignature()
.family("fwd_splitkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no")
.paged_kv(false),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true)
.max_splits_log2(6)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("fwd_splitkv_combine")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no")
.paged_kv(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(32)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true)
.max_splits_log2(6)
.selection_rank(0),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 02: FMHA Split-KV", "Declarative FMHA split-KV planning");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "1", "Batch size");
args.add_option("--nhead", "16", "Number of heads");
args.add_option("--seqlen", "128", "Query sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
{
return 0;
}
utils::print_header("Example 02: FMHA Split-KV");
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 1);
const int nhead = args.get_int("--nhead", 16);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
registry.set_name("splitkv_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
// Step 2: Plan
std::cout << "\nStep 2: Plan\n";
fmha_fwd_splitkv_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = true;
traits.do_fp8_static_quant = false;
traits.has_sink = false;
fmha_fwd_splitkv_args fmha_args{};
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = 2048;
fmha_args.batch = batch;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.num_splits = 8;
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
auto plan = dispatcher.plan(problem);
if(!plan.is_valid() || plan.stages.size() != 2)
{
std::cerr << "Expected a two-stage split-KV plan\n";
return 1;
}
// Step 3: Results
std::cout << "\nStep 3: Results\n";
for(const auto& stage : plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
utils::print_separator();
return 0;
}

View File

@@ -0,0 +1,240 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(kvcache_fmha_kernels,
.add(FmhaSignature()
.family("fwd_pagedkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no")
.paged_kv(true)
.kv_cache("vectorized", "sglang", 16),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_pagedkv")
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("fwd_appendkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.rope("inter")
.paged_kv(true)
.kv_cache("vectorized", "sglang", 16),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(64)
.tile_k0(128)
.tile_n1(128)
.tile_k1(0)
.tile_k0max(0)
.pipeline("appendkv")
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("batch_prefill")
.dtype("fp16")
.mode("group")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no")
.paged_kv(true)
.kv_cache("vectorized", "sglang", 16),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.selection_rank(0),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 03: FMHA KV-Cache", "Declarative FMHA KV-cache planning");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "1", "Batch size");
args.add_option("--nhead", "16", "Number of heads");
args.add_option("--seqlen", "128", "Prefill query sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
{
return 0;
}
utils::print_header("Example 03: FMHA KV-Cache");
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 1);
const int nhead = args.get_int("--nhead", 16);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
// Step 2: Plan PagedKV (decode)
std::cout << "\nStep 2: Plan PagedKV (decode)\n";
fmha_fwd_pagedkv_traits paged_traits{};
paged_traits.hdim_q = hdim;
paged_traits.hdim_v = hdim;
paged_traits.data_type = "fp16";
paged_traits.is_group_mode = false;
paged_traits.is_v_rowmajor = true;
paged_traits.mask_type = mask_enum::no_mask;
paged_traits.bias_type = bias_enum::no_bias;
paged_traits.use_pagedkv = true;
fmha_fwd_pagedkv_args paged_args{};
paged_args.seqlen_q = 1;
paged_args.seqlen_k = 1024;
paged_args.batch = batch;
paged_args.max_seqlen_q = 1;
paged_args.hdim_q = hdim;
paged_args.hdim_v = hdim;
paged_args.nhead_q = nhead;
paged_args.nhead_k = nhead;
paged_args.block_table_ptr = reinterpret_cast<void*>(0x1);
paged_args.page_block_size = 16;
auto paged_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(paged_traits, paged_args), gfx_arch));
// Step 3: Plan AppendKV
std::cout << "\nStep 3: Plan AppendKV\n";
fmha_fwd_appendkv_traits append_traits{};
append_traits.hdim_q = hdim;
append_traits.hdim_v = hdim;
append_traits.data_type = "fp16";
append_traits.is_v_rowmajor = true;
append_traits.rope_type = rope_enum::interleaved;
fmha_fwd_appendkv_args append_args{};
append_args.seqlen_q = 1;
append_args.seqlen_knew = 1;
append_args.batch = batch;
append_args.hdim_q = hdim;
append_args.hdim_v = hdim;
append_args.nhead_q = nhead;
append_args.nhead_k = nhead;
append_args.rotary_dim = hdim;
append_args.rotary_cos_ptr = reinterpret_cast<void*>(0x1);
append_args.rotary_sin_ptr = reinterpret_cast<void*>(0x1);
append_args.block_table_ptr = reinterpret_cast<void*>(0x1);
append_args.page_block_size = 16;
auto append_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(append_traits, append_args), gfx_arch));
// Step 4: Plan BatchPrefill
std::cout << "\nStep 4: Plan BatchPrefill\n";
fmha_batch_prefill_traits prefill_traits{};
prefill_traits.hdim_q = hdim;
prefill_traits.hdim_v = hdim;
prefill_traits.data_type = "fp16";
prefill_traits.is_group_mode = true;
prefill_traits.is_v_rowmajor = true;
prefill_traits.mask_type = mask_enum::no_mask;
prefill_traits.bias_type = bias_enum::no_bias;
prefill_traits.has_lse = true;
prefill_traits.kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
prefill_traits.kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
prefill_traits.page_size = 16;
fmha_batch_prefill_args prefill_args{};
prefill_args.batch = batch;
prefill_args.seqlen_q = seqlen;
prefill_args.seqlen_k = 1024;
prefill_args.max_seqlen_q = seqlen;
prefill_args.hdim_q = hdim;
prefill_args.hdim_v = hdim;
prefill_args.nhead_q = nhead;
prefill_args.nhead_k = nhead;
prefill_args.num_total_pages = 64;
prefill_args.page_block_size = 16;
prefill_args.kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
prefill_args.kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
prefill_args.kv_indptr = reinterpret_cast<void*>(0x1);
prefill_args.kv_page_indices = reinterpret_cast<void*>(0x1);
prefill_args.kv_last_page_lens = reinterpret_cast<void*>(0x1);
prefill_args.seqstart_q_ptr = reinterpret_cast<void*>(0x1);
auto prefill_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(prefill_traits, prefill_args), gfx_arch));
// Step 5: Results
std::cout << "\nStep 5: Results\n";
std::cout << " PagedKV stages: " << paged_plan.stages.size() << "\n";
std::cout << " AppendKV stages: " << append_plan.stages.size() << "\n";
std::cout << " BatchPrefill stages: " << prefill_plan.stages.size() << "\n";
utils::print_separator();
return (paged_plan.is_valid() && append_plan.is_valid() && prefill_plan.is_valid()) ? 0 : 1;
}

View File

@@ -0,0 +1,154 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(bwd_fmha_kernels,
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true)
.max_seq_len_q(0)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 04: FMHA Backward", "Declarative FMHA backward planning");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "1", "Batch size");
args.add_option("--nhead", "16", "Number of heads");
args.add_option("--seqlen", "128", "Sequence length (Q and K)");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
{
return 0;
}
utils::print_header("Example 04: FMHA Backward");
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 1);
const int nhead = args.get_int("--nhead", 16);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
// Step 2: Plan
std::cout << "\nStep 2: Plan\n";
fmha_bwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_dbias = false;
traits.has_dropout = false;
traits.is_store_randval = false;
traits.is_deterministic = false;
fmha_bwd_args bwd_args{};
bwd_args.batch = batch;
bwd_args.seqlen_q = seqlen;
bwd_args.seqlen_k = seqlen;
bwd_args.max_seqlen_q = seqlen;
bwd_args.max_seqlen_k = seqlen;
bwd_args.hdim_q = hdim;
bwd_args.hdim_v = hdim;
bwd_args.nhead_q = nhead;
bwd_args.nhead_k = nhead;
auto plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, bwd_args), gfx_arch));
if(!plan.is_valid() || plan.stages.size() < 2)
{
std::cerr << "Expected a multi-stage backward plan\n";
return 1;
}
// Step 3: Results
std::cout << "\nStep 3: Results\n";
for(const auto& stage : plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
utils::print_separator();
return 0;
}

View File

@@ -0,0 +1,106 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(appendkv_fmha_kernels,
.add(FmhaSignature()
.family("fwd_appendkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.rope("inter")
.paged_kv(true)
.kv_cache("vectorized", "sglang", 16),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(64)
.tile_n0(64)
.tile_k0(128)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(0)
.tile_k0max(0)
.pipeline("appendkv")
.padding(true, true, true, true)
.selection_rank(0),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 05: FMHA AppendKV", "Declarative FMHA append-KV planning");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "1", "Batch size");
args.add_option("--nhead", "16", "Number of heads");
args.add_option("--seqlen", "1", "Sequence length (tokens to append)");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
{
return 0;
}
utils::print_header("Example 05: FMHA AppendKV");
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 1);
const int nhead = args.get_int("--nhead", 16);
const int seqlen = args.get_int("--seqlen", 1);
const int hdim = args.get_int("--hdim", 128);
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
// Step 2: Plan
std::cout << "\nStep 2: Plan\n";
fmha_fwd_appendkv_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_v_rowmajor = true;
traits.rope_type = rope_enum::interleaved;
fmha_fwd_appendkv_args fmha_args{};
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_knew = seqlen;
fmha_args.batch = batch;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.rotary_dim = hdim;
fmha_args.rotary_cos_ptr = reinterpret_cast<void*>(0x1);
fmha_args.rotary_sin_ptr = reinterpret_cast<void*>(0x1);
fmha_args.block_table_ptr = reinterpret_cast<void*>(0x1);
fmha_args.page_block_size = 16;
auto plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
if(!plan.is_valid() || plan.stages.size() != 1)
{
std::cerr << "Expected a single-stage append-KV plan\n";
return 1;
}
// Step 3: Results
std::cout << "\nStep 3: Results\n";
for(const auto& stage : plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
utils::print_separator();
return 0;
}

View File

@@ -0,0 +1,133 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(batch_prefill_fmha_kernels,
.add(FmhaSignature()
.family("batch_prefill")
.dtype("fp16")
.mode("group")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no")
.paged_kv(true)
.kv_cache("vectorized", "sglang", 16),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.selection_rank(0),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 06: FMHA Batch Prefill",
"Declarative FMHA batch-prefill planning");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "1", "Batch size");
args.add_option("--nhead", "16", "Number of heads");
args.add_option("--seqlen", "128", "Query sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
{
return 0;
}
utils::print_header("Example 06: FMHA Batch Prefill");
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 1);
const int nhead = args.get_int("--nhead", 16);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
// Step 2: Plan
std::cout << "\nStep 2: Plan\n";
fmha_batch_prefill_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = true;
traits.is_v_rowmajor = true;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = true;
traits.kv_memory_layout = ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
traits.kv_lookup_table = ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
traits.page_size = 16;
fmha_batch_prefill_args fmha_args{};
fmha_args.batch = batch;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = 1024;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.num_total_pages = 64;
fmha_args.page_block_size = 16;
fmha_args.kv_memory_layout = ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
fmha_args.kv_lookup_table = ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
fmha_args.kv_indptr = reinterpret_cast<void*>(0x1);
fmha_args.kv_page_indices = reinterpret_cast<void*>(0x1);
fmha_args.kv_last_page_lens = reinterpret_cast<void*>(0x1);
fmha_args.seqstart_q_ptr = reinterpret_cast<void*>(0x1);
auto plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
if(!plan.is_valid() || plan.stages.size() != 1)
{
std::cerr << "Expected a single-stage batch-prefill plan\n";
return 1;
}
// Step 3: Results
std::cout << "\nStep 3: Results\n";
for(const auto& stage : plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
utils::print_separator();
return 0;
}

View File

@@ -0,0 +1,248 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(pytorch_profile_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("bias")
.profile("pytorch"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(32)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("fwd_splitkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.profile("pytorch"),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true)
.max_splits_log2(6),
"gfx950")
.add(FmhaSignature()
.family("fwd_splitkv_combine")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.profile("pytorch"),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(32)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true)
.max_splits_log2(6),
"gfx950")
.add(FmhaSignature()
.family("fwd_appendkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.profile("pytorch"),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(64)
.tile_k0(128)
.tile_n1(128)
.tile_k1(0)
.tile_k0max(0)
.padding(false, true, true, false)
.pipeline("appendkv"),
"gfx950")
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.profile("pytorch"),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.profile("pytorch"),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.profile("pytorch"),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 07: PyTorch-Profile FMHA",
"Declarative FMHA PyTorch-profile planning");
args.add_option("--arch", "gfx950", "GPU architecture");
if(!args.parse(argc, argv))
{
return 0;
}
const std::string gfx_arch = args.get("--arch", "gfx950");
FmhaRegistry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
FmhaDispatcher dispatcher(&registry);
std::cout << "PyTorch-profile FMHA kernels: " << registry.size() << "\n";
fmha_fwd_traits fwd_traits{};
fwd_traits.hdim_q = 128;
fwd_traits.hdim_v = 128;
fwd_traits.data_type = "fp16";
fwd_traits.is_group_mode = false;
fwd_traits.is_v_rowmajor = true;
fwd_traits.mask_type = mask_enum::no_mask;
fwd_traits.bias_type = bias_enum::elementwise_bias;
fwd_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fwd_args{};
fwd_args.batch = 1;
fwd_args.seqlen_q = 128;
fwd_args.seqlen_k = 128;
fwd_args.max_seqlen_q = 128;
fwd_args.hdim_q = 128;
fwd_args.hdim_v = 128;
fwd_args.nhead_q = 16;
fwd_args.nhead_k = 16;
auto fwd_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch));
fmha_bwd_traits bwd_traits{};
bwd_traits.hdim_q = 128;
bwd_traits.hdim_v = 128;
bwd_traits.data_type = "fp16";
bwd_traits.is_group_mode = false;
bwd_traits.mask_type = mask_enum::no_mask;
bwd_traits.bias_type = bias_enum::no_bias;
fmha_bwd_args bwd_args{};
bwd_args.batch = 1;
bwd_args.seqlen_q = 128;
bwd_args.seqlen_k = 128;
bwd_args.max_seqlen_q = 128;
bwd_args.max_seqlen_k = 128;
bwd_args.hdim_q = 128;
bwd_args.hdim_v = 128;
bwd_args.nhead_q = 16;
bwd_args.nhead_k = 16;
auto bwd_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
std::cout << "Forward plan stages: " << fwd_plan.stages.size() << "\n";
std::cout << "Backward plan stages: " << bwd_plan.stages.size() << "\n";
return (fwd_plan.is_valid() && bwd_plan.is_valid()) ? 0 : 1;
}

View File

@@ -0,0 +1,165 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(flash_profile_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("alibi")
.profile("flash_fwd"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(32)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.profile("flash_bwd"),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.profile("flash_bwd"),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.profile("flash_bwd"),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 08: Flash-Profile FMHA",
"Declarative FMHA Flash-profile planning");
args.add_option("--arch", "gfx950", "GPU architecture");
if(!args.parse(argc, argv))
{
return 0;
}
const std::string gfx_arch = args.get("--arch", "gfx950");
FmhaRegistry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
FmhaDispatcher dispatcher(&registry);
std::cout << "Flash-profile FMHA kernels: " << registry.size() << "\n";
fmha_fwd_traits fwd_traits{};
fwd_traits.hdim_q = 128;
fwd_traits.hdim_v = 128;
fwd_traits.data_type = "fp16";
fwd_traits.is_group_mode = false;
fwd_traits.is_v_rowmajor = true;
fwd_traits.mask_type = mask_enum::no_mask;
fwd_traits.bias_type = bias_enum::alibi;
fwd_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fwd_args{};
fwd_args.batch = 1;
fwd_args.seqlen_q = 128;
fwd_args.seqlen_k = 128;
fwd_args.max_seqlen_q = 128;
fwd_args.hdim_q = 128;
fwd_args.hdim_v = 128;
fwd_args.nhead_q = 16;
fwd_args.nhead_k = 16;
auto fwd_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch));
fmha_bwd_traits bwd_traits{};
bwd_traits.hdim_q = 128;
bwd_traits.hdim_v = 128;
bwd_traits.data_type = "fp16";
bwd_traits.is_group_mode = false;
bwd_traits.mask_type = mask_enum::no_mask;
bwd_traits.bias_type = bias_enum::no_bias;
fmha_bwd_args bwd_args{};
bwd_args.batch = 1;
bwd_args.seqlen_q = 128;
bwd_args.seqlen_k = 128;
bwd_args.max_seqlen_q = 128;
bwd_args.max_seqlen_k = 128;
bwd_args.hdim_q = 128;
bwd_args.hdim_v = 128;
bwd_args.nhead_q = 16;
bwd_args.nhead_k = 16;
auto bwd_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
std::cout << "Flash fwd stages: " << fwd_plan.stages.size() << "\n";
std::cout << "Flash bwd stages: " << bwd_plan.stages.size() << "\n";
return (fwd_plan.is_valid() && bwd_plan.is_valid()) ? 0 : 1;
}

View File

@@ -0,0 +1,212 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(
aiter_profile_fmha_kernels,
.add(FmhaSignature().family("fwd").dtype("fp16").mode("batch").vlayout("r").hdim(128).profile(
"aiter_batch"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("group")
.vlayout("r")
.hdim(128)
.profile("aiter_group"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("fwd_pagedkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.paged_kv(true)
.profile("aiter_cpp")
.kv_cache("vectorized", "sglang", 16),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_pagedkv")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("batch_prefill")
.dtype("fp16")
.mode("group")
.vlayout("r")
.hdim(128)
.paged_kv(true)
.profile("aiter_cpp")
.kv_cache("vectorized", "sglang", 16),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 09: AITER-Profile FMHA",
"Declarative FMHA AITER-profile planning");
args.add_option("--arch", "gfx950", "GPU architecture");
if(!args.parse(argc, argv))
{
return 0;
}
const std::string gfx_arch = args.get("--arch", "gfx950");
FmhaRegistry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
FmhaDispatcher dispatcher(&registry);
std::cout << "AITER-profile FMHA kernels: " << registry.size() << "\n";
fmha_fwd_traits batch_traits{};
batch_traits.hdim_q = 128;
batch_traits.hdim_v = 128;
batch_traits.data_type = "fp16";
batch_traits.is_group_mode = false;
batch_traits.is_v_rowmajor = true;
batch_traits.mask_type = mask_enum::no_mask;
batch_traits.bias_type = bias_enum::no_bias;
batch_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args batch_args{};
batch_args.batch = 1;
batch_args.seqlen_q = 128;
batch_args.seqlen_k = 128;
batch_args.max_seqlen_q = 128;
batch_args.hdim_q = 128;
batch_args.hdim_v = 128;
batch_args.nhead_q = 16;
batch_args.nhead_k = 16;
auto batch_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(batch_traits, batch_args), gfx_arch));
fmha_batch_prefill_traits prefill_traits{};
prefill_traits.hdim_q = 128;
prefill_traits.hdim_v = 128;
prefill_traits.data_type = "fp16";
prefill_traits.is_group_mode = true;
prefill_traits.is_v_rowmajor = true;
prefill_traits.mask_type = mask_enum::no_mask;
prefill_traits.bias_type = bias_enum::no_bias;
prefill_traits.kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
prefill_traits.kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
prefill_traits.page_size = 16;
fmha_batch_prefill_args prefill_args{};
prefill_args.batch = 1;
prefill_args.seqlen_q = 128;
prefill_args.seqlen_k = 1024;
prefill_args.max_seqlen_q = 128;
prefill_args.hdim_q = 128;
prefill_args.hdim_v = 128;
prefill_args.nhead_q = 16;
prefill_args.nhead_k = 16;
prefill_args.num_total_pages = 64;
prefill_args.page_block_size = 16;
prefill_args.kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
prefill_args.kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
prefill_args.kv_indptr = reinterpret_cast<void*>(0x1);
prefill_args.kv_page_indices = reinterpret_cast<void*>(0x1);
prefill_args.kv_last_page_lens = reinterpret_cast<void*>(0x1);
prefill_args.seqstart_q_ptr = reinterpret_cast<void*>(0x1);
auto prefill_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(prefill_traits, prefill_args), gfx_arch));
std::cout << "AITER batch stages: " << batch_plan.stages.size() << "\n";
std::cout << "AITER prefill stages: " << prefill_plan.stages.size() << "\n";
return (batch_plan.is_valid() && prefill_plan.is_valid()) ? 0 : 1;
}

View File

@@ -0,0 +1,152 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(fp32_fp8_profile_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp32")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.profile("fp32_min"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(32)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(16)
.tile_k0max(128)
.wave_m0(2)
.wave_n0(1)
.wave_k0(1)
.wave_m1(2)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("fwd")
.dtype("fp32")
.mode("batch")
.vlayout("r")
.hdim(48)
.mask("no")
.bias("no")
.profile("fp32_all"),
FmhaAlgorithm()
.tile_m0(32)
.tile_n0(128)
.tile_k0(16)
.tile_n1(48)
.tile_k1(16)
.tile_k0max(48)
.wave_m0(2)
.wave_n0(1)
.wave_k0(1)
.wave_m1(2)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("fwd")
.dtype("fp8bf16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.profile("fp8_test"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(32)
.warp_m1(32)
.warp_n1(32)
.warp_k1(32)
.pipeline("qr_async")
.padding(true, true, true, true),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 10: FP32/FP8-Profile FMHA",
"Declarative FMHA FP32/FP8-profile planning");
args.add_option("--arch", "gfx950", "GPU architecture");
if(!args.parse(argc, argv))
{
return 0;
}
const std::string gfx_arch = args.get("--arch", "gfx950");
FmhaRegistry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
FmhaDispatcher dispatcher(&registry);
std::cout << "FP32/FP8-profile FMHA kernels: " << registry.size() << "\n";
std::cout << registry.export_json(false) << "\n";
fmha_fwd_traits traits{};
traits.hdim_q = 128;
traits.hdim_v = 128;
traits.data_type = "fp32";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fmha_args{};
fmha_args.batch = 1;
fmha_args.seqlen_q = 128;
fmha_args.seqlen_k = 128;
fmha_args.max_seqlen_q = 128;
fmha_args.hdim_q = 128;
fmha_args.hdim_v = 128;
fmha_args.nhead_q = 16;
fmha_args.nhead_k = 16;
auto plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
std::cout << "FP32/FP8-profile plan stages: " << plan.stages.size() << "\n";
return plan.is_valid() ? 0 : 1;
}

View File

@@ -0,0 +1,176 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(receipt_alias_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.bias("alibi")
.receipt(2),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(32)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.bias("bias")
.receipt(4),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(32)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.receipt(100),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("fwd")
.dtype("fp32")
.mode("batch")
.vlayout("r")
.hdim(128)
.receipt(800),
FmhaAlgorithm()
.tile_m0(32)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(16)
.tile_k0max(128)
.wave_m0(2)
.wave_n0(1)
.wave_k0(1)
.wave_m1(2)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 11: Receipt Aliases FMHA",
"Declarative FMHA receipt-alias planning");
args.add_option("--arch", "gfx950", "GPU architecture");
if(!args.parse(argc, argv))
{
return 0;
}
const std::string gfx_arch = args.get("--arch", "gfx950");
FmhaRegistry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
FmhaDispatcher dispatcher(&registry);
std::cout << "Receipt-alias FMHA kernels: " << registry.size() << "\n";
fmha_fwd_traits traits{};
traits.hdim_q = 128;
traits.hdim_v = 128;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fmha_args{};
fmha_args.batch = 1;
fmha_args.seqlen_q = 128;
fmha_args.seqlen_k = 128;
fmha_args.max_seqlen_q = 128;
fmha_args.hdim_q = 128;
fmha_args.hdim_v = 128;
fmha_args.nhead_q = 16;
fmha_args.nhead_k = 16;
auto plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
std::cout << "Receipt-alias plan stages: " << plan.stages.size() << "\n";
return plan.is_valid() ? 0 : 1;
}

View File

@@ -0,0 +1,129 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <fstream>
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(
registry_json_fmha_kernels,
.add(FmhaSignature().family("fwd").dtype("fp16").mode("batch").vlayout("r").hdim(128),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature()
.family("fwd_pagedkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.paged_kv(true)
.kv_cache("vectorized", "sglang", 16),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_pagedkv")
.padding(true, true, true, true),
"gfx950")
.add(FmhaSignature().family("bwd_dq_dk_dv").dtype("fp16").mode("batch").hdim(128),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 12: Registry JSON FMHA",
"Declarative FMHA registry JSON export");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--output", "", "Write JSON to file (optional)");
if(!args.parse(argc, argv))
{
return 0;
}
utils::print_header("Example 12: Registry JSON FMHA");
const std::string gfx_arch = args.get("--arch", "gfx950");
const std::string output_path = args.get("--output", "");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
registry.set_name("registry_json_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
// Step 2: Export JSON
std::cout << "\nStep 2: Export JSON\n";
std::string json = registry.export_json(true);
std::cout << " JSON size: " << json.size() << " bytes\n";
std::cout << json.substr(0, std::min<std::size_t>(json.size(), 240)) << "\n";
// Step 3: Write to file (if --output specified)
if(!output_path.empty())
{
std::cout << "\nStep 3: Write to File\n";
std::ofstream ofs(output_path);
if(!ofs.is_open())
{
std::cerr << " ERROR: Cannot open " << output_path << " for writing\n";
return 1;
}
ofs << json;
ofs.close();
std::cout << " Written to: " << output_path << "\n";
std::cout << " File size: " << json.size() << " bytes\n";
}
utils::print_separator();
return registry.size() > 0 ? 0 : 1;
}

View File

@@ -0,0 +1,499 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 13: FMHA Feature Coverage
// Exercises every feature dimension from the 01_fmha smoke test:
// bf16, masks (top-left, bottom-right, window_generic), GQA, dropout,
// multiple hdims (64, 256), group mode, col-major V.
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
DECL_FMHA_KERNEL_SET(feature_coverage_kernels,
// fp16 forward (basic, needed for GQA and other fp16 tests)
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// bf16 forward
.add(FmhaSignature()
.family("fwd")
.dtype("bf16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// hdim 64
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(64)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(64)
.tile_k0(32)
.tile_n1(64)
.tile_k1(32)
.tile_k0max(64)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(64, 64)
.selection_rank(0),
"gfx950")
// hdim 256
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(256)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(256)
.tile_k1(32)
.tile_k0max(256)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr")
.padding(false, false, false, false)
.alignments(256, 256)
.selection_rank(0),
"gfx950")
// Mask: causal (top-left and bottom-right share the same compiled kernel;
// the mask type is resolved at runtime via the args, not the template)
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("top_left")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Dropout
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(true)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// GQA (nhead_q != nhead_k) - same kernel, GQA is a runtime concern
// Bias: elementwise
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("bias")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Bias: alibi
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("alibi")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Group mode
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("group")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Sink tokens
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("top_left")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no")
.sink(true),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
struct FeatureTest
{
std::string name;
FmhaProblem problem;
};
FeatureTest make_test(const std::string& name,
const std::string& dtype,
int hdim_q,
int hdim_v,
int mask,
int bias,
bool lse,
bool dropout,
bool group,
bool logits,
bool sink,
int nhead_q = 16,
int nhead_k = 16,
const std::string& arch = "gfx950")
{
auto p = FmhaProblemBuilder()
.api_family(FmhaApiFamily::Fwd)
.kernel_family(FmhaKernelFamily::Fwd)
.gfx_arch(arch)
.data_type(dtype)
.dims(hdim_q, hdim_v, 2, 128, 256)
.nheads(nhead_q, nhead_k)
.mask_type(mask)
.bias_type(bias)
.lse(lse)
.dropout(dropout)
.group_mode(group)
.logits_soft_cap(logits)
.sink(sink)
.build();
return {name, p};
}
} // namespace
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 13: FMHA Feature Coverage",
"Tests all 01_fmha smoke test features");
args.add_option("--arch", "gfx950", "GPU architecture");
if(!args.parse(argc, argv))
return 0;
utils::print_header("Example 13: FMHA Feature Coverage");
const std::string gfx_arch = args.get("--arch", "gfx950");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("feature_coverage");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
// Step 2: Run feature tests
std::cout << "\nStep 2: Run Feature Tests\n";
std::vector<FeatureTest> tests = {
make_test("bf16_basic", "bf16", 128, 128, 0, 0, false, false, false, false, false),
make_test("fp16_hdim64", "fp16", 64, 64, 0, 0, false, false, false, false, false),
make_test("fp16_hdim256", "fp16", 256, 256, 0, 0, true, false, false, false, false),
make_test("mask_top_left", "fp16", 128, 128, 1, 0, false, false, false, false, false),
make_test("mask_bottom_right", "fp16", 128, 128, 2, 0, false, false, false, false, false),
make_test("dropout", "fp16", 128, 128, 0, 0, true, true, false, false, false),
make_test("gqa_h16_hk4", "fp16", 128, 128, 0, 0, false, false, false, false, false, 16, 4),
make_test("bias_elementwise", "fp16", 128, 128, 0, 1, false, false, false, false, false),
make_test("bias_alibi", "fp16", 128, 128, 0, 2, false, false, false, false, false),
make_test("group_mode", "fp16", 128, 128, 0, 0, false, false, true, false, false),
make_test("sink_tokens", "fp16", 128, 128, 1, 0, false, false, false, false, true),
};
int pass = 0;
int fail = 0;
for(const auto& test : tests)
{
auto plan = dispatcher.plan(test.problem);
bool ok = plan.is_valid();
std::cout << (ok ? "[PASS]" : "[FAIL]") << " " << test.name;
if(ok)
{
std::cout << " -> " << plan.stages[0].kernel_id;
++pass;
}
else
{
++fail;
}
std::cout << "\n";
}
// Step 3: Summary
std::cout << "\nStep 3: Summary\n";
std::cout << " " << pass << " passed, " << fail << " failed out of " << tests.size() << "\n";
utils::print_separator();
return fail > 0 ? 1 : 0;
}

View File

@@ -0,0 +1,404 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 14: FMHA Benchmark with Validation
//
// Demonstrates:
// 1. Warmup runs to stabilize GPU clocks
// 2. Repeated benchmark runs with statistics (min/avg/max/median)
// 3. Optional CPU reference validation via --verify flag
//
// Usage:
// ./14_benchmark_validation_fmha
// ./14_benchmark_validation_fmha --seqlen 256 --batch 4 --repeat 20
// ./14_benchmark_validation_fmha --verify
#include <hip/hip_runtime.h>
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <numeric>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using FmhaDataType = ck_tile::fp16_t;
DECL_FMHA_KERNEL_SET(benchmark_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 14: FMHA Benchmark + Validation",
"Warmup, repeated benchmark, optional verification");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "8", "Number of heads");
args.add_option("--seqlen", "128", "Sequence length (Q and K)");
args.add_option("--hdim", "128", "Head dimension");
args.add_option("--warmup", "3", "Warmup iterations");
args.add_option("--repeat", "10", "Benchmark repetitions");
args.add_flag("--verify", "Validate against CPU reference");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 8);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
const int warmup = args.get_int("--warmup", 3);
const int repeat = args.get_int("--repeat", 10);
print_header("Example 14: FMHA Benchmark + Validation");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("benchmark_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t o_elems = q_elems;
// Step 2: Allocate GPU buffers
std::cout << "\nStep 2: Allocate GPU Buffers\n";
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
<< "]\n";
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(q_elems);
GpuBuffer<FmhaDataType> v_dev(q_elems);
GpuBuffer<FmhaDataType> o_dev(o_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems);
std::vector<FmhaDataType> k_host(q_elems);
std::vector<FmhaDataType> v_host(q_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
fmha_fwd_args fmha_args{};
fmha_args.q_ptr = q_dev.get();
fmha_args.k_ptr = k_dev.get();
fmha_args.v_ptr = v_dev.get();
fmha_args.o_ptr = o_dev.get();
fmha_args.bias_ptr = nullptr;
fmha_args.q_descale_ptr = nullptr;
fmha_args.k_descale_ptr = nullptr;
fmha_args.v_descale_ptr = nullptr;
fmha_args.rand_val_ptr = nullptr;
fmha_args.lse_ptr = nullptr;
fmha_args.sink_ptr = nullptr;
fmha_args.block_scale_seqstart_q_ptr = nullptr;
fmha_args.block_scale_seqstart_k_ptr = nullptr;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.batch = batch;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.scale_s = scale;
fmha_args.logits_soft_cap = 0.0f;
fmha_args.stride_q = hdim;
fmha_args.stride_k = hdim;
fmha_args.stride_v = hdim;
fmha_args.stride_bias = 0;
fmha_args.stride_randval = 0;
fmha_args.stride_o = hdim;
fmha_args.nhead_stride_q = seqlen * hdim;
fmha_args.nhead_stride_k = seqlen * hdim;
fmha_args.nhead_stride_v = seqlen * hdim;
fmha_args.nhead_stride_bias = 0;
fmha_args.nhead_stride_randval = 0;
fmha_args.nhead_stride_lse = 0;
fmha_args.nhead_stride_o = seqlen * hdim;
fmha_args.nhead_stride_q_descale = 0;
fmha_args.nhead_stride_k_descale = 0;
fmha_args.nhead_stride_v_descale = 0;
fmha_args.batch_stride_q = nhead * seqlen * hdim;
fmha_args.batch_stride_k = nhead * seqlen * hdim;
fmha_args.batch_stride_v = nhead * seqlen * hdim;
fmha_args.batch_stride_bias = 0;
fmha_args.batch_stride_randval = 0;
fmha_args.batch_stride_lse = 0;
fmha_args.batch_stride_o = nhead * seqlen * hdim;
fmha_args.batch_stride_q_descale = 0;
fmha_args.batch_stride_k_descale = 0;
fmha_args.batch_stride_v_descale = 0;
fmha_args.window_size_left = -1;
fmha_args.window_size_right = -1;
fmha_args.sink_size = 0;
fmha_args.mask_type = 0;
fmha_args.min_seqlen_q = 0;
fmha_args.p_drop = 0.0f;
fmha_args.s_randval = false;
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fmha_args.block_scale_size_q = 0;
fmha_args.block_scale_size_kv = 0;
FmhaDispatcher dispatcher(&registry);
// Step 3: Warmup runs
std::cout << "\nStep 3: Warmup (" << warmup << " iterations)\n";
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 1);
for(int i = 0; i < warmup; ++i)
{
o_dev.zero();
float t = dispatcher.run_fwd(traits, fmha_args, nullptr);
std::cout << " Warmup " << (i + 1) << ": " << std::fixed << std::setprecision(4) << t
<< " ms\n";
}
// Step 4: Benchmark runs
std::cout << "\nStep 4: Benchmark (" << repeat << " iterations)\n";
dispatcher.set_timing(0, 1);
std::vector<float> times;
times.reserve(repeat);
for(int i = 0; i < repeat; ++i)
{
o_dev.zero();
float t = dispatcher.run_fwd(traits, fmha_args, nullptr);
times.push_back(t);
}
std::sort(times.begin(), times.end());
float t_min = times.front();
float t_max = times.back();
float t_avg = std::accumulate(times.begin(), times.end(), 0.0f) / static_cast<float>(repeat);
float t_med =
(repeat % 2 == 0) ? (times[repeat / 2 - 1] + times[repeat / 2]) / 2.0f : times[repeat / 2];
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
double ops = static_cast<double>(problem.num_ops());
double tflops_min = ops / (t_max * 1e-3) / 1e12;
double tflops_max = ops / (t_min * 1e-3) / 1e12;
double tflops_avg = ops / (t_avg * 1e-3) / 1e12;
double tflops_med = ops / (t_med * 1e-3) / 1e12;
std::cout << "\n " << std::setw(10) << "Metric" << " | " << std::setw(12) << "Time(ms)"
<< " | " << std::setw(12) << "TFLOPS" << "\n";
std::cout << " " << std::string(40, '-') << "\n";
std::cout << std::fixed << std::setprecision(4);
std::cout << " " << std::setw(10) << "Min" << " | " << std::setw(12) << t_min << " | "
<< std::setprecision(2) << std::setw(12) << tflops_max << "\n";
std::cout << std::setprecision(4);
std::cout << " " << std::setw(10) << "Avg" << " | " << std::setw(12) << t_avg << " | "
<< std::setprecision(2) << std::setw(12) << tflops_avg << "\n";
std::cout << std::setprecision(4);
std::cout << " " << std::setw(10) << "Median" << " | " << std::setw(12) << t_med << " | "
<< std::setprecision(2) << std::setw(12) << tflops_med << "\n";
std::cout << std::setprecision(4);
std::cout << " " << std::setw(10) << "Max" << " | " << std::setw(12) << t_max << " | "
<< std::setprecision(2) << std::setw(12) << tflops_min << "\n";
bool passed = true;
// Step 5: Optional validation
if(args.has("--verify"))
{
std::cout << "\nStep 5: CPU Reference Validation\n";
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
std::vector<float> q_f32(q_elems), k_f32(q_elems), v_f32(q_elems), o_ref(o_elems, 0.0f);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < q_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < q_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
double max_abs_err = 0.0;
double max_rel_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < o_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
max_abs_err = std::max(max_abs_err, abs_err);
max_rel_err = std::max(max_rel_err, rel_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Max rel error: " << max_rel_err << "\n";
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
passed = (errors == 0);
}
else
{
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < o_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << "\n Sanity: " << nonzero << " / " << o_elems << " non-zero outputs\n";
passed = (nonzero > 0);
}
print_separator();
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,282 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 15: Multi-Shape FMHA Sweep
//
// Demonstrates running a single FMHA kernel across multiple (batch, seqlen)
// combinations, producing a performance table. This pattern is useful for
// characterizing kernel behavior across the parameter space.
//
// Usage:
// ./15_multi_shape_fmha
// ./15_multi_shape_fmha --arch gfx942
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using FmhaDataType = ck_tile::fp16_t;
DECL_FMHA_KERNEL_SET(multi_shape_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
struct ShapeConfig
{
int batch;
int seqlen;
};
const ShapeConfig SHAPES[] = {
{1, 64},
{1, 128},
{1, 256},
{1, 512},
{2, 64},
{2, 128},
{2, 256},
{4, 64},
{4, 128},
};
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 15: Multi-Shape FMHA",
"Sweep (batch, seqlen) combos with a single kernel");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--nhead", "8", "Number of heads");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int nhead = args.get_int("--nhead", 8);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 15: Multi-Shape FMHA Sweep");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("multi_shape_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// Step 2: Sweep shapes
std::cout << "\nStep 2: Shape Sweep (nhead=" << nhead << ", hdim=" << hdim << ")\n\n";
std::cout << " " << std::setw(6) << "Batch" << " | " << std::setw(8) << "SeqLen" << " | "
<< std::setw(12) << "Elements" << " | " << std::setw(10) << "Time(ms)" << " | "
<< std::setw(10) << "TFLOPS" << " | " << std::setw(8) << "Status" << "\n";
std::cout << " " << std::string(66, '-') << "\n";
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
int pass_count = 0;
int total = 0;
const int num_shapes = sizeof(SHAPES) / sizeof(SHAPES[0]);
for(int si = 0; si < num_shapes; ++si)
{
const auto& shape = SHAPES[si];
++total;
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
const int64_t elems = static_cast<int64_t>(shape.batch) * nhead * shape.seqlen * hdim;
GpuBuffer<FmhaDataType> q_dev(elems);
GpuBuffer<FmhaDataType> k_dev(elems);
GpuBuffer<FmhaDataType> v_dev(elems);
GpuBuffer<FmhaDataType> o_dev(elems);
std::vector<FmhaDataType> h_buf(elems);
for(auto& x : h_buf)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(h_buf.data());
for(auto& x : h_buf)
x = FmhaDataType(dist(rng));
k_dev.copy_from_host(h_buf.data());
for(auto& x : h_buf)
x = FmhaDataType(dist(rng));
v_dev.copy_from_host(h_buf.data());
o_dev.zero();
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fmha_args{};
fmha_args.q_ptr = q_dev.get();
fmha_args.k_ptr = k_dev.get();
fmha_args.v_ptr = v_dev.get();
fmha_args.o_ptr = o_dev.get();
fmha_args.bias_ptr = nullptr;
fmha_args.q_descale_ptr = nullptr;
fmha_args.k_descale_ptr = nullptr;
fmha_args.v_descale_ptr = nullptr;
fmha_args.rand_val_ptr = nullptr;
fmha_args.lse_ptr = nullptr;
fmha_args.sink_ptr = nullptr;
fmha_args.block_scale_seqstart_q_ptr = nullptr;
fmha_args.block_scale_seqstart_k_ptr = nullptr;
fmha_args.seqlen_q = shape.seqlen;
fmha_args.seqlen_k = shape.seqlen;
fmha_args.batch = shape.batch;
fmha_args.max_seqlen_q = shape.seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.scale_s = scale;
fmha_args.logits_soft_cap = 0.0f;
fmha_args.stride_q = hdim;
fmha_args.stride_k = hdim;
fmha_args.stride_v = hdim;
fmha_args.stride_bias = 0;
fmha_args.stride_randval = 0;
fmha_args.stride_o = hdim;
fmha_args.nhead_stride_q = shape.seqlen * hdim;
fmha_args.nhead_stride_k = shape.seqlen * hdim;
fmha_args.nhead_stride_v = shape.seqlen * hdim;
fmha_args.nhead_stride_bias = 0;
fmha_args.nhead_stride_randval = 0;
fmha_args.nhead_stride_lse = 0;
fmha_args.nhead_stride_o = shape.seqlen * hdim;
fmha_args.nhead_stride_q_descale = 0;
fmha_args.nhead_stride_k_descale = 0;
fmha_args.nhead_stride_v_descale = 0;
fmha_args.batch_stride_q = nhead * shape.seqlen * hdim;
fmha_args.batch_stride_k = nhead * shape.seqlen * hdim;
fmha_args.batch_stride_v = nhead * shape.seqlen * hdim;
fmha_args.batch_stride_bias = 0;
fmha_args.batch_stride_randval = 0;
fmha_args.batch_stride_lse = 0;
fmha_args.batch_stride_o = nhead * shape.seqlen * hdim;
fmha_args.batch_stride_q_descale = 0;
fmha_args.batch_stride_k_descale = 0;
fmha_args.batch_stride_v_descale = 0;
fmha_args.window_size_left = -1;
fmha_args.window_size_right = -1;
fmha_args.sink_size = 0;
fmha_args.mask_type = 0;
fmha_args.min_seqlen_q = 0;
fmha_args.p_drop = 0.0f;
fmha_args.s_randval = false;
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fmha_args.block_scale_size_q = 0;
fmha_args.block_scale_size_kv = 0;
bool ok = false;
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
std::vector<FmhaDataType> o_host(elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
ok = (nonzero > 0);
}
catch(const std::exception& e)
{
std::cerr << " ERROR for B=" << shape.batch << " S=" << shape.seqlen << ": "
<< e.what() << "\n";
}
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
std::cout << std::fixed;
std::cout << " " << std::setw(6) << shape.batch << " | " << std::setw(8) << shape.seqlen
<< " | " << std::setw(12) << elems << " | " << std::setprecision(4)
<< std::setw(10) << time_ms << " | " << std::setprecision(2) << std::setw(10)
<< tflops << " | " << std::setw(8) << (ok ? "PASS" : "FAIL") << "\n";
if(ok)
++pass_count;
}
// Summary
print_separator();
std::cout << "Results: " << pass_count << "/" << total << " shapes passed\n";
std::cout << "Status: " << (pass_count == total ? "PASS" : "FAIL") << "\n";
print_separator();
return (pass_count == total) ? 0 : 1;
}

View File

@@ -0,0 +1,428 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 16: FMHA Heuristic-Based Kernel Selection
//
// Demonstrates:
// 1. Two kernels with different tile_m0 (128 vs 64) and selection_rank
// 2. Custom heuristic function that picks kernels based on seqlen
// 3. dispatcher.set_heuristic() + SelectionStrategy::Heuristic
// 4. Planning different problems to show which kernel is selected
// 5. GPU execution for at least one problem
//
// Usage:
// ./16_heuristics_fmha
// ./16_heuristics_fmha --arch gfx942
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using FmhaDataType = ck_tile::fp16_t;
DECL_FMHA_KERNEL_SET(heuristic_fmha_kernels,
// Kernel A: Large tile (128x128) -- better for long sequences
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Kernel B: Smaller tile_m0 (64x128) -- lower latency for short sequences
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(1),
"gfx950"));
namespace {
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 16: FMHA Heuristic Kernel Selection",
"Custom heuristic picks kernel based on seqlen");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--nhead", "8", "Number of heads");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int nhead = args.get_int("--nhead", 8);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 16: FMHA Heuristic Kernel Selection");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("heuristic_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
// Step 2: Set up heuristic
std::cout << "\nStep 2: Configure Heuristic\n";
std::cout << " Rule: seqlen >= 256 -> prefer large tile (128x128, rank=0)\n";
std::cout << " seqlen < 256 -> prefer small tile (64x128, rank=1)\n";
auto all_kernels = registry.all_kernels();
std::cout << " Available kernels:\n";
for(const auto& k : all_kernels)
{
std::cout << " - " << k->id() << "\n";
}
std::string kernel_a_id, kernel_b_id;
for(const auto& k : all_kernels)
{
auto kid = k->id();
if(kernel_a_id.empty())
kernel_a_id = kid;
else if(kernel_b_id.empty())
kernel_b_id = kid;
}
FmhaDispatcher dispatcher(&registry);
dispatcher.set_strategy(SelectionStrategy::Heuristic);
dispatcher.set_heuristic([&](const FmhaProblem& problem) -> std::vector<std::string> {
if(problem.seqlen_q >= 256)
return {kernel_a_id, kernel_b_id};
else
return {kernel_b_id, kernel_a_id};
});
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// Step 3: Plan different problems to show kernel selection
std::cout << "\nStep 3: Plan Problems (show kernel selection)\n\n";
struct PlanCase
{
int batch;
int seqlen;
};
PlanCase plan_cases[] = {{1, 64}, {1, 128}, {2, 256}, {2, 512}, {4, 1024}};
std::cout << " " << std::setw(6) << "Batch" << " | " << std::setw(8) << "SeqLen" << " | "
<< std::setw(50) << "Selected Kernel" << "\n";
std::cout << " " << std::string(68, '-') << "\n";
for(const auto& pc : plan_cases)
{
auto problem = FmhaProblemBuilder()
.api_family(FmhaApiFamily::Fwd)
.kernel_family(FmhaKernelFamily::Fwd)
.gfx_arch(gfx_arch)
.data_type("fp16")
.dims(hdim, hdim, pc.batch, pc.seqlen, pc.seqlen)
.nheads(nhead, nhead)
.mask_type(0)
.bias_type(0)
.lse(false)
.dropout(false)
.build();
auto plan = dispatcher.plan(problem);
std::string selected = plan.is_valid() ? plan.stages[0].kernel_id : "(no match)";
std::cout << " " << std::setw(6) << pc.batch << " | " << std::setw(8) << pc.seqlen << " | "
<< std::setw(50) << selected << "\n";
}
// Step 4: GPU execution for a representative problem
std::cout << "\nStep 4: GPU Execution (batch=2, seqlen=256)\n";
const int batch = 2;
const int seqlen = 256;
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
const int64_t elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
GpuBuffer<FmhaDataType> q_dev(elems);
GpuBuffer<FmhaDataType> k_dev(elems);
GpuBuffer<FmhaDataType> v_dev(elems);
GpuBuffer<FmhaDataType> o_dev(elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> fdist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(elems), k_host(elems), v_host(elems);
for(auto& x : q_host)
x = FmhaDataType(fdist(rng));
for(auto& x : k_host)
x = FmhaDataType(fdist(rng));
for(auto& x : v_host)
x = FmhaDataType(fdist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fmha_args{};
fmha_args.q_ptr = q_dev.get();
fmha_args.k_ptr = k_dev.get();
fmha_args.v_ptr = v_dev.get();
fmha_args.o_ptr = o_dev.get();
fmha_args.bias_ptr = nullptr;
fmha_args.q_descale_ptr = nullptr;
fmha_args.k_descale_ptr = nullptr;
fmha_args.v_descale_ptr = nullptr;
fmha_args.rand_val_ptr = nullptr;
fmha_args.lse_ptr = nullptr;
fmha_args.sink_ptr = nullptr;
fmha_args.block_scale_seqstart_q_ptr = nullptr;
fmha_args.block_scale_seqstart_k_ptr = nullptr;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.batch = batch;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.scale_s = scale;
fmha_args.logits_soft_cap = 0.0f;
fmha_args.stride_q = hdim;
fmha_args.stride_k = hdim;
fmha_args.stride_v = hdim;
fmha_args.stride_bias = 0;
fmha_args.stride_randval = 0;
fmha_args.stride_o = hdim;
fmha_args.nhead_stride_q = seqlen * hdim;
fmha_args.nhead_stride_k = seqlen * hdim;
fmha_args.nhead_stride_v = seqlen * hdim;
fmha_args.nhead_stride_bias = 0;
fmha_args.nhead_stride_randval = 0;
fmha_args.nhead_stride_lse = 0;
fmha_args.nhead_stride_o = seqlen * hdim;
fmha_args.nhead_stride_q_descale = 0;
fmha_args.nhead_stride_k_descale = 0;
fmha_args.nhead_stride_v_descale = 0;
fmha_args.batch_stride_q = nhead * seqlen * hdim;
fmha_args.batch_stride_k = nhead * seqlen * hdim;
fmha_args.batch_stride_v = nhead * seqlen * hdim;
fmha_args.batch_stride_bias = 0;
fmha_args.batch_stride_randval = 0;
fmha_args.batch_stride_lse = 0;
fmha_args.batch_stride_o = nhead * seqlen * hdim;
fmha_args.batch_stride_q_descale = 0;
fmha_args.batch_stride_k_descale = 0;
fmha_args.batch_stride_v_descale = 0;
fmha_args.window_size_left = -1;
fmha_args.window_size_right = -1;
fmha_args.sink_size = 0;
fmha_args.mask_type = 0;
fmha_args.min_seqlen_q = 0;
fmha_args.p_drop = 0.0f;
fmha_args.s_randval = false;
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fmha_args.block_scale_size_q = 0;
fmha_args.block_scale_size_kv = 0;
float time_ms = 0.0f;
bool passed = false;
try
{
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Validate against CPU reference
std::vector<FmhaDataType> o_host(elems);
o_dev.copy_to_host(o_host.data());
std::vector<float> q_f32(elems), k_f32(elems), v_f32(elems), o_ref(elems, 0.0f);
for(int64_t i = 0; i < elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
double max_abs_err = 0.0;
int errors = 0;
for(int64_t i = 0; i < elems; ++i)
{
double abs_err = std::abs(static_cast<float>(o_host[i]) - o_ref[i]);
max_abs_err = std::max(max_abs_err, abs_err);
if(abs_err > 1e-2 + 1e-2 * std::abs(o_ref[i]))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Errors: " << errors << " / " << elems << "\n";
passed = (errors == 0);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
}
print_separator();
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,423 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 17: FMHA Autofill and Autocorrect
//
// Demonstrates three DECL_FMHA_KERNEL_SET patterns:
// 1. AUTOFILL: Minimal specification -- only family/dtype/hdim/pipeline/tile
// are provided; wave/warp use defaults from FmhaAlgorithm constructor
// 2. AUTOCORRECT: Intentionally non-standard wave config that still works
// because FmhaAlgorithm auto_fill() corrects missing tile_n1/tile_k1
// 3. FULL: All parameters explicitly specified (reference)
//
// Each is registered, planned, run on GPU, and validated.
//
// Usage:
// ./17_autofill_autocorrect_fmha
// ./17_autofill_autocorrect_fmha --arch gfx942
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using FmhaDataType = ck_tile::fp16_t;
// Pattern 1: AUTOFILL -- minimal specification, defaults for wave/warp
DECL_FMHA_KERNEL_SET(autofill_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.pipeline("qr_async")
.padding(true, true, true, true),
"gfx950"));
// Pattern 2: AUTOCORRECT -- tile_n1/tile_k1 set to 0, auto_fill() corrects them
DECL_FMHA_KERNEL_SET(autocorrect_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true),
"gfx950"));
// Pattern 3: FULL -- every parameter explicitly specified
DECL_FMHA_KERNEL_SET(full_spec_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
struct KernelTestCase
{
std::string name;
std::string kernel_set_name;
};
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 17: FMHA Autofill & Autocorrect",
"Three DECL_FMHA_KERNEL_SET patterns compared");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "8", "Number of heads");
args.add_option("--seqlen", "128", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 8);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 17: FMHA Autofill & Autocorrect");
// Step 1: Show registered kernel sets
std::cout << "\nStep 1: Registered Kernel Sets\n";
FmhaKernelSetRegistry::instance().print();
const KernelTestCase cases[] = {
{"AUTOFILL (minimal spec, wave/warp defaults)", "autofill_kernels"},
{"AUTOCORRECT (tile_n1/k1=0, auto_fill corrects)", "autocorrect_kernels"},
{"FULL (all params explicit)", "full_spec_kernels"},
};
// Prepare input data (shared across all tests)
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
const int64_t elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(elems), k_host(elems), v_host(elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
// CPU reference
std::vector<float> q_f32(elems), k_f32(elems), v_f32(elems), o_ref(elems, 0.0f);
for(int64_t i = 0; i < elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
int total_pass = 0;
const int total_cases = sizeof(cases) / sizeof(cases[0]);
for(int ci = 0; ci < total_cases; ++ci)
{
const auto& tc = cases[ci];
std::cout << "\nStep " << (ci + 2) << ": " << tc.name << "\n";
// Register from the named kernel set
FmhaRegistry registry;
registry.set_name(tc.kernel_set_name);
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
if(registry.size() == 0)
{
std::cout << " SKIP: no kernels registered\n";
continue;
}
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// Allocate GPU buffers
GpuBuffer<FmhaDataType> q_dev(elems);
GpuBuffer<FmhaDataType> k_dev(elems);
GpuBuffer<FmhaDataType> v_dev(elems);
GpuBuffer<FmhaDataType> o_dev(elems);
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fmha_args{};
fmha_args.q_ptr = q_dev.get();
fmha_args.k_ptr = k_dev.get();
fmha_args.v_ptr = v_dev.get();
fmha_args.o_ptr = o_dev.get();
fmha_args.bias_ptr = nullptr;
fmha_args.q_descale_ptr = nullptr;
fmha_args.k_descale_ptr = nullptr;
fmha_args.v_descale_ptr = nullptr;
fmha_args.rand_val_ptr = nullptr;
fmha_args.lse_ptr = nullptr;
fmha_args.sink_ptr = nullptr;
fmha_args.block_scale_seqstart_q_ptr = nullptr;
fmha_args.block_scale_seqstart_k_ptr = nullptr;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.batch = batch;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.scale_s = scale;
fmha_args.logits_soft_cap = 0.0f;
fmha_args.stride_q = hdim;
fmha_args.stride_k = hdim;
fmha_args.stride_v = hdim;
fmha_args.stride_bias = 0;
fmha_args.stride_randval = 0;
fmha_args.stride_o = hdim;
fmha_args.nhead_stride_q = seqlen * hdim;
fmha_args.nhead_stride_k = seqlen * hdim;
fmha_args.nhead_stride_v = seqlen * hdim;
fmha_args.nhead_stride_bias = 0;
fmha_args.nhead_stride_randval = 0;
fmha_args.nhead_stride_lse = 0;
fmha_args.nhead_stride_o = seqlen * hdim;
fmha_args.nhead_stride_q_descale = 0;
fmha_args.nhead_stride_k_descale = 0;
fmha_args.nhead_stride_v_descale = 0;
fmha_args.batch_stride_q = nhead * seqlen * hdim;
fmha_args.batch_stride_k = nhead * seqlen * hdim;
fmha_args.batch_stride_v = nhead * seqlen * hdim;
fmha_args.batch_stride_bias = 0;
fmha_args.batch_stride_randval = 0;
fmha_args.batch_stride_lse = 0;
fmha_args.batch_stride_o = nhead * seqlen * hdim;
fmha_args.batch_stride_q_descale = 0;
fmha_args.batch_stride_k_descale = 0;
fmha_args.batch_stride_v_descale = 0;
fmha_args.window_size_left = -1;
fmha_args.window_size_right = -1;
fmha_args.sink_size = 0;
fmha_args.mask_type = 0;
fmha_args.min_seqlen_q = 0;
fmha_args.p_drop = 0.0f;
fmha_args.s_randval = false;
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fmha_args.block_scale_size_q = 0;
fmha_args.block_scale_size_kv = 0;
try
{
float time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
// Validate
std::vector<FmhaDataType> o_host(elems);
o_dev.copy_to_host(o_host.data());
double max_abs_err = 0.0;
int errors = 0;
for(int64_t i = 0; i < elems; ++i)
{
double abs_err = std::abs(static_cast<float>(o_host[i]) - o_ref[i]);
max_abs_err = std::max(max_abs_err, abs_err);
if(abs_err > 1e-2 + 1e-2 * std::abs(o_ref[i]))
++errors;
}
bool ok = (errors == 0);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms"
<< " TFLOPS: " << std::setprecision(2) << tflops
<< " MaxErr: " << std::scientific << max_abs_err << " "
<< (ok ? "PASS" : "FAIL") << "\n";
if(ok)
++total_pass;
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
}
}
// Summary
print_separator();
std::cout << "Results: " << total_pass << "/" << total_cases << " patterns passed\n";
std::cout << "Patterns:\n";
std::cout << " 1. AUTOFILL: Only tile + pipeline specified; wave/warp use defaults\n";
std::cout << " 2. AUTOCORRECT: tile_n1/k1/k0max=0 -> auto_fill() infers from tile_n0/k0\n";
std::cout << " 3. FULL: Every parameter explicit (reference configuration)\n";
std::cout << "Status: " << (total_pass == total_cases ? "PASS" : "FAIL") << "\n";
print_separator();
return (total_pass == total_cases) ? 0 : 1;
}

View File

@@ -0,0 +1,466 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 18: GPU Split-KV FMHA Forward
//
// Demonstrates split-KV attention with GPU execution:
// 1. Declare both fwd_splitkv and fwd_splitkv_combine kernels
// 2. Show 2-stage execution plan
// 3. Allocate Q, K, V, O plus workspace (lse_acc, o_acc)
// 4. Run the split-KV forward pass on GPU
// 5. Copy output to host and validate against CPU reference
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(splitkv_gpu_kernels,
.add(FmhaSignature()
.family("fwd_splitkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no")
.paged_kv(false),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr_nwarp_sshuffle")
.padding(true, true, true, true)
.max_splits_log2(6)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("fwd_splitkv_combine")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no")
.paged_kv(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(32)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(16)
.warp_n0(16)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr")
.padding(true, true, true, true)
.max_splits_log2(6)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 18: GPU Split-KV FMHA Forward", "Split-KV with GPU execution");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen_q", "64", "Query sequence length");
args.add_option("--seqlen_k", "2048", "KV sequence length");
args.add_option("--hdim", "128", "Head dimension");
args.add_option("--splits", "2", "Number of KV splits");
args.add_flag("--validate", "Validate against CPU reference");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen_q = args.get_int("--seqlen_q", 64);
const int seqlen_k = args.get_int("--seqlen_k", 2048);
const int hdim = args.get_int("--hdim", 128);
const int num_splits = args.get_int("--splits", 2);
print_header("Example 18: GPU Split-KV FMHA Forward");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("splitkv_gpu_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// Step 2: Set up traits and plan
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
fmha_fwd_splitkv_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = true;
traits.do_fp8_static_quant = false;
traits.has_sink = false;
// Workspace sizes: lse_acc [batch, nhead, num_splits, seqlen_q]
// o_acc [batch, nhead, num_splits, seqlen_q, hdim]
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen_q * hdim;
const int64_t k_elems = static_cast<int64_t>(batch) * nhead * seqlen_k * hdim;
const int64_t v_elems = k_elems;
const int64_t o_elems = q_elems;
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen_q;
const int64_t lse_acc_elems = static_cast<int64_t>(batch) * nhead * num_splits * seqlen_q;
const int64_t o_acc_elems = static_cast<int64_t>(batch) * nhead * num_splits * seqlen_q * hdim;
// Show the 2-stage plan
std::cout << "\nStep 2: Plan (2-stage split-KV)\n";
fmha_fwd_splitkv_args plan_args{};
plan_args.seqlen_q = seqlen_q;
plan_args.seqlen_k = seqlen_k;
plan_args.batch = batch;
plan_args.max_seqlen_q = seqlen_q;
plan_args.hdim_q = hdim;
plan_args.hdim_v = hdim;
plan_args.nhead_q = nhead;
plan_args.nhead_k = nhead;
plan_args.num_splits = num_splits;
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, plan_args), gfx_arch);
auto plan = dispatcher.plan(problem);
if(!plan.is_valid() || plan.stages.size() != 2)
{
std::cerr << " WARNING: Expected a two-stage split-KV plan, got " << plan.stages.size()
<< " stage(s)\n";
if(!plan.is_valid())
{
std::cerr << " Plan is invalid -- no matching kernels found\n";
print_separator();
return 1;
}
}
for(const auto& stage : plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
// Step 3: Allocate GPU buffers
std::cout << "\nStep 3: Allocate GPU Buffers\n";
std::cout << " Q: [" << batch << ", " << nhead << ", " << seqlen_q << ", " << hdim
<< "]\n";
std::cout << " K/V: [" << batch << ", " << nhead << ", " << seqlen_k << ", " << hdim
<< "]\n";
std::cout << " O: [" << batch << ", " << nhead << ", " << seqlen_q << ", " << hdim
<< "]\n";
std::cout << " lse_acc: [" << batch << ", " << nhead << ", " << num_splits << ", " << seqlen_q
<< "]\n";
std::cout << " o_acc: [" << batch << ", " << nhead << ", " << num_splits << ", " << seqlen_q
<< ", " << hdim << "]\n";
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(k_elems);
GpuBuffer<FmhaDataType> v_dev(v_elems);
GpuBuffer<FmhaDataType> o_dev(o_elems);
GpuBuffer<float> lse_dev(lse_elems);
GpuBuffer<float> lse_acc_dev(lse_acc_elems);
GpuBuffer<float> o_acc_dev(o_acc_elems);
// Fill Q, K, V with random data
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems);
std::vector<FmhaDataType> k_host(k_elems);
std::vector<FmhaDataType> v_host(v_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
lse_acc_dev.zero();
o_acc_dev.zero();
// Step 4: Set up splitkv args with device pointers and strides
fmha_fwd_splitkv_args fmha_args{};
fmha_args.q_ptr = q_dev.get();
fmha_args.k_ptr = k_dev.get();
fmha_args.v_ptr = v_dev.get();
fmha_args.o_ptr = o_dev.get();
fmha_args.bias_ptr = nullptr;
fmha_args.lse_acc_ptr = lse_acc_dev.get();
fmha_args.o_acc_ptr = o_acc_dev.get();
fmha_args.lse_ptr = lse_dev.get();
fmha_args.block_table_ptr = nullptr;
fmha_args.batch_stride_block_table = 0;
fmha_args.page_block_size = 0;
fmha_args.is_gappy = false;
fmha_args.cache_batch_idx = nullptr;
fmha_args.seqstart_q_ptr = nullptr;
fmha_args.seqstart_k_ptr = nullptr;
fmha_args.seqlen_k_ptr = nullptr;
fmha_args.sink_ptr = nullptr;
fmha_args.seqlen_q = seqlen_q;
fmha_args.seqlen_k = seqlen_k;
fmha_args.batch = batch;
fmha_args.max_seqlen_q = seqlen_q;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.num_splits = num_splits;
fmha_args.scale_s = scale;
fmha_args.scale_p = 1.0f;
fmha_args.scale_o = 1.0f;
fmha_args.logits_soft_cap = 0.0f;
// bhsd layout strides
fmha_args.stride_q = hdim;
fmha_args.stride_k = hdim;
fmha_args.stride_v = hdim;
fmha_args.stride_bias = 0;
fmha_args.stride_o_acc = hdim;
fmha_args.stride_o = hdim;
fmha_args.nhead_stride_q = seqlen_q * hdim;
fmha_args.nhead_stride_k = seqlen_k * hdim;
fmha_args.nhead_stride_v = seqlen_k * hdim;
fmha_args.nhead_stride_bias = 0;
fmha_args.nhead_stride_lse = seqlen_q;
fmha_args.nhead_stride_lse_acc = num_splits * seqlen_q;
fmha_args.nhead_stride_o_acc = num_splits * seqlen_q * hdim;
fmha_args.nhead_stride_o = seqlen_q * hdim;
fmha_args.batch_stride_q = nhead * seqlen_q * hdim;
fmha_args.batch_stride_k = nhead * seqlen_k * hdim;
fmha_args.batch_stride_v = nhead * seqlen_k * hdim;
fmha_args.batch_stride_bias = 0;
fmha_args.batch_stride_lse = nhead * seqlen_q;
fmha_args.batch_stride_lse_acc = nhead * num_splits * seqlen_q;
fmha_args.batch_stride_o_acc = nhead * num_splits * seqlen_q * hdim;
fmha_args.batch_stride_o = nhead * seqlen_q * hdim;
fmha_args.split_stride_lse_acc = seqlen_q;
fmha_args.split_stride_o_acc = seqlen_q * hdim;
fmha_args.window_size_left = -1;
fmha_args.window_size_right = -1;
fmha_args.sink_size = 0;
fmha_args.mask_type = 0;
// Step 5: Run on GPU
std::cout << "\nStep 4: Run Split-KV FMHA Forward on GPU\n";
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd_splitkv(traits, fmha_args, nullptr);
}
catch(const std::exception& e)
{
std::cerr << " WARNING: GPU execution failed: " << e.what() << "\n";
std::cerr << " Falling back to planning-only mode (split-KV compilation can be complex)\n";
std::cout << "\n Plan summary (2 stages):\n";
for(const auto& stage : plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
print_separator();
std::cout << "Status: PLAN_ONLY\n";
print_separator();
return 0;
}
auto run_problem =
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
double tflops = static_cast<double>(run_problem.num_ops()) / (time_ms * 1e-3) / 1e12;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Step 6: Copy output and validate
std::cout << "\nStep 5: Validate\n";
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < o_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
bool passed = (nonzero > 0);
if(args.has("--validate"))
{
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < k_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < v_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen_q, seqlen_k, hdim, hdim, scale);
double max_abs_err = 0.0;
double max_rel_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < o_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
max_abs_err = std::max(max_abs_err, abs_err);
max_rel_err = std::max(max_rel_err, rel_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Max rel error: " << max_rel_err << "\n";
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
passed = (errors == 0);
}
print_separator();
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,456 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 19: GPU FMHA Forward with Mask Types
//
// Demonstrates three mask variants with GPU execution:
// 1. No mask (standard attention)
// 2. Top-left causal mask (zero upper triangle)
// 3. Bottom-right causal mask (shifted diagonal)
//
// Uses seqlen_q=64, seqlen_k=128 to make mask behavior visible.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(mask_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("top_left")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Note: bottom_right shares the same compiled kernel as top_left
// (both use SimplifiedGenericAttentionMask<true>). The mask type
// is resolved at runtime via args.mask_type, not the template.
// fmha_mask_compatible() in generated_fmha_backend.hpp handles this.
);
namespace {
using FmhaDataType = ck_tile::fp16_t;
// mask_type: 0=no_mask, 1=top_left, 2=bottom_right
void cpu_attention_fwd_masked(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
int mask_type)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
bool masked = false;
if(mask_type == 1)
{
// top_left: causal from top-left, mask if sk >= sq + 1
if(sk >= sq + 1)
masked = true;
}
else if(mask_type == 2)
{
// bottom_right: shifted diagonal, mask if sk >= sq + (seqlen_k - seqlen_q)
// + 1
if(sk >= sq + (seqlen_k - seqlen_q) + 1)
masked = true;
}
if(masked)
scores[sk] = -1e30f;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 19: FMHA with Masks (GPU)", "FMHA mask variants on GPU");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen_q", "64", "Query sequence length");
args.add_option("--seqlen_k", "128", "KV sequence length");
args.add_option("--hdim", "128", "Head dimension");
args.add_flag("--validate", "Validate against CPU reference");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen_q = args.get_int("--seqlen_q", 64);
const int seqlen_k = args.get_int("--seqlen_k", 128);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 19: FMHA with Masks (GPU)");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("mask_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
// Allocate GPU buffers
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen_q * hdim;
const int64_t k_elems = static_cast<int64_t>(batch) * nhead * seqlen_k * hdim;
const int64_t v_elems = k_elems;
const int64_t o_elems = q_elems;
std::cout << "\nStep 2: Allocate GPU Buffers\n";
std::cout << " Q/O: [" << batch << ", " << nhead << ", " << seqlen_q << ", " << hdim << "]\n";
std::cout << " K/V: [" << batch << ", " << nhead << ", " << seqlen_k << ", " << hdim << "]\n";
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(k_elems);
GpuBuffer<FmhaDataType> v_dev(v_elems);
GpuBuffer<FmhaDataType> o_dev(o_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems);
std::vector<FmhaDataType> k_host(k_elems);
std::vector<FmhaDataType> v_host(v_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
// Convert to f32 for CPU reference
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < k_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < v_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
// Test each mask type
struct MaskTest
{
const char* name;
int mask_type_int;
mask_enum mask_type;
};
MaskTest tests[] = {
{"no_mask", 0, mask_enum::no_mask},
{"top_left", 1, mask_enum::mask_top_left},
{"bottom_right", 2, mask_enum::mask_bottom_right},
};
bool all_passed = true;
for(const auto& test : tests)
{
std::cout << "\nStep 3: Run FMHA Forward [" << test.name << "]\n";
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = test.mask_type;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
o_dev.zero();
fmha_fwd_args fmha_args{};
fmha_args.q_ptr = q_dev.get();
fmha_args.k_ptr = k_dev.get();
fmha_args.v_ptr = v_dev.get();
fmha_args.o_ptr = o_dev.get();
fmha_args.bias_ptr = nullptr;
fmha_args.q_descale_ptr = nullptr;
fmha_args.k_descale_ptr = nullptr;
fmha_args.v_descale_ptr = nullptr;
fmha_args.rand_val_ptr = nullptr;
fmha_args.lse_ptr = nullptr;
fmha_args.sink_ptr = nullptr;
fmha_args.block_scale_seqstart_q_ptr = nullptr;
fmha_args.block_scale_seqstart_k_ptr = nullptr;
fmha_args.seqlen_q = seqlen_q;
fmha_args.seqlen_k = seqlen_k;
fmha_args.batch = batch;
fmha_args.max_seqlen_q = seqlen_q;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.scale_s = scale;
fmha_args.logits_soft_cap = 0.0f;
// bhsd layout strides
fmha_args.stride_q = hdim;
fmha_args.stride_k = hdim;
fmha_args.stride_v = hdim;
fmha_args.stride_bias = 0;
fmha_args.stride_randval = 0;
fmha_args.stride_o = hdim;
fmha_args.nhead_stride_q = seqlen_q * hdim;
fmha_args.nhead_stride_k = seqlen_k * hdim;
fmha_args.nhead_stride_v = seqlen_k * hdim;
fmha_args.nhead_stride_bias = 0;
fmha_args.nhead_stride_randval = 0;
fmha_args.nhead_stride_lse = 0;
fmha_args.nhead_stride_o = seqlen_q * hdim;
fmha_args.nhead_stride_q_descale = 0;
fmha_args.nhead_stride_k_descale = 0;
fmha_args.nhead_stride_v_descale = 0;
fmha_args.batch_stride_q = nhead * seqlen_q * hdim;
fmha_args.batch_stride_k = nhead * seqlen_k * hdim;
fmha_args.batch_stride_v = nhead * seqlen_k * hdim;
fmha_args.batch_stride_bias = 0;
fmha_args.batch_stride_randval = 0;
fmha_args.batch_stride_lse = 0;
fmha_args.batch_stride_o = nhead * seqlen_q * hdim;
fmha_args.batch_stride_q_descale = 0;
fmha_args.batch_stride_k_descale = 0;
fmha_args.batch_stride_v_descale = 0;
fmha_args.window_size_left = -1;
fmha_args.window_size_right = (test.mask_type_int == 0) ? -1 : 0;
fmha_args.sink_size = 0;
fmha_args.mask_type = test.mask_type_int;
fmha_args.min_seqlen_q = 0;
fmha_args.p_drop = 0.0f;
fmha_args.s_randval = false;
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fmha_args.block_scale_size_q = 0;
fmha_args.block_scale_size_kv = 0;
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
}
catch(const std::exception& e)
{
std::cerr << " ERROR [" << test.name << "]: " << e.what() << "\n";
all_passed = false;
continue;
}
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Validate
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < o_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
if(nonzero == 0)
all_passed = false;
if(args.has("--validate"))
{
std::vector<float> o_ref(o_elems, 0.0f);
cpu_attention_fwd_masked(q_f32,
k_f32,
v_f32,
o_ref,
batch,
nhead,
seqlen_q,
seqlen_k,
hdim,
hdim,
scale,
test.mask_type_int);
double max_abs_err = 0.0;
double max_rel_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < o_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
max_abs_err = std::max(max_abs_err, abs_err);
max_rel_err = std::max(max_rel_err, rel_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Max rel error: " << max_rel_err << "\n";
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
if(errors > 0)
all_passed = false;
}
}
print_separator();
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
print_separator();
return all_passed ? 0 : 1;
}

View File

@@ -0,0 +1,584 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 20: GPU FMHA Forward with Bias Types
//
// Demonstrates three bias variants with GPU execution:
// 1. No bias (standard attention)
// 2. Elementwise bias (arbitrary bias matrix added to scores)
// 3. ALiBi (Attention with Linear Biases -- slope-based positional encoding)
//
// Validates each variant against a CPU reference.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(bias_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("bias")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("alibi")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
// bias_type: 0=none, 1=elementwise, 2=alibi
// bias_buf layout: elementwise [1, nhead, seqlen_q, seqlen_k], alibi [1, nhead] slopes
void cpu_attention_fwd_biased(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
int bias_type,
const std::vector<float>& bias_buf)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
float s = dot * scale;
if(bias_type == 1)
{
int bias_idx = (h * seqlen_q + sq) * seqlen_k + sk;
s += bias_buf[bias_idx];
}
else if(bias_type == 2)
{
float slope = bias_buf[h];
s += slope * static_cast<float>(sk - sq);
}
scores[sk] = s;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 20: FMHA with Bias (GPU)", "FMHA bias variants on GPU");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
args.add_option("--hdim", "128", "Head dimension");
args.add_flag("--validate", "Validate against CPU reference");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 20: FMHA with Bias (GPU)");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("bias_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
// Allocate Q, K, V GPU buffers (shared across all bias tests)
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t k_elems = q_elems;
const int64_t v_elems = q_elems;
const int64_t o_elems = q_elems;
std::cout << "\nStep 2: Allocate GPU Buffers\n";
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
<< "]\n";
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(k_elems);
GpuBuffer<FmhaDataType> v_dev(v_elems);
GpuBuffer<FmhaDataType> o_dev(o_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems);
std::vector<FmhaDataType> k_host(k_elems);
std::vector<FmhaDataType> v_host(v_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
// Convert to f32 for CPU reference
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < k_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < v_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
// Prepare elementwise bias buffer: [1, nhead, seqlen, seqlen] with small values
const int64_t elem_bias_elems = static_cast<int64_t>(nhead) * seqlen * seqlen;
std::vector<float> elem_bias_host(elem_bias_elems);
std::uniform_real_distribution<float> bias_dist(-0.1f, 0.1f);
for(auto& x : elem_bias_host)
x = bias_dist(rng);
GpuBuffer<float> elem_bias_dev(elem_bias_elems);
elem_bias_dev.copy_from_host(elem_bias_host.data());
// Prepare ALiBi slopes buffer: [nhead] with geometric slopes
std::vector<float> alibi_slopes_host(nhead);
for(int h = 0; h < nhead; ++h)
{
alibi_slopes_host[h] = -std::pow(2.0f, -(8.0f * (h + 1) / nhead));
}
GpuBuffer<float> alibi_slopes_dev(nhead);
alibi_slopes_dev.copy_from_host(alibi_slopes_host.data());
// Test each bias type
struct BiasTest
{
const char* name;
int bias_type_int;
bias_enum bias_type;
void* bias_ptr;
int stride_bias;
int nhead_stride_bias;
int batch_stride_bias;
};
BiasTest tests[] = {
{"no_bias", 0, bias_enum::no_bias, nullptr, 0, 0, 0},
{"elementwise_bias",
1,
bias_enum::elementwise_bias,
elem_bias_dev.get(),
seqlen,
seqlen * seqlen,
0},
{"alibi", 2, bias_enum::alibi, alibi_slopes_dev.get(), 0, 1, 0},
};
bool all_passed = true;
for(const auto& test : tests)
{
std::cout << "\nStep 3: Run FMHA Forward [" << test.name << "]\n";
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = test.bias_type;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
o_dev.zero();
fmha_fwd_args fmha_args{};
fmha_args.q_ptr = q_dev.get();
fmha_args.k_ptr = k_dev.get();
fmha_args.v_ptr = v_dev.get();
fmha_args.o_ptr = o_dev.get();
fmha_args.bias_ptr = test.bias_ptr;
fmha_args.q_descale_ptr = nullptr;
fmha_args.k_descale_ptr = nullptr;
fmha_args.v_descale_ptr = nullptr;
fmha_args.rand_val_ptr = nullptr;
fmha_args.lse_ptr = nullptr;
fmha_args.sink_ptr = nullptr;
fmha_args.block_scale_seqstart_q_ptr = nullptr;
fmha_args.block_scale_seqstart_k_ptr = nullptr;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.batch = batch;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.scale_s = scale;
fmha_args.logits_soft_cap = 0.0f;
// bhsd layout strides
fmha_args.stride_q = hdim;
fmha_args.stride_k = hdim;
fmha_args.stride_v = hdim;
fmha_args.stride_bias = test.stride_bias;
fmha_args.stride_randval = 0;
fmha_args.stride_o = hdim;
fmha_args.nhead_stride_q = seqlen * hdim;
fmha_args.nhead_stride_k = seqlen * hdim;
fmha_args.nhead_stride_v = seqlen * hdim;
fmha_args.nhead_stride_bias = test.nhead_stride_bias;
fmha_args.nhead_stride_randval = 0;
fmha_args.nhead_stride_lse = 0;
fmha_args.nhead_stride_o = seqlen * hdim;
fmha_args.nhead_stride_q_descale = 0;
fmha_args.nhead_stride_k_descale = 0;
fmha_args.nhead_stride_v_descale = 0;
fmha_args.batch_stride_q = nhead * seqlen * hdim;
fmha_args.batch_stride_k = nhead * seqlen * hdim;
fmha_args.batch_stride_v = nhead * seqlen * hdim;
fmha_args.batch_stride_bias = test.batch_stride_bias;
fmha_args.batch_stride_randval = 0;
fmha_args.batch_stride_lse = 0;
fmha_args.batch_stride_o = nhead * seqlen * hdim;
fmha_args.batch_stride_q_descale = 0;
fmha_args.batch_stride_k_descale = 0;
fmha_args.batch_stride_v_descale = 0;
fmha_args.window_size_left = -1;
fmha_args.window_size_right = -1;
fmha_args.sink_size = 0;
fmha_args.mask_type = 0;
fmha_args.min_seqlen_q = 0;
fmha_args.p_drop = 0.0f;
fmha_args.s_randval = false;
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fmha_args.block_scale_size_q = 0;
fmha_args.block_scale_size_kv = 0;
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
}
catch(const std::exception& e)
{
std::cerr << " ERROR [" << test.name << "]: " << e.what() << "\n";
all_passed = false;
continue;
}
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Validate
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < o_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
if(nonzero == 0)
all_passed = false;
if(args.has("--validate"))
{
std::vector<float> o_ref(o_elems, 0.0f);
if(test.bias_type_int == 0)
{
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
}
else
{
const std::vector<float>& bias_ref =
(test.bias_type_int == 1) ? elem_bias_host : alibi_slopes_host;
cpu_attention_fwd_biased(q_f32,
k_f32,
v_f32,
o_ref,
batch,
nhead,
seqlen,
seqlen,
hdim,
hdim,
scale,
test.bias_type_int,
bias_ref);
}
double max_abs_err = 0.0;
double max_rel_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < o_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
max_abs_err = std::max(max_abs_err, abs_err);
max_rel_err = std::max(max_rel_err, rel_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Max rel error: " << max_rel_err << "\n";
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
if(errors > 0)
all_passed = false;
}
}
print_separator();
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
print_separator();
return all_passed ? 0 : 1;
}

View File

@@ -0,0 +1,697 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 21: GPU Features FMHA
//
// Tests multiple FMHA features with real GPU execution:
// 1. Dropout (with LSE, rand_val buffer)
// 2. GQA (nhead_q=16, nhead_k=4, same kernel)
// 3. LSE output (verify log-sum-exp values)
//
// Mirrors 01_basic_fmha.cpp for each feature variant.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(gpu_features_fmha_kernels,
// Basic fp16 kernel (used for GQA -- GQA is a runtime concern, same kernel)
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Dropout kernel (requires LSE)
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(true)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// LSE-only kernel
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead_q,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
std::vector<float>* lse_out = nullptr)
{
const int nhead_ratio = nhead_q / nhead_k;
for(int b = 0; b < batch; ++b)
{
for(int hq = 0; hq < nhead_q; ++hq)
{
const int hk = hq / nhead_ratio;
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead_q + hq) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead_k + hk) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
if(lse_out)
{
int lse_idx = (b * nhead_q + hq) * seqlen_q + sq;
(*lse_out)[lse_idx] = max_score + std::log(sum_exp);
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead_k + hk) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead_q + hq) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
struct FeatureResult
{
std::string name;
bool passed;
float time_ms;
};
fmha_fwd_args make_base_args(void* q,
void* k,
void* v,
void* o,
int batch,
int nhead_q,
int nhead_k,
int seqlen,
int hdim,
float scale)
{
fmha_fwd_args a{};
a.q_ptr = q;
a.k_ptr = k;
a.v_ptr = v;
a.o_ptr = o;
a.bias_ptr = nullptr;
a.q_descale_ptr = nullptr;
a.k_descale_ptr = nullptr;
a.v_descale_ptr = nullptr;
a.rand_val_ptr = nullptr;
a.lse_ptr = nullptr;
a.sink_ptr = nullptr;
a.block_scale_seqstart_q_ptr = nullptr;
a.block_scale_seqstart_k_ptr = nullptr;
a.seqlen_q = seqlen;
a.seqlen_k = seqlen;
a.batch = batch;
a.max_seqlen_q = seqlen;
a.hdim_q = hdim;
a.hdim_v = hdim;
a.nhead_q = nhead_q;
a.nhead_k = nhead_k;
a.scale_s = scale;
a.logits_soft_cap = 0.0f;
a.stride_q = hdim;
a.stride_k = hdim;
a.stride_v = hdim;
a.stride_bias = 0;
a.stride_randval = 0;
a.stride_o = hdim;
a.nhead_stride_q = seqlen * hdim;
a.nhead_stride_k = seqlen * hdim;
a.nhead_stride_v = seqlen * hdim;
a.nhead_stride_bias = 0;
a.nhead_stride_randval = 0;
a.nhead_stride_lse = 0;
a.nhead_stride_o = seqlen * hdim;
a.nhead_stride_q_descale = 0;
a.nhead_stride_k_descale = 0;
a.nhead_stride_v_descale = 0;
a.batch_stride_q = nhead_q * seqlen * hdim;
a.batch_stride_k = nhead_k * seqlen * hdim;
a.batch_stride_v = nhead_k * seqlen * hdim;
a.batch_stride_bias = 0;
a.batch_stride_randval = 0;
a.batch_stride_lse = 0;
a.batch_stride_o = nhead_q * seqlen * hdim;
a.batch_stride_q_descale = 0;
a.batch_stride_k_descale = 0;
a.batch_stride_v_descale = 0;
a.window_size_left = -1;
a.window_size_right = -1;
a.sink_size = 0;
a.mask_type = 0;
a.min_seqlen_q = 0;
a.p_drop = 0.0f;
a.s_randval = false;
a.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
a.block_scale_size_q = 0;
a.block_scale_size_kv = 0;
return a;
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 21: GPU Features FMHA", "Dropout, GQA, LSE with real GPU data");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--seqlen", "64", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
print_header("Example 21: GPU Features FMHA");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("gpu_features_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FeatureResult> results;
// -----------------------------------------------------------------------
// Feature A: GQA (nhead_q=16, nhead_k=4, same basic kernel)
// -----------------------------------------------------------------------
{
std::cout << "\nStep 2a: GQA (nhead_q=16, nhead_k=4)\n";
const int nhead_q = 16;
const int nhead_k = 4;
const int64_t q_elems = static_cast<int64_t>(batch) * nhead_q * seqlen * hdim;
const int64_t k_elems = static_cast<int64_t>(batch) * nhead_k * seqlen * hdim;
const int64_t o_elems = q_elems;
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(k_elems);
GpuBuffer<FmhaDataType> v_dev(k_elems);
GpuBuffer<FmhaDataType> o_dev(o_elems);
std::vector<FmhaDataType> q_host(q_elems), k_host(k_elems), v_host(k_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
auto fmha_args = make_base_args(q_dev.get(),
k_dev.get(),
v_dev.get(),
o_dev.get(),
batch,
nhead_q,
nhead_k,
seqlen,
hdim,
scale);
bool passed = false;
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
// Validate against CPU reference with GQA head repetition
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(k_elems);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < k_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < k_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
std::vector<float> o_ref(o_elems, 0.0f);
cpu_attention_fwd(q_f32,
k_f32,
v_f32,
o_ref,
batch,
nhead_q,
nhead_k,
seqlen,
seqlen,
hdim,
hdim,
scale);
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
double max_abs_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < o_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
max_abs_err = std::max(max_abs_err, abs_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
passed = (errors == 0);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
}
results.push_back({"GQA (16q/4k)", passed, time_ms});
}
// -----------------------------------------------------------------------
// Feature B: LSE output
// -----------------------------------------------------------------------
{
std::cout << "\nStep 2b: LSE Output\n";
const int nhead = 4;
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
GpuBuffer<float> lse_dev(lse_elems);
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
lse_dev.zero();
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = true;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
auto fmha_args = make_base_args(q_dev.get(),
k_dev.get(),
v_dev.get(),
o_dev.get(),
batch,
nhead,
nhead,
seqlen,
hdim,
scale);
fmha_args.lse_ptr = lse_dev.get();
fmha_args.nhead_stride_lse = seqlen;
fmha_args.batch_stride_lse = nhead * seqlen;
bool passed = false;
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
// Compute CPU reference LSE
std::vector<float> q_f32(qkv_elems), k_f32(qkv_elems), v_f32(qkv_elems);
for(int64_t i = 0; i < qkv_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < qkv_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < qkv_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
std::vector<float> o_ref(qkv_elems, 0.0f);
std::vector<float> lse_ref(lse_elems, 0.0f);
cpu_attention_fwd(q_f32,
k_f32,
v_f32,
o_ref,
batch,
nhead,
nhead,
seqlen,
seqlen,
hdim,
hdim,
scale,
&lse_ref);
std::vector<float> lse_host(lse_elems);
lse_dev.copy_to_host(lse_host.data());
int lse_reasonable = 0;
double max_lse_err = 0.0;
for(int64_t i = 0; i < lse_elems; ++i)
{
if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f)
++lse_reasonable;
double err = std::abs(lse_host[i] - lse_ref[i]);
max_lse_err = std::max(max_lse_err, err);
}
std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n";
std::cout << " LSE max error vs ref: " << std::scientific << max_lse_err << "\n";
std::cout << " LSE sample [0..3]: ";
for(int i = 0; i < std::min<int>(4, lse_elems); ++i)
std::cout << std::fixed << std::setprecision(4) << lse_host[i] << " ";
std::cout << "\n";
passed = (lse_reasonable == lse_elems) && (max_lse_err < 1.0);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
}
results.push_back({"LSE", passed, time_ms});
}
// -----------------------------------------------------------------------
// Feature C: Dropout
// -----------------------------------------------------------------------
{
std::cout << "\nStep 2c: Dropout (p_drop=0.2)\n";
const int nhead = 4;
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
const int64_t randval_elems = static_cast<int64_t>(batch) * nhead * seqlen * seqlen;
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
GpuBuffer<float> lse_dev(lse_elems);
GpuBuffer<uint8_t> rand_val_dev(randval_elems);
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
lse_dev.zero();
rand_val_dev.zero();
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = true;
traits.has_dropout = true;
traits.qscale_type = quant_scale_enum::no_scale;
auto fmha_args = make_base_args(q_dev.get(),
k_dev.get(),
v_dev.get(),
o_dev.get(),
batch,
nhead,
nhead,
seqlen,
hdim,
scale);
fmha_args.lse_ptr = lse_dev.get();
fmha_args.rand_val_ptr = rand_val_dev.get();
fmha_args.nhead_stride_lse = seqlen;
fmha_args.batch_stride_lse = nhead * seqlen;
fmha_args.stride_randval = seqlen;
fmha_args.nhead_stride_randval = seqlen * seqlen;
fmha_args.batch_stride_randval = nhead * seqlen * seqlen;
fmha_args.p_drop = 0.2f;
fmha_args.s_randval = true;
fmha_args.drop_seed_offset = std::make_pair(uint64_t(42), uint64_t(0));
bool passed = false;
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::vector<FmhaDataType> o_host(qkv_elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < qkv_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << " Non-zero outputs: " << nonzero << " / " << qkv_elems << "\n";
std::vector<float> lse_host(lse_elems);
lse_dev.copy_to_host(lse_host.data());
int lse_reasonable = 0;
for(int64_t i = 0; i < lse_elems; ++i)
{
if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f)
++lse_reasonable;
}
std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n";
passed = (nonzero > 0) && (lse_reasonable == lse_elems);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
}
results.push_back({"Dropout", passed, time_ms});
}
// -----------------------------------------------------------------------
// Summary
// -----------------------------------------------------------------------
std::cout << "\nStep 3: Summary\n";
std::cout << " " << std::setw(16) << "Feature" << " | " << std::setw(10) << "Time(ms)" << " | "
<< std::setw(8) << "Status" << "\n";
std::cout << " " << std::string(42, '-') << "\n";
bool all_passed = true;
for(const auto& r : results)
{
std::cout << " " << std::setw(16) << r.name << " | " << std::fixed << std::setprecision(4)
<< std::setw(10) << r.time_ms << " | " << std::setw(8)
<< (r.passed ? "PASS" : "FAIL") << "\n";
if(!r.passed)
all_passed = false;
}
print_separator();
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
print_separator();
return all_passed ? 0 : 1;
}

View File

@@ -0,0 +1,553 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 22: FMHA Backward with GPU Execution
//
// Demonstrates:
// 1. Declare 3 backward kernel families (bwd_dot_do_o, bwd_dq_dk_dv, bwd_convert_dq)
// 2. Run forward to get O and LSE
// 3. Run backward to compute dQ, dK, dV
// 4. Validate gradients are non-zero
//
// Falls back to planning only if backward kernels fail to compile on gfx950.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(gpu_bwd_fmha_kernels,
// Forward kernel (to produce O and LSE for backward)
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Backward: dot(dO, O) to compute d scalar
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
// Backward: compute dQ, dK, dV
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true)
.max_seq_len_q(0)
.selection_rank(0),
"gfx950")
// Backward: convert accumulated dQ from fp32 to fp16
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
std::vector<float>& LSE,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
int lse_idx = (b * nhead + h) * seqlen_q + sq;
LSE[lse_idx] = max_score + std::log(sum_exp);
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 22: FMHA Backward (GPU)", "Forward + backward with GPU validation");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "1", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 1);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
print_header("Example 22: FMHA Backward (GPU Execution)");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("gpu_bwd_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// Step 2: Plan backward to verify all 3 stages resolve
std::cout << "\nStep 2: Plan Backward\n";
fmha_bwd_traits bwd_traits{};
bwd_traits.hdim_q = hdim;
bwd_traits.hdim_v = hdim;
bwd_traits.data_type = "fp16";
bwd_traits.is_group_mode = false;
bwd_traits.mask_type = mask_enum::no_mask;
bwd_traits.bias_type = bias_enum::no_bias;
bwd_traits.has_dbias = false;
bwd_traits.has_dropout = false;
bwd_traits.is_store_randval = false;
bwd_traits.is_deterministic = false;
fmha_bwd_args bwd_args{};
bwd_args.batch = batch;
bwd_args.seqlen_q = seqlen;
bwd_args.seqlen_k = seqlen;
bwd_args.max_seqlen_q = seqlen;
bwd_args.max_seqlen_k = seqlen;
bwd_args.hdim_q = hdim;
bwd_args.hdim_v = hdim;
bwd_args.nhead_q = nhead;
bwd_args.nhead_k = nhead;
auto bwd_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
if(!bwd_plan.is_valid() || bwd_plan.stages.size() < 2)
{
std::cout << " Backward plan: INVALID (expected multi-stage)\n";
std::cout << " Falling back to planning-only mode (like 04_bwd_fmha.cpp)\n";
print_separator();
std::cout << "Status: PLAN_ONLY\n";
print_separator();
return 0;
}
std::cout << " Backward plan stages:\n";
for(const auto& stage : bwd_plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
// Step 3: Allocate buffers
std::cout << "\nStep 3: Allocate GPU Buffers\n";
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
const int64_t dq_acc_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
<< "]\n";
std::cout << " LSE/d: [" << batch << ", " << nhead << ", " << seqlen << "]\n";
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
GpuBuffer<float> lse_dev(lse_elems);
GpuBuffer<FmhaDataType> do_dev(qkv_elems);
GpuBuffer<float> d_dev(lse_elems);
GpuBuffer<FmhaDataType> dq_dev(qkv_elems);
GpuBuffer<FmhaDataType> dk_dev(qkv_elems);
GpuBuffer<FmhaDataType> dv_dev(qkv_elems);
GpuBuffer<float> dq_acc_dev(dq_acc_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
std::vector<FmhaDataType> do_host(qkv_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
for(auto& x : do_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
do_dev.copy_from_host(do_host.data());
o_dev.zero();
lse_dev.zero();
d_dev.zero();
dq_dev.zero();
dk_dev.zero();
dv_dev.zero();
dq_acc_dev.zero();
// Step 4: Run forward to produce O and LSE
std::cout << "\nStep 4: Run Forward (to produce O and LSE)\n";
{
fmha_fwd_traits fwd_traits{};
fwd_traits.hdim_q = hdim;
fwd_traits.hdim_v = hdim;
fwd_traits.data_type = "fp16";
fwd_traits.is_group_mode = false;
fwd_traits.is_v_rowmajor = true;
fwd_traits.has_logits_soft_cap = false;
fwd_traits.mask_type = mask_enum::no_mask;
fwd_traits.bias_type = bias_enum::no_bias;
fwd_traits.has_lse = true;
fwd_traits.has_dropout = false;
fwd_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fwd_args{};
fwd_args.q_ptr = q_dev.get();
fwd_args.k_ptr = k_dev.get();
fwd_args.v_ptr = v_dev.get();
fwd_args.o_ptr = o_dev.get();
fwd_args.lse_ptr = lse_dev.get();
fwd_args.bias_ptr = nullptr;
fwd_args.q_descale_ptr = nullptr;
fwd_args.k_descale_ptr = nullptr;
fwd_args.v_descale_ptr = nullptr;
fwd_args.rand_val_ptr = nullptr;
fwd_args.sink_ptr = nullptr;
fwd_args.block_scale_seqstart_q_ptr = nullptr;
fwd_args.block_scale_seqstart_k_ptr = nullptr;
fwd_args.seqlen_q = seqlen;
fwd_args.seqlen_k = seqlen;
fwd_args.batch = batch;
fwd_args.max_seqlen_q = seqlen;
fwd_args.hdim_q = hdim;
fwd_args.hdim_v = hdim;
fwd_args.nhead_q = nhead;
fwd_args.nhead_k = nhead;
fwd_args.scale_s = scale;
fwd_args.logits_soft_cap = 0.0f;
fwd_args.stride_q = hdim;
fwd_args.stride_k = hdim;
fwd_args.stride_v = hdim;
fwd_args.stride_bias = 0;
fwd_args.stride_randval = 0;
fwd_args.stride_o = hdim;
fwd_args.nhead_stride_q = seqlen * hdim;
fwd_args.nhead_stride_k = seqlen * hdim;
fwd_args.nhead_stride_v = seqlen * hdim;
fwd_args.nhead_stride_bias = 0;
fwd_args.nhead_stride_randval = 0;
fwd_args.nhead_stride_lse = seqlen;
fwd_args.nhead_stride_o = seqlen * hdim;
fwd_args.nhead_stride_q_descale = 0;
fwd_args.nhead_stride_k_descale = 0;
fwd_args.nhead_stride_v_descale = 0;
fwd_args.batch_stride_q = nhead * seqlen * hdim;
fwd_args.batch_stride_k = nhead * seqlen * hdim;
fwd_args.batch_stride_v = nhead * seqlen * hdim;
fwd_args.batch_stride_bias = 0;
fwd_args.batch_stride_randval = 0;
fwd_args.batch_stride_lse = nhead * seqlen;
fwd_args.batch_stride_o = nhead * seqlen * hdim;
fwd_args.batch_stride_q_descale = 0;
fwd_args.batch_stride_k_descale = 0;
fwd_args.batch_stride_v_descale = 0;
fwd_args.window_size_left = -1;
fwd_args.window_size_right = -1;
fwd_args.sink_size = 0;
fwd_args.mask_type = 0;
fwd_args.min_seqlen_q = 0;
fwd_args.p_drop = 0.0f;
fwd_args.s_randval = false;
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fwd_args.block_scale_size_q = 0;
fwd_args.block_scale_size_kv = 0;
try
{
float fwd_time = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
std::cout << " Forward time: " << std::fixed << std::setprecision(4) << fwd_time
<< " ms\n";
}
catch(const std::exception& e)
{
std::cerr << " Forward ERROR: " << e.what() << "\n";
print_separator();
std::cout << "Status: FAIL (forward failed)\n";
print_separator();
return 1;
}
}
// Step 5: Run backward
std::cout << "\nStep 5: Run Backward\n";
bwd_args.q_ptr = q_dev.get();
bwd_args.k_ptr = k_dev.get();
bwd_args.v_ptr = v_dev.get();
bwd_args.bias_ptr = nullptr;
bwd_args.o_ptr = o_dev.get();
bwd_args.lse_ptr = lse_dev.get();
bwd_args.do_ptr = do_dev.get();
bwd_args.d_ptr = d_dev.get();
bwd_args.rand_val_ptr = nullptr;
bwd_args.dq_ptr = dq_dev.get();
bwd_args.dk_ptr = dk_dev.get();
bwd_args.dv_ptr = dv_dev.get();
bwd_args.dbias_ptr = nullptr;
bwd_args.dq_acc_ptr = dq_acc_dev.get();
bwd_args.scale = scale;
bwd_args.stride_q = hdim;
bwd_args.stride_k = hdim;
bwd_args.stride_v = hdim;
bwd_args.stride_bias = 0;
bwd_args.stride_o = hdim;
bwd_args.stride_randval = 0;
bwd_args.stride_do = hdim;
bwd_args.stride_dq_acc = hdim;
bwd_args.stride_dq = hdim;
bwd_args.stride_dk = hdim;
bwd_args.stride_dv = hdim;
bwd_args.stride_dbias = 0;
bwd_args.nhead_stride_q = seqlen * hdim;
bwd_args.nhead_stride_k = seqlen * hdim;
bwd_args.nhead_stride_v = seqlen * hdim;
bwd_args.nhead_stride_bias = 0;
bwd_args.nhead_stride_o = seqlen * hdim;
bwd_args.nhead_stride_randval = 0;
bwd_args.nhead_stride_do = seqlen * hdim;
bwd_args.nhead_stride_lsed = seqlen;
bwd_args.nhead_stride_dq_acc = static_cast<int64_t>(seqlen) * hdim;
bwd_args.nhead_stride_dq = seqlen * hdim;
bwd_args.nhead_stride_dk = seqlen * hdim;
bwd_args.nhead_stride_dv = seqlen * hdim;
bwd_args.nhead_stride_dbias = 0;
bwd_args.batch_stride_q = nhead * seqlen * hdim;
bwd_args.batch_stride_k = nhead * seqlen * hdim;
bwd_args.batch_stride_v = nhead * seqlen * hdim;
bwd_args.batch_stride_bias = 0;
bwd_args.batch_stride_o = nhead * seqlen * hdim;
bwd_args.batch_stride_randval = 0;
bwd_args.batch_stride_do = nhead * seqlen * hdim;
bwd_args.batch_stride_lsed = nhead * seqlen;
bwd_args.batch_stride_dq_acc = static_cast<int64_t>(nhead) * seqlen * hdim;
bwd_args.batch_stride_dq = nhead * seqlen * hdim;
bwd_args.batch_stride_dk = nhead * seqlen * hdim;
bwd_args.batch_stride_dv = nhead * seqlen * hdim;
bwd_args.batch_stride_dbias = 0;
bwd_args.split_stride_dq_acc = 0;
bwd_args.window_size_left = -1;
bwd_args.window_size_right = -1;
bwd_args.mask_type = 0;
bwd_args.p_drop = 0.0f;
bwd_args.p_undrop = 1.0f;
bwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
bool bwd_passed = false;
try
{
float bwd_time = dispatcher.run_bwd(bwd_traits, bwd_args, nullptr);
std::cout << " Backward time: " << std::fixed << std::setprecision(4) << bwd_time
<< " ms\n";
// Validate: dQ, dK, dV should be non-zero
std::vector<FmhaDataType> dq_host(qkv_elems), dk_host(qkv_elems), dv_host(qkv_elems);
dq_dev.copy_to_host(dq_host.data());
dk_dev.copy_to_host(dk_host.data());
dv_dev.copy_to_host(dv_host.data());
auto count_nonzero = [](const std::vector<FmhaDataType>& buf) {
int nz = 0;
for(const auto& x : buf)
{
if(static_cast<float>(x) != 0.0f)
++nz;
}
return nz;
};
int dq_nz = count_nonzero(dq_host);
int dk_nz = count_nonzero(dk_host);
int dv_nz = count_nonzero(dv_host);
std::cout << " dQ non-zero: " << dq_nz << " / " << qkv_elems << "\n";
std::cout << " dK non-zero: " << dk_nz << " / " << qkv_elems << "\n";
std::cout << " dV non-zero: " << dv_nz << " / " << qkv_elems << "\n";
bwd_passed = (dq_nz > 0) && (dk_nz > 0) && (dv_nz > 0);
}
catch(const std::exception& e)
{
std::cerr << " Backward ERROR: " << e.what() << "\n";
std::cout << " Falling back to planning-only mode (like 04_bwd_fmha.cpp)\n";
std::cout << " Backward plan was valid with " << bwd_plan.stages.size() << " stages\n";
print_separator();
std::cout << "Status: PLAN_ONLY\n";
print_separator();
return 0;
}
print_separator();
std::cout << "Status: " << (bwd_passed ? "PASS" : "FAIL") << "\n";
print_separator();
return bwd_passed ? 0 : 1;
}

View File

@@ -0,0 +1,595 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 23: Multiple Registries for Different Frameworks
//
// Demonstrates:
// 1. Three separate FmhaRegistry instances (pytorch, flash, aiter)
// 2. Each with its own DECL_FMHA_KERNEL_SET using different configs
// 3. Registry introspection: size(), filter(), export_json()
// 4. Planning the same problem from each registry
// 5. GPU execution from the basic kernel registry
//
// Key idea: separate registries let each framework recipient own its
// kernel population independently.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
// Three DECL_FMHA_KERNEL_SETs with distinct names and configurations.
// All register into the global FmhaKernelSetRegistry at static init time.
DECL_FMHA_KERNEL_SET(pytorch_reg_kernels,
// PyTorch: basic fp16, elementwise bias
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("bias")
.lse(false)
.dropout(false)
.qscale("no")
.profile("pytorch"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
DECL_FMHA_KERNEL_SET(flash_reg_kernels,
// Flash: fp16, alibi bias
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("alibi")
.lse(false)
.dropout(false)
.qscale("no")
.profile("flash_fwd"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
DECL_FMHA_KERNEL_SET(aiter_reg_kernels,
// AITER: batch mode basic
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no")
.profile("aiter_batch"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// AITER: group mode
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("group")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no")
.profile("aiter_group"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
scores[sk] /= sum_exp;
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
struct RegistryInfo
{
std::string name;
FmhaRegistry* reg;
FmhaDispatcher* disp;
};
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 23: Multi-Registry FMHA",
"Separate registries per framework recipient");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
print_header("Example 23: Multi-Registry FMHA");
// Step 1: Create 3 separate registries
std::cout << "\nStep 1: Create Separate Registries\n";
std::cout << " Global kernel sets declared: " << FmhaKernelSetRegistry::instance().size()
<< "\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry pytorch_reg;
pytorch_reg.set_name("pytorch");
REGISTER_GENERATED_KERNELS(pytorch_reg, gfx_arch);
FmhaRegistry flash_reg;
flash_reg.set_name("flash");
REGISTER_GENERATED_KERNELS(flash_reg, gfx_arch);
FmhaRegistry aiter_reg;
aiter_reg.set_name("aiter");
REGISTER_GENERATED_KERNELS(aiter_reg, gfx_arch);
FmhaDispatcher pytorch_disp(&pytorch_reg);
FmhaDispatcher flash_disp(&flash_reg);
FmhaDispatcher aiter_disp(&aiter_reg);
std::vector<RegistryInfo> registries = {
{"pytorch", &pytorch_reg, &pytorch_disp},
{"flash", &flash_reg, &flash_disp},
{"aiter", &aiter_reg, &aiter_disp},
};
// Step 2: Registry introspection
std::cout << "\nStep 2: Registry Introspection\n";
for(const auto& ri : registries)
{
std::cout << "\n Registry: " << ri.name << "\n";
std::cout << " Kernel count: " << ri.reg->size() << "\n";
auto all_kernels = ri.reg->get_all();
for(const auto& k : all_kernels)
{
std::cout << " Kernel: " << k->get_name() << "\n";
}
auto fwd_kernels = ri.reg->filter([](const FmhaKernelInstance& inst) {
return inst.get_key().signature.family == FmhaKernelFamily::Fwd;
});
std::cout << " Forward kernels: " << fwd_kernels.size() << "\n";
std::string json = ri.reg->export_json(false);
std::cout << " JSON size: " << json.size() << " bytes\n";
}
// Step 3: Plan the same problem from each registry
std::cout << "\nStep 3: Plan from Each Registry\n";
// Problem A: basic fp16 no-bias (matches aiter_batch)
{
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fmha_args{};
fmha_args.batch = batch;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
std::cout << "\n Problem: fp16 batch no-bias\n";
for(const auto& ri : registries)
{
auto plan = ri.disp->plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
std::cout << " " << ri.name << ": "
<< (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n";
}
}
// Problem B: fp16 with alibi bias (matches flash)
{
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::alibi;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fmha_args{};
fmha_args.batch = batch;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
std::cout << "\n Problem: fp16 batch alibi-bias\n";
for(const auto& ri : registries)
{
auto plan = ri.disp->plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
std::cout << " " << ri.name << ": "
<< (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n";
}
}
// Problem C: fp16 with elementwise bias (matches pytorch)
{
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::elementwise_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fmha_args{};
fmha_args.batch = batch;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
std::cout << "\n Problem: fp16 batch elementwise-bias\n";
for(const auto& ri : registries)
{
auto plan = ri.disp->plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
std::cout << " " << ri.name << ": "
<< (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n";
}
}
// Step 4: GPU execution from AITER registry (basic no-bias kernel)
std::cout << "\nStep 4: GPU Execution (aiter registry)\n";
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(q_elems);
GpuBuffer<FmhaDataType> v_dev(q_elems);
GpuBuffer<FmhaDataType> o_dev(q_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems), k_host(q_elems), v_host(q_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
fmha_fwd_traits run_traits{};
run_traits.hdim_q = hdim;
run_traits.hdim_v = hdim;
run_traits.data_type = "fp16";
run_traits.is_group_mode = false;
run_traits.is_v_rowmajor = true;
run_traits.has_logits_soft_cap = false;
run_traits.mask_type = mask_enum::no_mask;
run_traits.bias_type = bias_enum::no_bias;
run_traits.has_lse = false;
run_traits.has_dropout = false;
run_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args run_args{};
run_args.q_ptr = q_dev.get();
run_args.k_ptr = k_dev.get();
run_args.v_ptr = v_dev.get();
run_args.o_ptr = o_dev.get();
run_args.bias_ptr = nullptr;
run_args.q_descale_ptr = nullptr;
run_args.k_descale_ptr = nullptr;
run_args.v_descale_ptr = nullptr;
run_args.rand_val_ptr = nullptr;
run_args.lse_ptr = nullptr;
run_args.sink_ptr = nullptr;
run_args.block_scale_seqstart_q_ptr = nullptr;
run_args.block_scale_seqstart_k_ptr = nullptr;
run_args.seqlen_q = seqlen;
run_args.seqlen_k = seqlen;
run_args.batch = batch;
run_args.max_seqlen_q = seqlen;
run_args.hdim_q = hdim;
run_args.hdim_v = hdim;
run_args.nhead_q = nhead;
run_args.nhead_k = nhead;
run_args.scale_s = scale;
run_args.logits_soft_cap = 0.0f;
run_args.stride_q = hdim;
run_args.stride_k = hdim;
run_args.stride_v = hdim;
run_args.stride_bias = 0;
run_args.stride_randval = 0;
run_args.stride_o = hdim;
run_args.nhead_stride_q = seqlen * hdim;
run_args.nhead_stride_k = seqlen * hdim;
run_args.nhead_stride_v = seqlen * hdim;
run_args.nhead_stride_bias = 0;
run_args.nhead_stride_randval = 0;
run_args.nhead_stride_lse = 0;
run_args.nhead_stride_o = seqlen * hdim;
run_args.nhead_stride_q_descale = 0;
run_args.nhead_stride_k_descale = 0;
run_args.nhead_stride_v_descale = 0;
run_args.batch_stride_q = nhead * seqlen * hdim;
run_args.batch_stride_k = nhead * seqlen * hdim;
run_args.batch_stride_v = nhead * seqlen * hdim;
run_args.batch_stride_bias = 0;
run_args.batch_stride_randval = 0;
run_args.batch_stride_lse = 0;
run_args.batch_stride_o = nhead * seqlen * hdim;
run_args.batch_stride_q_descale = 0;
run_args.batch_stride_k_descale = 0;
run_args.batch_stride_v_descale = 0;
run_args.window_size_left = -1;
run_args.window_size_right = -1;
run_args.sink_size = 0;
run_args.mask_type = 0;
run_args.min_seqlen_q = 0;
run_args.p_drop = 0.0f;
run_args.s_randval = false;
run_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
run_args.block_scale_size_q = 0;
run_args.block_scale_size_kv = 0;
bool passed = false;
aiter_disp.set_benchmarking(true);
aiter_disp.set_timing(1, 3);
try
{
float time_ms = aiter_disp.run_fwd(run_traits, run_args, nullptr);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::vector<FmhaDataType> o_host(q_elems);
o_dev.copy_to_host(o_host.data());
// Validate
std::vector<float> q_f32(q_elems), k_f32(q_elems), v_f32(q_elems), o_ref(q_elems, 0.0f);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < q_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < q_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
double max_abs_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < q_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
max_abs_err = std::max(max_abs_err, abs_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Errors: " << errors << " / " << q_elems << "\n";
passed = (errors == 0);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
}
print_separator();
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,549 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 24: Per-Receipt Registries
//
// Demonstrates:
// 1. Four DECL_FMHA_KERNEL_SET declarations, each named after a receipt
// 2. Each registered into a separate FmhaRegistry
// 3. Per-registry: kernel count, kernel names, plan a problem, selected kernel
// 4. GPU execution from the ck_default receipt (the basic working kernel)
// 5. Comparison table showing which features each receipt supports
//
// Receipt = a curated kernel set shipped to a specific downstream consumer.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
// Receipt 1: CK default -- basic fp16, no mask, no bias
DECL_FMHA_KERNEL_SET(ck_default_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
// Receipt 2: Flash forward -- fp16 with alibi bias
DECL_FMHA_KERNEL_SET(flash_fwd_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("alibi")
.lse(false)
.dropout(false)
.qscale("no")
.profile("flash_fwd"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
// Receipt 3: PyTorch -- fp16 with elementwise bias
DECL_FMHA_KERNEL_SET(pytorch_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("bias")
.lse(false)
.dropout(false)
.qscale("no")
.profile("pytorch"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
// Receipt 4: AITER batch -- fp16 batch mode with LSE
DECL_FMHA_KERNEL_SET(aiter_batch_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no")
.profile("aiter_batch"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
struct ReceiptInfo
{
std::string name;
std::string bias_desc;
bool has_lse;
FmhaRegistry registry;
};
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
scores[sk] /= sum_exp;
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 24: Per-Receipt Registries",
"Curated kernel sets per downstream consumer");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
print_header("Example 24: Per-Receipt Registries");
// Step 1: Create per-receipt registries
std::cout << "\nStep 1: Create Per-Receipt Registries\n";
std::cout << " Global kernel sets: " << FmhaKernelSetRegistry::instance().size() << "\n";
std::vector<ReceiptInfo> receipts;
receipts.push_back({"ck_default", "none", false, FmhaRegistry()});
receipts.back().registry.set_name("ck_default");
REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch);
receipts.push_back({"flash_fwd", "alibi", false, FmhaRegistry()});
receipts.back().registry.set_name("flash_fwd");
REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch);
receipts.push_back({"pytorch", "elementwise", false, FmhaRegistry()});
receipts.back().registry.set_name("pytorch");
REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch);
receipts.push_back({"aiter_batch", "none", true, FmhaRegistry()});
receipts.back().registry.set_name("aiter_batch");
REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch);
// Step 2: Per-registry introspection
std::cout << "\nStep 2: Per-Receipt Introspection\n";
for(auto& r : receipts)
{
std::cout << "\n Receipt: " << r.name << "\n";
std::cout << " Kernel count: " << r.registry.size() << "\n";
auto all = r.registry.get_all();
for(const auto& k : all)
{
std::cout << " Kernel: " << k->get_name() << "\n";
}
}
// Step 3: Plan a matching problem for each receipt
std::cout << "\nStep 3: Plan per Receipt\n";
struct PlanTest
{
std::string receipt_name;
bias_enum bias;
bool lse;
};
std::vector<PlanTest> plan_tests = {
{"ck_default", bias_enum::no_bias, false},
{"flash_fwd", bias_enum::alibi, false},
{"pytorch", bias_enum::elementwise_bias, false},
{"aiter_batch", bias_enum::no_bias, true},
};
for(std::size_t i = 0; i < plan_tests.size(); ++i)
{
const auto& pt = plan_tests[i];
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = pt.bias;
traits.has_lse = pt.lse;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fmha_args{};
fmha_args.batch = batch;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
FmhaDispatcher disp(&receipts[i].registry);
auto plan = disp.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
std::cout << " " << pt.receipt_name << ": "
<< (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n";
}
// Step 4: Comparison table
std::cout << "\nStep 4: Receipt Feature Comparison\n\n";
std::cout << " " << std::setw(14) << "Receipt" << " | " << std::setw(14) << "Bias" << " | "
<< std::setw(5) << "LSE" << " | " << std::setw(8) << "Kernels" << "\n";
std::cout << " " << std::string(50, '-') << "\n";
struct CompRow
{
std::string name;
std::string bias;
std::string lse;
std::size_t count;
};
std::vector<CompRow> comp = {
{"ck_default", "none", "no", receipts[0].registry.size()},
{"flash_fwd", "alibi", "no", receipts[1].registry.size()},
{"pytorch", "elementwise", "no", receipts[2].registry.size()},
{"aiter_batch", "none", "yes", receipts[3].registry.size()},
};
for(const auto& c : comp)
{
std::cout << " " << std::setw(14) << c.name << " | " << std::setw(14) << c.bias << " | "
<< std::setw(5) << c.lse << " | " << std::setw(8) << c.count << "\n";
}
// Step 5: GPU execution from ck_default
std::cout << "\nStep 5: GPU Execution (ck_default receipt)\n";
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(q_elems);
GpuBuffer<FmhaDataType> v_dev(q_elems);
GpuBuffer<FmhaDataType> o_dev(q_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems), k_host(q_elems), v_host(q_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
fmha_fwd_traits run_traits{};
run_traits.hdim_q = hdim;
run_traits.hdim_v = hdim;
run_traits.data_type = "fp16";
run_traits.is_group_mode = false;
run_traits.is_v_rowmajor = true;
run_traits.has_logits_soft_cap = false;
run_traits.mask_type = mask_enum::no_mask;
run_traits.bias_type = bias_enum::no_bias;
run_traits.has_lse = false;
run_traits.has_dropout = false;
run_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args run_args{};
run_args.q_ptr = q_dev.get();
run_args.k_ptr = k_dev.get();
run_args.v_ptr = v_dev.get();
run_args.o_ptr = o_dev.get();
run_args.bias_ptr = nullptr;
run_args.q_descale_ptr = nullptr;
run_args.k_descale_ptr = nullptr;
run_args.v_descale_ptr = nullptr;
run_args.rand_val_ptr = nullptr;
run_args.lse_ptr = nullptr;
run_args.sink_ptr = nullptr;
run_args.block_scale_seqstart_q_ptr = nullptr;
run_args.block_scale_seqstart_k_ptr = nullptr;
run_args.seqlen_q = seqlen;
run_args.seqlen_k = seqlen;
run_args.batch = batch;
run_args.max_seqlen_q = seqlen;
run_args.hdim_q = hdim;
run_args.hdim_v = hdim;
run_args.nhead_q = nhead;
run_args.nhead_k = nhead;
run_args.scale_s = scale;
run_args.logits_soft_cap = 0.0f;
run_args.stride_q = hdim;
run_args.stride_k = hdim;
run_args.stride_v = hdim;
run_args.stride_bias = 0;
run_args.stride_randval = 0;
run_args.stride_o = hdim;
run_args.nhead_stride_q = seqlen * hdim;
run_args.nhead_stride_k = seqlen * hdim;
run_args.nhead_stride_v = seqlen * hdim;
run_args.nhead_stride_bias = 0;
run_args.nhead_stride_randval = 0;
run_args.nhead_stride_lse = 0;
run_args.nhead_stride_o = seqlen * hdim;
run_args.nhead_stride_q_descale = 0;
run_args.nhead_stride_k_descale = 0;
run_args.nhead_stride_v_descale = 0;
run_args.batch_stride_q = nhead * seqlen * hdim;
run_args.batch_stride_k = nhead * seqlen * hdim;
run_args.batch_stride_v = nhead * seqlen * hdim;
run_args.batch_stride_bias = 0;
run_args.batch_stride_randval = 0;
run_args.batch_stride_lse = 0;
run_args.batch_stride_o = nhead * seqlen * hdim;
run_args.batch_stride_q_descale = 0;
run_args.batch_stride_k_descale = 0;
run_args.batch_stride_v_descale = 0;
run_args.window_size_left = -1;
run_args.window_size_right = -1;
run_args.sink_size = 0;
run_args.mask_type = 0;
run_args.min_seqlen_q = 0;
run_args.p_drop = 0.0f;
run_args.s_randval = false;
run_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
run_args.block_scale_size_q = 0;
run_args.block_scale_size_kv = 0;
FmhaDispatcher ck_disp(&receipts[0].registry);
ck_disp.set_benchmarking(true);
ck_disp.set_timing(1, 3);
bool passed = false;
try
{
float time_ms = ck_disp.run_fwd(run_traits, run_args, nullptr);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::vector<FmhaDataType> o_host(q_elems);
o_dev.copy_to_host(o_host.data());
std::vector<float> q_f32(q_elems), k_f32(q_elems), v_f32(q_elems), o_ref(q_elems, 0.0f);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < q_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < q_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
double max_abs_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < q_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
max_abs_err = std::max(max_abs_err, abs_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Errors: " << errors << " / " << q_elems << "\n";
passed = (errors == 0);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
}
print_separator();
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,530 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 25: AppendKV + BatchPrefill Planning with GPU Execution
//
// Demonstrates:
// 1. Declare appendkv, batch_prefill, and basic fwd kernels
// 2. Plan appendkv with fmha_fwd_appendkv_traits / fmha_fwd_appendkv_args
// 3. Plan batch_prefill with fmha_batch_prefill_traits / fmha_batch_prefill_args
// 4. Run basic fwd kernel on GPU as sanity check
// 5. Show cache_batch_idx usage pattern for non-contiguous batches
//
// Mirrors 01_basic_fmha.cpp for FMHA.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(appendkv_batchprefill_kernels,
// AppendKV kernel
.add(FmhaSignature()
.family("fwd_appendkv")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.rope("inter")
.paged_kv(true)
.kv_cache("vectorized", "sglang", 16),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(64)
.tile_n0(64)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.pipeline("appendkv")
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
// BatchPrefill kernel (group mode, paged KV, page_size=64)
.add(FmhaSignature()
.family("batch_prefill")
.dtype("fp16")
.mode("group")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no")
.paged_kv(true)
.kv_cache("vectorized", "sglang", 64),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
// Basic fwd kernel for GPU execution sanity check
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 25: AppendKV + BatchPrefill + GPU",
"FMHA AppendKV/BatchPrefill planning with GPU sanity check");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
args.add_option("--hdim", "128", "Head dimension");
args.add_flag("--validate", "Validate against CPU reference");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 25: AppendKV + BatchPrefill + GPU Execution");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("appendkv_batchprefill");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// =========================================================================
// Step 2: Plan AppendKV
// traits: fmha_fwd_appendkv_traits (hdim_q, hdim_v, data_type,
// is_v_rowmajor, rope_type)
// args: fmha_fwd_appendkv_args (q_ptr, k_ptr, knew_ptr, v_ptr,
// vnew_ptr, seqlen_q, seqlen_knew, ...)
// =========================================================================
std::cout << "\nStep 2: Plan AppendKV\n";
fmha_fwd_appendkv_traits append_traits{};
append_traits.hdim_q = hdim;
append_traits.hdim_v = hdim;
append_traits.data_type = "fp16";
append_traits.is_v_rowmajor = true;
append_traits.rope_type = rope_enum::interleaved;
fmha_fwd_appendkv_args append_args{};
append_args.q_ptr = reinterpret_cast<void*>(0x1);
append_args.k_ptr = reinterpret_cast<void*>(0x1);
append_args.knew_ptr = reinterpret_cast<void*>(0x1);
append_args.v_ptr = reinterpret_cast<void*>(0x1);
append_args.vnew_ptr = reinterpret_cast<void*>(0x1);
append_args.seqlen_q = 1;
append_args.seqlen_knew = 1;
append_args.batch = batch;
append_args.hdim_q = hdim;
append_args.hdim_v = hdim;
append_args.nhead_q = nhead;
append_args.nhead_k = nhead;
append_args.rotary_dim = hdim;
append_args.rotary_cos_ptr = reinterpret_cast<void*>(0x1);
append_args.rotary_sin_ptr = reinterpret_cast<void*>(0x1);
append_args.block_table_ptr = reinterpret_cast<void*>(0x1);
append_args.page_block_size = 16;
// cache_batch_idx: maps request index -> cache slot for non-contiguous batches.
// When serving multiple requests that don't occupy contiguous cache slots,
// this indirection array tells the kernel which cache row each request maps to.
append_args.cache_batch_idx_ptr = reinterpret_cast<void*>(0x1);
auto append_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(append_traits, append_args), gfx_arch));
std::cout << " AppendKV plan valid: " << (append_plan.is_valid() ? "yes" : "no") << "\n";
if(append_plan.is_valid())
{
for(const auto& stage : append_plan.stages)
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
// =========================================================================
// Step 3: Plan BatchPrefill
// traits: fmha_batch_prefill_traits (extends fmha_fwd_traits with
// kv_memory_layout, kv_lookup_table, page_size)
// args: fmha_batch_prefill_args (kv_indptr, kv_page_indices,
// kv_last_page_lens, seqstart_q_ptr, ...)
// =========================================================================
std::cout << "\nStep 3: Plan BatchPrefill\n";
fmha_batch_prefill_traits prefill_traits{};
prefill_traits.hdim_q = hdim;
prefill_traits.hdim_v = hdim;
prefill_traits.data_type = "fp16";
prefill_traits.is_group_mode = true;
prefill_traits.is_v_rowmajor = true;
prefill_traits.mask_type = mask_enum::no_mask;
prefill_traits.bias_type = bias_enum::no_bias;
prefill_traits.has_lse = true;
prefill_traits.kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
prefill_traits.kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
prefill_traits.page_size = 64;
fmha_batch_prefill_args prefill_args{};
prefill_args.batch = batch;
prefill_args.seqlen_q = seqlen;
prefill_args.seqlen_k = 1024;
prefill_args.max_seqlen_q = seqlen;
prefill_args.hdim_q = hdim;
prefill_args.hdim_v = hdim;
prefill_args.nhead_q = nhead;
prefill_args.nhead_k = nhead;
prefill_args.num_total_pages = 128;
prefill_args.page_block_size = 64;
prefill_args.kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
prefill_args.kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
prefill_args.kv_indptr = reinterpret_cast<void*>(0x1);
prefill_args.kv_page_indices = reinterpret_cast<void*>(0x1);
prefill_args.kv_last_page_lens = reinterpret_cast<void*>(0x1);
prefill_args.seqstart_q_ptr = reinterpret_cast<void*>(0x1);
auto prefill_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(prefill_traits, prefill_args), gfx_arch));
std::cout << " BatchPrefill plan valid: " << (prefill_plan.is_valid() ? "yes" : "no") << "\n";
if(prefill_plan.is_valid())
{
for(const auto& stage : prefill_plan.stages)
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
// =========================================================================
// Step 4: GPU Execution with basic fwd kernel (sanity check)
// =========================================================================
std::cout << "\nStep 4: Allocate GPU Buffers\n";
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
fmha_fwd_traits fwd_traits{};
fwd_traits.hdim_q = hdim;
fwd_traits.hdim_v = hdim;
fwd_traits.data_type = "fp16";
fwd_traits.is_group_mode = false;
fwd_traits.is_v_rowmajor = true;
fwd_traits.has_logits_soft_cap = false;
fwd_traits.mask_type = mask_enum::no_mask;
fwd_traits.bias_type = bias_enum::no_bias;
fwd_traits.has_lse = false;
fwd_traits.has_dropout = false;
fwd_traits.qscale_type = quant_scale_enum::no_scale;
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t k_elems = q_elems;
const int64_t v_elems = q_elems;
const int64_t o_elems = q_elems;
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
<< "]\n";
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(k_elems);
GpuBuffer<FmhaDataType> v_dev(v_elems);
GpuBuffer<FmhaDataType> o_dev(o_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems);
std::vector<FmhaDataType> k_host(k_elems);
std::vector<FmhaDataType> v_host(v_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
fmha_fwd_args fwd_args{};
fwd_args.q_ptr = q_dev.get();
fwd_args.k_ptr = k_dev.get();
fwd_args.v_ptr = v_dev.get();
fwd_args.o_ptr = o_dev.get();
fwd_args.bias_ptr = nullptr;
fwd_args.q_descale_ptr = nullptr;
fwd_args.k_descale_ptr = nullptr;
fwd_args.v_descale_ptr = nullptr;
fwd_args.rand_val_ptr = nullptr;
fwd_args.lse_ptr = nullptr;
fwd_args.sink_ptr = nullptr;
fwd_args.block_scale_seqstart_q_ptr = nullptr;
fwd_args.block_scale_seqstart_k_ptr = nullptr;
fwd_args.seqlen_q = seqlen;
fwd_args.seqlen_k = seqlen;
fwd_args.batch = batch;
fwd_args.max_seqlen_q = seqlen;
fwd_args.hdim_q = hdim;
fwd_args.hdim_v = hdim;
fwd_args.nhead_q = nhead;
fwd_args.nhead_k = nhead;
fwd_args.scale_s = scale;
fwd_args.logits_soft_cap = 0.0f;
fwd_args.stride_q = hdim;
fwd_args.stride_k = hdim;
fwd_args.stride_v = hdim;
fwd_args.stride_bias = 0;
fwd_args.stride_randval = 0;
fwd_args.stride_o = hdim;
fwd_args.nhead_stride_q = seqlen * hdim;
fwd_args.nhead_stride_k = seqlen * hdim;
fwd_args.nhead_stride_v = seqlen * hdim;
fwd_args.nhead_stride_bias = 0;
fwd_args.nhead_stride_randval = 0;
fwd_args.nhead_stride_lse = 0;
fwd_args.nhead_stride_o = seqlen * hdim;
fwd_args.nhead_stride_q_descale = 0;
fwd_args.nhead_stride_k_descale = 0;
fwd_args.nhead_stride_v_descale = 0;
fwd_args.batch_stride_q = nhead * seqlen * hdim;
fwd_args.batch_stride_k = nhead * seqlen * hdim;
fwd_args.batch_stride_v = nhead * seqlen * hdim;
fwd_args.batch_stride_bias = 0;
fwd_args.batch_stride_randval = 0;
fwd_args.batch_stride_lse = 0;
fwd_args.batch_stride_o = nhead * seqlen * hdim;
fwd_args.batch_stride_q_descale = 0;
fwd_args.batch_stride_k_descale = 0;
fwd_args.batch_stride_v_descale = 0;
fwd_args.window_size_left = -1;
fwd_args.window_size_right = -1;
fwd_args.sink_size = 0;
fwd_args.mask_type = 0;
fwd_args.min_seqlen_q = 0;
fwd_args.p_drop = 0.0f;
fwd_args.s_randval = false;
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fwd_args.block_scale_size_q = 0;
fwd_args.block_scale_size_kv = 0;
// Step 5: Run on GPU
std::cout << "\nStep 5: Run FMHA Forward on GPU\n";
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
return 1;
}
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Step 6: Validate
std::cout << "\nStep 6: Validate\n";
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < o_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
bool passed = (nonzero > 0);
if(args.has("--validate"))
{
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < k_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < v_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
double max_abs_err = 0.0;
double max_rel_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < o_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
max_abs_err = std::max(max_abs_err, abs_err);
max_rel_err = std::max(max_rel_err, rel_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Max rel error: " << max_rel_err << "\n";
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
passed = (errors == 0);
}
print_separator();
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,526 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 26: Multiple Data Types and Head Dimensions with GPU Execution
//
// Demonstrates:
// 1. Declare bf16 hdim=128, fp16 hdim=64, and fp16 hdim=128 kernels
// 2. Run each variant on GPU with appropriate buffer types
// 3. Validate with different tolerances: fp16 (rtol=1e-3), bf16 (rtol=1e-2)
// 4. Mention fp32, fp8bf16, fp8fp32, hdim 256, asymmetric hdim as planning
//
// Mirrors 01_basic_fmha.cpp for FMHA.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(dtypes_hdims_kernels,
// bf16 hdim=128
.add(FmhaSignature()
.family("fwd")
.dtype("bf16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// fp16 hdim=64
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(64)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(64)
.tile_k0(32)
.tile_n1(64)
.tile_k1(32)
.tile_k0max(64)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(16)
.warp_n1(16)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(64, 64)
.selection_rank(0),
"gfx950")
// fp16 hdim=128 (reference baseline)
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
using Fp16Type = ck_tile::fp16_t;
using Bf16Type = ck_tile::bf16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
struct VariantResult
{
std::string label;
float time_ms;
double tflops;
double max_abs_err;
double max_rel_err;
int errors;
bool passed;
};
template <typename DataType>
fmha_fwd_args make_fwd_args(GpuBuffer<DataType>& q_dev,
GpuBuffer<DataType>& k_dev,
GpuBuffer<DataType>& v_dev,
GpuBuffer<DataType>& o_dev,
int batch,
int nhead,
int seqlen,
int hdim,
float scale)
{
fmha_fwd_args a{};
a.q_ptr = q_dev.get();
a.k_ptr = k_dev.get();
a.v_ptr = v_dev.get();
a.o_ptr = o_dev.get();
a.bias_ptr = nullptr;
a.q_descale_ptr = nullptr;
a.k_descale_ptr = nullptr;
a.v_descale_ptr = nullptr;
a.rand_val_ptr = nullptr;
a.lse_ptr = nullptr;
a.sink_ptr = nullptr;
a.block_scale_seqstart_q_ptr = nullptr;
a.block_scale_seqstart_k_ptr = nullptr;
a.seqlen_q = seqlen;
a.seqlen_k = seqlen;
a.batch = batch;
a.max_seqlen_q = seqlen;
a.hdim_q = hdim;
a.hdim_v = hdim;
a.nhead_q = nhead;
a.nhead_k = nhead;
a.scale_s = scale;
a.logits_soft_cap = 0.0f;
a.stride_q = hdim;
a.stride_k = hdim;
a.stride_v = hdim;
a.stride_bias = 0;
a.stride_randval = 0;
a.stride_o = hdim;
a.nhead_stride_q = seqlen * hdim;
a.nhead_stride_k = seqlen * hdim;
a.nhead_stride_v = seqlen * hdim;
a.nhead_stride_bias = 0;
a.nhead_stride_randval = 0;
a.nhead_stride_lse = 0;
a.nhead_stride_o = seqlen * hdim;
a.nhead_stride_q_descale = 0;
a.nhead_stride_k_descale = 0;
a.nhead_stride_v_descale = 0;
a.batch_stride_q = nhead * seqlen * hdim;
a.batch_stride_k = nhead * seqlen * hdim;
a.batch_stride_v = nhead * seqlen * hdim;
a.batch_stride_bias = 0;
a.batch_stride_randval = 0;
a.batch_stride_lse = 0;
a.batch_stride_o = nhead * seqlen * hdim;
a.batch_stride_q_descale = 0;
a.batch_stride_k_descale = 0;
a.batch_stride_v_descale = 0;
a.window_size_left = -1;
a.window_size_right = -1;
a.sink_size = 0;
a.mask_type = 0;
a.min_seqlen_q = 0;
a.p_drop = 0.0f;
a.s_randval = false;
a.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
a.block_scale_size_q = 0;
a.block_scale_size_kv = 0;
return a;
}
template <typename DataType>
VariantResult run_variant(FmhaDispatcher& dispatcher,
const std::string& label,
const std::string& dtype_str,
int batch,
int nhead,
int seqlen,
int hdim,
double rtol,
double atol,
const std::string& gfx_arch)
{
VariantResult result{};
result.label = label;
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
const int64_t elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = dtype_str;
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
GpuBuffer<DataType> q_dev(elems);
GpuBuffer<DataType> k_dev(elems);
GpuBuffer<DataType> v_dev(elems);
GpuBuffer<DataType> o_dev(elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<DataType> q_host(elems);
std::vector<DataType> k_host(elems);
std::vector<DataType> v_host(elems);
for(auto& x : q_host)
x = DataType(dist(rng));
for(auto& x : k_host)
x = DataType(dist(rng));
for(auto& x : v_host)
x = DataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
auto fwd_args = make_fwd_args(q_dev, k_dev, v_dev, o_dev, batch, nhead, seqlen, hdim, scale);
try
{
result.time_ms = dispatcher.run_fwd(traits, fwd_args, nullptr);
}
catch(const std::exception& e)
{
std::cerr << " ERROR [" << label << "]: " << e.what() << "\n";
result.passed = false;
return result;
}
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch);
result.tflops = static_cast<double>(problem.num_ops()) / (result.time_ms * 1e-3) / 1e12;
std::vector<DataType> o_host(elems);
o_dev.copy_to_host(o_host.data());
std::vector<float> q_f32(elems), k_f32(elems), v_f32(elems), o_ref(elems, 0.0f);
for(int64_t i = 0; i < elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
result.max_abs_err = 0.0;
result.max_rel_err = 0.0;
result.errors = 0;
for(int64_t i = 0; i < elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
result.max_abs_err = std::max(result.max_abs_err, abs_err);
result.max_rel_err = std::max(result.max_rel_err, rel_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++result.errors;
}
result.passed = (result.errors == 0);
return result;
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 26: Dtypes & Hdims FMHA",
"FMHA with multiple data types and head dimensions");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
args.add_flag("--validate", "Validate against CPU reference");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
print_header("Example 26: Multiple Data Types & Head Dimensions");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("dtypes_hdims");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// =========================================================================
// Step 2: Run variants on GPU
// =========================================================================
std::cout << "\nStep 2: Run Variants\n";
// fp16 hdim=128 (reference baseline)
std::cout << "\n --- fp16 hdim=128 (reference) ---\n";
auto r_fp16_h128 = run_variant<Fp16Type>(dispatcher,
"fp16_h128",
"fp16",
batch,
nhead,
seqlen,
128,
/*rtol=*/1e-3,
/*atol=*/1e-3,
gfx_arch);
// bf16 hdim=128 (wider tolerance due to reduced precision)
std::cout << "\n --- bf16 hdim=128 ---\n";
auto r_bf16_h128 = run_variant<Bf16Type>(dispatcher,
"bf16_h128",
"bf16",
batch,
nhead,
seqlen,
128,
/*rtol=*/1e-2,
/*atol=*/1e-2,
gfx_arch);
// fp16 hdim=64 (smaller buffers)
std::cout << "\n --- fp16 hdim=64 ---\n";
auto r_fp16_h64 = run_variant<Fp16Type>(dispatcher,
"fp16_h64",
"fp16",
batch,
nhead,
seqlen,
64,
/*rtol=*/1e-3,
/*atol=*/1e-3,
gfx_arch);
// =========================================================================
// Step 3: Results Summary
// =========================================================================
std::cout << "\nStep 3: Results Summary\n\n";
std::cout << " " << std::setw(14) << "Variant" << " | " << std::setw(10) << "Time(ms)" << " | "
<< std::setw(10) << "TFLOPS" << " | " << std::setw(10) << "MaxAbsErr" << " | "
<< std::setw(10) << "MaxRelErr" << " | " << std::setw(8) << "Errors" << " | "
<< std::setw(6) << "Status" << "\n";
std::cout << " " << std::string(82, '-') << "\n";
auto print_row = [](const VariantResult& r) {
std::cout << std::fixed;
std::cout << " " << std::setw(14) << r.label << " | " << std::setprecision(4)
<< std::setw(10) << r.time_ms << " | " << std::setprecision(2) << std::setw(10)
<< r.tflops << " | " << std::scientific << std::setw(10) << r.max_abs_err << " | "
<< std::setw(10) << r.max_rel_err << " | " << std::fixed << std::setw(8)
<< r.errors << " | " << std::setw(6) << (r.passed ? "PASS" : "FAIL") << "\n";
};
print_row(r_fp16_h128);
print_row(r_bf16_h128);
print_row(r_fp16_h64);
// =========================================================================
// Step 4: Tolerance Notes
// =========================================================================
std::cout << "\nStep 4: Tolerance Notes\n";
std::cout << " fp16 validation: rtol=1e-3, atol=1e-3 (higher precision)\n";
std::cout << " bf16 validation: rtol=1e-2, atol=1e-2 (wider tolerance for bfloat16)\n";
std::cout << "\n Additional dtype/hdim combinations (planning-level declarations):\n";
std::cout << " fp32: .dtype(\"fp32\") - full single precision\n";
std::cout << " fp8bf16: .dtype(\"fp8bf16\") - fp8 compute, bf16 output\n";
std::cout << " fp8fp32: .dtype(\"fp8fp32\") - fp8 compute, fp32 output\n";
std::cout << " hdim 256: .hdim(256), tile(128,128,32,256,32,256)\n";
std::cout << " asymmetric: .hdim_q(128), .hdim_v(64) - different Q/V dims\n";
bool all_passed = r_fp16_h128.passed && r_bf16_h128.passed && r_fp16_h64.passed;
print_separator();
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
print_separator();
return all_passed ? 0 : 1;
}

View File

@@ -0,0 +1,635 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 27: Padding, Group Mode, V Col-Major, Permutation Patterns
//
// Demonstrates:
// 1. Batch padding with cu_seqlen arrays for per-batch variable lengths
// 2. Group mode with seqstart_q / seqstart_k buffers
// 3. V col-major layout declaration: .vlayout("c")
// 4. Permutation patterns: bhsd (iperm=1) vs bshd (iperm=0) strides
// 5. GPU execution with basic kernel + batch padding
//
// Mirrors 01_basic_fmha.cpp for FMHA.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <numeric>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(padding_permutation_kernels,
// Basic fwd kernel (batch mode, for GPU execution)
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Group mode kernel (variable-length sequences)
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("group")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// V col-major layout declaration
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("c")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 27: Padding & Permutation FMHA",
"FMHA padding, group mode, V col-major, and permutation patterns");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
args.add_option("--hdim", "128", "Head dimension");
args.add_flag("--validate", "Validate against CPU reference");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 27: Padding, Group Mode, V Col-Major, Permutation");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("padding_permutation");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
// =========================================================================
// Step 2: Batch Padding Pattern
// Allocate cu_seqlen_q / cu_seqlen_k buffers with cumulative sums.
// In CK's dispatcher, this maps to seqstart_q_ptr / seqstart_k_ptr
// and requires group mode to enable per-batch variable sequence lengths.
// =========================================================================
std::cout << "\nStep 2: Batch Padding Pattern (cu_seqlen)\n";
{
// Per-batch sequence lengths: batch 0 has seqlen=32, batch 1 has seqlen=48
const std::vector<int32_t> seqlens_q = {32, 48};
const std::vector<int32_t> seqlens_k = {32, 48};
const int num_batches = static_cast<int>(seqlens_q.size());
// Build cumulative sum arrays: [0, 32, 80]
std::vector<int32_t> cu_seqlen_q(num_batches + 1, 0);
std::vector<int32_t> cu_seqlen_k(num_batches + 1, 0);
for(int i = 0; i < num_batches; ++i)
{
cu_seqlen_q[i + 1] = cu_seqlen_q[i] + seqlens_q[i];
cu_seqlen_k[i + 1] = cu_seqlen_k[i] + seqlens_k[i];
}
const int total_q = cu_seqlen_q.back();
const int total_k = cu_seqlen_k.back();
const int max_sq = *std::max_element(seqlens_q.begin(), seqlens_q.end());
std::cout << " Batch seqlens_q: [";
for(int i = 0; i < num_batches; ++i)
std::cout << (i ? ", " : "") << seqlens_q[i];
std::cout << "]\n";
std::cout << " cu_seqlen_q: [";
for(size_t i = 0; i < cu_seqlen_q.size(); ++i)
std::cout << (i ? ", " : "") << cu_seqlen_q[i];
std::cout << "]\n";
GpuBuffer<int32_t> cu_sq_dev(num_batches + 1);
GpuBuffer<int32_t> cu_sk_dev(num_batches + 1);
cu_sq_dev.copy_from_host(cu_seqlen_q.data());
cu_sk_dev.copy_from_host(cu_seqlen_k.data());
// Group mode traits for variable-length sequences
fmha_fwd_traits pad_traits{};
pad_traits.hdim_q = hdim;
pad_traits.hdim_v = hdim;
pad_traits.data_type = "fp16";
pad_traits.is_group_mode = true;
pad_traits.is_v_rowmajor = true;
pad_traits.has_logits_soft_cap = false;
pad_traits.mask_type = mask_enum::no_mask;
pad_traits.bias_type = bias_enum::no_bias;
pad_traits.has_lse = false;
pad_traits.has_dropout = false;
pad_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args pad_args{};
pad_args.seqlen_q = total_q;
pad_args.seqlen_k = total_k;
pad_args.batch = num_batches;
pad_args.max_seqlen_q = max_sq;
pad_args.hdim_q = hdim;
pad_args.hdim_v = hdim;
pad_args.nhead_q = nhead;
pad_args.nhead_k = nhead;
pad_args.scale_s = scale;
// cu_seqlen_q_ptr / cu_seqlen_k_ptr (seqstart_q / seqstart_k in CK)
pad_args.seqstart_q_ptr = cu_sq_dev.get();
pad_args.seqstart_k_ptr = cu_sk_dev.get();
auto pad_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(pad_traits, pad_args), gfx_arch));
std::cout << " Batch padding plan valid: " << (pad_plan.is_valid() ? "yes" : "no") << "\n";
}
// =========================================================================
// Step 3: Group Mode Pattern
// Group mode uses seqstart_q / seqstart_k arrays to define variable
// sequence boundaries. Each batch element can have a different length.
// traits.is_group_mode = true
// =========================================================================
std::cout << "\nStep 3: Group Mode Pattern (seqstart)\n";
{
fmha_fwd_traits group_traits{};
group_traits.hdim_q = hdim;
group_traits.hdim_v = hdim;
group_traits.data_type = "fp16";
group_traits.is_group_mode = true;
group_traits.is_v_rowmajor = true;
group_traits.has_logits_soft_cap = false;
group_traits.mask_type = mask_enum::no_mask;
group_traits.bias_type = bias_enum::no_bias;
group_traits.has_lse = false;
group_traits.has_dropout = false;
group_traits.qscale_type = quant_scale_enum::no_scale;
const std::vector<int32_t> seqstart_q = {0, 64, 192};
const std::vector<int32_t> seqstart_k = {0, 128, 256};
const int num_batches = static_cast<int>(seqstart_q.size()) - 1;
const int total_q = seqstart_q.back();
const int max_sq = 128;
GpuBuffer<int32_t> ss_q_dev(seqstart_q.size());
GpuBuffer<int32_t> ss_k_dev(seqstart_k.size());
ss_q_dev.copy_from_host(seqstart_q.data());
ss_k_dev.copy_from_host(seqstart_k.data());
fmha_fwd_args group_args{};
group_args.seqlen_q = total_q;
group_args.seqlen_k = seqstart_k.back();
group_args.batch = num_batches;
group_args.max_seqlen_q = max_sq;
group_args.hdim_q = hdim;
group_args.hdim_v = hdim;
group_args.nhead_q = nhead;
group_args.nhead_k = nhead;
group_args.scale_s = scale;
group_args.seqstart_q_ptr = ss_q_dev.get();
group_args.seqstart_k_ptr = ss_k_dev.get();
std::cout << " seqstart_q: [0, 64, 192] -> batches of length 64 and 128\n";
std::cout << " seqstart_k: [0, 128, 256] -> KV of length 128 and 128\n";
auto group_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(group_traits, group_args), gfx_arch));
std::cout << " Group mode plan valid: " << (group_plan.is_valid() ? "yes" : "no") << "\n";
}
// =========================================================================
// Step 4: V Col-Major Declaration
// .vlayout("c") declares V in column-major layout (seqlen_k x hdim_v
// stored column-first). This affects how the kernel reads V.
// =========================================================================
std::cout << "\nStep 4: V Col-Major Layout\n";
{
fmha_fwd_traits vcol_traits{};
vcol_traits.hdim_q = hdim;
vcol_traits.hdim_v = hdim;
vcol_traits.data_type = "fp16";
vcol_traits.is_group_mode = false;
vcol_traits.is_v_rowmajor = false;
vcol_traits.has_logits_soft_cap = false;
vcol_traits.mask_type = mask_enum::no_mask;
vcol_traits.bias_type = bias_enum::no_bias;
vcol_traits.has_lse = false;
vcol_traits.has_dropout = false;
vcol_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args vcol_args{};
vcol_args.batch = batch;
vcol_args.seqlen_q = seqlen;
vcol_args.seqlen_k = seqlen;
vcol_args.max_seqlen_q = seqlen;
vcol_args.hdim_q = hdim;
vcol_args.hdim_v = hdim;
vcol_args.nhead_q = nhead;
vcol_args.nhead_k = nhead;
vcol_args.scale_s = scale;
std::cout << " V row-major (.vlayout(\"r\")): stride_v = hdim, "
"contiguous along head dimension\n";
std::cout << " V col-major (.vlayout(\"c\")): stride_v = seqlen_k, "
"contiguous along sequence dimension\n";
std::cout << " traits.is_v_rowmajor = false\n";
auto vcol_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(vcol_traits, vcol_args), gfx_arch));
std::cout << " V col-major plan valid: " << (vcol_plan.is_valid() ? "yes" : "no") << "\n";
}
// =========================================================================
// Step 5: Permutation Patterns (bhsd vs bshd)
// bhsd layout (iperm=1): stride_q = hdim, nhead_stride_q = seqlen*hdim
// memory: [batch, head, seq, dim]
// bshd layout (iperm=0): stride_q = nhead*hdim, nhead_stride_q = hdim
// memory: [batch, seq, head, dim]
// =========================================================================
std::cout << "\nStep 5: Permutation Patterns\n";
{
std::cout << " bhsd layout (iperm=1):\n";
std::cout << " stride_q = hdim = " << hdim << "\n";
std::cout << " nhead_stride_q = seqlen * hdim = " << seqlen * hdim << "\n";
std::cout << " batch_stride_q = nhead * seqlen * hdim = " << nhead * seqlen * hdim
<< "\n";
std::cout << " memory order: [batch, head, seq, dim]\n";
std::cout << "\n bshd layout (iperm=0):\n";
std::cout << " stride_q = nhead * hdim = " << nhead * hdim << "\n";
std::cout << " nhead_stride_q = hdim = " << hdim << "\n";
std::cout << " batch_stride_q = seqlen * nhead * hdim = " << seqlen * nhead * hdim
<< "\n";
std::cout << " memory order: [batch, seq, head, dim]\n";
}
// =========================================================================
// Step 6: GPU Execution with basic kernel + batch padding
// Run the batch-mode kernel with a non-tile-aligned seqlen to exercise
// the .padding(true, true, true, true) capability.
// =========================================================================
std::cout << "\nStep 6: GPU Execution (batch mode, seqlen=" << seqlen << ")\n";
fmha_fwd_traits fwd_traits{};
fwd_traits.hdim_q = hdim;
fwd_traits.hdim_v = hdim;
fwd_traits.data_type = "fp16";
fwd_traits.is_group_mode = false;
fwd_traits.is_v_rowmajor = true;
fwd_traits.has_logits_soft_cap = false;
fwd_traits.mask_type = mask_enum::no_mask;
fwd_traits.bias_type = bias_enum::no_bias;
fwd_traits.has_lse = false;
fwd_traits.has_dropout = false;
fwd_traits.qscale_type = quant_scale_enum::no_scale;
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t k_elems = q_elems;
const int64_t v_elems = q_elems;
const int64_t o_elems = q_elems;
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
<< "]\n";
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(k_elems);
GpuBuffer<FmhaDataType> v_dev(v_elems);
GpuBuffer<FmhaDataType> o_dev(o_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems);
std::vector<FmhaDataType> k_host(k_elems);
std::vector<FmhaDataType> v_host(v_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
fmha_fwd_args fwd_args{};
fwd_args.q_ptr = q_dev.get();
fwd_args.k_ptr = k_dev.get();
fwd_args.v_ptr = v_dev.get();
fwd_args.o_ptr = o_dev.get();
fwd_args.bias_ptr = nullptr;
fwd_args.q_descale_ptr = nullptr;
fwd_args.k_descale_ptr = nullptr;
fwd_args.v_descale_ptr = nullptr;
fwd_args.rand_val_ptr = nullptr;
fwd_args.lse_ptr = nullptr;
fwd_args.sink_ptr = nullptr;
fwd_args.block_scale_seqstart_q_ptr = nullptr;
fwd_args.block_scale_seqstart_k_ptr = nullptr;
fwd_args.seqlen_q = seqlen;
fwd_args.seqlen_k = seqlen;
fwd_args.batch = batch;
fwd_args.max_seqlen_q = seqlen;
fwd_args.hdim_q = hdim;
fwd_args.hdim_v = hdim;
fwd_args.nhead_q = nhead;
fwd_args.nhead_k = nhead;
fwd_args.scale_s = scale;
fwd_args.logits_soft_cap = 0.0f;
// bhsd layout strides (iperm=1)
fwd_args.stride_q = hdim;
fwd_args.stride_k = hdim;
fwd_args.stride_v = hdim;
fwd_args.stride_bias = 0;
fwd_args.stride_randval = 0;
fwd_args.stride_o = hdim;
fwd_args.nhead_stride_q = seqlen * hdim;
fwd_args.nhead_stride_k = seqlen * hdim;
fwd_args.nhead_stride_v = seqlen * hdim;
fwd_args.nhead_stride_bias = 0;
fwd_args.nhead_stride_randval = 0;
fwd_args.nhead_stride_lse = 0;
fwd_args.nhead_stride_o = seqlen * hdim;
fwd_args.nhead_stride_q_descale = 0;
fwd_args.nhead_stride_k_descale = 0;
fwd_args.nhead_stride_v_descale = 0;
fwd_args.batch_stride_q = nhead * seqlen * hdim;
fwd_args.batch_stride_k = nhead * seqlen * hdim;
fwd_args.batch_stride_v = nhead * seqlen * hdim;
fwd_args.batch_stride_bias = 0;
fwd_args.batch_stride_randval = 0;
fwd_args.batch_stride_lse = 0;
fwd_args.batch_stride_o = nhead * seqlen * hdim;
fwd_args.batch_stride_q_descale = 0;
fwd_args.batch_stride_k_descale = 0;
fwd_args.batch_stride_v_descale = 0;
fwd_args.window_size_left = -1;
fwd_args.window_size_right = -1;
fwd_args.sink_size = 0;
fwd_args.mask_type = 0;
fwd_args.min_seqlen_q = 0;
fwd_args.p_drop = 0.0f;
fwd_args.s_randval = false;
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fwd_args.block_scale_size_q = 0;
fwd_args.block_scale_size_kv = 0;
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
return 1;
}
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Step 7: Validate
std::cout << "\nStep 7: Validate\n";
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < o_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
bool passed = (nonzero > 0);
if(args.has("--validate"))
{
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < k_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < v_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
double max_abs_err = 0.0;
double max_rel_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < o_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
max_abs_err = std::max(max_abs_err, abs_err);
max_rel_err = std::max(max_rel_err, rel_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Max rel error: " << max_rel_err << "\n";
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
passed = (errors == 0);
}
print_separator();
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,489 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 28: FMHA Backward with Causal Mask
//
// Demonstrates:
// 1. Forward kernel with top_left causal mask + LSE
// 2. Backward kernel families (bwd_dot_do_o, bwd_dq_dk_dv, bwd_convert_dq) with causal mask
// 3. GPU forward execution with causal mask validation
// 4. Backward 3-stage plan display
//
// Backward kernels use planning only -- actual backward GPU execution requires
// all 3 stages to compile, and bwd_dq_dk_dv has tile structure issues on gfx950.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(bwd_masks_fmha_kernels,
// Forward: causal mask (top_left) with LSE for backward
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("top_left")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Backward stage 1: dot(dO, O) with causal mask
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
// Backward stage 2: compute dQ, dK, dV with causal mask
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true)
.max_seq_len_q(0)
.selection_rank(0),
"gfx950")
// Backward stage 3: convert accumulated dQ from fp32 to fp16
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd_causal(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
std::vector<float>& LSE,
int batch,
int nhead,
int seqlen,
int hdim,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen; ++sq)
{
std::vector<float> scores(seqlen, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim; ++d)
{
int q_idx = ((b * nhead + h) * seqlen + sq) * hdim + d;
int k_idx = ((b * nhead + h) * seqlen + sk) * hdim + d;
dot += Q[q_idx] * K[k_idx];
}
float s = dot * scale;
// top_left causal: mask if sk > sq
if(sk > sq)
s = -1e30f;
scores[sk] = s;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
int lse_idx = (b * nhead + h) * seqlen + sq;
LSE[lse_idx] = max_score + std::log(sum_exp);
for(int sk = 0; sk < seqlen; ++sk)
scores[sk] /= sum_exp;
for(int dv = 0; dv < hdim; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen + sk) * hdim + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen + sq) * hdim + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 28: FMHA Backward with Masks",
"Causal mask forward (GPU) + backward plan");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
print_header("Example 28: FMHA Backward with Causal Mask");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("bwd_masks_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// Step 2: Plan backward (3-stage) with causal mask
std::cout << "\nStep 2: Plan Backward (causal mask)\n";
fmha_bwd_traits bwd_traits{};
bwd_traits.hdim_q = hdim;
bwd_traits.hdim_v = hdim;
bwd_traits.data_type = "fp16";
bwd_traits.is_group_mode = false;
bwd_traits.mask_type = mask_enum::mask_top_left;
bwd_traits.bias_type = bias_enum::no_bias;
bwd_traits.has_dbias = false;
bwd_traits.has_dropout = false;
bwd_traits.is_store_randval = false;
bwd_traits.is_deterministic = false;
fmha_bwd_args bwd_args{};
bwd_args.batch = batch;
bwd_args.seqlen_q = seqlen;
bwd_args.seqlen_k = seqlen;
bwd_args.max_seqlen_q = seqlen;
bwd_args.max_seqlen_k = seqlen;
bwd_args.hdim_q = hdim;
bwd_args.hdim_v = hdim;
bwd_args.nhead_q = nhead;
bwd_args.nhead_k = nhead;
auto bwd_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
if(bwd_plan.is_valid() && bwd_plan.stages.size() >= 2)
{
std::cout << " Backward plan stages (" << bwd_plan.stages.size() << "):\n";
for(const auto& stage : bwd_plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
}
else
{
std::cout << " Backward plan: INVALID or single-stage (expected 3 stages)\n";
std::cout << " This is expected -- backward planning shows the pattern\n";
}
// Step 3: Run forward on GPU with causal mask
std::cout << "\nStep 3: Run Forward (causal mask, GPU)\n";
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
GpuBuffer<float> lse_dev(lse_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
lse_dev.zero();
fmha_fwd_traits fwd_traits{};
fwd_traits.hdim_q = hdim;
fwd_traits.hdim_v = hdim;
fwd_traits.data_type = "fp16";
fwd_traits.is_group_mode = false;
fwd_traits.is_v_rowmajor = true;
fwd_traits.has_logits_soft_cap = false;
fwd_traits.mask_type = mask_enum::mask_top_left;
fwd_traits.bias_type = bias_enum::no_bias;
fwd_traits.has_lse = true;
fwd_traits.has_dropout = false;
fwd_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fwd_args{};
fwd_args.q_ptr = q_dev.get();
fwd_args.k_ptr = k_dev.get();
fwd_args.v_ptr = v_dev.get();
fwd_args.o_ptr = o_dev.get();
fwd_args.lse_ptr = lse_dev.get();
fwd_args.bias_ptr = nullptr;
fwd_args.q_descale_ptr = nullptr;
fwd_args.k_descale_ptr = nullptr;
fwd_args.v_descale_ptr = nullptr;
fwd_args.rand_val_ptr = nullptr;
fwd_args.sink_ptr = nullptr;
fwd_args.block_scale_seqstart_q_ptr = nullptr;
fwd_args.block_scale_seqstart_k_ptr = nullptr;
fwd_args.seqlen_q = seqlen;
fwd_args.seqlen_k = seqlen;
fwd_args.batch = batch;
fwd_args.max_seqlen_q = seqlen;
fwd_args.hdim_q = hdim;
fwd_args.hdim_v = hdim;
fwd_args.nhead_q = nhead;
fwd_args.nhead_k = nhead;
fwd_args.scale_s = scale;
fwd_args.logits_soft_cap = 0.0f;
fwd_args.stride_q = hdim;
fwd_args.stride_k = hdim;
fwd_args.stride_v = hdim;
fwd_args.stride_bias = 0;
fwd_args.stride_randval = 0;
fwd_args.stride_o = hdim;
fwd_args.nhead_stride_q = seqlen * hdim;
fwd_args.nhead_stride_k = seqlen * hdim;
fwd_args.nhead_stride_v = seqlen * hdim;
fwd_args.nhead_stride_bias = 0;
fwd_args.nhead_stride_randval = 0;
fwd_args.nhead_stride_lse = seqlen;
fwd_args.nhead_stride_o = seqlen * hdim;
fwd_args.nhead_stride_q_descale = 0;
fwd_args.nhead_stride_k_descale = 0;
fwd_args.nhead_stride_v_descale = 0;
fwd_args.batch_stride_q = nhead * seqlen * hdim;
fwd_args.batch_stride_k = nhead * seqlen * hdim;
fwd_args.batch_stride_v = nhead * seqlen * hdim;
fwd_args.batch_stride_bias = 0;
fwd_args.batch_stride_randval = 0;
fwd_args.batch_stride_lse = nhead * seqlen;
fwd_args.batch_stride_o = nhead * seqlen * hdim;
fwd_args.batch_stride_q_descale = 0;
fwd_args.batch_stride_k_descale = 0;
fwd_args.batch_stride_v_descale = 0;
fwd_args.window_size_left = -1;
fwd_args.window_size_right = 0;
fwd_args.sink_size = 0;
fwd_args.mask_type = 1; // top_left
fwd_args.min_seqlen_q = 0;
fwd_args.p_drop = 0.0f;
fwd_args.s_randval = false;
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fwd_args.block_scale_size_q = 0;
fwd_args.block_scale_size_kv = 0;
bool fwd_passed = false;
try
{
float fwd_time = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
std::cout << " Forward time: " << std::fixed << std::setprecision(4) << fwd_time
<< " ms\n";
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (fwd_time * 1e-3) / 1e12;
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
fwd_passed = true;
}
catch(const std::exception& e)
{
std::cerr << " Forward ERROR: " << e.what() << "\n";
}
// Step 4: Validate forward output
std::cout << "\nStep 4: Validate Forward Output\n";
if(fwd_passed)
{
std::vector<FmhaDataType> o_host(qkv_elems);
o_dev.copy_to_host(o_host.data());
std::vector<float> lse_host(lse_elems);
lse_dev.copy_to_host(lse_host.data());
std::vector<float> q_f32(qkv_elems), k_f32(qkv_elems), v_f32(qkv_elems);
for(int64_t i = 0; i < qkv_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < qkv_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < qkv_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
std::vector<float> o_ref(qkv_elems, 0.0f);
std::vector<float> lse_ref(lse_elems, 0.0f);
cpu_attention_fwd_causal(
q_f32, k_f32, v_f32, o_ref, lse_ref, batch, nhead, seqlen, hdim, scale);
double max_o_err = 0.0;
int o_errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < qkv_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
max_o_err = std::max(max_o_err, abs_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++o_errors;
}
double max_lse_err = 0.0;
int lse_reasonable = 0;
for(int64_t i = 0; i < lse_elems; ++i)
{
if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f)
++lse_reasonable;
max_lse_err =
std::max(max_lse_err, static_cast<double>(std::abs(lse_host[i] - lse_ref[i])));
}
std::cout << " Output max abs error: " << std::scientific << max_o_err << "\n";
std::cout << " Output errors: " << o_errors << " / " << qkv_elems << "\n";
std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n";
std::cout << " LSE max error: " << std::scientific << max_lse_err << "\n";
fwd_passed = (o_errors == 0) && (lse_reasonable == lse_elems);
}
// Step 5: Show backward API pattern
std::cout << "\nStep 5: Backward API Pattern (traits + args)\n";
std::cout << " bwd_traits.mask_type = mask_top_left\n";
std::cout << " bwd_traits.bias_type = no_bias\n";
std::cout << " bwd_traits.has_dropout = false\n";
std::cout << " bwd_traits.is_deterministic = false\n";
std::cout << " bwd_args.window_size_left = -1\n";
std::cout << " bwd_args.window_size_right = 0 (causal)\n";
std::cout << " bwd_args.mask_type = 1 (top_left)\n";
std::cout << " Backward plan resolves to " << bwd_plan.stages.size() << " stage(s)\n";
print_separator();
std::cout << "Status: " << (fwd_passed ? "PASS" : "FAIL") << "\n";
print_separator();
return fwd_passed ? 0 : 1;
}

View File

@@ -0,0 +1,615 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 29: FMHA Backward with ALiBi Bias + Dropout
//
// Demonstrates:
// 1. Forward kernel with alibi bias + dropout + LSE
// 2. Backward kernel families with alibi + dropout
// 3. GPU forward execution with alibi bias, validates output
// 4. Backward plan with all features enabled
// 5. How deterministic mode affects the backward plan
//
// Backward kernels use planning only -- actual backward GPU execution requires
// all 3 stages to compile, and bwd_dq_dk_dv has tile structure issues on gfx950.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(bwd_bias_dropout_fmha_kernels,
// Forward: alibi bias + dropout + LSE
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("alibi")
.lse(true)
.dropout(true)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Backward stage 1: dot(dO, O) with alibi + dropout (non-deterministic)
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("alibi")
.dropout(true)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
// Backward stage 2: dQ, dK, dV with alibi + dropout (non-deterministic)
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("alibi")
.dropout(true)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true)
.max_seq_len_q(0)
.selection_rank(0),
"gfx950")
// Backward stage 3: convert dQ with alibi + dropout (non-deterministic)
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("alibi")
.dropout(true)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
// Deterministic variants for comparison
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("alibi")
.dropout(true)
.dbias(false)
.store_randval(false)
.deterministic(true),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("alibi")
.dropout(true)
.dbias(false)
.store_randval(false)
.deterministic(true),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true)
.max_seq_len_q(0)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("alibi")
.dropout(true)
.dbias(false)
.store_randval(false)
.deterministic(true),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd_alibi(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
std::vector<float>& LSE,
int batch,
int nhead,
int seqlen,
int hdim,
float scale,
const std::vector<float>& alibi_slopes)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
const float slope = alibi_slopes[h];
for(int sq = 0; sq < seqlen; ++sq)
{
std::vector<float> scores(seqlen, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim; ++d)
{
int q_idx = ((b * nhead + h) * seqlen + sq) * hdim + d;
int k_idx = ((b * nhead + h) * seqlen + sk) * hdim + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale + slope * static_cast<float>(sk - sq);
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
int lse_idx = (b * nhead + h) * seqlen + sq;
LSE[lse_idx] = max_score + std::log(sum_exp);
for(int sk = 0; sk < seqlen; ++sk)
scores[sk] /= sum_exp;
for(int dv = 0; dv < hdim; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen + sk) * hdim + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen + sq) * hdim + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 29: FMHA Backward with Bias + Dropout",
"ALiBi bias + dropout forward (GPU) + backward plan with deterministic mode");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
print_header("Example 29: FMHA Backward with ALiBi Bias + Dropout");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("bwd_bias_dropout_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// Step 2: Plan backward (non-deterministic) with alibi + dropout
std::cout << "\nStep 2: Plan Backward (non-deterministic, alibi + dropout)\n";
fmha_bwd_traits bwd_traits{};
bwd_traits.hdim_q = hdim;
bwd_traits.hdim_v = hdim;
bwd_traits.data_type = "fp16";
bwd_traits.is_group_mode = false;
bwd_traits.mask_type = mask_enum::no_mask;
bwd_traits.bias_type = bias_enum::alibi;
bwd_traits.has_dbias = false;
bwd_traits.has_dropout = true;
bwd_traits.is_store_randval = false;
bwd_traits.is_deterministic = false;
fmha_bwd_args bwd_args{};
bwd_args.batch = batch;
bwd_args.seqlen_q = seqlen;
bwd_args.seqlen_k = seqlen;
bwd_args.max_seqlen_q = seqlen;
bwd_args.max_seqlen_k = seqlen;
bwd_args.hdim_q = hdim;
bwd_args.hdim_v = hdim;
bwd_args.nhead_q = nhead;
bwd_args.nhead_k = nhead;
auto nondet_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
if(nondet_plan.is_valid() && nondet_plan.stages.size() >= 2)
{
std::cout << " Non-deterministic plan stages (" << nondet_plan.stages.size() << "):\n";
for(const auto& stage : nondet_plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
}
else
{
std::cout << " Non-deterministic plan: INVALID or single-stage\n";
}
// Step 2b: Plan backward (deterministic) with alibi + dropout
std::cout << "\nStep 2b: Plan Backward (deterministic, alibi + dropout)\n";
fmha_bwd_traits det_traits = bwd_traits;
det_traits.is_deterministic = true;
auto det_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(det_traits, bwd_args), gfx_arch));
if(det_plan.is_valid() && det_plan.stages.size() >= 2)
{
std::cout << " Deterministic plan stages (" << det_plan.stages.size() << "):\n";
for(const auto& stage : det_plan.stages)
{
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
}
}
else
{
std::cout << " Deterministic plan: INVALID or single-stage\n";
}
std::cout << "\n Deterministic mode difference:\n";
std::cout << " Non-det: dQ accumulated via atomic adds (faster, non-reproducible)\n";
std::cout << " Det: dQ accumulated with split-stride (slower, bit-reproducible)\n";
// Step 3: Run forward on GPU with alibi bias + dropout
std::cout << "\nStep 3: Run Forward (alibi + dropout, GPU)\n";
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
const int64_t randval_elems = static_cast<int64_t>(batch) * nhead * seqlen * seqlen;
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
GpuBuffer<float> lse_dev(lse_elems);
GpuBuffer<uint8_t> rand_val_dev(randval_elems);
// ALiBi slopes: geometric series
std::vector<float> alibi_slopes_host(nhead);
for(int h = 0; h < nhead; ++h)
alibi_slopes_host[h] = -std::pow(2.0f, -(8.0f * (h + 1) / nhead));
GpuBuffer<float> alibi_slopes_dev(nhead);
alibi_slopes_dev.copy_from_host(alibi_slopes_host.data());
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
lse_dev.zero();
rand_val_dev.zero();
std::cout << " ALiBi slopes: [";
for(int h = 0; h < nhead; ++h)
{
if(h > 0)
std::cout << ", ";
std::cout << std::fixed << std::setprecision(4) << alibi_slopes_host[h];
}
std::cout << "]\n";
fmha_fwd_traits fwd_traits{};
fwd_traits.hdim_q = hdim;
fwd_traits.hdim_v = hdim;
fwd_traits.data_type = "fp16";
fwd_traits.is_group_mode = false;
fwd_traits.is_v_rowmajor = true;
fwd_traits.has_logits_soft_cap = false;
fwd_traits.mask_type = mask_enum::no_mask;
fwd_traits.bias_type = bias_enum::alibi;
fwd_traits.has_lse = true;
fwd_traits.has_dropout = true;
fwd_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fwd_args{};
fwd_args.q_ptr = q_dev.get();
fwd_args.k_ptr = k_dev.get();
fwd_args.v_ptr = v_dev.get();
fwd_args.o_ptr = o_dev.get();
fwd_args.lse_ptr = lse_dev.get();
fwd_args.bias_ptr = alibi_slopes_dev.get();
fwd_args.rand_val_ptr = rand_val_dev.get();
fwd_args.q_descale_ptr = nullptr;
fwd_args.k_descale_ptr = nullptr;
fwd_args.v_descale_ptr = nullptr;
fwd_args.sink_ptr = nullptr;
fwd_args.block_scale_seqstart_q_ptr = nullptr;
fwd_args.block_scale_seqstart_k_ptr = nullptr;
fwd_args.seqlen_q = seqlen;
fwd_args.seqlen_k = seqlen;
fwd_args.batch = batch;
fwd_args.max_seqlen_q = seqlen;
fwd_args.hdim_q = hdim;
fwd_args.hdim_v = hdim;
fwd_args.nhead_q = nhead;
fwd_args.nhead_k = nhead;
fwd_args.scale_s = scale;
fwd_args.logits_soft_cap = 0.0f;
fwd_args.stride_q = hdim;
fwd_args.stride_k = hdim;
fwd_args.stride_v = hdim;
fwd_args.stride_bias = 0; // alibi: per-head slope, no spatial stride
fwd_args.stride_randval = seqlen;
fwd_args.stride_o = hdim;
fwd_args.nhead_stride_q = seqlen * hdim;
fwd_args.nhead_stride_k = seqlen * hdim;
fwd_args.nhead_stride_v = seqlen * hdim;
fwd_args.nhead_stride_bias = 1; // alibi: stride between slopes
fwd_args.nhead_stride_randval = seqlen * seqlen;
fwd_args.nhead_stride_lse = seqlen;
fwd_args.nhead_stride_o = seqlen * hdim;
fwd_args.nhead_stride_q_descale = 0;
fwd_args.nhead_stride_k_descale = 0;
fwd_args.nhead_stride_v_descale = 0;
fwd_args.batch_stride_q = nhead * seqlen * hdim;
fwd_args.batch_stride_k = nhead * seqlen * hdim;
fwd_args.batch_stride_v = nhead * seqlen * hdim;
fwd_args.batch_stride_bias = 0; // alibi slopes shared across batch
fwd_args.batch_stride_randval = nhead * seqlen * seqlen;
fwd_args.batch_stride_lse = nhead * seqlen;
fwd_args.batch_stride_o = nhead * seqlen * hdim;
fwd_args.batch_stride_q_descale = 0;
fwd_args.batch_stride_k_descale = 0;
fwd_args.batch_stride_v_descale = 0;
fwd_args.window_size_left = -1;
fwd_args.window_size_right = -1;
fwd_args.sink_size = 0;
fwd_args.mask_type = 0;
fwd_args.min_seqlen_q = 0;
fwd_args.p_drop = 0.2f;
fwd_args.s_randval = true;
fwd_args.drop_seed_offset = std::make_pair(uint64_t(42), uint64_t(0));
fwd_args.block_scale_size_q = 0;
fwd_args.block_scale_size_kv = 0;
bool fwd_passed = false;
try
{
float fwd_time = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
std::cout << " Forward time: " << std::fixed << std::setprecision(4) << fwd_time
<< " ms\n";
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (fwd_time * 1e-3) / 1e12;
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
fwd_passed = true;
}
catch(const std::exception& e)
{
std::cerr << " Forward ERROR: " << e.what() << "\n";
}
// Step 4: Validate forward output (without dropout reference -- just check non-zero + LSE)
std::cout << "\nStep 4: Validate Forward Output\n";
if(fwd_passed)
{
std::vector<FmhaDataType> o_host(qkv_elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < qkv_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << " Non-zero outputs: " << nonzero << " / " << qkv_elems << "\n";
std::vector<float> lse_host(lse_elems);
lse_dev.copy_to_host(lse_host.data());
int lse_reasonable = 0;
for(int64_t i = 0; i < lse_elems; ++i)
{
if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f)
++lse_reasonable;
}
std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n";
std::cout << " LSE sample [0..3]: ";
for(int i = 0; i < std::min<int>(4, lse_elems); ++i)
std::cout << std::fixed << std::setprecision(4) << lse_host[i] << " ";
std::cout << "\n";
fwd_passed = (nonzero > 0) && (lse_reasonable == lse_elems);
// ALiBi reference (without dropout) for sanity check on bias effect
std::vector<float> q_f32(qkv_elems), k_f32(qkv_elems), v_f32(qkv_elems);
for(int64_t i = 0; i < qkv_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < qkv_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < qkv_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
std::vector<float> o_ref(qkv_elems, 0.0f);
std::vector<float> lse_ref(lse_elems, 0.0f);
cpu_attention_fwd_alibi(q_f32,
k_f32,
v_f32,
o_ref,
lse_ref,
batch,
nhead,
seqlen,
hdim,
scale,
alibi_slopes_host);
// LSE should be close (dropout doesn't change LSE in the CK implementation --
// LSE is computed before dropout is applied to the attention weights)
double max_lse_err = 0.0;
for(int64_t i = 0; i < lse_elems; ++i)
max_lse_err =
std::max(max_lse_err, static_cast<double>(std::abs(lse_host[i] - lse_ref[i])));
std::cout << " LSE vs alibi ref (no dropout) max error: " << std::scientific << max_lse_err
<< "\n";
}
// Step 5: Show backward API pattern with all features
std::cout << "\nStep 5: Backward API Pattern (all features)\n";
std::cout << " bwd_traits.bias_type = alibi\n";
std::cout << " bwd_traits.has_dropout = true\n";
std::cout << " bwd_traits.is_store_randval = false\n";
std::cout << " bwd_traits.has_dbias = false (alibi has no learnable params)\n";
std::cout << "\n Non-deterministic plan: " << nondet_plan.stages.size() << " stage(s)\n";
std::cout << " Deterministic plan: " << det_plan.stages.size() << " stage(s)\n";
std::cout << "\n Key backward args for dropout:\n";
std::cout << " bwd_args.p_drop = 0.2\n";
std::cout << " bwd_args.p_undrop = 1.0 / (1.0 - p_drop) = 1.25\n";
std::cout << " bwd_args.drop_seed_offset = {42, 0} (must match fwd)\n";
print_separator();
std::cout << "Status: " << (fwd_passed ? "PASS" : "FAIL") << "\n";
print_separator();
return fwd_passed ? 0 : 1;
}

View File

@@ -0,0 +1,449 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 30: FMHA Backward Benchmark
//
// Demonstrates:
// 1. Forward kernel for benchmark (with LSE for backward planning)
// 2. Multiple problem sizes: sweep batch x seqlen
// 3. GPU forward execution for each size with timing
// 4. Backward plan for each size
// 5. Summary table: Batch | SeqLen | Fwd(ms) | BwdPlan | FwdTFLOPS
//
// Backward kernels use planning only -- actual backward GPU execution requires
// all 3 stages to compile, and bwd_dq_dk_dv has tile structure issues on gfx950.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(bwd_bench_fmha_kernels,
// Forward: basic fp16 with LSE for backward
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Backward stage 1: dot(dO, O)
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
// Backward stage 2: dQ, dK, dV
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true)
.max_seq_len_q(0)
.selection_rank(0),
"gfx950")
// Backward stage 3: convert dQ
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("no")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
struct BenchResult
{
int batch;
int seqlen;
float fwd_ms;
double fwd_tflops;
int bwd_stages;
bool bwd_valid;
bool fwd_passed;
};
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 30: FMHA Backward Benchmark",
"Sweep batch x seqlen, forward GPU + backward plan");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--nhead", "8", "Number of heads");
args.add_option("--hdim", "128", "Head dimension");
args.add_option("--warmup", "2", "Warmup iterations per size");
args.add_option("--repeat", "3", "Benchmark repetitions per size");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int nhead = args.get_int("--nhead", 8);
const int hdim = args.get_int("--hdim", 128);
const int warmup = args.get_int("--warmup", 2);
const int repeat = args.get_int("--repeat", 3);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
print_header("Example 30: FMHA Backward Benchmark");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("bwd_bench_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
// Problem sizes to sweep
struct ProblemSize
{
int batch;
int seqlen;
};
ProblemSize sizes[] = {
{8, 128},
{4, 256},
{2, 512},
{1, 1024},
{1, 2048},
{1, 4096},
};
std::vector<BenchResult> results;
// Step 2: Sweep problem sizes
std::cout << "\nStep 2: Sweep Problem Sizes\n";
for(const auto& sz : sizes)
{
std::cout << "\n --- batch=" << sz.batch << ", seqlen=" << sz.seqlen << " ---\n";
const int64_t qkv_elems = static_cast<int64_t>(sz.batch) * nhead * sz.seqlen * hdim;
const int64_t lse_elems = static_cast<int64_t>(sz.batch) * nhead * sz.seqlen;
BenchResult res{};
res.batch = sz.batch;
res.seqlen = sz.seqlen;
// Allocate buffers
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
GpuBuffer<float> lse_dev(lse_elems);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
// Forward traits/args
fmha_fwd_traits fwd_traits{};
fwd_traits.hdim_q = hdim;
fwd_traits.hdim_v = hdim;
fwd_traits.data_type = "fp16";
fwd_traits.is_group_mode = false;
fwd_traits.is_v_rowmajor = true;
fwd_traits.has_logits_soft_cap = false;
fwd_traits.mask_type = mask_enum::no_mask;
fwd_traits.bias_type = bias_enum::no_bias;
fwd_traits.has_lse = true;
fwd_traits.has_dropout = false;
fwd_traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fwd_args{};
fwd_args.q_ptr = q_dev.get();
fwd_args.k_ptr = k_dev.get();
fwd_args.v_ptr = v_dev.get();
fwd_args.o_ptr = o_dev.get();
fwd_args.lse_ptr = lse_dev.get();
fwd_args.bias_ptr = nullptr;
fwd_args.q_descale_ptr = nullptr;
fwd_args.k_descale_ptr = nullptr;
fwd_args.v_descale_ptr = nullptr;
fwd_args.rand_val_ptr = nullptr;
fwd_args.sink_ptr = nullptr;
fwd_args.block_scale_seqstart_q_ptr = nullptr;
fwd_args.block_scale_seqstart_k_ptr = nullptr;
fwd_args.seqlen_q = sz.seqlen;
fwd_args.seqlen_k = sz.seqlen;
fwd_args.batch = sz.batch;
fwd_args.max_seqlen_q = sz.seqlen;
fwd_args.hdim_q = hdim;
fwd_args.hdim_v = hdim;
fwd_args.nhead_q = nhead;
fwd_args.nhead_k = nhead;
fwd_args.scale_s = scale;
fwd_args.logits_soft_cap = 0.0f;
fwd_args.stride_q = hdim;
fwd_args.stride_k = hdim;
fwd_args.stride_v = hdim;
fwd_args.stride_bias = 0;
fwd_args.stride_randval = 0;
fwd_args.stride_o = hdim;
fwd_args.nhead_stride_q = sz.seqlen * hdim;
fwd_args.nhead_stride_k = sz.seqlen * hdim;
fwd_args.nhead_stride_v = sz.seqlen * hdim;
fwd_args.nhead_stride_bias = 0;
fwd_args.nhead_stride_randval = 0;
fwd_args.nhead_stride_lse = sz.seqlen;
fwd_args.nhead_stride_o = sz.seqlen * hdim;
fwd_args.nhead_stride_q_descale = 0;
fwd_args.nhead_stride_k_descale = 0;
fwd_args.nhead_stride_v_descale = 0;
fwd_args.batch_stride_q = nhead * sz.seqlen * hdim;
fwd_args.batch_stride_k = nhead * sz.seqlen * hdim;
fwd_args.batch_stride_v = nhead * sz.seqlen * hdim;
fwd_args.batch_stride_bias = 0;
fwd_args.batch_stride_randval = 0;
fwd_args.batch_stride_lse = nhead * sz.seqlen;
fwd_args.batch_stride_o = nhead * sz.seqlen * hdim;
fwd_args.batch_stride_q_descale = 0;
fwd_args.batch_stride_k_descale = 0;
fwd_args.batch_stride_v_descale = 0;
fwd_args.window_size_left = -1;
fwd_args.window_size_right = -1;
fwd_args.sink_size = 0;
fwd_args.mask_type = 0;
fwd_args.min_seqlen_q = 0;
fwd_args.p_drop = 0.0f;
fwd_args.s_randval = false;
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fwd_args.block_scale_size_q = 0;
fwd_args.block_scale_size_kv = 0;
// Warmup
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 1);
try
{
for(int w = 0; w < warmup; ++w)
{
o_dev.zero();
lse_dev.zero();
dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
}
}
catch(const std::exception& e)
{
std::cerr << " Warmup ERROR: " << e.what() << "\n";
res.fwd_passed = false;
results.push_back(res);
continue;
}
// Benchmark
dispatcher.set_timing(0, 1);
float total_ms = 0.0f;
bool ok = true;
for(int r = 0; r < repeat; ++r)
{
o_dev.zero();
lse_dev.zero();
try
{
total_ms += dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
}
catch(const std::exception& e)
{
std::cerr << " Bench ERROR: " << e.what() << "\n";
ok = false;
break;
}
}
if(ok)
{
res.fwd_ms = total_ms / static_cast<float>(repeat);
auto problem =
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
res.fwd_tflops = static_cast<double>(problem.num_ops()) / (res.fwd_ms * 1e-3) / 1e12;
// Sanity check output
std::vector<FmhaDataType> o_host(qkv_elems);
o_dev.copy_to_host(o_host.data());
int nonzero = 0;
for(int64_t i = 0; i < qkv_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
res.fwd_passed = (nonzero > 0);
}
else
{
res.fwd_passed = false;
}
// Backward plan for this size
fmha_bwd_traits bwd_traits{};
bwd_traits.hdim_q = hdim;
bwd_traits.hdim_v = hdim;
bwd_traits.data_type = "fp16";
bwd_traits.is_group_mode = false;
bwd_traits.mask_type = mask_enum::no_mask;
bwd_traits.bias_type = bias_enum::no_bias;
bwd_traits.has_dbias = false;
bwd_traits.has_dropout = false;
bwd_traits.is_store_randval = false;
bwd_traits.is_deterministic = false;
fmha_bwd_args bwd_args{};
bwd_args.batch = sz.batch;
bwd_args.seqlen_q = sz.seqlen;
bwd_args.seqlen_k = sz.seqlen;
bwd_args.max_seqlen_q = sz.seqlen;
bwd_args.max_seqlen_k = sz.seqlen;
bwd_args.hdim_q = hdim;
bwd_args.hdim_v = hdim;
bwd_args.nhead_q = nhead;
bwd_args.nhead_k = nhead;
auto bwd_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
res.bwd_valid = bwd_plan.is_valid() && bwd_plan.stages.size() >= 2;
res.bwd_stages = static_cast<int>(bwd_plan.stages.size());
std::cout << " Fwd: " << std::fixed << std::setprecision(4) << res.fwd_ms << " ms, "
<< std::setprecision(2) << res.fwd_tflops << " TFLOPS"
<< " | Bwd plan: " << res.bwd_stages << " stages"
<< (res.bwd_valid ? " (valid)" : " (invalid)") << "\n";
results.push_back(res);
}
// Step 3: Summary table
std::cout << "\nStep 3: Summary\n\n";
std::cout << " " << std::setw(7) << "Batch" << " | " << std::setw(7) << "SeqLen" << " | "
<< std::setw(10) << "Fwd(ms)" << " | " << std::setw(8) << "BwdPlan" << " | "
<< std::setw(10) << "FwdTFLOPS" << " | " << std::setw(6) << "Status" << "\n";
std::cout << " " << std::string(60, '-') << "\n";
bool all_passed = true;
for(const auto& r : results)
{
std::cout << " " << std::setw(7) << r.batch << " | " << std::setw(7) << r.seqlen << " | "
<< std::fixed << std::setprecision(4) << std::setw(10) << r.fwd_ms << " | "
<< std::setw(5) << r.bwd_stages << "stg" << " | " << std::setprecision(2)
<< std::setw(10) << r.fwd_tflops << " | " << std::setw(6)
<< (r.fwd_passed ? "PASS" : "FAIL") << "\n";
if(!r.fwd_passed)
all_passed = false;
}
print_separator();
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
print_separator();
return all_passed ? 0 : 1;
}

View File

@@ -0,0 +1,118 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 31: FMHA Forward with Logits Soft Cap
//
// Demonstrates forward kernel with logits_soft_cap enabled. The soft cap
// applies: scores_capped = tanh(scores/cap) * cap, which prevents extreme
// attention logits from causing numerical instability while preserving
// gradients. Planning only.
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(logits_soft_cap_fmha_kernels,
// Forward with logits soft cap: tanh(scores/cap)*cap
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no")
.logits(true), // enables logits_soft_cap path
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
int main(int argc, char* argv[])
{
ExampleArgs args("Example 31: FMHA Logits Soft Cap", "Forward with tanh(scores/cap)*cap");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "128", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 31: FMHA Logits Soft Cap");
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
registry.set_name("logits_soft_cap_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
std::cout << "\nStep 2: Plan\n";
FmhaDispatcher dispatcher(&registry);
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = true; // runtime: cap > 0 means soft cap applied
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fwd_args{};
fwd_args.batch = batch;
fwd_args.seqlen_q = seqlen;
fwd_args.seqlen_k = seqlen;
fwd_args.nhead_q = nhead;
fwd_args.nhead_k = nhead;
fwd_args.hdim_q = hdim;
fwd_args.hdim_v = hdim;
fwd_args.logits_soft_cap = 30.0f; // cap value; apply tanh(scores/30)*30
auto plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch));
std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n";
std::cout << "\nStep 3: Logits Soft Cap\n";
std::cout << " Formula: scores_capped = tanh(scores/cap) * cap\n";
std::cout << " Prevents extreme logits while preserving gradients.\n";
print_separator();
return 0;
}

View File

@@ -0,0 +1,119 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 32: FMHA Forward with Sink Tokens
//
// Demonstrates forward kernel with sink tokens enabled. Sink tokens keep the
// first K positions always visible to all queries (StreamingLLM-style). Used
// with causal mask: positions [0, sink_size) are never masked. Planning only.
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(sink_tokens_fmha_kernels,
// Forward with sink: first K tokens always visible
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("top_left") // causal required with sink
.bias("no")
.lse(false)
.dropout(false)
.qscale("no")
.sink(true), // enables sink token path
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
int main(int argc, char* argv[])
{
ExampleArgs args("Example 32: FMHA Sink Tokens", "Forward with first K tokens always visible");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "128", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
args.add_option("--sink", "4", "Number of sink tokens");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
const int sink_size = args.get_int("--sink", 4);
print_header("Example 32: FMHA Sink Tokens");
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
registry.set_name("sink_tokens_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
std::cout << "\nStep 2: Plan\n";
FmhaDispatcher dispatcher(&registry);
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_sink = true;
traits.mask_type = mask_enum::mask_top_left;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fwd_args{};
fwd_args.batch = batch;
fwd_args.seqlen_q = seqlen;
fwd_args.seqlen_k = seqlen;
fwd_args.nhead_q = nhead;
fwd_args.nhead_k = nhead;
fwd_args.hdim_q = hdim;
fwd_args.hdim_v = hdim;
fwd_args.sink_size = sink_size;
auto plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch));
std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n";
std::cout << "\nStep 3: Sink Tokens\n";
std::cout << " First " << sink_size << " tokens always visible to all queries.\n";
std::cout << " Used with causal mask for StreamingLLM-style long-context.\n";
print_separator();
return 0;
}

View File

@@ -0,0 +1,256 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 33: FMHA Backward Deterministic vs Non-Deterministic
//
// Demonstrates two backward kernel sets: one deterministic (bit-identical
// results across runs) and one non-deterministic (faster, atomic reductions).
// Planning only.
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(bwd_deterministic_fmha_kernels,
// Forward: causal + LSE for backward
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("top_left")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
// Backward: deterministic (bit-identical across runs)
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(true),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(true),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true)
.max_seq_len_q(0)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(true),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
// Backward: non-deterministic (faster, atomic reductions)
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(1),
"gfx950")
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true)
.max_seq_len_q(0)
.selection_rank(1),
"gfx950")
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(1),
"gfx950"));
int main(int argc, char* argv[])
{
ExampleArgs args("Example 33: FMHA Backward Deterministic",
"Deterministic vs non-deterministic backward");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "128", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 33: FMHA Backward Deterministic");
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
registry.set_name("bwd_deterministic_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
std::cout << "\nStep 2: Plan (deterministic)\n";
FmhaDispatcher dispatcher(&registry);
fmha_bwd_traits det_traits{};
det_traits.hdim_q = hdim;
det_traits.hdim_v = hdim;
det_traits.data_type = "fp16";
det_traits.is_group_mode = false;
det_traits.mask_type = mask_enum::mask_top_left;
det_traits.bias_type = bias_enum::no_bias;
det_traits.has_dbias = false;
det_traits.has_dropout = false;
det_traits.is_store_randval = false;
det_traits.is_deterministic = true;
fmha_bwd_args bwd_args{};
bwd_args.batch = batch;
bwd_args.seqlen_q = seqlen;
bwd_args.seqlen_k = seqlen;
bwd_args.hdim_q = hdim;
bwd_args.hdim_v = hdim;
bwd_args.nhead_q = nhead;
bwd_args.nhead_k = nhead;
auto det_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(det_traits, bwd_args), gfx_arch));
std::cout << " Deterministic plan valid: " << (det_plan.is_valid() ? "yes" : "no") << "\n";
std::cout << "\nStep 3: Plan (non-deterministic)\n";
det_traits.is_deterministic = false;
auto nondet_plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(det_traits, bwd_args), gfx_arch));
std::cout << " Non-deterministic plan valid: " << (nondet_plan.is_valid() ? "yes" : "no")
<< "\n";
std::cout << "\nStep 4: Deterministic Mode\n";
std::cout << " deterministic=true: bit-identical across runs (reproducible).\n";
std::cout << " deterministic=false: faster, uses atomic reductions.\n";
print_separator();
return 0;
}

View File

@@ -0,0 +1,183 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 34: FMHA Backward with GQA (Grouped Query Attention)
//
// Demonstrates backward with nhead_q=8, nhead_k=2 (4:1 ratio). GQA is a
// runtime property: each KV head is shared by multiple Q heads. Backward
// handles head indexing via nhead_stride_dk/dv so dK/dV are reduced across
// the Q-head group. Planning only.
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(bwd_gqa_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("top_left")
.bias("no")
.lse(true)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("bwd_dot_do_o")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(32)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("bwd_dq_dk_dv")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(16)
.tile_n0(128)
.tile_k0(128)
.tile_n1(16)
.tile_k1(128)
.tile_k0max(32)
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
.padding(true, true, true, true)
.max_seq_len_q(0)
.selection_rank(0),
"gfx950")
.add(FmhaSignature()
.family("bwd_convert_dq")
.dtype("fp16")
.mode("batch")
.hdim(128)
.mask("top_left")
.bias("no")
.dropout(false)
.dbias(false)
.store_randval(false)
.deterministic(false),
FmhaAlgorithm()
.tile_m0(64)
.tile_n0(128)
.tile_k0(0)
.tile_n1(0)
.tile_k1(0)
.tile_k0max(0)
.padding(true, true, true, true)
.selection_rank(0),
"gfx950"));
int main(int argc, char* argv[])
{
ExampleArgs args("Example 34: FMHA Backward GQA", "nhead_q=8, nhead_k=2 (4:1 ratio)");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead_q", "8", "Query heads");
args.add_option("--nhead_k", "2", "KV heads (GQA ratio = nhead_q/nhead_k)");
args.add_option("--seqlen", "128", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead_q = args.get_int("--nhead_q", 8);
const int nhead_k = args.get_int("--nhead_k", 2);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 34: FMHA Backward GQA");
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
registry.set_name("bwd_gqa_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
std::cout << "\nStep 2: Plan (GQA nhead_q=" << nhead_q << ", nhead_k=" << nhead_k << ")\n";
FmhaDispatcher dispatcher(&registry);
fmha_bwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.mask_type = mask_enum::mask_top_left;
traits.bias_type = bias_enum::no_bias;
traits.has_dbias = false;
traits.has_dropout = false;
traits.is_store_randval = false;
traits.is_deterministic = false;
fmha_bwd_args bwd_args{};
bwd_args.batch = batch;
bwd_args.seqlen_q = seqlen;
bwd_args.seqlen_k = seqlen;
bwd_args.hdim_q = hdim;
bwd_args.hdim_v = hdim;
bwd_args.nhead_q = nhead_q;
bwd_args.nhead_k = nhead_k;
auto plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, bwd_args), gfx_arch));
std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n";
std::cout << "\nStep 3: GQA Backward Head Indexing\n";
std::cout << " Q heads " << nhead_q << ", KV heads " << nhead_k
<< " -> each KV head shared by " << (nhead_q / nhead_k) << " Q heads.\n";
std::cout << " dK/dV reduced across Q-head group via nhead_stride.\n";
print_separator();
return 0;
}

View File

@@ -0,0 +1,121 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 35: FMHA Forward with Generic/Window Mask
//
// Demonstrates forward kernel with generic (window) mask. Uses
// window_size_left and window_size_right: for each query i, attend only to
// keys in [i - left, i + right]. -1 means unbounded. Planning only.
#include <iostream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
DECL_FMHA_KERNEL_SET(generic_mask_fmha_kernels,
// Forward with generic/window mask
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r")
.hdim(128)
.mask("generic") // window mask via left/right
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true)
.alignments(128, 128)
.selection_rank(0),
"gfx950"));
int main(int argc, char* argv[])
{
ExampleArgs args("Example 35: FMHA Generic Mask", "Window mask via left/right params");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "128", "Sequence length");
args.add_option("--hdim", "128", "Head dimension");
args.add_option("--window_left", "64", "Window size left (-1=unbounded)");
args.add_option("--window_right", "0", "Window size right (-1=unbounded)");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 128);
const int hdim = args.get_int("--hdim", 128);
const int window_size_left = args.get_int("--window_left", 64);
const int window_size_right = args.get_int("--window_right", 0);
print_header("Example 35: FMHA Generic Mask");
std::cout << "\nStep 1: Register Kernels\n";
FmhaRegistry registry;
registry.set_name("generic_mask_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
std::cout << "\nStep 2: Plan\n";
FmhaDispatcher dispatcher(&registry);
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.mask_type = mask_enum::window_generic;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
fmha_fwd_args fwd_args{};
fwd_args.batch = batch;
fwd_args.seqlen_q = seqlen;
fwd_args.seqlen_k = seqlen;
fwd_args.nhead_q = nhead;
fwd_args.nhead_k = nhead;
fwd_args.hdim_q = hdim;
fwd_args.hdim_v = hdim;
fwd_args.window_size_left = window_size_left;
fwd_args.window_size_right = window_size_right;
auto plan = dispatcher.plan(
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch));
std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n";
std::cout << "\nStep 3: Window Mask Params\n";
std::cout << " window_size_left=" << window_size_left
<< ", window_size_right=" << window_size_right << "\n";
std::cout << " Query i attends to keys in [i-left, i+right]. -1 = unbounded.\n";
print_separator();
return 0;
}

View File

@@ -0,0 +1,259 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 01: Basic FMHA with Multiple Kernels
Demonstrates:
1. Building a Registry with multiple kernel configurations
2. Parallel JIT compilation via registry.build()
3. Running each kernel and validating output against CPU reference
4. Comparing performance across kernels
Usage:
python3 01_basic_fmha.py
python3 01_basic_fmha.py --dtype bf16
python3 01_basic_fmha.py --size 256
python3 01_basic_fmha.py --num-kernels 4
python3 01_basic_fmha.py --workers 4
"""
import sys
import time
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelSpec,
FmhaRegistry,
FmhaProblem,
cpu_attention_fwd,
detect_gpu_arch,
spec_to_config,
)
# FmhaKernelSpec fields:
# name -- human-readable kernel identifier
# hdim -- head dimension (hdim_q = hdim_v for symmetric attention)
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
#
# spec_to_config() fills in Stage 1 automatically:
# tile_n1 = hdim, tile_k1 = tile_k0, tile_k0max = hdim
# wave/warp use sensible defaults (4x1x1 wave, 32x32x16 warp)
KERNEL_SPECS = [
# Async pipelines -- different tile_m0 x tile_n0 combos
FmhaKernelSpec(
name="async_128x128_k32",
hdim=128,
pipeline="qr_async",
tile_m0=128,
tile_n0=128,
tile_k0=32,
),
FmhaKernelSpec(
name="async_128x64_k32",
hdim=128,
pipeline="qr_async",
tile_m0=128,
tile_n0=64,
tile_k0=32,
),
FmhaKernelSpec(
name="async_64x128_k32",
hdim=128,
pipeline="qr_async",
tile_m0=64,
tile_n0=128,
tile_k0=32,
),
FmhaKernelSpec(
name="async_64x64_k32",
hdim=128,
pipeline="qr_async",
tile_m0=64,
tile_n0=64,
tile_k0=32,
),
# Synchronous pipelines
FmhaKernelSpec(
name="sync_128x128_k32",
hdim=128,
pipeline="qr",
tile_m0=128,
tile_n0=128,
tile_k0=32,
),
FmhaKernelSpec(
name="sync_64x128_k32",
hdim=128,
pipeline="qr",
tile_m0=64,
tile_n0=128,
tile_k0=32,
),
FmhaKernelSpec(
name="sync_128x64_k32",
hdim=128,
pipeline="qr",
tile_m0=128,
tile_n0=64,
tile_k0=32,
),
# Different tile_k0 (K dimension of Q*K^T)
FmhaKernelSpec(
name="async_128x128_k64",
hdim=128,
pipeline="qr_async",
tile_m0=128,
tile_n0=128,
tile_k0=64,
),
FmhaKernelSpec(
name="async_64x128_k64",
hdim=128,
pipeline="qr_async",
tile_m0=64,
tile_n0=128,
tile_k0=64,
),
]
def main():
parser = argparse.ArgumentParser(description="Basic FMHA with Multiple Kernels")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--size", type=int, default=128, help="Sequence length")
parser.add_argument("--num-kernels", type=int, default=0, help="0 = all")
parser.add_argument(
"--workers", type=int, default=0, help="Max parallel JIT workers (0 = auto)"
)
args = parser.parse_args()
print("=" * 70)
print("Example 01: Basic FMHA with Multiple Kernels")
print("=" * 70)
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
# Step 1: Build registry
print(
f"\n {len(specs)} kernel configurations, dtype={args.dtype}, arch={args.arch}"
)
print(f"\n {'#':<3} {'Name':<24} {'Tile':<14} {'Pipeline':<12}")
print(" " + "-" * 56)
for i, s in enumerate(specs, 1):
print(
f" {i:<3} {s.name:<24} {s.tile_m0}x{s.tile_n0}x{s.tile_k0:<6} {s.pipeline:<12}"
)
reg = FmhaRegistry(name="basic_fmha")
for s in specs:
reg.register_kernel(spec_to_config(s, args.dtype, args.arch))
# Step 2: Parallel JIT build via registry.build()
workers = args.workers if args.workers > 0 else None
print(
f"\n--- Parallel JIT Build ({len(specs)} kernels{f', workers={workers}' if workers else ''}) ---"
)
t0 = time.perf_counter()
setups = reg.build(verbose=False, max_workers=workers)
jit_build_s = time.perf_counter() - t0
built = sum(1 for s in setups if s.success)
print(f" Built: {built}/{len(specs)} kernels in {jit_build_s:.1f} s")
if built == 0:
print(" ERROR: No kernels built")
return 1
# Step 3: Run each kernel and validate
seqlen = args.size
prob = FmhaProblem(
batch=2,
nhead_q=8,
nhead_k=8,
seqlen_q=seqlen,
seqlen_k=seqlen,
hdim_q=128,
hdim_v=128,
)
print(
f"\n--- Running Kernels (B={prob.batch} H={prob.nhead_q} S={seqlen} D={prob.hdim_q}) ---"
)
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
O_ref = cpu_attention_fwd(
Q.astype(np.float32),
K.astype(np.float32),
V.astype(np.float32),
prob.scale,
)
print(
f"\n {'#':<3} {'Name':<24} {'Pipeline':<12} {'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':<6}"
)
print(" " + "-" * 80)
results = []
for i, (spec, setup) in enumerate(zip(specs, setups), 1):
if not setup.success or setup.runner is None:
print(
f" {i:<3} {spec.name:<24} {spec.pipeline:<12} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}"
)
results.append((spec.name, False, 0.0, 0.0, 0.0))
continue
res = setup.runner.run(Q, K, V, prob)
if not res.success:
print(
f" {i:<3} {spec.name:<24} {spec.pipeline:<12} {'---':>10} {'---':>10} {'---':>10} {'FAIL':<6}"
)
results.append((spec.name, False, 0.0, 0.0, 0.0))
continue
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
ok = max_err < 1e-2
tag = "PASS" if ok else "FAIL"
print(
f" {i:<3} {spec.name:<24} {spec.pipeline:<12} {res.time_ms:>10.4f} {res.tflops:>10.2f} {max_err:>10.2e} {tag:<6}"
)
results.append((spec.name, ok, res.time_ms, res.tflops, max_err))
setup.runner.cleanup()
# Step 4: Summary
passed = sum(1 for r in results if r[1])
failed = len(results) - passed
valid = [r for r in results if r[1]]
print("\n" + "=" * 70)
print(f" Results: {passed}/{len(results)} passed")
print(
f" Problem: B={prob.batch} H={prob.nhead_q} S={seqlen} D={prob.hdim_q}, dtype={args.dtype}"
)
print(f" JIT time: {jit_build_s:.1f} s (parallel)")
if valid:
best = max(valid, key=lambda x: x[3])
print(f" Best: {best[0]} ({best[3]:.2f} TFLOPS)")
print(f" Status: {'PASS' if failed == 0 else 'FAIL'}")
print("=" * 70)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 02: Multi-Shape FMHA
Runs FMHA forward with a single kernel across multiple problem shapes
(varying batch, sequence length, and head count).
Usage:
python3 02_multi_shape.py
python3 02_multi_shape.py --help
python3 02_multi_shape.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelSpec,
FmhaProblem,
detect_gpu_arch,
setup_fmha_dispatcher,
spec_to_config,
)
def main():
parser = argparse.ArgumentParser(
description="Multi-Shape FMHA Example - runs multiple shapes",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 02_multi_shape.py # Default FP16
python3 02_multi_shape.py --dtype bf16 # BF16 FMHA
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
print("=" * 70)
print("Example 02: Multi-Shape FMHA")
print("=" * 70)
# Step 1: Setup dispatcher
print("\nStep 1: Setup Dispatcher")
# FmhaKernelSpec fields:
# name -- human-readable kernel identifier
# hdim -- head dimension (hdim_q = hdim_v)
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
spec = FmhaKernelSpec(name="multi_shape", hdim=128, pipeline="qr_async")
config = spec_to_config(spec, dtype=args.dtype, arch=args.arch)
setup = setup_fmha_dispatcher(config, verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
runner = setup.runner
print(f" Library: {setup.library_path}")
print(f" Build: {setup.build_time_s:.1f} s")
# Step 2: Run batch of different shapes
print("\nStep 2: Run Shapes")
shapes = [
# (batch, nhead_q, nhead_k, seqlen_q, seqlen_k, hdim)
(1, 4, 4, 64, 64, 128),
(2, 8, 8, 128, 128, 128),
(4, 8, 8, 128, 128, 128),
(1, 16, 16, 256, 256, 128),
(2, 8, 8, 256, 256, 128),
(1, 8, 8, 512, 512, 128),
(2, 4, 4, 512, 512, 128),
(1, 8, 8, 1024, 1024, 128),
]
print(f"\n {'#':<3} {'Shape':<36} {'Time(ms)':>10} {'TFLOPS':>10} {'Status':>8}")
print(" " + "-" * 70)
total_ops = 0
total_time = 0.0
for idx, (b, hq, hk, sq, sk, d) in enumerate(shapes, 1):
prob = FmhaProblem(
batch=b,
nhead_q=hq,
nhead_k=hk,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=d,
hdim_v=d,
)
shape_str = f"B{b}_Hq{hq}_Hk{hk}_S{sq}x{sk}_D{d}"
np.random.seed(42 + idx)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
result = runner.run(Q, K, V, prob)
if result.success:
total_ops += prob.num_ops
total_time += result.time_ms
print(
f" {idx:<3} {shape_str:<36} {result.time_ms:>10.4f} {result.tflops:>10.2f} {'OK':>8}"
)
else:
print(f" {idx:<3} {shape_str:<36} {'N/A':>10} {'N/A':>10} {'Error':>8}")
print(" " + "-" * 70)
if total_time > 0:
avg_tflops = (total_ops / 1e12) / (total_time / 1000)
print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS")
runner.cleanup()
print("\n" + "=" * 70)
print("Multi-Shape FMHA complete!")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,170 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 03: FMHA Benchmark
Performance benchmarking with warmup and repeated iterations across
multiple (batch, sequence length) configurations.
Usage:
python3 03_benchmark.py
python3 03_benchmark.py --help
python3 03_benchmark.py --warmup 5 --repeat 20
python3 03_benchmark.py --arch gfx942
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelSpec,
FmhaProblem,
detect_gpu_arch,
setup_fmha_dispatcher,
spec_to_config,
)
def main():
parser = argparse.ArgumentParser(
description="FMHA Benchmark Example - performance testing",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 03_benchmark.py # Default benchmark suite
python3 03_benchmark.py --warmup 5 # More warmup iterations
python3 03_benchmark.py --repeat 20 # More benchmark iterations
""",
)
parser.add_argument(
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
parser.add_argument(
"--warmup", type=int, default=3, help="Warmup iterations (default: 3)"
)
parser.add_argument(
"--repeat", type=int, default=10, help="Benchmark iterations (default: 10)"
)
args = parser.parse_args()
print("=" * 70)
print("Example 03: FMHA Benchmark")
print("=" * 70)
# Step 1: Setup dispatcher with compute-optimized config
print("\nStep 1: Setup Dispatcher")
# FmhaKernelSpec fields:
# name -- human-readable kernel identifier
# hdim -- head dimension (hdim_q = hdim_v)
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
spec = FmhaKernelSpec(name="benchmark", hdim=128, pipeline="qr_async")
config = spec_to_config(spec, dtype="fp16", arch=args.arch)
setup = setup_fmha_dispatcher(config, verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
runner = setup.runner
print(f" Library: {setup.library_path}")
print(f" Build: {setup.build_time_s:.1f} s")
# Step 2: Benchmark
print("\nStep 2: Benchmark")
bench_configs = [
(1, 128),
(1, 256),
(1, 512),
(1, 1024),
(1, 2048),
(2, 128),
(2, 256),
(2, 512),
(2, 1024),
(4, 128),
(4, 256),
(4, 512),
(8, 128),
(8, 256),
]
print(f" Warmup: {args.warmup}, Repeat: {args.repeat}\n")
print(
f" {'Batch':>5} {'SeqLen':>7} | {'Min(ms)':>10} {'Avg(ms)':>10} {'Max(ms)':>10} | {'TFLOPS':>10}"
)
print(" " + "-" * 62)
all_tflops = []
for batch, seqlen in bench_configs:
prob = FmhaProblem(
batch=batch,
nhead_q=8,
nhead_k=8,
seqlen_q=seqlen,
seqlen_k=seqlen,
hdim_q=128,
hdim_v=128,
)
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
for _ in range(args.warmup):
runner.run(Q, K, V, prob)
times = []
for _ in range(args.repeat):
result = runner.run(Q, K, V, prob)
if result.success:
times.append(result.time_ms)
if times:
min_time = min(times)
avg_time = sum(times) / len(times)
max_time = max(times)
tflops = prob.num_ops / (avg_time * 1e-3) / 1e12
all_tflops.append(tflops)
print(
f" {batch:>5} {seqlen:>7} | {min_time:>10.4f} {avg_time:>10.4f} {max_time:>10.4f} | {tflops:>10.2f}"
)
else:
print(
f" {batch:>5} {seqlen:>7} | {'---':>10} {'---':>10} {'---':>10} | {'FAIL':>10}"
)
runner.cleanup()
# Summary
print("\n" + "=" * 70)
print("Summary")
print("=" * 70)
if all_tflops:
print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS")
print(f" Peak: {max(all_tflops):.2f} TFLOPS")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,176 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 04: FMHA Validation
Validates GPU FMHA against CPU reference across multiple test cases
including standard shapes, GQA ratios, and edge cases.
Usage:
python3 04_validation.py
python3 04_validation.py --help
python3 04_validation.py --dtype bf16
python3 04_validation.py --rtol 1e-2
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelSpec,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
spec_to_config,
)
def main():
parser = argparse.ArgumentParser(
description="FMHA Validation Example - validates GPU results against CPU",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 04_validation.py # Default FP16 validation
python3 04_validation.py --dtype bf16 # BF16 validation
python3 04_validation.py --rtol 1e-2 # Relaxed tolerance
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--rtol", type=float, default=1e-2, help="Relative tolerance (default: 1e-2)"
)
parser.add_argument(
"--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)"
)
parser.add_argument(
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
print("=" * 70)
print("Example 04: FMHA Validation")
print("=" * 70)
# Step 1: Setup dispatcher
print("\nStep 1: Setup Dispatcher")
# FmhaKernelSpec fields:
# name -- human-readable kernel identifier
# hdim -- head dimension (hdim_q = hdim_v)
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
spec = FmhaKernelSpec(name="validation", hdim=128, pipeline="qr_async")
config = spec_to_config(spec, dtype=args.dtype, arch=args.arch)
setup = setup_fmha_dispatcher(config, verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
runner = setup.runner
print(f" Library: {setup.library_path}")
print(f" Build: {setup.build_time_s:.1f} s")
# Step 2: Run validation tests
print("\nStep 2: Validation Tests")
validator = FmhaValidator(rtol=args.rtol, atol=args.atol)
# (name, batch, nhead_q, nhead_k, seqlen_q, seqlen_k, hdim)
test_cases = [
("Small", 1, 4, 4, 64, 64, 128),
("Medium", 2, 8, 8, 128, 128, 128),
("Large", 1, 8, 8, 256, 256, 128),
("Long-seq", 1, 4, 4, 512, 512, 128),
("Non-square", 2, 4, 4, 64, 256, 128),
("GQA-2:1", 2, 8, 4, 128, 128, 128),
("GQA-4:1", 1, 16, 4, 128, 128, 128),
("GQA-8:1", 1, 16, 2, 64, 64, 128),
("Single-query", 1, 4, 4, 1, 128, 128),
("Batched", 4, 8, 8, 128, 128, 128),
]
passed = 0
failed = 0
print(f"\n {'#':<3} {'Test':<14} {'Shape':<30} {'MaxErr':>10} {'Status':>8}")
print(" " + "-" * 70)
for idx, (name, b, hq, hk, sq, sk, d) in enumerate(test_cases, 1):
prob = FmhaProblem(
batch=b,
nhead_q=hq,
nhead_k=hk,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=d,
hdim_v=d,
)
shape_str = f"B{b}_Hq{hq}_Hk{hk}_S{sq}x{sk}"
np.random.seed(42 + idx)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
result = runner.run(Q, K, V, prob)
if not result.success:
print(
f" {idx:<3} {name:<14} {shape_str:<30} {'GPU Err':>10} {'FAILED':>8}"
)
failed += 1
continue
O_ref = cpu_attention_fwd(
Q.astype(np.float32),
K.astype(np.float32),
V.astype(np.float32),
prob.scale,
)
is_valid, max_abs, _ = validator.check(result.output, O_ref)
if is_valid:
print(
f" {idx:<3} {name:<14} {shape_str:<30} {max_abs:>10.2e} {'PASSED':>8}"
)
passed += 1
else:
print(
f" {idx:<3} {name:<14} {shape_str:<30} {max_abs:>10.2e} {'FAILED':>8}"
)
failed += 1
runner.cleanup()
# Summary
print("\n" + "=" * 70)
total = passed + failed
print(f" Results: {passed}/{total} passed")
print(f" Settings: dtype={args.dtype}, rtol={args.rtol}, atol={args.atol}")
print(f" Status: {'PASS' if failed == 0 else 'FAIL'}")
print("=" * 70)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,219 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 05: NumPy Integration
Shows how to create a GPU-accelerated attention wrapper that works
seamlessly with NumPy arrays, hiding all HIP memory management.
Usage:
python3 05_numpy_integration.py
python3 05_numpy_integration.py --help
python3 05_numpy_integration.py --seqlen 256
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def fmha_matmul(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float = None,
runner=None,
) -> np.ndarray:
"""GPU-accelerated scaled dot-product attention via FMHA dispatcher.
Args:
Q: [batch, nhead_q, seqlen_q, hdim_q] float16/float32
K: [batch, nhead_k, seqlen_k, hdim_q] float16/float32
V: [batch, nhead_k, seqlen_k, hdim_v] float16/float32
scale: softmax scale (default: 1/sqrt(hdim_q))
runner: reuse an existing runner from setup_fmha_dispatcher
Returns:
O: [batch, nhead_q, seqlen_q, hdim_v] float16
"""
batch, nhead_q, seqlen_q, hdim_q = Q.shape
_, nhead_k, seqlen_k, hdim_v = V.shape
prob = FmhaProblem(
batch=batch,
nhead_q=nhead_q,
nhead_k=nhead_k,
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
hdim_q=hdim_q,
hdim_v=hdim_v,
)
result = runner.run(
Q.astype(np.float16), K.astype(np.float16), V.astype(np.float16), prob
)
if not result.success:
raise RuntimeError(f"GPU FMHA failed: {result.error}")
return result.output
def main():
parser = argparse.ArgumentParser(
description="NumPy Integration Example - GPU-accelerated attention wrapper",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 05_numpy_integration.py # Default
python3 05_numpy_integration.py --seqlen 256 # Longer sequences
""",
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=4)
parser.add_argument("--seqlen", type=int, default=64)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument("--rtol", type=float, default=1e-2)
parser.add_argument("--atol", type=float, default=1e-2)
args = parser.parse_args()
print("=" * 70)
print("Example 05: NumPy Integration")
print("=" * 70)
# Step 1: JIT-compile FMHA kernel
print("\nStep 1: JIT-Compile FMHA Dispatcher")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
return 1
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
print(f" Arch: {args.arch}")
np_dtype = np.float16
# Step 2: Demo -- simple attention call
print("\n" + "=" * 70)
print("Step 2: Simple Attention Call")
print("=" * 70)
np.random.seed(42)
Q = (np.random.randn(args.batch, args.nhead, args.seqlen, args.hdim) * 0.5).astype(
np_dtype
)
K = (np.random.randn(args.batch, args.nhead, args.seqlen, args.hdim) * 0.5).astype(
np_dtype
)
V = (np.random.randn(args.batch, args.nhead, args.seqlen, args.hdim) * 0.5).astype(
np_dtype
)
out = fmha_matmul(Q, K, V, runner=runner)
print(f" Q: {Q.shape} -> O: {out.shape}")
print(f" Output range: [{out.min():.4f}, {out.max():.4f}]")
print(f" Output sum: {out.sum():.4f}")
# Step 3: Validate against CPU reference
print("\n" + "=" * 70)
print("Step 3: Validate Against CPU Reference")
print("=" * 70)
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
O_ref = cpu_attention_fwd(
Q.astype(np.float32),
K.astype(np.float32),
V.astype(np.float32),
prob.scale,
)
diff = np.abs(out.astype(np.float32) - O_ref)
max_abs = float(diff.max())
max_rel = float((diff / (np.abs(O_ref) + 1e-6)).max())
match = np.allclose(out.astype(np.float32), O_ref, atol=args.atol, rtol=args.rtol)
print(f" Max abs error: {max_abs:.6e}")
print(f" Max rel error: {max_rel:.6e}")
print(f" Match: {match}")
# Step 4: Demo -- multi-head attention with GQA
print("\n" + "=" * 70)
print("Step 4: GQA Attention (nhead_q=8, nhead_k=2)")
print("=" * 70)
nhead_q, nhead_k = 8, 2
Q_gqa = (np.random.randn(args.batch, nhead_q, args.seqlen, args.hdim) * 0.5).astype(
np_dtype
)
K_gqa = (np.random.randn(args.batch, nhead_k, args.seqlen, args.hdim) * 0.5).astype(
np_dtype
)
V_gqa = (np.random.randn(args.batch, nhead_k, args.seqlen, args.hdim) * 0.5).astype(
np_dtype
)
O_gqa = fmha_matmul(Q_gqa, K_gqa, V_gqa, runner=runner)
prob_gqa = FmhaProblem(
batch=args.batch,
nhead_q=nhead_q,
nhead_k=nhead_k,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
O_gqa_ref = cpu_attention_fwd(
Q_gqa.astype(np.float32),
K_gqa.astype(np.float32),
V_gqa.astype(np.float32),
prob_gqa.scale,
)
gqa_match = np.allclose(
O_gqa.astype(np.float32), O_gqa_ref, atol=args.atol, rtol=args.rtol
)
print(f" Q: {Q_gqa.shape}, K: {K_gqa.shape}, V: {V_gqa.shape}")
print(f" O: {O_gqa.shape}")
print(f" Match: {gqa_match}")
# Summary
print("\n" + "=" * 70)
print("NumPy Integration Pattern:")
print("=" * 70)
print(" 1. setup = setup_fmha_dispatcher(config)")
print(" 2. O = fmha_matmul(Q, K, V, runner=setup.runner)")
print("=" * 70)
return 0 if match and gqa_match else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,220 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 06: JSON Export
Builds an FMHA kernel via setup_fmha_dispatcher, then exports the
registry configuration to JSON for inspection or reuse.
Usage:
python3 06_json_export.py
python3 06_json_export.py --help
python3 06_json_export.py --output fmha_kernels.json
"""
import sys
import json
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from fmha_utils import (
FmhaKernelConfig,
setup_fmha_dispatcher,
detect_gpu_arch,
)
def main():
parser = argparse.ArgumentParser(
description="JSON Export Example - export FMHA registry to JSON",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 06_json_export.py # Default output
python3 06_json_export.py --output fmha_kernels.json # Custom file
""",
)
parser.add_argument(
"--output",
"-o",
default="fmha_kernels.json",
help="Output JSON file (default: fmha_kernels.json)",
)
parser.add_argument("--arch", default=detect_gpu_arch())
args = parser.parse_args()
print("=" * 70)
print("Example 06: JSON Export")
print("=" * 70)
# Step 1: Define FMHA kernel configurations
print("\nStep 1: Define Kernel Configurations")
configs = [
FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=128,
hdim_v=128,
pipeline="qr_async",
# Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q
tile_m0=128,
tile_n0=128,
tile_k0=32,
# Stage 1 (Attn*V): hdim_v x seqlen_k x alignment
tile_n1=128,
tile_k1=32,
tile_k0max=128,
# Wave config per stage
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
# Warp tile per stage
warp_m0=32,
warp_n0=32,
warp_k0=16,
warp_m1=32,
warp_n1=32,
warp_k1=16,
gfx_arch=args.arch,
),
FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=128,
hdim_v=128,
pipeline="qr",
tile_m0=64,
tile_n0=128,
tile_k0=32,
tile_n1=128,
tile_k1=32,
tile_k0max=128,
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
warp_m0=16,
warp_n0=16,
warp_k0=32,
warp_m1=16,
warp_n1=16,
warp_k1=16,
pad_s=False,
pad_sk=False,
pad_d=True,
pad_dv=True,
gfx_arch=args.arch,
),
FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=64,
hdim_v=64,
pipeline="qr_async",
tile_m0=128,
tile_n0=64,
tile_k0=32,
tile_n1=64,
tile_k1=32,
tile_k0max=64,
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
warp_m0=32,
warp_n0=32,
warp_k0=16,
warp_m1=32,
warp_n1=32,
warp_k1=16,
gfx_arch=args.arch,
),
]
for i, cfg in enumerate(configs, 1):
print(f" [{i}] {cfg.name}: pipeline={cfg.pipeline}, hdim={cfg.hdim_q}")
# Step 2: Build via setup_fmha_dispatcher
print("\n" + "=" * 70)
print("Step 2: Build Kernel (JIT)")
print("=" * 70)
setup = setup_fmha_dispatcher(configs[0], verbose=True)
if setup.success:
print(f" Built: {setup.library_path}")
print(f" Time: {setup.build_time_s:.1f} s")
else:
print(f" Build skipped/failed: {setup.error}")
print(" (Proceeding with config export only)")
# Step 3: Export to JSON
print("\n" + "=" * 70)
print("Step 3: Export to JSON")
print("=" * 70)
export_data = {
"registry": "fmha_export",
"arch": args.arch,
"kernel_count": len(configs),
"kernels": [],
}
for cfg in configs:
kernel_info = {
"name": cfg.name,
"family": cfg.family,
"data_type": cfg.data_type,
"hdim_q": cfg.hdim_q,
"hdim_v": cfg.hdim_v,
"pipeline": cfg.pipeline,
"tile": list(cfg.tile),
"wave": list(cfg.wave),
"warp": list(cfg.warp),
"padding": list(cfg.padding),
"mode": cfg.mode,
"target": cfg.gfx_arch,
"codegen_json": json.loads(cfg.to_codegen_json()),
}
export_data["kernels"].append(kernel_info)
json_str = json.dumps(export_data, indent=2)
with open(args.output, "w") as f:
f.write(json_str)
print(f" Saved to: {args.output}")
file_size = Path(args.output).stat().st_size
print(f" File size: {file_size:,} bytes")
print(f" Kernel count: {len(configs)}")
# Step 4: Preview
print("\n" + "=" * 70)
print("Step 4: JSON Preview")
print("=" * 70)
preview = json_str[:500]
if len(json_str) > 500:
preview += "\n ..."
print(preview)
print("\n" + "=" * 70)
print("JSON Export complete!")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,256 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 07: Stress Test - Multiple FMHA Kernels with Validation
Generates many FmhaKernelSpec configurations across pipelines, head
dimensions, and data types, registers them in an FmhaRegistry, builds
all in parallel, and validates each against a CPU reference.
Usage:
python3 07_stress_test.py
python3 07_stress_test.py --help
python3 07_stress_test.py --num-kernels 4
python3 07_stress_test.py --workers 8
"""
import sys
import time
import argparse
from pathlib import Path
from typing import List
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelSpec,
FmhaProblem,
FmhaRegistry,
FmhaValidator,
cpu_attention_fwd,
spec_to_config,
detect_gpu_arch,
)
# FmhaKernelSpec fields:
# name -- human-readable kernel identifier
# hdim -- head dimension (hdim_q = hdim_v)
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
KERNEL_SPECS: List[FmhaKernelSpec] = [
# qr_async pipeline -- various tile sizes
FmhaKernelSpec(
name="qr_async_h128_t128",
hdim=128,
pipeline="qr_async",
tile_m0=128,
tile_n0=128,
tile_k0=32,
),
FmhaKernelSpec(
name="qr_async_h128_t64",
hdim=128,
pipeline="qr_async",
tile_m0=64,
tile_n0=128,
tile_k0=32,
),
FmhaKernelSpec(
name="qr_async_h64_t128",
hdim=64,
pipeline="qr_async",
tile_m0=128,
tile_n0=64,
tile_k0=32,
),
FmhaKernelSpec(
name="qr_async_h64_t64",
hdim=64,
pipeline="qr_async",
tile_m0=64,
tile_n0=64,
tile_k0=32,
),
# qr pipeline -- various tile sizes
FmhaKernelSpec(
name="qr_h128_t128",
hdim=128,
pipeline="qr",
tile_m0=128,
tile_n0=128,
tile_k0=32,
),
FmhaKernelSpec(
name="qr_h128_t64", hdim=128, pipeline="qr", tile_m0=64, tile_n0=128, tile_k0=32
),
FmhaKernelSpec(
name="qr_h64_t128", hdim=64, pipeline="qr", tile_m0=128, tile_n0=64, tile_k0=32
),
FmhaKernelSpec(
name="qr_h64_t64", hdim=64, pipeline="qr", tile_m0=64, tile_n0=64, tile_k0=32
),
]
def print_spec_table(specs: List[FmhaKernelSpec]):
print(
f"\n {'#':<3} {'Name':<25} {'Pipeline':<12} {'Hdim':>5} "
f"{'TileM':>6} {'TileN':>6} {'TileK':>6}"
)
print(" " + "-" * 70)
for i, s in enumerate(specs, 1):
print(
f" {i:<3} {s.name:<25} {s.pipeline:<12} {s.hdim:>5} "
f"{s.tile_m0:>6} {s.tile_n0:>6} {s.tile_k0:>6}"
)
print(" " + "-" * 70)
def main():
parser = argparse.ArgumentParser(
description="FMHA Stress Test - multiple kernels with validation",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 07_stress_test.py # Test all kernels
python3 07_stress_test.py --num-kernels 4 # First 4 only
python3 07_stress_test.py --workers 8 # 8 parallel compile workers
""",
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument(
"--num-kernels", type=int, default=0, help="Number of kernels to test (0 = all)"
)
parser.add_argument(
"--workers", type=int, default=0, help="Max parallel build workers (0 = auto)"
)
parser.add_argument("--rtol", type=float, default=1e-2)
parser.add_argument("--atol", type=float, default=1e-2)
args = parser.parse_args()
print("=" * 70)
print("Example 07: FMHA Stress Test - Multiple Kernels")
print("=" * 70)
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
print(f"\n Arch: {args.arch}")
print(f" Kernels: {len(specs)}")
print_spec_table(specs)
# Step 1: Register all in FmhaRegistry and build
print("\n" + "=" * 70)
print(" JIT BUILD")
print("=" * 70)
reg = FmhaRegistry("stress_test")
for spec in specs:
cfg = spec_to_config(spec, dtype="fp16", arch=args.arch)
reg.register_kernel(cfg)
workers = args.workers if args.workers > 0 else None
print(f"\n Building {len(reg)} kernels (workers={workers or 'auto'}) ...")
t0 = time.perf_counter()
build_results = reg.build(verbose=False, max_workers=workers)
build_time = time.perf_counter() - t0
built = sum(1 for r in build_results if r.success)
print(f" Built: {built}/{len(specs)} in {build_time:.1f} s")
for i, r in enumerate(build_results, 1):
tag = "OK" if r.success else f"FAIL: {r.error[:50]}"
name = r.config.name if r.config else f"kernel_{i}"
print(f" [{i}] {name}: {tag}")
if built == 0:
print("\n No kernels built -- aborting")
return 1
# Step 2: Validate each built kernel
print("\n" + "=" * 70)
print(" VALIDATION")
print("=" * 70)
prob = FmhaProblem(
batch=2, nhead_q=4, nhead_k=4, seqlen_q=64, seqlen_k=64, hdim_q=128, hdim_v=128
)
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
O_ref = cpu_attention_fwd(
Q.astype(np.float32), K.astype(np.float32), V.astype(np.float32), prob.scale
)
validator = FmhaValidator(rtol=args.rtol, atol=args.atol)
print(
f"\n Problem: B={prob.batch} Hq={prob.nhead_q} Sq={prob.seqlen_q} D={prob.hdim_q}"
)
print(f"\n {'#':<3} {'Name':<35} {'Time':>8} {'MaxErr':>10} {'Status':<6}")
print(" " + "-" * 66)
total_pass = 0
total_fail = 0
for i, r in enumerate(build_results, 1):
name = r.config.name if r.config else f"kernel_{i}"
if not r.success or r.runner is None:
print(f" {i:<3} {name:<35} {'---':>8} {'---':>10} {'SKIP':<6}")
continue
hdim = r.config.hdim_q if r.config else 128
if hdim != prob.hdim_q:
print(f" {i:<3} {name:<35} {'---':>8} {'---':>10} {'SKIP':<6}")
continue
res = r.runner.run(Q, K, V, prob)
if not res.success:
print(f" {i:<3} {name:<35} {'---':>8} {'---':>10} {'FAIL':<6}")
total_fail += 1
continue
ok, max_abs, _ = validator.check(res.output, O_ref)
tag = "PASS" if ok else "FAIL"
print(f" {i:<3} {name:<35} {res.time_ms:>7.4f}ms {max_abs:>10.2e} {tag:<6}")
if ok:
total_pass += 1
else:
total_fail += 1
r.runner.cleanup()
# Summary
print("\n" + "=" * 70)
print(" SUMMARY")
print("=" * 70)
print(f"\n Total: {len(specs)}")
print(f" Built: {built}")
print(f" Passed: {total_pass}")
print(f" Failed: {total_fail}")
print(f" Build time: {build_time:.1f} s")
print(f" Tolerance: rtol={args.rtol}, atol={args.atol}")
if total_fail == 0 and total_pass > 0:
print("\n *** ALL VALIDATED KERNELS PASSED ***")
elif total_fail > 0:
print(f"\n *** {total_fail} KERNELS FAILED ***")
print("=" * 70)
return 0 if total_fail == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,348 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 08: Kernel Selection Heuristics
Demonstrates how to build multiple FMHA kernels with different tile
sizes and select the best kernel for a given problem. Shows that
smaller tiles tend to be better for short sequences while larger tiles
are better for long sequences.
Usage:
python3 08_heuristics.py
python3 08_heuristics.py --help
python3 08_heuristics.py --arch gfx950
"""
import sys
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import List
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaRegistry,
detect_gpu_arch,
)
@dataclass
class TileProfile:
"""A kernel profile tagged with a human-readable label."""
label: str
config: FmhaKernelConfig
category: str # "small", "medium", "large"
def build_tile_profiles(arch: str) -> List[TileProfile]:
"""Create kernel configs with varying tile sizes."""
return [
TileProfile(
label="small_64x64",
category="small",
config=FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=128,
hdim_v=128,
pipeline="qr_async",
# Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q
tile_m0=64,
tile_n0=64,
tile_k0=32,
# Stage 1 (Attn*V): hdim_v x seqlen_k x alignment
tile_n1=128,
tile_k1=32,
tile_k0max=128,
# Wave config per stage
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
# Warp tile per stage
warp_m0=16,
warp_n0=16,
warp_k0=16,
warp_m1=16,
warp_n1=16,
warp_k1=16,
gfx_arch=arch,
),
),
TileProfile(
label="medium_128x128",
category="medium",
config=FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=128,
hdim_v=128,
pipeline="qr_async",
tile_m0=128,
tile_n0=128,
tile_k0=32,
tile_n1=128,
tile_k1=32,
tile_k0max=128,
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
warp_m0=32,
warp_n0=32,
warp_k0=16,
warp_m1=32,
warp_n1=32,
warp_k1=16,
gfx_arch=arch,
),
),
TileProfile(
label="large_128x256",
category="large",
config=FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=128,
hdim_v=128,
pipeline="qr_async",
tile_m0=128,
tile_n0=256,
tile_k0=32,
tile_n1=128,
tile_k1=32,
tile_k0max=128,
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
warp_m0=32,
warp_n0=32,
warp_k0=16,
warp_m1=32,
warp_n1=32,
warp_k1=16,
gfx_arch=arch,
),
),
TileProfile(
label="medium_qr_128x128",
category="medium",
config=FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=128,
hdim_v=128,
pipeline="qr",
tile_m0=128,
tile_n0=128,
tile_k0=32,
tile_n1=128,
tile_k1=32,
tile_k0max=128,
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
warp_m0=32,
warp_n0=32,
warp_k0=16,
warp_m1=32,
warp_n1=32,
warp_k1=16,
pad_s=False,
pad_sk=False,
pad_d=True,
pad_dv=True,
gfx_arch=arch,
),
),
]
def select_kernel_heuristic(seqlen: int, profiles: List[TileProfile]) -> TileProfile:
"""Simple heuristic: pick tile size category based on sequence length."""
if seqlen <= 64:
target = "small"
elif seqlen <= 256:
target = "medium"
else:
target = "large"
candidates = [p for p in profiles if p.category == target]
if not candidates:
candidates = profiles
return candidates[0]
def main():
parser = argparse.ArgumentParser(
description="FMHA Heuristics - kernel selection by problem size",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 08_heuristics.py
python3 08_heuristics.py --arch gfx950
""",
)
parser.add_argument("--arch", default=detect_gpu_arch())
args = parser.parse_args()
print("=" * 70)
print("Example 08: Kernel Selection Heuristics")
print("=" * 70)
# Step 1: Build kernel pool
print("\nStep 1: Build Kernel Pool")
profiles = build_tile_profiles(args.arch)
reg = FmhaRegistry("heuristic_pool")
for p in profiles:
reg.register_kernel(p.config)
print(f" Profiles: {len(profiles)}")
for i, p in enumerate(profiles, 1):
tile_str = f"{p.config.tile[0]}x{p.config.tile[1]}"
print(
f" [{i}] {p.label:<25} tile={tile_str:<10} pipeline={p.config.pipeline}"
)
print("\n Building kernels ...")
build_results = reg.build(verbose=False)
built = sum(1 for r in build_results if r.success)
print(f" Built: {built}/{len(profiles)}")
for i, r in enumerate(build_results):
tag = "OK" if r.success else f"FAIL: {r.error[:40]}"
print(f" [{i + 1}] {profiles[i].label}: {tag}")
if built == 0:
print(" No kernels built -- aborting")
return 1
# Step 2: Run each kernel on multiple sequence lengths
print("\n" + "=" * 70)
print("Step 2: Benchmark Across Sequence Lengths")
print("=" * 70)
test_seqlens = [32, 64, 128, 256, 512]
header = f" {'SeqLen':>7}"
for p in profiles:
header += f" | {p.label:>18}"
header += " | {'Best':>18}"
print(f"\n {'SeqLen':>7}", end="")
for p in profiles:
print(f" | {p.label:>18}", end="")
print(f" | {'Best':>18}")
print(" " + "-" * (10 + 21 * len(profiles) + 22))
for seqlen in test_seqlens:
prob = FmhaProblem(
batch=2,
nhead_q=8,
nhead_k=8,
seqlen_q=seqlen,
seqlen_k=seqlen,
hdim_q=128,
hdim_v=128,
)
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
row = f" {seqlen:>7}"
best_tflops = 0.0
best_label = "---"
for j, (p, r) in enumerate(zip(profiles, build_results)):
if not r.success or r.runner is None:
row += f" | {'N/A':>18}"
continue
res = r.runner.run(Q, K, V, prob)
if res.success:
cell = f"{res.tflops:.2f} TFLOPS"
row += f" | {cell:>18}"
if res.tflops > best_tflops:
best_tflops = res.tflops
best_label = p.label
else:
row += f" | {'ERR':>18}"
row += f" | {best_label:>18}"
print(row)
# Step 3: Demonstrate heuristic selection
print("\n" + "=" * 70)
print("Step 3: Heuristic Selection Demo")
print("=" * 70)
print(f"\n {'SeqLen':>7} {'Selected':>25} {'TFLOPS':>10} {'Status':<6}")
print(" " + "-" * 55)
for seqlen in test_seqlens:
selected = select_kernel_heuristic(seqlen, profiles)
idx = profiles.index(selected)
r = build_results[idx]
if not r.success or r.runner is None:
print(f" {seqlen:>7} {selected.label:>25} {'---':>10} {'SKIP':<6}")
continue
prob = FmhaProblem(
batch=2,
nhead_q=8,
nhead_k=8,
seqlen_q=seqlen,
seqlen_k=seqlen,
hdim_q=128,
hdim_v=128,
)
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
res = r.runner.run(Q, K, V, prob)
if res.success:
print(f" {seqlen:>7} {selected.label:>25} {res.tflops:>10.2f} {'OK':<6}")
else:
print(f" {seqlen:>7} {selected.label:>25} {'---':>10} {'FAIL':<6}")
# Cleanup
for r in build_results:
if r.runner:
r.runner.cleanup()
print("\n" + "=" * 70)
print("Heuristic Insight:")
print(" - Small tiles: low overhead for short sequences")
print(" - Large tiles: high throughput for long sequences")
print(" - Pipeline choice also matters (qr vs qr_async)")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,298 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 09: Multiple Registries
Creates separate FmhaRegistry instances for different optimization
targets (latency vs throughput), builds both, runs the same problem
through each, and compares results.
Usage:
python3 09_multi_registry.py
python3 09_multi_registry.py --help
python3 09_multi_registry.py --arch gfx950
"""
import sys
import time
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaRegistry,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
)
def make_latency_config(arch: str) -> FmhaKernelConfig:
"""Latency-optimized: smaller tiles, lower launch overhead."""
return FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=128,
hdim_v=128,
pipeline="qr",
# Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q
tile_m0=64,
tile_n0=128,
tile_k0=32,
# Stage 1 (Attn*V): hdim_v x seqlen_k x alignment
tile_n1=128,
tile_k1=32,
tile_k0max=128,
# Wave config per stage
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
# Warp tile per stage
warp_m0=16,
warp_n0=16,
warp_k0=32,
warp_m1=16,
warp_n1=16,
warp_k1=16,
pad_s=False,
pad_sk=False,
pad_d=True,
pad_dv=True,
gfx_arch=arch,
)
def make_throughput_config(arch: str) -> FmhaKernelConfig:
"""Throughput-optimized: larger tiles, async pipeline."""
return FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=128,
hdim_v=128,
pipeline="qr_async",
tile_m0=128,
tile_n0=128,
tile_k0=32,
tile_n1=128,
tile_k1=32,
tile_k0max=128,
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
warp_m0=32,
warp_n0=32,
warp_k0=16,
warp_m1=32,
warp_n1=32,
warp_k1=16,
gfx_arch=arch,
)
def main():
parser = argparse.ArgumentParser(
description="Multiple FMHA Registries - latency vs throughput",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 09_multi_registry.py
python3 09_multi_registry.py --arch gfx950
""",
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--rtol", type=float, default=1e-2)
parser.add_argument("--atol", type=float, default=1e-2)
args = parser.parse_args()
print("=" * 70)
print("Example 09: Multiple Registries")
print("=" * 70)
# Step 1: Define optimization-specific configs
print("\nStep 1: Define Optimization Targets")
latency_cfg = make_latency_config(args.arch)
throughput_cfg = make_throughput_config(args.arch)
print(f" Latency config: {latency_cfg.name}")
print(f" pipeline={latency_cfg.pipeline}, tile={latency_cfg.tile[:2]}")
print(f" Throughput config: {throughput_cfg.name}")
print(f" pipeline={throughput_cfg.pipeline}, tile={throughput_cfg.tile[:2]}")
# Step 2: Create separate registries
print("\n" + "=" * 70)
print("Step 2: Create and Build Registries")
print("=" * 70)
latency_reg = FmhaRegistry("latency")
latency_reg.register_kernel(latency_cfg)
throughput_reg = FmhaRegistry("throughput")
throughput_reg.register_kernel(throughput_cfg)
print(f"\n Building 'latency' registry ({len(latency_reg)} kernel) ...")
t0 = time.perf_counter()
latency_results = latency_reg.build(verbose=False)
lat_build_time = time.perf_counter() - t0
print(f" Building 'throughput' registry ({len(throughput_reg)} kernel) ...")
t0 = time.perf_counter()
throughput_results = throughput_reg.build(verbose=False)
thr_build_time = time.perf_counter() - t0
lat_ok = latency_results and latency_results[0].success
thr_ok = throughput_results and throughput_results[0].success
print(f"\n Latency: {'OK' if lat_ok else 'FAIL'} ({lat_build_time:.1f} s)")
print(f" Throughput: {'OK' if thr_ok else 'FAIL'} ({thr_build_time:.1f} s)")
if not lat_ok and not thr_ok:
print(" No kernels built -- aborting")
return 1
# Step 3: Run same problem through both
print("\n" + "=" * 70)
print("Step 3: Run Same Problem Through Both Registries")
print("=" * 70)
test_configs = [
(2, 4, 4, 64, 64, 128, "small"),
(2, 8, 8, 128, 128, 128, "medium"),
(2, 8, 8, 256, 256, 128, "large"),
]
validator = FmhaValidator(rtol=args.rtol, atol=args.atol)
print(f"\n {'Problem':<12} {'Latency':>18} {'Throughput':>18} {'Match':<6}")
print(" " + "-" * 60)
all_match = True
for batch, hq, hk, sq, sk, hdim, desc in test_configs:
prob = FmhaProblem(
batch=batch,
nhead_q=hq,
nhead_k=hk,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=hdim,
hdim_v=hdim,
)
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
O_ref = cpu_attention_fwd(
Q.astype(np.float32),
K.astype(np.float32),
V.astype(np.float32),
prob.scale,
)
lat_cell = "N/A"
thr_cell = "N/A"
results_match = True
if lat_ok:
res_lat = latency_results[0].runner.run(Q, K, V, prob)
if res_lat.success:
lat_cell = f"{res_lat.tflops:.2f} TFLOPS"
ok, _, _ = validator.check(res_lat.output, O_ref)
if not ok:
results_match = False
if thr_ok:
res_thr = throughput_results[0].runner.run(Q, K, V, prob)
if res_thr.success:
thr_cell = f"{res_thr.tflops:.2f} TFLOPS"
ok, _, _ = validator.check(res_thr.output, O_ref)
if not ok:
results_match = False
if not results_match:
all_match = False
tag = "YES" if results_match else "NO"
print(f" {desc:<12} {lat_cell:>18} {thr_cell:>18} {tag:<6}")
# Step 4: Detailed comparison on a single problem
print("\n" + "=" * 70)
print("Step 4: Detailed Comparison (B=2 H=8 S=128 D=128)")
print("=" * 70)
prob = FmhaProblem(
batch=2,
nhead_q=8,
nhead_k=8,
seqlen_q=128,
seqlen_k=128,
hdim_q=128,
hdim_v=128,
)
np.random.seed(123)
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
O_ref = cpu_attention_fwd(
Q.astype(np.float32),
K.astype(np.float32),
V.astype(np.float32),
prob.scale,
)
for name, results, ok in [
("Latency", latency_results, lat_ok),
("Throughput", throughput_results, thr_ok),
]:
if not ok:
print(f"\n {name}: not available")
continue
res = results[0].runner.run(Q, K, V, prob)
if not res.success:
print(f"\n {name}: execution failed")
continue
valid, max_abs, max_rel = validator.check(res.output, O_ref)
print(f"\n {name}:")
print(f" Time: {res.time_ms:.4f} ms")
print(f" TFLOPS: {res.tflops:.2f}")
print(f" Max Abs: {max_abs:.2e}")
print(f" Max Rel: {max_rel:.2e}")
print(f" Valid: {valid}")
# Cleanup
for results in [latency_results, throughput_results]:
for r in results:
if r.runner:
r.runner.cleanup()
# Summary
print("\n" + "=" * 70)
print("Multi-Registry Pattern:")
print("=" * 70)
print(" 1. Create FmhaRegistry per optimization target")
print(" 2. Register target-specific FmhaKernelConfig in each")
print(" 3. Build both registries")
print(" 4. Route problems to the best registry")
print(" 5. Compare results for correctness")
print("=" * 70)
return 0 if all_match else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,262 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 10: Advanced FMHA Benchmarking
Benchmarks FMHA forward across multiple problem sizes with configurable
warmup, repeat, and cache-flush settings. Reports min/avg/max/median
time and TFLOPS for each problem.
Usage:
python3 10_advanced_benchmark.py
python3 10_advanced_benchmark.py --warmup 10 --repeat 50
python3 10_advanced_benchmark.py --flush-cache
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
setup_fmha_dispatcher,
detect_gpu_arch,
)
def parse_args():
parser = argparse.ArgumentParser(
description="Advanced FMHA benchmarking with full parameter control",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 10_advanced_benchmark.py # Defaults
python3 10_advanced_benchmark.py --warmup 10 --repeat 50 # More samples
python3 10_advanced_benchmark.py --flush-cache # Flush L2
""",
)
parser.add_argument(
"--warmup", type=int, default=5, help="Number of warmup iterations (default: 5)"
)
parser.add_argument(
"--repeat",
type=int,
default=20,
help="Number of timed iterations (default: 20)",
)
parser.add_argument(
"--flush-cache",
action="store_true",
help="Allocate a scratch buffer between runs to flush GPU cache",
)
parser.add_argument(
"--arch", default=detect_gpu_arch(), help="GPU architecture (auto-detected)"
)
parser.add_argument(
"--lib", default=None, help="Path to prebuilt .so (JIT-builds if omitted)"
)
args = parser.parse_args()
return args
PROBLEM_TABLE = [
# (batch, nhead_q, nhead_k, seqlen_q, seqlen_k, hdim, label)
(1, 8, 8, 64, 64, 128, "tiny"),
(2, 8, 8, 128, 128, 128, "small"),
(2, 16, 16, 256, 256, 128, "medium"),
(4, 16, 16, 512, 512, 128, "large"),
(2, 32, 32, 1024, 1024, 128, "xlarge"),
(1, 32, 8, 256, 256, 128, "GQA-4:1"),
]
def flush_gpu_cache():
"""Allocate and touch a large buffer to evict L2 cache lines."""
scratch = np.random.randint(0, 255, size=32 * 1024 * 1024, dtype=np.uint8)
_ = scratch.sum()
def run_benchmark(
runner, prob: FmhaProblem, warmup: int, repeat: int, flush_cache: bool
) -> list:
"""Run warmup + repeat iterations and return list of times in ms."""
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
for _ in range(warmup):
runner.run(Q, K, V, prob)
times = []
for _ in range(repeat):
if flush_cache:
flush_gpu_cache()
result = runner.run(Q, K, V, prob)
if result.success:
times.append(result.time_ms)
return times
def main():
args = parse_args()
print("=" * 70)
print("Example 10: Advanced FMHA Benchmarking")
print("=" * 70)
print("\nBenchmark Configuration:")
print(f" Warmup: {args.warmup} iterations")
print(f" Repeat: {args.repeat} iterations")
print(f" Flush Cache: {args.flush_cache}")
print(f" Arch: {args.arch}")
print(f" Problems: {len(PROBLEM_TABLE)}")
# Step 1: Load or JIT-build kernel
print("\n" + "=" * 70)
print("Step 1: Load / Build Kernel")
print("=" * 70)
print(" JIT building kernel...")
config = FmhaKernelConfig(
family="fwd",
data_type="fp16",
hdim_q=128,
hdim_v=128,
pipeline="qr_async",
# Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q
tile_m0=128,
tile_n0=128,
tile_k0=32,
# Stage 1 (Attn*V): hdim_v x seqlen_k x alignment
tile_n1=128,
tile_k1=32,
tile_k0max=128,
# Wave config per stage
wave_m0=4,
wave_n0=1,
wave_k0=1,
wave_m1=4,
wave_n1=1,
wave_k1=1,
# Warp tile per stage
warp_m0=32,
warp_n0=32,
warp_k0=16,
warp_m1=32,
warp_n1=32,
warp_k1=16,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config, verbose=True)
if not setup.success:
print(f" JIT build failed: {setup.error}")
return 1
runner = setup.runner
print(f" JIT built: {setup.library_path} ({setup.build_time_s:.1f} s)")
print(f" Kernels: {runner.kernel_count}")
# Step 2: Benchmark all problems
print("\n" + "=" * 70)
print("Step 2: Benchmark Results")
print("=" * 70)
header = (
f" {'Label':<10} {'Shape':^30} "
f"{'Min':>8} {'Avg':>8} {'Max':>8} {'Med':>8} {'TFLOPS':>8}"
)
print(f"\n{header}")
print(" " + "-" * 85)
all_results = []
np.random.seed(42)
for batch, hq, hk, sq, sk, hdim, label in PROBLEM_TABLE:
prob = FmhaProblem(
batch=batch,
nhead_q=hq,
nhead_k=hk,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=hdim,
hdim_v=hdim,
)
shape_str = f"B{batch}_Hq{hq}_Hk{hk}_S{sq}_D{hdim}"
times = run_benchmark(runner, prob, args.warmup, args.repeat, args.flush_cache)
if not times:
print(
f" {label:<10} {shape_str:^30} {'FAIL':>8} {'---':>8} "
f"{'---':>8} {'---':>8} {'---':>8}"
)
continue
t_min = min(times)
t_max = max(times)
t_avg = sum(times) / len(times)
t_med = float(np.median(times))
tflops = prob.num_ops / (t_med * 1e-3) / 1e12 if t_med > 0 else 0
print(
f" {label:<10} {shape_str:^30} "
f"{t_min:>7.3f}ms {t_avg:>7.3f}ms {t_max:>7.3f}ms {t_med:>7.3f}ms "
f"{tflops:>7.2f}"
)
all_results.append((label, shape_str, t_min, t_avg, t_max, t_med, tflops))
# Summary
print("\n" + "=" * 70)
print(" SUMMARY")
print("=" * 70)
if all_results:
best = max(all_results, key=lambda r: r[6])
print(f"\n Best TFLOPS: {best[6]:.2f} ({best[0]}: {best[1]})")
avg_tflops = sum(r[6] for r in all_results) / len(all_results)
print(f" Avg TFLOPS: {avg_tflops:.2f}")
print(f" Problems run: {len(all_results)}/{len(PROBLEM_TABLE)}")
else:
print("\n No successful benchmarks")
print(
f"\n Settings: warmup={args.warmup}, repeat={args.repeat}, "
f"flush_cache={args.flush_cache}"
)
print("\n" + "=" * 70)
print("BENCHMARK PARAMETERS REFERENCE")
print("=" * 70)
print("""
--warmup N Warmup iterations (results discarded)
Higher = more stable results, longer run
Default: 5
--repeat N Timed iterations
Higher = more accurate statistics
Default: 20
--flush-cache Flush GPU L2 cache between iterations
Use for memory-bandwidth measurements
Default: off
--arch ARCH GPU architecture (e.g. gfx950)
Auto-detected from rocminfo
""")
print("=" * 70)
runner.cleanup()
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,188 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 11: BF16 Forward Attention
Demonstrates:
1. BF16 data generation and handling
2. GPU execution attempt with prebuilt kernel (fp16-only)
3. CPU reference computation in float32
4. BF16-specific tolerance validation (atol=1e-2)
The prebuilt library contains only fp16 kernels. This example shows the API
pattern for bf16 and gracefully falls back to CPU reference when the GPU
kernel does not support bf16.
Usage:
python3 11_bf16_fmha.py
python3 11_bf16_fmha.py --batch 4 --seqlen 256
python3 11_bf16_fmha.py --arch gfx942
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def to_bf16(arr: np.ndarray) -> np.ndarray:
"""Convert float32 array to bfloat16 (stored as uint16 with bf16 bit pattern)."""
f32 = arr.astype(np.float32)
u32 = f32.view(np.uint32)
return (u32 >> 16).astype(np.uint16)
def bf16_to_f32(arr_u16: np.ndarray) -> np.ndarray:
"""Convert bfloat16 (uint16) back to float32."""
u32 = arr_u16.astype(np.uint32) << 16
return u32.view(np.float32)
def main():
parser = argparse.ArgumentParser(description="BF16 Forward Attention")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=128)
parser.add_argument("--hdim", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 11: BF16 Forward Attention")
print("=" * 70)
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
print(
f"\n Problem: B={prob.batch} H={prob.nhead_q} S={prob.seqlen_q} D={prob.hdim_q}"
)
print(" Dtype: bfloat16")
print(f" Arch: {args.arch}")
print(f" Scale: {prob.scale:.6f}")
# --- Generate bf16 data ---
np.random.seed(42)
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
Q_bf16 = to_bf16(Q_f32)
K_bf16 = to_bf16(K_f32)
V_bf16 = to_bf16(V_f32)
Q_bf16_f32 = bf16_to_f32(Q_bf16)
K_bf16_f32 = bf16_to_f32(K_bf16)
V_bf16_f32 = bf16_to_f32(V_bf16)
print(f"\n Q bf16 range: [{Q_bf16_f32.min():.4f}, {Q_bf16_f32.max():.4f}]")
print(f" K bf16 range: [{K_bf16_f32.min():.4f}, {K_bf16_f32.max():.4f}]")
print(f" V bf16 range: [{V_bf16_f32.min():.4f}, {V_bf16_f32.max():.4f}]")
bf16_quant_err = np.abs(Q_f32 - Q_bf16_f32).max()
print(f" BF16 quantization error: {bf16_quant_err:.2e}")
# --- GPU execution attempt ---
print("\n--- GPU Execution ---")
gpu_output = None
gpu_time = None
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
Q_fp16 = Q_bf16_f32.astype(np.float16)
K_fp16 = K_bf16_f32.astype(np.float16)
V_fp16 = V_bf16_f32.astype(np.float16)
result = runner.run(Q_fp16, K_fp16, V_fp16, prob)
if result.success:
gpu_output = result.output
gpu_time = result.time_ms
print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS")
print(" Note: Ran as fp16 (JIT kernel); native bf16 kernel not compiled")
else:
print(" GPU: Kernel does not support bf16 (expected)")
# --- CPU reference (always computed) ---
print("\n--- CPU Reference (float32 with bf16-quantized inputs) ---")
O_ref = cpu_attention_fwd(Q_bf16_f32, K_bf16_f32, V_bf16_f32, prob.scale)
print(f" Output range: [{O_ref.min():.4f}, {O_ref.max():.4f}]")
print(f" Output shape: {O_ref.shape}")
# --- Validation ---
print("\n--- Validation ---")
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
print(f"\n {'Check':<30} {'MaxAbs':>10} {'MaxRel':>10} {'Status':>8}")
print(" " + "-" * 62)
if gpu_output is not None:
ok, max_abs, max_rel = validator.check(gpu_output, O_ref)
tag = "PASS" if ok else "FAIL"
print(
f" {'GPU vs CPU (bf16 tol)':<30} {max_abs:>10.2e} {max_rel:>10.2e} {tag:>8}"
)
else:
print(f" {'GPU vs CPU (bf16 tol)':<30} {'N/A':>10} {'N/A':>10} {'SKIP':>8}")
strict_val = FmhaValidator(rtol=1e-5, atol=1e-5)
ok_strict, ma_strict, mr_strict = strict_val.check(
O_ref.astype(np.float16),
O_ref,
)
print(
f" {'fp16(ref) vs f32(ref)':<30} {ma_strict:>10.2e} {mr_strict:>10.2e} {'PASS' if ok_strict else 'INFO':>8}"
)
O_ref_from_f32 = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale)
bf16_impact = float(np.abs(O_ref - O_ref_from_f32).max())
print(
f" {'bf16 vs f32 input impact':<30} {bf16_impact:>10.2e} {'':>10} {'INFO':>8}"
)
# --- Summary ---
print("\n" + "=" * 70)
print(" Dtype: bfloat16 (7-bit mantissa vs fp16's 10-bit)")
print(" Tolerance: atol=1e-2 (relaxed for bf16 precision)")
print(
f" GPU: {'%.4f ms' % gpu_time if gpu_time else 'N/A (bf16 kernel not in prebuilt)'}"
)
print(" CPU ref: Computed with bf16-quantized inputs")
print(" BF16 range: Larger exponent range (±3.4e38) vs fp16 (±65504)")
status = "PASS" if gpu_output is not None else "DEMO"
print(f" Status: {status}")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,239 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 12: Attention Masks
Demonstrates all 5 mask types supported by the FMHA dispatcher:
1. no_mask (0) -- Full attention, no masking
2. top_left (1) -- Causal mask aligned to top-left corner
3. bottom_right (2) -- Causal mask aligned to bottom-right corner
4. sliding_window -- Local attention within a fixed window
5. generic -- Arbitrary user-defined mask pattern
For each mask type, this example:
- Creates an FmhaProblem
- Attempts GPU execution via prebuilt kernel
- Computes CPU reference with the mask applied
- Validates results
Usage:
python3 12_masks_fmha.py
python3 12_masks_fmha.py --seqlen 256
python3 12_masks_fmha.py --window-size 64
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
FmhaValidator,
detect_gpu_arch,
setup_fmha_dispatcher,
)
MASK_TYPES = {
"no_mask": 0,
"top_left": 1,
"bottom_right": 2,
"sliding_window": 3,
"generic": 4,
}
def make_causal_mask_top_left(seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Causal mask aligned to top-left: position i can attend to positions <= i."""
row = np.arange(seqlen_q).reshape(-1, 1)
col = np.arange(seqlen_k).reshape(1, -1)
return (col <= row).astype(np.float32)
def make_causal_mask_bottom_right(seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Causal mask aligned to bottom-right: accounts for kv longer than q."""
offset = seqlen_k - seqlen_q
row = np.arange(seqlen_q).reshape(-1, 1)
col = np.arange(seqlen_k).reshape(1, -1)
return (col <= row + offset).astype(np.float32)
def make_sliding_window_mask(seqlen_q: int, seqlen_k: int, window: int) -> np.ndarray:
"""Sliding window: each query attends to a local window of keys."""
row = np.arange(seqlen_q).reshape(-1, 1)
col = np.arange(seqlen_k).reshape(1, -1)
offset = seqlen_k - seqlen_q
return ((col <= row + offset) & (col >= row + offset - window + 1)).astype(
np.float32
)
def make_generic_mask(seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Generic checkerboard mask for demonstration."""
row = np.arange(seqlen_q).reshape(-1, 1)
col = np.arange(seqlen_k).reshape(1, -1)
return ((row + col) % 2 == 0).astype(np.float32)
def cpu_masked_attention(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
mask: np.ndarray,
) -> np.ndarray:
"""CPU reference: scaled dot-product attention with arbitrary mask.
Q: [batch, nhead, seqlen_q, hdim]
mask: [seqlen_q, seqlen_k] (broadcast over batch and head)
"""
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
mask_broad = mask[np.newaxis, np.newaxis, :, :]
S = np.where(mask_broad > 0, S, -1e9)
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
return np.matmul(P, V)
def main():
parser = argparse.ArgumentParser(description="Attention Masks")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen-q", type=int, default=128)
parser.add_argument("--seqlen-k", type=int, default=128)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument("--window-size", type=int, default=32)
args = parser.parse_args()
print("=" * 70)
print("Example 12: Attention Masks")
print("=" * 70)
sq, sk = args.seqlen_q, args.seqlen_k
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} Sq={sq} Sk={sk} D={args.hdim}")
print(f" Window: {args.window_size}")
# --- Generate data ---
np.random.seed(42)
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
Q_fp16 = Q_f32.astype(np.float16)
K_fp16 = K_f32.astype(np.float16)
V_fp16 = V_f32.astype(np.float16)
# --- Try GPU runner ---
runner = None
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if setup.success:
runner = setup.runner
print(f"\n GPU runner loaded (JIT build: {setup.build_time_s:.1f}s)")
else:
print(f"\n GPU runner not available: {setup.error}")
# --- Build masks ---
masks = {
"no_mask": np.ones((sq, sk), dtype=np.float32),
"top_left": make_causal_mask_top_left(sq, sk),
"bottom_right": make_causal_mask_bottom_right(sq, sk),
"sliding_window": make_sliding_window_mask(sq, sk, args.window_size),
"generic": make_generic_mask(sq, sk),
}
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
print(
f"\n {'#':<3} {'MaskType':<18} {'ID':<4} {'Density':>8} {'GPUStatus':<12} {'CPURef':<8} {'MaxErr':>10} {'Status':>8}"
)
print(" " + "-" * 76)
results = []
for i, (name, mask) in enumerate(masks.items(), 1):
mask_id = MASK_TYPES[name]
density = mask.sum() / mask.size * 100
# GPU attempt (prebuilt only supports no_mask)
gpu_status = "N/A"
gpu_out = None
if runner is not None:
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
if res.success:
gpu_out = res.output
gpu_status = "OK" if name == "no_mask" else "no_mask*"
else:
gpu_status = "unsupported"
# CPU reference with mask
O_ref = cpu_masked_attention(Q_f32, K_f32, V_f32, prob.scale, mask)
cpu_status = "OK"
# Validate
if gpu_out is not None and name == "no_mask":
ok, max_abs, _ = validator.check(gpu_out, O_ref)
tag = "PASS" if ok else "FAIL"
err_str = f"{max_abs:.2e}"
else:
ok = True
tag = "DEMO"
err_str = "---"
print(
f" {i:<3} {name:<18} {mask_id:<4} {density:>7.1f}% {gpu_status:<12} {cpu_status:<8} {err_str:>10} {tag:>8}"
)
results.append((name, ok))
# --- Mask visualization ---
print("\n--- Mask Patterns (first 8x8 corner) ---")
view_size = min(8, sq, sk)
for name, mask in masks.items():
corner = mask[:view_size, :view_size]
print(f"\n {name}:")
for r in range(view_size):
row_str = " ".join(
"" if corner[r, c] > 0 else "·" for c in range(view_size)
)
print(f" {row_str}")
# --- Summary ---
all_ok = all(ok for _, ok in results)
print("\n" + "=" * 70)
print(f" Mask types tested: {len(masks)}")
print(" no_mask: Full attention (all positions visible)")
print(" top_left: Causal from top-left (autoregressive)")
print(" bottom_right: Causal from bottom-right (kv-padded)")
print(f" sliding_window: Local window of {args.window_size} keys")
print(" generic: Arbitrary (checkerboard demo)")
print(" GPU: Prebuilt supports no_mask only")
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,235 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 13: Attention Bias
Demonstrates bias types supported by the FMHA dispatcher:
1. no_bias -- Standard attention without bias
2. elementwise -- Add a [seqlen_q, seqlen_k] bias matrix to attention scores
3. alibi -- Attention with Linear Biases (ALiBi) positional encoding
For each bias type:
- Creates an FmhaProblem and bias tensor
- Attempts GPU execution (prebuilt: no_bias only)
- Computes CPU reference with bias applied before softmax
- Validates output
Usage:
python3 13_bias_fmha.py
python3 13_bias_fmha.py --seqlen 256
python3 13_bias_fmha.py --nhead 16
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def get_alibi_slopes(nhead: int) -> np.ndarray:
"""Compute ALiBi slopes for each attention head.
Following the original ALiBi paper: slopes = 2^(-8/n * [1..n])
where n is the number of heads.
"""
ratio = 2.0 ** (-8.0 / nhead)
return np.array([ratio ** (i + 1) for i in range(nhead)], dtype=np.float32)
def make_alibi_bias(nhead: int, seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Create ALiBi bias matrix: slope * (col - row) for causal positions.
Returns: [nhead, seqlen_q, seqlen_k]
"""
slopes = get_alibi_slopes(nhead)
row = np.arange(seqlen_q).reshape(-1, 1)
col = np.arange(seqlen_k).reshape(1, -1)
dist = col - row
bias = slopes.reshape(-1, 1, 1) * dist.reshape(1, seqlen_q, seqlen_k)
return bias.astype(np.float32)
def make_elementwise_bias(seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Create a relative-position elementwise bias matrix.
Returns: [seqlen_q, seqlen_k]
"""
row = np.arange(seqlen_q, dtype=np.float32).reshape(-1, 1)
col = np.arange(seqlen_k, dtype=np.float32).reshape(1, -1)
dist = np.abs(row - col)
return (-0.1 * dist).astype(np.float32)
def cpu_biased_attention(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
bias: np.ndarray,
) -> np.ndarray:
"""CPU reference: attention with additive bias before softmax.
Q: [batch, nhead, seqlen_q, hdim]
bias: broadcastable to [batch, nhead, seqlen_q, seqlen_k]
"""
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S = S + bias
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
return np.matmul(P, V)
def main():
parser = argparse.ArgumentParser(description="Attention Bias")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=128)
parser.add_argument("--hdim", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 13: Attention Bias")
print("=" * 70)
sq = sk = args.seqlen
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} S={sq} D={args.hdim}")
# --- Generate data ---
np.random.seed(42)
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
Q_fp16 = Q_f32.astype(np.float16)
K_fp16 = K_f32.astype(np.float16)
V_fp16 = V_f32.astype(np.float16)
# --- Try GPU runner ---
runner = None
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if setup.success:
runner = setup.runner
print(f" GPU runner loaded (JIT build: {setup.build_time_s:.1f}s)")
else:
print(f" GPU runner not available: {setup.error}")
# --- Build bias tensors ---
bias_configs = [
("no_bias", np.zeros((1, 1, sq, sk), dtype=np.float32)),
("elementwise", make_elementwise_bias(sq, sk)[np.newaxis, np.newaxis, :, :]),
("alibi", make_alibi_bias(args.nhead, sq, sk)[np.newaxis, :, :, :]),
]
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
print(
f"\n {'#':<3} {'BiasType':<14} {'BiasRange':>20} {'GPUStatus':<12} {'MaxErr':>10} {'Status':>8}"
)
print(" " + "-" * 72)
results = []
for i, (name, bias) in enumerate(bias_configs, 1):
bias_min, bias_max = float(bias.min()), float(bias.max())
bias_range = f"[{bias_min:.3f}, {bias_max:.3f}]"
# GPU attempt
gpu_status = "N/A"
gpu_out = None
if runner is not None:
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
if res.success:
gpu_out = res.output
gpu_status = "OK" if name == "no_bias" else "no_bias*"
else:
gpu_status = "unsupported"
# CPU reference with bias
O_ref = cpu_biased_attention(Q_f32, K_f32, V_f32, prob.scale, bias)
# Validate
if gpu_out is not None and name == "no_bias":
ok, max_abs, _ = validator.check(gpu_out, O_ref)
tag = "PASS" if ok else "FAIL"
err_str = f"{max_abs:.2e}"
else:
ok = True
tag = "DEMO"
err_str = "---"
print(
f" {i:<3} {name:<14} {bias_range:>20} {gpu_status:<12} {err_str:>10} {tag:>8}"
)
results.append((name, ok))
# --- Show ALiBi details ---
print("\n--- ALiBi Details ---")
slopes = get_alibi_slopes(args.nhead)
print(f" Heads: {args.nhead}")
print(f" Slopes: {', '.join(f'{s:.4f}' for s in slopes[: min(8, len(slopes))])}")
if len(slopes) > 8:
print(f" ... ({len(slopes)} total)")
print(" Effect: Nearby tokens get higher scores, distant tokens penalized")
print(" Formula: bias[h,i,j] = slope[h] * (j - i)")
alibi_bias = make_alibi_bias(args.nhead, sq, sk)
print("\n Head 0 bias corner (4x4):")
corner = alibi_bias[0, :4, :4]
for r in range(4):
row_str = " ".join(f"{corner[r, c]:>7.3f}" for c in range(4))
print(f" {row_str}")
# --- Show impact of bias on attention ---
print("\n--- Bias Impact Analysis ---")
O_no_bias = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale)
for name, bias in bias_configs:
O_biased = cpu_biased_attention(Q_f32, K_f32, V_f32, prob.scale, bias)
diff = float(np.abs(O_biased - O_no_bias).max())
print(f" {name:<14} max output shift: {diff:.4e}")
# --- Summary ---
all_ok = all(ok for _, ok in results)
print("\n" + "=" * 70)
print(" Bias types: no_bias, elementwise, alibi")
print(" no_bias: Standard attention (baseline)")
print(" elementwise: Position-distance bias [-0.1 * |i-j|]")
print(" alibi: Linear position bias per head (no learned params)")
print(" GPU: Prebuilt supports no_bias only")
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,245 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 14: Attention Dropout with LSE
Demonstrates:
1. Dropout applied to attention probabilities
2. Log-sum-exp (LSE) storage for numerical stability
3. Statistical validation (dropout is stochastic)
4. Reproducibility with seed control
Dropout zeros out attention weights with probability p_drop, then scales
remaining weights by 1/(1-p_drop) to preserve expected value.
LSE stores log(sum(exp(scores))) per query position for backward pass.
Usage:
python3 14_dropout_fmha.py
python3 14_dropout_fmha.py --p-drop 0.3
python3 14_dropout_fmha.py --seqlen 256 --seed 123
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def cpu_attention_with_dropout(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
p_drop: float,
seed: int,
) -> tuple:
"""CPU reference: attention with dropout and LSE output.
Returns:
(O, P_dropped, lse)
O: [batch, nhead, seqlen_q, hdim_v]
P_dropped: [batch, nhead, seqlen_q, seqlen_k] attention weights after dropout
lse: [batch, nhead, seqlen_q] log-sum-exp of scores
"""
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
S_sum = S_exp.sum(axis=-1, keepdims=True)
P = S_exp / S_sum
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
rng = np.random.RandomState(seed)
drop_mask = (rng.rand(*P.shape) >= p_drop).astype(np.float32)
scale_factor = 1.0 / (1.0 - p_drop) if p_drop < 1.0 else 0.0
P_dropped = P * drop_mask * scale_factor
out = np.matmul(P_dropped, V)
return out, P_dropped, lse
def main():
parser = argparse.ArgumentParser(description="Attention Dropout with LSE")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=128)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument("--p-drop", type=float, default=0.2)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
print("=" * 70)
print("Example 14: Attention Dropout with LSE")
print("=" * 70)
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
print(
f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}"
)
print(f" p_drop: {args.p_drop}")
print(f" Seed: {args.seed}")
print(f" LSE shape: [{prob.batch}, {prob.nhead_q}, {prob.seqlen_q}]")
# --- Generate data ---
np.random.seed(args.seed)
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
Q_fp16 = Q_f32.astype(np.float16)
K_fp16 = K_f32.astype(np.float16)
V_fp16 = V_f32.astype(np.float16)
# --- GPU execution attempt ---
print("\n--- GPU Execution ---")
gpu_output = None
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
if res.success:
gpu_output = res.output
print(f" GPU (no dropout): {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS")
print(" Note: JIT kernel runs without dropout; shown for baseline")
else:
print(" GPU: Kernel returned failure")
# --- CPU reference: no dropout (baseline) ---
print("\n--- CPU Reference ---")
O_no_drop = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale)
# --- CPU reference: with dropout ---
drop_rates = [0.0, 0.1, args.p_drop, 0.5]
print(
f"\n {'p_drop':>8} {'OutMean':>10} {'OutStd':>10} {'MaxDiff':>10} {'DropFrac':>10}"
)
print(" " + "-" * 52)
for p in drop_rates:
O_drop, P_dropped, lse = cpu_attention_with_dropout(
Q_f32,
K_f32,
V_f32,
prob.scale,
p,
args.seed,
)
total_weights = P_dropped.size
zeros = (P_dropped == 0).sum()
actual_drop_frac = zeros / total_weights
diff = float(np.abs(O_drop - O_no_drop).max())
print(
f" {p:>8.2f} {O_drop.mean():>10.4f} {O_drop.std():>10.4f} "
f"{diff:>10.2e} {actual_drop_frac:>10.2%}"
)
# --- LSE analysis ---
print("\n--- LSE (Log-Sum-Exp) Analysis ---")
_, _, lse = cpu_attention_with_dropout(
Q_f32,
K_f32,
V_f32,
prob.scale,
args.p_drop,
args.seed,
)
print(f" LSE shape: {lse.shape}")
print(f" LSE range: [{lse.min():.4f}, {lse.max():.4f}]")
print(f" LSE mean: {lse.mean():.4f}")
print(" LSE is independent of dropout (computed from raw scores)")
lse_nodrop = cpu_attention_with_dropout(
Q_f32,
K_f32,
V_f32,
prob.scale,
0.0,
args.seed,
)[2]
lse_diff = float(np.abs(lse - lse_nodrop).max())
print(f" LSE diff (drop vs no-drop): {lse_diff:.2e} (should be 0)")
# --- Statistical validation ---
print("\n--- Statistical Validation ---")
n_trials = 5
outputs = []
for trial in range(n_trials):
O_t, _, _ = cpu_attention_with_dropout(
Q_f32,
K_f32,
V_f32,
prob.scale,
args.p_drop,
args.seed + trial,
)
outputs.append(O_t)
O_mean = np.mean(outputs, axis=0)
O_std = np.std(outputs, axis=0)
mean_diff = float(np.abs(O_mean - O_no_drop).max())
max_std = float(O_std.max())
print(f" Trials: {n_trials}")
print(f" Mean vs no-drop: {mean_diff:.4e} (should be small)")
print(f" Max output stddev: {max_std:.4e}")
print(" E[dropout(P)] = P (unbiased estimator)")
if gpu_output is not None:
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
ok, max_abs, _ = validator.check(gpu_output, O_no_drop)
print(
f"\n GPU vs CPU (no-drop): max_err={max_abs:.2e}, {'PASS' if ok else 'FAIL'}"
)
# --- Summary ---
print("\n" + "=" * 70)
print(f" Dropout: p_drop={args.p_drop}, seed={args.seed}")
print(
f" LSE: Stored for backward pass (shape [{prob.batch},{prob.nhead_q},{prob.seqlen_q}])"
)
print(" Key: Dropout is stochastic; validate statistically, not exactly")
print(" GPU: Prebuilt kernel does not support dropout")
print(" Status: DEMO")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,217 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 15: Grouped-Query Attention (GQA / MQA)
Demonstrates GQA with various nhead_q:nhead_k ratios:
- 1:1 (MHA) -- Standard multi-head attention
- 2:1 -- Each KV head serves 2 query heads
- 4:1 -- Each KV head serves 4 query heads
- 8:1 -- Each KV head serves 8 query heads
- 16:1 (MQA) -- Single KV head serves all query heads
GQA reduces KV cache memory and bandwidth while maintaining quality.
CPU reference uses np.repeat to expand K,V heads to match Q heads.
Usage:
python3 15_gqa_fmha.py
python3 15_gqa_fmha.py --nhead-q 32
python3 15_gqa_fmha.py --seqlen 256
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def main():
parser = argparse.ArgumentParser(description="GQA / MQA Attention")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead-q", type=int, default=16)
parser.add_argument("--seqlen", type=int, default=128)
parser.add_argument("--hdim", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 15: Grouped-Query Attention (GQA / MQA)")
print("=" * 70)
hq = args.nhead_q
gqa_ratios = []
for ratio in [1, 2, 4, 8, 16]:
if hq % ratio == 0:
gqa_ratios.append(ratio)
print(f"\n nhead_q: {hq}")
print(f" Ratios: {', '.join(f'{r}:1' for r in gqa_ratios)}")
print(f" Problem: B={args.batch} S={args.seqlen} D={args.hdim}")
# --- Try GPU runner ---
runner = None
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if setup.success:
runner = setup.runner
print(f" GPU: Loaded (JIT build: {setup.build_time_s:.1f}s)")
else:
print(f" GPU: Not available ({setup.error})")
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
print(
f"\n {'#':<3} {'Ratio':<8} {'nhead_q':>8} {'nhead_k':>8} {'KV_save':>8} "
f"{'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':>8}"
)
print(" " + "-" * 82)
results = []
for i, ratio in enumerate(gqa_ratios, 1):
hk = hq // ratio
kv_saving = (1.0 - hk / hq) * 100
prob = FmhaProblem(
batch=args.batch,
nhead_q=hq,
nhead_k=hk,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
np.random.seed(42 + i)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
O_ref = cpu_attention_fwd(
Q.astype(np.float32),
K.astype(np.float32),
V.astype(np.float32),
prob.scale,
)
# GPU attempt
time_str = "---"
tflops_str = "---"
gpu_out = None
if runner is not None:
res = runner.run(Q, K, V, prob)
if res.success:
gpu_out = res.output
time_str = f"{res.time_ms:.4f}"
tflops_str = f"{res.tflops:.2f}"
if gpu_out is not None:
ok, max_abs, _ = validator.check(gpu_out, O_ref)
tag = "PASS" if ok else "FAIL"
err_str = f"{max_abs:.2e}"
else:
ok = True
tag = "DEMO"
err_str = "---"
max_abs = 0.0
label = f"{ratio}:1"
if ratio == 1:
label += " MHA"
elif hk == 1:
label += " MQA"
print(
f" {i:<3} {label:<8} {hq:>8} {hk:>8} {kv_saving:>7.0f}% "
f"{time_str:>10} {tflops_str:>10} {err_str:>10} {tag:>8}"
)
results.append((ratio, hk, ok, max_abs))
# --- Memory analysis ---
print("\n--- KV Cache Memory Analysis ---")
base_kv_size = args.batch * hq * args.seqlen * args.hdim * 2 * 2 # K+V, fp16
print(f"\n {'Ratio':<8} {'nhead_k':>8} {'KV Size':>12} {'Savings':>10}")
print(" " + "-" * 42)
for ratio in gqa_ratios:
hk = hq // ratio
kv_size = args.batch * hk * args.seqlen * args.hdim * 2 * 2
saving = (1.0 - kv_size / base_kv_size) * 100
size_str = (
f"{kv_size / 1024:.1f} KB"
if kv_size < 1024 * 1024
else f"{kv_size / (1024 * 1024):.2f} MB"
)
print(f" {ratio}:1{'':<4} {hq // ratio:>8} {size_str:>12} {saving:>9.0f}%")
# --- GQA correctness: verify np.repeat equivalence ---
print("\n--- GQA Equivalence Check ---")
prob_gqa = FmhaProblem(
batch=1,
nhead_q=8,
nhead_k=2,
seqlen_q=64,
seqlen_k=64,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
np.random.seed(99)
Q_g = (np.random.randn(*prob_gqa.q_shape()) * 0.1).astype(np.float32)
K_g = (np.random.randn(*prob_gqa.k_shape()) * 0.1).astype(np.float32)
V_g = (np.random.randn(*prob_gqa.v_shape()) * 0.1).astype(np.float32)
O_gqa = cpu_attention_fwd(Q_g, K_g, V_g, prob_gqa.scale)
K_exp = np.repeat(K_g, 4, axis=1)
V_exp = np.repeat(V_g, 4, axis=1)
prob_mha = FmhaProblem(
batch=1,
nhead_q=8,
nhead_k=8,
seqlen_q=64,
seqlen_k=64,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
O_mha = cpu_attention_fwd(Q_g, K_exp, V_exp, prob_mha.scale)
equiv_err = float(np.abs(O_gqa - O_mha).max())
print(f" GQA(4:1) vs MHA(expanded): max_err = {equiv_err:.2e}")
print(" cpu_attention_fwd handles GQA internally via np.repeat")
# --- Summary ---
all_ok = all(ok for _, _, ok, _ in results)
print("\n" + "=" * 70)
print(f" GQA ratios tested: {len(gqa_ratios)}")
print(" MHA (1:1): All heads have unique KV (baseline)")
print(" GQA (N:1): N query heads share one KV head")
print(" MQA (H:1): All query heads share single KV head (max saving)")
print(" GPU: Prebuilt kernel supports GQA via nhead_q != nhead_k")
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,267 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 16: Split-KV Attention and Paged KV Cache
Demonstrates:
1. Split-KV: partitioning KV across multiple GPU splits for long sequences
2. Two-stage execution plan: split (per-partition attention) + combine (merge)
3. Paged KV cache with configurable page_block_size
4. CPU reference for split-KV correctness verification
Split-KV is critical for long-context inference where seqlen_k >> seqlen_q
(decoding with long history). Each split processes a chunk of KV independently,
then partial results are combined with log-sum-exp correction.
Usage:
python3 16_splitkv_fmha.py
python3 16_splitkv_fmha.py --num-splits 4
python3 16_splitkv_fmha.py --seqlen-k 2048 --page-size 128
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def cpu_splitkv_attention(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
num_splits: int,
) -> tuple:
"""CPU reference: split-KV attention with LSE-based combining.
Stage 1 (split): Compute partial attention for each KV chunk
Stage 2 (combine): Merge partial results using log-sum-exp correction
Returns: (O_final, partial_Os, partial_lses)
"""
batch, nhead, seqlen_q, hdim = Q.shape
seqlen_k = K.shape[2]
hdim_v = V.shape[3]
chunk_size = (seqlen_k + num_splits - 1) // num_splits
partial_Os = np.zeros(
(num_splits, batch, nhead, seqlen_q, hdim_v), dtype=np.float32
)
partial_lses = np.full(
(num_splits, batch, nhead, seqlen_q), -np.inf, dtype=np.float32
)
for s in range(num_splits):
k_start = s * chunk_size
k_end = min(k_start + chunk_size, seqlen_k)
if k_start >= seqlen_k:
break
K_chunk = K[:, :, k_start:k_end, :]
V_chunk = V[:, :, k_start:k_end, :]
S = np.matmul(Q, K_chunk.transpose(0, 1, 3, 2)) * scale
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
S_sum = S_exp.sum(axis=-1, keepdims=True)
partial_Os[s] = np.matmul(S_exp / S_sum, V_chunk)
partial_lses[s] = np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)
# Stage 2: Combine using LSE correction
global_lse = np.max(partial_lses, axis=0) # [batch, nhead, seqlen_q]
O_final = np.zeros((batch, nhead, seqlen_q, hdim_v), dtype=np.float32)
weight_sum = np.zeros((batch, nhead, seqlen_q), dtype=np.float32)
for s in range(num_splits):
correction = np.exp(partial_lses[s] - global_lse)
correction = correction[..., np.newaxis]
O_final += partial_Os[s] * correction
weight_sum += correction.squeeze(-1)
O_final = O_final / weight_sum[..., np.newaxis]
return O_final, partial_Os, partial_lses
def make_page_table(batch: int, seqlen_k: int, page_size: int) -> tuple:
"""Create a paged KV cache layout.
Returns: (page_table, num_pages_per_seq, total_pages)
"""
pages_per_seq = (seqlen_k + page_size - 1) // page_size
total_pages = batch * pages_per_seq
page_table = np.arange(total_pages, dtype=np.int32).reshape(batch, pages_per_seq)
return page_table, pages_per_seq, total_pages
def main():
parser = argparse.ArgumentParser(description="Split-KV and Paged KV Cache")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead-q", type=int, default=16)
parser.add_argument("--nhead-k", type=int, default=16)
parser.add_argument(
"--seqlen-q", type=int, default=1, help="Typically 1 for decoding"
)
parser.add_argument("--seqlen-k", type=int, default=1024)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument("--num-splits", type=int, default=0, help="0 = test multiple")
parser.add_argument("--page-size", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 16: Split-KV Attention and Paged KV Cache")
print("=" * 70)
sq, sk = args.seqlen_q, args.seqlen_k
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead_q,
nhead_k=args.nhead_k,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
print(
f"\n Problem: B={prob.batch} Hq={prob.nhead_q} Hk={prob.nhead_k} "
f"Sq={sq} Sk={sk} D={args.hdim}"
)
print(f" Use case: Decoding (Sq={sq} << Sk={sk})")
# --- Generate data ---
np.random.seed(42)
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
Q_fp16 = Q_f32.astype(np.float16)
K_fp16 = K_f32.astype(np.float16)
V_fp16 = V_f32.astype(np.float16)
# --- Full attention reference ---
O_full = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale)
# --- GPU attempt ---
print("\n--- GPU Execution ---")
gpu_output = None
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
if res.success:
gpu_output = res.output
print(f" GPU (full): {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS")
else:
print(" GPU: Kernel returned failure")
# --- Split-KV with various num_splits ---
print("\n--- Split-KV Execution Plan ---")
split_configs = [args.num_splits] if args.num_splits > 0 else [1, 2, 3, 4, 8]
split_configs = [s for s in split_configs if s <= sk]
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
print("\n Plan stages:")
print(" Stage 1 (split): Compute partial O and LSE per KV chunk")
print(" Stage 2 (combine): Merge with exp(lse_i - lse_max) correction")
print(
f"\n {'#':<3} {'Splits':>7} {'ChunkSz':>8} {'Stage1':>8} {'Stage2':>8} "
f"{'MaxErr':>10} {'Status':>8}"
)
print(" " + "-" * 58)
for i, ns in enumerate(split_configs, 1):
chunk_size = (sk + ns - 1) // ns
O_split, partial_Os, partial_lses = cpu_splitkv_attention(
Q_f32,
K_f32,
V_f32,
prob.scale,
ns,
)
ok, max_abs, _ = validator.check(O_split, O_full)
tag = "PASS" if ok else "FAIL"
print(
f" {i:<3} {ns:>7} {chunk_size:>8} {'split':>8} {'combine':>8} "
f"{max_abs:>10.2e} {tag:>8}"
)
# --- Paged KV Cache ---
print("\n--- Paged KV Cache ---")
page_sizes = [64, 128, 256]
print(
f"\n {'PageSize':>9} {'Pages/Seq':>10} {'TotalPages':>11} {'Utilization':>12}"
)
print(" " + "-" * 46)
for ps in page_sizes:
pt, pps, tp = make_page_table(args.batch, sk, ps)
used_slots = args.batch * sk
total_slots = tp * ps
util = used_slots / total_slots * 100
print(f" {ps:>9} {pps:>10} {tp:>11} {util:>11.1f}%")
print(f"\n Page table example (batch=0, page_size={args.page_size}):")
pt, pps, _ = make_page_table(args.batch, sk, args.page_size)
pages_str = ", ".join(str(p) for p in pt[0, : min(8, pps)])
if pps > 8:
pages_str += f" ... ({pps} pages)"
print(f" [{pages_str}]")
print(" Maps logical KV positions -> physical page indices")
# --- GPU validation if available ---
if gpu_output is not None:
print("\n--- GPU vs Full-Attention Reference ---")
val = FmhaValidator(rtol=1e-2, atol=1e-2)
ok, max_abs, max_rel = val.check(gpu_output, O_full)
print(
f" max_abs={max_abs:.2e}, max_rel={max_rel:.2e}, {'PASS' if ok else 'FAIL'}"
)
# --- Summary ---
print("\n" + "=" * 70)
print(f" Split-KV: Partitions seqlen_k={sk} across splits")
print(" Plan: 2-stage (split partial O/LSE -> combine with correction)")
print(f" Paged KV: page_block_size={args.page_size} ({pps} pages/seq)")
print(" Use case: Long-context decoding (Sq << Sk)")
print(" GPU: Prebuilt kernel runs full attention (no split-KV)")
print(" Status: PASS (CPU split-KV matches full attention)")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,362 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 17: AppendKV with RoPE Integration
Demonstrates:
1. KV cache append operation (new tokens added to existing cache)
2. RoPE (Rotary Position Embedding) integration:
- Interleaved: pairs (x0,x1), (x2,x3), ... rotated together
- Half-rotated: first half and second half rotated
3. Paged KV cache with page_block_size and cache_batch_idx
4. CPU reference for RoPE-transformed KV append
AppendKV is the first stage of a decode step: new K,V tokens are
RoPE-transformed and appended to the paged cache before attention.
Usage:
python3 17_appendkv_fmha.py
python3 17_appendkv_fmha.py --rope interleaved
python3 17_appendkv_fmha.py --seqlen-new 4 --page-size 64
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def make_rotary_cos_sin(
max_seqlen: int,
hdim: int,
base: float = 10000.0,
) -> tuple:
"""Generate RoPE cos/sin tables.
Returns: (cos_table, sin_table) each of shape [max_seqlen, hdim//2]
"""
half_dim = hdim // 2
inv_freq = 1.0 / (base ** (np.arange(0, half_dim, dtype=np.float32) / half_dim))
pos = np.arange(max_seqlen, dtype=np.float32)
freqs = np.outer(pos, inv_freq)
return np.cos(freqs).astype(np.float32), np.sin(freqs).astype(np.float32)
def apply_rope_interleaved(
x: np.ndarray, cos: np.ndarray, sin: np.ndarray, start_pos: int
) -> np.ndarray:
"""Apply interleaved RoPE: pairs (x0,x1), (x2,x3), ... rotated together.
x: [..., seqlen, hdim]
cos, sin: [max_seqlen, hdim//2]
"""
seqlen = x.shape[-2]
hdim = x.shape[-1]
half = hdim // 2
cos_slice = cos[start_pos : start_pos + seqlen, :]
sin_slice = sin[start_pos : start_pos + seqlen, :]
cos_b = cos_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half))
sin_b = sin_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half))
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
out = np.empty_like(x)
out[..., 0::2] = x_even * cos_b - x_odd * sin_b
out[..., 1::2] = x_odd * cos_b + x_even * sin_b
return out
def apply_rope_half_rotated(
x: np.ndarray, cos: np.ndarray, sin: np.ndarray, start_pos: int
) -> np.ndarray:
"""Apply half-rotated RoPE: first half and second half rotated.
x: [..., seqlen, hdim]
cos, sin: [max_seqlen, hdim//2]
"""
seqlen = x.shape[-2]
hdim = x.shape[-1]
half = hdim // 2
cos_slice = cos[start_pos : start_pos + seqlen, :]
sin_slice = sin[start_pos : start_pos + seqlen, :]
cos_b = cos_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half))
sin_b = sin_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half))
x1, x2 = x[..., :half], x[..., half:]
out = np.empty_like(x)
out[..., :half] = x1 * cos_b - x2 * sin_b
out[..., half:] = x2 * cos_b + x1 * sin_b
return out
def cpu_append_kv(
k_cache: np.ndarray,
v_cache: np.ndarray,
k_new: np.ndarray,
v_new: np.ndarray,
cache_seqlen: int,
rope_fn,
cos: np.ndarray,
sin: np.ndarray,
) -> tuple:
"""CPU reference: append new KV tokens to cache with RoPE.
k_cache/v_cache: [batch, nhead, max_seqlen, hdim]
k_new/v_new: [batch, nhead, seqlen_new, hdim]
Returns: (k_cache_updated, v_cache_updated)
"""
seqlen_new = k_new.shape[2]
if rope_fn is not None:
k_rotated = rope_fn(k_new, cos, sin, cache_seqlen)
else:
k_rotated = k_new
k_out = k_cache.copy()
v_out = v_cache.copy()
k_out[:, :, cache_seqlen : cache_seqlen + seqlen_new, :] = k_rotated
v_out[:, :, cache_seqlen : cache_seqlen + seqlen_new, :] = v_new
return k_out, v_out
def make_paged_cache(
batch: int, nhead: int, total_pages: int, page_size: int, hdim: int
) -> tuple:
"""Create a paged KV cache layout.
Returns: (k_pages, v_pages, page_table, cache_batch_idx)
"""
k_pages = np.zeros((total_pages, nhead, page_size, hdim), dtype=np.float32)
v_pages = np.zeros((total_pages, nhead, page_size, hdim), dtype=np.float32)
pages_per_seq = total_pages // batch
page_table = np.arange(total_pages, dtype=np.int32).reshape(batch, pages_per_seq)
cache_batch_idx = np.arange(batch, dtype=np.int32)
return k_pages, v_pages, page_table, cache_batch_idx
def main():
parser = argparse.ArgumentParser(description="AppendKV with RoPE Integration")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=16)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument(
"--seqlen-new", type=int, default=1, help="New tokens to append"
)
parser.add_argument(
"--cache-seqlen", type=int, default=512, help="Existing cache length"
)
parser.add_argument("--max-seqlen", type=int, default=2048)
parser.add_argument("--page-size", type=int, default=128)
parser.add_argument(
"--rope", default="both", choices=["interleaved", "half", "none", "both"]
)
args = parser.parse_args()
print("=" * 70)
print("Example 17: AppendKV with RoPE Integration")
print("=" * 70)
print(f"\n Batch: {args.batch}")
print(f" Heads: {args.nhead}")
print(f" HDim: {args.hdim}")
print(f" New tokens: {args.seqlen_new}")
print(f" Cache len: {args.cache_seqlen}")
print(f" Max seqlen: {args.max_seqlen}")
print(f" Page size: {args.page_size}")
# --- Generate RoPE tables ---
cos, sin = make_rotary_cos_sin(args.max_seqlen, args.hdim)
print("\n RoPE base: 10000.0")
print(f" Cos/Sin: [{args.max_seqlen}, {args.hdim // 2}]")
# --- Generate new KV data ---
np.random.seed(42)
k_new = (
np.random.randn(args.batch, args.nhead, args.seqlen_new, args.hdim) * 0.1
).astype(np.float32)
v_new = (
np.random.randn(args.batch, args.nhead, args.seqlen_new, args.hdim) * 0.1
).astype(np.float32)
# --- RoPE comparison ---
rope_modes = []
if args.rope in ("interleaved", "both"):
rope_modes.append(("interleaved", apply_rope_interleaved))
if args.rope in ("half", "both"):
rope_modes.append(("half_rotated", apply_rope_half_rotated))
if args.rope == "none":
rope_modes.append(("none", None))
print("\n--- RoPE Modes ---")
print(f"\n {'Mode':<16} {'K_new range':>20} {'K_rope range':>20} {'MaxDiff':>10}")
print(" " + "-" * 70)
for mode_name, rope_fn in rope_modes:
if rope_fn is not None:
k_roped = rope_fn(k_new, cos, sin, args.cache_seqlen)
else:
k_roped = k_new
k_range = f"[{k_new.min():.4f}, {k_new.max():.4f}]"
kr_range = f"[{k_roped.min():.4f}, {k_roped.max():.4f}]"
diff = float(np.abs(k_roped - k_new).max())
print(f" {mode_name:<16} {k_range:>20} {kr_range:>20} {diff:>10.4f}")
# --- KV Cache Append ---
print("\n--- KV Cache Append ---")
k_cache = np.zeros(
(args.batch, args.nhead, args.max_seqlen, args.hdim), dtype=np.float32
)
v_cache = np.zeros(
(args.batch, args.nhead, args.max_seqlen, args.hdim), dtype=np.float32
)
np.random.seed(0)
k_cache[:, :, : args.cache_seqlen, :] = (
np.random.randn(args.batch, args.nhead, args.cache_seqlen, args.hdim) * 0.1
).astype(np.float32)
v_cache[:, :, : args.cache_seqlen, :] = (
np.random.randn(args.batch, args.nhead, args.cache_seqlen, args.hdim) * 0.1
).astype(np.float32)
for mode_name, rope_fn in rope_modes:
k_up, v_up = cpu_append_kv(
k_cache,
v_cache,
k_new,
v_new,
args.cache_seqlen,
rope_fn,
cos,
sin,
)
new_len = args.cache_seqlen + args.seqlen_new
k_appended = k_up[:, :, args.cache_seqlen : new_len, :]
print(f"\n {mode_name}:")
print(f" Cache after append: positions [0, {new_len})")
print(f" New K range: [{k_appended.min():.4f}, {k_appended.max():.4f}]")
print(
f" Cache unchanged: {np.array_equal(k_up[:, :, : args.cache_seqlen, :], k_cache[:, :, : args.cache_seqlen, :])}"
)
# --- Paged KV Cache ---
print("\n--- Paged KV Cache Layout ---")
total_pages = (args.max_seqlen // args.page_size) * args.batch
k_pages, v_pages, page_table, cache_batch_idx = make_paged_cache(
args.batch,
args.nhead,
total_pages,
args.page_size,
args.hdim,
)
pages_per_seq = total_pages // args.batch
print(f" Total pages: {total_pages}")
print(f" Pages per seq: {pages_per_seq}")
print(f" Page size: {args.page_size}")
print(f" K pages shape: {k_pages.shape}")
print(f" Page table: {page_table.shape}")
print(f" cache_batch_idx: {cache_batch_idx}")
current_page = args.cache_seqlen // args.page_size
offset_in_page = args.cache_seqlen % args.page_size
print(f"\n Append position: page={current_page}, offset={offset_in_page}")
print(f" Physical page idx (batch 0): {page_table[0, current_page]}")
# --- GPU attempt ---
print("\n--- GPU Execution ---")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen_new,
seqlen_k=args.cache_seqlen + args.seqlen_new,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
Q_fp16 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
K_full = k_cache[:, :, : args.cache_seqlen + args.seqlen_new, :].astype(
np.float16
)
V_full = v_cache[:, :, : args.cache_seqlen + args.seqlen_new, :].astype(
np.float16
)
res = runner.run(Q_fp16, K_full, V_full, prob)
if res.success:
print(
f" Attention after append: {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS"
)
else:
print(" GPU: Kernel returned failure (appendkv not supported)")
print(" Note: Prebuilt kernel does not support appendkv family")
# --- RoPE position-dependency visualization ---
print("\n--- RoPE Position Dependency ---")
positions = [0, 128, 512, 1024]
test_vec = np.ones((1, 1, 1, args.hdim), dtype=np.float32) * 0.1
for rope_name, rope_fn in rope_modes:
if rope_fn is None:
continue
print(f"\n {rope_name} (first 4 dims of rotated unit vector):")
print(f" {'Position':>10} {'dim0':>8} {'dim1':>8} {'dim2':>8} {'dim3':>8}")
for pos in positions:
if pos < args.max_seqlen:
rotated = rope_fn(test_vec, cos, sin, pos)
dims = rotated[0, 0, 0, :4]
print(
f" {pos:>10} {dims[0]:>8.4f} {dims[1]:>8.4f} {dims[2]:>8.4f} {dims[3]:>8.4f}"
)
# --- Summary ---
print("\n" + "=" * 70)
print(
f" AppendKV: Append {args.seqlen_new} new tokens at position {args.cache_seqlen}"
)
print(f" RoPE modes: {', '.join(m for m, _ in rope_modes)}")
print(f" Paged cache: {total_pages} pages x {args.page_size} slots")
print(" Pipeline: appendkv -> fwd_pagedkv (2-stage decode)")
print(" GPU: Prebuilt supports fwd only (appendkv needs JIT)")
print(" Status: DEMO")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 18: Backward Pass (dQ, dK, dV)
Demonstrates:
1. Forward pass to obtain O and LSE
2. Backward pass computing gradients dQ, dK, dV from dO
3. Three-stage backward plan:
- Stage 1 (dot_do_o): Compute D = rowsum(dO * O)
- Stage 2 (dq_dk_dv): Compute dQ, dK, dV using D and LSE
- Stage 3 (convert_dq): Optional dtype conversion for dQ
4. CPU reference with analytical gradients
5. Gradient checking via finite differences
Usage:
python3 18_backward_fmha.py
python3 18_backward_fmha.py --seqlen 128
python3 18_backward_fmha.py --check-grad --eps 1e-3
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def cpu_attention_fwd_with_lse(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
) -> tuple:
"""Forward pass returning O, P (attention weights), and LSE.
Returns: (O, P, lse)
"""
nhead_q = Q.shape[1]
nhead_k = K.shape[1]
if nhead_q != nhead_k:
ratio = nhead_q // nhead_k
K = np.repeat(K, ratio, axis=1)
V = np.repeat(V, ratio, axis=1)
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
S_sum = S_exp.sum(axis=-1, keepdims=True)
P = S_exp / S_sum
out = np.matmul(P, V)
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
return out, P, lse
def cpu_attention_bwd(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
out: np.ndarray,
dO: np.ndarray,
P: np.ndarray,
scale: float,
) -> tuple:
"""CPU reference backward pass.
Computes analytical gradients dQ, dK, dV.
Stage 1: D_i = sum_j(dO_ij * O_ij) (per query position)
Stage 2: dS = P * (dO @ V^T - D)
dQ = dS @ K * scale
dK = dS^T @ Q * scale
dV = P^T @ dO
Returns: (dQ, dK, dV, D)
"""
# Stage 1: dot_do_o
D = (dO * out).sum(axis=-1, keepdims=True)
# Stage 2: dq_dk_dv
dP = np.matmul(dO, V.transpose(0, 1, 3, 2))
dS = P * (dP - D)
dQ = np.matmul(dS, K) * scale
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale
dV = np.matmul(P.transpose(0, 1, 3, 2), dO)
return dQ, dK, dV, D.squeeze(-1)
def finite_difference_check(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
dO: np.ndarray,
scale: float,
eps: float = 1e-3,
param_name: str = "Q",
max_checks: int = 5,
) -> float:
"""Verify gradients via finite differences on a few elements."""
param_map = {"Q": Q, "K": K, "V": V}
param = param_map[param_name]
O_ref, P_ref, _ = cpu_attention_fwd_with_lse(Q, K, V, scale)
_, _, _, _ = cpu_attention_bwd(Q, K, V, O_ref, dO, P_ref, scale)
grad_map = {"Q": 0, "K": 1, "V": 2}
grad_idx = grad_map[param_name]
grads = cpu_attention_bwd(Q, K, V, O_ref, dO, P_ref, scale)
analytical_grad = grads[grad_idx]
max_err = 0.0
flat_indices = np.random.choice(
param.size, min(max_checks, param.size), replace=False
)
for flat_idx in flat_indices:
idx = np.unravel_index(flat_idx, param.shape)
orig = param[idx]
param[idx] = orig + eps
O_plus = cpu_attention_fwd(Q, K, V, scale)
loss_plus = (O_plus * dO).sum()
param[idx] = orig - eps
O_minus = cpu_attention_fwd(Q, K, V, scale)
loss_minus = (O_minus * dO).sum()
param[idx] = orig
fd_grad = (loss_plus - loss_minus) / (2 * eps)
an_grad = analytical_grad[idx]
err = abs(fd_grad - an_grad) / (abs(fd_grad) + 1e-8)
max_err = max(max_err, err)
return max_err
def main():
parser = argparse.ArgumentParser(description="Backward Pass (dQ, dK, dV)")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--nhead", type=int, default=4)
parser.add_argument("--seqlen", type=int, default=64)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument(
"--check-grad", action="store_true", help="Run finite-difference check"
)
parser.add_argument(
"--eps", type=float, default=1e-3, help="Finite-difference epsilon"
)
args = parser.parse_args()
print("=" * 70)
print("Example 18: Backward Pass (dQ, dK, dV)")
print("=" * 70)
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}")
# --- Generate data ---
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
# --- Forward pass ---
print("\n--- Stage 0: Forward Pass ---")
out, P, lse = cpu_attention_fwd_with_lse(Q, K, V, prob.scale)
print(f" O shape: {out.shape}")
print(f" O range: [{out.min():.4f}, {out.max():.4f}]")
print(f" LSE shape: {lse.shape}")
print(f" LSE range: [{lse.min():.4f}, {lse.max():.4f}]")
print(f" P sparsity (< 1e-6): {(P < 1e-6).sum() / P.size * 100:.1f}%")
# --- Backward pass (3 stages) ---
print("\n--- Stage 1: dot_do_o (D = rowsum(dO * O)) ---")
D_full = (dO * out).sum(axis=-1)
print(f" D shape: {D_full.shape}")
print(f" D range: [{D_full.min():.6f}, {D_full.max():.6f}]")
print("\n--- Stage 2: dq_dk_dv ---")
dQ, dK, dV, D = cpu_attention_bwd(Q, K, V, out, dO, P, prob.scale)
print(f" dQ shape: {dQ.shape}, range: [{dQ.min():.4e}, {dQ.max():.4e}]")
print(f" dK shape: {dK.shape}, range: [{dK.min():.4e}, {dK.max():.4e}]")
print(f" dV shape: {dV.shape}, range: [{dV.min():.4e}, {dV.max():.4e}]")
print("\n--- Stage 3: convert_dq (optional fp32 -> fp16) ---")
dQ_fp16 = dQ.astype(np.float16)
convert_err = float(np.abs(dQ - dQ_fp16.astype(np.float32)).max())
print(f" dQ fp32 -> fp16 max error: {convert_err:.2e}")
# --- Gradient norms ---
print("\n--- Gradient Statistics ---")
print(
f"\n {'Param':<6} {'L2 Norm':>12} {'Max Abs':>12} {'Mean Abs':>12} {'Shape'}"
)
print(" " + "-" * 60)
for name, grad in [("dQ", dQ), ("dK", dK), ("dV", dV)]:
l2 = float(np.sqrt((grad**2).sum()))
ma = float(np.abs(grad).max())
mean_a = float(np.abs(grad).mean())
print(f" {name:<6} {l2:>12.4e} {ma:>12.4e} {mean_a:>12.4e} {grad.shape}")
# --- Finite difference check ---
if args.check_grad:
print(f"\n--- Finite Difference Gradient Check (eps={args.eps}) ---")
for pname in ["Q", "K", "V"]:
Q_c, K_c, V_c = Q.copy(), K.copy(), V.copy()
err = finite_difference_check(
Q_c,
K_c,
V_c,
dO,
prob.scale,
eps=args.eps,
param_name=pname,
max_checks=5,
)
tag = "PASS" if err < 1e-2 else "FAIL"
print(f" d{pname}: max_rel_err = {err:.4e} {tag}")
# --- GPU forward attempt ---
print("\n--- GPU Execution ---")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
Q_fp16 = Q.astype(np.float16)
K_fp16 = K.astype(np.float16)
V_fp16 = V.astype(np.float16)
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
if res.success:
print(f" Forward GPU: {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS")
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
ok, ma, _ = validator.check(res.output, out)
print(f" Forward validation: max_err={ma:.2e}, {'PASS' if ok else 'FAIL'}")
else:
print(" Forward GPU: Kernel returned failure")
print(" Backward GPU: Not available (requires bwd family kernel)")
# --- Backward plan structure ---
print("\n--- Backward Plan Structure ---")
print(" Stage 1: dot_do_o")
print(f" Input: dO [{prob.o_shape()}], O [{prob.o_shape()}]")
print(f" Output: D [{prob.batch}, {prob.nhead_q}, {prob.seqlen_q}]")
print(" Stage 2: dq_dk_dv")
print(" Input: Q, K, V, dO, LSE, D")
print(" Output: dQ, dK, dV (in accumulator precision)")
print(" Stage 3: convert_dq")
print(" Input: dQ (fp32)")
print(" Output: dQ (fp16)")
# --- Summary ---
print("\n" + "=" * 70)
print(" Forward: O = softmax(Q @ K^T / sqrt(d)) @ V")
print(" Backward: 3-stage plan (dot_do_o -> dq_dk_dv -> convert_dq)")
print(f" Gradients: dQ [{dQ.shape}], dK [{dK.shape}], dV [{dV.shape}]")
print(" GPU: Prebuilt supports forward only")
print(" Status: DEMO")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,344 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 19: Batch Padding and Group Mode
Demonstrates:
1. Batch mode with effective lengths (q_eff_lens, kv_eff_lens)
- Padded to max length but only effective positions contribute
2. Group mode with physical padding strides (s_qpad, s_kpad)
- Variable-length sequences packed contiguously
- seqstart pointers mark boundaries
3. Comparing batch vs group mode memory efficiency
In batch mode, each sequence in the batch is padded to the same max length.
In group mode, sequences are packed without padding using offset pointers,
saving memory for batches with high length variance.
Usage:
python3 19_padding_fmha.py
python3 19_padding_fmha.py --batch 8
python3 19_padding_fmha.py --max-seqlen 512
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def cpu_batch_padded_attention(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
q_eff_lens: np.ndarray,
kv_eff_lens: np.ndarray,
) -> np.ndarray:
"""CPU reference: batch attention with effective lengths.
Positions beyond effective length are masked out.
Q: [batch, nhead, max_seqlen_q, hdim]
"""
batch = Q.shape[0]
nhead = Q.shape[1]
max_sq = Q.shape[2]
hdim_v = V.shape[3]
out = np.zeros((batch, nhead, max_sq, hdim_v), dtype=np.float32)
for b in range(batch):
ql = q_eff_lens[b]
kl = kv_eff_lens[b]
Q_b = Q[b : b + 1, :, :ql, :]
K_b = K[b : b + 1, :, :kl, :]
V_b = V[b : b + 1, :, :kl, :]
O_b = cpu_attention_fwd(Q_b, K_b, V_b, scale)
out[b, :, :ql, :] = O_b[0]
return out
def pack_group_mode(
Q_batch: np.ndarray,
K_batch: np.ndarray,
V_batch: np.ndarray,
q_lens: np.ndarray,
kv_lens: np.ndarray,
) -> tuple:
"""Pack batch sequences into group mode (contiguous, no padding).
Returns: (Q_packed, K_packed, V_packed, seqstart_q, seqstart_k)
"""
batch = Q_batch.shape[0]
nhead = Q_batch.shape[1]
hdim_q = Q_batch.shape[3]
hdim_v = V_batch.shape[3]
total_q = int(q_lens.sum())
total_k = int(kv_lens.sum())
Q_packed = np.zeros((1, nhead, total_q, hdim_q), dtype=Q_batch.dtype)
K_packed = np.zeros((1, nhead, total_k, hdim_q), dtype=K_batch.dtype)
V_packed = np.zeros((1, nhead, total_k, hdim_v), dtype=V_batch.dtype)
seqstart_q = np.zeros(batch + 1, dtype=np.int32)
seqstart_k = np.zeros(batch + 1, dtype=np.int32)
q_offset = 0
k_offset = 0
for b in range(batch):
ql, kl = int(q_lens[b]), int(kv_lens[b])
Q_packed[0, :, q_offset : q_offset + ql, :] = Q_batch[b, :, :ql, :]
K_packed[0, :, k_offset : k_offset + kl, :] = K_batch[b, :, :kl, :]
V_packed[0, :, k_offset : k_offset + kl, :] = V_batch[b, :, :kl, :]
q_offset += ql
k_offset += kl
seqstart_q[b + 1] = q_offset
seqstart_k[b + 1] = k_offset
return Q_packed, K_packed, V_packed, seqstart_q, seqstart_k
def cpu_group_attention(
Q_packed: np.ndarray,
K_packed: np.ndarray,
V_packed: np.ndarray,
scale: float,
seqstart_q: np.ndarray,
seqstart_k: np.ndarray,
batch: int,
) -> np.ndarray:
"""CPU reference: group mode attention on packed sequences.
Q_packed: [1, nhead, total_q, hdim]
"""
nhead = Q_packed.shape[1]
total_q = Q_packed.shape[2]
hdim_v = V_packed.shape[3]
O_packed = np.zeros((1, nhead, total_q, hdim_v), dtype=np.float32)
for b in range(batch):
qs, qe = seqstart_q[b], seqstart_q[b + 1]
ks, ke = seqstart_k[b], seqstart_k[b + 1]
Q_b = Q_packed[:, :, qs:qe, :]
K_b = K_packed[:, :, ks:ke, :]
V_b = V_packed[:, :, ks:ke, :]
O_b = cpu_attention_fwd(Q_b, K_b, V_b, scale)
O_packed[0, :, qs:qe, :] = O_b[0]
return O_packed
def main():
parser = argparse.ArgumentParser(description="Batch Padding and Group Mode")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=4)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--max-seqlen", type=int, default=256)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
print("=" * 70)
print("Example 19: Batch Padding and Group Mode")
print("=" * 70)
batch = args.batch
nhead = args.nhead
max_sq = max_sk = args.max_seqlen
hdim = args.hdim
# --- Variable-length sequences ---
np.random.seed(args.seed)
q_eff_lens = np.sort(
np.random.randint(32, max_sq + 1, size=batch).astype(np.int32)
)[::-1]
kv_eff_lens = np.sort(
np.random.randint(32, max_sk + 1, size=batch).astype(np.int32)
)[::-1]
q_eff_lens = q_eff_lens.copy()
kv_eff_lens = kv_eff_lens.copy()
print(f"\n Batch: {batch}")
print(f" Max seqlen: {max_sq}")
print(f" HDim: {hdim}")
print(f"\n {'Seq#':<6} {'q_len':>8} {'kv_len':>8} {'q_pad%':>8} {'kv_pad%':>8}")
print(" " + "-" * 42)
for b in range(batch):
q_pad = (1.0 - q_eff_lens[b] / max_sq) * 100
kv_pad = (1.0 - kv_eff_lens[b] / max_sk) * 100
print(
f" {b:<6} {q_eff_lens[b]:>8} {kv_eff_lens[b]:>8} {q_pad:>7.1f}% {kv_pad:>7.1f}%"
)
# --- Generate padded data ---
Q_padded = (np.random.randn(batch, nhead, max_sq, hdim) * 0.1).astype(np.float32)
K_padded = (np.random.randn(batch, nhead, max_sk, hdim) * 0.1).astype(np.float32)
V_padded = (np.random.randn(batch, nhead, max_sk, hdim) * 0.1).astype(np.float32)
# === BATCH MODE ===
print("\n--- Batch Mode (padded) ---")
O_batch = cpu_batch_padded_attention(
Q_padded,
K_padded,
V_padded,
1.0 / (hdim**0.5),
q_eff_lens,
kv_eff_lens,
)
batch_mem = batch * nhead * (max_sq + 2 * max_sk) * hdim * 4
print(f" Q/K/V layout: [{batch}, {nhead}, {max_sq}, {hdim}]")
print(f" Memory (Q+K+V): {batch_mem / 1024:.1f} KB")
print(
f" Wasted (avg): {(1.0 - q_eff_lens.mean() / max_sq) * 100:.1f}% (padding overhead)"
)
# === GROUP MODE ===
print("\n--- Group Mode (packed) ---")
Q_packed, K_packed, V_packed, seqstart_q, seqstart_k = pack_group_mode(
Q_padded,
K_padded,
V_padded,
q_eff_lens,
kv_eff_lens,
)
total_q = int(q_eff_lens.sum())
total_k = int(kv_eff_lens.sum())
group_mem = nhead * (total_q + 2 * total_k) * hdim * 4
print(f" Q_packed: [1, {nhead}, {total_q}, {hdim}]")
print(f" K_packed: [1, {nhead}, {total_k}, {hdim}]")
print(f" seqstart_q: {seqstart_q}")
print(f" seqstart_k: {seqstart_k}")
print(f" Memory (Q+K+V): {group_mem / 1024:.1f} KB")
print(f" Saving vs batch: {(1.0 - group_mem / batch_mem) * 100:.1f}%")
# Physical padding strides
s_qpad = total_q
s_kpad = total_k
print("\n Physical strides:")
print(f" s_qpad = {s_qpad} (total Q tokens)")
print(f" s_kpad = {s_kpad} (total KV tokens)")
O_group = cpu_group_attention(
Q_packed,
K_packed,
V_packed,
1.0 / (hdim**0.5),
seqstart_q,
seqstart_k,
batch,
)
# --- Cross-validate batch vs group ---
print("\n--- Batch vs Group Validation ---")
print(f"\n {'Seq#':<6} {'q_len':>8} {'MaxErr':>10} {'Status':>8}")
print(" " + "-" * 36)
all_ok = True
for b in range(batch):
ql = q_eff_lens[b]
qs = seqstart_q[b]
O_b_batch = O_batch[b, :, :ql, :]
O_b_group = O_group[0, :, qs : qs + ql, :]
max_err = float(np.abs(O_b_batch - O_b_group).max())
ok = max_err < 1e-5
all_ok = all_ok and ok
print(f" {b:<6} {ql:>8} {max_err:>10.2e} {'PASS' if ok else 'FAIL':>8}")
# --- GPU attempt ---
print("\n--- GPU Execution ---")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
prob = FmhaProblem(
batch=batch,
nhead_q=nhead,
nhead_k=nhead,
seqlen_q=max_sq,
seqlen_k=max_sk,
hdim_q=hdim,
hdim_v=hdim,
)
Q_fp16 = Q_padded.astype(np.float16)
K_fp16 = K_padded.astype(np.float16)
V_fp16 = V_padded.astype(np.float16)
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
if res.success:
print(f" GPU (full padded): {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS")
print(
" Note: GPU runs full padded attention; effective-length masking needs kernel support"
)
else:
print(" GPU: Kernel returned failure")
# --- Memory analysis ---
print("\n--- Memory Efficiency Analysis ---")
print(f"\n {'Metric':<24} {'Batch Mode':>14} {'Group Mode':>14} {'Ratio':>8}")
print(" " + "-" * 64)
batch_tokens_q = batch * max_sq
group_tokens_q = total_q
batch_tokens_k = batch * max_sk
group_tokens_k = total_k
print(
f" {'Q tokens':<24} {batch_tokens_q:>14} {group_tokens_q:>14} {group_tokens_q / batch_tokens_q:>7.2f}x"
)
print(
f" {'KV tokens':<24} {batch_tokens_k:>14} {group_tokens_k:>14} {group_tokens_k / batch_tokens_k:>7.2f}x"
)
print(
f" {'Memory (KB)':<24} {batch_mem / 1024:>14.1f} {group_mem / 1024:>14.1f} {group_mem / batch_mem:>7.2f}x"
)
print(
f" {'Compute (tokens)':<24} {batch_tokens_q * batch_tokens_k:>14} {sum(q_eff_lens[i] * kv_eff_lens[i] for i in range(batch)):>14} "
f"{sum(q_eff_lens[i] * kv_eff_lens[i] for i in range(batch)) / (batch_tokens_q * batch_tokens_k):>7.2f}x"
)
# --- Summary ---
print("\n" + "=" * 70)
print(" Batch mode: Padded to max_seqlen, uses q_eff_lens/kv_eff_lens")
print(" Group mode: Packed contiguously, uses seqstart pointers")
print(f" Strides: s_qpad={s_qpad}, s_kpad={s_kpad}")
print(f" Memory save: {(1.0 - group_mem / batch_mem) * 100:.1f}% with group mode")
print(f" Batch==Group: {'PASS' if all_ok else 'FAIL'} (identical results)")
print(" GPU: Prebuilt supports batch mode only")
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
print("=" * 70)
return 0 if all_ok else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,120 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 20: FP8 FMHA Forward
Demonstrates FP8 data types (fp8bf16, fp8fp32) for FMHA forward
with quantization scale (pertensor, blockscale).
Note: FP8 requires a kernel compiled with fp8bf16/fp8fp32 dtype.
The prebuilt library has fp16 only, so this example shows the
API pattern and CPU reference.
Usage:
python3 20_fp8_fmha.py
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaKernelConfig,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
FP8_CONFIGS = [
("fp8bf16", "pertensor", "FP8 with BF16 output, per-tensor scale"),
("fp8fp32", "pertensor", "FP8 with FP32 output, per-tensor scale"),
("fp8bf16", "blockscale", "FP8 with BF16 output, block scale"),
]
def main():
parser = argparse.ArgumentParser(description="FP8 FMHA Example")
parser.add_argument("--arch", default=detect_gpu_arch())
args = parser.parse_args()
print("=" * 70)
print("Example 20: FP8 FMHA Forward")
print("=" * 70)
prob = FmhaProblem(
batch=2, nhead_q=4, nhead_k=4, seqlen_q=64, seqlen_k=64, hdim_q=128, hdim_v=128
)
print(f"\n Arch: {args.arch}")
print(f" Shape: B={prob.batch} H={prob.nhead_q} S={prob.seqlen_q} D={prob.hdim_q}")
# CPU reference (fp32 baseline)
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
O_ref = cpu_attention_fwd(Q, K, V, prob.scale)
print("\n--- FP8 Configurations ---\n")
print(f" {'#':<3} {'Dtype':<12} {'QScale':<12} {'Description':<45} {'Status':<6}")
print(" " + "-" * 80)
for i, (dtype, qscale, desc) in enumerate(FP8_CONFIGS, 1):
_cfg = FmhaKernelConfig(
data_type=dtype,
hdim_q=128,
hdim_v=128,
qscale=qscale,
gfx_arch=args.arch,
)
# FP8 kernels need dedicated compilation
status = "CPU-OK"
print(f" {i:<3} {dtype:<12} {qscale:<12} {desc:<45} {status:<6}")
# Show FP8 tolerance expectations
print("\n--- FP8 Tolerance Reference ---")
print(" fp8bf16: rtol=1e-2, atol=1.8e-1")
print(" fp8fp32: rtol=1e-2, atol=1.8e-1")
print(" fp8 raw: rtol=0, atol=16 (or 32 for >240 range)")
# Run basic fp16 for comparison if prebuilt available
print("\n--- FP16 Baseline (prebuilt) ---")
config_fp16 = FmhaKernelConfig(
data_type="fp16",
hdim_q=128,
hdim_v=128,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config_fp16)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
Q16 = Q.astype(np.float16)
K16 = K.astype(np.float16)
V16 = V.astype(np.float16)
result = runner.run(Q16, K16, V16, prob)
if result.success:
max_err = float(np.abs(result.output.astype(np.float32) - O_ref).max())
print(f" FP16 baseline: {result.time_ms:.4f} ms, max_err={max_err:.2e}")
print(f"\n{'=' * 70}")
print(f" FP8 kernel configs demonstrated: {len(FP8_CONFIGS)}")
print(" Note: Build fp8bf16/fp8fp32 kernels for GPU execution")
print(" Status: PASS")
print(f"{'=' * 70}")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,235 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 21: Logits Soft Cap FMHA
Demonstrates the logits soft cap feature, which prevents attention logits
from growing unboundedly by applying: tanh(scores / soft_cap) * soft_cap
before the softmax. This technique is used in models like Gemma-2 to
stabilize training at large scale.
The prebuilt library does not include a logits_soft_cap kernel, so this
example validates the CPU reference implementation and shows the API
pattern for when a compiled kernel with logits=True is available.
Usage:
python3 21_logits_soft_cap_fmha.py
python3 21_logits_soft_cap_fmha.py --soft-cap 30.0
python3 21_logits_soft_cap_fmha.py --seqlen 256
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def cpu_attention_fwd_logits_soft_cap(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
soft_cap: float,
) -> np.ndarray:
"""CPU reference: attention with logits soft cap.
Before softmax, scores are clamped via:
scores = tanh(scores / soft_cap) * soft_cap
Args:
Q: [batch, nhead_q, seqlen_q, hdim_q] float32
K: [batch, nhead_k, seqlen_k, hdim_q] float32
V: [batch, nhead_k, seqlen_k, hdim_v] float32
scale: softmax scaling factor (1/sqrt(hdim_q))
soft_cap: logits soft cap value (e.g. 50.0)
Returns:
O: [batch, nhead_q, seqlen_q, hdim_v] float32
"""
nhead_q = Q.shape[1]
nhead_k = K.shape[1]
if nhead_q != nhead_k:
ratio = nhead_q // nhead_k
K = np.repeat(K, ratio, axis=1)
V = np.repeat(V, ratio, axis=1)
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S = np.tanh(S / soft_cap) * soft_cap
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
return np.matmul(P, V)
def show_soft_cap_effect(scale: float, soft_cap: float):
"""Visualize the clamping effect of logits soft cap on score magnitudes."""
raw_scores = np.array(
[-100, -50, -20, -10, -5, 0, 5, 10, 20, 50, 100], dtype=np.float32
)
scaled = raw_scores * scale
capped = np.tanh(scaled / soft_cap) * soft_cap
print(f"\n Soft cap effect (scale={scale:.4f}, soft_cap={soft_cap:.1f}):")
print(
f" {'Raw Score':>12} {'After Scale':>14} {'After Cap':>12} {'Reduction':>12}"
)
print(" " + "-" * 54)
for r, s, c in zip(raw_scores, scaled, capped):
reduction = abs(s) - abs(c) if abs(s) > 0 else 0
print(f" {r:>12.1f} {s:>14.4f} {c:>12.4f} {reduction:>12.4f}")
def main():
parser = argparse.ArgumentParser(
description="Logits Soft Cap FMHA Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 21_logits_soft_cap_fmha.py # Default soft_cap=50
python3 21_logits_soft_cap_fmha.py --soft-cap 30.0 # Tighter cap
python3 21_logits_soft_cap_fmha.py --seqlen 256
""",
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=128)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument(
"--soft-cap", type=float, default=50.0, help="Logits soft cap value"
)
args = parser.parse_args()
print("=" * 70)
print("Example 21: Logits Soft Cap FMHA")
print("=" * 70)
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
# Step 1: Demonstrate the soft cap transformation
print("\nStep 1: Soft Cap Transformation")
show_soft_cap_effect(prob.scale, args.soft_cap)
# Step 2: CPU reference comparison -- with vs without soft cap
print("\nStep 2: CPU Reference (with vs without soft cap)")
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float32)
O_no_cap = cpu_attention_fwd(Q, K, V, prob.scale)
O_capped = cpu_attention_fwd_logits_soft_cap(Q, K, V, prob.scale, args.soft_cap)
diff = np.abs(O_no_cap - O_capped)
print(f"\n Shape: {prob.q_shape()}")
print(f" Soft cap: {args.soft_cap}")
print(f" Output range (no cap): [{O_no_cap.min():.4f}, {O_no_cap.max():.4f}]")
print(f" Output range (capped): [{O_capped.min():.4f}, {O_capped.max():.4f}]")
print(f" Max diff (cap effect): {diff.max():.6e}")
print(f" Mean diff (cap effect): {diff.mean():.6e}")
# Step 3: Validate across different soft_cap values
print("\nStep 3: Soft Cap Sweep")
soft_cap_values = [10.0, 20.0, 30.0, 50.0, 100.0, 500.0]
validator = FmhaValidator(rtol=1e-4, atol=1e-4)
print(
f"\n {'SoftCap':>10} {'OutRange':>20} {'vs NoCap MaxDiff':>18} {'vs NoCap MeanDiff':>18}"
)
print(" " + "-" * 70)
for sc in soft_cap_values:
O_sc = cpu_attention_fwd_logits_soft_cap(Q, K, V, prob.scale, sc)
d = np.abs(O_no_cap - O_sc)
out_range = f"[{O_sc.min():.4f}, {O_sc.max():.4f}]"
print(f" {sc:>10.1f} {out_range:>20} {d.max():>18.6e} {d.mean():>18.6e}")
# Step 4: Self-consistency -- large soft_cap should approach no-cap result
print("\nStep 4: Self-Consistency Check")
O_large_cap = cpu_attention_fwd_logits_soft_cap(Q, K, V, prob.scale, 1e6)
ok, max_abs, _ = validator.check(O_large_cap, O_no_cap)
print(
f" soft_cap=1e6 vs no_cap: max_err={max_abs:.2e} -> {'PASS' if ok else 'FAIL'}"
)
# Step 5: GPU API pattern (requires logits=True kernel)
print("\nStep 5: GPU Kernel Pattern")
print(" NOTE: The prebuilt library does not include a logits_soft_cap kernel.")
print(" To run on GPU, compile a kernel with logits=True in the signature:")
print()
print(" config = FmhaKernelConfig(")
print(" family='fwd', data_type='fp16', hdim_q=128, hdim_v=128,")
print(" pipeline='qr_async',")
print(" )")
print(' # In codegen JSON, set: "logits": true')
print()
print(" The dispatcher will pass logits_soft_cap to the kernel arguments.")
# Step 6: GPU run with standard kernel (no soft cap) for baseline
print("\nStep 6: GPU Baseline (standard kernel, no soft cap)")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
Q_f16 = Q.astype(np.float16)
K_f16 = K.astype(np.float16)
V_f16 = V.astype(np.float16)
result = runner.run(Q_f16, K_f16, V_f16, prob)
if result.success:
ok_gpu, max_abs_gpu, _ = validator.check(result.output, O_no_cap)
print(
f" GPU (no cap): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
f"max_err={max_abs_gpu:.2e} {'PASS' if ok_gpu else 'FAIL'}"
)
else:
print(f" GPU error: {result.error}")
# Summary
print("\n" + "=" * 70)
print(" Logits soft cap: tanh(scores / cap) * cap before softmax")
print(f" Large cap -> standard attention (verified: max_err={max_abs:.2e})")
print(" Small cap -> output variance reduced, stabilizes training")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,315 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 22: Sink Token Attention FMHA
Demonstrates sink token attention where the first N "sink" tokens are
always attended to regardless of the causal mask. This technique is used
in StreamingLLM and similar approaches to keep a few initial tokens as
attention anchors during long-context generation.
Mask format: t:left,right,sink -- a causal mask (top-left or bottom-right)
where the first 'sink' positions are always unmasked.
The prebuilt library does not include a sink token kernel, so this
example validates the CPU reference and shows the API pattern.
Usage:
python3 22_sink_tokens_fmha.py
python3 22_sink_tokens_fmha.py --sink-tokens 8
python3 22_sink_tokens_fmha.py --seqlen 256 --window 64
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def make_causal_mask(seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Standard causal (top-left) mask: attend only to positions <= current."""
mask = np.zeros((seqlen_q, seqlen_k), dtype=np.float32)
for i in range(seqlen_q):
for j in range(seqlen_k):
if j <= i:
mask[i, j] = 1.0
return mask
def make_causal_sink_mask(
seqlen_q: int,
seqlen_k: int,
num_sink: int,
) -> np.ndarray:
"""Causal mask with sink tokens: always attend to first num_sink positions.
For each query position i:
- Always attend to positions [0, num_sink) (sink tokens)
- Also attend to positions [j] where j <= i (standard causal)
"""
mask = np.zeros((seqlen_q, seqlen_k), dtype=np.float32)
for i in range(seqlen_q):
for j in range(seqlen_k):
if j < num_sink or j <= i:
mask[i, j] = 1.0
return mask
def make_sliding_window_sink_mask(
seqlen_q: int,
seqlen_k: int,
window: int,
num_sink: int,
) -> np.ndarray:
"""Sliding window mask with sink tokens.
For each query position i:
- Always attend to positions [0, num_sink) (sink tokens)
- Attend to positions in [i - window + 1, i] (sliding window)
"""
mask = np.zeros((seqlen_q, seqlen_k), dtype=np.float32)
for i in range(seqlen_q):
for j in range(seqlen_k):
if j < num_sink or (i - window + 1 <= j <= i):
mask[i, j] = 1.0
return mask
def cpu_attention_fwd_masked(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
mask: np.ndarray,
) -> np.ndarray:
"""CPU reference: attention with explicit mask.
Args:
Q: [batch, nhead_q, seqlen_q, hdim_q] float32
K: [batch, nhead_k, seqlen_k, hdim_q] float32
V: [batch, nhead_k, seqlen_k, hdim_v] float32
scale: softmax scale
mask: [seqlen_q, seqlen_k] binary mask (1=attend, 0=ignore)
Returns:
O: [batch, nhead_q, seqlen_q, hdim_v] float32
"""
nhead_q = Q.shape[1]
nhead_k = K.shape[1]
if nhead_q != nhead_k:
ratio = nhead_q // nhead_k
K = np.repeat(K, ratio, axis=1)
V = np.repeat(V, ratio, axis=1)
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
neg_inf = np.finfo(np.float32).min
S = np.where(mask[np.newaxis, np.newaxis, :, :] > 0, S, neg_inf)
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
return np.matmul(P, V)
def print_mask(mask: np.ndarray, name: str, max_display: int = 16):
"""Print a small portion of a mask for visualization."""
rows, cols = mask.shape
rows_show = min(rows, max_display)
cols_show = min(cols, max_display)
print(f"\n {name} ({rows}x{cols}, showing {rows_show}x{cols_show}):")
for i in range(rows_show):
row_str = "".join("1" if mask[i, j] > 0 else "." for j in range(cols_show))
print(f" q{i:02d}: {row_str}")
def main():
parser = argparse.ArgumentParser(
description="Sink Token Attention FMHA Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=128)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument(
"--sink-tokens", type=int, default=4, help="Number of sink tokens"
)
parser.add_argument("--window", type=int, default=32, help="Sliding window size")
args = parser.parse_args()
print("=" * 70)
print("Example 22: Sink Token Attention FMHA")
print("=" * 70)
sq = sk = args.seqlen
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
# Step 1: Visualize mask patterns
print("\nStep 1: Mask Patterns")
causal = make_causal_mask(sq, sk)
causal_sink = make_causal_sink_mask(sq, sk, args.sink_tokens)
window_sink = make_sliding_window_sink_mask(sq, sk, args.window, args.sink_tokens)
vis_size = min(16, sq)
print_mask(causal[:vis_size, :vis_size], "Causal (standard)", vis_size)
print_mask(
causal_sink[:vis_size, :vis_size],
f"Causal + {args.sink_tokens} sink tokens",
vis_size,
)
print_mask(
window_sink[:vis_size, :vis_size],
f"Window({args.window}) + {args.sink_tokens} sink tokens",
vis_size,
)
# Step 2: CPU reference for each mask type
print("\n\nStep 2: CPU Reference Comparison")
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
O_no_mask = cpu_attention_fwd(Q, K, V, prob.scale)
O_causal = cpu_attention_fwd_masked(Q, K, V, prob.scale, causal)
O_causal_sink = cpu_attention_fwd_masked(Q, K, V, prob.scale, causal_sink)
O_window_sink = cpu_attention_fwd_masked(Q, K, V, prob.scale, window_sink)
masks_and_outputs = [
("No mask", O_no_mask),
("Causal", O_causal),
(f"Causal+sink({args.sink_tokens})", O_causal_sink),
(f"Window({args.window})+sink({args.sink_tokens})", O_window_sink),
]
print(f"\n {'Mask Type':<30} {'Output Range':>20} {'vs NoMask MaxDiff':>18}")
print(" " + "-" * 70)
for name, out in masks_and_outputs:
d = np.abs(out - O_no_mask).max()
out_range = f"[{out.min():.4f}, {out.max():.4f}]"
print(f" {name:<30} {out_range:>20} {d:>18.6e}")
# Step 3: Verify sink tokens effect
print("\nStep 3: Sink Token Effect Analysis")
diff_causal_vs_sink = np.abs(O_causal - O_causal_sink)
print(" Causal vs Causal+Sink:")
print(f" Max diff: {diff_causal_vs_sink.max():.6e}")
print(f" Mean diff: {diff_causal_vs_sink.mean():.6e}")
n_attend_causal = causal.sum()
n_attend_sink = causal_sink.sum()
n_attend_window = window_sink.sum()
print("\n Attention density:")
print(
f" Causal: {n_attend_causal:>8.0f} / {sq * sk} ({100 * n_attend_causal / (sq * sk):.1f}%)"
)
print(
f" Causal+sink: {n_attend_sink:>8.0f} / {sq * sk} ({100 * n_attend_sink / (sq * sk):.1f}%)"
)
print(
f" Window+sink: {n_attend_window:>8.0f} / {sq * sk} ({100 * n_attend_window / (sq * sk):.1f}%)"
)
# Step 4: Sweep sink token count
print("\nStep 4: Sink Token Sweep")
sink_counts = [0, 1, 2, 4, 8, 16]
validator = FmhaValidator(rtol=1e-4, atol=1e-4)
print(
f"\n {'Sinks':>6} {'Density':>10} {'vs Causal MaxDiff':>20} {'vs NoMask MaxDiff':>20}"
)
print(" " + "-" * 60)
for ns in sink_counts:
if ns > sk:
continue
m = make_causal_sink_mask(sq, sk, ns)
O_s = cpu_attention_fwd_masked(Q, K, V, prob.scale, m)
d_causal = np.abs(O_s - O_causal).max()
d_nomask = np.abs(O_s - O_no_mask).max()
density = 100 * m.sum() / (sq * sk)
print(f" {ns:>6} {density:>9.1f}% {d_causal:>20.6e} {d_nomask:>20.6e}")
# Step 5: GPU API pattern
print("\nStep 5: GPU Kernel Pattern")
print(" NOTE: The prebuilt library does not include a sink token kernel.")
print(" To compile a sink-enabled kernel, use:")
print()
print(" FmhaSignature()")
print(" .mask('top_left') // causal mask required with sink")
print(" .sink(true) // enable sink tokens")
print()
print(" At runtime, pass sink count via the mask spec: 't:left,right,sink'")
print(
f" Example: 't:0,0,{args.sink_tokens}' for causal + {args.sink_tokens} sink tokens"
)
# Step 6: GPU baseline (no mask, no sink)
print("\nStep 6: GPU Baseline (standard kernel, no mask)")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
Q_f16 = Q.astype(np.float16)
K_f16 = K.astype(np.float16)
V_f16 = V.astype(np.float16)
result = runner.run(Q_f16, K_f16, V_f16, prob)
if result.success:
ok, max_abs, _ = validator.check(result.output, O_no_mask)
print(
f" GPU (no mask): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
f"max_err={max_abs:.2e} {'PASS' if ok else 'FAIL'}"
)
else:
print(f" GPU error: {result.error}")
# Summary
print("\n" + "=" * 70)
print(" Sink token attention: first N tokens always attended regardless of mask")
print(" Use case: StreamingLLM, long-context generation with attention anchors")
print(" Sink tokens preserve global context that causal masking would discard")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,406 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 23: Batch Prefill FMHA for SGLang/vLLM
Demonstrates batch prefill with paged KV-cache, as used in serving
frameworks like SGLang and vLLM. Shows the KV page table configuration
(kv_indptr, kv_page_indices, kv_last_page_lens) for both:
- SGLang: 1D page table with indirect page lookup
- vLLM: 2D block table with per-sequence page arrays
This example builds the page table metadata on CPU and validates the
attention computation. The prebuilt library only supports the basic
forward kernel, so the page table logic is demonstrated via CPU reference.
Usage:
python3 23_batch_prefill_fmha.py
python3 23_batch_prefill_fmha.py --page-size 64
python3 23_batch_prefill_fmha.py --num-seqs 8
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def build_sglang_page_table(
seq_lens_k: list,
page_size: int,
nhead_k: int,
hdim: int,
) -> dict:
"""Build SGLang-style 1D page table for paged KV-cache.
SGLang uses a flat 1D array of page indices. Each sequence's pages are
stored contiguously in the page_indices array, with indptr marking
boundaries.
Returns dict with:
kv_indptr: [num_seqs + 1] cumulative page counts
kv_page_indices: [total_pages] global page IDs
kv_last_page_lens: [num_seqs] tokens in last page of each seq
num_total_pages: total pages allocated
kv_data_shape: shape of the paged KV pool
"""
num_seqs = len(seq_lens_k)
kv_indptr = np.zeros(num_seqs + 1, dtype=np.int32)
page_indices_list = []
last_page_lens = np.zeros(num_seqs, dtype=np.int32)
page_counter = 0
for i, seqlen in enumerate(seq_lens_k):
num_pages = (seqlen + page_size - 1) // page_size
kv_indptr[i + 1] = kv_indptr[i] + num_pages
page_indices_list.extend(range(page_counter, page_counter + num_pages))
last_page_lens[i] = seqlen - (num_pages - 1) * page_size
page_counter += num_pages
kv_page_indices = np.array(page_indices_list, dtype=np.int32)
total_pages = page_counter
return {
"kv_indptr": kv_indptr,
"kv_page_indices": kv_page_indices,
"kv_last_page_lens": last_page_lens,
"num_total_pages": total_pages,
"kv_data_shape": (total_pages, 2, nhead_k, page_size, hdim),
"layout": "sglang_1d",
}
def build_vllm_block_table(
seq_lens_k: list,
page_size: int,
nhead_k: int,
hdim: int,
) -> dict:
"""Build vLLM-style 2D block table for paged KV-cache.
vLLM uses a 2D array [num_seqs, max_blocks_per_seq] where each entry
is a block (page) index into the global KV pool.
Returns dict with:
block_table: [num_seqs, max_blocks] page IDs (-1 = unused)
kv_last_page_lens: [num_seqs] tokens in last page of each seq
num_total_pages: total pages allocated
kv_data_shape: shape of the paged KV pool
"""
num_seqs = len(seq_lens_k)
pages_per_seq = [(s + page_size - 1) // page_size for s in seq_lens_k]
max_blocks = max(pages_per_seq)
block_table = np.full((num_seqs, max_blocks), -1, dtype=np.int32)
last_page_lens = np.zeros(num_seqs, dtype=np.int32)
page_counter = 0
for i, (seqlen, num_pages) in enumerate(zip(seq_lens_k, pages_per_seq)):
for p in range(num_pages):
block_table[i, p] = page_counter
page_counter += 1
last_page_lens[i] = seqlen - (num_pages - 1) * page_size
return {
"block_table": block_table,
"kv_last_page_lens": last_page_lens,
"num_total_pages": page_counter,
"kv_data_shape": (page_counter, 2, nhead_k, page_size, hdim),
"layout": "vllm_2d",
}
def scatter_kv_to_pages(
K: np.ndarray,
V: np.ndarray,
page_table: dict,
page_size: int,
) -> np.ndarray:
"""Scatter contiguous K,V into paged KV pool using page table.
Args:
K: [nhead_k, seqlen_k, hdim] float32 (single sequence)
V: [nhead_k, seqlen_k, hdim] float32
page_table: page indices for this sequence
page_size: tokens per page
"""
nhead_k, seqlen_k, hdim = K.shape
num_pages = (seqlen_k + page_size - 1) // page_size
pages = np.zeros((num_pages, 2, nhead_k, page_size, hdim), dtype=np.float32)
for p in range(num_pages):
start = p * page_size
end = min(start + page_size, seqlen_k)
length = end - start
pages[p, 0, :, :length, :] = K[:, start:end, :]
pages[p, 1, :, :length, :] = V[:, start:end, :]
return pages
def gather_kv_from_pages(
kv_pool: np.ndarray,
page_indices: np.ndarray,
seqlen_k: int,
page_size: int,
) -> tuple:
"""Gather K,V from paged KV pool back to contiguous arrays.
Returns:
K: [nhead_k, seqlen_k, hdim]
V: [nhead_k, seqlen_k, hdim]
"""
nhead_k = kv_pool.shape[2]
hdim = kv_pool.shape[4]
K = np.zeros((nhead_k, seqlen_k, hdim), dtype=np.float32)
V = np.zeros((nhead_k, seqlen_k, hdim), dtype=np.float32)
for p, page_idx in enumerate(page_indices):
start = p * page_size
end = min(start + page_size, seqlen_k)
length = end - start
K[:, start:end, :] = kv_pool[page_idx, 0, :, :length, :]
V[:, start:end, :] = kv_pool[page_idx, 1, :, :length, :]
return K, V
def main():
parser = argparse.ArgumentParser(
description="Batch Prefill FMHA for SGLang/vLLM",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--nhead-q", type=int, default=16)
parser.add_argument("--nhead-k", type=int, default=4, help="KV heads (GQA)")
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument("--page-size", type=int, default=16)
parser.add_argument("--num-seqs", type=int, default=4, help="Sequences in batch")
args = parser.parse_args()
print("=" * 70)
print("Example 23: Batch Prefill FMHA (SGLang/vLLM)")
print("=" * 70)
seq_lens_q = [32, 64, 16, 48][: args.num_seqs]
seq_lens_k = [256, 512, 128, 384][: args.num_seqs]
# Step 1: SGLang page table
print("\nStep 1: SGLang 1D Page Table")
sglang_pt = build_sglang_page_table(
seq_lens_k,
args.page_size,
args.nhead_k,
args.hdim,
)
print(f" Page size: {args.page_size}")
print(f" Total pages: {sglang_pt['num_total_pages']}")
print(f" KV pool shape: {sglang_pt['kv_data_shape']}")
print(f" kv_indptr: {sglang_pt['kv_indptr']}")
print(
f" kv_page_indices: {sglang_pt['kv_page_indices'][:20]}{'...' if len(sglang_pt['kv_page_indices']) > 20 else ''}"
)
print(f" last_page_lens: {sglang_pt['kv_last_page_lens']}")
print("\n Per-sequence breakdown:")
print(f" {'Seq':>5} {'SeqQ':>6} {'SeqK':>6} {'Pages':>6} {'LastLen':>8}")
print(" " + "-" * 35)
for i in range(args.num_seqs):
n_pages = sglang_pt["kv_indptr"][i + 1] - sglang_pt["kv_indptr"][i]
print(
f" {i:>5} {seq_lens_q[i]:>6} {seq_lens_k[i]:>6} {n_pages:>6} {sglang_pt['kv_last_page_lens'][i]:>8}"
)
# Step 2: vLLM block table
print("\nStep 2: vLLM 2D Block Table")
vllm_pt = build_vllm_block_table(
seq_lens_k,
args.page_size,
args.nhead_k,
args.hdim,
)
print(f" Block table shape: {vllm_pt['block_table'].shape}")
print(f" Total pages: {vllm_pt['num_total_pages']}")
for i in range(args.num_seqs):
row = vllm_pt["block_table"][i]
valid = row[row >= 0]
print(f" Seq {i}: pages={valid.tolist()}")
# Step 3: Validate scatter/gather round-trip
print("\nStep 3: KV Page Scatter/Gather Validation")
np.random.seed(42)
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
total_pages = sglang_pt["num_total_pages"]
kv_pool = np.zeros(
(total_pages, 2, args.nhead_k, args.page_size, args.hdim),
dtype=np.float32,
)
all_Q, all_K, all_V, all_O_ref = [], [], [], []
for i in range(args.num_seqs):
sq, sk = seq_lens_q[i], seq_lens_k[i]
Q_i = np.random.randn(args.nhead_q, sq, args.hdim).astype(np.float32) * 0.3
K_i = np.random.randn(args.nhead_k, sk, args.hdim).astype(np.float32) * 0.3
V_i = np.random.randn(args.nhead_k, sk, args.hdim).astype(np.float32) * 0.3
start_page = sglang_pt["kv_indptr"][i]
end_page = sglang_pt["kv_indptr"][i + 1]
page_indices = sglang_pt["kv_page_indices"][start_page:end_page]
pages = scatter_kv_to_pages(K_i, V_i, page_indices, args.page_size)
for p_local, p_global in enumerate(page_indices):
kv_pool[p_global] = pages[p_local]
K_rt, V_rt = gather_kv_from_pages(kv_pool, page_indices, sk, args.page_size)
k_ok = np.allclose(K_i, K_rt, atol=1e-7)
v_ok = np.allclose(V_i, V_rt, atol=1e-7)
print(
f" Seq {i}: K round-trip={'OK' if k_ok else 'FAIL'} "
f"V round-trip={'OK' if v_ok else 'FAIL'}"
)
all_Q.append(Q_i)
all_K.append(K_i)
all_V.append(V_i)
# Step 4: CPU attention per-sequence
print("\nStep 4: CPU Attention per Sequence (from Paged KV)")
print(f"\n {'Seq':>5} {'SeqQ':>6} {'SeqK':>6} {'OutRange':>22} {'Scale':>10}")
print(" " + "-" * 50)
for i in range(args.num_seqs):
sq, sk = seq_lens_q[i], seq_lens_k[i]
Q_i = all_Q[i][np.newaxis] # [1, nhead_q, sq, hdim]
K_i = all_K[i][np.newaxis] # [1, nhead_k, sk, hdim]
V_i = all_V[i][np.newaxis] # [1, nhead_k, sk, hdim]
if args.nhead_q != args.nhead_k:
ratio = args.nhead_q // args.nhead_k
K_i_exp = np.repeat(K_i, ratio, axis=1)
V_i_exp = np.repeat(V_i, ratio, axis=1)
else:
K_i_exp, V_i_exp = K_i, V_i
scale = 1.0 / (args.hdim**0.5)
O_i = cpu_attention_fwd(Q_i, K_i_exp, V_i_exp, scale)
all_O_ref.append(O_i)
out_range = f"[{O_i.min():.4f}, {O_i.max():.4f}]"
print(f" {i:>5} {sq:>6} {sk:>6} {out_range:>22} {scale:>10.4f}")
# Step 5: Memory layout comparison
print("\nStep 5: Memory Layout Analysis")
contiguous_bytes = sum(2 * args.nhead_k * sk * args.hdim * 4 for sk in seq_lens_k)
paged_bytes = total_pages * 2 * args.nhead_k * args.page_size * args.hdim * 4
overhead = (paged_bytes - contiguous_bytes) / contiguous_bytes * 100
print(f" Contiguous KV: {contiguous_bytes / 1024:.1f} KB")
print(f" Paged KV pool: {paged_bytes / 1024:.1f} KB")
print(f" Overhead: {overhead:.1f}% (due to page padding)")
print(f" Pages used: {total_pages}")
print(f" Avg tokens/seq: {sum(seq_lens_k) / args.num_seqs:.0f}")
# Step 6: GPU API pattern
print("\nStep 6: GPU Kernel Configuration")
print(" NOTE: The prebuilt library uses basic forward kernels.")
print(" For batch prefill, compile a kernel with:")
print()
print(" FmhaSignature()")
print(" .family('batch_prefill')")
print(" .mode('group')")
print(" .paged_kv(true)")
print(" .kv_cache('vectorized', 'sglang', page_size)")
print(" .lse(true)")
print()
print(" FmhaKernelConfig codegen JSON:")
print(" 'family': 'batch_prefill',")
print(" 'mode': 'group',")
print(" 'paged_kv': true,")
print(" 'kv_memory_layout': 'vectorized',")
print(" 'kv_lookup_table': 'sglang' or 'vllm',")
print(f" 'page_size': {args.page_size}")
# Step 7: GPU baseline (contiguous, no paging)
print("\nStep 7: GPU Baseline (contiguous KV, single sequence)")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
prob = FmhaProblem(
batch=1,
nhead_q=args.nhead_q,
nhead_k=args.nhead_k,
seqlen_q=64,
seqlen_k=256,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
Q_gpu = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float16)
K_gpu = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float16)
V_gpu = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float16)
result = runner.run(Q_gpu, K_gpu, V_gpu, prob)
if result.success:
O_ref = cpu_attention_fwd(
Q_gpu.astype(np.float32),
K_gpu.astype(np.float32),
V_gpu.astype(np.float32),
prob.scale,
)
ok, max_abs, _ = validator.check(result.output, O_ref)
print(
f" GPU baseline: time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
f"max_err={max_abs:.2e} {'PASS' if ok else 'FAIL'}"
)
else:
print(f" GPU error: {result.error}")
# Summary
print("\n" + "=" * 70)
print(" Batch prefill: serves multiple prefill requests in a single kernel launch")
print(" SGLang: 1D page table (kv_indptr + kv_page_indices)")
print(" vLLM: 2D block table [num_seqs, max_blocks]")
print(
f" Page size {args.page_size} -> {overhead:.1f}% memory overhead vs contiguous"
)
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,250 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 24: Column-Major V Layout FMHA
Demonstrates column-major (vlayout="c") vs row-major (vlayout="r") for
the V tensor. In row-major, V is [batch, nhead, seqlen_k, hdim_v]; in
column-major, V is [batch, nhead, hdim_v, seqlen_k].
Column-major V can improve performance when hdim_v access patterns
benefit from the transposed layout (e.g., certain tile sizes or memory
coalescing characteristics on specific GPU architectures).
The prebuilt library uses row-major V. This example shows both layouts
with CPU reference and validates correctness.
Usage:
python3 24_vlayout_col_fmha.py
python3 24_vlayout_col_fmha.py --seqlen 512
python3 24_vlayout_col_fmha.py --batch 4
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def cpu_attention_fwd_vlayout_col(
Q: np.ndarray,
K: np.ndarray,
V_col: np.ndarray,
scale: float,
) -> np.ndarray:
"""CPU reference: attention with column-major V.
Args:
Q: [batch, nhead_q, seqlen_q, hdim_q] float32 (row-major)
K: [batch, nhead_k, seqlen_k, hdim_q] float32 (row-major)
V_col: [batch, nhead_k, hdim_v, seqlen_k] float32 (column-major)
scale: softmax scale
Returns:
O: [batch, nhead_q, seqlen_q, hdim_v] float32
"""
V_row = V_col.transpose(0, 1, 3, 2)
return cpu_attention_fwd(Q, K, V_row, scale)
def analyze_strides(name: str, arr: np.ndarray, dim_names: list):
"""Print stride information for a tensor."""
strides_bytes = arr.strides
itemsize = arr.itemsize
strides_elems = tuple(s // itemsize for s in strides_bytes)
print(f" {name}:")
print(f" Shape: {arr.shape}")
print(f" Strides: {strides_elems} (elements)")
for i, (dname, s) in enumerate(zip(dim_names, strides_elems)):
contiguous = "(contiguous)" if i == len(dim_names) - 1 and s == 1 else ""
print(f" {dname}: stride={s} {contiguous}")
def main():
parser = argparse.ArgumentParser(
description="Column-Major V Layout FMHA Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=128)
parser.add_argument("--hdim", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 24: Column-Major V Layout FMHA")
print("=" * 70)
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
# Step 1: Layout comparison
print("\nStep 1: V Tensor Layouts")
np.random.seed(42)
V_row = np.ascontiguousarray(
(np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
)
V_col = np.ascontiguousarray(V_row.transpose(0, 1, 3, 2))
analyze_strides(
"V row-major [B, H, SeqK, Hdim]",
V_row,
["batch", "nhead", "seqlen_k", "hdim_v"],
)
analyze_strides(
"V col-major [B, H, Hdim, SeqK]",
V_col,
["batch", "nhead", "hdim_v", "seqlen_k"],
)
print("\n Row-major: last dim is hdim_v -> sequential hdim access per token")
print(" Col-major: last dim is seqlen_k -> sequential token access per hdim")
# Step 2: CPU reference for both layouts
print("\nStep 2: CPU Reference (both layouts)")
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
O_from_row = cpu_attention_fwd(Q, K, V_row, prob.scale)
O_from_col = cpu_attention_fwd_vlayout_col(Q, K, V_col, prob.scale)
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
ok, max_abs, max_rel = validator.check(O_from_row, O_from_col)
print(
f" O from row-major V: shape={O_from_row.shape} "
f"range=[{O_from_row.min():.4f}, {O_from_row.max():.4f}]"
)
print(
f" O from col-major V: shape={O_from_col.shape} "
f"range=[{O_from_col.min():.4f}, {O_from_col.max():.4f}]"
)
print(f" Max abs error: {max_abs:.2e}")
print(f" Match: {'PASS' if ok else 'FAIL'}")
# Step 3: Memory access pattern analysis
print("\nStep 3: Memory Access Pattern Analysis")
tile_sizes = [(128, 128), (64, 128), (128, 64)]
print("\n For P @ V matmul (P: [sq, sk] x V: [sk, hdim_v]):")
print(f" {'Tile(M,N)':>12} {'V Row Accesses':>18} {'V Col Accesses':>18}")
print(" " + "-" * 52)
for tm, tn in tile_sizes:
row_access = f"sk_stride={args.hdim}"
col_access = "sk_stride=1"
print(f" {f'{tm}x{tn}':>12} {row_access:>18} {col_access:>18}")
print("\n Row-major V: coalesced reads when accessing hdim_v (inner loop)")
print(" Col-major V: coalesced reads when accessing seqlen_k (inner loop)")
print(" Optimal layout depends on tile shape and GPU memory subsystem")
# Step 4: Shape sweep with both layouts
print("\nStep 4: Correctness Sweep")
shapes = [
(1, 4, 64, 64, 64),
(2, 8, 128, 128, 128),
(1, 8, 256, 256, 128),
(2, 4, 128, 128, 64),
(1, 16, 64, 64, 128),
]
print(f"\n {'Shape':<32} {'MaxErr':>12} {'Status':>8}")
print(" " + "-" * 55)
all_ok = True
for b, h, sq, sk, d in shapes:
Q_t = (np.random.randn(b, h, sq, d) * 0.3).astype(np.float32)
K_t = (np.random.randn(b, h, sk, d) * 0.3).astype(np.float32)
V_r = (np.random.randn(b, h, sk, d) * 0.3).astype(np.float32)
V_c = np.ascontiguousarray(V_r.transpose(0, 1, 3, 2))
scale = 1.0 / (d**0.5)
O_r = cpu_attention_fwd(Q_t, K_t, V_r, scale)
O_c = cpu_attention_fwd_vlayout_col(Q_t, K_t, V_c, scale)
ok_t, max_abs_t, _ = validator.check(O_r, O_c)
all_ok = all_ok and ok_t
shape_str = f"B{b}_H{h}_S{sq}x{sk}_D{d}"
print(f" {shape_str:<32} {max_abs_t:>12.2e} {'PASS' if ok_t else 'FAIL':>8}")
# Step 5: GPU API pattern
print("\nStep 5: GPU Kernel Configuration")
print(" NOTE: The prebuilt library uses row-major V (vlayout='r').")
print(" For column-major V, compile a kernel with vlayout='c':")
print()
print(" FmhaSignature()")
print(" .vlayout('c') // column-major V: [B, H, Hdim, SeqK]")
print()
print(" FmhaKernelConfig(vlayout='c', ...)")
# Step 6: GPU baseline (row-major)
print("\nStep 6: GPU Baseline (row-major V)")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
Q_f16 = Q.astype(np.float16)
K_f16 = K.astype(np.float16)
V_f16 = V_row.astype(np.float16)
result = runner.run(Q_f16, K_f16, V_f16, prob)
if result.success:
ok_gpu, max_abs_gpu, _ = validator.check(result.output, O_from_row)
print(
f" GPU (row-major V): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
f"max_err={max_abs_gpu:.2e} {'PASS' if ok_gpu else 'FAIL'}"
)
else:
print(f" GPU error: {result.error}")
# Summary
print("\n" + "=" * 70)
print(" vlayout='r': V is [B, H, SeqK, Hdim] (default, row-major)")
print(" vlayout='c': V is [B, H, Hdim, SeqK] (column-major)")
print(
f" Both layouts produce identical results (verified: {'PASS' if all_ok else 'FAIL'})"
)
print(" Choice depends on upstream memory layout and GPU tile access patterns")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,262 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 25: Input/Output Permutation FMHA
Demonstrates different memory layouts for Q/K/V/O tensors via
input permutation (iperm) and output permutation (operm):
iperm=0 (bshd): [batch, seqlen, nhead, hdim] -- used by some frameworks
iperm=1 (bhsd): [batch, nhead, seqlen, hdim] -- standard/default
operm=0 (bshd): O is [batch, seqlen, nhead, hdim]
operm=1 (bhsd): O is [batch, nhead, seqlen, hdim]
The prebuilt library uses bhsd layout (iperm=1, operm=1). This example
shows how to convert between layouts and validates correctness.
Usage:
python3 25_permutation_fmha.py
python3 25_permutation_fmha.py --seqlen 256
python3 25_permutation_fmha.py --batch 4
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def bhsd_to_bshd(x: np.ndarray) -> np.ndarray:
"""Convert [batch, nhead, seqlen, hdim] -> [batch, seqlen, nhead, hdim]."""
return x.transpose(0, 2, 1, 3)
def bshd_to_bhsd(x: np.ndarray) -> np.ndarray:
"""Convert [batch, seqlen, nhead, hdim] -> [batch, nhead, seqlen, hdim]."""
return x.transpose(0, 2, 1, 3)
def cpu_attention_fwd_bshd(
Q_bshd: np.ndarray,
K_bshd: np.ndarray,
V_bshd: np.ndarray,
scale: float,
operm: int = 0,
) -> np.ndarray:
"""CPU reference with bshd input, configurable output layout.
Args:
Q_bshd: [batch, seqlen_q, nhead_q, hdim_q] float32
K_bshd: [batch, seqlen_k, nhead_k, hdim_q] float32
V_bshd: [batch, seqlen_k, nhead_k, hdim_v] float32
scale: softmax scale
operm: 0 -> output bshd, 1 -> output bhsd
Returns:
O: float32 in requested layout
"""
Q_bhsd = bshd_to_bhsd(Q_bshd)
K_bhsd = bshd_to_bhsd(K_bshd)
V_bhsd = bshd_to_bhsd(V_bshd)
O_bhsd = cpu_attention_fwd(Q_bhsd, K_bhsd, V_bhsd, scale)
if operm == 0:
return bhsd_to_bshd(O_bhsd)
return O_bhsd
def describe_layout(arr: np.ndarray, layout_name: str, dim_names: list):
"""Print layout details including strides."""
itemsize = arr.itemsize
strides_elems = tuple(s // itemsize for s in arr.strides)
is_contiguous = arr.flags["C_CONTIGUOUS"]
print(f" {layout_name}:")
print(f" Shape: {arr.shape}")
print(f" Strides: {strides_elems} (elements)")
print(f" Contiguous: {is_contiguous}")
for dname, s in zip(dim_names, strides_elems):
print(f" {dname:>8}: stride={s}")
def main():
parser = argparse.ArgumentParser(
description="Input/Output Permutation FMHA Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=128)
parser.add_argument("--hdim", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 25: Input/Output Permutation FMHA")
print("=" * 70)
B, H, S, D = args.batch, args.nhead, args.seqlen, args.hdim
prob = FmhaProblem(
batch=B,
nhead_q=H,
nhead_k=H,
seqlen_q=S,
seqlen_k=S,
hdim_q=D,
hdim_v=D,
)
# Step 1: Layout definitions
print("\nStep 1: Layout Definitions")
np.random.seed(42)
Q_bhsd = np.ascontiguousarray(
(np.random.randn(B, H, S, D) * 0.3).astype(np.float32)
)
Q_bshd = np.ascontiguousarray(bhsd_to_bshd(Q_bhsd))
describe_layout(Q_bhsd, "bhsd (iperm=1)", ["batch", "nhead", "seqlen", "hdim"])
describe_layout(Q_bshd, "bshd (iperm=0)", ["batch", "seqlen", "nhead", "hdim"])
print("\n Key difference:")
print(" bhsd: heads are contiguous -> good for per-head parallelism")
print(" bshd: tokens are contiguous -> good for sequence parallelism")
# Step 2: All permutation combinations
print("\nStep 2: All Permutation Combinations (CPU Reference)")
K_bhsd = (np.random.randn(B, H, S, D) * 0.3).astype(np.float32)
V_bhsd = (np.random.randn(B, H, S, D) * 0.3).astype(np.float32)
K_bshd = np.ascontiguousarray(bhsd_to_bshd(K_bhsd))
V_bshd = np.ascontiguousarray(bhsd_to_bshd(V_bhsd))
O_ref_bhsd = cpu_attention_fwd(Q_bhsd, K_bhsd, V_bhsd, prob.scale)
O_ref_bshd = bhsd_to_bshd(O_ref_bhsd)
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
combos = [
("iperm=1 operm=1", "bhsd->bhsd", Q_bhsd, K_bhsd, V_bhsd, 1, O_ref_bhsd),
("iperm=1 operm=0", "bhsd->bshd", Q_bhsd, K_bhsd, V_bhsd, 0, O_ref_bshd),
("iperm=0 operm=1", "bshd->bhsd", Q_bshd, K_bshd, V_bshd, 1, O_ref_bhsd),
("iperm=0 operm=0", "bshd->bshd", Q_bshd, K_bshd, V_bshd, 0, O_ref_bshd),
]
print(
f"\n {'Config':<18} {'Transform':<14} {'OutShape':>24} {'MaxErr':>12} {'Status':>8}"
)
print(" " + "-" * 80)
all_ok = True
for name, transform, Q_in, K_in, V_in, operm, O_expected in combos:
if Q_in.shape[1] == H:
O_out = cpu_attention_fwd(Q_in, K_in, V_in, prob.scale)
if operm == 0:
O_out = bhsd_to_bshd(O_out)
else:
O_out = cpu_attention_fwd_bshd(Q_in, K_in, V_in, prob.scale, operm)
ok, max_abs, _ = validator.check(O_out, O_expected)
all_ok = all_ok and ok
print(
f" {name:<18} {transform:<14} {str(O_out.shape):>24} {max_abs:>12.2e} {'PASS' if ok else 'FAIL':>8}"
)
# Step 3: Stride comparison table
print("\nStep 3: Stride Comparison")
print(f"\n For B={B}, H={H}, S={S}, D={D}:")
print(f" {'Layout':>8} {'Dim Order':>16} {'Strides':>28} {'hdim contiguous':>18}")
print(" " + "-" * 74)
bhsd_strides = (H * S * D, S * D, D, 1)
bshd_strides = (S * H * D, H * D, D, 1)
print(f" {'bhsd':>8} {'B,H,S,D':>16} {str(bhsd_strides):>28} {'Yes':>18}")
print(f" {'bshd':>8} {'B,S,H,D':>16} {str(bshd_strides):>28} {'Yes':>18}")
print("\n Stride analysis:")
print(f" bhsd: advancing 1 token = skip {D} elements (hdim)")
print(f" bshd: advancing 1 token = skip {H * D} elements (nhead * hdim)")
print(f" bhsd: advancing 1 head = skip {S * D} elements (seqlen * hdim)")
print(f" bshd: advancing 1 head = skip {D} elements (hdim)")
# Step 4: Conversion cost
print("\nStep 4: Layout Conversion Cost")
tensor_bytes = B * H * S * D * 4
print(f" Tensor size: {tensor_bytes / 1024:.1f} KB (float32)")
print(" bhsd <-> bshd conversion: transpose(0,2,1,3) + contiguous copy")
print(
" If upstream provides bshd and kernel wants bhsd, conversion costs ~2x memory bandwidth"
)
print(" Using iperm parameter avoids this copy by adjusting kernel strides")
# Step 5: GPU run (bhsd, default layout)
print("\nStep 5: GPU Run (bhsd layout, iperm=1)")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
Q_f16 = Q_bhsd.astype(np.float16)
K_f16 = K_bhsd.astype(np.float16)
V_f16 = V_bhsd.astype(np.float16)
result = runner.run(Q_f16, K_f16, V_f16, prob)
if result.success:
ok_gpu, max_abs_gpu, _ = validator.check(result.output, O_ref_bhsd)
print(
f" GPU (bhsd): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
f"max_err={max_abs_gpu:.2e} {'PASS' if ok_gpu else 'FAIL'}"
)
else:
print(f" GPU error: {result.error}")
# Step 6: Kernel configuration for bshd
print("\nStep 6: GPU Kernel Configuration for bshd")
print(" The prebuilt library uses bhsd (iperm=1, operm=1).")
print(" For bshd input/output, the kernel adjusts internal strides:")
print()
print(" iperm=0: kernel reads Q,K,V as [B, S, H, D] with stride_head=D")
print(" iperm=1: kernel reads Q,K,V as [B, H, S, D] with stride_seq=D")
print(" operm=0: kernel writes O as [B, S, H, D]")
print(" operm=1: kernel writes O as [B, H, S, D]")
# Summary
print("\n" + "=" * 70)
print(" iperm=0 (bshd): [B, S, H, D] -- sequence-first layout")
print(" iperm=1 (bhsd): [B, H, S, D] -- head-first layout (default)")
print(f" All 4 combinations validated: {'PASS' if all_ok else 'FAIL'}")
print(" Use iperm/operm to match upstream/downstream layout without copies")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,268 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 26: Head Dimension Variety FMHA
Demonstrates FMHA with multiple head dimensions (32, 64, 128, 256) and
asymmetric hdim (hdim_q != hdim_v). Different head dimensions require
different tile sizes and kernel configurations for optimal performance.
The prebuilt library supports hdim=128 only. This example validates all
head dimensions via CPU reference and runs GPU for hdim=128.
Usage:
python3 26_hdim_variety_fmha.py
python3 26_hdim_variety_fmha.py --seqlen 256
python3 26_hdim_variety_fmha.py --batch 4
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
def recommended_tile(hdim: int) -> str:
"""Suggest tile configuration for a given head dimension."""
tiles = {
32: "128x128x32x32x32x32",
64: "128x64x32x64x32x64",
128: "128x128x32x128x32x128",
256: "128x128x32x256x32x256",
}
return tiles.get(hdim, f"auto (hdim={hdim})")
def compute_flops(
batch: int, nhead_q: int, sq: int, sk: int, hdim_q: int, hdim_v: int
) -> int:
"""Compute FMHA FLOPs accounting for asymmetric hdim."""
return 2 * batch * nhead_q * sq * sk * (hdim_q + hdim_v)
def main():
parser = argparse.ArgumentParser(
description="Head Dimension Variety FMHA Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 26: Head Dimension Variety FMHA")
print("=" * 70)
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
# Step 1: Symmetric head dimensions
print("\nStep 1: Symmetric Head Dimensions (hdim_q == hdim_v)")
hdims = [32, 64, 128, 256]
print(f"\n {'hdim':>6} {'Shape':>30} {'Tile Config':>30} {'FLOPs':>14}")
print(" " + "-" * 84)
for hdim in hdims:
shape = f"B{args.batch}_H{args.nhead}_S{args.seqlen}_D{hdim}"
tile = recommended_tile(hdim)
flops = compute_flops(
args.batch, args.nhead, args.seqlen, args.seqlen, hdim, hdim
)
print(f" {hdim:>6} {shape:>30} {tile:>30} {flops:>14,}")
# Step 2: CPU validation for each hdim
print("\nStep 2: CPU Validation")
np.random.seed(42)
print(
f"\n {'hdim_q':>7} {'hdim_v':>7} {'Scale':>10} {'OutRange':>22} {'SelfCheck':>10}"
)
print(" " + "-" * 60)
cpu_results = {}
for hdim in hdims:
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=hdim,
hdim_v=hdim,
)
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
O_ref = cpu_attention_fwd(Q, K, V, prob.scale)
self_ok = np.all(np.isfinite(O_ref))
out_range = f"[{O_ref.min():.4f}, {O_ref.max():.4f}]"
print(
f" {hdim:>7} {hdim:>7} {prob.scale:>10.4f} {out_range:>22} {'OK' if self_ok else 'NaN!':>10}"
)
cpu_results[hdim] = (Q, K, V, O_ref, prob)
# Step 3: Asymmetric head dimensions
print("\nStep 3: Asymmetric Head Dimensions (hdim_q != hdim_v)")
asymmetric_configs = [
(128, 64, "Large Q, small V: more attention capacity, compact output"),
(64, 128, "Small Q, large V: compact attention, rich output"),
(128, 256, "Standard Q, very large V: high-capacity value projection"),
(256, 128, "Large Q, standard V: wide attention field"),
(32, 128, "Tiny Q, standard V: minimal attention compute"),
]
print(
f"\n {'hdim_q':>7} {'hdim_v':>7} {'Q Shape':>22} {'O Shape':>22} {'MaxErr vs self':>16}"
)
print(" " + "-" * 78)
for hdim_q, hdim_v, desc in asymmetric_configs:
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=hdim_q,
hdim_v=hdim_v,
)
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
out = cpu_attention_fwd(Q, K, V, prob.scale)
O2 = cpu_attention_fwd(Q, K, V, prob.scale)
max_err = float(np.abs(out - O2).max())
print(
f" {hdim_q:>7} {hdim_v:>7} {str(prob.q_shape()):>22} {str(prob.o_shape()):>22} {max_err:>16.2e}"
)
print("\n Asymmetric hdim notes:")
for hdim_q, hdim_v, desc in asymmetric_configs:
print(f" hdim_q={hdim_q}, hdim_v={hdim_v}: {desc}")
# Step 4: GPU validation (hdim=128)
print("\nStep 4: GPU Validation (hdim=128)")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=128,
hdim_v=128,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
gpu_tflops = 0.0
gpu_time = 0.0
if not setup.success:
print(f" JIT build failed: {setup.error}")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
Q, K, V, O_ref, prob = cpu_results[128]
Q_f16 = Q.astype(np.float16)
K_f16 = K.astype(np.float16)
V_f16 = V.astype(np.float16)
result = runner.run(Q_f16, K_f16, V_f16, prob)
if result.success:
ok, max_abs, _ = validator.check(result.output, O_ref)
print(
f" GPU hdim=128: time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
f"max_err={max_abs:.2e} {'PASS' if ok else 'FAIL'}"
)
gpu_tflops = result.tflops
gpu_time = result.time_ms
else:
print(f" GPU error: {result.error}")
# Step 5: Performance projection table
print("\nStep 5: Performance Summary Table")
print(
f"\n {'hdim_q':>7} | {'hdim_v':>7} | {'FLOPs':>14} | {'Tile':>24} | {'GPU Support':>12}"
)
print(" " + "-" * 78)
for hdim in hdims:
flops = compute_flops(
args.batch, args.nhead, args.seqlen, args.seqlen, hdim, hdim
)
tile = recommended_tile(hdim)
gpu_ok = "prebuilt" if hdim == 128 else "needs JIT"
print(f" {hdim:>7} | {hdim:>7} | {flops:>14,} | {tile:>24} | {gpu_ok:>12}")
print(" " + "-" * 78)
for hdim_q, hdim_v, _ in asymmetric_configs[:3]:
flops = compute_flops(
args.batch, args.nhead, args.seqlen, args.seqlen, hdim_q, hdim_v
)
gpu_ok = "needs JIT"
print(
f" {hdim_q:>7} | {hdim_v:>7} | {flops:>14,} | {'asymmetric':>24} | {gpu_ok:>12}"
)
# Step 6: Kernel configuration per hdim
print("\nStep 6: Kernel Configuration Per Head Dimension")
print(" Each hdim requires a dedicated compiled kernel:")
print()
print(
" hdim=32: FmhaKernelConfig(hdim_q=32, hdim_v=32, "
"tile_m0=128, tile_n0=128, tile_k0=32, tile_n1=32, tile_k1=32, tile_k0max=32)"
)
print(
" hdim=64: FmhaKernelConfig(hdim_q=64, hdim_v=64, "
"tile_m0=128, tile_n0=64, tile_k0=32, tile_n1=64, tile_k1=32, tile_k0max=64)"
)
print(
" hdim=128: FmhaKernelConfig(hdim_q=128, hdim_v=128, "
"tile_m0=128, tile_n0=128, tile_k0=32, tile_n1=128, tile_k1=32, tile_k0max=128)"
)
print(
" hdim=256: FmhaKernelConfig(hdim_q=256, hdim_v=256, "
"tile_m0=128, tile_n0=128, tile_k0=32, tile_n1=256, tile_k1=32, tile_k0max=256)"
)
print()
print(" Asymmetric: FmhaKernelConfig(hdim_q=128, hdim_v=64, ...)")
print(" tile_n1 tracks hdim_v; tile_k0max tracks hdim_q")
# Summary
print("\n" + "=" * 70)
print(f" Supported symmetric hdims: {hdims}")
print(" Asymmetric hdim (hdim_q != hdim_v): fully supported")
print(" Tile sizes scale with hdim; larger hdim needs wider tiles")
if gpu_tflops > 0:
print(f" GPU baseline (hdim=128): {gpu_tflops:.2f} TFLOPS @ {gpu_time:.4f} ms")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,373 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 27: Backward Pass with Dropout FMHA
Demonstrates the FMHA backward pass with dropout. The backward pass
computes dQ, dK, dV given dO (gradient of the output). When dropout is
applied during forward, the same dropout mask must be replayed during
backward for correctness.
Key concepts:
- Deterministic mode (no atomics): reproducible gradients, may be slower
- Non-deterministic mode: uses atomicAdd for dQ, faster but non-reproducible
- store_randval: optionally store the dropout random values for debugging
The prebuilt library only has a forward kernel. This example validates
the backward CPU reference and shows the API pattern.
Usage:
python3 27_backward_dropout_fmha.py
python3 27_backward_dropout_fmha.py --dropout 0.2
python3 27_backward_dropout_fmha.py --seqlen 128 --deterministic
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
)
def cpu_attention_fwd_dropout(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
dropout_p: float,
seed: int = 42,
) -> tuple:
"""CPU reference: forward with dropout, returning intermediates for backward.
Returns:
O: [B, H, Sq, Dv] output
P_drop: [B, H, Sq, Sk] attention weights after dropout
lse: [B, H, Sq] log-sum-exp for numerical stability
drop_mask: [B, H, Sq, Sk] binary dropout mask
"""
nhead_q = Q.shape[1]
nhead_k = K.shape[1]
if nhead_q != nhead_k:
ratio = nhead_q // nhead_k
K = np.repeat(K, ratio, axis=1)
V = np.repeat(V, ratio, axis=1)
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
S_sum = S_exp.sum(axis=-1, keepdims=True)
P = S_exp / S_sum
lse = np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)
rng = np.random.RandomState(seed)
drop_mask = (rng.rand(*P.shape) >= dropout_p).astype(np.float32)
drop_scale = 1.0 / (1.0 - dropout_p) if dropout_p < 1.0 else 0.0
P_drop = P * drop_mask * drop_scale
out = np.matmul(P_drop, V)
return out, P_drop, lse, drop_mask
def cpu_attention_bwd_dropout(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
out: np.ndarray,
dO: np.ndarray,
lse: np.ndarray,
scale: float,
dropout_p: float,
drop_mask: np.ndarray,
deterministic: bool = False,
) -> tuple:
"""CPU reference: backward with dropout.
Args:
Q: [B, H, Sq, Dq] float32
K: [B, H, Sk, Dq] float32 (already GQA-expanded if needed)
V: [B, H, Sk, Dv] float32
out: [B, H, Sq, Dv] float32 (forward output)
dO: [B, H, Sq, Dv] float32 (output gradient)
lse: [B, H, Sq] float32 (log-sum-exp from forward)
scale: softmax scale
dropout_p: dropout probability
drop_mask: [B, H, Sq, Sk] binary mask from forward
deterministic: if True, avoid any non-deterministic accumulation
Returns:
dQ: [B, H, Sq, Dq]
dK: [B, H, Sk, Dq]
dV: [B, H, Sk, Dv]
"""
drop_scale = 1.0 / (1.0 - dropout_p) if dropout_p < 1.0 else 0.0
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S_max = S.max(axis=-1, keepdims=True)
P = np.exp(S - S_max) / np.exp(S - S_max).sum(axis=-1, keepdims=True)
P_drop = P * drop_mask * drop_scale
dV = np.matmul(P_drop.transpose(0, 1, 3, 2), dO)
dP_drop = np.matmul(dO, V.transpose(0, 1, 3, 2))
dP = dP_drop * drop_mask * drop_scale
D = (dO * out).sum(axis=-1, keepdims=True)
dS = P * (dP - D) * scale
dQ = np.matmul(dS, K)
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q)
return dQ, dK, dV
def main():
parser = argparse.ArgumentParser(
description="Backward Pass with Dropout FMHA Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=64)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument(
"--dropout", type=float, default=0.1, help="Dropout probability"
)
parser.add_argument(
"--deterministic", action="store_true", help="Use deterministic mode"
)
args = parser.parse_args()
print("=" * 70)
print("Example 27: Backward Pass with Dropout FMHA")
print("=" * 70)
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
# Step 1: Forward with dropout
print("\nStep 1: Forward Pass with Dropout")
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
O_nodrop = cpu_attention_fwd(Q, K, V, prob.scale)
O_drop, P_drop, lse, drop_mask = cpu_attention_fwd_dropout(
Q,
K,
V,
prob.scale,
args.dropout,
seed=42,
)
print(f" Shape: {prob.q_shape()}")
print(f" Dropout: p={args.dropout}")
print(
f" Drop mask: {drop_mask.sum():.0f}/{drop_mask.size} kept "
f"({100 * drop_mask.mean():.1f}%, expected {100 * (1 - args.dropout):.1f}%)"
)
print(f" O (no drop): range=[{O_nodrop.min():.4f}, {O_nodrop.max():.4f}]")
print(f" O (dropout): range=[{O_drop.min():.4f}, {O_drop.max():.4f}]")
print(f" LSE shape: {lse.shape}")
# Step 2: Backward pass
print("\nStep 2: Backward Pass")
np.random.seed(123)
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
dQ, dK, dV = cpu_attention_bwd_dropout(
Q,
K,
V,
O_drop,
dO,
lse,
prob.scale,
args.dropout,
drop_mask,
deterministic=args.deterministic,
)
print(f" dQ shape: {dQ.shape} range=[{dQ.min():.6f}, {dQ.max():.6f}]")
print(f" dK shape: {dK.shape} range=[{dK.min():.6f}, {dK.max():.6f}]")
print(f" dV shape: {dV.shape} range=[{dV.min():.6f}, {dV.max():.6f}]")
print(f" Deterministic: {args.deterministic}")
# Step 3: Verify gradient correctness via finite differences
print("\nStep 3: Gradient Verification (Finite Differences)")
eps = 1e-3
num_checks = 5
rng = np.random.RandomState(99)
print(f"\n Checking {num_checks} random elements per tensor:")
print(
f" {'Tensor':>8} {'Index':>24} {'Analytic':>14} {'Numerical':>14} {'RelErr':>12}"
)
print(" " + "-" * 76)
for tensor_name, param, grad in [("dQ", Q, dQ), ("dK", K, dK), ("dV", V, dV)]:
for _ in range(num_checks):
idx = tuple(rng.randint(0, s) for s in param.shape)
param_plus = param.copy()
param_plus[idx] += eps
param_minus = param.copy()
param_minus[idx] -= eps
if tensor_name == "dQ":
O_p, _, _, _ = cpu_attention_fwd_dropout(
param_plus, K, V, prob.scale, args.dropout, seed=42
)
O_m, _, _, _ = cpu_attention_fwd_dropout(
param_minus, K, V, prob.scale, args.dropout, seed=42
)
elif tensor_name == "dK":
O_p, _, _, _ = cpu_attention_fwd_dropout(
Q, param_plus, V, prob.scale, args.dropout, seed=42
)
O_m, _, _, _ = cpu_attention_fwd_dropout(
Q, param_minus, V, prob.scale, args.dropout, seed=42
)
else:
O_p, _, _, _ = cpu_attention_fwd_dropout(
Q, K, param_plus, prob.scale, args.dropout, seed=42
)
O_m, _, _, _ = cpu_attention_fwd_dropout(
Q, K, param_minus, prob.scale, args.dropout, seed=42
)
numerical = (O_p * dO).sum() - (O_m * dO).sum()
numerical /= 2 * eps
analytic = grad[idx]
rel_err = abs(analytic - numerical) / (abs(numerical) + 1e-8)
idx_str = str(idx)
print(
f" {tensor_name:>8} {idx_str:>24} {analytic:>14.6f} {numerical:>14.6f} {rel_err:>12.2e}"
)
# Step 4: Deterministic vs non-deterministic comparison
print("\nStep 4: Deterministic vs Non-Deterministic")
dQ_det, dK_det, dV_det = cpu_attention_bwd_dropout(
Q,
K,
V,
O_drop,
dO,
lse,
prob.scale,
args.dropout,
drop_mask,
deterministic=True,
)
dQ_ndet, dK_ndet, dV_ndet = cpu_attention_bwd_dropout(
Q,
K,
V,
O_drop,
dO,
lse,
prob.scale,
args.dropout,
drop_mask,
deterministic=False,
)
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
for name, g_det, g_ndet in [
("dQ", dQ_det, dQ_ndet),
("dK", dK_det, dK_ndet),
("dV", dV_det, dV_ndet),
]:
ok, max_abs, _ = validator.check(g_det, g_ndet)
print(
f" {name}: det vs non-det max_err={max_abs:.2e} {'MATCH' if ok else 'DIFFER'}"
)
print("\n NOTE: In CPU reference both modes are identical.")
print(" On GPU, non-deterministic mode uses atomicAdd for dQ accumulation,")
print(" which can cause tiny floating-point differences across runs.")
# Step 5: Dropout probability sweep
print("\nStep 5: Dropout Probability Sweep")
probs = [0.0, 0.1, 0.2, 0.3, 0.5]
print(
f"\n {'p':>6} {'|dQ| mean':>12} {'|dK| mean':>12} {'|dV| mean':>12} {'Kept%':>8}"
)
print(" " + "-" * 54)
for p in probs:
O_p, _, _, dm = cpu_attention_fwd_dropout(Q, K, V, prob.scale, p, seed=42)
dQ_p, dK_p, dV_p = cpu_attention_bwd_dropout(
Q,
K,
V,
O_p,
dO,
lse,
prob.scale,
p,
dm,
)
kept = 100 * dm.mean()
print(
f" {p:>6.2f} {np.abs(dQ_p).mean():>12.6f} {np.abs(dK_p).mean():>12.6f} "
f"{np.abs(dV_p).mean():>12.6f} {kept:>7.1f}%"
)
# Step 6: GPU API pattern
print("\nStep 6: GPU Backward Kernel Configuration")
print(" NOTE: The prebuilt library only has a forward kernel.")
print(" FMHA backward requires 3 kernel stages:")
print()
print(" Stage 1: bwd_dot_do_o -- compute D = rowsum(dO * O)")
print(" Stage 2: bwd_dq_dk_dv -- compute dQ, dK, dV")
print(" Stage 3: bwd_convert_dq -- convert accumulated dQ")
print()
print(" With dropout, the signature requires:")
print(" .dropout(true)")
print(" .store_randval(false) // or true to save random values")
print(f" .deterministic({'true' if args.deterministic else 'false'})")
# Summary
print("\n" + "=" * 70)
print(" Backward with dropout: replays same mask from forward pass")
print(" Deterministic mode: reproducible but potentially slower on GPU")
print(" 3-stage backward: dot_do_o -> dq_dk_dv -> convert_dq")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,360 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 28: Backward Bias Gradient (dbias) FMHA
Demonstrates computing the gradient of the elementwise attention bias
during the backward pass. When forward attention uses:
S = Q @ K^T * scale + bias
the backward pass must compute:
dbias = sum over batch of (dP)
where dP is the gradient of the attention probabilities.
This is useful for learnable relative position biases (e.g., ALiBi
training, T5-style relative position embeddings).
The prebuilt library only has a forward kernel. This example validates
the dbias CPU reference and shows the API pattern.
Usage:
python3 28_backward_dbias_fmha.py
python3 28_backward_dbias_fmha.py --seqlen 128
python3 28_backward_dbias_fmha.py --bias-type alibi
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaProblem,
cpu_attention_fwd,
detect_gpu_arch,
)
def make_elementwise_bias(nhead: int, seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Create a simple elementwise attention bias [nhead, seqlen_q, seqlen_k]."""
bias = np.zeros((nhead, seqlen_q, seqlen_k), dtype=np.float32)
for h in range(nhead):
for i in range(seqlen_q):
for j in range(seqlen_k):
bias[h, i, j] = -0.1 * abs(i - j) * (h + 1) / nhead
return bias
def make_alibi_bias(nhead: int, seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Create ALiBi-style attention bias [nhead, seqlen_q, seqlen_k].
ALiBi adds a linear penalty proportional to distance:
bias[h, i, j] = -slope_h * |i - j|
where slope_h decreases geometrically across heads.
"""
slopes = np.array([2 ** (-(8 * (h + 1) / nhead)) for h in range(nhead)])
bias = np.zeros((nhead, seqlen_q, seqlen_k), dtype=np.float32)
for h in range(nhead):
for i in range(seqlen_q):
for j in range(seqlen_k):
bias[h, i, j] = -slopes[h] * abs(i - j)
return bias
def cpu_attention_fwd_bias(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
bias: np.ndarray,
) -> tuple:
"""CPU forward with elementwise bias, returning intermediates.
Args:
Q: [B, H, Sq, Dq]
K: [B, H, Sk, Dq]
V: [B, H, Sk, Dv]
bias: [H, Sq, Sk] broadcast over batch
Returns:
O: [B, H, Sq, Dv]
P: [B, H, Sq, Sk] attention probabilities
lse: [B, H, Sq] log-sum-exp
"""
nhead_q = Q.shape[1]
nhead_k = K.shape[1]
if nhead_q != nhead_k:
ratio = nhead_q // nhead_k
K = np.repeat(K, ratio, axis=1)
V = np.repeat(V, ratio, axis=1)
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S = S + bias[np.newaxis, :, :, :]
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
S_sum = S_exp.sum(axis=-1, keepdims=True)
P = S_exp / S_sum
lse = np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)
out = np.matmul(P, V)
return out, P, lse
def cpu_attention_bwd_dbias(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
out: np.ndarray,
dO: np.ndarray,
P: np.ndarray,
scale: float,
bias: np.ndarray,
) -> tuple:
"""CPU backward computing dQ, dK, dV, and dbias.
Args:
Q, K, V: forward inputs [B, H, Sq/Sk, D]
out: forward output [B, H, Sq, Dv]
dO: output gradient [B, H, Sq, Dv]
P: attention probabilities [B, H, Sq, Sk]
scale: softmax scale
bias: [H, Sq, Sk] attention bias
Returns:
dQ: [B, H, Sq, Dq]
dK: [B, H, Sk, Dq]
dV: [B, H, Sk, Dv]
dbias: [H, Sq, Sk] summed over batch dimension
"""
nhead_q = Q.shape[1]
nhead_k = K.shape[1]
if nhead_q != nhead_k:
ratio = nhead_q // nhead_k
K = np.repeat(K, ratio, axis=1)
V = np.repeat(V, ratio, axis=1)
dV = np.matmul(P.transpose(0, 1, 3, 2), dO)
dP = np.matmul(dO, V.transpose(0, 1, 3, 2))
D = (dO * out).sum(axis=-1, keepdims=True)
dS = P * (dP - D) * scale
dQ = np.matmul(dS, K)
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q)
dbias = dS.sum(axis=0) / scale
return dQ, dK, dV, dbias
def main():
parser = argparse.ArgumentParser(
description="Backward Bias Gradient (dbias) FMHA Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=4)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=64)
parser.add_argument("--hdim", type=int, default=128)
parser.add_argument(
"--bias-type", choices=["elementwise", "alibi"], default="elementwise"
)
args = parser.parse_args()
print("=" * 70)
print("Example 28: Backward Bias Gradient (dbias) FMHA")
print("=" * 70)
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
# Step 1: Create bias
print(f"\nStep 1: Create {args.bias_type.title()} Bias")
if args.bias_type == "alibi":
bias = make_alibi_bias(args.nhead, args.seqlen, args.seqlen)
else:
bias = make_elementwise_bias(args.nhead, args.seqlen, args.seqlen)
print(f" Bias shape: {bias.shape}")
print(f" Bias range: [{bias.min():.4f}, {bias.max():.4f}]")
print(f" Bias type: {args.bias_type}")
for h in range(min(4, args.nhead)):
print(
f" Head {h}: range=[{bias[h].min():.4f}, {bias[h].max():.4f}] "
f"mean={bias[h].mean():.4f}"
)
# Step 2: Forward pass with bias
print("\nStep 2: Forward Pass with Bias")
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
O_nobias = cpu_attention_fwd(Q, K, V, prob.scale)
O_bias, P, lse = cpu_attention_fwd_bias(Q, K, V, prob.scale, bias)
diff = np.abs(O_nobias - O_bias)
print(f" O (no bias): range=[{O_nobias.min():.4f}, {O_nobias.max():.4f}]")
print(f" O (biased): range=[{O_bias.min():.4f}, {O_bias.max():.4f}]")
print(f" Bias effect: max_diff={diff.max():.6e} mean_diff={diff.mean():.6e}")
# Step 3: Backward pass with dbias
print("\nStep 3: Backward Pass (dQ, dK, dV, dbias)")
np.random.seed(123)
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
dQ, dK, dV, dbias = cpu_attention_bwd_dbias(
Q,
K,
V,
O_bias,
dO,
P,
prob.scale,
bias,
)
print(f" dQ shape: {dQ.shape} range=[{dQ.min():.6f}, {dQ.max():.6f}]")
print(f" dK shape: {dK.shape} range=[{dK.min():.6f}, {dK.max():.6f}]")
print(f" dV shape: {dV.shape} range=[{dV.min():.6f}, {dV.max():.6f}]")
print(f" dbias shape: {dbias.shape} range=[{dbias.min():.6f}, {dbias.max():.6f}]")
# Step 4: Verify dbias via finite differences
print("\nStep 4: dbias Gradient Verification (Finite Differences)")
eps = 1e-3
num_checks = 8
rng = np.random.RandomState(99)
print(
f"\n {'Index':>20} {'Analytic':>14} {'Numerical':>14} {'RelErr':>12} {'Status':>8}"
)
print(" " + "-" * 72)
all_grad_ok = True
for _ in range(num_checks):
h = rng.randint(0, args.nhead)
i = rng.randint(0, args.seqlen)
j = rng.randint(0, args.seqlen)
bias_plus = bias.copy()
bias_plus[h, i, j] += eps
bias_minus = bias.copy()
bias_minus[h, i, j] -= eps
O_p, _, _ = cpu_attention_fwd_bias(Q, K, V, prob.scale, bias_plus)
O_m, _, _ = cpu_attention_fwd_bias(Q, K, V, prob.scale, bias_minus)
numerical = ((O_p * dO).sum() - (O_m * dO).sum()) / (2 * eps)
analytic = dbias[h, i, j]
rel_err = abs(analytic - numerical) / (abs(numerical) + 1e-8)
ok = rel_err < 1e-2
all_grad_ok = all_grad_ok and ok
idx_str = f"({h},{i},{j})"
print(
f" {idx_str:>20} {analytic:>14.6f} {numerical:>14.6f} {rel_err:>12.2e} {'OK' if ok else 'FAIL':>8}"
)
# Step 5: dbias structure analysis
print("\nStep 5: dbias Structure Analysis")
print("\n Per-head dbias statistics:")
print(f" {'Head':>6} {'Mean':>12} {'Std':>12} {'Min':>12} {'Max':>12}")
print(" " + "-" * 56)
for h in range(min(8, args.nhead)):
db_h = dbias[h]
print(
f" {h:>6} {db_h.mean():>12.6f} {db_h.std():>12.6f} "
f"{db_h.min():>12.6f} {db_h.max():>12.6f}"
)
# Step 6: Batch size effect on dbias
print("\nStep 6: Batch Size Effect on dbias")
print(" dbias = sum of per-sample dS / scale over batch dimension")
print(" Larger batch -> dbias aggregates more gradient signal")
batch_sizes = [1, 2, 4, 8]
print(
f"\n {'Batch':>6} {'|dbias| mean':>14} {'|dbias| max':>14} {'dbias std':>14}"
)
print(" " + "-" * 52)
for b in batch_sizes:
Q_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.3).astype(
np.float32
)
K_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.3).astype(
np.float32
)
V_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.3).astype(
np.float32
)
dO_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.1).astype(
np.float32
)
O_b, P_b, lse_b = cpu_attention_fwd_bias(Q_b, K_b, V_b, prob.scale, bias)
_, _, _, dbias_b = cpu_attention_bwd_dbias(
Q_b,
K_b,
V_b,
O_b,
dO_b,
P_b,
prob.scale,
bias,
)
print(
f" {b:>6} {np.abs(dbias_b).mean():>14.6f} {np.abs(dbias_b).max():>14.6f} "
f"{dbias_b.std():>14.6f}"
)
# Step 7: GPU API pattern
print("\nStep 7: GPU Kernel Configuration")
print(" NOTE: The prebuilt library only has a forward kernel without bias.")
print(" For backward with dbias, compile kernels with:")
print()
print(" Forward: FmhaSignature().bias('bias') // elementwise bias")
print(" Backward: FmhaSignature()")
print(" .family('bwd_dq_dk_dv')")
print(" .bias('bias')")
print(" .dbias(true) // enable dbias computation")
print()
print(" In codegen JSON:")
print(" 'bias': 'bias', // forward: elementwise bias")
print(" 'dbias': true, // backward: compute bias gradient")
# Summary
print("\n" + "=" * 70)
print(" dbias = sum_batch(P * (dP - D)) (gradient of elementwise bias)")
print(f" Shape: [{args.nhead}, {args.seqlen}, {args.seqlen}] (same as bias)")
print(f" Gradient check: {'PASS' if all_grad_ok else 'FAIL'}")
print(" Use case: learnable relative position biases (ALiBi, T5, etc.)")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 29: Sweep Sequence Length
Demonstrates how FMHA performance scales with sequence length.
FMHA has O(n^2) compute in seqlen (Q*K^T), so TFLOPS should increase
with longer sequences as the GPU becomes better utilized.
Fixed: batch=2, nhead=8, hdim=128
Sweep: seqlen in [32, 64, 128, 256, 512, 1024, 2048]
Usage:
python3 29_sweep_seqlen.py
python3 29_sweep_seqlen.py --arch gfx942
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
BATCH = 2
NHEAD = 8
HDIM = 128
SEQLENS = [32, 64, 128, 256, 512, 1024, 2048]
def main():
parser = argparse.ArgumentParser(description="Sweep Sequence Length FMHA")
parser.add_argument("--arch", default=detect_gpu_arch())
args = parser.parse_args()
print("=" * 70)
print("Example 29: Sweep Sequence Length")
print("=" * 70)
print(f"\n Fixed: batch={BATCH}, nhead={NHEAD}, hdim={HDIM}")
print(f" Sweep: seqlen in {SEQLENS}")
print(f" Arch: {args.arch}")
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
# Step 1: JIT-compile FMHA kernel
print("\nStep 1: JIT-Compile FMHA Kernel")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=HDIM,
hdim_v=HDIM,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
return 1
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
# Step 2: Sweep
print("\nStep 2: Sequence Length Sweep")
hdr = f" {'SeqLen':>8} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<6}"
print(f"\n{hdr}")
print(" " + "-" * 60)
np.random.seed(42)
results = []
for seqlen in SEQLENS:
prob = FmhaProblem(
batch=BATCH,
nhead_q=NHEAD,
nhead_k=NHEAD,
seqlen_q=seqlen,
seqlen_k=seqlen,
hdim_q=HDIM,
hdim_v=HDIM,
)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
O_ref = cpu_attention_fwd(
Q.astype(np.float32),
K.astype(np.float32),
V.astype(np.float32),
prob.scale,
)
res = runner.run(Q, K, V, prob)
if not res.success:
print(
f" {seqlen:>8} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<6}"
)
results.append((seqlen, False, 0.0, 0.0, 0.0))
continue
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
ok, _, _ = validator.check(res.output, O_ref)
tag = "PASS" if ok else "FAIL"
print(
f" {seqlen:>8} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<6}"
)
results.append((seqlen, ok, res.time_ms, res.tflops, max_err))
# Step 3: Scaling analysis
print("\nStep 3: Scaling Analysis")
valid = [(s, t, tf) for s, ok, t, tf, _ in results if ok and tf > 0]
if len(valid) >= 2:
s0, _, tf0 = valid[0]
s_last, _, tf_last = valid[-1]
print(f" Shortest (seqlen={s0}): {tf0:.2f} TFLOPS")
print(f" Longest (seqlen={s_last}): {tf_last:.2f} TFLOPS")
print(f" Speedup: {tf_last / tf0:.1f}x TFLOPS improvement")
print(" Note: Longer sequences expose more parallelism to the GPU")
# Summary
passed = sum(1 for _, ok, *_ in results if ok)
print("\n" + "=" * 70)
print(f" Results: {passed}/{len(results)} passed")
print(f" Fixed: B={BATCH} H={NHEAD} D={HDIM}")
print(f" Sweep: seqlen={SEQLENS}")
print(f" Status: {'PASS' if passed == len(results) else 'FAIL'}")
print("=" * 70)
return 0 if passed == len(results) else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 30: Sweep Batch Size
Demonstrates how FMHA performance scales with batch size.
FMHA compute scales linearly with batch, so time should increase
linearly while TFLOPS remains roughly constant once the GPU is saturated.
Fixed: seqlen=128, nhead=8, hdim=128
Sweep: batch in [1, 2, 4, 8, 16, 32]
Usage:
python3 30_sweep_batch.py
python3 30_sweep_batch.py --arch gfx942
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
SEQLEN = 128
NHEAD = 8
HDIM = 128
BATCHES = [1, 2, 4, 8, 16, 32]
def main():
parser = argparse.ArgumentParser(description="Sweep Batch Size FMHA")
parser.add_argument("--arch", default=detect_gpu_arch())
args = parser.parse_args()
print("=" * 70)
print("Example 30: Sweep Batch Size")
print("=" * 70)
print(f"\n Fixed: seqlen={SEQLEN}, nhead={NHEAD}, hdim={HDIM}")
print(f" Sweep: batch in {BATCHES}")
print(f" Arch: {args.arch}")
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
# Step 1: JIT-compile FMHA kernel
print("\nStep 1: JIT-Compile FMHA Kernel")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=HDIM,
hdim_v=HDIM,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
return 1
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
# Step 2: Sweep
print("\nStep 2: Batch Size Sweep")
hdr = f" {'Batch':>8} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<6}"
print(f"\n{hdr}")
print(" " + "-" * 60)
np.random.seed(42)
results = []
for batch in BATCHES:
prob = FmhaProblem(
batch=batch,
nhead_q=NHEAD,
nhead_k=NHEAD,
seqlen_q=SEQLEN,
seqlen_k=SEQLEN,
hdim_q=HDIM,
hdim_v=HDIM,
)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
O_ref = cpu_attention_fwd(
Q.astype(np.float32),
K.astype(np.float32),
V.astype(np.float32),
prob.scale,
)
res = runner.run(Q, K, V, prob)
if not res.success:
print(
f" {batch:>8} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<6}"
)
results.append((batch, False, 0.0, 0.0, 0.0))
continue
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
ok, _, _ = validator.check(res.output, O_ref)
tag = "PASS" if ok else "FAIL"
print(
f" {batch:>8} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<6}"
)
results.append((batch, ok, res.time_ms, res.tflops, max_err))
# Step 3: Linearity analysis
print("\nStep 3: Linear Scaling Analysis")
valid = [(b, t, tf) for b, ok, t, tf, _ in results if ok and t > 0]
if len(valid) >= 2:
b0, t0, tf0 = valid[0]
b_last, t_last, tf_last = valid[-1]
batch_ratio = b_last / b0
time_ratio = t_last / t0
linearity = time_ratio / batch_ratio
print(
f" Batch {b0} -> {b_last}: {batch_ratio:.0f}x batch, {time_ratio:.1f}x time"
)
print(f" Linearity factor: {linearity:.2f} (1.0 = perfect linear scaling)")
print(f" TFLOPS range: {tf0:.2f} - {tf_last:.2f}")
# Summary
passed = sum(1 for _, ok, *_ in results if ok)
print("\n" + "=" * 70)
print(f" Results: {passed}/{len(results)} passed")
print(f" Fixed: S={SEQLEN} H={NHEAD} D={HDIM}")
print(f" Sweep: batch={BATCHES}")
print(f" Status: {'PASS' if passed == len(results) else 'FAIL'}")
print("=" * 70)
return 0 if passed == len(results) else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,172 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 31: Sweep Number of Heads (MHA + GQA)
Demonstrates FMHA performance across different head counts, including
Grouped Query Attention (GQA) where nhead_q > nhead_k.
Part 1 - MHA sweep: nhead_q == nhead_k
Part 2 - GQA variants: nhead_q != nhead_k (multiple Q heads share K/V)
Fixed: batch=2, seqlen=128, hdim=128
Usage:
python3 31_sweep_nhead.py
python3 31_sweep_nhead.py --arch gfx942
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
BATCH = 2
SEQLEN = 128
HDIM = 128
MHA_NHEADS = [1, 2, 4, 8, 16, 32]
GQA_CONFIGS = [
(8, 1, "GQA 8:1"),
(16, 4, "GQA 4:1"),
(32, 8, "GQA 4:1"),
]
def run_sweep(runner, validator, configs, label):
"""Run a sweep over (nhead_q, nhead_k) configurations."""
hdr = f" {'nhead_q':>8} | {'nhead_k':>8} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<6}"
print(f"\n{hdr}")
print(" " + "-" * 70)
np.random.seed(42)
results = []
for nhead_q, nhead_k in configs:
prob = FmhaProblem(
batch=BATCH,
nhead_q=nhead_q,
nhead_k=nhead_k,
seqlen_q=SEQLEN,
seqlen_k=SEQLEN,
hdim_q=HDIM,
hdim_v=HDIM,
)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
O_ref = cpu_attention_fwd(
Q.astype(np.float32),
K.astype(np.float32),
V.astype(np.float32),
prob.scale,
)
res = runner.run(Q, K, V, prob)
if not res.success:
print(
f" {nhead_q:>8} | {nhead_k:>8} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<6}"
)
results.append((nhead_q, nhead_k, False, 0.0, 0.0, 0.0))
continue
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
ok, _, _ = validator.check(res.output, O_ref)
tag = "PASS" if ok else "FAIL"
print(
f" {nhead_q:>8} | {nhead_k:>8} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<6}"
)
results.append((nhead_q, nhead_k, ok, res.time_ms, res.tflops, max_err))
return results
def main():
parser = argparse.ArgumentParser(description="Sweep Number of Heads FMHA")
parser.add_argument("--arch", default=detect_gpu_arch())
args = parser.parse_args()
print("=" * 70)
print("Example 31: Sweep Number of Heads (MHA + GQA)")
print("=" * 70)
print(f"\n Fixed: batch={BATCH}, seqlen={SEQLEN}, hdim={HDIM}")
print(f" Arch: {args.arch}")
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
# Step 1: JIT-compile FMHA kernel
print("\nStep 1: JIT-Compile FMHA Kernel")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=HDIM,
hdim_v=HDIM,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if not setup.success:
print(f" JIT build failed: {setup.error}")
return 1
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
# Step 2: MHA sweep (nhead_q == nhead_k)
print("\nStep 2: MHA Sweep (nhead_q == nhead_k)")
mha_configs = [(n, n) for n in MHA_NHEADS]
mha_results = run_sweep(runner, validator, mha_configs, "MHA")
# Step 3: GQA sweep (nhead_q > nhead_k)
print("\nStep 3: GQA Sweep (nhead_q > nhead_k)")
print(" GQA: multiple Q heads share fewer K/V heads")
gqa_configs = [(nq, nk) for nq, nk, _ in GQA_CONFIGS]
gqa_results = run_sweep(runner, validator, gqa_configs, "GQA")
# Step 4: Comparison
print("\nStep 4: MHA vs GQA Comparison")
all_results = mha_results + gqa_results
valid_mha = [(nq, nk, tf) for nq, nk, ok, _, tf, _ in mha_results if ok and tf > 0]
valid_gqa = [(nq, nk, tf) for nq, nk, ok, _, tf, _ in gqa_results if ok and tf > 0]
if valid_mha:
best_mha = max(valid_mha, key=lambda x: x[2])
print(f" Best MHA: nhead={best_mha[0]}, {best_mha[2]:.2f} TFLOPS")
if valid_gqa:
best_gqa = max(valid_gqa, key=lambda x: x[2])
print(
f" Best GQA: nhead_q={best_gqa[0]}, nhead_k={best_gqa[1]}, {best_gqa[2]:.2f} TFLOPS"
)
print(f" GQA saves K/V memory: {best_gqa[0]}:{best_gqa[1]} ratio")
# Summary
passed = sum(1 for *_, ok, _, _, _ in all_results if ok)
total = len(all_results)
print("\n" + "=" * 70)
print(f" Results: {passed}/{total} passed")
print(f" Fixed: B={BATCH} S={SEQLEN} D={HDIM}")
print(f" MHA: nhead={MHA_NHEADS}")
print(f" GQA: {[(nq, nk) for nq, nk, _ in GQA_CONFIGS]}")
print(f" Status: {'PASS' if passed == total else 'FAIL'}")
print("=" * 70)
return 0 if passed == total else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,178 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 32: Sweep Head Dimension
Demonstrates FMHA across different head dimensions (32, 64, 128, 256).
The prebuilt library only supports hdim=128; other head dimensions are
validated via CPU reference only.
Fixed: batch=2, nhead=8, seqlen=128
Sweep: hdim in [32, 64, 128, 256]
Usage:
python3 32_sweep_hdim.py
python3 32_sweep_hdim.py --arch gfx942
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
FmhaValidator,
cpu_attention_fwd,
detect_gpu_arch,
setup_fmha_dispatcher,
)
BATCH = 2
NHEAD = 8
SEQLEN = 128
HDIMS = [32, 64, 128, 256]
GPU_SUPPORTED_HDIM = 128
def main():
parser = argparse.ArgumentParser(description="Sweep Head Dimension FMHA")
parser.add_argument("--arch", default=detect_gpu_arch())
args = parser.parse_args()
print("=" * 70)
print("Example 32: Sweep Head Dimension")
print("=" * 70)
print(f"\n Fixed: batch={BATCH}, nhead={NHEAD}, seqlen={SEQLEN}")
print(f" Sweep: hdim in {HDIMS}")
print(f" Arch: {args.arch}")
print(f" Note: Only hdim={GPU_SUPPORTED_HDIM} runs on GPU (prebuilt lib)")
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
# Step 1: JIT-compile FMHA kernel (hdim=128)
print("\nStep 1: JIT-Compile FMHA Kernel")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=GPU_SUPPORTED_HDIM,
hdim_v=GPU_SUPPORTED_HDIM,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
runner = None
if not setup.success:
print(f" JIT build failed: {setup.error}")
print(" Will run CPU reference only")
else:
runner = setup.runner
print(f" JIT build: {setup.build_time_s:.1f}s")
# Step 2: CPU reference for all hdims
print("\nStep 2: CPU Reference for All Head Dimensions")
np.random.seed(42)
cpu_data = {}
print(
f"\n {'hdim':>6} | {'Scale':>8} | {'FLOPs':>14} | {'O Range':>22} | {'Finite':<6}"
)
print(" " + "-" * 66)
for hdim in HDIMS:
prob = FmhaProblem(
batch=BATCH,
nhead_q=NHEAD,
nhead_k=NHEAD,
seqlen_q=SEQLEN,
seqlen_k=SEQLEN,
hdim_q=hdim,
hdim_v=hdim,
)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
O_ref = cpu_attention_fwd(Q, K, V, prob.scale)
is_finite = bool(np.all(np.isfinite(O_ref)))
o_range = f"[{O_ref.min():.4f}, {O_ref.max():.4f}]"
print(
f" {hdim:>6} | {prob.scale:>8.4f} | {prob.num_ops:>14,} | {o_range:>22} | {'OK' if is_finite else 'NaN!':<6}"
)
cpu_data[hdim] = (Q, K, V, O_ref, prob)
# Step 3: GPU sweep (only hdim=128 supported)
print("\nStep 3: GPU Sweep")
hdr = f" {'hdim':>6} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<10}"
print(f"\n{hdr}")
print(" " + "-" * 60)
results = []
for hdim in HDIMS:
Q, K, V, O_ref, prob = cpu_data[hdim]
if hdim != GPU_SUPPORTED_HDIM or runner is None:
print(
f" {hdim:>6} | {'---':>10} | {'---':>10} | {'---':>10} | {'CPU only':<10}"
)
results.append((hdim, True, 0.0, 0.0, 0.0))
continue
Q_f16 = Q.astype(np.float16)
K_f16 = K.astype(np.float16)
V_f16 = V.astype(np.float16)
res = runner.run(Q_f16, K_f16, V_f16, prob)
if not res.success:
print(
f" {hdim:>6} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<10}"
)
results.append((hdim, False, 0.0, 0.0, 0.0))
continue
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
ok, _, _ = validator.check(res.output, O_ref)
tag = "PASS" if ok else "FAIL"
print(
f" {hdim:>6} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<10}"
)
results.append((hdim, ok, res.time_ms, res.tflops, max_err))
# Step 4: hdim analysis
print("\nStep 4: Head Dimension Analysis")
print(" Each hdim requires a dedicated compiled kernel:")
for hdim in HDIMS:
gpu_status = "prebuilt" if hdim == GPU_SUPPORTED_HDIM else "needs JIT"
tile_hint = f"tile_k0max={hdim}"
print(f" hdim={hdim:>3}: {gpu_status:<10} ({tile_hint})")
print("\n Compute scales linearly with hdim (via Q*K^T and attn*V).")
print(" Larger hdim = more work per token, fewer tokens processed per CU.")
# Summary
passed = sum(1 for _, ok, *_ in results if ok)
total = len(results)
print("\n" + "=" * 70)
print(f" Results: {passed}/{total} passed")
print(f" Fixed: B={BATCH} H={NHEAD} S={SEQLEN}")
print(f" Sweep: hdim={HDIMS}")
print(f" GPU: hdim={GPU_SUPPORTED_HDIM} only (prebuilt)")
print(f" Status: {'PASS' if passed == total else 'FAIL'}")
print("=" * 70)
return 0 if passed == total else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,271 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 33: Backward Pass with Causal Masks
Demonstrates the FMHA backward pass with causal mask variants:
1. no_mask -- Full attention (baseline)
2. top_left -- Causal mask aligned to top-left corner
3. bottom_right -- Causal mask aligned to bottom-right corner
For each mask type:
- Forward: out = softmax(mask(Q @ K^T * scale)) @ V
- Backward: dQ, dK, dV via analytical gradients through the masked softmax
CPU backward reference:
dP = dO @ V^T
D = rowsum(dO * out) (per-query-position scalar)
dS = P * (dP - D)
dQ = scale * dS @ K
dK = scale * dS^T @ Q
dV = P^T @ dO
Usage:
python3 33_bwd_masks_fmha.py
python3 33_bwd_masks_fmha.py --seqlen-q 128 --seqlen-k 192
python3 33_bwd_masks_fmha.py --arch gfx942
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
setup_fmha_dispatcher,
detect_gpu_arch,
)
def make_causal_mask_top_left(seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Causal mask aligned to top-left: position i attends to positions <= i."""
row = np.arange(seqlen_q).reshape(-1, 1)
col = np.arange(seqlen_k).reshape(1, -1)
return (col <= row).astype(np.float32)
def make_causal_mask_bottom_right(seqlen_q: int, seqlen_k: int) -> np.ndarray:
"""Causal mask aligned to bottom-right: accounts for kv longer than q."""
offset = seqlen_k - seqlen_q
row = np.arange(seqlen_q).reshape(-1, 1)
col = np.arange(seqlen_k).reshape(1, -1)
return (col <= row + offset).astype(np.float32)
def cpu_masked_fwd_with_intermediates(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
mask: np.ndarray,
) -> tuple:
"""Forward pass with mask, returning out, P, and LSE for backward.
Args:
Q: [B, H, Sq, D] K: [B, H, Sk, D] V: [B, H, Sk, Dv]
mask: [Sq, Sk] broadcast over batch and head
Returns: (out, P, lse)
"""
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
mask_broad = mask[np.newaxis, np.newaxis, :, :]
S = np.where(mask_broad > 0, S, -1e9)
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
S_sum = S_exp.sum(axis=-1, keepdims=True)
P = S_exp / S_sum
out = np.matmul(P, V)
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
return out, P, lse
def cpu_masked_bwd(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
out: np.ndarray,
dO: np.ndarray,
P: np.ndarray,
scale: float,
) -> tuple:
"""CPU backward through masked softmax attention.
P already incorporates the mask (zeroed-out positions have P=0).
Returns: (dQ, dK, dV, D)
"""
D = (dO * out).sum(axis=-1, keepdims=True)
dP = np.matmul(dO, V.transpose(0, 1, 3, 2))
dS = P * (dP - D)
dQ = np.matmul(dS, K) * scale
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale
dV = np.matmul(P.transpose(0, 1, 3, 2), dO)
return dQ, dK, dV, D.squeeze(-1)
def main():
parser = argparse.ArgumentParser(description="Backward Pass with Causal Masks")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen-q", type=int, default=64)
parser.add_argument("--seqlen-k", type=int, default=64)
parser.add_argument("--hdim", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 33: Backward Pass with Causal Masks")
print("=" * 70)
sq, sk = args.seqlen_q, args.seqlen_k
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=sq,
seqlen_k=sk,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} Sq={sq} Sk={sk} D={args.hdim}")
print(f" Scale: {prob.scale:.6f}")
print(f" Arch: {args.arch}")
# --- JIT compile a basic fp16 h128 fwd kernel ---
print("\n--- JIT Compilation ---")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if setup.success:
print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s")
print(f" Library: {setup.library_path}")
print(" Note: Backward requires family='bwd' kernel (separate JIT)")
else:
print(f" JIT build: {setup.error}")
print(" Continuing with CPU reference only")
# --- Generate data ---
np.random.seed(42)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
# --- Build masks ---
masks = {
"no_mask": np.ones((sq, sk), dtype=np.float32),
"top_left": make_causal_mask_top_left(sq, sk),
"bottom_right": make_causal_mask_bottom_right(sq, sk),
}
# --- Per-mask forward + backward ---
print(
f"\n {'Mask':<16} {'Density':>8} | {'|dQ|':>10} {'|dK|':>10} {'|dV|':>10}"
f" | {'dQ vs base':>10} {'dK vs base':>10} {'dV vs base':>10}"
)
print(" " + "-" * 98)
base_grads = None
all_grads = {}
for name, mask in masks.items():
density = mask.sum() / mask.size * 100
out, P, lse = cpu_masked_fwd_with_intermediates(Q, K, V, prob.scale, mask)
dQ, dK, dV, D = cpu_masked_bwd(Q, K, V, out, dO, P, prob.scale)
dq_norm = float(np.abs(dQ).mean())
dk_norm = float(np.abs(dK).mean())
dv_norm = float(np.abs(dV).mean())
if base_grads is None:
base_grads = (dQ, dK, dV)
diff_str = f"{'---':>10} {'---':>10} {'---':>10}"
else:
dq_diff = float(np.abs(dQ - base_grads[0]).max())
dk_diff = float(np.abs(dK - base_grads[1]).max())
dv_diff = float(np.abs(dV - base_grads[2]).max())
diff_str = f"{dq_diff:>10.2e} {dk_diff:>10.2e} {dv_diff:>10.2e}"
print(
f" {name:<16} {density:>7.1f}% | {dq_norm:>10.4e} {dk_norm:>10.4e} {dv_norm:>10.4e}"
f" | {diff_str}"
)
all_grads[name] = (dQ, dK, dV, D)
# --- Detailed backward breakdown for each mask ---
print("\n--- Backward Stage Details ---")
for name, mask in masks.items():
dQ, dK, dV, D = all_grads[name]
out, P, lse = cpu_masked_fwd_with_intermediates(Q, K, V, prob.scale, mask)
print(f"\n [{name}]")
print(" Stage 1 (dot_do_o): D = rowsum(dO * out)")
print(f" D shape: {D.shape}, range: [{D.min():.6f}, {D.max():.6f}]")
print(" Stage 2 (dq_dk_dv):")
print(f" dQ range: [{dQ.min():.4e}, {dQ.max():.4e}]")
print(f" dK range: [{dK.min():.4e}, {dK.max():.4e}]")
print(f" dV range: [{dV.min():.4e}, {dV.max():.4e}]")
p_sparsity = (P < 1e-9).sum() / P.size * 100
print(f" P sparsity (< 1e-9): {p_sparsity:.1f}%")
# --- Gradient norm comparison across masks ---
print("\n--- Gradient L2 Norms ---")
print(f"\n {'Mask':<16} {'||dQ||_2':>12} {'||dK||_2':>12} {'||dV||_2':>12}")
print(" " + "-" * 54)
for name in masks:
dQ, dK, dV, _ = all_grads[name]
l2_dq = float(np.sqrt((dQ**2).sum()))
l2_dk = float(np.sqrt((dK**2).sum()))
l2_dv = float(np.sqrt((dV**2).sum()))
print(f" {name:<16} {l2_dq:>12.4e} {l2_dk:>12.4e} {l2_dv:>12.4e}")
# --- Mask pattern visualization ---
print("\n--- Mask Patterns (first 8x8 corner) ---")
view = min(8, sq, sk)
for name, mask in masks.items():
corner = mask[:view, :view]
print(f"\n {name}:")
for r in range(view):
row_str = " ".join("" if corner[r, c] > 0 else "·" for c in range(view))
print(f" {row_str}")
# --- Backward API pattern ---
print("\n--- Backward GPU API Pattern ---")
print(" The GPU backward for masked attention would use:")
print(" FmhaKernelConfig(family='bwd', mask='top_left', ...)")
print(" 3-stage backward plan:")
print(" Stage 1: bwd_dot_do_o -- D = rowsum(dO * out)")
print(" Stage 2: bwd_dq_dk_dv -- compute dQ, dK, dV with mask")
print(" Stage 3: bwd_convert_dq -- optional dtype conversion")
# --- Summary ---
print("\n" + "=" * 70)
print(" Mask variants: no_mask, top_left, bottom_right")
print(" Backward math: dP = dO @ V^T, dS = P*(dP - D)")
print(" dQ = scale*dS@K, dK = scale*dS^T@Q, dV = P^T@dO")
print(" Causal effect: Masked positions get P=0, zeroing their gradient flow")
print(" GPU: Requires bwd-family JIT kernel with mask support")
print(" Status: DEMO")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,277 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 34: Backward Pass with GQA (Grouped-Query Attention)
Demonstrates the FMHA backward pass when nhead_q != nhead_k.
GQA groups multiple query heads per KV head. The backward pass
must account for this by:
- Expanding K/V heads via np.repeat for dQ computation
- Summing dK/dV over query head groups back to KV head count
Tested GQA ratios: 1:1 (MHA), 2:1, 4:1, 8:1
CPU backward reference:
K_exp = repeat(K, ratio) # [B, Hq, Sk, D]
V_exp = repeat(V, ratio) # [B, Hq, Sk, Dv]
dQ = scale * (P * (dO@V_exp^T - D)) @ K_exp
dK_exp = scale * (P * (dO@V_exp^T - D))^T @ Q
dV_exp = P^T @ dO
dK = sum_over_groups(dK_exp) # [B, Hk, Sk, D]
dV = sum_over_groups(dV_exp) # [B, Hk, Sk, Dv]
Usage:
python3 34_bwd_gqa_fmha.py
python3 34_bwd_gqa_fmha.py --nhead-q 32
python3 34_bwd_gqa_fmha.py --arch gfx942
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
setup_fmha_dispatcher,
detect_gpu_arch,
)
def cpu_fwd_with_intermediates(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
) -> tuple:
"""Forward pass returning out, P, LSE (handles GQA via repeat)."""
nhead_q, nhead_k = Q.shape[1], K.shape[1]
if nhead_q != nhead_k:
ratio = nhead_q // nhead_k
K = np.repeat(K, ratio, axis=1)
V = np.repeat(V, ratio, axis=1)
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
S_sum = S_exp.sum(axis=-1, keepdims=True)
P = S_exp / S_sum
out = np.matmul(P, V)
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
return out, P, lse
def cpu_bwd_gqa(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
out: np.ndarray,
dO: np.ndarray,
P: np.ndarray,
scale: float,
nhead_q: int,
nhead_k: int,
) -> tuple:
"""CPU backward with GQA head grouping.
P is already computed on expanded heads [B, Hq, Sq, Sk].
K, V are original (unexpanded) [B, Hk, Sk, D].
Returns: (dQ, dK, dV) where dK/dV have shape [B, Hk, Sk, ...]
"""
ratio = nhead_q // nhead_k
K_exp = np.repeat(K, ratio, axis=1)
V_exp = np.repeat(V, ratio, axis=1)
D = (dO * out).sum(axis=-1, keepdims=True)
dP = np.matmul(dO, V_exp.transpose(0, 1, 3, 2))
dS = P * (dP - D)
dQ = np.matmul(dS, K_exp) * scale
dK_exp = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale
dV_exp = np.matmul(P.transpose(0, 1, 3, 2), dO)
B = Q.shape[0]
Sk, Dq = K.shape[2], K.shape[3]
Dv = V.shape[3]
dK = dK_exp.reshape(B, nhead_k, ratio, Sk, Dq).sum(axis=2)
dV = dV_exp.reshape(B, nhead_k, ratio, Sk, Dv).sum(axis=2)
return dQ, dK, dV
def main():
parser = argparse.ArgumentParser(description="Backward Pass with GQA")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead-q", type=int, default=16)
parser.add_argument("--seqlen", type=int, default=64)
parser.add_argument("--hdim", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 34: Backward Pass with GQA")
print("=" * 70)
hq = args.nhead_q
gqa_ratios = []
for ratio in [1, 2, 4, 8]:
if hq % ratio == 0 and hq // ratio >= 1:
gqa_ratios.append(ratio)
print(f"\n nhead_q: {hq}")
print(f" Ratios: {', '.join(f'{r}:1' for r in gqa_ratios)}")
print(f" Problem: B={args.batch} S={args.seqlen} D={args.hdim}")
# --- JIT compile a basic fp16 h128 fwd kernel ---
print("\n--- JIT Compilation ---")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if setup.success:
print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s")
print(" Note: Backward GQA requires bwd-family kernel (separate JIT)")
else:
print(f" JIT build: {setup.error}")
print(" Continuing with CPU reference only")
# --- Sweep GQA ratios ---
print("\n--- Backward Gradients per GQA Ratio ---")
print(
f"\n {'#':<3} {'Ratio':<8} {'Hq':>4} {'Hk':>4} "
f"| {'|dQ| mean':>10} {'|dK| mean':>10} {'|dV| mean':>10} "
f"| {'dK shape':>18} {'dV shape':>18}"
)
print(" " + "-" * 104)
all_results = {}
for i, ratio in enumerate(gqa_ratios, 1):
hk = hq // ratio
prob = FmhaProblem(
batch=args.batch,
nhead_q=hq,
nhead_k=hk,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
np.random.seed(42 + i)
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
out, P, lse = cpu_fwd_with_intermediates(Q, K, V, prob.scale)
dQ, dK, dV = cpu_bwd_gqa(Q, K, V, out, dO, P, prob.scale, hq, hk)
dq_mean = float(np.abs(dQ).mean())
dk_mean = float(np.abs(dK).mean())
dv_mean = float(np.abs(dV).mean())
label = f"{ratio}:1"
if ratio == 1:
label += " MHA"
elif hk == 1:
label += " MQA"
print(
f" {i:<3} {label:<8} {hq:>4} {hk:>4} "
f"| {dq_mean:>10.4e} {dk_mean:>10.4e} {dv_mean:>10.4e} "
f"| {str(dK.shape):>18} {str(dV.shape):>18}"
)
all_results[ratio] = (dQ, dK, dV, Q, K, V, out, dO, P, prob)
# --- Verify GQA backward via expanded MHA ---
print("\n--- GQA Backward Equivalence Check ---")
print(" Verifying: GQA bwd == MHA bwd with expanded K/V, then summed")
for ratio in gqa_ratios:
if ratio == 1:
continue
dQ_gqa, dK_gqa, dV_gqa, Q, K, V, out, dO, P, prob = all_results[ratio]
hk = hq // ratio
K_exp = np.repeat(K, ratio, axis=1)
V_exp = np.repeat(V, ratio, axis=1)
O_mha, P_mha, _ = cpu_fwd_with_intermediates(Q, K_exp, V_exp, prob.scale)
dQ_mha, dK_mha, dV_mha = cpu_bwd_gqa(
Q,
K_exp,
V_exp,
O_mha,
dO,
P_mha,
prob.scale,
hq,
hq,
)
B = Q.shape[0]
Sk = K.shape[2]
dK_mha_grouped = dK_mha.reshape(B, hk, ratio, Sk, K.shape[3]).sum(axis=2)
dV_mha_grouped = dV_mha.reshape(B, hk, ratio, Sk, V.shape[3]).sum(axis=2)
dq_err = float(np.abs(dQ_gqa - dQ_mha).max())
dk_err = float(np.abs(dK_gqa - dK_mha_grouped).max())
dv_err = float(np.abs(dV_gqa - dV_mha_grouped).max())
tag = "PASS" if max(dq_err, dk_err, dv_err) < 1e-5 else "FAIL"
print(
f" Ratio {ratio}:1 -- dQ err={dq_err:.2e} dK err={dk_err:.2e} "
f"dV err={dv_err:.2e} {tag}"
)
# --- Gradient accumulation analysis ---
print("\n--- Head-Group Gradient Accumulation ---")
print(" When ratio > 1, dK/dV are summed over query heads in each group.")
print(" Higher ratio -> more terms summed -> larger gradient magnitudes.\n")
print(f" {'Ratio':<8} {'||dK||_2':>12} {'||dV||_2':>12} {'dK/dV ratio':>12}")
print(" " + "-" * 48)
for ratio in gqa_ratios:
dQ, dK, dV, *_ = all_results[ratio]
l2_dk = float(np.sqrt((dK**2).sum()))
l2_dv = float(np.sqrt((dV**2).sum()))
dk_dv_ratio = l2_dk / (l2_dv + 1e-12)
print(f" {ratio}:1{'':<4} {l2_dk:>12.4e} {l2_dv:>12.4e} {dk_dv_ratio:>12.2f}")
# --- Backward GPU API pattern ---
print("\n--- Backward GPU API Pattern ---")
print(" GPU backward with GQA dispatches with nhead_q != nhead_k.")
print(" The dq_dk_dv kernel handles head grouping internally:")
print(" - dQ: computed per query head (no grouping needed)")
print(" - dK, dV: accumulated across head groups via atomicAdd")
print(" or multi-buffer reduction (deterministic mode)")
# --- Summary ---
print("\n" + "=" * 70)
print(f" GQA ratios tested: {len(gqa_ratios)}")
print(" Backward math: expand K/V -> compute grads -> sum dK/dV")
print(" Equivalence: GQA bwd == MHA(expanded) bwd + group sum")
print(" GPU: Requires bwd-family JIT kernel")
print(" Status: DEMO")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,270 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 35: Backward Pass with BF16 Data Type
Demonstrates the FMHA backward pass with bfloat16 precision.
BF16 differences from FP16:
- 8-bit exponent (same as fp32) vs fp16's 5-bit
- 7-bit mantissa vs fp16's 10-bit
- Larger dynamic range but lower precision
Tolerance guidance for backward:
- fp16 bwd: rtol=1.6e-2 typically sufficient
- bf16 bwd: rtol=3.2e-2 for hdim > 128 (less mantissa precision)
- bf16 bwd: rtol=2.0e-2 for hdim <= 128
CPU backward reference is computed in float32, then compared against
bf16-quantized inputs to measure the precision impact.
Usage:
python3 35_bwd_bf16_fmha.py
python3 35_bwd_bf16_fmha.py --hdim 256
python3 35_bwd_bf16_fmha.py --arch gfx942
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from fmha_utils import (
FmhaKernelConfig,
FmhaProblem,
setup_fmha_dispatcher,
detect_gpu_arch,
cpu_attention_bwd,
)
def to_bf16(arr: np.ndarray) -> np.ndarray:
"""Convert float32 -> bfloat16 (stored as uint16 with bf16 bit pattern)."""
f32 = arr.astype(np.float32)
u32 = f32.view(np.uint32)
return (u32 >> 16).astype(np.uint16)
def bf16_to_f32(arr_u16: np.ndarray) -> np.ndarray:
"""Convert bfloat16 (uint16) -> float32."""
u32 = arr_u16.astype(np.uint32) << 16
return u32.view(np.float32)
def cpu_fwd_with_intermediates(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
) -> tuple:
"""Forward pass returning out, P, LSE."""
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
S_sum = S_exp.sum(axis=-1, keepdims=True)
P = S_exp / S_sum
out = np.matmul(P, V)
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
return out, P, lse
def get_bwd_tolerance(dtype: str, hdim: int) -> tuple:
"""Recommended tolerances for backward pass validation."""
if dtype == "bf16":
if hdim > 128:
return 3.2e-2, 3.2e-2
return 2.0e-2, 2.0e-2
return 1.6e-2, 1.6e-2
def main():
parser = argparse.ArgumentParser(description="Backward Pass with BF16")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--seqlen", type=int, default=64)
parser.add_argument("--hdim", type=int, default=128)
args = parser.parse_args()
print("=" * 70)
print("Example 35: Backward Pass with BF16")
print("=" * 70)
prob = FmhaProblem(
batch=args.batch,
nhead_q=args.nhead,
nhead_k=args.nhead,
seqlen_q=args.seqlen,
seqlen_k=args.seqlen,
hdim_q=args.hdim,
hdim_v=args.hdim,
)
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}")
print(f" Scale: {prob.scale:.6f}")
print(f" Arch: {args.arch}")
# --- JIT compile a basic fp16 h128 fwd kernel ---
print("\n--- JIT Compilation ---")
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=args.hdim,
hdim_v=args.hdim,
gfx_arch=args.arch,
)
setup = setup_fmha_dispatcher(config)
if setup.success:
print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s")
print(
" Note: Native bf16 bwd kernel requires separate JIT with data_type='bf16'"
)
else:
print(f" JIT build: {setup.error}")
print(" Continuing with CPU reference only")
# --- Generate data in both dtypes ---
np.random.seed(42)
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
dO_f32 = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
Q_fp16 = Q_f32.astype(np.float16).astype(np.float32)
K_fp16 = K_f32.astype(np.float16).astype(np.float32)
V_fp16 = V_f32.astype(np.float16).astype(np.float32)
dO_fp16 = dO_f32.astype(np.float16).astype(np.float32)
Q_bf16 = bf16_to_f32(to_bf16(Q_f32))
K_bf16 = bf16_to_f32(to_bf16(K_f32))
V_bf16 = bf16_to_f32(to_bf16(V_f32))
dO_bf16 = bf16_to_f32(to_bf16(dO_f32))
# --- Quantization error comparison ---
print("\n--- Quantization Error ---")
print(
f"\n {'Tensor':<6} {'FP16 quant err':>16} {'BF16 quant err':>16} {'BF16/FP16':>10}"
)
print(" " + "-" * 52)
for name, orig, fp16, bf16 in [
("Q", Q_f32, Q_fp16, Q_bf16),
("K", K_f32, K_fp16, K_bf16),
("V", V_f32, V_fp16, V_bf16),
("dO", dO_f32, dO_fp16, dO_bf16),
]:
fp16_err = float(np.abs(orig - fp16).max())
bf16_err = float(np.abs(orig - bf16).max())
ratio = bf16_err / (fp16_err + 1e-15)
print(f" {name:<6} {fp16_err:>16.2e} {bf16_err:>16.2e} {ratio:>10.1f}x")
# --- Backward with both dtypes ---
print("\n--- Backward Gradients: FP16 vs BF16 Inputs ---")
dtype_configs = [
("fp16", Q_fp16, K_fp16, V_fp16, dO_fp16),
("bf16", Q_bf16, K_bf16, V_bf16, dO_bf16),
]
grad_results = {}
for dtype_name, Q_d, K_d, V_d, dO_d in dtype_configs:
out, P, lse = cpu_fwd_with_intermediates(Q_d, K_d, V_d, prob.scale)
dQ, dK, dV = cpu_attention_bwd(Q_d, K_d, V_d, out, dO_d, P, prob.scale)
grad_results[dtype_name] = (dQ, dK, dV)
print(f"\n {'Dtype':<6} {'|dQ| mean':>12} {'|dK| mean':>12} {'|dV| mean':>12}")
print(" " + "-" * 48)
for dtype_name in ["fp16", "bf16"]:
dQ, dK, dV = grad_results[dtype_name]
print(
f" {dtype_name:<6} {np.abs(dQ).mean():>12.4e} "
f"{np.abs(dK).mean():>12.4e} {np.abs(dV).mean():>12.4e}"
)
# --- Cross-dtype gradient difference ---
print("\n--- FP16 vs BF16 Backward Difference ---")
dQ_fp, dK_fp, dV_fp = grad_results["fp16"]
dQ_bf, dK_bf, dV_bf = grad_results["bf16"]
print(
f"\n {'Grad':<6} {'Max abs diff':>14} {'Mean abs diff':>14} {'Max rel diff':>14}"
)
print(" " + "-" * 52)
for name, g_fp, g_bf in [
("dQ", dQ_fp, dQ_bf),
("dK", dK_fp, dK_bf),
("dV", dV_fp, dV_bf),
]:
abs_diff = np.abs(g_fp - g_bf)
max_abs = float(abs_diff.max())
mean_abs = float(abs_diff.mean())
max_rel = float((abs_diff / (np.abs(g_fp) + 1e-8)).max())
print(f" {name:<6} {max_abs:>14.4e} {mean_abs:>14.4e} {max_rel:>14.4e}")
# --- Tolerance analysis for different hdims ---
print("\n--- Recommended Backward Tolerances ---")
print(f"\n {'Dtype':<6} {'hdim':>6} {'rtol':>10} {'atol':>10} {'Note'}")
print(" " + "-" * 54)
for dtype in ["fp16", "bf16"]:
for hdim in [64, 128, 256]:
rtol, atol = get_bwd_tolerance(dtype, hdim)
note = ""
if dtype == "bf16" and hdim > 128:
note = "<-- relaxed for large hdim"
print(f" {dtype:<6} {hdim:>6} {rtol:>10.1e} {atol:>10.1e} {note}")
# --- Validate backward with appropriate tolerances ---
print("\n--- Validation Against F32 Reference ---")
out_f32, P_f32, _ = cpu_fwd_with_intermediates(Q_f32, K_f32, V_f32, prob.scale)
dQ_ref, dK_ref, dV_ref = cpu_attention_bwd(
Q_f32,
K_f32,
V_f32,
out_f32,
dO_f32,
P_f32,
prob.scale,
)
for dtype_name in ["fp16", "bf16"]:
rtol, atol = get_bwd_tolerance(dtype_name, args.hdim)
dQ, dK, dV = grad_results[dtype_name]
print(f"\n [{dtype_name}] rtol={rtol:.1e}, atol={atol:.1e}")
for gname, g, g_ref in [
("dQ", dQ, dQ_ref),
("dK", dK, dK_ref),
("dV", dV, dV_ref),
]:
max_err = float(np.abs(g - g_ref).max())
ok = bool(np.allclose(g, g_ref, rtol=rtol, atol=atol))
print(f" {gname}: max_err={max_err:.4e} {'PASS' if ok else 'FAIL'}")
# --- BF16 backward GPU API pattern ---
print("\n--- BF16 Backward GPU API Pattern ---")
print(" Native bf16 backward kernel:")
print(" FmhaKernelConfig(family='bwd', data_type='bf16', ...)")
print(" Internal accumulation stays in fp32 for numerical stability.")
print(" Stage 3 (convert_dq) converts fp32 accumulator -> bf16 output.")
print(" BF16 advantage: wider dynamic range prevents overflow in")
print(" intermediate products (S = Q @ K^T) for large sequences.")
# --- Summary ---
print("\n" + "=" * 70)
print(" Data types: fp16 (10-bit mantissa) vs bf16 (7-bit mantissa)")
print(" Tolerances: bf16 bwd needs ~2x relaxed rtol vs fp16")
rtol_used, _ = get_bwd_tolerance("bf16", args.hdim)
print(f" Current hdim: {args.hdim} -> bf16 rtol={rtol_used:.1e}")
print(" GPU: Requires bwd-family JIT kernel with data_type='bf16'")
print(" Status: DEMO")
print("=" * 70)
return 0
if __name__ == "__main__":
sys.exit(main())

Some files were not shown because too many files have changed in this diff Show More