mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Merge remote-tracking branch 'origin/develop' into ginolu/sparge_attention
This commit is contained in:
181
CMakeLists.txt
181
CMakeLists.txt
@@ -52,6 +52,10 @@ option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
|
||||
option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
|
||||
option(BUILD_CK_TILE_ENGINE "Build the tile_engine subdirectory" OFF)
|
||||
option(BUILD_CK_EXAMPLES "Build the example subdirectory" ON)
|
||||
option(BUILD_CK_TUTORIALS "Build the tutorial subdirectory" ON)
|
||||
option(CK_ENABLE_ROCM_CK "Build rocm_ck API" OFF)
|
||||
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
add_definitions(-DCK_EXPERIMENTAL_BUILDER)
|
||||
@@ -207,6 +211,21 @@ else()
|
||||
set(USER_GPU_TARGETS 0)
|
||||
endif()
|
||||
|
||||
#Unsupported GPU targets to be filtered from the list:
|
||||
set(CK_UNSUPPORTED_GPU_TARGETS "gfx900;gfx906;gfx90c")
|
||||
|
||||
#If only one of the unsupported targets is requested, generate dummy target and exit here.
|
||||
if("${GPU_TARGETS}" IN_LIST CK_UNSUPPORTED_GPU_TARGETS)
|
||||
add_custom_target(ck_dummy_target)
|
||||
message("CK is not supported for target ${GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
#If multiple targets are requested, filter out any targets currently on the unsupported list:
|
||||
message(STATUS "Filtering out unsupported targets: ${CK_UNSUPPORTED_GPU_TARGETS}")
|
||||
list(REMOVE_ITEM GPU_TARGETS ${CK_UNSUPPORTED_GPU_TARGETS})
|
||||
list(REMOVE_ITEM GPU_ARCHS ${CK_UNSUPPORTED_GPU_TARGETS})
|
||||
|
||||
find_package(hip REQUIRED)
|
||||
enable_language(HIP)
|
||||
|
||||
@@ -229,8 +248,10 @@ if(NOT ENABLE_ASAN_PACKAGING)
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1200;gfx1201")
|
||||
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600400000 AND ${hip_VERSION_FLAT} LESS 600443483)
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1200;gfx1201;gfx950")
|
||||
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600443483)
|
||||
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600443483 AND ${hip_VERSION_FLAT} LESS 700200000)
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx950;gfx10-3-generic;gfx11-generic;gfx12-generic")
|
||||
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 700200000)
|
||||
set(CK_GPU_TARGETS "")
|
||||
endif()
|
||||
else()
|
||||
#build CK only for xnack-supported targets when using ASAN
|
||||
@@ -668,59 +689,64 @@ if(NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
# Optimization: Search only in library/src where all instance files actually live
|
||||
# (was searching entire source tree, taking ~40s instead of <1s)
|
||||
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/library/src/*/device_*_instance.cpp")
|
||||
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
|
||||
set(CK_DEVICE_INSTANCES)
|
||||
FOREACH(subdir_path ${dir_list})
|
||||
set(target_dir)
|
||||
IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}")
|
||||
set(cmake_instance)
|
||||
file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance)
|
||||
set(add_inst 0)
|
||||
if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(NOT "${cmake_instance}" MATCHES "DTYPES")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES)
|
||||
list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance)
|
||||
endif()
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
|
||||
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
|
||||
option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF)
|
||||
option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
option(BUILD_CK_DEVICE_INSTANCES "Build device operation instances in library/" ON)
|
||||
option(BUILD_CK_PROFILER "Build the CK profiler in profiler/" ON)
|
||||
option(BUILD_CK_TILE_ENGINE_TESTS "Build tile engine tests" ON)
|
||||
option(BUILD_CK_TILE_FMHA_TESTS "Build FMHA tests" ON)
|
||||
option(BUILD_CK_TILE_CSHUFFLE_LDS_BENCHMARKS "Build CShuffleLds microbenchmarks (requires BUILD_CK_EXAMPLES=ON)" OFF)
|
||||
|
||||
add_subdirectory(library)
|
||||
if(BUILD_CK_DEVICE_INSTANCES)
|
||||
# Optimization: Search only in library/src where all instance files actually live
|
||||
# (was searching entire source tree, taking ~40s instead of <1s)
|
||||
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/library/src/*/device_*_instance.cpp")
|
||||
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
|
||||
set(CK_DEVICE_INSTANCES)
|
||||
FOREACH(subdir_path ${dir_list})
|
||||
set(target_dir)
|
||||
IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}")
|
||||
set(cmake_instance)
|
||||
file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance)
|
||||
set(add_inst 0)
|
||||
if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(NOT "${cmake_instance}" MATCHES "DTYPES")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES)
|
||||
list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance)
|
||||
endif()
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
add_subdirectory(library)
|
||||
endif()
|
||||
|
||||
if (CK_EXPERIMENTAL_BUILDER)
|
||||
add_subdirectory(experimental/builder)
|
||||
@@ -728,34 +754,47 @@ if (CK_EXPERIMENTAL_BUILDER)
|
||||
endif()
|
||||
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
)
|
||||
if(BUILD_CK_EXAMPLES)
|
||||
rocm_package_setup_component(examples
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME examples
|
||||
)
|
||||
add_subdirectory(example)
|
||||
endif()
|
||||
|
||||
rocm_package_setup_component(examples
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME examples
|
||||
)
|
||||
add_subdirectory(example)
|
||||
|
||||
add_subdirectory(tutorial)
|
||||
rocm_package_setup_component(tutorials
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tutorials
|
||||
)
|
||||
add_subdirectory(tile_engine)
|
||||
if(BUILD_CK_TUTORIALS)
|
||||
add_subdirectory(tutorial)
|
||||
rocm_package_setup_component(tutorials
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tutorials
|
||||
)
|
||||
endif()
|
||||
if(BUILD_CK_TILE_ENGINE)
|
||||
add_subdirectory(tile_engine)
|
||||
endif()
|
||||
if(CK_ENABLE_ROCM_CK)
|
||||
add_subdirectory(rocm_ck)
|
||||
if(TARGET check)
|
||||
add_dependencies(check build-smoke-rocm-ck)
|
||||
endif()
|
||||
endif()
|
||||
if(BUILD_TESTING)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
if(BUILD_CK_PROFILER)
|
||||
if (NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
|
||||
|
||||
@@ -51,6 +51,22 @@
|
||||
"GPU_TARGETS": "gfx908;gfx90a;gfx942"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "dev-minimal",
|
||||
"binaryDir": "${sourceDir}/build",
|
||||
"displayName": "CK Dev - Minimal Build",
|
||||
"description": "Fast iteration build with minimal components (configure ~5s vs ~150s)",
|
||||
"inherits": ["dev"],
|
||||
"cacheVariables": {
|
||||
"BUILD_CK_DEVICE_INSTANCES": "OFF",
|
||||
"BUILD_CK_PROFILER": "OFF",
|
||||
"BUILD_CK_EXAMPLES": "OFF",
|
||||
"BUILD_CK_TUTORIALS": "OFF",
|
||||
"BUILD_CK_TILE_ENGINE": "OFF",
|
||||
"BUILD_CK_TILE_ENGINE_TESTS": "OFF",
|
||||
"BUILD_CK_TILE_FMHA_TESTS": "OFF"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "dev-gfx908",
|
||||
"displayName": "CK Dev - gfx908",
|
||||
|
||||
30
Dockerfile
30
Dockerfile
@@ -3,7 +3,19 @@ FROM ubuntu:24.04
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG ROCMVERSION=7.1.1
|
||||
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
|
||||
ARG TARBALL_URL=https://rocm.nightlies.amd.com/tarball/therock-dist-linux-gfx90X-dcgpu-7.12.0a20260218.tar.gz
|
||||
|
||||
# TheRock nightly tarball configuration.
|
||||
# By default, discovers the latest tarball from the nightlies index.
|
||||
# Manual overrides:
|
||||
# Pin a specific tarball:
|
||||
# --build-arg TARBALL_URL=https://rocm.nightlies.amd.com/tarball-multi-arch/therock-dist-linux-multiarch-7.13.0a20260430.tar.gz
|
||||
# Change the arch variant (default: multiarch):
|
||||
# --build-arg TARBALL_PATTERN=therock-dist-linux-gfx90a
|
||||
# --build-arg TARBALL_PATTERN=therock-dist-linux-gfx94X-dcgpu
|
||||
ARG TARBALL_URL=""
|
||||
ARG TARBALL_BASE=https://rocm.nightlies.amd.com/tarball-multi-arch
|
||||
ARG TARBALL_PATTERN=therock-dist-linux-multiarch
|
||||
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
|
||||
@@ -18,10 +30,18 @@ RUN set -xe && \
|
||||
|
||||
RUN if [ "$compiler_version" = "therock" ]; then \
|
||||
rm -rf /opt/rocm && mkdir /opt/rocm && \
|
||||
echo "Downloading ROCm tarball from $TARBALL_URL..." && \
|
||||
if [ -n "$TARBALL_URL" ]; then \
|
||||
echo "Using provided TARBALL_URL: $TARBALL_URL" ; \
|
||||
else \
|
||||
echo "Discovering latest tarball from $TARBALL_BASE..." && \
|
||||
TARBALL_URL="${TARBALL_BASE}/$(curl -sL "${TARBALL_BASE}/" \
|
||||
| grep -oP '"name":\s*"\K'"${TARBALL_PATTERN}"'-[^"]+\.tar\.gz' \
|
||||
| sort -V | tail -1)" && \
|
||||
echo "Found: $TARBALL_URL" ; \
|
||||
fi && \
|
||||
wget -q -O /tmp/rocm.tar.gz "$TARBALL_URL" && \
|
||||
echo "Extracting tarball to /opt/rocm..." && \
|
||||
tar -xzf /tmp/rocm.tar.gz -C /opt/rocm --strip-components=1 ; \
|
||||
tar -xzf /tmp/rocm.tar.gz -C /opt/rocm --strip-components=1 && \
|
||||
rm /tmp/rocm.tar.gz ; \
|
||||
else echo "using the release compiler" && \
|
||||
wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
|
||||
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
|
||||
@@ -36,7 +56,7 @@ ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin
|
||||
ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION}
|
||||
RUN set -x && \
|
||||
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
|
||||
wget -qO sccache.tar.gz https://github.com/mozilla/sccache/releases/latest/download/sccache-v$SCCACHE_VERSION-x86_64-unknown-linux-musl.tar.gz && \
|
||||
wget -qO sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v$SCCACHE_VERSION/sccache-v$SCCACHE_VERSION-x86_64-unknown-linux-musl.tar.gz && \
|
||||
tar -xzf sccache.tar.gz --strip-components=1 -C ${SCCACHE_INSTALL_LOCATION} && \
|
||||
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache
|
||||
|
||||
|
||||
@@ -10,30 +10,36 @@ RUN pip install pandas zmq einops ninja tabulate vcs_versioning && \
|
||||
sudo mkdir /home/jenkins/workspace && \
|
||||
cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \
|
||||
if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \
|
||||
git clone --depth 1 -b "$CK_AITER_BRANCH" --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \
|
||||
cd rocm-libraries && \
|
||||
mkdir rocm-libraries && cd rocm-libraries && \
|
||||
git init -q && \
|
||||
git remote add origin https://github.com/ROCm/rocm-libraries.git && \
|
||||
git fetch --depth 1 --filter=blob:none origin "$CK_AITER_BRANCH" && \
|
||||
git sparse-checkout init --cone && \
|
||||
git sparse-checkout set projects/composablekernel && \
|
||||
git checkout "$CK_AITER_BRANCH" && \
|
||||
git checkout FETCH_HEAD && \
|
||||
ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \
|
||||
LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \
|
||||
mv projects/composablekernel ../ck && \
|
||||
cd ../ck && rm -rf ../rocm-libraries && \
|
||||
git init && \
|
||||
git init -b "$LOCAL_BRANCH" && \
|
||||
git config user.name "assistant-librarian[bot]" && \
|
||||
git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \
|
||||
git branch -m "$CK_AITER_BRANCH" && git add -A && \
|
||||
git add -A && \
|
||||
git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" ; \
|
||||
else \
|
||||
git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck ; \
|
||||
git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck && \
|
||||
LOCAL_BRANCH="$CK_AITER_BRANCH" ; \
|
||||
fi && \
|
||||
cd /home/jenkins/workspace && rm -rf aiter && \
|
||||
git clone --depth 1 -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \
|
||||
cd aiter && \
|
||||
rm -rf 3rdparty/composable_kernel/ && \
|
||||
git clone -b "$CK_AITER_BRANCH" ../ck 3rdparty/composable_kernel/ && \
|
||||
git clone -b "$LOCAL_BRANCH" ../ck 3rdparty/composable_kernel/ && \
|
||||
python3 setup.py develop && \
|
||||
groupadd -g 1001 jenkins && \
|
||||
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \
|
||||
groupadd -f video && \
|
||||
groupadd -f render && \
|
||||
chown -R jenkins:jenkins /home/jenkins && \
|
||||
chmod -R a+rwx /home/jenkins && \
|
||||
chown -R jenkins:jenkins /tmp && \
|
||||
|
||||
@@ -12,30 +12,36 @@ RUN set -x ; \
|
||||
sudo mkdir /home/jenkins/workspace && \
|
||||
cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \
|
||||
if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \
|
||||
git clone --depth 1 -b "$CK_FA_BRANCH" --no-checkout --filter=blob:none https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \
|
||||
cd rocm-libraries && \
|
||||
mkdir rocm-libraries && cd rocm-libraries && \
|
||||
git init -q && \
|
||||
git remote add origin https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \
|
||||
git fetch --depth 1 --filter=blob:none origin "$CK_FA_BRANCH" && \
|
||||
git sparse-checkout init --cone && \
|
||||
git sparse-checkout set projects/composablekernel && \
|
||||
git checkout "$CK_FA_BRANCH" && \
|
||||
git checkout FETCH_HEAD && \
|
||||
ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \
|
||||
LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \
|
||||
mv projects/composablekernel ../ck && \
|
||||
cd ../ck && rm -rf ../rocm-libraries && \
|
||||
git init && \
|
||||
git init -b "$LOCAL_BRANCH" && \
|
||||
git config user.name "assistant-librarian[bot]" && \
|
||||
git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \
|
||||
git branch -m "$CK_FA_BRANCH" && git add -A && \
|
||||
git add -A && \
|
||||
git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" > /dev/null ; \
|
||||
else \
|
||||
git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck ; \
|
||||
git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck && \
|
||||
LOCAL_BRANCH="$CK_FA_BRANCH" ; \
|
||||
fi && \
|
||||
cd /home/jenkins/workspace && rm -rf flash-attention && \
|
||||
git clone --depth 1 -b "$FA_BRANCH" --recursive "https://github.com/$FA_ORIGIN/flash-attention.git" && \
|
||||
cd flash-attention && \
|
||||
rm -rf csrc/composable_kernel/ && \
|
||||
git clone -b "$CK_FA_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \
|
||||
git clone -b "$LOCAL_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \
|
||||
MAX_JOBS=$(nproc) GPU_ARCHS="$GPU_ARCHS" /opt/venv/bin/python3 -u -m pip install --no-build-isolation -v . && \
|
||||
groupadd -g 1001 jenkins && \
|
||||
useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \
|
||||
groupadd -f video && \
|
||||
groupadd -f render && \
|
||||
chown -R jenkins:jenkins /home/jenkins && \
|
||||
chmod -R a+rwx /home/jenkins && \
|
||||
chown -R jenkins:jenkins /tmp && \
|
||||
|
||||
@@ -4,6 +4,7 @@ ARG CK_PYTORCH_BRANCH="develop"
|
||||
RUN groupadd -g 109 render && \
|
||||
usermod -u 1001 jenkins && \
|
||||
groupmod -g 1001 jenkins && \
|
||||
pip install --upgrade pandas && \
|
||||
cd /tmp/pytorch && \
|
||||
rm -rf build && \
|
||||
cd /tmp/pytorch/third_party && \
|
||||
@@ -18,15 +19,8 @@ RUN groupadd -g 109 render && \
|
||||
cd /tmp/pytorch/third_party/flash-attention/csrc && \
|
||||
rm -rf composable_kernel && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
chown -R jenkins:jenkins /tmp/pytorch && \
|
||||
chmod -R a+rwx /tmp/pytorch && \
|
||||
sudo usermod -aG irc jenkins && \
|
||||
#install hipblaslt
|
||||
git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \
|
||||
cd rocm-libraries && \
|
||||
git checkout develop && \
|
||||
git sparse-checkout init --cone && \
|
||||
git sparse-checkout set projects/hipblaslt shared/origami && \
|
||||
cd projects/hipblaslt && \
|
||||
git show --oneline -s && \
|
||||
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller
|
||||
mkdir -p /var/jenkins/workspace/pytorch && \
|
||||
cp -r /tmp/pytorch/* /var/jenkins/workspace/pytorch/ && \
|
||||
chown -R jenkins:jenkins /var/jenkins/workspace/pytorch && \
|
||||
chmod -R a+rwx /var/jenkins/workspace/pytorch && \
|
||||
sudo usermod -aG irc jenkins
|
||||
|
||||
200
Jenkinsfile
vendored
200
Jenkinsfile
vendored
@@ -24,6 +24,24 @@
|
||||
// Benefits: PR builds 5h → 30min (typical), nightly builds unchanged
|
||||
// See: script/dependency-parser/README.md for details
|
||||
//
|
||||
|
||||
@NonCPS
|
||||
String getGitHubCommitHash(def build)
|
||||
{
|
||||
def scmAction = build?.actions.find { action ->
|
||||
action instanceof jenkins.scm.api.SCMRevisionAction
|
||||
}
|
||||
if (scmAction?.revision instanceof org.jenkinsci.plugins.github_branch_source.PullRequestSCMRevision)
|
||||
{
|
||||
return scmAction.revision.pullHash
|
||||
}
|
||||
else if (scmAction?.revision instanceof jenkins.plugins.git.AbstractGitSCMSource$SCMRevisionImpl)
|
||||
{
|
||||
return scmAction.revision.hash
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
def rocmnode(name) {
|
||||
return '(rocmtest || miopen) && (' + name + ')'
|
||||
}
|
||||
@@ -31,6 +49,7 @@ def rocmnode(name) {
|
||||
def show_node_info() {
|
||||
sh """
|
||||
echo "NODE_NAME = \$NODE_NAME"
|
||||
hostname
|
||||
lsb_release -sd
|
||||
uname -r
|
||||
cat /sys/module/amdgpu/version
|
||||
@@ -38,6 +57,30 @@ def show_node_info() {
|
||||
"""
|
||||
}
|
||||
|
||||
def setGithubStatus(String context, String state, String description) {
|
||||
def sha = env.GIT_COMMIT
|
||||
def targetUrl = env.RUN_DISPLAY_URL ?: env.BUILD_URL
|
||||
def statusUrl = "https://api.github.com/repos/ROCm/rocm-libraries/statuses/${sha}"
|
||||
withCredentials([usernamePassword(credentialsId: 'github-app-miopen', usernameVariable: 'GITHUB_APP', passwordVariable: 'GITHUB_TOKEN')]) {
|
||||
def code = '0'
|
||||
try {
|
||||
retry(3) {
|
||||
code = sh(returnStdout: true, script: """
|
||||
curl -s -w "%{http_code}" -o /dev/null -X POST '${statusUrl}' \\
|
||||
-H "Authorization: token \$GITHUB_TOKEN" \\
|
||||
-H 'Content-Type: application/json' \\
|
||||
-d '{"state":"${state}","context":"${context}","description":"${description}","target_url":"${targetUrl}"}'
|
||||
""").trim()
|
||||
if (!code.startsWith('2')) {
|
||||
error("GitHub status POST returned ${code}")
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
echo "WARNING: GitHub status POST failed after retries (context=${context}, state=${state}, code=${code})"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def cloneUpdateRefRepo() {
|
||||
def refRepoPath = "/var/jenkins/ref-repo/rocm-libraries"
|
||||
def lockLabel = "git ref repo lock - ${env.NODE_NAME}"
|
||||
@@ -78,7 +121,14 @@ def checkoutComposableKernel()
|
||||
//update ref repo
|
||||
cloneUpdateRefRepo()
|
||||
// checkout project
|
||||
checkout scm
|
||||
def scmVars = checkout scm
|
||||
// getGitHubCommitHash reads SCMRevisionAction recorded before any local merge,
|
||||
// giving the true PR branch tip (pullHash) or branch HEAD (hash).
|
||||
// Falls back to ORIG_HEAD (pre-merge HEAD set by git merge) when SCMRevisionAction
|
||||
// is unavailable, then to HEAD for branch builds where no merge occurred.
|
||||
env.GIT_COMMIT = getGitHubCommitHash(currentBuild.rawBuild) ?: sh(returnStdout: true, script: '''
|
||||
git rev-parse ORIG_HEAD 2>/dev/null || git rev-parse HEAD
|
||||
''').trim()
|
||||
}
|
||||
|
||||
def generateAndArchiveBuildTraceVisualization(String buildTraceFileName) {
|
||||
@@ -672,6 +722,9 @@ def cmake_build(Map conf=[:]){
|
||||
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
|
||||
setup_args = " -D CK_EXPERIMENTAL_BUILDER=ON " + setup_args
|
||||
}
|
||||
if (params.RUN_ROCM_CK_TESTS) {
|
||||
setup_args = " -D CK_ENABLE_ROCM_CK=ON " + setup_args
|
||||
}
|
||||
setup_cmd = conf.get(
|
||||
"setup_cmd",
|
||||
"""${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_FLAGS=" -O3 " .. """
|
||||
@@ -788,6 +841,9 @@ def cmake_build(Map conf=[:]){
|
||||
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
|
||||
sh 'ninja check-builder'
|
||||
}
|
||||
if (params.RUN_ROCM_CK_TESTS) {
|
||||
sh 'ninja check-rocm-ck'
|
||||
}
|
||||
if(params.BUILD_PACKAGES){
|
||||
echo "Build ckProfiler packages"
|
||||
sh 'ninja -j64 package'
|
||||
@@ -827,6 +883,9 @@ def cmake_build(Map conf=[:]){
|
||||
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
|
||||
sh 'ninja check-builder'
|
||||
}
|
||||
if (params.RUN_ROCM_CK_TESTS) {
|
||||
sh 'ninja check-rocm-ck'
|
||||
}
|
||||
if(params.BUILD_PACKAGES){
|
||||
echo "Build ckProfiler packages"
|
||||
sh 'ninja -j64 package'
|
||||
@@ -840,8 +899,10 @@ def cmake_build(Map conf=[:]){
|
||||
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_fmha_*.log"
|
||||
stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}"
|
||||
dir("projects/composablekernel"){
|
||||
archiveArtifacts "perf_fmha_*.log"
|
||||
stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}"
|
||||
}
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
|
||||
@@ -858,13 +919,19 @@ def buildHipClangJob(Map conf=[:]){
|
||||
def retimage
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
|
||||
setGithubStatus("${env.STAGE_NAME}", 'pending', "Starting ${env.STAGE_NAME}")
|
||||
try {
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 20, unit: 'HOURS')
|
||||
{
|
||||
cmake_build(conf)
|
||||
}
|
||||
}
|
||||
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
|
||||
}
|
||||
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
|
||||
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
|
||||
throw e
|
||||
}
|
||||
return retimage
|
||||
}
|
||||
@@ -888,7 +955,8 @@ def Build_CK(Map conf=[:]){
|
||||
def image
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
|
||||
setGithubStatus("${env.STAGE_NAME}", 'pending', "Starting ${env.STAGE_NAME}")
|
||||
try {
|
||||
try {
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
@@ -905,6 +973,7 @@ def Build_CK(Map conf=[:]){
|
||||
}
|
||||
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
|
||||
echo "The job was cancelled or aborted"
|
||||
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
|
||||
throw e
|
||||
}
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
@@ -915,16 +984,10 @@ def Build_CK(Map conf=[:]){
|
||||
cmake_build(conf)
|
||||
if ( params.RUN_INDUCTOR_TESTS && arch == "gfx90a" ){
|
||||
echo "Run inductor codegen tests"
|
||||
sh """
|
||||
python3 -m venv ${env.WORKSPACE}/projects/composablekernel
|
||||
. ${env.WORKSPACE}/projects/composablekernel/bin/activate
|
||||
python3 -m pip install pytest build setuptools setuptools_scm
|
||||
python3 -m pip install .
|
||||
python3 -m pytest python/test/test_gen_instances.py
|
||||
"""
|
||||
sh "projects/composablekernel/script/run_inductor_tests.sh"
|
||||
}
|
||||
// run performance tests, stash the logs, results will be processed on the master node
|
||||
dir("projects/composablekernel/script"){
|
||||
dir("projects/composablekernel/script"){
|
||||
if (params.RUN_PERFORMANCE_TESTS){
|
||||
if (params.RUN_FULL_QA && (arch == "gfx90a" || arch == "gfx942")){
|
||||
// run full tests on gfx90a or gfx942
|
||||
@@ -971,6 +1034,11 @@ def Build_CK(Map conf=[:]){
|
||||
}
|
||||
}
|
||||
}
|
||||
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
|
||||
}
|
||||
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
|
||||
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
|
||||
throw e
|
||||
}
|
||||
return retimage
|
||||
}
|
||||
@@ -991,7 +1059,8 @@ def process_results(Map conf=[:]){
|
||||
//use older image that has user jenkins
|
||||
def image = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm6.3"
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
|
||||
setGithubStatus("${env.STAGE_NAME}", 'pending', 'Processing results...')
|
||||
try {
|
||||
try
|
||||
{
|
||||
echo "Pulling image: ${image}"
|
||||
@@ -1005,6 +1074,10 @@ def process_results(Map conf=[:]){
|
||||
error "Unable to locate image: ${image}"
|
||||
}
|
||||
}
|
||||
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
|
||||
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
|
||||
throw e
|
||||
}
|
||||
|
||||
withDockerContainer(image: image, args: '--cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v=/var/jenkins/:/var/jenkins') {
|
||||
timeout(time: 15, unit: 'MINUTES'){
|
||||
@@ -1023,6 +1096,13 @@ def process_results(Map conf=[:]){
|
||||
catch(Exception err){
|
||||
echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_fmha_log_gfx950"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the FMHA performance logs for gfx950: ${err.getMessage()}."
|
||||
}
|
||||
|
||||
}
|
||||
if (params.BUILD_INSTANCES_ONLY){
|
||||
// unstash deb packages
|
||||
@@ -1105,10 +1185,10 @@ def process_results(Map conf=[:]){
|
||||
// process the logs
|
||||
sh "./process_perf_data.sh"
|
||||
}
|
||||
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
|
||||
}
|
||||
catch(e){
|
||||
echo "Throwing error exception while processing performance test results"
|
||||
echo 'Exception occurred: ' + e.toString()
|
||||
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
|
||||
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
|
||||
throw e
|
||||
}
|
||||
finally{
|
||||
@@ -1123,7 +1203,8 @@ def run_downstream_tests(Map conf=[:]){
|
||||
checkoutComposableKernel()
|
||||
def dockerOpts = get_docker_options() + ' --group-add irc '
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
|
||||
setGithubStatus("${env.STAGE_NAME}", 'pending', "Starting ${env.STAGE_NAME}")
|
||||
try {
|
||||
try
|
||||
{
|
||||
echo "Pulling image: ${conf.image}"
|
||||
@@ -1137,6 +1218,10 @@ def run_downstream_tests(Map conf=[:]){
|
||||
error "Unable to locate image: ${conf.image}"
|
||||
}
|
||||
}
|
||||
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
|
||||
setGithubStatus("${env.STAGE_NAME}", 'failure', "Stage ${env.STAGE_NAME} failed")
|
||||
throw e
|
||||
}
|
||||
|
||||
withDockerContainer(image: conf.image, args: dockerOpts) {
|
||||
timeout(time: conf.get("timeoutHours", 2), unit: 'HOURS'){
|
||||
@@ -1146,10 +1231,12 @@ def run_downstream_tests(Map conf=[:]){
|
||||
for (cmd in conf.execute_cmds) {
|
||||
sh "${cmd}"
|
||||
}
|
||||
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
|
||||
}
|
||||
catch(e){
|
||||
echo "Throwing error exception while running ${env.STAGE_NAME}"
|
||||
echo 'Exception occurred: ' + e.toString()
|
||||
setGithubStatus("${env.STAGE_NAME}", 'error', "Stage ${env.STAGE_NAME} failed")
|
||||
throw e
|
||||
}
|
||||
finally{
|
||||
@@ -1161,8 +1248,11 @@ def run_downstream_tests(Map conf=[:]){
|
||||
|
||||
def getPytorchTestsCmds() {
|
||||
return [
|
||||
"python3 /tmp/pytorch/tools/amd_build/build_amd.py",
|
||||
"USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python /tmp/pytorch/setup.py develop"
|
||||
"mkdir pytorch",
|
||||
"cp -r /var/jenkins/workspace/pytorch/* pytorch/",
|
||||
"ls -ltr pytorch",
|
||||
"python3 pytorch/tools/amd_build/build_amd.py",
|
||||
"cd pytorch && USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python3 setup.py develop"
|
||||
]
|
||||
}
|
||||
def getAiterTestsCmds() {
|
||||
@@ -1197,7 +1287,7 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_
|
||||
0 13 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
|
||||
0 11 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;RUN_FA_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true
|
||||
0 9 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : ""
|
||||
CURRENT_BRANCH_NAME = env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME
|
||||
CURRENT_BRANCH_NAME = env.CHANGE_ID ? "refs/pull/${env.CHANGE_ID}/head" : (env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME)
|
||||
|
||||
POLL_SPEC = BRANCH_NAME == "develop" ? 'H H/6 * * *' : ''
|
||||
|
||||
@@ -1318,8 +1408,8 @@ pipeline {
|
||||
description: "Build CK and run tests on gfx101 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX103",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx103 (default: ON)")
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx103 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX11",
|
||||
defaultValue: true,
|
||||
@@ -1338,8 +1428,8 @@ pipeline {
|
||||
description: "Generate a detailed time trace (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_INDUCTOR_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run inductor codegen tests (default: OFF)")
|
||||
defaultValue: true,
|
||||
description: "Run inductor codegen tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "RUN_CODEGEN_TESTS",
|
||||
defaultValue: true,
|
||||
@@ -1348,6 +1438,10 @@ pipeline {
|
||||
name: "RUN_BUILDER_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run CK_BUILDER tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_ROCM_CK_TESTS",
|
||||
defaultValue: true,
|
||||
description: "Run rocm_ck tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "RUN_ALL_UNIT_TESTS",
|
||||
defaultValue: false,
|
||||
@@ -1404,9 +1498,9 @@ pipeline {
|
||||
dbsshport = "${dbsshport}"
|
||||
dbsshuser = "${dbsshuser}"
|
||||
dbsshpassword = "${dbsshpassword}"
|
||||
ck_git_creds = "${ck_git_creds}"
|
||||
gerrit_cred="${gerrit_cred}"
|
||||
DOCKER_BUILDKIT = "1"
|
||||
BUILD_GFX103 = "${env.BRANCH_NAME == 'develop' ? true : false}"
|
||||
}
|
||||
stages{
|
||||
stage("Determine CI Execution") {
|
||||
@@ -1743,6 +1837,7 @@ pipeline {
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D BUILD_CK_TILE_ENGINE="ON" \
|
||||
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942" \
|
||||
@@ -1756,7 +1851,7 @@ pipeline {
|
||||
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
|
||||
-D GEMM_PRESHUFFLE_CONFIG_FILE="default_ci_config.json" .. && \
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
|
||||
}
|
||||
@@ -1785,6 +1880,7 @@ pipeline {
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D BUILD_CK_TILE_ENGINE="ON" \
|
||||
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942" \
|
||||
@@ -1799,7 +1895,7 @@ pipeline {
|
||||
-D GROUPED_GEMM_DATATYPE="fp8;fp16" \
|
||||
-D GROUPED_GEMM_LAYOUT="rcr;rrr;crr;ccr" .. && \
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all benchmark_grouped_gemm_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/grouped_gemm/grouped_gemm_benchmark.py . --problem-sizes "1024,1024,1024" --group-counts 8 --warmup 5 --repeat 5 --verbose --json grouped_gemm_results.json """
|
||||
@@ -1819,6 +1915,7 @@ pipeline {
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D BUILD_CK_TILE_ENGINE="ON" \
|
||||
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx950" \
|
||||
@@ -1829,7 +1926,7 @@ pipeline {
|
||||
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
|
||||
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
|
||||
}
|
||||
@@ -1848,13 +1945,14 @@ pipeline {
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D BUILD_CK_TILE_ENGINE="ON" \
|
||||
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx1201" \
|
||||
-D GEMM_UNIVERSAL_DATATYPE="fp16" \
|
||||
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" .. && \
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_universal_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1904,6 +2002,7 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
/*
|
||||
stage("Build CK and run Tests on gfx908")
|
||||
{
|
||||
when {
|
||||
@@ -1920,6 +2019,7 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
*/
|
||||
stage("Build CK and run Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
@@ -2036,8 +2136,9 @@ pipeline {
|
||||
}
|
||||
success {
|
||||
script {
|
||||
// Report the parent stage build ck and run tests status
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${env.STAGE_NAME}", account: 'ROCm', repo: 'rocm-libraries') {
|
||||
node(rocmnode("nogpu")) {
|
||||
// Report the parent stage build ck and run tests status
|
||||
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
|
||||
echo "Reporting success status for build ck and run tests"
|
||||
}
|
||||
}
|
||||
@@ -2063,12 +2164,9 @@ pipeline {
|
||||
post {
|
||||
success {
|
||||
script {
|
||||
// Report the skipped parent's stage status
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Process Performance Test Results", account: 'ROCm', repo: 'rocm-libraries') {
|
||||
echo "Process Performance Test Results stage skipped."
|
||||
}
|
||||
// Report the skipped stage's status
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Process results", account: 'ROCm', repo: 'rocm-libraries') {
|
||||
node(rocmnode("nogpu")) {
|
||||
// Report the skipped parent's stage status
|
||||
setGithubStatus("${env.STAGE_NAME}", 'success', "Stage ${env.STAGE_NAME} passed")
|
||||
echo "Process Performance Test Results stage skipped."
|
||||
}
|
||||
}
|
||||
@@ -2078,20 +2176,22 @@ pipeline {
|
||||
}
|
||||
post {
|
||||
success {
|
||||
githubNotify context: 'Math CI Summary',
|
||||
status: 'SUCCESS',
|
||||
description: 'All checks have passed'
|
||||
script {
|
||||
node(rocmnode("nogpu")) {
|
||||
setGithubStatus('Math CI Summary', 'success', "Math CI passed")
|
||||
}
|
||||
}
|
||||
}
|
||||
failure {
|
||||
githubNotify context: 'Math CI Summary',
|
||||
status: 'FAILURE',
|
||||
description: 'Some checks have failed'
|
||||
node(rocmnode("nogpu")) {
|
||||
script {
|
||||
checkoutComposableKernel()
|
||||
}
|
||||
withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) {
|
||||
sh 'bash projects/composablekernel/script/infra_helper/send_failure_notifications.sh'
|
||||
script {
|
||||
node(rocmnode("nogpu")) {
|
||||
setGithubStatus('Math CI Summary', 'failure', "Math CI failed")
|
||||
script {
|
||||
checkoutComposableKernel()
|
||||
}
|
||||
withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) {
|
||||
sh 'bash projects/composablekernel/script/infra_helper/send_failure_notifications.sh'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
15
README.md
15
README.md
@@ -124,6 +124,21 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
|
||||
../script/cmake-ck-dev.sh .. gfx90a -DCMAKE_BUILD_TYPE=Release
|
||||
```
|
||||
|
||||
**Fast iteration builds:**
|
||||
|
||||
For faster CMake configuration during development (~5s vs ~150s), use the `--minimal` flag to disable
|
||||
building device instances, profiler, examples, tutorials, and tests:
|
||||
|
||||
```bash
|
||||
../script/cmake-ck-dev.sh --minimal .. gfx90a
|
||||
```
|
||||
|
||||
You can also specify a custom preset:
|
||||
|
||||
```bash
|
||||
../script/cmake-ck-dev.sh --preset=dev-minimal .. gfx90a
|
||||
```
|
||||
|
||||
5. Build the entire CK library:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -71,6 +71,7 @@ set(GTEST_CXX_FLAGS
|
||||
-Wno-lifetime-safety-intra-tu-suggestions
|
||||
-Wno-lifetime-safety-cross-tu-suggestions
|
||||
-Wno-character-conversion
|
||||
-Wno-lifetime-safety-invalidation
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
@@ -78,7 +79,8 @@ if(WIN32)
|
||||
-Wno-suggest-destructor-override
|
||||
-Wno-suggest-override
|
||||
-Wno-nonportable-system-include-path
|
||||
-Wno-language-extension-token)
|
||||
-Wno-language-extension-token
|
||||
-Wno-lifetime-safety-invalidation)
|
||||
endif()
|
||||
|
||||
target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS})
|
||||
|
||||
@@ -21,6 +21,8 @@ endif()
|
||||
add_library(ck_tile_dispatcher
|
||||
src/registry.cpp
|
||||
src/dispatcher.cpp
|
||||
src/fmha_registry.cpp
|
||||
src/fmha_dispatcher.cpp
|
||||
)
|
||||
|
||||
# Enable PIC for Python bindings
|
||||
@@ -34,13 +36,21 @@ target_include_directories(ck_tile_dispatcher
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
# Link against CK Tile headers (header-only)
|
||||
# CK Tile core headers (ck_tile/core, ck_tile/ops, etc.)
|
||||
target_include_directories(ck_tile_dispatcher
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
# CK project root -- needed only for FMHA generated wrappers that include
|
||||
# "example/ck_tile/01_fmha/fmha_fwd.hpp". PRIVATE to avoid exposing the
|
||||
# entire project tree to downstream consumers.
|
||||
target_include_directories(ck_tile_dispatcher
|
||||
PRIVATE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/..>
|
||||
)
|
||||
|
||||
# Link against HIP headers if available
|
||||
if(hip_FOUND)
|
||||
target_link_libraries(ck_tile_dispatcher PUBLIC hip::host)
|
||||
|
||||
@@ -394,6 +394,12 @@ python3 examples/grouped_conv/python/03_bwd_data.py # Backward data +
|
||||
python3 examples/grouped_conv/python/04_bwd_weight.py # Backward weight + CPU ref
|
||||
python3 examples/grouped_conv/python/05_benchmark.py # Multi-problem benchmark
|
||||
python3 examples/grouped_conv/python/06_registry_json.py # Heuristic selection + JSON
|
||||
|
||||
# FMHA Examples (JIT-compiled on the fly)
|
||||
python3 examples/fmha/python/01_basic_fmha.py # Basic forward attention
|
||||
python3 examples/fmha/python/12_masks_fmha.py # Causal masks
|
||||
python3 examples/fmha/python/18_backward_fmha.py # Backward pass
|
||||
python3 examples/fmha/python/16_splitkv_fmha.py # Split-KV for long sequences
|
||||
```
|
||||
|
||||
### Example Output
|
||||
@@ -716,7 +722,7 @@ This matrix shows all CK Tile operations with per-data-type, per-layout, and per
|
||||
| GEMM | streamk_gemm<br>example: `40_streamk_gemm/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| Reduce | multi_reduce2d<br>example: `05_reduce/` | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Reduce | reduce2d<br>example: `05_reduce/` | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Attention | fmha<br>example: `01_fmha/` | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Attention | fmha<br>example: `01_fmha/` | ✅ | ✅ | ✅ | ✅ | ❌ | | | | | | | ✅ | ✅ | ✅ | ❌ |
|
||||
| Attention | sparse_attn<br>example: `50_sparse_attn/` | ❌ | | ❌ | | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Activation | softmax | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
| Activation | topk_softmax<br>example: `09_topk_softmax/` | ❌ | ❌ | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -871,7 +877,14 @@ dispatcher/
|
||||
| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder)
|
||||
| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations
|
||||
| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe)
|
||||
| +---- grouped_conv_utils.hpp # Grouped conv utilities
|
||||
| |---- grouped_conv_utils.hpp # Grouped conv utilities
|
||||
| |---- fmha_types.hpp # FMHA fwd/bwd args and traits structs
|
||||
| |---- fmha_problem.hpp # FmhaProblem, FmhaProblemBuilder
|
||||
| |---- fmha_kernel_key.hpp # FmhaKernelKey (Signature + Algorithm)
|
||||
| |---- fmha_kernel_instance.hpp # FmhaKernelInstance virtual interface
|
||||
| |---- fmha_kernel_decl.hpp # Declarative FmhaSignature/FmhaAlgorithm
|
||||
| |---- fmha_registry.hpp # FmhaRegistry (thread-safe)
|
||||
| +---- fmha_dispatcher.hpp # FmhaDispatcher (plan, select, run)
|
||||
|
|
||||
|---- src/ # C++ implementation
|
||||
|
|
||||
@@ -879,12 +892,17 @@ dispatcher/
|
||||
| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings
|
||||
| |---- unified_gemm_codegen.py # GEMM kernel generator
|
||||
| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator
|
||||
| |---- unified_fmha_codegen.py # FMHA kernel generator
|
||||
| |---- fmha_arch_specs.json # FMHA per-arch tile/pipeline specs
|
||||
| |---- fmha_rules.py # FMHA validation rules
|
||||
| |---- fmha_profiles.py # FMHA named profiles/receipts
|
||||
| +---- arch_specs.json # GPU specifications
|
||||
|
|
||||
|---- python/ # Python utilities
|
||||
| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output
|
||||
| |---- ctypes_utils.py # GEMM ctypes utilities
|
||||
| +---- grouped_conv_utils.py # Grouped conv utilities
|
||||
| |---- grouped_conv_utils.py # Grouped conv utilities
|
||||
| +---- fmha_utils.py # FMHA: JIT compile, FmhaRunner, FmhaKernelConfig
|
||||
|
|
||||
|---- scripts/ # Build scripts
|
||||
| |---- compile_gemm_examples.py # GEMM build script
|
||||
@@ -892,15 +910,19 @@ dispatcher/
|
||||
|
|
||||
|---- bindings/ctypes/ # Python ctypes interface
|
||||
| |---- gemm_ctypes_lib.cpp # GEMM Python library
|
||||
| +---- conv_ctypes_lib.cpp # Grouped conv Python library
|
||||
| |---- conv_ctypes_lib.cpp # Grouped conv Python library
|
||||
| +---- fmha_ctypes_lib.cpp # FMHA Python library
|
||||
|
|
||||
|---- examples/ # Examples
|
||||
| |---- gemm/
|
||||
| | |---- cpp/ # C++ GEMM examples (01-07)
|
||||
| | +---- python/ # Python GEMM examples (01-11)
|
||||
| +---- grouped_conv/
|
||||
| |---- cpp/ # C++ Grouped Conv examples (01-07)
|
||||
| +---- python/ # Python Grouped Conv examples (01-06)
|
||||
| |---- grouped_conv/
|
||||
| | |---- cpp/ # C++ Grouped Conv examples (01-07)
|
||||
| | +---- python/ # Python Grouped Conv examples (01-06)
|
||||
| +---- fmha/
|
||||
| |---- cpp/ # C++ FMHA examples (01-35)
|
||||
| +---- python/ # Python FMHA examples (01-38)
|
||||
|
|
||||
+---- tests/ # Unit tests (C++ and Python)
|
||||
```
|
||||
@@ -913,6 +935,8 @@ dispatcher/
|
||||
|-----------|--------|
|
||||
| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) |
|
||||
| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) |
|
||||
| FMHA C++ | examples/fmha/cpp/ (35 examples covering all FMHA variants) |
|
||||
| FMHA Python | examples/fmha/python/ (38 examples with JIT compilation) |
|
||||
| Codegen | [codegen/README.md](codegen/README.md) |
|
||||
| Python Utils | [python/README.md](python/README.md) |
|
||||
| C++ Headers | [include/ck_tile/dispatcher/README.md](include/ck_tile/dispatcher/README.md) |
|
||||
|
||||
@@ -10,6 +10,7 @@ bindings/
|
||||
| |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API
|
||||
| |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data)
|
||||
| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library)
|
||||
| |---- fmha_ctypes_lib.cpp # FMHA dispatcher C API (fwd + bwd)
|
||||
| |---- gpu_helper.cpp # CLI helper for Python
|
||||
| +---- CMakeLists.txt
|
||||
+---- README.md
|
||||
|
||||
@@ -129,7 +129,22 @@ float conv_bwdw_run(const void* input_ptr,
|
||||
return -1.0f;
|
||||
if(!input_ptr || !grad_output_ptr || !grad_weight_ptr)
|
||||
return -1.0f; // Null data pointer would cause kernel crash
|
||||
return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream);
|
||||
|
||||
try
|
||||
{
|
||||
return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream);
|
||||
}
|
||||
catch(const std::exception&)
|
||||
{
|
||||
// Kernel rejected args (e.g. unsupported tile/channel combo)
|
||||
// -3.0f matches conv_ctypes_lib.cpp:316 convention
|
||||
// -2.0f is reserved for "no kernel / not compiled for this direction"
|
||||
return -3.0f;
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
return -3.0f;
|
||||
}
|
||||
#else
|
||||
return -1.0f;
|
||||
#endif
|
||||
|
||||
1685
dispatcher/bindings/ctypes/fmha_ctypes_lib.cpp
Normal file
1685
dispatcher/bindings/ctypes/fmha_ctypes_lib.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
@@ -81,7 +81,9 @@
|
||||
"warp_configs": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1]
|
||||
[4, 1, 1],
|
||||
[8, 2, 1],
|
||||
[4, 4, 1]
|
||||
],
|
||||
"warp_tile_combos": {
|
||||
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
|
||||
@@ -256,8 +258,8 @@
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]]
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16], [32, 32, 32], [16, 16, 64]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16], [32, 32, 32], [16, 16, 64]],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
|
||||
|
||||
Generated from: arch_specs.json
|
||||
Generated at: 2026-01-05T19:34:01.224422
|
||||
Generated at: 2026-04-10T20:07:11.665064
|
||||
|
||||
To update this file:
|
||||
1. Edit arch_specs.json
|
||||
@@ -50,7 +49,7 @@ WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {
|
||||
"gfx908": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1], [8, 2, 1], [4, 4, 1]],
|
||||
"gfx1100": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
|
||||
"gfx1200": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
|
||||
"gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
|
||||
@@ -226,6 +225,8 @@ PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
],
|
||||
"bf16_bf16_fp32": [
|
||||
[32, 32, 8],
|
||||
@@ -233,6 +234,8 @@ PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
],
|
||||
"fp8_fp8_fp32": [
|
||||
[32, 32, 16],
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Shared codegen infrastructure for GEMM and grouped convolution code generators.
|
||||
Shared codegen infrastructure for GEMM, grouped convolution, and FMHA code generators.
|
||||
|
||||
Extracted from unified_gemm_codegen.py + arch-aware expansion helpers from conv.
|
||||
Both unified_gemm_codegen.py and unified_grouped_conv_codegen.py import from here
|
||||
|
||||
4
dispatcher/codegen/fmha/__init__.py
Normal file
4
dispatcher/codegen/fmha/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""FMHA codegen subpackage — tile specs, instance generation, symbol mapping, and C++ codegen."""
|
||||
1385
dispatcher/codegen/fmha/codegen.py
Normal file
1385
dispatcher/codegen/fmha/codegen.py
Normal file
File diff suppressed because it is too large
Load Diff
175
dispatcher/codegen/fmha/fmha_arch_specs.json
Normal file
175
dispatcher/codegen/fmha/fmha_arch_specs.json
Normal file
@@ -0,0 +1,175 @@
|
||||
{
|
||||
"_comment": "FMHA-specific architecture specs. Edit this file to add new GPU/dtype/pipeline support for FMHA.",
|
||||
"_note": "Common GPU hardware data (element_sizes, warp_size, warp_configs, lds_capacity_kb) lives in ../arch_specs.json. This file holds FMHA-specific capabilities, tile tables, and validation rules.",
|
||||
|
||||
"architectures": {
|
||||
"gfx90a": {
|
||||
"family": "cdna2",
|
||||
"arch_tag": "ck_tile::gfx9_t",
|
||||
"supported_dtypes": ["fp16", "bf16", "fp32"],
|
||||
"supported_pipelines": ["qr", "qr_async", "qs", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
|
||||
"supports_trload": false,
|
||||
"supports_v3": false
|
||||
},
|
||||
"gfx942": {
|
||||
"family": "cdna3",
|
||||
"arch_tag": "ck_tile::gfx9_t",
|
||||
"supported_dtypes": ["fp16", "bf16", "fp32", "fp8", "fp8fp16", "fp8bf16", "fp8fp32", "bf8"],
|
||||
"supported_pipelines": ["qr", "qr_async", "qs", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
|
||||
"supports_trload": false,
|
||||
"supports_v3": false
|
||||
},
|
||||
"gfx950": {
|
||||
"family": "cdna4",
|
||||
"arch_tag": "ck_tile::gfx9_t",
|
||||
"supported_dtypes": ["fp16", "bf16", "fp32", "fp8", "fp8fp16", "fp8bf16", "fp8fp32", "bf8"],
|
||||
"supported_pipelines": ["qr", "qr_async", "qs", "qr_async_trload", "qr_async_trload_v3", "v3", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
|
||||
"supports_trload": true,
|
||||
"supports_v3": true
|
||||
},
|
||||
"gfx1100": {
|
||||
"family": "rdna3",
|
||||
"arch_tag": "ck_tile::gfx11_t",
|
||||
"supported_dtypes": ["fp16", "bf16"],
|
||||
"supported_pipelines": ["qr", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
|
||||
"supports_trload": false,
|
||||
"supports_v3": false
|
||||
},
|
||||
"gfx1201": {
|
||||
"family": "rdna4",
|
||||
"arch_tag": "ck_tile::gfx12_t",
|
||||
"supported_dtypes": ["fp16", "bf16", "fp8", "fp8bf16"],
|
||||
"supported_pipelines": ["qr", "qr_pagedkv", "qr_nwarp_sshuffle", "appendkv", "bwd"],
|
||||
"supports_trload": false,
|
||||
"supports_v3": false
|
||||
}
|
||||
},
|
||||
|
||||
"supported_hdims": {
|
||||
"_comment": "hdim_q must satisfy ceil_to_qualified_tile_length() in tile_fmha_shape.hpp. Each entry is [hdim_q, hdim_v].",
|
||||
"fp16": [[32,32], [64,64], [80,96], [96,128], [128,128], [160,160], [192,128], [192,192], [256,256]],
|
||||
"bf16": [[32,32], [64,64], [80,96], [96,128], [128,128], [160,160], [192,128], [192,192], [256,256]],
|
||||
"fp32": [[32,32], [48,48], [64,64], [96,128], [128,128], [192,192], [256,256]],
|
||||
"fp8": [[64,64], [128,128], [192,128], [256,256]],
|
||||
"fp8bf16": [[64,64], [128,128], [192,128], [256,256]],
|
||||
"fp8fp32": [[128,128]],
|
||||
"bf8": [[64,64], [128,128], [192,128], [256,256]],
|
||||
"mxfp8": [[128,128], [256,256]],
|
||||
"mxfp4": [[128,128], [256,256]]
|
||||
},
|
||||
|
||||
"fmha_warp_tiles": {
|
||||
"_comment": "FMHA warp tile sizes [wm0, wn0, wk0] per FMHA dtype. Subset of MFMA/WMMA instructions relevant to attention.",
|
||||
"fp16": [[32,32,16], [16,16,32]],
|
||||
"bf16": [[32,32,16], [16,16,32]],
|
||||
"fp32": [[16,16,16]],
|
||||
"fp8": [[32,32,32]],
|
||||
"fp8bf16": [[32,32,32]],
|
||||
"fp8fp32": [[32,32,32]],
|
||||
"bf8": [[32,32,32]],
|
||||
"mxfp8": [[32,32,64]],
|
||||
"mxfp4": [[16,16,128]]
|
||||
},
|
||||
|
||||
"fmha_element_sizes": {
|
||||
"_comment": "FMHA-specific element sizes for composite dtypes not in parent arch_specs.json. Common dtypes (fp16, bf16, fp32, fp8, bf8) use ../arch_specs.json element_sizes.",
|
||||
"fp8bf16": 1,
|
||||
"fp8fp32": 1,
|
||||
"mxfp8": 1,
|
||||
"mxfp4": 1
|
||||
},
|
||||
|
||||
"tile_sweep_ranges": {
|
||||
"_comment": "Block tile dimensions to sweep. Must be multiples of warp tile sizes.",
|
||||
"valid_bm0": [16, 32, 64, 128, 192, 256],
|
||||
"valid_bn0": [16, 32, 64, 96, 128, 192, 256, 384],
|
||||
"valid_bk0": [16, 32, 64, 96, 128, 256]
|
||||
},
|
||||
|
||||
"k0max_map": {
|
||||
"_comment": "Maps hdim_q -> padded K-tile length. Source: tile_fmha_shape.hpp ceil_to_qualified_tile_length().",
|
||||
"32": 32, "48": 48, "64": 64, "80": 96, "96": 128,
|
||||
"128": 128, "160": 256, "192": 192, "256": 256
|
||||
},
|
||||
|
||||
"lds_limits": {
|
||||
"_comment": "LDS budget in bytes per non-async FMHA pipeline. Async pipelines compute LDS dynamically.",
|
||||
"qr": 65536,
|
||||
"qs": 65536
|
||||
},
|
||||
|
||||
"global_rules": {
|
||||
"hdim_192_128_no_bias_dropout": true,
|
||||
"logits_requires_no_bias": true,
|
||||
"group_mode_requires_padding": true,
|
||||
"hdim_divisible_by": 8
|
||||
},
|
||||
|
||||
"defaults": {
|
||||
"tile": [128, 64, 32, 128, 32, 128],
|
||||
"wave": [2, 2, 1, 2, 2, 1, 1, 1, 1],
|
||||
"warp": [32, 32, 16, 32, 32, 16, 16, 16, 16],
|
||||
"padding": [true, true, true, true],
|
||||
"block_per_cu": 1,
|
||||
"num_wave_groups": 1,
|
||||
"selection_rank": 0
|
||||
},
|
||||
|
||||
"splitkv_combine": {
|
||||
"combine_bn1": 32,
|
||||
"hdims_fp16": [32, 64, 96, 128, 256],
|
||||
"hdims_fp8": [64, 128, 256]
|
||||
},
|
||||
|
||||
"batch_prefill": {
|
||||
"supported_page_sizes": [1, 16, 1024],
|
||||
"supported_kv_memory_layouts": ["vectorized", "linear"],
|
||||
"supported_kv_lookup_tables": ["vllm", "sglang"]
|
||||
},
|
||||
|
||||
"bwd_tiles": {
|
||||
"_comment": "BWD dq_dk_dv tile tables. Format: [bm0, bn0, bk0, bn1, bk1, bk0max, tile6, tile7, tile8].",
|
||||
"dq_dk_dv_fp16": {
|
||||
"32_32": [32, 128, 32, 32, 32, 32, 64, 32, 32],
|
||||
"64_64": [32, 128, 64, 32, 64, 32, 32, 64, 64],
|
||||
"96_128": [32, 128, 96, 32, 96, 32, 32, 96, 96],
|
||||
"128_128": [16, 128, 128, 16, 128, 16, 32, 128, 128],
|
||||
"256_256": [16, 64, 256, 16, 256, 16, 32, 256, 256]
|
||||
},
|
||||
"dq_dk_dv_extra": {
|
||||
"64_64": [
|
||||
{"tile": [32, 128, 64, 32, 64, 32, 32, 64, 64], "tag": "trload", "batch_only": false},
|
||||
{"tile": [32, 16, 64, 32, 64, 32, 16, 64, 64], "tag": "small", "batch_only": true}
|
||||
],
|
||||
"128_128": [
|
||||
{"tile": [16, 16, 128, 16, 128, 16, 16, 128, 128], "tag": "small", "batch_only": true},
|
||||
{"tile": [16, 192, 128, 16, 128, 16, 32, 128, 128], "tag": "bn192", "batch_only": false},
|
||||
{"tile": [32, 128, 128, 32, 128, 32, 32, 128, 128], "tag": "trload", "batch_only": false}
|
||||
]
|
||||
},
|
||||
"dot_do_o_hdims": [32, 64, 96, 128, 256],
|
||||
"convert_dq_hdims": [32, 64, 96, 128, 256],
|
||||
"convert_dq_tile_groups": {"32": 1, "64": 1, "96": 1, "128": 1, "256": 1},
|
||||
"pad_combos": [["f","f"], ["f","t"], ["f","8"], ["t","f"], ["t","t"], ["t","8"], ["8","8"]],
|
||||
"extra_pad_combos": [["f","f"], ["8","8"]],
|
||||
"dropouts": ["no", "dropout_wg16", "dropout_wg16_storerandval"],
|
||||
"small_dropouts": ["no"]
|
||||
},
|
||||
|
||||
"bwd_wave_warp": {
|
||||
"_comment": "BWD wave/warp lookup. Key: 'bm0_bn0_bk0_trload'. Value: {wave: 9-tuple, warp_k1: int}.",
|
||||
"16_16_128_t": {"wave": [1,1,1,1,1,1,1,1,1], "warp_k1": 16},
|
||||
"16_64_256_f": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16},
|
||||
"16_128_128_f": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16},
|
||||
"16_192_128_t": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16},
|
||||
"32_16_64_t": {"wave": [1,1,1,1,1,1,1,1,1], "warp_k1": 16},
|
||||
"32_128_32_f": {"wave": [1,4,1,4,1,1,2,2,1], "warp_k1": 16},
|
||||
"32_128_64_f": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 16},
|
||||
"32_128_64_t": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 32},
|
||||
"32_128_96_f": {"wave": [1,4,1,4,1,1,2,2,1], "warp_k1": 16},
|
||||
"32_128_128_t": {"wave": [1,4,1,4,1,1,1,4,1], "warp_k1": 32},
|
||||
"64_128_32_f": {"wave": [2,4,1,4,1,1,2,4,1], "warp_k1": 16},
|
||||
"64_128_64_f": {"wave": [2,4,1,4,1,1,2,4,1], "warp_k1": 16},
|
||||
"64_128_128_f": {"wave": [2,4,1,4,1,1,2,4,1], "warp_k1": 16}
|
||||
}
|
||||
}
|
||||
261
dispatcher/codegen/fmha/generate_fallback.py
Normal file
261
dispatcher/codegen/fmha/generate_fallback.py
Normal file
@@ -0,0 +1,261 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Generate FMHA fallback kernel + dispatch header for the Python ctypes library.
|
||||
|
||||
Mirrors generate_conv_dispatch_header.py: generates a single FMHA forward
|
||||
kernel and creates a dispatch header that can be force-included into
|
||||
fmha_ctypes_lib.cpp.
|
||||
|
||||
Usage:
|
||||
python3 generate_fmha_fallback.py --output-dir /path/to/output --gpu-target gfx950
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Default kernel config for fallback — a single fwd fp16 kernel with
|
||||
# known-good tile (128x128x32, qr_async) for basic smoke-test capability.
|
||||
# Source: tile dims from fmha_fwd.py FmhaFwdTileSize for hdim=128 fp16.
|
||||
DEFAULT_CONFIG = {
|
||||
"arch": "gfx950",
|
||||
"signature": {
|
||||
"family": "fwd",
|
||||
"data_type": "fp16",
|
||||
"mode": "batch",
|
||||
"vlayout": "r",
|
||||
"hdim_q": 128,
|
||||
"hdim_v": 128,
|
||||
"mask": "no",
|
||||
"bias": "no",
|
||||
"lse": False,
|
||||
"dropout": False,
|
||||
"qscale": "no",
|
||||
"rope": "none",
|
||||
"logits": False,
|
||||
"paged_kv": False,
|
||||
"fp8_static_quant": False,
|
||||
"skip_min_seqlen_q": False,
|
||||
"sink": False,
|
||||
"dbias": False,
|
||||
"store_randval": False,
|
||||
"deterministic": False,
|
||||
"kv_memory_layout": "vectorized",
|
||||
"kv_lookup_table": "sglang",
|
||||
"page_size": 1,
|
||||
},
|
||||
"algorithm": {
|
||||
"pipeline": "qr_async",
|
||||
"tile": [128, 128, 32, 128, 32, 128],
|
||||
"wave": [4, 1, 1, 4, 1, 1, 1, 1, 1],
|
||||
"warp": [32, 32, 16, 32, 32, 16, 16, 16, 16],
|
||||
"padding": [True, True, True, True],
|
||||
"block_per_cu": 1,
|
||||
"num_wave_groups": 1,
|
||||
"max_splits_log2": 0,
|
||||
"max_seq_len_q": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def generate_dispatch_header(output_dir: Path, wrapper_dir: Path) -> Path:
|
||||
"""Generate fmha_python_dispatch.hpp from the wrapper headers."""
|
||||
wrappers = sorted(wrapper_dir.glob("dispatcher_wrapper_fmha_*.hpp"))
|
||||
if not wrappers:
|
||||
raise RuntimeError(f"No FMHA dispatcher wrappers found in {wrapper_dir}")
|
||||
|
||||
kernel_names = []
|
||||
make_calls = []
|
||||
for w in wrappers:
|
||||
stem = w.stem.replace("dispatcher_wrapper_", "")
|
||||
kernel_names.append(stem)
|
||||
make_calls.append(
|
||||
f" registry.register_kernel("
|
||||
f"ck_tile::dispatcher::generated::make_{stem}(arch));"
|
||||
)
|
||||
|
||||
lines = [
|
||||
"// Auto-generated FMHA dispatch header for Python ctypes library",
|
||||
"#pragma once",
|
||||
"",
|
||||
]
|
||||
for w in wrappers:
|
||||
lines.append(f'#include "dispatcher_wrappers/{w.name}"')
|
||||
|
||||
lines += [
|
||||
"",
|
||||
'#include "ck_tile/dispatcher/fmha_registry.hpp"',
|
||||
'#include "ck_tile/dispatcher/fmha_dispatcher.hpp"',
|
||||
"",
|
||||
"namespace generated {",
|
||||
"",
|
||||
"inline void register_fmha_python_kernels("
|
||||
"ck_tile::dispatcher::FmhaRegistry& registry, const std::string& arch) {",
|
||||
" (void)arch;",
|
||||
]
|
||||
lines += make_calls
|
||||
lines += [
|
||||
"}",
|
||||
"",
|
||||
"} // namespace generated",
|
||||
"",
|
||||
"#ifndef REGISTER_GENERATED_KERNELS",
|
||||
"#define REGISTER_GENERATED_KERNELS(registry, arch) "
|
||||
"::generated::register_fmha_python_kernels(registry, arch)",
|
||||
"#endif",
|
||||
"",
|
||||
"// Stable C ABI for dlopen/dlsym-based kernel registration.",
|
||||
'// Plugins call dlsym(handle, "ck_fmha_register_kernels") to get this.',
|
||||
'extern "C" __attribute__((visibility("default")))',
|
||||
"int ck_fmha_register_kernels(",
|
||||
" ck_tile::dispatcher::FmhaRegistry& registry, const char* arch) {",
|
||||
" ::generated::register_fmha_python_kernels(registry, arch ? std::string(arch) : std::string());",
|
||||
f" return {len(kernel_names)};",
|
||||
"}",
|
||||
"",
|
||||
"// Kernel inventory for Python introspection",
|
||||
f"static const int FMHA_KERNEL_COUNT = {len(kernel_names)};",
|
||||
"static const char* FMHA_KERNEL_NAMES[] = {"
|
||||
+ ", ".join(f'"{n}"' for n in kernel_names)
|
||||
+ "};",
|
||||
"",
|
||||
]
|
||||
|
||||
header_path = output_dir / "fmha_python_dispatch.hpp"
|
||||
header_path.write_text("\n".join(lines) + "\n")
|
||||
return header_path
|
||||
|
||||
|
||||
def compile_kernels(output_dir: Path, gpu_target: str, include_dirs: str) -> Path:
|
||||
"""Compile kernel .cpp files into a static library."""
|
||||
import shutil
|
||||
|
||||
hipcc = shutil.which("hipcc") or "/opt/rocm/bin/hipcc"
|
||||
kernel_cpps = sorted(output_dir.glob("fmha_*.cpp"))
|
||||
if not kernel_cpps:
|
||||
raise RuntimeError("No kernel .cpp files to compile")
|
||||
|
||||
import re
|
||||
|
||||
# Use the shared compile flags from fmha_utils
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "python"))
|
||||
from fmha_utils import fmha_compile_flags # noqa: E402
|
||||
|
||||
base_flags = fmha_compile_flags(gpu_target, hipcc, family="bwd")
|
||||
|
||||
inc_flags = []
|
||||
for d in re.split(r"[;:]", include_dirs):
|
||||
d = d.strip()
|
||||
if d:
|
||||
inc_flags.extend(["-I", d])
|
||||
|
||||
objs = []
|
||||
for cpp in kernel_cpps:
|
||||
obj = cpp.with_suffix(".o")
|
||||
cmd = base_flags + inc_flags + [str(cpp), "-o", str(obj)]
|
||||
print(f" Compiling {cpp.name}...")
|
||||
r = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if r.returncode != 0:
|
||||
print(f" FAILED: {r.stderr}", file=sys.stderr)
|
||||
raise RuntimeError(f"Failed to compile {cpp.name}")
|
||||
objs.append(str(obj))
|
||||
|
||||
lib_path = output_dir / "libfmha_python_fallback.a"
|
||||
ar_cmd = ["ar", "rcs", str(lib_path)] + objs
|
||||
subprocess.check_call(ar_cmd)
|
||||
print(f" Static lib: {lib_path}")
|
||||
return lib_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate FMHA fallback kernel for Python library"
|
||||
)
|
||||
parser.add_argument("--output-dir", required=True, type=Path)
|
||||
parser.add_argument("--gpu-target", default="gfx950")
|
||||
parser.add_argument(
|
||||
"--config-json",
|
||||
default=None,
|
||||
help="Override default kernel config (JSON string)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compile", action="store_true", help="Also compile the kernel .cpp into a .a"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-dirs",
|
||||
default="",
|
||||
help="Semicolon-separated include directories for compilation",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = args.output_dir
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
codegen_dir = Path(__file__).parent
|
||||
codegen_script = codegen_dir / "codegen.py"
|
||||
|
||||
# Accept either a single config dict or a list of configs
|
||||
if args.config_json:
|
||||
parsed = json.loads(args.config_json)
|
||||
if isinstance(parsed, list):
|
||||
# Multi-config: pass list directly to unified_fmha_codegen
|
||||
codegen_input = parsed
|
||||
else:
|
||||
# Single config: merge with defaults
|
||||
config = dict(DEFAULT_CONFIG)
|
||||
config["arch"] = args.gpu_target
|
||||
config["signature"] = dict(DEFAULT_CONFIG["signature"])
|
||||
config["algorithm"] = dict(DEFAULT_CONFIG["algorithm"])
|
||||
config.update(parsed)
|
||||
codegen_input = config
|
||||
else:
|
||||
config = dict(DEFAULT_CONFIG)
|
||||
config["arch"] = args.gpu_target
|
||||
config["signature"] = dict(DEFAULT_CONFIG["signature"])
|
||||
config["algorithm"] = dict(DEFAULT_CONFIG["algorithm"])
|
||||
codegen_input = config
|
||||
|
||||
print(f"Generating FMHA fallback kernel for {args.gpu_target}...")
|
||||
print(f" Output: {output_dir}")
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(codegen_script),
|
||||
"--output-dir",
|
||||
str(output_dir),
|
||||
"--gpu-target",
|
||||
args.gpu_target,
|
||||
"--config-json",
|
||||
json.dumps(codegen_input),
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(codegen_dir))
|
||||
if result.returncode != 0:
|
||||
print(f" Codegen FAILED:\n{result.stderr}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
wrapper_dir = output_dir / "dispatcher_wrappers"
|
||||
if not wrapper_dir.exists():
|
||||
print(" ERROR: No dispatcher_wrappers dir created", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
header_path = generate_dispatch_header(output_dir, wrapper_dir)
|
||||
print(f" Dispatch header: {header_path}")
|
||||
|
||||
kernel_cpps = list(output_dir.glob("fmha_*.cpp"))
|
||||
print(f" Kernel TUs: {len(kernel_cpps)}")
|
||||
|
||||
if args.compile and kernel_cpps:
|
||||
compile_kernels(output_dir, args.gpu_target, args.include_dirs)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
2692
dispatcher/codegen/fmha/instance_gen.py
Normal file
2692
dispatcher/codegen/fmha/instance_gen.py
Normal file
File diff suppressed because it is too large
Load Diff
333
dispatcher/codegen/fmha/symbol_map.py
Normal file
333
dispatcher/codegen/fmha/symbol_map.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
# Architecture tag → C++ arch trait type.
|
||||
# Source: CK's include/ck_tile/core/arch/amd_gpu_traits.hpp
|
||||
# gfx9* → gfx9_t, gfx11* → gfx11_t, gfx12* → gfx12_t.
|
||||
ARCH_TAG_MAP = {
|
||||
"gfx90a": "ck_tile::gfx9_t",
|
||||
"gfx942": "ck_tile::gfx9_t",
|
||||
"gfx950": "ck_tile::gfx9_t",
|
||||
"gfx1100": "ck_tile::gfx11_t",
|
||||
"gfx1201": "ck_tile::gfx12_t",
|
||||
}
|
||||
|
||||
# Architecture → preprocessor guard for conditional compilation.
|
||||
# Source: HIP compiler predefined macros (__gfx90a__, __gfx942__, etc.).
|
||||
ARCH_PREPROC_MAP = {
|
||||
"gfx90a": "defined(__gfx90a__)",
|
||||
"gfx942": "defined(__gfx942__)",
|
||||
"gfx950": "defined(__gfx950__)",
|
||||
"gfx1100": "defined(__gfx1100__)",
|
||||
"gfx1201": "defined(__gfx1201__)",
|
||||
}
|
||||
|
||||
# Forward dtype → C++ type config struct.
|
||||
# Source: example/ck_tile/01_fmha/fmha_fwd.hpp FmhaFwdTypeConfig<> specializations
|
||||
# and codegen/cpp_symbol_map.py FWD_DTYPE_MAP.
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp32": "FmhaFwdFp32",
|
||||
"fp16": "FmhaFwdFp16",
|
||||
"bf16": "FmhaFwdBf16",
|
||||
"fp8": "FmhaFwdFp8",
|
||||
"bf8": "FmhaFwdBf8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16",
|
||||
"fp8fp32": "FmhaFwdFp8Fp32",
|
||||
}
|
||||
|
||||
# Backward dtype → C++ type config struct.
|
||||
# Source: example/ck_tile/01_fmha/fmha_bwd.hpp FmhaBwdTypeConfig<> specializations.
|
||||
# BWD currently only supports fp16/bf16/fp32.
|
||||
BWD_DTYPE_MAP = {
|
||||
"fp32": "FmhaBwdFp32",
|
||||
"fp16": "FmhaBwdFp16",
|
||||
"bf16": "FmhaBwdBf16",
|
||||
}
|
||||
|
||||
# Kernel family → C++ enum.
|
||||
# Source: include/ck_tile/dispatcher/fmha_types.hpp FmhaKernelFamily enum.
|
||||
KERNEL_FAMILY_TO_ENUM = {
|
||||
"fwd": "FmhaKernelFamily::Fwd",
|
||||
"fwd_pagedkv": "FmhaKernelFamily::FwdPagedKv",
|
||||
"fwd_splitkv": "FmhaKernelFamily::FwdSplitKv",
|
||||
"fwd_splitkv_combine": "FmhaKernelFamily::FwdSplitKvCombine",
|
||||
"fwd_appendkv": "FmhaKernelFamily::FwdAppendKv",
|
||||
"batch_prefill": "FmhaKernelFamily::BatchPrefill",
|
||||
"bwd_dot_do_o": "FmhaKernelFamily::BwdDotDoO",
|
||||
"bwd_dq_dk_dv": "FmhaKernelFamily::BwdDqDkDv",
|
||||
"bwd_convert_dq": "FmhaKernelFamily::BwdConvertDq",
|
||||
}
|
||||
|
||||
# API family → C++ enum.
|
||||
# Source: include/ck_tile/dispatcher/fmha_types.hpp FmhaApiFamily enum.
|
||||
API_FAMILY_TO_ENUM = {
|
||||
"fwd": "FmhaApiFamily::Fwd",
|
||||
"fwd_pagedkv": "FmhaApiFamily::FwdPagedKv",
|
||||
"fwd_splitkv": "FmhaApiFamily::FwdSplitKv",
|
||||
"fwd_appendkv": "FmhaApiFamily::FwdAppendKv",
|
||||
"batch_prefill": "FmhaApiFamily::BatchPrefill",
|
||||
"bwd": "FmhaApiFamily::Bwd",
|
||||
}
|
||||
|
||||
# Mask type → canonical form and C++ types.
|
||||
# Source: include/ck_tile/ops/fmha/block/block_attention_mask.hpp
|
||||
# SimplifiedGenericAttentionMask<is_causal> and GenericAttentionMask<has_mask, has_local>.
|
||||
MASK_CANONICAL = {
|
||||
"no": "no",
|
||||
"no_mask": "no",
|
||||
"causal": "top_left",
|
||||
"top_left": "top_left",
|
||||
"t": "top_left",
|
||||
"bottom_right": "bottom_right",
|
||||
"b": "bottom_right",
|
||||
"generic": "generic",
|
||||
"window_generic": "generic",
|
||||
"g": "generic",
|
||||
}
|
||||
|
||||
MASK_TO_CPP = {
|
||||
"no": "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"top_left": "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
"bottom_right": "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
"generic": "ck_tile::GenericAttentionMask<true, true>",
|
||||
}
|
||||
|
||||
MASK_TO_CPP_GENERIC = {
|
||||
"no": "FmhaMasks::NoMask",
|
||||
"top_left": "FmhaMasks::CausalMask",
|
||||
"bottom_right": "FmhaMasks::CausalMask",
|
||||
"generic": "FmhaMasks::GenericMask",
|
||||
}
|
||||
|
||||
MASK_TO_INT = {
|
||||
"no": 0,
|
||||
"top_left": 1,
|
||||
"bottom_right": 2,
|
||||
"generic": 3,
|
||||
}
|
||||
|
||||
# Bias type → canonical form and C++ enum.
|
||||
# Source: include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp.
|
||||
BIAS_CANONICAL = {
|
||||
"no": "no",
|
||||
"no_bias": "no",
|
||||
"bias": "bias",
|
||||
"elementwise": "bias",
|
||||
"elementwise_bias": "bias",
|
||||
"alibi": "alibi",
|
||||
}
|
||||
|
||||
BIAS_TO_CPP = {
|
||||
"no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
|
||||
"bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
||||
"alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI",
|
||||
}
|
||||
|
||||
BIAS_TO_INT = {
|
||||
"no": 0,
|
||||
"bias": 1,
|
||||
"alibi": 2,
|
||||
}
|
||||
|
||||
# Quantization scale type → canonical form and C++ enum.
|
||||
# Source: include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp.
|
||||
QSCALE_CANONICAL = {
|
||||
"no": "no",
|
||||
"no_scale": "no",
|
||||
"pertensor": "pertensor",
|
||||
"blockscale": "blockscale",
|
||||
"kv_blockscale": "kv_blockscale",
|
||||
}
|
||||
|
||||
QSCALE_TO_CPP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
|
||||
}
|
||||
|
||||
QSCALE_TO_INT = {
|
||||
"no": 0,
|
||||
"pertensor": 1,
|
||||
"blockscale": 2,
|
||||
"kv_blockscale": 3,
|
||||
}
|
||||
|
||||
# Rotary embedding type → canonical form and C++ enum.
|
||||
# Source: include/ck_tile/ops/fmha/block/rotary_embedding_enum.hpp.
|
||||
ROPE_CANONICAL = {
|
||||
"none": "none",
|
||||
"no": "none",
|
||||
"inter": "inter",
|
||||
"interleaved": "inter",
|
||||
"half": "half",
|
||||
"half_rotated": "half",
|
||||
}
|
||||
|
||||
ROPE_TO_CPP = {
|
||||
"none": "ck_tile::RotaryEmbeddingEnum::NONE",
|
||||
"inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
|
||||
"half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED",
|
||||
}
|
||||
|
||||
ROPE_TO_INT = {
|
||||
"none": 0,
|
||||
"inter": 1,
|
||||
"half": 2,
|
||||
}
|
||||
|
||||
# V layout → C++ bool (true = row-major, false = column-major).
|
||||
# Source: TileFmhaShape<..., IsVLayoutRowMajor> template parameter.
|
||||
LAYOUT_TO_BOOL = {
|
||||
"r": "true",
|
||||
"row": "true",
|
||||
"row_major": "true",
|
||||
"c": "false",
|
||||
"col": "false",
|
||||
"col_major": "false",
|
||||
}
|
||||
|
||||
# KV cache memory layout → canonical form and C++ enum.
|
||||
# Source: include/ck_tile/ops/fmha/block/block_attention_kv_cache.hpp.
|
||||
KV_MEMORY_LAYOUT_CANONICAL = {
|
||||
"vectorized": "vectorized",
|
||||
"linear": "linear",
|
||||
}
|
||||
|
||||
KV_MEMORY_LAYOUT_TO_CPP = {
|
||||
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
|
||||
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
|
||||
}
|
||||
|
||||
KV_MEMORY_LAYOUT_TO_INT = {
|
||||
"vectorized": 0,
|
||||
"linear": 1,
|
||||
}
|
||||
|
||||
# KV lookup table type → canonical form and C++ enum.
|
||||
# Source: include/ck_tile/ops/fmha/block/block_attention_kv_cache.hpp.
|
||||
KV_LOOKUP_CANONICAL = {
|
||||
"sglang": "sglang",
|
||||
"vllm": "vllm",
|
||||
}
|
||||
|
||||
KV_LOOKUP_TO_CPP = {
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
}
|
||||
|
||||
KV_LOOKUP_TO_INT = {
|
||||
"vllm": 0,
|
||||
"sglang": 1,
|
||||
}
|
||||
|
||||
# Pipeline tag → C++ pipeline class.
|
||||
# Source: include/ck_tile/ops/fmha/pipeline/ — one header per pipeline variant.
|
||||
PIPELINE_TO_CPP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
"v3": "ck_tile::BlockFmhaFwdV3Pipeline",
|
||||
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
|
||||
"qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS",
|
||||
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS",
|
||||
"appendkv": "ck_tile::BlockFmhaFwdAppendKVPipeline",
|
||||
"batch_prefill_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
# Pipeline tag → C++ pipeline enum value.
|
||||
# Source: include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp.
|
||||
PIPELINE_ENUM_TO_CPP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
|
||||
"v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
|
||||
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
|
||||
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"batch_prefill_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
True: "true",
|
||||
False: "false",
|
||||
"t": "true",
|
||||
"f": "false",
|
||||
}
|
||||
|
||||
|
||||
def canonical_mask(value: str) -> str:
|
||||
return MASK_CANONICAL.get(value, value)
|
||||
|
||||
|
||||
def canonical_bias(value: str) -> str:
|
||||
return BIAS_CANONICAL.get(value, value)
|
||||
|
||||
|
||||
def canonical_qscale(value: str) -> str:
|
||||
return QSCALE_CANONICAL.get(value, value)
|
||||
|
||||
|
||||
def canonical_rope(value: str) -> str:
|
||||
return ROPE_CANONICAL.get(value, value)
|
||||
|
||||
|
||||
def canonical_kv_memory_layout(value: str) -> str:
|
||||
return KV_MEMORY_LAYOUT_CANONICAL.get(value, value)
|
||||
|
||||
|
||||
def canonical_kv_lookup(value: str) -> str:
|
||||
return KV_LOOKUP_CANONICAL.get(value, value)
|
||||
|
||||
|
||||
def sanitize_token(value) -> str:
|
||||
return str(value).replace("::", "_").replace("/", "_").replace(" ", "_")
|
||||
|
||||
|
||||
def kernel_name_from_config(config: dict) -> str:
|
||||
sig = config["signature"]
|
||||
alg = config["algorithm"]
|
||||
|
||||
family = sanitize_token(sig["family"])
|
||||
dtype = sanitize_token(sig["data_type"])
|
||||
mode = sanitize_token(sig["mode"])
|
||||
vlayout = sanitize_token(sig["vlayout"])
|
||||
mask = sanitize_token(canonical_mask(sig["mask"]))
|
||||
bias = sanitize_token(canonical_bias(sig["bias"]))
|
||||
qscale = sanitize_token(canonical_qscale(sig["qscale"]))
|
||||
rope = sanitize_token(canonical_rope(sig["rope"]))
|
||||
kv_memory = sanitize_token(canonical_kv_memory_layout(sig["kv_memory_layout"]))
|
||||
kv_lookup = sanitize_token(canonical_kv_lookup(sig["kv_lookup_table"]))
|
||||
pipeline = sanitize_token(alg["pipeline"])
|
||||
|
||||
canonical_blob = json.dumps(
|
||||
{
|
||||
"family": family,
|
||||
"dtype": dtype,
|
||||
"mode": mode,
|
||||
"vlayout": vlayout,
|
||||
"mask": mask,
|
||||
"bias": bias,
|
||||
"qscale": qscale,
|
||||
"rope": rope,
|
||||
"kv_memory": kv_memory,
|
||||
"kv_lookup": kv_lookup,
|
||||
"sig": sig,
|
||||
"alg": alg,
|
||||
},
|
||||
sort_keys=True,
|
||||
).encode("utf-8")
|
||||
digest = hashlib.sha1(canonical_blob).hexdigest()[:12]
|
||||
|
||||
return (
|
||||
f"fmha_{family}_{dtype}_{mode}_h{sig['hdim_q']}x{sig['hdim_v']}"
|
||||
f"_{pipeline}_{digest}"
|
||||
)
|
||||
921
dispatcher/codegen/fmha/validation.py
Normal file
921
dispatcher/codegen/fmha/validation.py
Normal file
@@ -0,0 +1,921 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
FMHA validation and kernel specifications.
|
||||
|
||||
Architecture-specific data (dtypes, pipelines, hdims, tile tables) is stored in
|
||||
``fmha_arch_specs.json`` so that it can be edited without touching Python code.
|
||||
Common GPU hardware data (element sizes, warp size, LDS capacity) is imported
|
||||
from the parent ``arch_specs_generated`` module (generated from ``arch_specs.json``).
|
||||
|
||||
This file provides:
|
||||
- JSON loading helpers
|
||||
- Tile constraints (per-arch rules that reject invalid tiles)
|
||||
- Feature compatibility rules (pipeline × feature flag interactions)
|
||||
- Receipt filters and profiles (deployment-specific kernel subsets)
|
||||
- Config validation for the AOT codegen path
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
# Ensure this directory and parent codegen/ are on sys.path for sibling imports
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_CODEGEN_DIR = _THIS_DIR.parent
|
||||
sys.path.insert(0, str(_THIS_DIR))
|
||||
sys.path.insert(0, str(_CODEGEN_DIR))
|
||||
|
||||
from symbol_map import ( # noqa: E402
|
||||
BWD_DTYPE_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
canonical_bias,
|
||||
canonical_mask,
|
||||
canonical_qscale,
|
||||
)
|
||||
|
||||
# Import shared hardware data from parent arch_specs_generated (generated from
|
||||
# arch_specs.json by generate_arch_specs.py). Falls back to inline defaults if
|
||||
# the generated module is unavailable (e.g. in standalone testing).
|
||||
try:
|
||||
from arch_specs_generated import ELEMENT_SIZE_MAP as _PARENT_ELEMENT_SIZES # noqa: E402
|
||||
except ImportError:
|
||||
_PARENT_ELEMENT_SIZES = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"fp32": 4,
|
||||
"fp64": 8,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int8": 1,
|
||||
"int4": 0.5,
|
||||
"pk_fp4": 0.5,
|
||||
"int32": 4,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JSON data loading
|
||||
# =============================================================================
|
||||
|
||||
_FMHA_SPECS_PATH = _THIS_DIR / "fmha_arch_specs.json"
|
||||
|
||||
|
||||
def _load_fmha_specs() -> dict:
|
||||
"""Load fmha_arch_specs.json (cached after first call)."""
|
||||
if not hasattr(_load_fmha_specs, "_cache"):
|
||||
with open(_FMHA_SPECS_PATH) as f:
|
||||
_load_fmha_specs._cache = json.load(f)
|
||||
return _load_fmha_specs._cache
|
||||
|
||||
|
||||
def _build_element_sizes() -> Dict[str, int]:
|
||||
"""Merge parent element sizes with FMHA-specific composite dtypes."""
|
||||
base = {k: int(v) for k, v in _PARENT_ELEMENT_SIZES.items()}
|
||||
base.update(_load_fmha_specs().get("fmha_element_sizes", {}))
|
||||
return base
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 1. Architecture capabilities (loaded from fmha_arch_specs.json)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _build_arch_dtypes() -> Dict[str, List[str]]:
|
||||
"""Build ARCH_DTYPES from JSON architectures."""
|
||||
return {
|
||||
arch: info["supported_dtypes"]
|
||||
for arch, info in _load_fmha_specs()["architectures"].items()
|
||||
}
|
||||
|
||||
|
||||
def _build_supported_hdims() -> Dict[str, List[Tuple[int, int]]]:
|
||||
"""Build SUPPORTED_HDIMS from JSON, converting [q,v] lists to tuples."""
|
||||
return {
|
||||
dtype: [tuple(pair) for pair in pairs]
|
||||
for dtype, pairs in _load_fmha_specs()["supported_hdims"].items()
|
||||
if dtype != "_comment"
|
||||
}
|
||||
|
||||
|
||||
def _build_arch_metadata() -> Dict[str, dict]:
|
||||
"""Build ARCH_METADATA from JSON architectures."""
|
||||
return dict(_load_fmha_specs()["architectures"])
|
||||
|
||||
|
||||
ARCH_DTYPES: Dict[str, List[str]] = _build_arch_dtypes()
|
||||
SUPPORTED_HDIMS: Dict[str, List[Tuple[int, int]]] = _build_supported_hdims()
|
||||
ARCH_METADATA: Dict[str, dict] = _build_arch_metadata()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 2. Tile hardware parameters (loaded from fmha_arch_specs.json + parent arch_specs)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _build_warp_classes() -> Dict[str, List[Tuple[int, int, int]]]:
|
||||
"""Build WARP_CLASSES from JSON fmha_warp_tiles."""
|
||||
return {
|
||||
dtype: [tuple(w) for w in warps]
|
||||
for dtype, warps in _load_fmha_specs()["fmha_warp_tiles"].items()
|
||||
if dtype != "_comment"
|
||||
}
|
||||
|
||||
|
||||
def _build_lds_limits() -> Dict[str, int]:
|
||||
"""Build LDS_LIMITS from JSON."""
|
||||
return dict(_load_fmha_specs()["lds_limits"])
|
||||
|
||||
|
||||
def _build_k0max_map() -> Dict[int, int]:
|
||||
"""Build K0_MAX_SUBMAX_MAP from JSON (string keys → int keys)."""
|
||||
return {
|
||||
int(k): v for k, v in _load_fmha_specs()["k0max_map"].items() if k != "_comment"
|
||||
}
|
||||
|
||||
|
||||
_specs = _load_fmha_specs()
|
||||
_tile_ranges = _specs["tile_sweep_ranges"]
|
||||
|
||||
LDS_LIMITS: Dict[str, int] = _build_lds_limits()
|
||||
WARP_CLASSES: Dict[str, List[Tuple[int, int, int]]] = _build_warp_classes()
|
||||
ELEMENT_SIZES: Dict[str, int] = _build_element_sizes()
|
||||
VALID_BM0: List[int] = _tile_ranges["valid_bm0"]
|
||||
VALID_BN0: List[int] = _tile_ranges["valid_bn0"]
|
||||
VALID_BK0: List[int] = _tile_ranges["valid_bk0"]
|
||||
K0_MAX_SUBMAX_MAP: Dict[int, int] = _build_k0max_map()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 3. Tile constraints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def check_gfx9_tile_constraints(
|
||||
dtype: str,
|
||||
hdim_q: int,
|
||||
hdim_v: int,
|
||||
pipeline: str,
|
||||
bm0: int,
|
||||
bn0: int,
|
||||
bk0: int,
|
||||
) -> bool:
|
||||
"""Gfx9 compatibility rules.
|
||||
|
||||
Source: fmha_fwd.py CompatibilityRuleFactoryGfx9.check_hdim_tile().
|
||||
Applies to gfx90a, gfx942, gfx950 for pipelines in {qr, qr_async, qs}.
|
||||
Note: CK factory is stricter (bm0==128 only for non-128 hdims); we allow
|
||||
{64, 128, 192, 256} to let the tile engine explore more configurations.
|
||||
"""
|
||||
if dtype == "fp32":
|
||||
return True
|
||||
if pipeline not in ("qr", "qr_async", "qs"):
|
||||
return True
|
||||
if (hdim_q, hdim_v) == (128, 128) and bn0 != 128:
|
||||
return False
|
||||
if (hdim_q, hdim_v) == (128, 128) and pipeline == "qr_async" and bm0 != 128:
|
||||
return False
|
||||
if (hdim_q, hdim_v) != (128, 128) and bm0 not in (64, 128, 192, 256):
|
||||
return False
|
||||
if (hdim_q, hdim_v) == (128, 128) and pipeline != "qr_async" and bk0 == 64:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_gfx950_tile_constraints(
|
||||
hdim_q: int,
|
||||
hdim_v: int,
|
||||
pipeline: str,
|
||||
bm0: int,
|
||||
bn0: int,
|
||||
) -> bool:
|
||||
"""Gfx950 trload/v3 constraints.
|
||||
|
||||
Source: fmha_fwd.py CompatibilityRuleFactoryGfx950.check_tile_pipeline().
|
||||
Note: CK enforces biconditional (v3_tile ↔ v3_pipeline); we only enforce
|
||||
v3_pipeline → v3_tile since non-v3 pipelines may still use bm0=256.
|
||||
"""
|
||||
if pipeline == "qr_async_trload":
|
||||
if (hdim_q, hdim_v) == (128, 128) and bn0 == 128:
|
||||
return False
|
||||
if (hdim_q, hdim_v) not in [(64, 64), (128, 128)]:
|
||||
return False
|
||||
is_v3_tile = bm0 == 256
|
||||
is_v3_pipeline = pipeline == "qr_async_trload_v3"
|
||||
# v3 pipeline requires bm0=256; other pipelines also allow bm0=256
|
||||
if is_v3_pipeline and not is_v3_tile:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_qr_mfma_insts(
|
||||
arch: str,
|
||||
hdim_q: int,
|
||||
pipeline: str,
|
||||
bn0: int,
|
||||
bk0: int,
|
||||
wn0: int,
|
||||
wk0: int,
|
||||
) -> bool:
|
||||
"""NumMfmaInsts % 8 == 0 check.
|
||||
|
||||
Source: block_fmha_pipeline_qr_ks_vs.hpp static_assert at line ~490.
|
||||
Full C++ formula: (kM0/WarpM)*(kN0/WarpN)*(kK0/WarpK) / (MWarp*NWarp).
|
||||
We simplify to (bn0/wn0)*(bk0/wk0), omitting (bm0/wm0)/(rm0*rn0) which
|
||||
equals 1 for all current fp16/bf16/fp32/fp8 tiles, or a power-of-2 factor
|
||||
for mxfp8/mxfp4 that doesn't change the mod-8 result. This is conservative:
|
||||
it can only reject tiles the full formula would also reject, never the reverse.
|
||||
Only applies to qr pipeline + hdim_q==256 + CDNA (gfx9*).
|
||||
"""
|
||||
if pipeline != "qr" or hdim_q != 256:
|
||||
return True
|
||||
if not arch.startswith("gfx9"):
|
||||
return True
|
||||
num_mfma = (bn0 // wn0) * (bk0 // wk0)
|
||||
if num_mfma % 8 != 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def tile_passes_all_constraints(
|
||||
arch: str,
|
||||
dtype: str,
|
||||
hdim_q: int,
|
||||
hdim_v: int,
|
||||
pipeline: str,
|
||||
bm0: int,
|
||||
bn0: int,
|
||||
bk0: int,
|
||||
wm0: int,
|
||||
wn0: int,
|
||||
wk0: int,
|
||||
) -> bool:
|
||||
"""Master constraint check — returns True if the tile is valid."""
|
||||
elem_size = ELEMENT_SIZES.get(dtype, 2)
|
||||
lds_limit = LDS_LIMITS.get(pipeline, 65536)
|
||||
|
||||
# LDS capacity check (pipeline-dependent formula)
|
||||
if pipeline in ("qr_async", "qr_async_trload", "qr_async_trload_v3"):
|
||||
# Async pipeline: Q is in registers. LDS holds NumKVLdsBuffers (=3) copies of
|
||||
# max(SingleKSize, SingleVSize). Derived from GetSmemSizeKV() in
|
||||
# block_fmha_pipeline_qx_ks_vs_custom_policy.hpp.
|
||||
#
|
||||
# SingleVSize formula (MakeVLdsBlockDescriptor):
|
||||
# Banks=32, PixelsPerRow = Banks*4/sizeof(dtype) = 32*4/elem_size
|
||||
# kKPack = 16/elem_size (GetSmemKPackV)
|
||||
# NPerRow = PixelsPerRow/kKPack
|
||||
# SingleVSize = (bk1/kKPack) * (hdim_v/NPerRow) * (PixelsPerRow + kKPack)
|
||||
# For bf16: PixelsPerRow=64, kKPack=8, NPerRow=8
|
||||
# SingleVSize = (32/8)*(hdim_v/8)*(64+8) = 4*(hdim_v/8)*72 = 36*hdim_v
|
||||
#
|
||||
# SingleKSize formula (GetSingleSmemElementSpaceSize, async branch):
|
||||
# KPack = 16/elem_size, KVector = alignment (gfx950: 16/elem_size = 8 for bf16)
|
||||
# LanesPerK = bk0/KVector, LaneGroups = 64/LanesPerK
|
||||
# NumIssues = bn0/(LaneGroups*NumWarps)
|
||||
# SingleKSize = NumIssues*NumWarps*(64*KVector + KPack)
|
||||
#
|
||||
bk1 = 32 # kK1 in TileFmhaShape — design choice from fmha_fwd.py tile defs
|
||||
num_warps = bm0 // wm0
|
||||
# Banks: arch.hpp get_n_lds_banks() — 64 for gfx950, 32 for older
|
||||
banks = 64 if arch == "gfx950" else 32
|
||||
pixels_per_row = banks * 4 // elem_size # Banks * 4bytes / sizeof(dtype)
|
||||
k_pack = 16 // elem_size # GetSmemKPackV: 16 / sizeof(dtype)
|
||||
n_per_row = pixels_per_row // k_pack
|
||||
single_v = (bk1 // k_pack) * (hdim_v // n_per_row) * (pixels_per_row + k_pack)
|
||||
|
||||
# KVector: GetAlignmentK in custom_policy.hpp — MaxLoadSizeInBytes / sizeof(dtype)
|
||||
# gfx950 uses dwordx4 (16 bytes), older uses dword (4 bytes)
|
||||
k_vector = 16 // elem_size if arch == "gfx950" else 4 // elem_size
|
||||
lanes_per_k = bk0 // k_vector if k_vector > 0 else 1
|
||||
lane_groups = 64 // lanes_per_k if lanes_per_k > 0 else 1 # WarpSize=64
|
||||
num_issues = (
|
||||
bn0 // (lane_groups * num_warps) if (lane_groups * num_warps) > 0 else 0
|
||||
)
|
||||
single_k = num_issues * num_warps * (64 * k_vector + k_pack)
|
||||
|
||||
single_buf_bytes = max(single_k, single_v) * elem_size
|
||||
# NumPrefetchK = NumPrefetchV = 3 (async_default_policy.hpp)
|
||||
num_kv_buffers = 3
|
||||
# Q uses registers (QLoadOnce=true), so GetSmemSizeQ() = 0.
|
||||
total_lds = single_buf_bytes * num_kv_buffers
|
||||
# gfx950 HW LDS limit: arch.hpp get_smem_capacity() = 163840 (160 KiB)
|
||||
if total_lds > 163840:
|
||||
return False
|
||||
else:
|
||||
# Non-async (qr/qs): Q and K tiles share LDS simultaneously
|
||||
if (bm0 * bk0 + bn0 * bk0) * elem_size > lds_limit:
|
||||
return False
|
||||
# bk0 range
|
||||
if bk0 > hdim_q:
|
||||
return False
|
||||
# hdim_q divisibility (tile_fmha_shape.hpp:60)
|
||||
if hdim_q % bk0 != 0:
|
||||
return False
|
||||
# Warp alignment
|
||||
if bm0 % wm0 != 0 or bk0 % wk0 != 0 or bn0 % wn0 != 0:
|
||||
return False
|
||||
# MFMA inst count
|
||||
if not check_qr_mfma_insts(arch, hdim_q, pipeline, bn0, bk0, wn0, wk0):
|
||||
return False
|
||||
# Async DMA distribution constraint (MakeKLdsStoreBlockDescriptor, custom_policy.hpp).
|
||||
# NumIssues = kNPerBlock / (LaneGroups * NumWarps) must be a positive integer, where
|
||||
# LaneGroups = WarpSize / LanesPerK = 64 / (bk0 / KVector).
|
||||
# Equivalently: (bn0 * bk0) % (kBlockSize * KVector) == 0.
|
||||
# KVector = MaxLoadSizeInBytes / sizeof(dtype): gfx950=16/2=8, older=4/2=2 for bf16.
|
||||
if pipeline == "qr_async" and arch.startswith("gfx9"):
|
||||
kvector = 16 // elem_size if arch == "gfx950" else 4 // elem_size
|
||||
num_warps = bm0 // wm0
|
||||
block_size = num_warps * 64 # WarpSize = 64
|
||||
if (bn0 * bk0) % (block_size * kvector) != 0:
|
||||
return False
|
||||
# Arch constraints
|
||||
if arch in ("gfx90a", "gfx942", "gfx950"):
|
||||
if not check_gfx9_tile_constraints(
|
||||
dtype, hdim_q, hdim_v, pipeline, bm0, bn0, bk0
|
||||
):
|
||||
return False
|
||||
if arch == "gfx950":
|
||||
if not check_gfx950_tile_constraints(hdim_q, hdim_v, pipeline, bm0, bn0):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 4. Feature compatibility rules
|
||||
# =============================================================================
|
||||
|
||||
# Supported mask, bias, and boolean values for feature products.
|
||||
# These are the template enum values in CK's FMHA traits structs.
|
||||
MASKS = ["no", "causal", "generic"]
|
||||
BIASES = ["no", "bias", "alibi"]
|
||||
BOOLS = ["t", "f"]
|
||||
|
||||
# Dtype groups matching CK's _DT_* classification in fmha_fwd.py factory classes.
|
||||
DT_FP16_BF16 = {"fp16", "bf16"}
|
||||
DT_FP8 = {"fp8bf16", "fp8", "bf8"}
|
||||
DT_FP8FP32 = {"fp8fp32"}
|
||||
DT_FP32 = {"fp32"}
|
||||
|
||||
|
||||
def check_logits_bias(logits: str, bias: str) -> bool:
|
||||
"""logits_soft_cap requires no bias.
|
||||
|
||||
Source: fmha_fwd.py CompatibilityRuleFactory.check_feature().
|
||||
"""
|
||||
return not (logits == "t" and bias != "no")
|
||||
|
||||
|
||||
def check_group_mode_padding(mode: str, spad: str, skpad: str) -> bool:
|
||||
"""Group mode requires spad=t and skpad=t.
|
||||
|
||||
Source: fmha_fwd.py CompatibilityRuleFactory.check_feature() +
|
||||
block_fmha_pipeline static_asserts for padding.
|
||||
"""
|
||||
if mode == "group":
|
||||
return spad == "t" and skpad == "t"
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 5. Variant-specific tile tables (loaded from fmha_arch_specs.json)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _build_bwd_tiles() -> Tuple[
|
||||
Dict[Tuple[int, int], Tuple[int, ...]],
|
||||
Dict[Tuple[int, int], List[Tuple[Tuple[int, ...], str, bool]]],
|
||||
Dict[Tuple[int, int, int, str], dict],
|
||||
]:
|
||||
"""Build BWD tile tables from JSON."""
|
||||
bwd = _load_fmha_specs()["bwd_tiles"]
|
||||
|
||||
# Main tiles: "hdimq_hdimv" -> 9-tuple
|
||||
main = {}
|
||||
for k, v in bwd["dq_dk_dv_fp16"].items():
|
||||
hq, hv = map(int, k.split("_"))
|
||||
main[(hq, hv)] = tuple(v)
|
||||
|
||||
# Extra tiles: "hdimq_hdimv" -> [(tile, tag, batch_only), ...]
|
||||
extra = {}
|
||||
for k, entries in bwd.get("dq_dk_dv_extra", {}).items():
|
||||
hq, hv = map(int, k.split("_"))
|
||||
extra[(hq, hv)] = [
|
||||
(tuple(e["tile"]), e["tag"], e["batch_only"]) for e in entries
|
||||
]
|
||||
|
||||
# Wave/warp lookup: "bm0_bn0_bk0_trload" -> {wave, warp_k1}
|
||||
ww = {}
|
||||
for k, v in _load_fmha_specs()["bwd_wave_warp"].items():
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
parts = k.split("_")
|
||||
key = (int(parts[0]), int(parts[1]), int(parts[2]), parts[3])
|
||||
ww[key] = {"wave": tuple(v["wave"]), "warp_k1": v["warp_k1"]}
|
||||
|
||||
return main, extra, ww
|
||||
|
||||
|
||||
def _build_splitkv_hdims() -> Tuple[List[int], List[int]]:
|
||||
"""Build SplitKV combine hdim lists from JSON."""
|
||||
skv = _load_fmha_specs()["splitkv_combine"]
|
||||
return skv["hdims_fp16"], skv["hdims_fp8"]
|
||||
|
||||
|
||||
_bwd_main, _bwd_extra, _bwd_ww = _build_bwd_tiles()
|
||||
_skv_fp16, _skv_fp8 = _build_splitkv_hdims()
|
||||
|
||||
SPLITKV_COMBINE_HDIMS_FP16: List[int] = _skv_fp16
|
||||
SPLITKV_COMBINE_HDIMS_FP8: List[int] = _skv_fp8
|
||||
BWD_DQ_DK_DV_TILES_FP16: Dict[Tuple[int, int], Tuple[int, ...]] = _bwd_main
|
||||
BWD_DQ_DK_DV_EXTRA_TILES: Dict[
|
||||
Tuple[int, int], List[Tuple[Tuple[int, ...], str, bool]]
|
||||
] = _bwd_extra
|
||||
BWD_DQ_WAVE_WARP: Dict[Tuple[int, int, int, str], dict] = _bwd_ww
|
||||
|
||||
_bwd_json = _load_fmha_specs()["bwd_tiles"]
|
||||
BWD_EXTRA_PAD_COMBOS: List[Tuple[str, str]] = [
|
||||
tuple(p) for p in _bwd_json["extra_pad_combos"]
|
||||
]
|
||||
BWD_SMALL_DROPOUTS: List[str] = _bwd_json["small_dropouts"]
|
||||
BWD_DOT_DO_O_HDIMS: List[int] = _bwd_json["dot_do_o_hdims"]
|
||||
BWD_CONVERT_DQ_HDIMS: List[int] = _bwd_json["convert_dq_hdims"]
|
||||
BWD_CONVERT_DQ_TILE_GROUPS: Dict[int, int] = {
|
||||
int(k): v for k, v in _bwd_json["convert_dq_tile_groups"].items()
|
||||
}
|
||||
BWD_DROPOUTS: List[str] = _bwd_json["dropouts"]
|
||||
BWD_PAD_COMBOS: List[Tuple[str, str]] = [tuple(p) for p in _bwd_json["pad_combos"]]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 6. Receipt filters
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class Receipt(IntEnum):
|
||||
"""Named receipt levels for deployment profiles.
|
||||
|
||||
These are deployment-specific filters, not derived from C++ constraints.
|
||||
They control which kernel subsets are emitted for different integration
|
||||
targets (PyTorch, AITER, Flash-Attention, etc.).
|
||||
"""
|
||||
|
||||
CK_DEFAULT = 0
|
||||
CK_EXTENDED = 1
|
||||
FLASH_FWD = 2
|
||||
FLASH_BWD = 3
|
||||
PYTORCH = 4
|
||||
AITER_BATCH = 100
|
||||
AITER_GROUP = 200
|
||||
AITER_BWD_BATCH = 300
|
||||
AITER_BWD_GROUP = 400
|
||||
AITER_CPP = 600
|
||||
FP32_ALL = 800
|
||||
FP32_MIN = 801
|
||||
FP8_TEST = 888
|
||||
|
||||
|
||||
RECEIPT_FILTERS: Dict[int, Callable[[str, object], bool]] = {
|
||||
0: lambda dtype, spec: dtype != "fp32",
|
||||
2: lambda dtype, spec: (
|
||||
dtype in ("fp16", "bf16")
|
||||
and getattr(spec, "bias", "no") in ("no", "alibi")
|
||||
and getattr(spec, "qscale", "no") == "no"
|
||||
and getattr(spec, "skip", "f") == "f"
|
||||
and getattr(spec, "sink", "f") == "f"
|
||||
),
|
||||
4: lambda dtype, spec: (
|
||||
dtype in ("fp16", "bf16")
|
||||
and getattr(spec, "bias", "no") in ("no", "bias")
|
||||
and getattr(spec, "qscale", "no") == "no"
|
||||
and getattr(spec, "skip", "f") == "f"
|
||||
and getattr(spec, "logits", "f") == "f"
|
||||
),
|
||||
100: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"),
|
||||
200: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"),
|
||||
600: lambda dtype, spec: dtype in ("fp16", "bf16", "fp8bf16"),
|
||||
888: lambda dtype, spec: dtype in ("fp8bf16", "fp8fp32"),
|
||||
800: lambda dtype, spec: (
|
||||
dtype == "fp32"
|
||||
and getattr(spec, "skip", "f") == "f"
|
||||
and getattr(spec, "logits", "f") == "f"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def receipt_filter(receipt: int, dtype: str, spec) -> bool:
|
||||
"""Apply receipt-level filter. Returns True if the kernel should be kept."""
|
||||
fn = RECEIPT_FILTERS.get(receipt)
|
||||
if fn is None:
|
||||
return dtype != "fp32"
|
||||
return fn(dtype, spec)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 7. Profiles
|
||||
# =============================================================================
|
||||
|
||||
PROFILE_ALIASES: Dict[str, str] = {str(r.value): r.name.lower() for r in Receipt}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaProfile:
|
||||
name: str
|
||||
predicate: Callable[[dict], bool]
|
||||
|
||||
def allows(self, config: dict) -> bool:
|
||||
return self.predicate(config)
|
||||
|
||||
|
||||
def _dtype_is(config: dict, allowed: Iterable[str]) -> bool:
|
||||
return config["signature"]["data_type"] in set(allowed)
|
||||
|
||||
|
||||
def _mode_is(config: dict, allowed: Iterable[str]) -> bool:
|
||||
return config["signature"]["mode"] in set(allowed)
|
||||
|
||||
|
||||
def _family_is(config: dict, allowed: Iterable[str]) -> bool:
|
||||
return config["signature"]["family"] in set(allowed)
|
||||
|
||||
|
||||
def _common_row_major_filter(config: dict) -> bool:
|
||||
return config["signature"]["vlayout"] == "r"
|
||||
|
||||
|
||||
def _bias_is(config: dict, allowed: Iterable[str]) -> bool:
|
||||
return canonical_bias(config["signature"]["bias"]) in set(allowed)
|
||||
|
||||
|
||||
def _qscale_is(config: dict, allowed: Iterable[str]) -> bool:
|
||||
return canonical_qscale(config["signature"]["qscale"]) in set(allowed)
|
||||
|
||||
|
||||
def _no_skip_or_logits(config: dict) -> bool:
|
||||
return (not config["signature"]["skip_min_seqlen_q"]) and (
|
||||
not config["signature"]["logits"]
|
||||
)
|
||||
|
||||
|
||||
PROFILES: Dict[str, FmhaProfile] = {
|
||||
"ck_default": FmhaProfile(
|
||||
"ck_default", lambda c: c["signature"]["data_type"] != "fp32"
|
||||
),
|
||||
"ck_extended": FmhaProfile(
|
||||
"ck_extended", lambda c: c["signature"]["data_type"] != "fp32"
|
||||
),
|
||||
"flash_fwd": FmhaProfile(
|
||||
"flash_fwd",
|
||||
lambda c: (
|
||||
_family_is(c, {"fwd", "fwd_splitkv", "fwd_appendkv", "fwd_pagedkv"})
|
||||
and _dtype_is(c, {"fp16", "bf16"})
|
||||
and _common_row_major_filter(c)
|
||||
and _bias_is(c, {"no", "alibi"})
|
||||
and _qscale_is(c, {"no"})
|
||||
and not c["signature"]["skip_min_seqlen_q"]
|
||||
),
|
||||
),
|
||||
"flash_bwd": FmhaProfile(
|
||||
"flash_bwd",
|
||||
lambda c: (
|
||||
_family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"})
|
||||
and _dtype_is(c, {"fp16", "bf16"})
|
||||
),
|
||||
),
|
||||
"pytorch": FmhaProfile(
|
||||
"pytorch",
|
||||
lambda c: (
|
||||
_dtype_is(c, {"fp16", "bf16"})
|
||||
and _common_row_major_filter(c)
|
||||
and _bias_is(c, {"no", "bias"})
|
||||
and _qscale_is(c, {"no"})
|
||||
and _no_skip_or_logits(c)
|
||||
and not c["signature"].get("sink", False)
|
||||
),
|
||||
),
|
||||
"aiter_batch": FmhaProfile(
|
||||
"aiter_batch",
|
||||
lambda c: (
|
||||
_dtype_is(c, {"fp16", "bf16", "fp8bf16"})
|
||||
and _mode_is(c, {"batch"})
|
||||
and _common_row_major_filter(c)
|
||||
and (
|
||||
c["signature"]["data_type"] != "fp8bf16"
|
||||
or c["signature"]["hdim_q"] in {128, 192}
|
||||
)
|
||||
),
|
||||
),
|
||||
"aiter_group": FmhaProfile(
|
||||
"aiter_group",
|
||||
lambda c: (
|
||||
_dtype_is(c, {"fp16", "bf16", "fp8bf16"})
|
||||
and _mode_is(c, {"group"})
|
||||
and _common_row_major_filter(c)
|
||||
),
|
||||
),
|
||||
"aiter_bwd_batch": FmhaProfile(
|
||||
"aiter_bwd_batch",
|
||||
lambda c: (
|
||||
_family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"})
|
||||
and _dtype_is(c, {"fp16", "bf16"})
|
||||
and _mode_is(c, {"batch"})
|
||||
),
|
||||
),
|
||||
"aiter_bwd_group": FmhaProfile(
|
||||
"aiter_bwd_group",
|
||||
lambda c: (
|
||||
_family_is(c, {"bwd_dot_do_o", "bwd_dq_dk_dv", "bwd_convert_dq"})
|
||||
and _dtype_is(c, {"fp16", "bf16"})
|
||||
and _mode_is(c, {"group"})
|
||||
),
|
||||
),
|
||||
"aiter_cpp": FmhaProfile(
|
||||
"aiter_cpp",
|
||||
lambda c: (
|
||||
_dtype_is(c, {"fp16", "bf16", "fp8bf16"})
|
||||
and _common_row_major_filter(c)
|
||||
and (
|
||||
c["signature"]["data_type"] != "fp8bf16"
|
||||
or c["signature"]["hdim_q"] in {128, 192}
|
||||
)
|
||||
),
|
||||
),
|
||||
"fp32_all": FmhaProfile(
|
||||
"fp32_all", lambda c: _dtype_is(c, {"fp32"}) and _no_skip_or_logits(c)
|
||||
),
|
||||
"fp32_min": FmhaProfile(
|
||||
"fp32_min",
|
||||
lambda c: (
|
||||
_dtype_is(c, {"fp32"})
|
||||
and _mode_is(c, {"batch"})
|
||||
and c["signature"]["hdim_q"] in {48, 128}
|
||||
and c["signature"]["hdim_v"] in {48, 128}
|
||||
and canonical_bias(c["signature"]["bias"]) == "no"
|
||||
and not c["signature"]["lse"]
|
||||
and not c["signature"]["dropout"]
|
||||
and canonical_qscale(c["signature"]["qscale"]) == "no"
|
||||
),
|
||||
),
|
||||
"fp8_test": FmhaProfile(
|
||||
"fp8_test",
|
||||
lambda c: (
|
||||
_dtype_is(c, {"fp8bf16", "fp8fp32"})
|
||||
and c["signature"]["hdim_q"] in {128, 192}
|
||||
and _common_row_major_filter(c)
|
||||
),
|
||||
),
|
||||
"all": FmhaProfile("all", lambda _: True),
|
||||
}
|
||||
|
||||
|
||||
def normalize_profile(
|
||||
profile: Optional[str] = None, receipt: Optional[str] = None
|
||||
) -> str:
|
||||
if profile:
|
||||
return PROFILE_ALIASES.get(str(profile), str(profile))
|
||||
if receipt is not None:
|
||||
return PROFILE_ALIASES.get(str(receipt), str(receipt))
|
||||
return "ck_default"
|
||||
|
||||
|
||||
def get_profile(
|
||||
profile: Optional[str] = None, receipt: Optional[str] = None
|
||||
) -> FmhaProfile:
|
||||
normalized = normalize_profile(profile=profile, receipt=receipt)
|
||||
if normalized not in PROFILES:
|
||||
raise KeyError(f"Unknown FMHA profile: {normalized}")
|
||||
return PROFILES[normalized]
|
||||
|
||||
|
||||
def profile_allows(
|
||||
config: dict, profile: Optional[str] = None, receipt: Optional[str] = None
|
||||
) -> bool:
|
||||
return get_profile(profile=profile, receipt=receipt).allows(config)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. Validation helpers (for unified_fmha_codegen)
|
||||
# =============================================================================
|
||||
|
||||
_DEFAULTS: dict = _load_fmha_specs()["defaults"]
|
||||
_GLOBAL_RULES: dict = _load_fmha_specs()["global_rules"]
|
||||
|
||||
|
||||
def load_arch_specs() -> dict:
|
||||
"""Return arch_specs dict compatible with unified_fmha_codegen.
|
||||
|
||||
Combines FMHA-specific architecture data from fmha_arch_specs.json with
|
||||
defaults, global rules, and splitkv combine params.
|
||||
"""
|
||||
specs = _load_fmha_specs()
|
||||
return {
|
||||
"architectures": ARCH_METADATA,
|
||||
"defaults": _DEFAULTS,
|
||||
"global_rules": _GLOBAL_RULES,
|
||||
"splitkv_combine": specs["splitkv_combine"],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 9. Config validation (for unified_fmha_codegen)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
valid: bool = True
|
||||
errors: List[str] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
|
||||
def add_error(self, msg: str):
|
||||
self.valid = False
|
||||
self.errors.append(msg)
|
||||
|
||||
def add_warning(self, msg: str):
|
||||
self.warnings.append(msg)
|
||||
|
||||
|
||||
def validate_config(
|
||||
config: dict, arch_specs: Optional[dict] = None
|
||||
) -> "ValidationResult":
|
||||
"""Validate an FMHA kernel config against all rules."""
|
||||
arch_specs = arch_specs or load_arch_specs()
|
||||
result = ValidationResult()
|
||||
|
||||
sig = config["signature"]
|
||||
alg = config["algorithm"]
|
||||
arch = config["arch"]
|
||||
|
||||
architectures = arch_specs.get("architectures", ARCH_METADATA)
|
||||
if arch not in architectures:
|
||||
result.add_error(f"Unsupported FMHA target architecture: {arch}")
|
||||
return result
|
||||
|
||||
arch_info = architectures[arch]
|
||||
global_rules = arch_specs.get("global_rules", _GLOBAL_RULES)
|
||||
dtype = sig["data_type"]
|
||||
family = sig["family"]
|
||||
pipeline = alg["pipeline"]
|
||||
canonical_mask(sig["mask"])
|
||||
bias = canonical_bias(sig["bias"])
|
||||
|
||||
# Family validation
|
||||
supported_families = {
|
||||
"fwd",
|
||||
"fwd_pagedkv",
|
||||
"fwd_splitkv",
|
||||
"fwd_splitkv_combine",
|
||||
"fwd_appendkv",
|
||||
"batch_prefill",
|
||||
"bwd_dot_do_o",
|
||||
"bwd_dq_dk_dv",
|
||||
"bwd_convert_dq",
|
||||
}
|
||||
if family not in supported_families:
|
||||
result.add_error(f"Unsupported FMHA family: {family}")
|
||||
|
||||
# Dtype validation
|
||||
supported_dtypes = set(arch_info["supported_dtypes"])
|
||||
if dtype not in supported_dtypes:
|
||||
result.add_error(f"dtype {dtype} is not supported on {arch}")
|
||||
|
||||
if family.startswith("bwd") and dtype not in BWD_DTYPE_MAP:
|
||||
result.add_error(
|
||||
f"Backward family {family} only supports {sorted(BWD_DTYPE_MAP)}"
|
||||
)
|
||||
|
||||
if (
|
||||
family.startswith("fwd")
|
||||
and not family.startswith("fwd_append")
|
||||
and dtype not in FWD_DTYPE_MAP
|
||||
):
|
||||
result.add_error(f"Forward family {family} does not recognize dtype {dtype}")
|
||||
|
||||
# Pipeline validation
|
||||
if (
|
||||
family != "fwd_splitkv_combine"
|
||||
and pipeline not in arch_info["supported_pipelines"]
|
||||
):
|
||||
result.add_error(f"pipeline {pipeline} is not supported on {arch}")
|
||||
|
||||
if pipeline in {"v3", "qr_async_trload_v3"} and not arch_info.get(
|
||||
"supports_v3", False
|
||||
):
|
||||
result.add_warning(f"v3 pipeline on {arch} requires supports_v3 in arch specs")
|
||||
|
||||
if pipeline == "qr_async_trload" and not arch_info.get("supports_trload", False):
|
||||
result.add_error("qr_async_trload requires a trload-capable architecture")
|
||||
|
||||
# Global rules
|
||||
hdim_q = sig["hdim_q"]
|
||||
hdim_v = sig["hdim_v"]
|
||||
divisor = global_rules.get("hdim_divisible_by", 8)
|
||||
if hdim_q % divisor != 0 or hdim_v % divisor != 0:
|
||||
result.add_error(f"Head dimensions must be multiples of {divisor}")
|
||||
|
||||
if global_rules.get("hdim_192_128_no_bias_dropout"):
|
||||
if (
|
||||
hdim_q == 192
|
||||
and hdim_v == 128
|
||||
and (bias != "no" or sig.get("dropout", False))
|
||||
):
|
||||
result.add_warning(
|
||||
"hdim (192,128) with bias/dropout has limited tile support"
|
||||
)
|
||||
|
||||
if global_rules.get("logits_requires_no_bias"):
|
||||
if bias != "no" and sig.get("logits", False):
|
||||
result.add_error("logits_soft_cap cannot be combined with bias")
|
||||
|
||||
if pipeline in {"qr_async_trload", "v3", "qr_async_trload_v3"} and (
|
||||
hdim_q != hdim_v or hdim_q not in {64, 128}
|
||||
):
|
||||
result.add_error(f"{pipeline} only supports symmetric head dims 64 or 128")
|
||||
|
||||
# Tile validation
|
||||
tile = alg["tile"]
|
||||
expected_tile_len = 9 if family == "bwd_dq_dk_dv" else 6
|
||||
if len(tile) != expected_tile_len or len(alg["wave"]) != 9 or len(alg["warp"]) != 9:
|
||||
result.add_error(
|
||||
f"tile/wave/warp must have {expected_tile_len}/9/9 elements for {family}"
|
||||
)
|
||||
|
||||
# MFMA instruction count check for qr/h256/CDNA
|
||||
_1d_families = {"bwd_dot_do_o", "bwd_convert_dq"}
|
||||
if (
|
||||
pipeline == "qr"
|
||||
and hdim_q == 256
|
||||
and family not in _1d_families
|
||||
and arch_info.get("family", "").startswith("cdna")
|
||||
and len(tile) >= 3
|
||||
and len(alg["wave"]) >= 2
|
||||
and len(alg["warp"]) >= 3
|
||||
):
|
||||
wm, wn, wk = alg["warp"][0], alg["warp"][1], alg["warp"][2]
|
||||
gm, gn = alg["wave"][0], alg["wave"][1]
|
||||
if wm > 0 and wn > 0 and wk > 0 and gm > 0 and gn > 0:
|
||||
num_mfma = (tile[0] // wm) * (tile[1] // wn) * (tile[2] // wk) // (gm * gn)
|
||||
if num_mfma % 8 != 0:
|
||||
result.add_error(
|
||||
f"NumMfmaInsts={num_mfma} must be divisible by 8 for qr/h256/CDNA"
|
||||
)
|
||||
|
||||
if alg["block_per_cu"] <= 0 and alg["block_per_cu"] != -1:
|
||||
result.add_error("block_per_cu must be positive or -1 (auto)")
|
||||
if alg["num_wave_groups"] <= 0:
|
||||
result.add_error("num_wave_groups must be positive")
|
||||
|
||||
# --- Family-specific rules ---
|
||||
if family == "batch_prefill":
|
||||
if sig.get("vlayout", "r") != "r":
|
||||
result.add_error("batch_prefill only supports row-major V layout")
|
||||
if not sig.get("paged_kv", False):
|
||||
result.add_error("batch_prefill requires paged_kv=true")
|
||||
ps = sig.get("page_size", 0)
|
||||
if ps <= 0 or (ps & (ps - 1)) != 0:
|
||||
result.add_error("batch_prefill page_size must be a positive power of two")
|
||||
if sig.get("mode", "batch") != "group":
|
||||
result.add_error("batch_prefill requires group mode")
|
||||
if pipeline != "qr_async":
|
||||
result.add_error("batch_prefill currently uses qr_async pipeline")
|
||||
|
||||
if family == "fwd_appendkv":
|
||||
if sig.get("mode", "batch") != "batch":
|
||||
result.add_error("fwd_appendkv uses batch-mode public API surface")
|
||||
if pipeline != "appendkv":
|
||||
result.add_error("fwd_appendkv must use appendkv pipeline")
|
||||
if sig.get("vlayout", "r") != "r":
|
||||
result.add_error("fwd_appendkv currently only supports row-major V")
|
||||
|
||||
if family == "fwd_splitkv_combine":
|
||||
if sig.get("mode", "batch") not in {"batch", "group"}:
|
||||
result.add_error("fwd_splitkv_combine requires batch or group mode")
|
||||
combine_bn1 = arch_specs.get("splitkv_combine", {}).get("combine_bn1", 32)
|
||||
if len(tile) > 3 and tile[3] != combine_bn1:
|
||||
result.add_error(f"fwd_splitkv_combine requires bn1={combine_bn1}")
|
||||
if len(tile) > 3 and (hdim_v < tile[3] or hdim_v % tile[3] != 0):
|
||||
result.add_error("fwd_splitkv_combine requires hdim_v divisible by bn1")
|
||||
|
||||
if family == "fwd_pagedkv":
|
||||
if pipeline != "qr_pagedkv":
|
||||
result.add_error("fwd_pagedkv must use qr_pagedkv pipeline")
|
||||
if not sig.get("paged_kv", False):
|
||||
result.add_error("fwd_pagedkv requires paged_kv=true")
|
||||
if sig.get("vlayout", "r") != "r":
|
||||
result.add_error("fwd_pagedkv currently only supports row-major V")
|
||||
|
||||
if family == "fwd_splitkv":
|
||||
if pipeline not in {"qr", "qr_nwarp_sshuffle"}:
|
||||
result.add_error("fwd_splitkv must use qr or qr_nwarp_sshuffle pipeline")
|
||||
if sig.get("vlayout", "r") != "r":
|
||||
result.add_error("fwd_splitkv currently only supports row-major V")
|
||||
|
||||
if family == "fwd" and sig.get("vlayout", "r") != "r":
|
||||
result.add_warning("dispatcher forward examples currently assume row-major V")
|
||||
|
||||
return result
|
||||
@@ -230,7 +230,7 @@ def generate_cpp_header(specs: Dict[str, Any], output_path: Path):
|
||||
|
||||
for arch, data in archs.items():
|
||||
enum_name = arch.upper().replace("GFX", "GFX_")
|
||||
arch_enums.append(f" {enum_name}, // {data['description']}")
|
||||
arch_enums.append(f" {enum_name},")
|
||||
arch_to_string_cases.append(
|
||||
f' case GpuArch::{enum_name}: return "{arch}";'
|
||||
)
|
||||
@@ -288,12 +288,12 @@ def generate_cpp_header(specs: Dict[str, Any], output_path: Path):
|
||||
f" if (pipeline == Pipeline::{pipeline_enum_map[pipeline]}) return {limit};"
|
||||
)
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
content = f"""// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
|
||||
*
|
||||
*
|
||||
* Generated from: arch_specs.json
|
||||
* Generated at: {timestamp}
|
||||
*
|
||||
|
||||
310
dispatcher/codegen/grouped_config_rules.py
Normal file
310
dispatcher/codegen/grouped_config_rules.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Single Source of Truth for Grouped Convolution Tile Configurations
|
||||
|
||||
This module defines all valid tile configurations for grouped convolution kernels.
|
||||
Both codegen and instance_builder import from here to ensure consistency.
|
||||
|
||||
Architecture:
|
||||
grouped_conv_tile_configs.py (SOURCE OF TRUTH)
|
||||
├── Used by unified_grouped_conv_codegen.py
|
||||
└── Used by grouped_conv_instance_builder.py
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
# =============================================================================
|
||||
# Tile Configurations (Single Source of Truth)
|
||||
# =============================================================================
|
||||
|
||||
# Common tile configurations used across variants
|
||||
# Format: (tile_m, tile_n, tile_k)
|
||||
# CRITICAL: tile_m MUST equal wave_m × warp_tile_m (TileGemmShape constraint)
|
||||
# Only tiles that successfully compile are included
|
||||
COMMON_TILES: List[Tuple[int, int, int]] = [
|
||||
# Using warp_tile [16,16,16]: tile_m = wave_m × 16
|
||||
(16, 64, 64), # 1 × 16 = 16, wave=(1,4,1)
|
||||
(32, 64, 64), # 2 × 16 = 32, wave=(2,2,1)
|
||||
(64, 64, 64), # 4 × 16 = 64, wave=(4,1,1)
|
||||
# (128, 64, 64), # 8 × 16 = 128, wave=(8,2,1) - EXCLUDED: Compile error
|
||||
# Using warp_tile [32,32,16]: tile_m = wave_m × 32
|
||||
(32, 128, 64), # 1 × 32 = 32, wave=(1,4,1)
|
||||
(64, 128, 64), # 2 × 32 = 64, wave=(2,2,1)
|
||||
(128, 128, 64), # 4 × 32 = 128, wave=(4,4,1) - NEW!
|
||||
# Note: 256x64x64 excluded - compilation issues
|
||||
# Using warp_tile [16,16,32]: tile_m = wave_m × 16
|
||||
(16, 64, 128), # 1 × 16 = 16, wave=(1,4,1)
|
||||
(32, 64, 128), # 2 × 16 = 32, wave=(2,2,1)
|
||||
(64, 64, 128), # 4 × 16 = 64, wave=(4,1,1)
|
||||
(128, 64, 128), # 8 × 16 = 128, wave=(8,2,1) - NEW!
|
||||
# Note: Excluded tiles:
|
||||
# - 128x64x64: wave=8x2x1, warp=16x16x16 - compile error
|
||||
# - 32x128x128, 64x128x128, 128x128x128, 256x128x128 (warp_tile 32x32x32) - compv4 issues
|
||||
# - 256x64x64, 256x128x128 - arch filter rejection
|
||||
]
|
||||
|
||||
# Wave configurations per tile
|
||||
# Key: (tile_m, tile_n, tile_k) -> (wave_m, wave_n, wave_k)
|
||||
# Constraint: tile_m == wave_m × warp_tile_m
|
||||
# Only use approved wave configs from arch_specs.json: [1,4,1], [2,2,1], [4,1,1], [8,2,1], [4,4,1]
|
||||
TILE_TO_WAVE: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
|
||||
# warp_tile [16,16,16]
|
||||
(16, 64, 64): (1, 4, 1),
|
||||
(32, 64, 64): (2, 2, 1),
|
||||
(64, 64, 64): (4, 1, 1),
|
||||
# warp_tile [32,32,16]
|
||||
(32, 128, 64): (1, 4, 1),
|
||||
(64, 128, 64): (2, 2, 1),
|
||||
(128, 128, 64): (4, 4, 1), # NEW - balanced 4x4 wave
|
||||
# warp_tile [16,16,32]
|
||||
(16, 64, 128): (1, 4, 1),
|
||||
(32, 64, 128): (2, 2, 1),
|
||||
(64, 64, 128): (4, 1, 1),
|
||||
(128, 64, 128): (8, 2, 1), # NEW
|
||||
}
|
||||
|
||||
# Warp tile configurations (must match arch_specs.json gfx950 bf16 approved list)
|
||||
# Key: (tile_m, tile_n, tile_k) -> (warp_m, warp_n, warp_k)
|
||||
TILE_TO_WARP: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
|
||||
# warp_tile [16,16,16]
|
||||
(16, 64, 64): (16, 16, 16),
|
||||
(32, 64, 64): (16, 16, 16),
|
||||
(64, 64, 64): (16, 16, 16),
|
||||
# warp_tile [32,32,16]
|
||||
(32, 128, 64): (32, 32, 16),
|
||||
(64, 128, 64): (32, 32, 16),
|
||||
(128, 128, 64): (32, 32, 16), # NEW
|
||||
# warp_tile [16,16,32]
|
||||
(16, 64, 128): (16, 16, 32),
|
||||
(32, 64, 128): (16, 16, 32),
|
||||
(64, 64, 128): (16, 16, 32),
|
||||
(128, 64, 128): (16, 16, 32), # NEW
|
||||
}
|
||||
|
||||
# Vector sizes per tile (for memory operations)
|
||||
# Key: (tile_m, tile_n, tile_k) -> (vec_a, vec_b, vec_c)
|
||||
TILE_TO_VECTOR: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
|
||||
(16, 64, 64): (4, 8, 8),
|
||||
(32, 64, 64): (4, 8, 8),
|
||||
(64, 64, 64): (4, 8, 8),
|
||||
(32, 128, 64): (4, 8, 8),
|
||||
(64, 128, 64): (4, 8, 8),
|
||||
(128, 128, 64): (4, 8, 8),
|
||||
(16, 64, 128): (4, 8, 8),
|
||||
(32, 64, 128): (4, 8, 8),
|
||||
(64, 64, 128): (4, 8, 8),
|
||||
(128, 64, 128): (4, 8, 8),
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Pipeline Variant Suffixes (single source of truth)
|
||||
# =============================================================================
|
||||
# Empirically verified valid (pipeline, wave_mode, has_dsb, has_si) combinations
|
||||
# observed in the 2D and 3D bf16 gfx950 benchmark CSVs. 30 entries total per ndim.
|
||||
# Each tuple: (pipeline, wave_mode, has_dsb, has_si)
|
||||
# wave_mode: "intrawave" | "interwave"
|
||||
# has_dsb: 1 if "_dsb" suffix present (double smem buffer), else 0
|
||||
# has_si: 1 if "_si" suffix present (store immediate), else 0
|
||||
PIPELINE_VARIANTS: List[Tuple[str, str, int, int]] = [
|
||||
# basic_v1: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
|
||||
("basic_v1", "intrawave", 0, 0),
|
||||
("basic_v1", "intrawave", 1, 0),
|
||||
("basic_v1", "intrawave", 0, 1),
|
||||
("basic_v1", "intrawave", 1, 1),
|
||||
("basic_v1", "interwave", 0, 0),
|
||||
("basic_v1", "interwave", 1, 0),
|
||||
("basic_v1", "interwave", 0, 1),
|
||||
("basic_v1", "interwave", 1, 1),
|
||||
# compv3: intrawave × {∅, dsb, si, dsb_si} = 4 combos
|
||||
("compv3", "intrawave", 0, 0),
|
||||
("compv3", "intrawave", 1, 0),
|
||||
("compv3", "intrawave", 0, 1),
|
||||
("compv3", "intrawave", 1, 1),
|
||||
# compv4: intrawave × {dsb, dsb_si} only = 2 combos
|
||||
("compv4", "intrawave", 1, 0),
|
||||
("compv4", "intrawave", 1, 1),
|
||||
# compv5: intrawave × {∅, dsb, si, dsb_si} = 4 combos
|
||||
("compv5", "intrawave", 0, 0),
|
||||
("compv5", "intrawave", 1, 0),
|
||||
("compv5", "intrawave", 0, 1),
|
||||
("compv5", "intrawave", 1, 1),
|
||||
# compv6: intrawave × {∅, dsb, si, dsb_si} = 4 combos
|
||||
("compv6", "intrawave", 0, 0),
|
||||
("compv6", "intrawave", 1, 0),
|
||||
("compv6", "intrawave", 0, 1),
|
||||
("compv6", "intrawave", 1, 1),
|
||||
# mem: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
|
||||
("mem", "intrawave", 0, 0),
|
||||
("mem", "intrawave", 1, 0),
|
||||
("mem", "intrawave", 0, 1),
|
||||
("mem", "intrawave", 1, 1),
|
||||
("mem", "interwave", 0, 0),
|
||||
("mem", "interwave", 1, 0),
|
||||
("mem", "interwave", 0, 1),
|
||||
("mem", "interwave", 1, 1),
|
||||
]
|
||||
|
||||
|
||||
def iter_pipeline_variants(pipelines: List[str] = None):
|
||||
"""Iterate (pipeline, wave_mode, has_dsb, has_si) tuples, optionally filtered.
|
||||
|
||||
Args:
|
||||
pipelines: optional list of pipeline names to keep. If None, yield all.
|
||||
"""
|
||||
if pipelines is None:
|
||||
for entry in PIPELINE_VARIANTS:
|
||||
yield entry
|
||||
return
|
||||
keep = set(pipelines)
|
||||
for entry in PIPELINE_VARIANTS:
|
||||
if entry[0] in keep:
|
||||
yield entry
|
||||
|
||||
|
||||
# Valid pipelines per variant
|
||||
# All 8 pipelines (basic_v1, mem, compv3-6, comp_async, basic_async_v1) successfully
|
||||
# build and run for all variants in both 2D and 3D (verified via 10_test_all_pipelines.py)
|
||||
VARIANT_PIPELINES: Dict[str, List[str]] = {
|
||||
"forward": [
|
||||
"basic_v1",
|
||||
"mem",
|
||||
"compv3",
|
||||
"compv4",
|
||||
"compv5",
|
||||
"compv6",
|
||||
"comp_async",
|
||||
"basic_async_v1",
|
||||
],
|
||||
"bwd_data": [
|
||||
"basic_v1",
|
||||
"mem",
|
||||
"compv3",
|
||||
"compv4",
|
||||
"compv5",
|
||||
"compv6",
|
||||
"comp_async",
|
||||
"basic_async_v1",
|
||||
],
|
||||
"bwd_weight": [
|
||||
"basic_v1",
|
||||
"mem",
|
||||
"compv3",
|
||||
"compv4",
|
||||
"compv5",
|
||||
"compv6",
|
||||
"comp_async",
|
||||
"basic_async_v1",
|
||||
],
|
||||
}
|
||||
|
||||
# Tiles that support compv4 pipeline
|
||||
# compv4 has stricter requirements due to double buffering and LDS constraints
|
||||
# Pattern: only warp_tile [16,16,16] or [16,16,32] work with compv4
|
||||
# Large warp_tile [32,32,16] and wave [8,2,1] fail arch validation for compv4
|
||||
COMPV4_COMPATIBLE_TILES: List[Tuple[int, int, int]] = [
|
||||
# warp_tile [16,16,16] - all work with compv4
|
||||
(16, 64, 64),
|
||||
(32, 64, 64),
|
||||
(64, 64, 64),
|
||||
# (128, 64, 64), # Excluded: wave=8x2x1 fails for compv4
|
||||
# warp_tile [16,16,32] - all work with compv4
|
||||
(16, 64, 128),
|
||||
(32, 64, 128),
|
||||
(64, 64, 128),
|
||||
# (128, 64, 128), # Excluded: wave=8x2x1 fails for compv4
|
||||
]
|
||||
|
||||
# Backward weight tiles (very restricted due to transpose_tile2d constraints)
|
||||
# Testing all tiles to verify which ones actually work
|
||||
BWD_WEIGHT_TILES: List[Tuple[int, int, int]] = [
|
||||
# warp_tile [16,16,16]
|
||||
(16, 64, 64), # Known working config
|
||||
(32, 64, 64), # Test
|
||||
(64, 64, 64), # Test
|
||||
# warp_tile [32,32,16]
|
||||
(32, 128, 64), # Test
|
||||
(64, 128, 64), # Test
|
||||
(128, 128, 64), # Test
|
||||
# warp_tile [16,16,32]
|
||||
(16, 64, 128), # Test
|
||||
(32, 64, 128), # Test
|
||||
(64, 64, 128), # Test
|
||||
(128, 64, 128), # Test
|
||||
]
|
||||
|
||||
# =============================================================================
|
||||
# Validation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def validate_tile_config(tile_m: int, tile_n: int, tile_k: int) -> bool:
|
||||
"""Check if a tile configuration is valid and registered."""
|
||||
tile_key = (tile_m, tile_n, tile_k)
|
||||
return (
|
||||
tile_key in TILE_TO_WAVE
|
||||
and tile_key in TILE_TO_WARP
|
||||
and tile_key in TILE_TO_VECTOR
|
||||
)
|
||||
|
||||
|
||||
def get_tile_full_config(tile_m: int, tile_n: int, tile_k: int) -> dict:
|
||||
"""Get complete configuration for a tile size.
|
||||
|
||||
Returns:
|
||||
dict with keys: wave_m, wave_n, wave_k, warp_m, warp_n, warp_k, vec_a, vec_b, vec_c
|
||||
or None if tile not found
|
||||
"""
|
||||
tile_key = (tile_m, tile_n, tile_k)
|
||||
if not validate_tile_config(tile_m, tile_n, tile_k):
|
||||
return None
|
||||
|
||||
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
|
||||
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
|
||||
vec_a, vec_b, vec_c = TILE_TO_VECTOR[tile_key]
|
||||
|
||||
return {
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"wave_m": wave_m,
|
||||
"wave_n": wave_n,
|
||||
"wave_k": wave_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"vec_a": vec_a,
|
||||
"vec_b": vec_b,
|
||||
"vec_c": vec_c,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Summary Statistics
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def print_summary():
|
||||
"""Print summary of available tile configurations."""
|
||||
print("=" * 80)
|
||||
print("Grouped Convolution Tile Configurations (Single Source of Truth)")
|
||||
print("=" * 80)
|
||||
print(f"Total tiles: {len(COMMON_TILES)}")
|
||||
print(f"Backward weight tiles: {len(BWD_WEIGHT_TILES)}")
|
||||
print()
|
||||
print("Tile sizes (M×N×K):")
|
||||
for tile in COMMON_TILES:
|
||||
m, n, k = tile
|
||||
wave = TILE_TO_WAVE[tile]
|
||||
warp = TILE_TO_WARP[tile]
|
||||
print(
|
||||
f" {m:3}×{n:3}×{k:3} wave={wave[0]}×{wave[1]}×{wave[2]} warp={warp[0]}×{warp[1]}×{warp[2]}"
|
||||
)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_summary()
|
||||
@@ -41,6 +41,26 @@ except ImportError:
|
||||
ArchFilter = None
|
||||
OperatorType = None
|
||||
|
||||
# Import tile configurations from grouped_config_rules (single source of truth)
|
||||
try:
|
||||
from grouped_config_rules import (
|
||||
COMMON_TILES,
|
||||
TILE_TO_WAVE,
|
||||
TILE_TO_WARP,
|
||||
VARIANT_PIPELINES,
|
||||
BWD_WEIGHT_TILES,
|
||||
COMPV4_COMPATIBLE_TILES,
|
||||
)
|
||||
HAS_TILE_CONFIGS = True
|
||||
except ImportError:
|
||||
HAS_TILE_CONFIGS = False
|
||||
COMMON_TILES = []
|
||||
TILE_TO_WAVE = {}
|
||||
TILE_TO_WARP = {}
|
||||
VARIANT_PIPELINES = {}
|
||||
BWD_WEIGHT_TILES = []
|
||||
COMPV4_COMPATIBLE_TILES = []
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration and Data Structures
|
||||
@@ -494,6 +514,21 @@ struct {kernel_name}_Config {{
|
||||
# Create valid C++ namespace name
|
||||
ns_name = "ns_" + kernel_name.replace("-", "_")
|
||||
|
||||
# basic_v1 / basic_async_v1 inherit BaseGemmPipelineAGmemBGmemCRegV1
|
||||
# whose TailHandler takes (run_func, has_hot_loop) and invokes
|
||||
# run_func(bool_constant<...>) -- 1 lambda arg. Other pipelines pass
|
||||
# (run_func, has_hot_loop, tail_number) and invoke 2-arg run_func.
|
||||
if tr.pipeline in ("basic_v1", "basic_async_v1"):
|
||||
tail_handler_call = "BaseGemmPipeline::TailHandler(Run, has_hot_loop);"
|
||||
run_lambda_signature = "[&](const auto has_hot_loop_)"
|
||||
else:
|
||||
tail_handler_call = (
|
||||
"BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);"
|
||||
)
|
||||
run_lambda_signature = (
|
||||
"[&](const auto has_hot_loop_, const auto tail_number_)"
|
||||
)
|
||||
|
||||
return f"""
|
||||
// Unique namespace for this kernel to avoid conflicts when including multiple kernels
|
||||
namespace {ns_name} {{
|
||||
@@ -605,7 +640,7 @@ struct {kernel_name}_Launcher {{
|
||||
using Kernel = {kernel_type}<
|
||||
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||||
const auto Run = {run_lambda_signature} {{
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
if (!Kernel::IsSupportedArgument(kargs)) {{
|
||||
@@ -621,7 +656,7 @@ struct {kernel_name}_Launcher {{
|
||||
return ave_time;
|
||||
}};
|
||||
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
{tail_handler_call}
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
@@ -1021,7 +1056,10 @@ def get_default_configs(
|
||||
variants: Optional[List[GroupedConvVariant]] = None,
|
||||
ndims: Optional[List[int]] = None,
|
||||
) -> List[GroupedConvKernelConfig]:
|
||||
"""Get default grouped convolution configurations for target architecture"""
|
||||
"""Get default grouped convolution configurations for target architecture.
|
||||
|
||||
Uses tile configurations from grouped_conv_instance_builder.py as single source of truth.
|
||||
"""
|
||||
configs = []
|
||||
|
||||
if variants is None:
|
||||
@@ -1029,39 +1067,53 @@ def get_default_configs(
|
||||
if ndims is None:
|
||||
ndims = [2]
|
||||
|
||||
# Valid configurations per variant (based on CK Tile example configs)
|
||||
# Forward and Backward Data: standard GEMM-like tiles
|
||||
fwd_bwd_data_tiles = [
|
||||
# (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
|
||||
(128, 128, 32, 2, 2, 32, 32, 16), # Standard 128x128
|
||||
(256, 256, 32, 2, 2, 32, 32, 16), # Large 256x256
|
||||
(64, 64, 32, 1, 4, 16, 16, 16), # Small 64x64
|
||||
(128, 64, 32, 2, 2, 32, 32, 16), # Rectangular
|
||||
(16, 64, 64, 1, 4, 16, 16, 32), # Tall and narrow
|
||||
]
|
||||
# Import tile configs from instance builder (single source of truth)
|
||||
if not HAS_TILE_CONFIGS or not COMMON_TILES:
|
||||
log.warning("grouped_config_rules not available, using fallback tile configs")
|
||||
# Fallback to minimal set if grouped_config_rules unavailable
|
||||
fwd_bwd_data_tiles = [
|
||||
(128, 128, 32, 2, 2, 32, 32, 16),
|
||||
(64, 64, 32, 1, 4, 16, 16, 16),
|
||||
(16, 64, 64, 1, 4, 16, 16, 32),
|
||||
]
|
||||
bwd_weight_tiles = [(16, 64, 64, 1, 4, 16, 16, 32)]
|
||||
else:
|
||||
# Build tile list from COMMON_TILES with wave/warp mappings
|
||||
fwd_bwd_data_tiles = []
|
||||
for tile_m, tile_n, tile_k in COMMON_TILES:
|
||||
tile_key = (tile_m, tile_n, tile_k)
|
||||
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
|
||||
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
|
||||
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
|
||||
fwd_bwd_data_tiles.append(
|
||||
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
|
||||
)
|
||||
|
||||
# Backward Weight: VERY specific tile configs that work with CK Tile's bwd_weight kernel
|
||||
# Based on ConvConfigComputeV3 from CK Tile examples (example/ck_tile/20_grouped_convolution/)
|
||||
# Note: Backward weight has strict constraints on warp configurations due to transpose_tile2d
|
||||
# Only specific warp configs work: (1, 4, 1) and (4, 1, 1) are known to work
|
||||
bwd_weight_tiles = [
|
||||
# (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
|
||||
# ConvConfigComputeV3: The primary working config for backward weight
|
||||
(16, 64, 64, 1, 4, 16, 16, 32),
|
||||
]
|
||||
# Backward weight: use BWD_WEIGHT_TILES from config rules
|
||||
bwd_weight_tiles = []
|
||||
for tile_m, tile_n, tile_k in BWD_WEIGHT_TILES:
|
||||
tile_key = (tile_m, tile_n, tile_k)
|
||||
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
|
||||
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
|
||||
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
|
||||
bwd_weight_tiles.append(
|
||||
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
|
||||
)
|
||||
|
||||
for variant in variants:
|
||||
# Select tile configs based on variant
|
||||
if variant == GroupedConvVariant.BACKWARD_WEIGHT:
|
||||
tile_configs = bwd_weight_tiles
|
||||
# Backward weight ONLY supports compv3 (compv4/compv5 have transpose_tile2d issues)
|
||||
pipelines = [("compv3", "cshuffle")]
|
||||
# Backward weight supports compv3 and mem pipelines
|
||||
# (compv4/compv5 have transpose_tile2d issues)
|
||||
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
|
||||
# Also generate two-stage variants (fp32 workspace + elementwise convert)
|
||||
two_stage_flags = [False, True]
|
||||
elif variant == GroupedConvVariant.BACKWARD_DATA:
|
||||
tile_configs = fwd_bwd_data_tiles
|
||||
# Backward data ONLY supports compv3 (compv4 has get_length issues in bwd_data kernel)
|
||||
pipelines = [("compv3", "cshuffle")]
|
||||
# Backward data supports compv3 and mem pipelines
|
||||
# (compv4/compv5 have get_length issues in bwd_data kernel)
|
||||
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
|
||||
two_stage_flags = [False]
|
||||
else:
|
||||
tile_configs = fwd_bwd_data_tiles
|
||||
@@ -1080,6 +1132,12 @@ def get_default_configs(
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) in tile_configs:
|
||||
# Skip tiles incompatible with compv4
|
||||
if pipeline == "compv4" and HAS_TILE_CONFIGS:
|
||||
tile_key = (tile_m, tile_n, tile_k)
|
||||
if tile_key not in COMPV4_COMPATIBLE_TILES:
|
||||
continue # Skip this tile for compv4
|
||||
|
||||
for two_stage in two_stage_flags:
|
||||
adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k
|
||||
|
||||
@@ -1609,7 +1667,16 @@ def main():
|
||||
parser.add_argument(
|
||||
"--pipeline",
|
||||
type=str,
|
||||
choices=["mem", "compv3", "compv4", "compv5"],
|
||||
choices=[
|
||||
"basic_v1",
|
||||
"basic_async_v1",
|
||||
"mem",
|
||||
"compv3",
|
||||
"compv4",
|
||||
"compv5",
|
||||
"compv6",
|
||||
"comp_async",
|
||||
],
|
||||
help="Pipeline type",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -1642,6 +1709,16 @@ def main():
|
||||
default=None,
|
||||
help="Double SMEM buffer (true/false)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-image",
|
||||
action="store_true",
|
||||
help="Enable split-image (EnableSplitImage) for large spatial tensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--two-stage",
|
||||
action="store_true",
|
||||
help="Enable two-stage bwd_weight (fp32 workspace + elementwise convert)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1679,7 +1756,13 @@ def main():
|
||||
if args.double_smem_buffer is not None:
|
||||
dsb = args.double_smem_buffer.lower() == "true"
|
||||
else:
|
||||
dsb = pipeline == "compv4" # compv4 requires double buffer
|
||||
# Historical default: only compv4 auto-defaults to dsb=true.
|
||||
# Other pipelines that also require DoubleSmemBuffer (e.g. comp_async)
|
||||
# must be told explicitly via --double-smem-buffer true; otherwise
|
||||
# they will fail loudly at the pipeline header static_assert. This
|
||||
# is intentional -- silent fallback to a different config would
|
||||
# mask the user's input.
|
||||
dsb = pipeline == "compv4"
|
||||
|
||||
trait = GroupedConvTraitConfig(
|
||||
pipeline=pipeline,
|
||||
@@ -1690,6 +1773,8 @@ def main():
|
||||
pad_k=args.pad_k,
|
||||
double_smem_buffer=dsb,
|
||||
num_groups_to_merge=args.num_groups_to_merge,
|
||||
split_image=args.split_image,
|
||||
two_stage=args.two_stage,
|
||||
)
|
||||
config = GroupedConvKernelConfig(
|
||||
tile=tile,
|
||||
@@ -1719,18 +1804,20 @@ def main():
|
||||
print(f" Spatial dims: {args.ndim}")
|
||||
print(f"\nConfigurations ({len(filtered_configs)}):")
|
||||
for cfg in filtered_configs:
|
||||
print(f" - {cfg.name('fp16')}")
|
||||
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
|
||||
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
|
||||
print(
|
||||
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
|
||||
)
|
||||
print(
|
||||
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
|
||||
)
|
||||
print(
|
||||
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
|
||||
)
|
||||
# List configs for each requested datatype (fixes bf16 -> fp16 bug)
|
||||
for dt in args.datatype:
|
||||
print(f" - {cfg.name(dt)}")
|
||||
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
|
||||
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
|
||||
print(
|
||||
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
|
||||
)
|
||||
print(
|
||||
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
|
||||
)
|
||||
print(
|
||||
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
|
||||
)
|
||||
return
|
||||
|
||||
# Generate
|
||||
|
||||
@@ -290,7 +290,7 @@ function(add_declarative_gpu_example NAME SOURCE)
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/example_kernel_builder.py
|
||||
${EXAMPLE_SOURCE}
|
||||
--output-dir ${EXAMPLE_KERNEL_DIR}
|
||||
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include"
|
||||
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include,${CMAKE_CURRENT_SOURCE_DIR}/../.."
|
||||
--gpu-target ${GPU_TARGET}
|
||||
--jobs ${NPROC}
|
||||
--target-name ${NAME}
|
||||
@@ -456,7 +456,47 @@ add_declarative_gpu_example(grouped_conv_06_bwd_weight grouped_conv/cpp/06_bw
|
||||
add_declarative_gpu_example(grouped_conv_07_benchmark grouped_conv/cpp/07_multi_tile_benchmark.cpp)
|
||||
|
||||
# =============================================================================
|
||||
# Grouped Convolution Python Library - Multi-Kernel (fwd/bwd_data/bwd_weight x 2D/3D)
|
||||
# FMHA C++ Examples
|
||||
# =============================================================================
|
||||
|
||||
add_declarative_gpu_example(fmha_01_basic fmha/cpp/01_basic_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_02_splitkv fmha/cpp/02_splitkv_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_03_kvcache fmha/cpp/03_kvcache_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_04_bwd fmha/cpp/04_bwd_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_05_appendkv fmha/cpp/05_appendkv_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_06_batch_prefill fmha/cpp/06_batch_prefill_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_07_profile_pytorch fmha/cpp/07_profile_pytorch_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_08_profile_flash fmha/cpp/08_profile_flash_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_09_profile_aiter fmha/cpp/09_profile_aiter_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_10_profile_fp32_fp8 fmha/cpp/10_profile_fp32_fp8_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_11_receipt_aliases fmha/cpp/11_receipt_aliases_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_12_registry_json fmha/cpp/12_registry_json_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_13_feature_coverage fmha/cpp/13_feature_coverage_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_14_benchmark_validation fmha/cpp/14_benchmark_validation_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_15_multi_shape fmha/cpp/15_multi_shape_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_16_heuristics fmha/cpp/16_heuristics_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_17_autofill_autocorrect fmha/cpp/17_autofill_autocorrect_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_18_gpu_splitkv fmha/cpp/18_gpu_splitkv_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_19_gpu_masks fmha/cpp/19_gpu_masks_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_20_gpu_bias fmha/cpp/20_gpu_bias_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_21_gpu_features fmha/cpp/21_gpu_features_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_22_gpu_bwd fmha/cpp/22_gpu_bwd_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_23_multi_registry fmha/cpp/23_multi_registry_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_24_per_receipt_registries fmha/cpp/24_per_receipt_registries_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_25_gpu_appendkv_prefill fmha/cpp/25_gpu_appendkv_batchprefill_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_26_dtypes_hdims fmha/cpp/26_dtypes_hdims_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_27_padding_permutation fmha/cpp/27_padding_permutation_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_28_bwd_masks fmha/cpp/28_bwd_masks_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_29_bwd_bias_dropout fmha/cpp/29_bwd_bias_dropout_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_30_bwd_benchmark fmha/cpp/30_bwd_benchmark_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_31_logits_soft_cap fmha/cpp/31_logits_soft_cap_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_32_sink_tokens fmha/cpp/32_sink_tokens_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_33_bwd_deterministic fmha/cpp/33_bwd_deterministic_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_34_bwd_gqa fmha/cpp/34_bwd_gqa_fmha.cpp)
|
||||
add_declarative_gpu_example(fmha_35_generic_mask fmha/cpp/35_generic_mask_fmha.cpp)
|
||||
|
||||
# =============================================================================
|
||||
# Grouped Convolution Python Library - Multi-Kernel (fwd/bwdd/bwdw x 2D/3D)
|
||||
# =============================================================================
|
||||
|
||||
# Kernel output directory for the Python conv library
|
||||
@@ -502,13 +542,67 @@ if(hip_FOUND)
|
||||
endif()
|
||||
add_dependencies(dispatcher_conv_lib generate_conv_fallback_kernels)
|
||||
|
||||
# =============================================================================
|
||||
# FMHA Python Library - Single Fallback Kernel
|
||||
# =============================================================================
|
||||
|
||||
set(FMHA_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/fmha_python_fallback")
|
||||
set(FMHA_DISPATCH_HEADER "${FMHA_FALLBACK_KERNEL_DIR}/fmha_python_dispatch.hpp")
|
||||
set(FMHA_FALLBACK_LIB "${FMHA_FALLBACK_KERNEL_DIR}/libfmha_python_fallback.a")
|
||||
set(FMHA_FALLBACK_SENTINEL "${FMHA_FALLBACK_KERNEL_DIR}/.generated")
|
||||
|
||||
# Generate the FMHA fallback kernel, compile it, and produce both
|
||||
# the dispatch header and a static library with the kernel object.
|
||||
# Uses example_kernel_builder.py with a synthetic source that declares
|
||||
# a single FMHA kernel set, just like the C++ examples do.
|
||||
set(FMHA_FALLBACK_SOURCE "${FMHA_FALLBACK_KERNEL_DIR}/fmha_python_fallback.cpp")
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_DISPATCH_HEADER} ${FMHA_FALLBACK_LIB} ${FMHA_FALLBACK_SENTINEL}
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${FMHA_FALLBACK_KERNEL_DIR}
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/fmha/generate_fallback.py
|
||||
--output-dir ${FMHA_FALLBACK_KERNEL_DIR}
|
||||
--gpu-target ${GPU_TARGET}
|
||||
--compile
|
||||
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include:${CMAKE_CURRENT_SOURCE_DIR}/../include:${CMAKE_CURRENT_SOURCE_DIR}/../.."
|
||||
COMMAND ${CMAKE_COMMAND} -E touch ${FMHA_FALLBACK_SENTINEL}
|
||||
COMMENT "Generating and compiling FMHA fallback kernel for Python library..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_fmha_fallback_kernels
|
||||
DEPENDS ${FMHA_DISPATCH_HEADER} ${FMHA_FALLBACK_LIB})
|
||||
|
||||
# FMHA dynamic library for Python
|
||||
add_library(dispatcher_fmha_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/fmha_ctypes_lib.cpp)
|
||||
target_link_libraries(dispatcher_fmha_lib PRIVATE ck_tile_dispatcher ${FMHA_FALLBACK_LIB})
|
||||
target_include_directories(dispatcher_fmha_lib PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../..
|
||||
${FMHA_FALLBACK_KERNEL_DIR}
|
||||
${FMHA_FALLBACK_KERNEL_DIR}/dispatcher_wrappers
|
||||
)
|
||||
target_compile_options(dispatcher_fmha_lib PRIVATE
|
||||
-include ${FMHA_DISPATCH_HEADER}
|
||||
-DGFX_ARCH="${GPU_TARGET}"
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
if(hip_FOUND)
|
||||
target_link_libraries(dispatcher_fmha_lib PRIVATE hip::device hip::host)
|
||||
endif()
|
||||
add_dependencies(dispatcher_fmha_lib generate_fmha_fallback_kernels)
|
||||
|
||||
message(STATUS "GEMM examples configured - kernels will be generated during 'make'")
|
||||
message(STATUS "Grouped Conv examples configured - kernels will be generated during 'make'")
|
||||
message(STATUS "FMHA examples configured - kernels will be generated during 'make'")
|
||||
|
||||
# Convenience target to build all Python ctypes libraries
|
||||
add_custom_target(python_libs
|
||||
DEPENDS dispatcher_gemm_lib dispatcher_conv_lib
|
||||
COMMENT "Building Python ctypes libraries (GEMM + Conv)"
|
||||
DEPENDS dispatcher_gemm_lib dispatcher_conv_lib dispatcher_fmha_lib
|
||||
COMMENT "Building Python ctypes libraries (GEMM + Conv + FMHA)"
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -59,9 +59,17 @@ python3 examples/gemm/python/08_heuristics.py
|
||||
```
|
||||
examples/
|
||||
|---- gemm/
|
||||
| |---- cpp/ # 6 C++ GEMM examples
|
||||
| |---- cpp/ # 7 C++ GEMM examples
|
||||
| +---- python/ # 11 Python GEMM examples
|
||||
|
|
||||
|---- grouped_conv/
|
||||
| |---- cpp/ # 7 C++ Grouped Conv examples
|
||||
| +---- python/ # 6 Python Grouped Conv examples
|
||||
|
|
||||
|---- fmha/
|
||||
| |---- cpp/ # 35 C++ FMHA examples (all variants)
|
||||
| +---- python/ # 38 Python FMHA examples (JIT-compiled)
|
||||
|
|
||||
+---- README.md
|
||||
```
|
||||
|
||||
|
||||
371
dispatcher/examples/fmha/cpp/01_basic_fmha.cpp
Normal file
371
dispatcher/examples/fmha/cpp/01_basic_fmha.cpp
Normal file
@@ -0,0 +1,371 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 01: Basic FMHA Forward with GPU Execution
|
||||
//
|
||||
// Demonstrates the full flow:
|
||||
// 1. Declare kernels via DECL_FMHA_KERNEL_SET
|
||||
// 2. Register and plan
|
||||
// 3. Allocate Q, K, V, O GPU buffers
|
||||
// 4. Run the FMHA forward kernel on GPU
|
||||
// 5. Copy output to host and validate against CPU reference
|
||||
//
|
||||
// Mirrors 01_basic_gemm.cpp for FMHA.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
// FMHA tile/wave/warp dimensions correspond to TWO GEMM stages:
|
||||
// Stage 0 (Q * K^T): tile_m0 x tile_n0 x tile_k0 (seqlen_q x seqlen_k x hdim_q)
|
||||
// Stage 1 (Attn * V): tile_m0 x tile_n1 x tile_k1 (seqlen_q x hdim_v x seqlen_k)
|
||||
// Wave/warp follow the same stage pattern: *_m0/n0/k0 for stage 0, *_m1/n1/k1 for stage 1.
|
||||
DECL_FMHA_KERNEL_SET(basic_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r") // V row-major
|
||||
.hdim(128) // hdim_q = hdim_v = 128
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 tile: seqlen_q=128, seqlen_k=128, hdim_q=32
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 tile: hdim_v=128, seqlen_k=32, alignment=128
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
// Wave: 4 warps on m, 1 on n, 1 on k (both stages)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
// Warp tile: 32x32x16 (both stages)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true) // pad_s, pad_sk, pad_d, pad_dv
|
||||
.alignments(128, 128) // hdim_q_alignment, hdim_v_alignment
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 01: FMHA Forward (GPU Execution)", "FMHA with real GPU data");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_flag("--validate", "Validate against CPU reference");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 01: FMHA Forward (GPU Execution)");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("basic_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// Step 2: Plan
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t k_elems = q_elems;
|
||||
const int64_t v_elems = q_elems;
|
||||
const int64_t o_elems = q_elems;
|
||||
|
||||
// Step 3: Allocate GPU buffers
|
||||
std::cout << "\nStep 2: Allocate GPU Buffers\n";
|
||||
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
|
||||
<< "]\n";
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(k_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(v_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(o_elems);
|
||||
|
||||
// Fill Q, K, V with random data
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems);
|
||||
std::vector<FmhaDataType> k_host(k_elems);
|
||||
std::vector<FmhaDataType> v_host(v_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
|
||||
// Step 4: Set up args with device pointers and strides
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.q_ptr = q_dev.get();
|
||||
fmha_args.k_ptr = k_dev.get();
|
||||
fmha_args.v_ptr = v_dev.get();
|
||||
fmha_args.o_ptr = o_dev.get();
|
||||
|
||||
fmha_args.bias_ptr = nullptr;
|
||||
fmha_args.q_descale_ptr = nullptr;
|
||||
fmha_args.k_descale_ptr = nullptr;
|
||||
fmha_args.v_descale_ptr = nullptr;
|
||||
fmha_args.rand_val_ptr = nullptr;
|
||||
fmha_args.lse_ptr = nullptr;
|
||||
fmha_args.sink_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = seqlen;
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.scale_s = scale;
|
||||
fmha_args.logits_soft_cap = 0.0f;
|
||||
|
||||
// bhsd layout strides
|
||||
fmha_args.stride_q = hdim;
|
||||
fmha_args.stride_k = hdim;
|
||||
fmha_args.stride_v = hdim;
|
||||
fmha_args.stride_bias = 0;
|
||||
fmha_args.stride_randval = 0;
|
||||
fmha_args.stride_o = hdim;
|
||||
|
||||
fmha_args.nhead_stride_q = seqlen * hdim;
|
||||
fmha_args.nhead_stride_k = seqlen * hdim;
|
||||
fmha_args.nhead_stride_v = seqlen * hdim;
|
||||
fmha_args.nhead_stride_bias = 0;
|
||||
fmha_args.nhead_stride_randval = 0;
|
||||
fmha_args.nhead_stride_lse = 0;
|
||||
fmha_args.nhead_stride_o = seqlen * hdim;
|
||||
fmha_args.nhead_stride_q_descale = 0;
|
||||
fmha_args.nhead_stride_k_descale = 0;
|
||||
fmha_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fmha_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_bias = 0;
|
||||
fmha_args.batch_stride_randval = 0;
|
||||
fmha_args.batch_stride_lse = 0;
|
||||
fmha_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_q_descale = 0;
|
||||
fmha_args.batch_stride_k_descale = 0;
|
||||
fmha_args.batch_stride_v_descale = 0;
|
||||
|
||||
fmha_args.window_size_left = -1;
|
||||
fmha_args.window_size_right = -1;
|
||||
fmha_args.sink_size = 0;
|
||||
fmha_args.mask_type = 0;
|
||||
fmha_args.min_seqlen_q = 0;
|
||||
fmha_args.p_drop = 0.0f;
|
||||
fmha_args.s_randval = false;
|
||||
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fmha_args.block_scale_size_q = 0;
|
||||
fmha_args.block_scale_size_kv = 0;
|
||||
|
||||
// Step 5: Run on GPU
|
||||
std::cout << "\nStep 3: Run FMHA Forward on GPU\n";
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
|
||||
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Step 6: Copy output and validate
|
||||
std::cout << "\nStep 4: Validate\n";
|
||||
std::vector<FmhaDataType> o_host(o_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
// Quick sanity check: output should be non-zero
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
|
||||
|
||||
bool passed = (nonzero > 0);
|
||||
|
||||
if(args.has("--validate"))
|
||||
{
|
||||
// CPU reference
|
||||
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < k_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < v_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
cpu_attention_fwd(
|
||||
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
double max_rel_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
max_rel_err = std::max(max_rel_err, rel_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Max rel error: " << max_rel_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
|
||||
passed = (errors == 0);
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return passed ? 0 : 1;
|
||||
}
|
||||
162
dispatcher/examples/fmha/cpp/02_splitkv_fmha.cpp
Normal file
162
dispatcher/examples/fmha/cpp/02_splitkv_fmha.cpp
Normal file
@@ -0,0 +1,162 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(splitkv_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_splitkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.paged_kv(false),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true)
|
||||
.max_splits_log2(6)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_splitkv_combine")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.paged_kv(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(32)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true)
|
||||
.max_splits_log2(6)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 02: FMHA Split-KV", "Declarative FMHA split-KV planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "1", "Batch size");
|
||||
args.add_option("--nhead", "16", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Query sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
utils::print_header("Example 02: FMHA Split-KV");
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 1);
|
||||
const int nhead = args.get_int("--nhead", 16);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("splitkv_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
// Step 2: Plan
|
||||
std::cout << "\nStep 2: Plan\n";
|
||||
|
||||
fmha_fwd_splitkv_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = true;
|
||||
traits.do_fp8_static_quant = false;
|
||||
traits.has_sink = false;
|
||||
|
||||
fmha_fwd_splitkv_args fmha_args{};
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = 2048;
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.num_splits = 8;
|
||||
|
||||
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
|
||||
auto plan = dispatcher.plan(problem);
|
||||
|
||||
if(!plan.is_valid() || plan.stages.size() != 2)
|
||||
{
|
||||
std::cerr << "Expected a two-stage split-KV plan\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Step 3: Results
|
||||
std::cout << "\nStep 3: Results\n";
|
||||
for(const auto& stage : plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
|
||||
utils::print_separator();
|
||||
return 0;
|
||||
}
|
||||
240
dispatcher/examples/fmha/cpp/03_kvcache_fmha.cpp
Normal file
240
dispatcher/examples/fmha/cpp/03_kvcache_fmha.cpp
Normal file
@@ -0,0 +1,240 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(kvcache_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_pagedkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.paged_kv(true)
|
||||
.kv_cache("vectorized", "sglang", 16),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_pagedkv")
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_appendkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.rope("inter")
|
||||
.paged_kv(true)
|
||||
.kv_cache("vectorized", "sglang", 16),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(64)
|
||||
.tile_k0(128)
|
||||
.tile_n1(128)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.pipeline("appendkv")
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("batch_prefill")
|
||||
.dtype("fp16")
|
||||
.mode("group")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.paged_kv(true)
|
||||
.kv_cache("vectorized", "sglang", 16),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 03: FMHA KV-Cache", "Declarative FMHA KV-cache planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "1", "Batch size");
|
||||
args.add_option("--nhead", "16", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Prefill query sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
utils::print_header("Example 03: FMHA KV-Cache");
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 1);
|
||||
const int nhead = args.get_int("--nhead", 16);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
// Step 2: Plan PagedKV (decode)
|
||||
std::cout << "\nStep 2: Plan PagedKV (decode)\n";
|
||||
|
||||
fmha_fwd_pagedkv_traits paged_traits{};
|
||||
paged_traits.hdim_q = hdim;
|
||||
paged_traits.hdim_v = hdim;
|
||||
paged_traits.data_type = "fp16";
|
||||
paged_traits.is_group_mode = false;
|
||||
paged_traits.is_v_rowmajor = true;
|
||||
paged_traits.mask_type = mask_enum::no_mask;
|
||||
paged_traits.bias_type = bias_enum::no_bias;
|
||||
paged_traits.use_pagedkv = true;
|
||||
|
||||
fmha_fwd_pagedkv_args paged_args{};
|
||||
paged_args.seqlen_q = 1;
|
||||
paged_args.seqlen_k = 1024;
|
||||
paged_args.batch = batch;
|
||||
paged_args.max_seqlen_q = 1;
|
||||
paged_args.hdim_q = hdim;
|
||||
paged_args.hdim_v = hdim;
|
||||
paged_args.nhead_q = nhead;
|
||||
paged_args.nhead_k = nhead;
|
||||
paged_args.block_table_ptr = reinterpret_cast<void*>(0x1);
|
||||
paged_args.page_block_size = 16;
|
||||
|
||||
auto paged_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(paged_traits, paged_args), gfx_arch));
|
||||
|
||||
// Step 3: Plan AppendKV
|
||||
std::cout << "\nStep 3: Plan AppendKV\n";
|
||||
|
||||
fmha_fwd_appendkv_traits append_traits{};
|
||||
append_traits.hdim_q = hdim;
|
||||
append_traits.hdim_v = hdim;
|
||||
append_traits.data_type = "fp16";
|
||||
append_traits.is_v_rowmajor = true;
|
||||
append_traits.rope_type = rope_enum::interleaved;
|
||||
|
||||
fmha_fwd_appendkv_args append_args{};
|
||||
append_args.seqlen_q = 1;
|
||||
append_args.seqlen_knew = 1;
|
||||
append_args.batch = batch;
|
||||
append_args.hdim_q = hdim;
|
||||
append_args.hdim_v = hdim;
|
||||
append_args.nhead_q = nhead;
|
||||
append_args.nhead_k = nhead;
|
||||
append_args.rotary_dim = hdim;
|
||||
append_args.rotary_cos_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.rotary_sin_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.block_table_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.page_block_size = 16;
|
||||
|
||||
auto append_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(append_traits, append_args), gfx_arch));
|
||||
|
||||
// Step 4: Plan BatchPrefill
|
||||
std::cout << "\nStep 4: Plan BatchPrefill\n";
|
||||
|
||||
fmha_batch_prefill_traits prefill_traits{};
|
||||
prefill_traits.hdim_q = hdim;
|
||||
prefill_traits.hdim_v = hdim;
|
||||
prefill_traits.data_type = "fp16";
|
||||
prefill_traits.is_group_mode = true;
|
||||
prefill_traits.is_v_rowmajor = true;
|
||||
prefill_traits.mask_type = mask_enum::no_mask;
|
||||
prefill_traits.bias_type = bias_enum::no_bias;
|
||||
prefill_traits.has_lse = true;
|
||||
prefill_traits.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
prefill_traits.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
prefill_traits.page_size = 16;
|
||||
|
||||
fmha_batch_prefill_args prefill_args{};
|
||||
prefill_args.batch = batch;
|
||||
prefill_args.seqlen_q = seqlen;
|
||||
prefill_args.seqlen_k = 1024;
|
||||
prefill_args.max_seqlen_q = seqlen;
|
||||
prefill_args.hdim_q = hdim;
|
||||
prefill_args.hdim_v = hdim;
|
||||
prefill_args.nhead_q = nhead;
|
||||
prefill_args.nhead_k = nhead;
|
||||
prefill_args.num_total_pages = 64;
|
||||
prefill_args.page_block_size = 16;
|
||||
prefill_args.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
prefill_args.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
prefill_args.kv_indptr = reinterpret_cast<void*>(0x1);
|
||||
prefill_args.kv_page_indices = reinterpret_cast<void*>(0x1);
|
||||
prefill_args.kv_last_page_lens = reinterpret_cast<void*>(0x1);
|
||||
prefill_args.seqstart_q_ptr = reinterpret_cast<void*>(0x1);
|
||||
|
||||
auto prefill_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(prefill_traits, prefill_args), gfx_arch));
|
||||
|
||||
// Step 5: Results
|
||||
std::cout << "\nStep 5: Results\n";
|
||||
std::cout << " PagedKV stages: " << paged_plan.stages.size() << "\n";
|
||||
std::cout << " AppendKV stages: " << append_plan.stages.size() << "\n";
|
||||
std::cout << " BatchPrefill stages: " << prefill_plan.stages.size() << "\n";
|
||||
|
||||
utils::print_separator();
|
||||
return (paged_plan.is_valid() && append_plan.is_valid() && prefill_plan.is_valid()) ? 0 : 1;
|
||||
}
|
||||
154
dispatcher/examples/fmha/cpp/04_bwd_fmha.cpp
Normal file
154
dispatcher/examples/fmha/cpp/04_bwd_fmha.cpp
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(bwd_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true)
|
||||
.max_seq_len_q(0)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 04: FMHA Backward", "Declarative FMHA backward planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "1", "Batch size");
|
||||
args.add_option("--nhead", "16", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Sequence length (Q and K)");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
utils::print_header("Example 04: FMHA Backward");
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 1);
|
||||
const int nhead = args.get_int("--nhead", 16);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
// Step 2: Plan
|
||||
std::cout << "\nStep 2: Plan\n";
|
||||
|
||||
fmha_bwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_dbias = false;
|
||||
traits.has_dropout = false;
|
||||
traits.is_store_randval = false;
|
||||
traits.is_deterministic = false;
|
||||
|
||||
fmha_bwd_args bwd_args{};
|
||||
bwd_args.batch = batch;
|
||||
bwd_args.seqlen_q = seqlen;
|
||||
bwd_args.seqlen_k = seqlen;
|
||||
bwd_args.max_seqlen_q = seqlen;
|
||||
bwd_args.max_seqlen_k = seqlen;
|
||||
bwd_args.hdim_q = hdim;
|
||||
bwd_args.hdim_v = hdim;
|
||||
bwd_args.nhead_q = nhead;
|
||||
bwd_args.nhead_k = nhead;
|
||||
|
||||
auto plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, bwd_args), gfx_arch));
|
||||
|
||||
if(!plan.is_valid() || plan.stages.size() < 2)
|
||||
{
|
||||
std::cerr << "Expected a multi-stage backward plan\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Step 3: Results
|
||||
std::cout << "\nStep 3: Results\n";
|
||||
for(const auto& stage : plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
|
||||
utils::print_separator();
|
||||
return 0;
|
||||
}
|
||||
106
dispatcher/examples/fmha/cpp/05_appendkv_fmha.cpp
Normal file
106
dispatcher/examples/fmha/cpp/05_appendkv_fmha.cpp
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(appendkv_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_appendkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.rope("inter")
|
||||
.paged_kv(true)
|
||||
.kv_cache("vectorized", "sglang", 16),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(64)
|
||||
.tile_n0(64)
|
||||
.tile_k0(128)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.pipeline("appendkv")
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 05: FMHA AppendKV", "Declarative FMHA append-KV planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "1", "Batch size");
|
||||
args.add_option("--nhead", "16", "Number of heads");
|
||||
args.add_option("--seqlen", "1", "Sequence length (tokens to append)");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
utils::print_header("Example 05: FMHA AppendKV");
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 1);
|
||||
const int nhead = args.get_int("--nhead", 16);
|
||||
const int seqlen = args.get_int("--seqlen", 1);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
// Step 2: Plan
|
||||
std::cout << "\nStep 2: Plan\n";
|
||||
|
||||
fmha_fwd_appendkv_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.rope_type = rope_enum::interleaved;
|
||||
|
||||
fmha_fwd_appendkv_args fmha_args{};
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_knew = seqlen;
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.rotary_dim = hdim;
|
||||
fmha_args.rotary_cos_ptr = reinterpret_cast<void*>(0x1);
|
||||
fmha_args.rotary_sin_ptr = reinterpret_cast<void*>(0x1);
|
||||
fmha_args.block_table_ptr = reinterpret_cast<void*>(0x1);
|
||||
fmha_args.page_block_size = 16;
|
||||
|
||||
auto plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
|
||||
|
||||
if(!plan.is_valid() || plan.stages.size() != 1)
|
||||
{
|
||||
std::cerr << "Expected a single-stage append-KV plan\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Step 3: Results
|
||||
std::cout << "\nStep 3: Results\n";
|
||||
for(const auto& stage : plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
|
||||
utils::print_separator();
|
||||
return 0;
|
||||
}
|
||||
133
dispatcher/examples/fmha/cpp/06_batch_prefill_fmha.cpp
Normal file
133
dispatcher/examples/fmha/cpp/06_batch_prefill_fmha.cpp
Normal file
@@ -0,0 +1,133 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(batch_prefill_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("batch_prefill")
|
||||
.dtype("fp16")
|
||||
.mode("group")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.paged_kv(true)
|
||||
.kv_cache("vectorized", "sglang", 16),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 06: FMHA Batch Prefill",
|
||||
"Declarative FMHA batch-prefill planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "1", "Batch size");
|
||||
args.add_option("--nhead", "16", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Query sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
utils::print_header("Example 06: FMHA Batch Prefill");
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 1);
|
||||
const int nhead = args.get_int("--nhead", 16);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
// Step 2: Plan
|
||||
std::cout << "\nStep 2: Plan\n";
|
||||
|
||||
fmha_batch_prefill_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = true;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = true;
|
||||
traits.kv_memory_layout = ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
traits.kv_lookup_table = ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
traits.page_size = 16;
|
||||
|
||||
fmha_batch_prefill_args fmha_args{};
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = 1024;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.num_total_pages = 64;
|
||||
fmha_args.page_block_size = 16;
|
||||
fmha_args.kv_memory_layout = ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
fmha_args.kv_lookup_table = ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
fmha_args.kv_indptr = reinterpret_cast<void*>(0x1);
|
||||
fmha_args.kv_page_indices = reinterpret_cast<void*>(0x1);
|
||||
fmha_args.kv_last_page_lens = reinterpret_cast<void*>(0x1);
|
||||
fmha_args.seqstart_q_ptr = reinterpret_cast<void*>(0x1);
|
||||
|
||||
auto plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
|
||||
|
||||
if(!plan.is_valid() || plan.stages.size() != 1)
|
||||
{
|
||||
std::cerr << "Expected a single-stage batch-prefill plan\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Step 3: Results
|
||||
std::cout << "\nStep 3: Results\n";
|
||||
for(const auto& stage : plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
|
||||
utils::print_separator();
|
||||
return 0;
|
||||
}
|
||||
248
dispatcher/examples/fmha/cpp/07_profile_pytorch_fmha.cpp
Normal file
248
dispatcher/examples/fmha/cpp/07_profile_pytorch_fmha.cpp
Normal file
@@ -0,0 +1,248 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(pytorch_profile_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("bias")
|
||||
.profile("pytorch"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(32)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_splitkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.profile("pytorch"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true)
|
||||
.max_splits_log2(6),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_splitkv_combine")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.profile("pytorch"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(32)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true)
|
||||
.max_splits_log2(6),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_appendkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.profile("pytorch"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(64)
|
||||
.tile_k0(128)
|
||||
.tile_n1(128)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(false, true, true, false)
|
||||
.pipeline("appendkv"),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.profile("pytorch"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.profile("pytorch"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.profile("pytorch"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 07: PyTorch-Profile FMHA",
|
||||
"Declarative FMHA PyTorch-profile planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
|
||||
FmhaRegistry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
std::cout << "PyTorch-profile FMHA kernels: " << registry.size() << "\n";
|
||||
|
||||
fmha_fwd_traits fwd_traits{};
|
||||
fwd_traits.hdim_q = 128;
|
||||
fwd_traits.hdim_v = 128;
|
||||
fwd_traits.data_type = "fp16";
|
||||
fwd_traits.is_group_mode = false;
|
||||
fwd_traits.is_v_rowmajor = true;
|
||||
fwd_traits.mask_type = mask_enum::no_mask;
|
||||
fwd_traits.bias_type = bias_enum::elementwise_bias;
|
||||
fwd_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.batch = 1;
|
||||
fwd_args.seqlen_q = 128;
|
||||
fwd_args.seqlen_k = 128;
|
||||
fwd_args.max_seqlen_q = 128;
|
||||
fwd_args.hdim_q = 128;
|
||||
fwd_args.hdim_v = 128;
|
||||
fwd_args.nhead_q = 16;
|
||||
fwd_args.nhead_k = 16;
|
||||
|
||||
auto fwd_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch));
|
||||
|
||||
fmha_bwd_traits bwd_traits{};
|
||||
bwd_traits.hdim_q = 128;
|
||||
bwd_traits.hdim_v = 128;
|
||||
bwd_traits.data_type = "fp16";
|
||||
bwd_traits.is_group_mode = false;
|
||||
bwd_traits.mask_type = mask_enum::no_mask;
|
||||
bwd_traits.bias_type = bias_enum::no_bias;
|
||||
|
||||
fmha_bwd_args bwd_args{};
|
||||
bwd_args.batch = 1;
|
||||
bwd_args.seqlen_q = 128;
|
||||
bwd_args.seqlen_k = 128;
|
||||
bwd_args.max_seqlen_q = 128;
|
||||
bwd_args.max_seqlen_k = 128;
|
||||
bwd_args.hdim_q = 128;
|
||||
bwd_args.hdim_v = 128;
|
||||
bwd_args.nhead_q = 16;
|
||||
bwd_args.nhead_k = 16;
|
||||
|
||||
auto bwd_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
|
||||
|
||||
std::cout << "Forward plan stages: " << fwd_plan.stages.size() << "\n";
|
||||
std::cout << "Backward plan stages: " << bwd_plan.stages.size() << "\n";
|
||||
return (fwd_plan.is_valid() && bwd_plan.is_valid()) ? 0 : 1;
|
||||
}
|
||||
165
dispatcher/examples/fmha/cpp/08_profile_flash_fmha.cpp
Normal file
165
dispatcher/examples/fmha/cpp/08_profile_flash_fmha.cpp
Normal file
@@ -0,0 +1,165 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(flash_profile_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.profile("flash_fwd"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(32)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.profile("flash_bwd"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.profile("flash_bwd"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.profile("flash_bwd"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 08: Flash-Profile FMHA",
|
||||
"Declarative FMHA Flash-profile planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
|
||||
FmhaRegistry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
std::cout << "Flash-profile FMHA kernels: " << registry.size() << "\n";
|
||||
|
||||
fmha_fwd_traits fwd_traits{};
|
||||
fwd_traits.hdim_q = 128;
|
||||
fwd_traits.hdim_v = 128;
|
||||
fwd_traits.data_type = "fp16";
|
||||
fwd_traits.is_group_mode = false;
|
||||
fwd_traits.is_v_rowmajor = true;
|
||||
fwd_traits.mask_type = mask_enum::no_mask;
|
||||
fwd_traits.bias_type = bias_enum::alibi;
|
||||
fwd_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.batch = 1;
|
||||
fwd_args.seqlen_q = 128;
|
||||
fwd_args.seqlen_k = 128;
|
||||
fwd_args.max_seqlen_q = 128;
|
||||
fwd_args.hdim_q = 128;
|
||||
fwd_args.hdim_v = 128;
|
||||
fwd_args.nhead_q = 16;
|
||||
fwd_args.nhead_k = 16;
|
||||
|
||||
auto fwd_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch));
|
||||
|
||||
fmha_bwd_traits bwd_traits{};
|
||||
bwd_traits.hdim_q = 128;
|
||||
bwd_traits.hdim_v = 128;
|
||||
bwd_traits.data_type = "fp16";
|
||||
bwd_traits.is_group_mode = false;
|
||||
bwd_traits.mask_type = mask_enum::no_mask;
|
||||
bwd_traits.bias_type = bias_enum::no_bias;
|
||||
|
||||
fmha_bwd_args bwd_args{};
|
||||
bwd_args.batch = 1;
|
||||
bwd_args.seqlen_q = 128;
|
||||
bwd_args.seqlen_k = 128;
|
||||
bwd_args.max_seqlen_q = 128;
|
||||
bwd_args.max_seqlen_k = 128;
|
||||
bwd_args.hdim_q = 128;
|
||||
bwd_args.hdim_v = 128;
|
||||
bwd_args.nhead_q = 16;
|
||||
bwd_args.nhead_k = 16;
|
||||
|
||||
auto bwd_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
|
||||
|
||||
std::cout << "Flash fwd stages: " << fwd_plan.stages.size() << "\n";
|
||||
std::cout << "Flash bwd stages: " << bwd_plan.stages.size() << "\n";
|
||||
return (fwd_plan.is_valid() && bwd_plan.is_valid()) ? 0 : 1;
|
||||
}
|
||||
212
dispatcher/examples/fmha/cpp/09_profile_aiter_fmha.cpp
Normal file
212
dispatcher/examples/fmha/cpp/09_profile_aiter_fmha.cpp
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(
|
||||
aiter_profile_fmha_kernels,
|
||||
.add(FmhaSignature().family("fwd").dtype("fp16").mode("batch").vlayout("r").hdim(128).profile(
|
||||
"aiter_batch"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("group")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.profile("aiter_group"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_pagedkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.paged_kv(true)
|
||||
.profile("aiter_cpp")
|
||||
.kv_cache("vectorized", "sglang", 16),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_pagedkv")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("batch_prefill")
|
||||
.dtype("fp16")
|
||||
.mode("group")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.paged_kv(true)
|
||||
.profile("aiter_cpp")
|
||||
.kv_cache("vectorized", "sglang", 16),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 09: AITER-Profile FMHA",
|
||||
"Declarative FMHA AITER-profile planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
|
||||
FmhaRegistry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
std::cout << "AITER-profile FMHA kernels: " << registry.size() << "\n";
|
||||
|
||||
fmha_fwd_traits batch_traits{};
|
||||
batch_traits.hdim_q = 128;
|
||||
batch_traits.hdim_v = 128;
|
||||
batch_traits.data_type = "fp16";
|
||||
batch_traits.is_group_mode = false;
|
||||
batch_traits.is_v_rowmajor = true;
|
||||
batch_traits.mask_type = mask_enum::no_mask;
|
||||
batch_traits.bias_type = bias_enum::no_bias;
|
||||
batch_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args batch_args{};
|
||||
batch_args.batch = 1;
|
||||
batch_args.seqlen_q = 128;
|
||||
batch_args.seqlen_k = 128;
|
||||
batch_args.max_seqlen_q = 128;
|
||||
batch_args.hdim_q = 128;
|
||||
batch_args.hdim_v = 128;
|
||||
batch_args.nhead_q = 16;
|
||||
batch_args.nhead_k = 16;
|
||||
|
||||
auto batch_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(batch_traits, batch_args), gfx_arch));
|
||||
|
||||
fmha_batch_prefill_traits prefill_traits{};
|
||||
prefill_traits.hdim_q = 128;
|
||||
prefill_traits.hdim_v = 128;
|
||||
prefill_traits.data_type = "fp16";
|
||||
prefill_traits.is_group_mode = true;
|
||||
prefill_traits.is_v_rowmajor = true;
|
||||
prefill_traits.mask_type = mask_enum::no_mask;
|
||||
prefill_traits.bias_type = bias_enum::no_bias;
|
||||
prefill_traits.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
prefill_traits.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
prefill_traits.page_size = 16;
|
||||
|
||||
fmha_batch_prefill_args prefill_args{};
|
||||
prefill_args.batch = 1;
|
||||
prefill_args.seqlen_q = 128;
|
||||
prefill_args.seqlen_k = 1024;
|
||||
prefill_args.max_seqlen_q = 128;
|
||||
prefill_args.hdim_q = 128;
|
||||
prefill_args.hdim_v = 128;
|
||||
prefill_args.nhead_q = 16;
|
||||
prefill_args.nhead_k = 16;
|
||||
prefill_args.num_total_pages = 64;
|
||||
prefill_args.page_block_size = 16;
|
||||
prefill_args.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
prefill_args.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
prefill_args.kv_indptr = reinterpret_cast<void*>(0x1);
|
||||
prefill_args.kv_page_indices = reinterpret_cast<void*>(0x1);
|
||||
prefill_args.kv_last_page_lens = reinterpret_cast<void*>(0x1);
|
||||
prefill_args.seqstart_q_ptr = reinterpret_cast<void*>(0x1);
|
||||
|
||||
auto prefill_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(prefill_traits, prefill_args), gfx_arch));
|
||||
|
||||
std::cout << "AITER batch stages: " << batch_plan.stages.size() << "\n";
|
||||
std::cout << "AITER prefill stages: " << prefill_plan.stages.size() << "\n";
|
||||
return (batch_plan.is_valid() && prefill_plan.is_valid()) ? 0 : 1;
|
||||
}
|
||||
152
dispatcher/examples/fmha/cpp/10_profile_fp32_fp8_fmha.cpp
Normal file
152
dispatcher/examples/fmha/cpp/10_profile_fp32_fp8_fmha.cpp
Normal file
@@ -0,0 +1,152 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(fp32_fp8_profile_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp32")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.profile("fp32_min"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(32)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(16)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(2)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(2)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp32")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(48)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.profile("fp32_all"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(32)
|
||||
.tile_n0(128)
|
||||
.tile_k0(16)
|
||||
.tile_n1(48)
|
||||
.tile_k1(16)
|
||||
.tile_k0max(48)
|
||||
.wave_m0(2)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(2)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp8bf16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.profile("fp8_test"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(32)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(32)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 10: FP32/FP8-Profile FMHA",
|
||||
"Declarative FMHA FP32/FP8-profile planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
|
||||
FmhaRegistry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
std::cout << "FP32/FP8-profile FMHA kernels: " << registry.size() << "\n";
|
||||
std::cout << registry.export_json(false) << "\n";
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = 128;
|
||||
traits.hdim_v = 128;
|
||||
traits.data_type = "fp32";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.batch = 1;
|
||||
fmha_args.seqlen_q = 128;
|
||||
fmha_args.seqlen_k = 128;
|
||||
fmha_args.max_seqlen_q = 128;
|
||||
fmha_args.hdim_q = 128;
|
||||
fmha_args.hdim_v = 128;
|
||||
fmha_args.nhead_q = 16;
|
||||
fmha_args.nhead_k = 16;
|
||||
|
||||
auto plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
|
||||
|
||||
std::cout << "FP32/FP8-profile plan stages: " << plan.stages.size() << "\n";
|
||||
return plan.is_valid() ? 0 : 1;
|
||||
}
|
||||
176
dispatcher/examples/fmha/cpp/11_receipt_aliases_fmha.cpp
Normal file
176
dispatcher/examples/fmha/cpp/11_receipt_aliases_fmha.cpp
Normal file
@@ -0,0 +1,176 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(receipt_alias_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.bias("alibi")
|
||||
.receipt(2),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(32)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.bias("bias")
|
||||
.receipt(4),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(32)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.receipt(100),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp32")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.receipt(800),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(32)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(16)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(2)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(2)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 11: Receipt Aliases FMHA",
|
||||
"Declarative FMHA receipt-alias planning");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
|
||||
FmhaRegistry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
std::cout << "Receipt-alias FMHA kernels: " << registry.size() << "\n";
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = 128;
|
||||
traits.hdim_v = 128;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.batch = 1;
|
||||
fmha_args.seqlen_q = 128;
|
||||
fmha_args.seqlen_k = 128;
|
||||
fmha_args.max_seqlen_q = 128;
|
||||
fmha_args.hdim_q = 128;
|
||||
fmha_args.hdim_v = 128;
|
||||
fmha_args.nhead_q = 16;
|
||||
fmha_args.nhead_k = 16;
|
||||
|
||||
auto plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
|
||||
|
||||
std::cout << "Receipt-alias plan stages: " << plan.stages.size() << "\n";
|
||||
return plan.is_valid() ? 0 : 1;
|
||||
}
|
||||
129
dispatcher/examples/fmha/cpp/12_registry_json_fmha.cpp
Normal file
129
dispatcher/examples/fmha/cpp/12_registry_json_fmha.cpp
Normal file
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(
|
||||
registry_json_fmha_kernels,
|
||||
.add(FmhaSignature().family("fwd").dtype("fp16").mode("batch").vlayout("r").hdim(128),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_pagedkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.paged_kv(true)
|
||||
.kv_cache("vectorized", "sglang", 16),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_pagedkv")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950")
|
||||
.add(FmhaSignature().family("bwd_dq_dk_dv").dtype("fp16").mode("batch").hdim(128),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 12: Registry JSON FMHA",
|
||||
"Declarative FMHA registry JSON export");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--output", "", "Write JSON to file (optional)");
|
||||
if(!args.parse(argc, argv))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
utils::print_header("Example 12: Registry JSON FMHA");
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const std::string output_path = args.get("--output", "");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("registry_json_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
// Step 2: Export JSON
|
||||
std::cout << "\nStep 2: Export JSON\n";
|
||||
std::string json = registry.export_json(true);
|
||||
std::cout << " JSON size: " << json.size() << " bytes\n";
|
||||
std::cout << json.substr(0, std::min<std::size_t>(json.size(), 240)) << "\n";
|
||||
|
||||
// Step 3: Write to file (if --output specified)
|
||||
if(!output_path.empty())
|
||||
{
|
||||
std::cout << "\nStep 3: Write to File\n";
|
||||
std::ofstream ofs(output_path);
|
||||
if(!ofs.is_open())
|
||||
{
|
||||
std::cerr << " ERROR: Cannot open " << output_path << " for writing\n";
|
||||
return 1;
|
||||
}
|
||||
ofs << json;
|
||||
ofs.close();
|
||||
std::cout << " Written to: " << output_path << "\n";
|
||||
std::cout << " File size: " << json.size() << " bytes\n";
|
||||
}
|
||||
|
||||
utils::print_separator();
|
||||
return registry.size() > 0 ? 0 : 1;
|
||||
}
|
||||
499
dispatcher/examples/fmha/cpp/13_feature_coverage_fmha.cpp
Normal file
499
dispatcher/examples/fmha/cpp/13_feature_coverage_fmha.cpp
Normal file
@@ -0,0 +1,499 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 13: FMHA Feature Coverage
|
||||
// Exercises every feature dimension from the 01_fmha smoke test:
|
||||
// bf16, masks (top-left, bottom-right, window_generic), GQA, dropout,
|
||||
// multiple hdims (64, 256), group mode, col-major V.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(feature_coverage_kernels,
|
||||
// fp16 forward (basic, needed for GQA and other fp16 tests)
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// bf16 forward
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("bf16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// hdim 64
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(64)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(64)
|
||||
.tile_k0(32)
|
||||
.tile_n1(64)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(64)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(64, 64)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// hdim 256
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(256)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(256)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(256)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(false, false, false, false)
|
||||
.alignments(256, 256)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Mask: causal (top-left and bottom-right share the same compiled kernel;
|
||||
// the mask type is resolved at runtime via the args, not the template)
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Dropout
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(true)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// GQA (nhead_q != nhead_k) - same kernel, GQA is a runtime concern
|
||||
// Bias: elementwise
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("bias")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Bias: alibi
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Group mode
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("group")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Sink tokens
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.sink(true),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
struct FeatureTest
|
||||
{
|
||||
std::string name;
|
||||
FmhaProblem problem;
|
||||
};
|
||||
|
||||
FeatureTest make_test(const std::string& name,
|
||||
const std::string& dtype,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
int mask,
|
||||
int bias,
|
||||
bool lse,
|
||||
bool dropout,
|
||||
bool group,
|
||||
bool logits,
|
||||
bool sink,
|
||||
int nhead_q = 16,
|
||||
int nhead_k = 16,
|
||||
const std::string& arch = "gfx950")
|
||||
{
|
||||
auto p = FmhaProblemBuilder()
|
||||
.api_family(FmhaApiFamily::Fwd)
|
||||
.kernel_family(FmhaKernelFamily::Fwd)
|
||||
.gfx_arch(arch)
|
||||
.data_type(dtype)
|
||||
.dims(hdim_q, hdim_v, 2, 128, 256)
|
||||
.nheads(nhead_q, nhead_k)
|
||||
.mask_type(mask)
|
||||
.bias_type(bias)
|
||||
.lse(lse)
|
||||
.dropout(dropout)
|
||||
.group_mode(group)
|
||||
.logits_soft_cap(logits)
|
||||
.sink(sink)
|
||||
.build();
|
||||
return {name, p};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
utils::ExampleArgs args("Example 13: FMHA Feature Coverage",
|
||||
"Tests all 01_fmha smoke test features");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
utils::print_header("Example 13: FMHA Feature Coverage");
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("feature_coverage");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
// Step 2: Run feature tests
|
||||
std::cout << "\nStep 2: Run Feature Tests\n";
|
||||
std::vector<FeatureTest> tests = {
|
||||
make_test("bf16_basic", "bf16", 128, 128, 0, 0, false, false, false, false, false),
|
||||
make_test("fp16_hdim64", "fp16", 64, 64, 0, 0, false, false, false, false, false),
|
||||
make_test("fp16_hdim256", "fp16", 256, 256, 0, 0, true, false, false, false, false),
|
||||
make_test("mask_top_left", "fp16", 128, 128, 1, 0, false, false, false, false, false),
|
||||
make_test("mask_bottom_right", "fp16", 128, 128, 2, 0, false, false, false, false, false),
|
||||
make_test("dropout", "fp16", 128, 128, 0, 0, true, true, false, false, false),
|
||||
make_test("gqa_h16_hk4", "fp16", 128, 128, 0, 0, false, false, false, false, false, 16, 4),
|
||||
make_test("bias_elementwise", "fp16", 128, 128, 0, 1, false, false, false, false, false),
|
||||
make_test("bias_alibi", "fp16", 128, 128, 0, 2, false, false, false, false, false),
|
||||
make_test("group_mode", "fp16", 128, 128, 0, 0, false, false, true, false, false),
|
||||
make_test("sink_tokens", "fp16", 128, 128, 1, 0, false, false, false, false, true),
|
||||
};
|
||||
|
||||
int pass = 0;
|
||||
int fail = 0;
|
||||
for(const auto& test : tests)
|
||||
{
|
||||
auto plan = dispatcher.plan(test.problem);
|
||||
bool ok = plan.is_valid();
|
||||
std::cout << (ok ? "[PASS]" : "[FAIL]") << " " << test.name;
|
||||
if(ok)
|
||||
{
|
||||
std::cout << " -> " << plan.stages[0].kernel_id;
|
||||
++pass;
|
||||
}
|
||||
else
|
||||
{
|
||||
++fail;
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
// Step 3: Summary
|
||||
std::cout << "\nStep 3: Summary\n";
|
||||
std::cout << " " << pass << " passed, " << fail << " failed out of " << tests.size() << "\n";
|
||||
|
||||
utils::print_separator();
|
||||
return fail > 0 ? 1 : 0;
|
||||
}
|
||||
404
dispatcher/examples/fmha/cpp/14_benchmark_validation_fmha.cpp
Normal file
404
dispatcher/examples/fmha/cpp/14_benchmark_validation_fmha.cpp
Normal file
@@ -0,0 +1,404 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 14: FMHA Benchmark with Validation
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Warmup runs to stabilize GPU clocks
|
||||
// 2. Repeated benchmark runs with statistics (min/avg/max/median)
|
||||
// 3. Optional CPU reference validation via --verify flag
|
||||
//
|
||||
// Usage:
|
||||
// ./14_benchmark_validation_fmha
|
||||
// ./14_benchmark_validation_fmha --seqlen 256 --batch 4 --repeat 20
|
||||
// ./14_benchmark_validation_fmha --verify
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(benchmark_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 14: FMHA Benchmark + Validation",
|
||||
"Warmup, repeated benchmark, optional verification");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "8", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Sequence length (Q and K)");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_option("--warmup", "3", "Warmup iterations");
|
||||
args.add_option("--repeat", "10", "Benchmark repetitions");
|
||||
args.add_flag("--verify", "Validate against CPU reference");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 8);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const int warmup = args.get_int("--warmup", 3);
|
||||
const int repeat = args.get_int("--repeat", 10);
|
||||
|
||||
print_header("Example 14: FMHA Benchmark + Validation");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("benchmark_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t o_elems = q_elems;
|
||||
|
||||
// Step 2: Allocate GPU buffers
|
||||
std::cout << "\nStep 2: Allocate GPU Buffers\n";
|
||||
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
|
||||
<< "]\n";
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(o_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems);
|
||||
std::vector<FmhaDataType> k_host(q_elems);
|
||||
std::vector<FmhaDataType> v_host(q_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.q_ptr = q_dev.get();
|
||||
fmha_args.k_ptr = k_dev.get();
|
||||
fmha_args.v_ptr = v_dev.get();
|
||||
fmha_args.o_ptr = o_dev.get();
|
||||
|
||||
fmha_args.bias_ptr = nullptr;
|
||||
fmha_args.q_descale_ptr = nullptr;
|
||||
fmha_args.k_descale_ptr = nullptr;
|
||||
fmha_args.v_descale_ptr = nullptr;
|
||||
fmha_args.rand_val_ptr = nullptr;
|
||||
fmha_args.lse_ptr = nullptr;
|
||||
fmha_args.sink_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = seqlen;
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.scale_s = scale;
|
||||
fmha_args.logits_soft_cap = 0.0f;
|
||||
|
||||
fmha_args.stride_q = hdim;
|
||||
fmha_args.stride_k = hdim;
|
||||
fmha_args.stride_v = hdim;
|
||||
fmha_args.stride_bias = 0;
|
||||
fmha_args.stride_randval = 0;
|
||||
fmha_args.stride_o = hdim;
|
||||
|
||||
fmha_args.nhead_stride_q = seqlen * hdim;
|
||||
fmha_args.nhead_stride_k = seqlen * hdim;
|
||||
fmha_args.nhead_stride_v = seqlen * hdim;
|
||||
fmha_args.nhead_stride_bias = 0;
|
||||
fmha_args.nhead_stride_randval = 0;
|
||||
fmha_args.nhead_stride_lse = 0;
|
||||
fmha_args.nhead_stride_o = seqlen * hdim;
|
||||
fmha_args.nhead_stride_q_descale = 0;
|
||||
fmha_args.nhead_stride_k_descale = 0;
|
||||
fmha_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fmha_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_bias = 0;
|
||||
fmha_args.batch_stride_randval = 0;
|
||||
fmha_args.batch_stride_lse = 0;
|
||||
fmha_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_q_descale = 0;
|
||||
fmha_args.batch_stride_k_descale = 0;
|
||||
fmha_args.batch_stride_v_descale = 0;
|
||||
|
||||
fmha_args.window_size_left = -1;
|
||||
fmha_args.window_size_right = -1;
|
||||
fmha_args.sink_size = 0;
|
||||
fmha_args.mask_type = 0;
|
||||
fmha_args.min_seqlen_q = 0;
|
||||
fmha_args.p_drop = 0.0f;
|
||||
fmha_args.s_randval = false;
|
||||
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fmha_args.block_scale_size_q = 0;
|
||||
fmha_args.block_scale_size_kv = 0;
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
// Step 3: Warmup runs
|
||||
std::cout << "\nStep 3: Warmup (" << warmup << " iterations)\n";
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 1);
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
o_dev.zero();
|
||||
float t = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
std::cout << " Warmup " << (i + 1) << ": " << std::fixed << std::setprecision(4) << t
|
||||
<< " ms\n";
|
||||
}
|
||||
|
||||
// Step 4: Benchmark runs
|
||||
std::cout << "\nStep 4: Benchmark (" << repeat << " iterations)\n";
|
||||
dispatcher.set_timing(0, 1);
|
||||
std::vector<float> times;
|
||||
times.reserve(repeat);
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
o_dev.zero();
|
||||
float t = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
times.push_back(t);
|
||||
}
|
||||
|
||||
std::sort(times.begin(), times.end());
|
||||
float t_min = times.front();
|
||||
float t_max = times.back();
|
||||
float t_avg = std::accumulate(times.begin(), times.end(), 0.0f) / static_cast<float>(repeat);
|
||||
float t_med =
|
||||
(repeat % 2 == 0) ? (times[repeat / 2 - 1] + times[repeat / 2]) / 2.0f : times[repeat / 2];
|
||||
|
||||
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
|
||||
double ops = static_cast<double>(problem.num_ops());
|
||||
double tflops_min = ops / (t_max * 1e-3) / 1e12;
|
||||
double tflops_max = ops / (t_min * 1e-3) / 1e12;
|
||||
double tflops_avg = ops / (t_avg * 1e-3) / 1e12;
|
||||
double tflops_med = ops / (t_med * 1e-3) / 1e12;
|
||||
|
||||
std::cout << "\n " << std::setw(10) << "Metric" << " | " << std::setw(12) << "Time(ms)"
|
||||
<< " | " << std::setw(12) << "TFLOPS" << "\n";
|
||||
std::cout << " " << std::string(40, '-') << "\n";
|
||||
std::cout << std::fixed << std::setprecision(4);
|
||||
std::cout << " " << std::setw(10) << "Min" << " | " << std::setw(12) << t_min << " | "
|
||||
<< std::setprecision(2) << std::setw(12) << tflops_max << "\n";
|
||||
std::cout << std::setprecision(4);
|
||||
std::cout << " " << std::setw(10) << "Avg" << " | " << std::setw(12) << t_avg << " | "
|
||||
<< std::setprecision(2) << std::setw(12) << tflops_avg << "\n";
|
||||
std::cout << std::setprecision(4);
|
||||
std::cout << " " << std::setw(10) << "Median" << " | " << std::setw(12) << t_med << " | "
|
||||
<< std::setprecision(2) << std::setw(12) << tflops_med << "\n";
|
||||
std::cout << std::setprecision(4);
|
||||
std::cout << " " << std::setw(10) << "Max" << " | " << std::setw(12) << t_max << " | "
|
||||
<< std::setprecision(2) << std::setw(12) << tflops_min << "\n";
|
||||
|
||||
bool passed = true;
|
||||
|
||||
// Step 5: Optional validation
|
||||
if(args.has("--verify"))
|
||||
{
|
||||
std::cout << "\nStep 5: CPU Reference Validation\n";
|
||||
|
||||
std::vector<FmhaDataType> o_host(o_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
std::vector<float> q_f32(q_elems), k_f32(q_elems), v_f32(q_elems), o_ref(o_elems, 0.0f);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
cpu_attention_fwd(
|
||||
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
double max_rel_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
max_rel_err = std::max(max_rel_err, rel_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Max rel error: " << max_rel_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
|
||||
passed = (errors == 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<FmhaDataType> o_host(o_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
std::cout << "\n Sanity: " << nonzero << " / " << o_elems << " non-zero outputs\n";
|
||||
passed = (nonzero > 0);
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return passed ? 0 : 1;
|
||||
}
|
||||
282
dispatcher/examples/fmha/cpp/15_multi_shape_fmha.cpp
Normal file
282
dispatcher/examples/fmha/cpp/15_multi_shape_fmha.cpp
Normal file
@@ -0,0 +1,282 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 15: Multi-Shape FMHA Sweep
|
||||
//
|
||||
// Demonstrates running a single FMHA kernel across multiple (batch, seqlen)
|
||||
// combinations, producing a performance table. This pattern is useful for
|
||||
// characterizing kernel behavior across the parameter space.
|
||||
//
|
||||
// Usage:
|
||||
// ./15_multi_shape_fmha
|
||||
// ./15_multi_shape_fmha --arch gfx942
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(multi_shape_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
struct ShapeConfig
|
||||
{
|
||||
int batch;
|
||||
int seqlen;
|
||||
};
|
||||
|
||||
const ShapeConfig SHAPES[] = {
|
||||
{1, 64},
|
||||
{1, 128},
|
||||
{1, 256},
|
||||
{1, 512},
|
||||
{2, 64},
|
||||
{2, 128},
|
||||
{2, 256},
|
||||
{4, 64},
|
||||
{4, 128},
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 15: Multi-Shape FMHA",
|
||||
"Sweep (batch, seqlen) combos with a single kernel");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--nhead", "8", "Number of heads");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int nhead = args.get_int("--nhead", 8);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 15: Multi-Shape FMHA Sweep");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("multi_shape_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// Step 2: Sweep shapes
|
||||
std::cout << "\nStep 2: Shape Sweep (nhead=" << nhead << ", hdim=" << hdim << ")\n\n";
|
||||
|
||||
std::cout << " " << std::setw(6) << "Batch" << " | " << std::setw(8) << "SeqLen" << " | "
|
||||
<< std::setw(12) << "Elements" << " | " << std::setw(10) << "Time(ms)" << " | "
|
||||
<< std::setw(10) << "TFLOPS" << " | " << std::setw(8) << "Status" << "\n";
|
||||
std::cout << " " << std::string(66, '-') << "\n";
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
int pass_count = 0;
|
||||
int total = 0;
|
||||
const int num_shapes = sizeof(SHAPES) / sizeof(SHAPES[0]);
|
||||
|
||||
for(int si = 0; si < num_shapes; ++si)
|
||||
{
|
||||
const auto& shape = SHAPES[si];
|
||||
++total;
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
const int64_t elems = static_cast<int64_t>(shape.batch) * nhead * shape.seqlen * hdim;
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(elems);
|
||||
|
||||
std::vector<FmhaDataType> h_buf(elems);
|
||||
for(auto& x : h_buf)
|
||||
x = FmhaDataType(dist(rng));
|
||||
q_dev.copy_from_host(h_buf.data());
|
||||
for(auto& x : h_buf)
|
||||
x = FmhaDataType(dist(rng));
|
||||
k_dev.copy_from_host(h_buf.data());
|
||||
for(auto& x : h_buf)
|
||||
x = FmhaDataType(dist(rng));
|
||||
v_dev.copy_from_host(h_buf.data());
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.q_ptr = q_dev.get();
|
||||
fmha_args.k_ptr = k_dev.get();
|
||||
fmha_args.v_ptr = v_dev.get();
|
||||
fmha_args.o_ptr = o_dev.get();
|
||||
|
||||
fmha_args.bias_ptr = nullptr;
|
||||
fmha_args.q_descale_ptr = nullptr;
|
||||
fmha_args.k_descale_ptr = nullptr;
|
||||
fmha_args.v_descale_ptr = nullptr;
|
||||
fmha_args.rand_val_ptr = nullptr;
|
||||
fmha_args.lse_ptr = nullptr;
|
||||
fmha_args.sink_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fmha_args.seqlen_q = shape.seqlen;
|
||||
fmha_args.seqlen_k = shape.seqlen;
|
||||
fmha_args.batch = shape.batch;
|
||||
fmha_args.max_seqlen_q = shape.seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.scale_s = scale;
|
||||
fmha_args.logits_soft_cap = 0.0f;
|
||||
|
||||
fmha_args.stride_q = hdim;
|
||||
fmha_args.stride_k = hdim;
|
||||
fmha_args.stride_v = hdim;
|
||||
fmha_args.stride_bias = 0;
|
||||
fmha_args.stride_randval = 0;
|
||||
fmha_args.stride_o = hdim;
|
||||
|
||||
fmha_args.nhead_stride_q = shape.seqlen * hdim;
|
||||
fmha_args.nhead_stride_k = shape.seqlen * hdim;
|
||||
fmha_args.nhead_stride_v = shape.seqlen * hdim;
|
||||
fmha_args.nhead_stride_bias = 0;
|
||||
fmha_args.nhead_stride_randval = 0;
|
||||
fmha_args.nhead_stride_lse = 0;
|
||||
fmha_args.nhead_stride_o = shape.seqlen * hdim;
|
||||
fmha_args.nhead_stride_q_descale = 0;
|
||||
fmha_args.nhead_stride_k_descale = 0;
|
||||
fmha_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fmha_args.batch_stride_q = nhead * shape.seqlen * hdim;
|
||||
fmha_args.batch_stride_k = nhead * shape.seqlen * hdim;
|
||||
fmha_args.batch_stride_v = nhead * shape.seqlen * hdim;
|
||||
fmha_args.batch_stride_bias = 0;
|
||||
fmha_args.batch_stride_randval = 0;
|
||||
fmha_args.batch_stride_lse = 0;
|
||||
fmha_args.batch_stride_o = nhead * shape.seqlen * hdim;
|
||||
fmha_args.batch_stride_q_descale = 0;
|
||||
fmha_args.batch_stride_k_descale = 0;
|
||||
fmha_args.batch_stride_v_descale = 0;
|
||||
|
||||
fmha_args.window_size_left = -1;
|
||||
fmha_args.window_size_right = -1;
|
||||
fmha_args.sink_size = 0;
|
||||
fmha_args.mask_type = 0;
|
||||
fmha_args.min_seqlen_q = 0;
|
||||
fmha_args.p_drop = 0.0f;
|
||||
fmha_args.s_randval = false;
|
||||
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fmha_args.block_scale_size_q = 0;
|
||||
fmha_args.block_scale_size_kv = 0;
|
||||
|
||||
bool ok = false;
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
|
||||
std::vector<FmhaDataType> o_host(elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
ok = (nonzero > 0);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR for B=" << shape.batch << " S=" << shape.seqlen << ": "
|
||||
<< e.what() << "\n";
|
||||
}
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
|
||||
|
||||
std::cout << std::fixed;
|
||||
std::cout << " " << std::setw(6) << shape.batch << " | " << std::setw(8) << shape.seqlen
|
||||
<< " | " << std::setw(12) << elems << " | " << std::setprecision(4)
|
||||
<< std::setw(10) << time_ms << " | " << std::setprecision(2) << std::setw(10)
|
||||
<< tflops << " | " << std::setw(8) << (ok ? "PASS" : "FAIL") << "\n";
|
||||
|
||||
if(ok)
|
||||
++pass_count;
|
||||
}
|
||||
|
||||
// Summary
|
||||
print_separator();
|
||||
std::cout << "Results: " << pass_count << "/" << total << " shapes passed\n";
|
||||
std::cout << "Status: " << (pass_count == total ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return (pass_count == total) ? 0 : 1;
|
||||
}
|
||||
428
dispatcher/examples/fmha/cpp/16_heuristics_fmha.cpp
Normal file
428
dispatcher/examples/fmha/cpp/16_heuristics_fmha.cpp
Normal file
@@ -0,0 +1,428 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 16: FMHA Heuristic-Based Kernel Selection
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Two kernels with different tile_m0 (128 vs 64) and selection_rank
|
||||
// 2. Custom heuristic function that picks kernels based on seqlen
|
||||
// 3. dispatcher.set_heuristic() + SelectionStrategy::Heuristic
|
||||
// 4. Planning different problems to show which kernel is selected
|
||||
// 5. GPU execution for at least one problem
|
||||
//
|
||||
// Usage:
|
||||
// ./16_heuristics_fmha
|
||||
// ./16_heuristics_fmha --arch gfx942
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(heuristic_fmha_kernels,
|
||||
// Kernel A: Large tile (128x128) -- better for long sequences
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
// Kernel B: Smaller tile_m0 (64x128) -- lower latency for short sequences
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(1),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 16: FMHA Heuristic Kernel Selection",
|
||||
"Custom heuristic picks kernel based on seqlen");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--nhead", "8", "Number of heads");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int nhead = args.get_int("--nhead", 8);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 16: FMHA Heuristic Kernel Selection");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("heuristic_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
// Step 2: Set up heuristic
|
||||
std::cout << "\nStep 2: Configure Heuristic\n";
|
||||
std::cout << " Rule: seqlen >= 256 -> prefer large tile (128x128, rank=0)\n";
|
||||
std::cout << " seqlen < 256 -> prefer small tile (64x128, rank=1)\n";
|
||||
|
||||
auto all_kernels = registry.all_kernels();
|
||||
std::cout << " Available kernels:\n";
|
||||
for(const auto& k : all_kernels)
|
||||
{
|
||||
std::cout << " - " << k->id() << "\n";
|
||||
}
|
||||
|
||||
std::string kernel_a_id, kernel_b_id;
|
||||
for(const auto& k : all_kernels)
|
||||
{
|
||||
auto kid = k->id();
|
||||
if(kernel_a_id.empty())
|
||||
kernel_a_id = kid;
|
||||
else if(kernel_b_id.empty())
|
||||
kernel_b_id = kid;
|
||||
}
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_strategy(SelectionStrategy::Heuristic);
|
||||
dispatcher.set_heuristic([&](const FmhaProblem& problem) -> std::vector<std::string> {
|
||||
if(problem.seqlen_q >= 256)
|
||||
return {kernel_a_id, kernel_b_id};
|
||||
else
|
||||
return {kernel_b_id, kernel_a_id};
|
||||
});
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// Step 3: Plan different problems to show kernel selection
|
||||
std::cout << "\nStep 3: Plan Problems (show kernel selection)\n\n";
|
||||
|
||||
struct PlanCase
|
||||
{
|
||||
int batch;
|
||||
int seqlen;
|
||||
};
|
||||
PlanCase plan_cases[] = {{1, 64}, {1, 128}, {2, 256}, {2, 512}, {4, 1024}};
|
||||
|
||||
std::cout << " " << std::setw(6) << "Batch" << " | " << std::setw(8) << "SeqLen" << " | "
|
||||
<< std::setw(50) << "Selected Kernel" << "\n";
|
||||
std::cout << " " << std::string(68, '-') << "\n";
|
||||
|
||||
for(const auto& pc : plan_cases)
|
||||
{
|
||||
auto problem = FmhaProblemBuilder()
|
||||
.api_family(FmhaApiFamily::Fwd)
|
||||
.kernel_family(FmhaKernelFamily::Fwd)
|
||||
.gfx_arch(gfx_arch)
|
||||
.data_type("fp16")
|
||||
.dims(hdim, hdim, pc.batch, pc.seqlen, pc.seqlen)
|
||||
.nheads(nhead, nhead)
|
||||
.mask_type(0)
|
||||
.bias_type(0)
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.build();
|
||||
|
||||
auto plan = dispatcher.plan(problem);
|
||||
std::string selected = plan.is_valid() ? plan.stages[0].kernel_id : "(no match)";
|
||||
std::cout << " " << std::setw(6) << pc.batch << " | " << std::setw(8) << pc.seqlen << " | "
|
||||
<< std::setw(50) << selected << "\n";
|
||||
}
|
||||
|
||||
// Step 4: GPU execution for a representative problem
|
||||
std::cout << "\nStep 4: GPU Execution (batch=2, seqlen=256)\n";
|
||||
|
||||
const int batch = 2;
|
||||
const int seqlen = 256;
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
const int64_t elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> fdist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(elems), k_host(elems), v_host(elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(fdist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(fdist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(fdist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.q_ptr = q_dev.get();
|
||||
fmha_args.k_ptr = k_dev.get();
|
||||
fmha_args.v_ptr = v_dev.get();
|
||||
fmha_args.o_ptr = o_dev.get();
|
||||
|
||||
fmha_args.bias_ptr = nullptr;
|
||||
fmha_args.q_descale_ptr = nullptr;
|
||||
fmha_args.k_descale_ptr = nullptr;
|
||||
fmha_args.v_descale_ptr = nullptr;
|
||||
fmha_args.rand_val_ptr = nullptr;
|
||||
fmha_args.lse_ptr = nullptr;
|
||||
fmha_args.sink_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = seqlen;
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.scale_s = scale;
|
||||
fmha_args.logits_soft_cap = 0.0f;
|
||||
|
||||
fmha_args.stride_q = hdim;
|
||||
fmha_args.stride_k = hdim;
|
||||
fmha_args.stride_v = hdim;
|
||||
fmha_args.stride_bias = 0;
|
||||
fmha_args.stride_randval = 0;
|
||||
fmha_args.stride_o = hdim;
|
||||
|
||||
fmha_args.nhead_stride_q = seqlen * hdim;
|
||||
fmha_args.nhead_stride_k = seqlen * hdim;
|
||||
fmha_args.nhead_stride_v = seqlen * hdim;
|
||||
fmha_args.nhead_stride_bias = 0;
|
||||
fmha_args.nhead_stride_randval = 0;
|
||||
fmha_args.nhead_stride_lse = 0;
|
||||
fmha_args.nhead_stride_o = seqlen * hdim;
|
||||
fmha_args.nhead_stride_q_descale = 0;
|
||||
fmha_args.nhead_stride_k_descale = 0;
|
||||
fmha_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fmha_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_bias = 0;
|
||||
fmha_args.batch_stride_randval = 0;
|
||||
fmha_args.batch_stride_lse = 0;
|
||||
fmha_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_q_descale = 0;
|
||||
fmha_args.batch_stride_k_descale = 0;
|
||||
fmha_args.batch_stride_v_descale = 0;
|
||||
|
||||
fmha_args.window_size_left = -1;
|
||||
fmha_args.window_size_right = -1;
|
||||
fmha_args.sink_size = 0;
|
||||
fmha_args.mask_type = 0;
|
||||
fmha_args.min_seqlen_q = 0;
|
||||
fmha_args.p_drop = 0.0f;
|
||||
fmha_args.s_randval = false;
|
||||
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fmha_args.block_scale_size_q = 0;
|
||||
fmha_args.block_scale_size_kv = 0;
|
||||
|
||||
float time_ms = 0.0f;
|
||||
bool passed = false;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
|
||||
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Validate against CPU reference
|
||||
std::vector<FmhaDataType> o_host(elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
std::vector<float> q_f32(elems), k_f32(elems), v_f32(elems), o_ref(elems, 0.0f);
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
cpu_attention_fwd(
|
||||
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
int errors = 0;
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
{
|
||||
double abs_err = std::abs(static_cast<float>(o_host[i]) - o_ref[i]);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
if(abs_err > 1e-2 + 1e-2 * std::abs(o_ref[i]))
|
||||
++errors;
|
||||
}
|
||||
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << elems << "\n";
|
||||
passed = (errors == 0);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return passed ? 0 : 1;
|
||||
}
|
||||
423
dispatcher/examples/fmha/cpp/17_autofill_autocorrect_fmha.cpp
Normal file
423
dispatcher/examples/fmha/cpp/17_autofill_autocorrect_fmha.cpp
Normal file
@@ -0,0 +1,423 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 17: FMHA Autofill and Autocorrect
|
||||
//
|
||||
// Demonstrates three DECL_FMHA_KERNEL_SET patterns:
|
||||
// 1. AUTOFILL: Minimal specification -- only family/dtype/hdim/pipeline/tile
|
||||
// are provided; wave/warp use defaults from FmhaAlgorithm constructor
|
||||
// 2. AUTOCORRECT: Intentionally non-standard wave config that still works
|
||||
// because FmhaAlgorithm auto_fill() corrects missing tile_n1/tile_k1
|
||||
// 3. FULL: All parameters explicitly specified (reference)
|
||||
//
|
||||
// Each is registered, planned, run on GPU, and validated.
|
||||
//
|
||||
// Usage:
|
||||
// ./17_autofill_autocorrect_fmha
|
||||
// ./17_autofill_autocorrect_fmha --arch gfx942
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
// Pattern 1: AUTOFILL -- minimal specification, defaults for wave/warp
|
||||
DECL_FMHA_KERNEL_SET(autofill_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950"));
|
||||
|
||||
// Pattern 2: AUTOCORRECT -- tile_n1/tile_k1 set to 0, auto_fill() corrects them
|
||||
DECL_FMHA_KERNEL_SET(autocorrect_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true),
|
||||
"gfx950"));
|
||||
|
||||
// Pattern 3: FULL -- every parameter explicitly specified
|
||||
DECL_FMHA_KERNEL_SET(full_spec_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct KernelTestCase
|
||||
{
|
||||
std::string name;
|
||||
std::string kernel_set_name;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 17: FMHA Autofill & Autocorrect",
|
||||
"Three DECL_FMHA_KERNEL_SET patterns compared");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "8", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 8);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 17: FMHA Autofill & Autocorrect");
|
||||
|
||||
// Step 1: Show registered kernel sets
|
||||
std::cout << "\nStep 1: Registered Kernel Sets\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
const KernelTestCase cases[] = {
|
||||
{"AUTOFILL (minimal spec, wave/warp defaults)", "autofill_kernels"},
|
||||
{"AUTOCORRECT (tile_n1/k1=0, auto_fill corrects)", "autocorrect_kernels"},
|
||||
{"FULL (all params explicit)", "full_spec_kernels"},
|
||||
};
|
||||
|
||||
// Prepare input data (shared across all tests)
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
const int64_t elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(elems), k_host(elems), v_host(elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
// CPU reference
|
||||
std::vector<float> q_f32(elems), k_f32(elems), v_f32(elems), o_ref(elems, 0.0f);
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
cpu_attention_fwd(q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
|
||||
int total_pass = 0;
|
||||
const int total_cases = sizeof(cases) / sizeof(cases[0]);
|
||||
|
||||
for(int ci = 0; ci < total_cases; ++ci)
|
||||
{
|
||||
const auto& tc = cases[ci];
|
||||
std::cout << "\nStep " << (ci + 2) << ": " << tc.name << "\n";
|
||||
|
||||
// Register from the named kernel set
|
||||
FmhaRegistry registry;
|
||||
registry.set_name(tc.kernel_set_name);
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
if(registry.size() == 0)
|
||||
{
|
||||
std::cout << " SKIP: no kernels registered\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// Allocate GPU buffers
|
||||
GpuBuffer<FmhaDataType> q_dev(elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(elems);
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.q_ptr = q_dev.get();
|
||||
fmha_args.k_ptr = k_dev.get();
|
||||
fmha_args.v_ptr = v_dev.get();
|
||||
fmha_args.o_ptr = o_dev.get();
|
||||
|
||||
fmha_args.bias_ptr = nullptr;
|
||||
fmha_args.q_descale_ptr = nullptr;
|
||||
fmha_args.k_descale_ptr = nullptr;
|
||||
fmha_args.v_descale_ptr = nullptr;
|
||||
fmha_args.rand_val_ptr = nullptr;
|
||||
fmha_args.lse_ptr = nullptr;
|
||||
fmha_args.sink_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = seqlen;
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.scale_s = scale;
|
||||
fmha_args.logits_soft_cap = 0.0f;
|
||||
|
||||
fmha_args.stride_q = hdim;
|
||||
fmha_args.stride_k = hdim;
|
||||
fmha_args.stride_v = hdim;
|
||||
fmha_args.stride_bias = 0;
|
||||
fmha_args.stride_randval = 0;
|
||||
fmha_args.stride_o = hdim;
|
||||
|
||||
fmha_args.nhead_stride_q = seqlen * hdim;
|
||||
fmha_args.nhead_stride_k = seqlen * hdim;
|
||||
fmha_args.nhead_stride_v = seqlen * hdim;
|
||||
fmha_args.nhead_stride_bias = 0;
|
||||
fmha_args.nhead_stride_randval = 0;
|
||||
fmha_args.nhead_stride_lse = 0;
|
||||
fmha_args.nhead_stride_o = seqlen * hdim;
|
||||
fmha_args.nhead_stride_q_descale = 0;
|
||||
fmha_args.nhead_stride_k_descale = 0;
|
||||
fmha_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fmha_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_bias = 0;
|
||||
fmha_args.batch_stride_randval = 0;
|
||||
fmha_args.batch_stride_lse = 0;
|
||||
fmha_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_q_descale = 0;
|
||||
fmha_args.batch_stride_k_descale = 0;
|
||||
fmha_args.batch_stride_v_descale = 0;
|
||||
|
||||
fmha_args.window_size_left = -1;
|
||||
fmha_args.window_size_right = -1;
|
||||
fmha_args.sink_size = 0;
|
||||
fmha_args.mask_type = 0;
|
||||
fmha_args.min_seqlen_q = 0;
|
||||
fmha_args.p_drop = 0.0f;
|
||||
fmha_args.s_randval = false;
|
||||
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fmha_args.block_scale_size_q = 0;
|
||||
fmha_args.block_scale_size_kv = 0;
|
||||
|
||||
try
|
||||
{
|
||||
float time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
|
||||
|
||||
// Validate
|
||||
std::vector<FmhaDataType> o_host(elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
int errors = 0;
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
{
|
||||
double abs_err = std::abs(static_cast<float>(o_host[i]) - o_ref[i]);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
if(abs_err > 1e-2 + 1e-2 * std::abs(o_ref[i]))
|
||||
++errors;
|
||||
}
|
||||
|
||||
bool ok = (errors == 0);
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms"
|
||||
<< " TFLOPS: " << std::setprecision(2) << tflops
|
||||
<< " MaxErr: " << std::scientific << max_abs_err << " "
|
||||
<< (ok ? "PASS" : "FAIL") << "\n";
|
||||
if(ok)
|
||||
++total_pass;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Summary
|
||||
print_separator();
|
||||
std::cout << "Results: " << total_pass << "/" << total_cases << " patterns passed\n";
|
||||
std::cout << "Patterns:\n";
|
||||
std::cout << " 1. AUTOFILL: Only tile + pipeline specified; wave/warp use defaults\n";
|
||||
std::cout << " 2. AUTOCORRECT: tile_n1/k1/k0max=0 -> auto_fill() infers from tile_n0/k0\n";
|
||||
std::cout << " 3. FULL: Every parameter explicit (reference configuration)\n";
|
||||
std::cout << "Status: " << (total_pass == total_cases ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return (total_pass == total_cases) ? 0 : 1;
|
||||
}
|
||||
466
dispatcher/examples/fmha/cpp/18_gpu_splitkv_fmha.cpp
Normal file
466
dispatcher/examples/fmha/cpp/18_gpu_splitkv_fmha.cpp
Normal file
@@ -0,0 +1,466 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 18: GPU Split-KV FMHA Forward
|
||||
//
|
||||
// Demonstrates split-KV attention with GPU execution:
|
||||
// 1. Declare both fwd_splitkv and fwd_splitkv_combine kernels
|
||||
// 2. Show 2-stage execution plan
|
||||
// 3. Allocate Q, K, V, O plus workspace (lse_acc, o_acc)
|
||||
// 4. Run the split-KV forward pass on GPU
|
||||
// 5. Copy output to host and validate against CPU reference
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(splitkv_gpu_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_splitkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.paged_kv(false),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_nwarp_sshuffle")
|
||||
.padding(true, true, true, true)
|
||||
.max_splits_log2(6)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_splitkv_combine")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.paged_kv(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(32)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(16)
|
||||
.warp_n0(16)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr")
|
||||
.padding(true, true, true, true)
|
||||
.max_splits_log2(6)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 18: GPU Split-KV FMHA Forward", "Split-KV with GPU execution");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen_q", "64", "Query sequence length");
|
||||
args.add_option("--seqlen_k", "2048", "KV sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_option("--splits", "2", "Number of KV splits");
|
||||
args.add_flag("--validate", "Validate against CPU reference");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen_q = args.get_int("--seqlen_q", 64);
|
||||
const int seqlen_k = args.get_int("--seqlen_k", 2048);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const int num_splits = args.get_int("--splits", 2);
|
||||
|
||||
print_header("Example 18: GPU Split-KV FMHA Forward");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("splitkv_gpu_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// Step 2: Set up traits and plan
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
fmha_fwd_splitkv_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = true;
|
||||
traits.do_fp8_static_quant = false;
|
||||
traits.has_sink = false;
|
||||
|
||||
// Workspace sizes: lse_acc [batch, nhead, num_splits, seqlen_q]
|
||||
// o_acc [batch, nhead, num_splits, seqlen_q, hdim]
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen_q * hdim;
|
||||
const int64_t k_elems = static_cast<int64_t>(batch) * nhead * seqlen_k * hdim;
|
||||
const int64_t v_elems = k_elems;
|
||||
const int64_t o_elems = q_elems;
|
||||
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen_q;
|
||||
const int64_t lse_acc_elems = static_cast<int64_t>(batch) * nhead * num_splits * seqlen_q;
|
||||
const int64_t o_acc_elems = static_cast<int64_t>(batch) * nhead * num_splits * seqlen_q * hdim;
|
||||
|
||||
// Show the 2-stage plan
|
||||
std::cout << "\nStep 2: Plan (2-stage split-KV)\n";
|
||||
|
||||
fmha_fwd_splitkv_args plan_args{};
|
||||
plan_args.seqlen_q = seqlen_q;
|
||||
plan_args.seqlen_k = seqlen_k;
|
||||
plan_args.batch = batch;
|
||||
plan_args.max_seqlen_q = seqlen_q;
|
||||
plan_args.hdim_q = hdim;
|
||||
plan_args.hdim_v = hdim;
|
||||
plan_args.nhead_q = nhead;
|
||||
plan_args.nhead_k = nhead;
|
||||
plan_args.num_splits = num_splits;
|
||||
|
||||
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, plan_args), gfx_arch);
|
||||
auto plan = dispatcher.plan(problem);
|
||||
|
||||
if(!plan.is_valid() || plan.stages.size() != 2)
|
||||
{
|
||||
std::cerr << " WARNING: Expected a two-stage split-KV plan, got " << plan.stages.size()
|
||||
<< " stage(s)\n";
|
||||
if(!plan.is_valid())
|
||||
{
|
||||
std::cerr << " Plan is invalid -- no matching kernels found\n";
|
||||
print_separator();
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
for(const auto& stage : plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
|
||||
// Step 3: Allocate GPU buffers
|
||||
std::cout << "\nStep 3: Allocate GPU Buffers\n";
|
||||
std::cout << " Q: [" << batch << ", " << nhead << ", " << seqlen_q << ", " << hdim
|
||||
<< "]\n";
|
||||
std::cout << " K/V: [" << batch << ", " << nhead << ", " << seqlen_k << ", " << hdim
|
||||
<< "]\n";
|
||||
std::cout << " O: [" << batch << ", " << nhead << ", " << seqlen_q << ", " << hdim
|
||||
<< "]\n";
|
||||
std::cout << " lse_acc: [" << batch << ", " << nhead << ", " << num_splits << ", " << seqlen_q
|
||||
<< "]\n";
|
||||
std::cout << " o_acc: [" << batch << ", " << nhead << ", " << num_splits << ", " << seqlen_q
|
||||
<< ", " << hdim << "]\n";
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(k_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(v_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(o_elems);
|
||||
GpuBuffer<float> lse_dev(lse_elems);
|
||||
GpuBuffer<float> lse_acc_dev(lse_acc_elems);
|
||||
GpuBuffer<float> o_acc_dev(o_acc_elems);
|
||||
|
||||
// Fill Q, K, V with random data
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems);
|
||||
std::vector<FmhaDataType> k_host(k_elems);
|
||||
std::vector<FmhaDataType> v_host(v_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
lse_acc_dev.zero();
|
||||
o_acc_dev.zero();
|
||||
|
||||
// Step 4: Set up splitkv args with device pointers and strides
|
||||
fmha_fwd_splitkv_args fmha_args{};
|
||||
fmha_args.q_ptr = q_dev.get();
|
||||
fmha_args.k_ptr = k_dev.get();
|
||||
fmha_args.v_ptr = v_dev.get();
|
||||
fmha_args.o_ptr = o_dev.get();
|
||||
|
||||
fmha_args.bias_ptr = nullptr;
|
||||
fmha_args.lse_acc_ptr = lse_acc_dev.get();
|
||||
fmha_args.o_acc_ptr = o_acc_dev.get();
|
||||
fmha_args.lse_ptr = lse_dev.get();
|
||||
|
||||
fmha_args.block_table_ptr = nullptr;
|
||||
fmha_args.batch_stride_block_table = 0;
|
||||
fmha_args.page_block_size = 0;
|
||||
fmha_args.is_gappy = false;
|
||||
fmha_args.cache_batch_idx = nullptr;
|
||||
fmha_args.seqstart_q_ptr = nullptr;
|
||||
fmha_args.seqstart_k_ptr = nullptr;
|
||||
fmha_args.seqlen_k_ptr = nullptr;
|
||||
fmha_args.sink_ptr = nullptr;
|
||||
|
||||
fmha_args.seqlen_q = seqlen_q;
|
||||
fmha_args.seqlen_k = seqlen_k;
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.max_seqlen_q = seqlen_q;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.num_splits = num_splits;
|
||||
|
||||
fmha_args.scale_s = scale;
|
||||
fmha_args.scale_p = 1.0f;
|
||||
fmha_args.scale_o = 1.0f;
|
||||
fmha_args.logits_soft_cap = 0.0f;
|
||||
|
||||
// bhsd layout strides
|
||||
fmha_args.stride_q = hdim;
|
||||
fmha_args.stride_k = hdim;
|
||||
fmha_args.stride_v = hdim;
|
||||
fmha_args.stride_bias = 0;
|
||||
fmha_args.stride_o_acc = hdim;
|
||||
fmha_args.stride_o = hdim;
|
||||
|
||||
fmha_args.nhead_stride_q = seqlen_q * hdim;
|
||||
fmha_args.nhead_stride_k = seqlen_k * hdim;
|
||||
fmha_args.nhead_stride_v = seqlen_k * hdim;
|
||||
fmha_args.nhead_stride_bias = 0;
|
||||
fmha_args.nhead_stride_lse = seqlen_q;
|
||||
fmha_args.nhead_stride_lse_acc = num_splits * seqlen_q;
|
||||
fmha_args.nhead_stride_o_acc = num_splits * seqlen_q * hdim;
|
||||
fmha_args.nhead_stride_o = seqlen_q * hdim;
|
||||
|
||||
fmha_args.batch_stride_q = nhead * seqlen_q * hdim;
|
||||
fmha_args.batch_stride_k = nhead * seqlen_k * hdim;
|
||||
fmha_args.batch_stride_v = nhead * seqlen_k * hdim;
|
||||
fmha_args.batch_stride_bias = 0;
|
||||
fmha_args.batch_stride_lse = nhead * seqlen_q;
|
||||
fmha_args.batch_stride_lse_acc = nhead * num_splits * seqlen_q;
|
||||
fmha_args.batch_stride_o_acc = nhead * num_splits * seqlen_q * hdim;
|
||||
fmha_args.batch_stride_o = nhead * seqlen_q * hdim;
|
||||
|
||||
fmha_args.split_stride_lse_acc = seqlen_q;
|
||||
fmha_args.split_stride_o_acc = seqlen_q * hdim;
|
||||
|
||||
fmha_args.window_size_left = -1;
|
||||
fmha_args.window_size_right = -1;
|
||||
fmha_args.sink_size = 0;
|
||||
fmha_args.mask_type = 0;
|
||||
|
||||
// Step 5: Run on GPU
|
||||
std::cout << "\nStep 4: Run Split-KV FMHA Forward on GPU\n";
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd_splitkv(traits, fmha_args, nullptr);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " WARNING: GPU execution failed: " << e.what() << "\n";
|
||||
std::cerr << " Falling back to planning-only mode (split-KV compilation can be complex)\n";
|
||||
std::cout << "\n Plan summary (2 stages):\n";
|
||||
for(const auto& stage : plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
print_separator();
|
||||
std::cout << "Status: PLAN_ONLY\n";
|
||||
print_separator();
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto run_problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
|
||||
double tflops = static_cast<double>(run_problem.num_ops()) / (time_ms * 1e-3) / 1e12;
|
||||
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Step 6: Copy output and validate
|
||||
std::cout << "\nStep 5: Validate\n";
|
||||
std::vector<FmhaDataType> o_host(o_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
|
||||
|
||||
bool passed = (nonzero > 0);
|
||||
|
||||
if(args.has("--validate"))
|
||||
{
|
||||
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < k_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < v_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
cpu_attention_fwd(
|
||||
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen_q, seqlen_k, hdim, hdim, scale);
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
double max_rel_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
max_rel_err = std::max(max_rel_err, rel_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Max rel error: " << max_rel_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
|
||||
passed = (errors == 0);
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return passed ? 0 : 1;
|
||||
}
|
||||
456
dispatcher/examples/fmha/cpp/19_gpu_masks_fmha.cpp
Normal file
456
dispatcher/examples/fmha/cpp/19_gpu_masks_fmha.cpp
Normal file
@@ -0,0 +1,456 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 19: GPU FMHA Forward with Mask Types
|
||||
//
|
||||
// Demonstrates three mask variants with GPU execution:
|
||||
// 1. No mask (standard attention)
|
||||
// 2. Top-left causal mask (zero upper triangle)
|
||||
// 3. Bottom-right causal mask (shifted diagonal)
|
||||
//
|
||||
// Uses seqlen_q=64, seqlen_k=128 to make mask behavior visible.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(mask_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
// Note: bottom_right shares the same compiled kernel as top_left
|
||||
// (both use SimplifiedGenericAttentionMask<true>). The mask type
|
||||
// is resolved at runtime via args.mask_type, not the template.
|
||||
// fmha_mask_compatible() in generated_fmha_backend.hpp handles this.
|
||||
);
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
// mask_type: 0=no_mask, 1=top_left, 2=bottom_right
|
||||
void cpu_attention_fwd_masked(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
int mask_type)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
|
||||
bool masked = false;
|
||||
if(mask_type == 1)
|
||||
{
|
||||
// top_left: causal from top-left, mask if sk >= sq + 1
|
||||
if(sk >= sq + 1)
|
||||
masked = true;
|
||||
}
|
||||
else if(mask_type == 2)
|
||||
{
|
||||
// bottom_right: shifted diagonal, mask if sk >= sq + (seqlen_k - seqlen_q)
|
||||
// + 1
|
||||
if(sk >= sq + (seqlen_k - seqlen_q) + 1)
|
||||
masked = true;
|
||||
}
|
||||
|
||||
if(masked)
|
||||
scores[sk] = -1e30f;
|
||||
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 19: FMHA with Masks (GPU)", "FMHA mask variants on GPU");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen_q", "64", "Query sequence length");
|
||||
args.add_option("--seqlen_k", "128", "KV sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_flag("--validate", "Validate against CPU reference");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen_q = args.get_int("--seqlen_q", 64);
|
||||
const int seqlen_k = args.get_int("--seqlen_k", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 19: FMHA with Masks (GPU)");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("mask_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
// Allocate GPU buffers
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen_q * hdim;
|
||||
const int64_t k_elems = static_cast<int64_t>(batch) * nhead * seqlen_k * hdim;
|
||||
const int64_t v_elems = k_elems;
|
||||
const int64_t o_elems = q_elems;
|
||||
|
||||
std::cout << "\nStep 2: Allocate GPU Buffers\n";
|
||||
std::cout << " Q/O: [" << batch << ", " << nhead << ", " << seqlen_q << ", " << hdim << "]\n";
|
||||
std::cout << " K/V: [" << batch << ", " << nhead << ", " << seqlen_k << ", " << hdim << "]\n";
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(k_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(v_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(o_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems);
|
||||
std::vector<FmhaDataType> k_host(k_elems);
|
||||
std::vector<FmhaDataType> v_host(v_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
|
||||
// Convert to f32 for CPU reference
|
||||
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < k_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < v_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
// Test each mask type
|
||||
struct MaskTest
|
||||
{
|
||||
const char* name;
|
||||
int mask_type_int;
|
||||
mask_enum mask_type;
|
||||
};
|
||||
|
||||
MaskTest tests[] = {
|
||||
{"no_mask", 0, mask_enum::no_mask},
|
||||
{"top_left", 1, mask_enum::mask_top_left},
|
||||
{"bottom_right", 2, mask_enum::mask_bottom_right},
|
||||
};
|
||||
|
||||
bool all_passed = true;
|
||||
|
||||
for(const auto& test : tests)
|
||||
{
|
||||
std::cout << "\nStep 3: Run FMHA Forward [" << test.name << "]\n";
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = test.mask_type;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.q_ptr = q_dev.get();
|
||||
fmha_args.k_ptr = k_dev.get();
|
||||
fmha_args.v_ptr = v_dev.get();
|
||||
fmha_args.o_ptr = o_dev.get();
|
||||
|
||||
fmha_args.bias_ptr = nullptr;
|
||||
fmha_args.q_descale_ptr = nullptr;
|
||||
fmha_args.k_descale_ptr = nullptr;
|
||||
fmha_args.v_descale_ptr = nullptr;
|
||||
fmha_args.rand_val_ptr = nullptr;
|
||||
fmha_args.lse_ptr = nullptr;
|
||||
fmha_args.sink_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fmha_args.seqlen_q = seqlen_q;
|
||||
fmha_args.seqlen_k = seqlen_k;
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.max_seqlen_q = seqlen_q;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.scale_s = scale;
|
||||
fmha_args.logits_soft_cap = 0.0f;
|
||||
|
||||
// bhsd layout strides
|
||||
fmha_args.stride_q = hdim;
|
||||
fmha_args.stride_k = hdim;
|
||||
fmha_args.stride_v = hdim;
|
||||
fmha_args.stride_bias = 0;
|
||||
fmha_args.stride_randval = 0;
|
||||
fmha_args.stride_o = hdim;
|
||||
|
||||
fmha_args.nhead_stride_q = seqlen_q * hdim;
|
||||
fmha_args.nhead_stride_k = seqlen_k * hdim;
|
||||
fmha_args.nhead_stride_v = seqlen_k * hdim;
|
||||
fmha_args.nhead_stride_bias = 0;
|
||||
fmha_args.nhead_stride_randval = 0;
|
||||
fmha_args.nhead_stride_lse = 0;
|
||||
fmha_args.nhead_stride_o = seqlen_q * hdim;
|
||||
fmha_args.nhead_stride_q_descale = 0;
|
||||
fmha_args.nhead_stride_k_descale = 0;
|
||||
fmha_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fmha_args.batch_stride_q = nhead * seqlen_q * hdim;
|
||||
fmha_args.batch_stride_k = nhead * seqlen_k * hdim;
|
||||
fmha_args.batch_stride_v = nhead * seqlen_k * hdim;
|
||||
fmha_args.batch_stride_bias = 0;
|
||||
fmha_args.batch_stride_randval = 0;
|
||||
fmha_args.batch_stride_lse = 0;
|
||||
fmha_args.batch_stride_o = nhead * seqlen_q * hdim;
|
||||
fmha_args.batch_stride_q_descale = 0;
|
||||
fmha_args.batch_stride_k_descale = 0;
|
||||
fmha_args.batch_stride_v_descale = 0;
|
||||
|
||||
fmha_args.window_size_left = -1;
|
||||
fmha_args.window_size_right = (test.mask_type_int == 0) ? -1 : 0;
|
||||
fmha_args.sink_size = 0;
|
||||
fmha_args.mask_type = test.mask_type_int;
|
||||
fmha_args.min_seqlen_q = 0;
|
||||
fmha_args.p_drop = 0.0f;
|
||||
fmha_args.s_randval = false;
|
||||
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fmha_args.block_scale_size_q = 0;
|
||||
fmha_args.block_scale_size_kv = 0;
|
||||
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR [" << test.name << "]: " << e.what() << "\n";
|
||||
all_passed = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
|
||||
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Validate
|
||||
std::vector<FmhaDataType> o_host(o_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
|
||||
|
||||
if(nonzero == 0)
|
||||
all_passed = false;
|
||||
|
||||
if(args.has("--validate"))
|
||||
{
|
||||
std::vector<float> o_ref(o_elems, 0.0f);
|
||||
cpu_attention_fwd_masked(q_f32,
|
||||
k_f32,
|
||||
v_f32,
|
||||
o_ref,
|
||||
batch,
|
||||
nhead,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim,
|
||||
hdim,
|
||||
scale,
|
||||
test.mask_type_int);
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
double max_rel_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
max_rel_err = std::max(max_rel_err, rel_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Max rel error: " << max_rel_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
|
||||
if(errors > 0)
|
||||
all_passed = false;
|
||||
}
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
584
dispatcher/examples/fmha/cpp/20_gpu_bias_fmha.cpp
Normal file
584
dispatcher/examples/fmha/cpp/20_gpu_bias_fmha.cpp
Normal file
@@ -0,0 +1,584 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 20: GPU FMHA Forward with Bias Types
|
||||
//
|
||||
// Demonstrates three bias variants with GPU execution:
|
||||
// 1. No bias (standard attention)
|
||||
// 2. Elementwise bias (arbitrary bias matrix added to scores)
|
||||
// 3. ALiBi (Attention with Linear Biases -- slope-based positional encoding)
|
||||
//
|
||||
// Validates each variant against a CPU reference.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(bias_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("bias")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// bias_type: 0=none, 1=elementwise, 2=alibi
|
||||
// bias_buf layout: elementwise [1, nhead, seqlen_q, seqlen_k], alibi [1, nhead] slopes
|
||||
void cpu_attention_fwd_biased(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
int bias_type,
|
||||
const std::vector<float>& bias_buf)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
if(bias_type == 1)
|
||||
{
|
||||
int bias_idx = (h * seqlen_q + sq) * seqlen_k + sk;
|
||||
s += bias_buf[bias_idx];
|
||||
}
|
||||
else if(bias_type == 2)
|
||||
{
|
||||
float slope = bias_buf[h];
|
||||
s += slope * static_cast<float>(sk - sq);
|
||||
}
|
||||
|
||||
scores[sk] = s;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 20: FMHA with Bias (GPU)", "FMHA bias variants on GPU");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_flag("--validate", "Validate against CPU reference");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 20: FMHA with Bias (GPU)");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("bias_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
// Allocate Q, K, V GPU buffers (shared across all bias tests)
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t k_elems = q_elems;
|
||||
const int64_t v_elems = q_elems;
|
||||
const int64_t o_elems = q_elems;
|
||||
|
||||
std::cout << "\nStep 2: Allocate GPU Buffers\n";
|
||||
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
|
||||
<< "]\n";
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(k_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(v_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(o_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems);
|
||||
std::vector<FmhaDataType> k_host(k_elems);
|
||||
std::vector<FmhaDataType> v_host(v_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
|
||||
// Convert to f32 for CPU reference
|
||||
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < k_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < v_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
// Prepare elementwise bias buffer: [1, nhead, seqlen, seqlen] with small values
|
||||
const int64_t elem_bias_elems = static_cast<int64_t>(nhead) * seqlen * seqlen;
|
||||
std::vector<float> elem_bias_host(elem_bias_elems);
|
||||
std::uniform_real_distribution<float> bias_dist(-0.1f, 0.1f);
|
||||
for(auto& x : elem_bias_host)
|
||||
x = bias_dist(rng);
|
||||
|
||||
GpuBuffer<float> elem_bias_dev(elem_bias_elems);
|
||||
elem_bias_dev.copy_from_host(elem_bias_host.data());
|
||||
|
||||
// Prepare ALiBi slopes buffer: [nhead] with geometric slopes
|
||||
std::vector<float> alibi_slopes_host(nhead);
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
alibi_slopes_host[h] = -std::pow(2.0f, -(8.0f * (h + 1) / nhead));
|
||||
}
|
||||
|
||||
GpuBuffer<float> alibi_slopes_dev(nhead);
|
||||
alibi_slopes_dev.copy_from_host(alibi_slopes_host.data());
|
||||
|
||||
// Test each bias type
|
||||
struct BiasTest
|
||||
{
|
||||
const char* name;
|
||||
int bias_type_int;
|
||||
bias_enum bias_type;
|
||||
void* bias_ptr;
|
||||
int stride_bias;
|
||||
int nhead_stride_bias;
|
||||
int batch_stride_bias;
|
||||
};
|
||||
|
||||
BiasTest tests[] = {
|
||||
{"no_bias", 0, bias_enum::no_bias, nullptr, 0, 0, 0},
|
||||
{"elementwise_bias",
|
||||
1,
|
||||
bias_enum::elementwise_bias,
|
||||
elem_bias_dev.get(),
|
||||
seqlen,
|
||||
seqlen * seqlen,
|
||||
0},
|
||||
{"alibi", 2, bias_enum::alibi, alibi_slopes_dev.get(), 0, 1, 0},
|
||||
};
|
||||
|
||||
bool all_passed = true;
|
||||
|
||||
for(const auto& test : tests)
|
||||
{
|
||||
std::cout << "\nStep 3: Run FMHA Forward [" << test.name << "]\n";
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = test.bias_type;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.q_ptr = q_dev.get();
|
||||
fmha_args.k_ptr = k_dev.get();
|
||||
fmha_args.v_ptr = v_dev.get();
|
||||
fmha_args.o_ptr = o_dev.get();
|
||||
|
||||
fmha_args.bias_ptr = test.bias_ptr;
|
||||
fmha_args.q_descale_ptr = nullptr;
|
||||
fmha_args.k_descale_ptr = nullptr;
|
||||
fmha_args.v_descale_ptr = nullptr;
|
||||
fmha_args.rand_val_ptr = nullptr;
|
||||
fmha_args.lse_ptr = nullptr;
|
||||
fmha_args.sink_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fmha_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = seqlen;
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
fmha_args.scale_s = scale;
|
||||
fmha_args.logits_soft_cap = 0.0f;
|
||||
|
||||
// bhsd layout strides
|
||||
fmha_args.stride_q = hdim;
|
||||
fmha_args.stride_k = hdim;
|
||||
fmha_args.stride_v = hdim;
|
||||
fmha_args.stride_bias = test.stride_bias;
|
||||
fmha_args.stride_randval = 0;
|
||||
fmha_args.stride_o = hdim;
|
||||
|
||||
fmha_args.nhead_stride_q = seqlen * hdim;
|
||||
fmha_args.nhead_stride_k = seqlen * hdim;
|
||||
fmha_args.nhead_stride_v = seqlen * hdim;
|
||||
fmha_args.nhead_stride_bias = test.nhead_stride_bias;
|
||||
fmha_args.nhead_stride_randval = 0;
|
||||
fmha_args.nhead_stride_lse = 0;
|
||||
fmha_args.nhead_stride_o = seqlen * hdim;
|
||||
fmha_args.nhead_stride_q_descale = 0;
|
||||
fmha_args.nhead_stride_k_descale = 0;
|
||||
fmha_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fmha_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_bias = test.batch_stride_bias;
|
||||
fmha_args.batch_stride_randval = 0;
|
||||
fmha_args.batch_stride_lse = 0;
|
||||
fmha_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fmha_args.batch_stride_q_descale = 0;
|
||||
fmha_args.batch_stride_k_descale = 0;
|
||||
fmha_args.batch_stride_v_descale = 0;
|
||||
|
||||
fmha_args.window_size_left = -1;
|
||||
fmha_args.window_size_right = -1;
|
||||
fmha_args.sink_size = 0;
|
||||
fmha_args.mask_type = 0;
|
||||
fmha_args.min_seqlen_q = 0;
|
||||
fmha_args.p_drop = 0.0f;
|
||||
fmha_args.s_randval = false;
|
||||
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fmha_args.block_scale_size_q = 0;
|
||||
fmha_args.block_scale_size_kv = 0;
|
||||
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR [" << test.name << "]: " << e.what() << "\n";
|
||||
all_passed = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
|
||||
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Validate
|
||||
std::vector<FmhaDataType> o_host(o_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
|
||||
|
||||
if(nonzero == 0)
|
||||
all_passed = false;
|
||||
|
||||
if(args.has("--validate"))
|
||||
{
|
||||
std::vector<float> o_ref(o_elems, 0.0f);
|
||||
|
||||
if(test.bias_type_int == 0)
|
||||
{
|
||||
cpu_attention_fwd(
|
||||
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
const std::vector<float>& bias_ref =
|
||||
(test.bias_type_int == 1) ? elem_bias_host : alibi_slopes_host;
|
||||
cpu_attention_fwd_biased(q_f32,
|
||||
k_f32,
|
||||
v_f32,
|
||||
o_ref,
|
||||
batch,
|
||||
nhead,
|
||||
seqlen,
|
||||
seqlen,
|
||||
hdim,
|
||||
hdim,
|
||||
scale,
|
||||
test.bias_type_int,
|
||||
bias_ref);
|
||||
}
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
double max_rel_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
max_rel_err = std::max(max_rel_err, rel_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Max rel error: " << max_rel_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
|
||||
if(errors > 0)
|
||||
all_passed = false;
|
||||
}
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
697
dispatcher/examples/fmha/cpp/21_gpu_features_fmha.cpp
Normal file
697
dispatcher/examples/fmha/cpp/21_gpu_features_fmha.cpp
Normal file
@@ -0,0 +1,697 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 21: GPU Features FMHA
|
||||
//
|
||||
// Tests multiple FMHA features with real GPU execution:
|
||||
// 1. Dropout (with LSE, rand_val buffer)
|
||||
// 2. GQA (nhead_q=16, nhead_k=4, same kernel)
|
||||
// 3. LSE output (verify log-sum-exp values)
|
||||
//
|
||||
// Mirrors 01_basic_fmha.cpp for each feature variant.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(gpu_features_fmha_kernels,
|
||||
// Basic fp16 kernel (used for GQA -- GQA is a runtime concern, same kernel)
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Dropout kernel (requires LSE)
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(true)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// LSE-only kernel
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead_q,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
std::vector<float>* lse_out = nullptr)
|
||||
{
|
||||
const int nhead_ratio = nhead_q / nhead_k;
|
||||
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int hq = 0; hq < nhead_q; ++hq)
|
||||
{
|
||||
const int hk = hq / nhead_ratio;
|
||||
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead_q + hq) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead_k + hk) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
if(lse_out)
|
||||
{
|
||||
int lse_idx = (b * nhead_q + hq) * seqlen_q + sq;
|
||||
(*lse_out)[lse_idx] = max_score + std::log(sum_exp);
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead_k + hk) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead_q + hq) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FeatureResult
|
||||
{
|
||||
std::string name;
|
||||
bool passed;
|
||||
float time_ms;
|
||||
};
|
||||
|
||||
fmha_fwd_args make_base_args(void* q,
|
||||
void* k,
|
||||
void* v,
|
||||
void* o,
|
||||
int batch,
|
||||
int nhead_q,
|
||||
int nhead_k,
|
||||
int seqlen,
|
||||
int hdim,
|
||||
float scale)
|
||||
{
|
||||
fmha_fwd_args a{};
|
||||
a.q_ptr = q;
|
||||
a.k_ptr = k;
|
||||
a.v_ptr = v;
|
||||
a.o_ptr = o;
|
||||
|
||||
a.bias_ptr = nullptr;
|
||||
a.q_descale_ptr = nullptr;
|
||||
a.k_descale_ptr = nullptr;
|
||||
a.v_descale_ptr = nullptr;
|
||||
a.rand_val_ptr = nullptr;
|
||||
a.lse_ptr = nullptr;
|
||||
a.sink_ptr = nullptr;
|
||||
a.block_scale_seqstart_q_ptr = nullptr;
|
||||
a.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
a.seqlen_q = seqlen;
|
||||
a.seqlen_k = seqlen;
|
||||
a.batch = batch;
|
||||
a.max_seqlen_q = seqlen;
|
||||
a.hdim_q = hdim;
|
||||
a.hdim_v = hdim;
|
||||
a.nhead_q = nhead_q;
|
||||
a.nhead_k = nhead_k;
|
||||
a.scale_s = scale;
|
||||
a.logits_soft_cap = 0.0f;
|
||||
|
||||
a.stride_q = hdim;
|
||||
a.stride_k = hdim;
|
||||
a.stride_v = hdim;
|
||||
a.stride_bias = 0;
|
||||
a.stride_randval = 0;
|
||||
a.stride_o = hdim;
|
||||
|
||||
a.nhead_stride_q = seqlen * hdim;
|
||||
a.nhead_stride_k = seqlen * hdim;
|
||||
a.nhead_stride_v = seqlen * hdim;
|
||||
a.nhead_stride_bias = 0;
|
||||
a.nhead_stride_randval = 0;
|
||||
a.nhead_stride_lse = 0;
|
||||
a.nhead_stride_o = seqlen * hdim;
|
||||
a.nhead_stride_q_descale = 0;
|
||||
a.nhead_stride_k_descale = 0;
|
||||
a.nhead_stride_v_descale = 0;
|
||||
|
||||
a.batch_stride_q = nhead_q * seqlen * hdim;
|
||||
a.batch_stride_k = nhead_k * seqlen * hdim;
|
||||
a.batch_stride_v = nhead_k * seqlen * hdim;
|
||||
a.batch_stride_bias = 0;
|
||||
a.batch_stride_randval = 0;
|
||||
a.batch_stride_lse = 0;
|
||||
a.batch_stride_o = nhead_q * seqlen * hdim;
|
||||
a.batch_stride_q_descale = 0;
|
||||
a.batch_stride_k_descale = 0;
|
||||
a.batch_stride_v_descale = 0;
|
||||
|
||||
a.window_size_left = -1;
|
||||
a.window_size_right = -1;
|
||||
a.sink_size = 0;
|
||||
a.mask_type = 0;
|
||||
a.min_seqlen_q = 0;
|
||||
a.p_drop = 0.0f;
|
||||
a.s_randval = false;
|
||||
a.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
a.block_scale_size_q = 0;
|
||||
a.block_scale_size_kv = 0;
|
||||
|
||||
return a;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 21: GPU Features FMHA", "Dropout, GQA, LSE with real GPU data");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--seqlen", "64", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
print_header("Example 21: GPU Features FMHA");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("gpu_features_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FeatureResult> results;
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Feature A: GQA (nhead_q=16, nhead_k=4, same basic kernel)
|
||||
// -----------------------------------------------------------------------
|
||||
{
|
||||
std::cout << "\nStep 2a: GQA (nhead_q=16, nhead_k=4)\n";
|
||||
const int nhead_q = 16;
|
||||
const int nhead_k = 4;
|
||||
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead_q * seqlen * hdim;
|
||||
const int64_t k_elems = static_cast<int64_t>(batch) * nhead_k * seqlen * hdim;
|
||||
const int64_t o_elems = q_elems;
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(k_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(k_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(o_elems);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems), k_host(k_elems), v_host(k_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
auto fmha_args = make_base_args(q_dev.get(),
|
||||
k_dev.get(),
|
||||
v_dev.get(),
|
||||
o_dev.get(),
|
||||
batch,
|
||||
nhead_q,
|
||||
nhead_k,
|
||||
seqlen,
|
||||
hdim,
|
||||
scale);
|
||||
|
||||
bool passed = false;
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
|
||||
// Validate against CPU reference with GQA head repetition
|
||||
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(k_elems);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < k_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < k_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
std::vector<float> o_ref(o_elems, 0.0f);
|
||||
cpu_attention_fwd(q_f32,
|
||||
k_f32,
|
||||
v_f32,
|
||||
o_ref,
|
||||
batch,
|
||||
nhead_q,
|
||||
nhead_k,
|
||||
seqlen,
|
||||
seqlen,
|
||||
hdim,
|
||||
hdim,
|
||||
scale);
|
||||
|
||||
std::vector<FmhaDataType> o_host(o_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
|
||||
passed = (errors == 0);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
}
|
||||
results.push_back({"GQA (16q/4k)", passed, time_ms});
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Feature B: LSE output
|
||||
// -----------------------------------------------------------------------
|
||||
{
|
||||
std::cout << "\nStep 2b: LSE Output\n";
|
||||
const int nhead = 4;
|
||||
|
||||
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
|
||||
GpuBuffer<float> lse_dev(lse_elems);
|
||||
|
||||
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
lse_dev.zero();
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = true;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
auto fmha_args = make_base_args(q_dev.get(),
|
||||
k_dev.get(),
|
||||
v_dev.get(),
|
||||
o_dev.get(),
|
||||
batch,
|
||||
nhead,
|
||||
nhead,
|
||||
seqlen,
|
||||
hdim,
|
||||
scale);
|
||||
fmha_args.lse_ptr = lse_dev.get();
|
||||
fmha_args.nhead_stride_lse = seqlen;
|
||||
fmha_args.batch_stride_lse = nhead * seqlen;
|
||||
|
||||
bool passed = false;
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
|
||||
// Compute CPU reference LSE
|
||||
std::vector<float> q_f32(qkv_elems), k_f32(qkv_elems), v_f32(qkv_elems);
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
std::vector<float> o_ref(qkv_elems, 0.0f);
|
||||
std::vector<float> lse_ref(lse_elems, 0.0f);
|
||||
cpu_attention_fwd(q_f32,
|
||||
k_f32,
|
||||
v_f32,
|
||||
o_ref,
|
||||
batch,
|
||||
nhead,
|
||||
nhead,
|
||||
seqlen,
|
||||
seqlen,
|
||||
hdim,
|
||||
hdim,
|
||||
scale,
|
||||
&lse_ref);
|
||||
|
||||
std::vector<float> lse_host(lse_elems);
|
||||
lse_dev.copy_to_host(lse_host.data());
|
||||
|
||||
int lse_reasonable = 0;
|
||||
double max_lse_err = 0.0;
|
||||
for(int64_t i = 0; i < lse_elems; ++i)
|
||||
{
|
||||
if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f)
|
||||
++lse_reasonable;
|
||||
double err = std::abs(lse_host[i] - lse_ref[i]);
|
||||
max_lse_err = std::max(max_lse_err, err);
|
||||
}
|
||||
std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n";
|
||||
std::cout << " LSE max error vs ref: " << std::scientific << max_lse_err << "\n";
|
||||
std::cout << " LSE sample [0..3]: ";
|
||||
for(int i = 0; i < std::min<int>(4, lse_elems); ++i)
|
||||
std::cout << std::fixed << std::setprecision(4) << lse_host[i] << " ";
|
||||
std::cout << "\n";
|
||||
passed = (lse_reasonable == lse_elems) && (max_lse_err < 1.0);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
}
|
||||
results.push_back({"LSE", passed, time_ms});
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Feature C: Dropout
|
||||
// -----------------------------------------------------------------------
|
||||
{
|
||||
std::cout << "\nStep 2c: Dropout (p_drop=0.2)\n";
|
||||
const int nhead = 4;
|
||||
|
||||
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
|
||||
const int64_t randval_elems = static_cast<int64_t>(batch) * nhead * seqlen * seqlen;
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
|
||||
GpuBuffer<float> lse_dev(lse_elems);
|
||||
GpuBuffer<uint8_t> rand_val_dev(randval_elems);
|
||||
|
||||
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
lse_dev.zero();
|
||||
rand_val_dev.zero();
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = true;
|
||||
traits.has_dropout = true;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
auto fmha_args = make_base_args(q_dev.get(),
|
||||
k_dev.get(),
|
||||
v_dev.get(),
|
||||
o_dev.get(),
|
||||
batch,
|
||||
nhead,
|
||||
nhead,
|
||||
seqlen,
|
||||
hdim,
|
||||
scale);
|
||||
fmha_args.lse_ptr = lse_dev.get();
|
||||
fmha_args.rand_val_ptr = rand_val_dev.get();
|
||||
fmha_args.nhead_stride_lse = seqlen;
|
||||
fmha_args.batch_stride_lse = nhead * seqlen;
|
||||
fmha_args.stride_randval = seqlen;
|
||||
fmha_args.nhead_stride_randval = seqlen * seqlen;
|
||||
fmha_args.batch_stride_randval = nhead * seqlen * seqlen;
|
||||
fmha_args.p_drop = 0.2f;
|
||||
fmha_args.s_randval = true;
|
||||
fmha_args.drop_seed_offset = std::make_pair(uint64_t(42), uint64_t(0));
|
||||
|
||||
bool passed = false;
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
|
||||
std::vector<FmhaDataType> o_host(qkv_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
std::cout << " Non-zero outputs: " << nonzero << " / " << qkv_elems << "\n";
|
||||
|
||||
std::vector<float> lse_host(lse_elems);
|
||||
lse_dev.copy_to_host(lse_host.data());
|
||||
int lse_reasonable = 0;
|
||||
for(int64_t i = 0; i < lse_elems; ++i)
|
||||
{
|
||||
if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f)
|
||||
++lse_reasonable;
|
||||
}
|
||||
std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n";
|
||||
passed = (nonzero > 0) && (lse_reasonable == lse_elems);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
}
|
||||
results.push_back({"Dropout", passed, time_ms});
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Summary
|
||||
// -----------------------------------------------------------------------
|
||||
std::cout << "\nStep 3: Summary\n";
|
||||
std::cout << " " << std::setw(16) << "Feature" << " | " << std::setw(10) << "Time(ms)" << " | "
|
||||
<< std::setw(8) << "Status" << "\n";
|
||||
std::cout << " " << std::string(42, '-') << "\n";
|
||||
|
||||
bool all_passed = true;
|
||||
for(const auto& r : results)
|
||||
{
|
||||
std::cout << " " << std::setw(16) << r.name << " | " << std::fixed << std::setprecision(4)
|
||||
<< std::setw(10) << r.time_ms << " | " << std::setw(8)
|
||||
<< (r.passed ? "PASS" : "FAIL") << "\n";
|
||||
if(!r.passed)
|
||||
all_passed = false;
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
553
dispatcher/examples/fmha/cpp/22_gpu_bwd_fmha.cpp
Normal file
553
dispatcher/examples/fmha/cpp/22_gpu_bwd_fmha.cpp
Normal file
@@ -0,0 +1,553 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 22: FMHA Backward with GPU Execution
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Declare 3 backward kernel families (bwd_dot_do_o, bwd_dq_dk_dv, bwd_convert_dq)
|
||||
// 2. Run forward to get O and LSE
|
||||
// 3. Run backward to compute dQ, dK, dV
|
||||
// 4. Validate gradients are non-zero
|
||||
//
|
||||
// Falls back to planning only if backward kernels fail to compile on gfx950.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(gpu_bwd_fmha_kernels,
|
||||
// Forward kernel (to produce O and LSE for backward)
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward: dot(dO, O) to compute d scalar
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward: compute dQ, dK, dV
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true)
|
||||
.max_seq_len_q(0)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward: convert accumulated dQ from fp32 to fp16
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
std::vector<float>& LSE,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
int lse_idx = (b * nhead + h) * seqlen_q + sq;
|
||||
LSE[lse_idx] = max_score + std::log(sum_exp);
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 22: FMHA Backward (GPU)", "Forward + backward with GPU validation");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "1", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 1);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
print_header("Example 22: FMHA Backward (GPU Execution)");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("gpu_bwd_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// Step 2: Plan backward to verify all 3 stages resolve
|
||||
std::cout << "\nStep 2: Plan Backward\n";
|
||||
|
||||
fmha_bwd_traits bwd_traits{};
|
||||
bwd_traits.hdim_q = hdim;
|
||||
bwd_traits.hdim_v = hdim;
|
||||
bwd_traits.data_type = "fp16";
|
||||
bwd_traits.is_group_mode = false;
|
||||
bwd_traits.mask_type = mask_enum::no_mask;
|
||||
bwd_traits.bias_type = bias_enum::no_bias;
|
||||
bwd_traits.has_dbias = false;
|
||||
bwd_traits.has_dropout = false;
|
||||
bwd_traits.is_store_randval = false;
|
||||
bwd_traits.is_deterministic = false;
|
||||
|
||||
fmha_bwd_args bwd_args{};
|
||||
bwd_args.batch = batch;
|
||||
bwd_args.seqlen_q = seqlen;
|
||||
bwd_args.seqlen_k = seqlen;
|
||||
bwd_args.max_seqlen_q = seqlen;
|
||||
bwd_args.max_seqlen_k = seqlen;
|
||||
bwd_args.hdim_q = hdim;
|
||||
bwd_args.hdim_v = hdim;
|
||||
bwd_args.nhead_q = nhead;
|
||||
bwd_args.nhead_k = nhead;
|
||||
|
||||
auto bwd_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
|
||||
|
||||
if(!bwd_plan.is_valid() || bwd_plan.stages.size() < 2)
|
||||
{
|
||||
std::cout << " Backward plan: INVALID (expected multi-stage)\n";
|
||||
std::cout << " Falling back to planning-only mode (like 04_bwd_fmha.cpp)\n";
|
||||
print_separator();
|
||||
std::cout << "Status: PLAN_ONLY\n";
|
||||
print_separator();
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::cout << " Backward plan stages:\n";
|
||||
for(const auto& stage : bwd_plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
|
||||
// Step 3: Allocate buffers
|
||||
std::cout << "\nStep 3: Allocate GPU Buffers\n";
|
||||
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
|
||||
const int64_t dq_acc_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
|
||||
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
|
||||
<< "]\n";
|
||||
std::cout << " LSE/d: [" << batch << ", " << nhead << ", " << seqlen << "]\n";
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
|
||||
GpuBuffer<float> lse_dev(lse_elems);
|
||||
GpuBuffer<FmhaDataType> do_dev(qkv_elems);
|
||||
GpuBuffer<float> d_dev(lse_elems);
|
||||
GpuBuffer<FmhaDataType> dq_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> dk_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> dv_dev(qkv_elems);
|
||||
GpuBuffer<float> dq_acc_dev(dq_acc_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
|
||||
std::vector<FmhaDataType> do_host(qkv_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : do_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
do_dev.copy_from_host(do_host.data());
|
||||
o_dev.zero();
|
||||
lse_dev.zero();
|
||||
d_dev.zero();
|
||||
dq_dev.zero();
|
||||
dk_dev.zero();
|
||||
dv_dev.zero();
|
||||
dq_acc_dev.zero();
|
||||
|
||||
// Step 4: Run forward to produce O and LSE
|
||||
std::cout << "\nStep 4: Run Forward (to produce O and LSE)\n";
|
||||
{
|
||||
fmha_fwd_traits fwd_traits{};
|
||||
fwd_traits.hdim_q = hdim;
|
||||
fwd_traits.hdim_v = hdim;
|
||||
fwd_traits.data_type = "fp16";
|
||||
fwd_traits.is_group_mode = false;
|
||||
fwd_traits.is_v_rowmajor = true;
|
||||
fwd_traits.has_logits_soft_cap = false;
|
||||
fwd_traits.mask_type = mask_enum::no_mask;
|
||||
fwd_traits.bias_type = bias_enum::no_bias;
|
||||
fwd_traits.has_lse = true;
|
||||
fwd_traits.has_dropout = false;
|
||||
fwd_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.q_ptr = q_dev.get();
|
||||
fwd_args.k_ptr = k_dev.get();
|
||||
fwd_args.v_ptr = v_dev.get();
|
||||
fwd_args.o_ptr = o_dev.get();
|
||||
fwd_args.lse_ptr = lse_dev.get();
|
||||
|
||||
fwd_args.bias_ptr = nullptr;
|
||||
fwd_args.q_descale_ptr = nullptr;
|
||||
fwd_args.k_descale_ptr = nullptr;
|
||||
fwd_args.v_descale_ptr = nullptr;
|
||||
fwd_args.rand_val_ptr = nullptr;
|
||||
fwd_args.sink_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fwd_args.seqlen_q = seqlen;
|
||||
fwd_args.seqlen_k = seqlen;
|
||||
fwd_args.batch = batch;
|
||||
fwd_args.max_seqlen_q = seqlen;
|
||||
fwd_args.hdim_q = hdim;
|
||||
fwd_args.hdim_v = hdim;
|
||||
fwd_args.nhead_q = nhead;
|
||||
fwd_args.nhead_k = nhead;
|
||||
fwd_args.scale_s = scale;
|
||||
fwd_args.logits_soft_cap = 0.0f;
|
||||
|
||||
fwd_args.stride_q = hdim;
|
||||
fwd_args.stride_k = hdim;
|
||||
fwd_args.stride_v = hdim;
|
||||
fwd_args.stride_bias = 0;
|
||||
fwd_args.stride_randval = 0;
|
||||
fwd_args.stride_o = hdim;
|
||||
|
||||
fwd_args.nhead_stride_q = seqlen * hdim;
|
||||
fwd_args.nhead_stride_k = seqlen * hdim;
|
||||
fwd_args.nhead_stride_v = seqlen * hdim;
|
||||
fwd_args.nhead_stride_bias = 0;
|
||||
fwd_args.nhead_stride_randval = 0;
|
||||
fwd_args.nhead_stride_lse = seqlen;
|
||||
fwd_args.nhead_stride_o = seqlen * hdim;
|
||||
fwd_args.nhead_stride_q_descale = 0;
|
||||
fwd_args.nhead_stride_k_descale = 0;
|
||||
fwd_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fwd_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_bias = 0;
|
||||
fwd_args.batch_stride_randval = 0;
|
||||
fwd_args.batch_stride_lse = nhead * seqlen;
|
||||
fwd_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_q_descale = 0;
|
||||
fwd_args.batch_stride_k_descale = 0;
|
||||
fwd_args.batch_stride_v_descale = 0;
|
||||
|
||||
fwd_args.window_size_left = -1;
|
||||
fwd_args.window_size_right = -1;
|
||||
fwd_args.sink_size = 0;
|
||||
fwd_args.mask_type = 0;
|
||||
fwd_args.min_seqlen_q = 0;
|
||||
fwd_args.p_drop = 0.0f;
|
||||
fwd_args.s_randval = false;
|
||||
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fwd_args.block_scale_size_q = 0;
|
||||
fwd_args.block_scale_size_kv = 0;
|
||||
|
||||
try
|
||||
{
|
||||
float fwd_time = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
|
||||
std::cout << " Forward time: " << std::fixed << std::setprecision(4) << fwd_time
|
||||
<< " ms\n";
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " Forward ERROR: " << e.what() << "\n";
|
||||
print_separator();
|
||||
std::cout << "Status: FAIL (forward failed)\n";
|
||||
print_separator();
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Run backward
|
||||
std::cout << "\nStep 5: Run Backward\n";
|
||||
|
||||
bwd_args.q_ptr = q_dev.get();
|
||||
bwd_args.k_ptr = k_dev.get();
|
||||
bwd_args.v_ptr = v_dev.get();
|
||||
bwd_args.bias_ptr = nullptr;
|
||||
bwd_args.o_ptr = o_dev.get();
|
||||
bwd_args.lse_ptr = lse_dev.get();
|
||||
bwd_args.do_ptr = do_dev.get();
|
||||
bwd_args.d_ptr = d_dev.get();
|
||||
bwd_args.rand_val_ptr = nullptr;
|
||||
bwd_args.dq_ptr = dq_dev.get();
|
||||
bwd_args.dk_ptr = dk_dev.get();
|
||||
bwd_args.dv_ptr = dv_dev.get();
|
||||
bwd_args.dbias_ptr = nullptr;
|
||||
bwd_args.dq_acc_ptr = dq_acc_dev.get();
|
||||
bwd_args.scale = scale;
|
||||
|
||||
bwd_args.stride_q = hdim;
|
||||
bwd_args.stride_k = hdim;
|
||||
bwd_args.stride_v = hdim;
|
||||
bwd_args.stride_bias = 0;
|
||||
bwd_args.stride_o = hdim;
|
||||
bwd_args.stride_randval = 0;
|
||||
bwd_args.stride_do = hdim;
|
||||
bwd_args.stride_dq_acc = hdim;
|
||||
bwd_args.stride_dq = hdim;
|
||||
bwd_args.stride_dk = hdim;
|
||||
bwd_args.stride_dv = hdim;
|
||||
bwd_args.stride_dbias = 0;
|
||||
|
||||
bwd_args.nhead_stride_q = seqlen * hdim;
|
||||
bwd_args.nhead_stride_k = seqlen * hdim;
|
||||
bwd_args.nhead_stride_v = seqlen * hdim;
|
||||
bwd_args.nhead_stride_bias = 0;
|
||||
bwd_args.nhead_stride_o = seqlen * hdim;
|
||||
bwd_args.nhead_stride_randval = 0;
|
||||
bwd_args.nhead_stride_do = seqlen * hdim;
|
||||
bwd_args.nhead_stride_lsed = seqlen;
|
||||
bwd_args.nhead_stride_dq_acc = static_cast<int64_t>(seqlen) * hdim;
|
||||
bwd_args.nhead_stride_dq = seqlen * hdim;
|
||||
bwd_args.nhead_stride_dk = seqlen * hdim;
|
||||
bwd_args.nhead_stride_dv = seqlen * hdim;
|
||||
bwd_args.nhead_stride_dbias = 0;
|
||||
|
||||
bwd_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
bwd_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
bwd_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
bwd_args.batch_stride_bias = 0;
|
||||
bwd_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
bwd_args.batch_stride_randval = 0;
|
||||
bwd_args.batch_stride_do = nhead * seqlen * hdim;
|
||||
bwd_args.batch_stride_lsed = nhead * seqlen;
|
||||
bwd_args.batch_stride_dq_acc = static_cast<int64_t>(nhead) * seqlen * hdim;
|
||||
bwd_args.batch_stride_dq = nhead * seqlen * hdim;
|
||||
bwd_args.batch_stride_dk = nhead * seqlen * hdim;
|
||||
bwd_args.batch_stride_dv = nhead * seqlen * hdim;
|
||||
bwd_args.batch_stride_dbias = 0;
|
||||
bwd_args.split_stride_dq_acc = 0;
|
||||
|
||||
bwd_args.window_size_left = -1;
|
||||
bwd_args.window_size_right = -1;
|
||||
bwd_args.mask_type = 0;
|
||||
bwd_args.p_drop = 0.0f;
|
||||
bwd_args.p_undrop = 1.0f;
|
||||
bwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
|
||||
bool bwd_passed = false;
|
||||
try
|
||||
{
|
||||
float bwd_time = dispatcher.run_bwd(bwd_traits, bwd_args, nullptr);
|
||||
std::cout << " Backward time: " << std::fixed << std::setprecision(4) << bwd_time
|
||||
<< " ms\n";
|
||||
|
||||
// Validate: dQ, dK, dV should be non-zero
|
||||
std::vector<FmhaDataType> dq_host(qkv_elems), dk_host(qkv_elems), dv_host(qkv_elems);
|
||||
dq_dev.copy_to_host(dq_host.data());
|
||||
dk_dev.copy_to_host(dk_host.data());
|
||||
dv_dev.copy_to_host(dv_host.data());
|
||||
|
||||
auto count_nonzero = [](const std::vector<FmhaDataType>& buf) {
|
||||
int nz = 0;
|
||||
for(const auto& x : buf)
|
||||
{
|
||||
if(static_cast<float>(x) != 0.0f)
|
||||
++nz;
|
||||
}
|
||||
return nz;
|
||||
};
|
||||
|
||||
int dq_nz = count_nonzero(dq_host);
|
||||
int dk_nz = count_nonzero(dk_host);
|
||||
int dv_nz = count_nonzero(dv_host);
|
||||
|
||||
std::cout << " dQ non-zero: " << dq_nz << " / " << qkv_elems << "\n";
|
||||
std::cout << " dK non-zero: " << dk_nz << " / " << qkv_elems << "\n";
|
||||
std::cout << " dV non-zero: " << dv_nz << " / " << qkv_elems << "\n";
|
||||
|
||||
bwd_passed = (dq_nz > 0) && (dk_nz > 0) && (dv_nz > 0);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " Backward ERROR: " << e.what() << "\n";
|
||||
std::cout << " Falling back to planning-only mode (like 04_bwd_fmha.cpp)\n";
|
||||
std::cout << " Backward plan was valid with " << bwd_plan.stages.size() << " stages\n";
|
||||
print_separator();
|
||||
std::cout << "Status: PLAN_ONLY\n";
|
||||
print_separator();
|
||||
return 0;
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (bwd_passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return bwd_passed ? 0 : 1;
|
||||
}
|
||||
595
dispatcher/examples/fmha/cpp/23_multi_registry_fmha.cpp
Normal file
595
dispatcher/examples/fmha/cpp/23_multi_registry_fmha.cpp
Normal file
@@ -0,0 +1,595 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 23: Multiple Registries for Different Frameworks
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Three separate FmhaRegistry instances (pytorch, flash, aiter)
|
||||
// 2. Each with its own DECL_FMHA_KERNEL_SET using different configs
|
||||
// 3. Registry introspection: size(), filter(), export_json()
|
||||
// 4. Planning the same problem from each registry
|
||||
// 5. GPU execution from the basic kernel registry
|
||||
//
|
||||
// Key idea: separate registries let each framework recipient own its
|
||||
// kernel population independently.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
// Three DECL_FMHA_KERNEL_SETs with distinct names and configurations.
|
||||
// All register into the global FmhaKernelSetRegistry at static init time.
|
||||
|
||||
DECL_FMHA_KERNEL_SET(pytorch_reg_kernels,
|
||||
// PyTorch: basic fp16, elementwise bias
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("bias")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.profile("pytorch"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
DECL_FMHA_KERNEL_SET(flash_reg_kernels,
|
||||
// Flash: fp16, alibi bias
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.profile("flash_fwd"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
DECL_FMHA_KERNEL_SET(aiter_reg_kernels,
|
||||
// AITER: batch mode basic
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.profile("aiter_batch"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
// AITER: group mode
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("group")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.profile("aiter_group"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
scores[sk] /= sum_exp;
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RegistryInfo
|
||||
{
|
||||
std::string name;
|
||||
FmhaRegistry* reg;
|
||||
FmhaDispatcher* disp;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 23: Multi-Registry FMHA",
|
||||
"Separate registries per framework recipient");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
print_header("Example 23: Multi-Registry FMHA");
|
||||
|
||||
// Step 1: Create 3 separate registries
|
||||
std::cout << "\nStep 1: Create Separate Registries\n";
|
||||
std::cout << " Global kernel sets declared: " << FmhaKernelSetRegistry::instance().size()
|
||||
<< "\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry pytorch_reg;
|
||||
pytorch_reg.set_name("pytorch");
|
||||
REGISTER_GENERATED_KERNELS(pytorch_reg, gfx_arch);
|
||||
|
||||
FmhaRegistry flash_reg;
|
||||
flash_reg.set_name("flash");
|
||||
REGISTER_GENERATED_KERNELS(flash_reg, gfx_arch);
|
||||
|
||||
FmhaRegistry aiter_reg;
|
||||
aiter_reg.set_name("aiter");
|
||||
REGISTER_GENERATED_KERNELS(aiter_reg, gfx_arch);
|
||||
|
||||
FmhaDispatcher pytorch_disp(&pytorch_reg);
|
||||
FmhaDispatcher flash_disp(&flash_reg);
|
||||
FmhaDispatcher aiter_disp(&aiter_reg);
|
||||
|
||||
std::vector<RegistryInfo> registries = {
|
||||
{"pytorch", &pytorch_reg, &pytorch_disp},
|
||||
{"flash", &flash_reg, &flash_disp},
|
||||
{"aiter", &aiter_reg, &aiter_disp},
|
||||
};
|
||||
|
||||
// Step 2: Registry introspection
|
||||
std::cout << "\nStep 2: Registry Introspection\n";
|
||||
for(const auto& ri : registries)
|
||||
{
|
||||
std::cout << "\n Registry: " << ri.name << "\n";
|
||||
std::cout << " Kernel count: " << ri.reg->size() << "\n";
|
||||
|
||||
auto all_kernels = ri.reg->get_all();
|
||||
for(const auto& k : all_kernels)
|
||||
{
|
||||
std::cout << " Kernel: " << k->get_name() << "\n";
|
||||
}
|
||||
|
||||
auto fwd_kernels = ri.reg->filter([](const FmhaKernelInstance& inst) {
|
||||
return inst.get_key().signature.family == FmhaKernelFamily::Fwd;
|
||||
});
|
||||
std::cout << " Forward kernels: " << fwd_kernels.size() << "\n";
|
||||
|
||||
std::string json = ri.reg->export_json(false);
|
||||
std::cout << " JSON size: " << json.size() << " bytes\n";
|
||||
}
|
||||
|
||||
// Step 3: Plan the same problem from each registry
|
||||
std::cout << "\nStep 3: Plan from Each Registry\n";
|
||||
|
||||
// Problem A: basic fp16 no-bias (matches aiter_batch)
|
||||
{
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = seqlen;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
|
||||
std::cout << "\n Problem: fp16 batch no-bias\n";
|
||||
for(const auto& ri : registries)
|
||||
{
|
||||
auto plan = ri.disp->plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
|
||||
std::cout << " " << ri.name << ": "
|
||||
<< (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Problem B: fp16 with alibi bias (matches flash)
|
||||
{
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::alibi;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = seqlen;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
|
||||
std::cout << "\n Problem: fp16 batch alibi-bias\n";
|
||||
for(const auto& ri : registries)
|
||||
{
|
||||
auto plan = ri.disp->plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
|
||||
std::cout << " " << ri.name << ": "
|
||||
<< (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Problem C: fp16 with elementwise bias (matches pytorch)
|
||||
{
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::elementwise_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = seqlen;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
|
||||
std::cout << "\n Problem: fp16 batch elementwise-bias\n";
|
||||
for(const auto& ri : registries)
|
||||
{
|
||||
auto plan = ri.disp->plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
|
||||
std::cout << " " << ri.name << ": "
|
||||
<< (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: GPU execution from AITER registry (basic no-bias kernel)
|
||||
std::cout << "\nStep 4: GPU Execution (aiter registry)\n";
|
||||
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(q_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems), k_host(q_elems), v_host(q_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_traits run_traits{};
|
||||
run_traits.hdim_q = hdim;
|
||||
run_traits.hdim_v = hdim;
|
||||
run_traits.data_type = "fp16";
|
||||
run_traits.is_group_mode = false;
|
||||
run_traits.is_v_rowmajor = true;
|
||||
run_traits.has_logits_soft_cap = false;
|
||||
run_traits.mask_type = mask_enum::no_mask;
|
||||
run_traits.bias_type = bias_enum::no_bias;
|
||||
run_traits.has_lse = false;
|
||||
run_traits.has_dropout = false;
|
||||
run_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args run_args{};
|
||||
run_args.q_ptr = q_dev.get();
|
||||
run_args.k_ptr = k_dev.get();
|
||||
run_args.v_ptr = v_dev.get();
|
||||
run_args.o_ptr = o_dev.get();
|
||||
|
||||
run_args.bias_ptr = nullptr;
|
||||
run_args.q_descale_ptr = nullptr;
|
||||
run_args.k_descale_ptr = nullptr;
|
||||
run_args.v_descale_ptr = nullptr;
|
||||
run_args.rand_val_ptr = nullptr;
|
||||
run_args.lse_ptr = nullptr;
|
||||
run_args.sink_ptr = nullptr;
|
||||
run_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
run_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
run_args.seqlen_q = seqlen;
|
||||
run_args.seqlen_k = seqlen;
|
||||
run_args.batch = batch;
|
||||
run_args.max_seqlen_q = seqlen;
|
||||
run_args.hdim_q = hdim;
|
||||
run_args.hdim_v = hdim;
|
||||
run_args.nhead_q = nhead;
|
||||
run_args.nhead_k = nhead;
|
||||
run_args.scale_s = scale;
|
||||
run_args.logits_soft_cap = 0.0f;
|
||||
|
||||
run_args.stride_q = hdim;
|
||||
run_args.stride_k = hdim;
|
||||
run_args.stride_v = hdim;
|
||||
run_args.stride_bias = 0;
|
||||
run_args.stride_randval = 0;
|
||||
run_args.stride_o = hdim;
|
||||
|
||||
run_args.nhead_stride_q = seqlen * hdim;
|
||||
run_args.nhead_stride_k = seqlen * hdim;
|
||||
run_args.nhead_stride_v = seqlen * hdim;
|
||||
run_args.nhead_stride_bias = 0;
|
||||
run_args.nhead_stride_randval = 0;
|
||||
run_args.nhead_stride_lse = 0;
|
||||
run_args.nhead_stride_o = seqlen * hdim;
|
||||
run_args.nhead_stride_q_descale = 0;
|
||||
run_args.nhead_stride_k_descale = 0;
|
||||
run_args.nhead_stride_v_descale = 0;
|
||||
|
||||
run_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
run_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
run_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
run_args.batch_stride_bias = 0;
|
||||
run_args.batch_stride_randval = 0;
|
||||
run_args.batch_stride_lse = 0;
|
||||
run_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
run_args.batch_stride_q_descale = 0;
|
||||
run_args.batch_stride_k_descale = 0;
|
||||
run_args.batch_stride_v_descale = 0;
|
||||
|
||||
run_args.window_size_left = -1;
|
||||
run_args.window_size_right = -1;
|
||||
run_args.sink_size = 0;
|
||||
run_args.mask_type = 0;
|
||||
run_args.min_seqlen_q = 0;
|
||||
run_args.p_drop = 0.0f;
|
||||
run_args.s_randval = false;
|
||||
run_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
run_args.block_scale_size_q = 0;
|
||||
run_args.block_scale_size_kv = 0;
|
||||
|
||||
bool passed = false;
|
||||
aiter_disp.set_benchmarking(true);
|
||||
aiter_disp.set_timing(1, 3);
|
||||
try
|
||||
{
|
||||
float time_ms = aiter_disp.run_fwd(run_traits, run_args, nullptr);
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
|
||||
std::vector<FmhaDataType> o_host(q_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
// Validate
|
||||
std::vector<float> q_f32(q_elems), k_f32(q_elems), v_f32(q_elems), o_ref(q_elems, 0.0f);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
cpu_attention_fwd(
|
||||
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << q_elems << "\n";
|
||||
passed = (errors == 0);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return passed ? 0 : 1;
|
||||
}
|
||||
549
dispatcher/examples/fmha/cpp/24_per_receipt_registries_fmha.cpp
Normal file
549
dispatcher/examples/fmha/cpp/24_per_receipt_registries_fmha.cpp
Normal file
@@ -0,0 +1,549 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 24: Per-Receipt Registries
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Four DECL_FMHA_KERNEL_SET declarations, each named after a receipt
|
||||
// 2. Each registered into a separate FmhaRegistry
|
||||
// 3. Per-registry: kernel count, kernel names, plan a problem, selected kernel
|
||||
// 4. GPU execution from the ck_default receipt (the basic working kernel)
|
||||
// 5. Comparison table showing which features each receipt supports
|
||||
//
|
||||
// Receipt = a curated kernel set shipped to a specific downstream consumer.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
// Receipt 1: CK default -- basic fp16, no mask, no bias
|
||||
DECL_FMHA_KERNEL_SET(ck_default_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
// Receipt 2: Flash forward -- fp16 with alibi bias
|
||||
DECL_FMHA_KERNEL_SET(flash_fwd_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.profile("flash_fwd"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
// Receipt 3: PyTorch -- fp16 with elementwise bias
|
||||
DECL_FMHA_KERNEL_SET(pytorch_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("bias")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.profile("pytorch"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
// Receipt 4: AITER batch -- fp16 batch mode with LSE
|
||||
DECL_FMHA_KERNEL_SET(aiter_batch_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.profile("aiter_batch"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
struct ReceiptInfo
|
||||
{
|
||||
std::string name;
|
||||
std::string bias_desc;
|
||||
bool has_lse;
|
||||
FmhaRegistry registry;
|
||||
};
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
scores[sk] /= sum_exp;
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 24: Per-Receipt Registries",
|
||||
"Curated kernel sets per downstream consumer");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
print_header("Example 24: Per-Receipt Registries");
|
||||
|
||||
// Step 1: Create per-receipt registries
|
||||
std::cout << "\nStep 1: Create Per-Receipt Registries\n";
|
||||
std::cout << " Global kernel sets: " << FmhaKernelSetRegistry::instance().size() << "\n";
|
||||
|
||||
std::vector<ReceiptInfo> receipts;
|
||||
|
||||
receipts.push_back({"ck_default", "none", false, FmhaRegistry()});
|
||||
receipts.back().registry.set_name("ck_default");
|
||||
REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch);
|
||||
|
||||
receipts.push_back({"flash_fwd", "alibi", false, FmhaRegistry()});
|
||||
receipts.back().registry.set_name("flash_fwd");
|
||||
REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch);
|
||||
|
||||
receipts.push_back({"pytorch", "elementwise", false, FmhaRegistry()});
|
||||
receipts.back().registry.set_name("pytorch");
|
||||
REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch);
|
||||
|
||||
receipts.push_back({"aiter_batch", "none", true, FmhaRegistry()});
|
||||
receipts.back().registry.set_name("aiter_batch");
|
||||
REGISTER_GENERATED_KERNELS(receipts.back().registry, gfx_arch);
|
||||
|
||||
// Step 2: Per-registry introspection
|
||||
std::cout << "\nStep 2: Per-Receipt Introspection\n";
|
||||
for(auto& r : receipts)
|
||||
{
|
||||
std::cout << "\n Receipt: " << r.name << "\n";
|
||||
std::cout << " Kernel count: " << r.registry.size() << "\n";
|
||||
|
||||
auto all = r.registry.get_all();
|
||||
for(const auto& k : all)
|
||||
{
|
||||
std::cout << " Kernel: " << k->get_name() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Plan a matching problem for each receipt
|
||||
std::cout << "\nStep 3: Plan per Receipt\n";
|
||||
|
||||
struct PlanTest
|
||||
{
|
||||
std::string receipt_name;
|
||||
bias_enum bias;
|
||||
bool lse;
|
||||
};
|
||||
std::vector<PlanTest> plan_tests = {
|
||||
{"ck_default", bias_enum::no_bias, false},
|
||||
{"flash_fwd", bias_enum::alibi, false},
|
||||
{"pytorch", bias_enum::elementwise_bias, false},
|
||||
{"aiter_batch", bias_enum::no_bias, true},
|
||||
};
|
||||
|
||||
for(std::size_t i = 0; i < plan_tests.size(); ++i)
|
||||
{
|
||||
const auto& pt = plan_tests[i];
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = pt.bias;
|
||||
traits.has_lse = pt.lse;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fmha_args{};
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.seqlen_q = seqlen;
|
||||
fmha_args.seqlen_k = seqlen;
|
||||
fmha_args.max_seqlen_q = seqlen;
|
||||
fmha_args.hdim_q = hdim;
|
||||
fmha_args.hdim_v = hdim;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead;
|
||||
|
||||
FmhaDispatcher disp(&receipts[i].registry);
|
||||
auto plan = disp.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch));
|
||||
|
||||
std::cout << " " << pt.receipt_name << ": "
|
||||
<< (plan.is_valid() ? plan.stages[0].kernel_id : "NO MATCH") << "\n";
|
||||
}
|
||||
|
||||
// Step 4: Comparison table
|
||||
std::cout << "\nStep 4: Receipt Feature Comparison\n\n";
|
||||
std::cout << " " << std::setw(14) << "Receipt" << " | " << std::setw(14) << "Bias" << " | "
|
||||
<< std::setw(5) << "LSE" << " | " << std::setw(8) << "Kernels" << "\n";
|
||||
std::cout << " " << std::string(50, '-') << "\n";
|
||||
|
||||
struct CompRow
|
||||
{
|
||||
std::string name;
|
||||
std::string bias;
|
||||
std::string lse;
|
||||
std::size_t count;
|
||||
};
|
||||
std::vector<CompRow> comp = {
|
||||
{"ck_default", "none", "no", receipts[0].registry.size()},
|
||||
{"flash_fwd", "alibi", "no", receipts[1].registry.size()},
|
||||
{"pytorch", "elementwise", "no", receipts[2].registry.size()},
|
||||
{"aiter_batch", "none", "yes", receipts[3].registry.size()},
|
||||
};
|
||||
|
||||
for(const auto& c : comp)
|
||||
{
|
||||
std::cout << " " << std::setw(14) << c.name << " | " << std::setw(14) << c.bias << " | "
|
||||
<< std::setw(5) << c.lse << " | " << std::setw(8) << c.count << "\n";
|
||||
}
|
||||
|
||||
// Step 5: GPU execution from ck_default
|
||||
std::cout << "\nStep 5: GPU Execution (ck_default receipt)\n";
|
||||
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(q_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems), k_host(q_elems), v_host(q_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_traits run_traits{};
|
||||
run_traits.hdim_q = hdim;
|
||||
run_traits.hdim_v = hdim;
|
||||
run_traits.data_type = "fp16";
|
||||
run_traits.is_group_mode = false;
|
||||
run_traits.is_v_rowmajor = true;
|
||||
run_traits.has_logits_soft_cap = false;
|
||||
run_traits.mask_type = mask_enum::no_mask;
|
||||
run_traits.bias_type = bias_enum::no_bias;
|
||||
run_traits.has_lse = false;
|
||||
run_traits.has_dropout = false;
|
||||
run_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args run_args{};
|
||||
run_args.q_ptr = q_dev.get();
|
||||
run_args.k_ptr = k_dev.get();
|
||||
run_args.v_ptr = v_dev.get();
|
||||
run_args.o_ptr = o_dev.get();
|
||||
|
||||
run_args.bias_ptr = nullptr;
|
||||
run_args.q_descale_ptr = nullptr;
|
||||
run_args.k_descale_ptr = nullptr;
|
||||
run_args.v_descale_ptr = nullptr;
|
||||
run_args.rand_val_ptr = nullptr;
|
||||
run_args.lse_ptr = nullptr;
|
||||
run_args.sink_ptr = nullptr;
|
||||
run_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
run_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
run_args.seqlen_q = seqlen;
|
||||
run_args.seqlen_k = seqlen;
|
||||
run_args.batch = batch;
|
||||
run_args.max_seqlen_q = seqlen;
|
||||
run_args.hdim_q = hdim;
|
||||
run_args.hdim_v = hdim;
|
||||
run_args.nhead_q = nhead;
|
||||
run_args.nhead_k = nhead;
|
||||
run_args.scale_s = scale;
|
||||
run_args.logits_soft_cap = 0.0f;
|
||||
|
||||
run_args.stride_q = hdim;
|
||||
run_args.stride_k = hdim;
|
||||
run_args.stride_v = hdim;
|
||||
run_args.stride_bias = 0;
|
||||
run_args.stride_randval = 0;
|
||||
run_args.stride_o = hdim;
|
||||
|
||||
run_args.nhead_stride_q = seqlen * hdim;
|
||||
run_args.nhead_stride_k = seqlen * hdim;
|
||||
run_args.nhead_stride_v = seqlen * hdim;
|
||||
run_args.nhead_stride_bias = 0;
|
||||
run_args.nhead_stride_randval = 0;
|
||||
run_args.nhead_stride_lse = 0;
|
||||
run_args.nhead_stride_o = seqlen * hdim;
|
||||
run_args.nhead_stride_q_descale = 0;
|
||||
run_args.nhead_stride_k_descale = 0;
|
||||
run_args.nhead_stride_v_descale = 0;
|
||||
|
||||
run_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
run_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
run_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
run_args.batch_stride_bias = 0;
|
||||
run_args.batch_stride_randval = 0;
|
||||
run_args.batch_stride_lse = 0;
|
||||
run_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
run_args.batch_stride_q_descale = 0;
|
||||
run_args.batch_stride_k_descale = 0;
|
||||
run_args.batch_stride_v_descale = 0;
|
||||
|
||||
run_args.window_size_left = -1;
|
||||
run_args.window_size_right = -1;
|
||||
run_args.sink_size = 0;
|
||||
run_args.mask_type = 0;
|
||||
run_args.min_seqlen_q = 0;
|
||||
run_args.p_drop = 0.0f;
|
||||
run_args.s_randval = false;
|
||||
run_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
run_args.block_scale_size_q = 0;
|
||||
run_args.block_scale_size_kv = 0;
|
||||
|
||||
FmhaDispatcher ck_disp(&receipts[0].registry);
|
||||
ck_disp.set_benchmarking(true);
|
||||
ck_disp.set_timing(1, 3);
|
||||
|
||||
bool passed = false;
|
||||
try
|
||||
{
|
||||
float time_ms = ck_disp.run_fwd(run_traits, run_args, nullptr);
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
|
||||
std::vector<FmhaDataType> o_host(q_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
std::vector<float> q_f32(q_elems), k_f32(q_elems), v_f32(q_elems), o_ref(q_elems, 0.0f);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
cpu_attention_fwd(
|
||||
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << q_elems << "\n";
|
||||
passed = (errors == 0);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return passed ? 0 : 1;
|
||||
}
|
||||
@@ -0,0 +1,530 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 25: AppendKV + BatchPrefill Planning with GPU Execution
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Declare appendkv, batch_prefill, and basic fwd kernels
|
||||
// 2. Plan appendkv with fmha_fwd_appendkv_traits / fmha_fwd_appendkv_args
|
||||
// 3. Plan batch_prefill with fmha_batch_prefill_traits / fmha_batch_prefill_args
|
||||
// 4. Run basic fwd kernel on GPU as sanity check
|
||||
// 5. Show cache_batch_idx usage pattern for non-contiguous batches
|
||||
//
|
||||
// Mirrors 01_basic_fmha.cpp for FMHA.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(appendkv_batchprefill_kernels,
|
||||
|
||||
// AppendKV kernel
|
||||
.add(FmhaSignature()
|
||||
.family("fwd_appendkv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.rope("inter")
|
||||
.paged_kv(true)
|
||||
.kv_cache("vectorized", "sglang", 16),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(64)
|
||||
.tile_n0(64)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.pipeline("appendkv")
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// BatchPrefill kernel (group mode, paged KV, page_size=64)
|
||||
.add(FmhaSignature()
|
||||
.family("batch_prefill")
|
||||
.dtype("fp16")
|
||||
.mode("group")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.paged_kv(true)
|
||||
.kv_cache("vectorized", "sglang", 64),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Basic fwd kernel for GPU execution sanity check
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 25: AppendKV + BatchPrefill + GPU",
|
||||
"FMHA AppendKV/BatchPrefill planning with GPU sanity check");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_flag("--validate", "Validate against CPU reference");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 25: AppendKV + BatchPrefill + GPU Execution");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("appendkv_batchprefill");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// =========================================================================
|
||||
// Step 2: Plan AppendKV
|
||||
// traits: fmha_fwd_appendkv_traits (hdim_q, hdim_v, data_type,
|
||||
// is_v_rowmajor, rope_type)
|
||||
// args: fmha_fwd_appendkv_args (q_ptr, k_ptr, knew_ptr, v_ptr,
|
||||
// vnew_ptr, seqlen_q, seqlen_knew, ...)
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 2: Plan AppendKV\n";
|
||||
|
||||
fmha_fwd_appendkv_traits append_traits{};
|
||||
append_traits.hdim_q = hdim;
|
||||
append_traits.hdim_v = hdim;
|
||||
append_traits.data_type = "fp16";
|
||||
append_traits.is_v_rowmajor = true;
|
||||
append_traits.rope_type = rope_enum::interleaved;
|
||||
|
||||
fmha_fwd_appendkv_args append_args{};
|
||||
append_args.q_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.k_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.knew_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.v_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.vnew_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.seqlen_q = 1;
|
||||
append_args.seqlen_knew = 1;
|
||||
append_args.batch = batch;
|
||||
append_args.hdim_q = hdim;
|
||||
append_args.hdim_v = hdim;
|
||||
append_args.nhead_q = nhead;
|
||||
append_args.nhead_k = nhead;
|
||||
append_args.rotary_dim = hdim;
|
||||
append_args.rotary_cos_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.rotary_sin_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.block_table_ptr = reinterpret_cast<void*>(0x1);
|
||||
append_args.page_block_size = 16;
|
||||
|
||||
// cache_batch_idx: maps request index -> cache slot for non-contiguous batches.
|
||||
// When serving multiple requests that don't occupy contiguous cache slots,
|
||||
// this indirection array tells the kernel which cache row each request maps to.
|
||||
append_args.cache_batch_idx_ptr = reinterpret_cast<void*>(0x1);
|
||||
|
||||
auto append_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(append_traits, append_args), gfx_arch));
|
||||
|
||||
std::cout << " AppendKV plan valid: " << (append_plan.is_valid() ? "yes" : "no") << "\n";
|
||||
if(append_plan.is_valid())
|
||||
{
|
||||
for(const auto& stage : append_plan.stages)
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Step 3: Plan BatchPrefill
|
||||
// traits: fmha_batch_prefill_traits (extends fmha_fwd_traits with
|
||||
// kv_memory_layout, kv_lookup_table, page_size)
|
||||
// args: fmha_batch_prefill_args (kv_indptr, kv_page_indices,
|
||||
// kv_last_page_lens, seqstart_q_ptr, ...)
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 3: Plan BatchPrefill\n";
|
||||
|
||||
fmha_batch_prefill_traits prefill_traits{};
|
||||
prefill_traits.hdim_q = hdim;
|
||||
prefill_traits.hdim_v = hdim;
|
||||
prefill_traits.data_type = "fp16";
|
||||
prefill_traits.is_group_mode = true;
|
||||
prefill_traits.is_v_rowmajor = true;
|
||||
prefill_traits.mask_type = mask_enum::no_mask;
|
||||
prefill_traits.bias_type = bias_enum::no_bias;
|
||||
prefill_traits.has_lse = true;
|
||||
prefill_traits.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
prefill_traits.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
prefill_traits.page_size = 64;
|
||||
|
||||
fmha_batch_prefill_args prefill_args{};
|
||||
prefill_args.batch = batch;
|
||||
prefill_args.seqlen_q = seqlen;
|
||||
prefill_args.seqlen_k = 1024;
|
||||
prefill_args.max_seqlen_q = seqlen;
|
||||
prefill_args.hdim_q = hdim;
|
||||
prefill_args.hdim_v = hdim;
|
||||
prefill_args.nhead_q = nhead;
|
||||
prefill_args.nhead_k = nhead;
|
||||
prefill_args.num_total_pages = 128;
|
||||
prefill_args.page_block_size = 64;
|
||||
prefill_args.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
prefill_args.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
prefill_args.kv_indptr = reinterpret_cast<void*>(0x1);
|
||||
prefill_args.kv_page_indices = reinterpret_cast<void*>(0x1);
|
||||
prefill_args.kv_last_page_lens = reinterpret_cast<void*>(0x1);
|
||||
prefill_args.seqstart_q_ptr = reinterpret_cast<void*>(0x1);
|
||||
|
||||
auto prefill_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(prefill_traits, prefill_args), gfx_arch));
|
||||
|
||||
std::cout << " BatchPrefill plan valid: " << (prefill_plan.is_valid() ? "yes" : "no") << "\n";
|
||||
if(prefill_plan.is_valid())
|
||||
{
|
||||
for(const auto& stage : prefill_plan.stages)
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Step 4: GPU Execution with basic fwd kernel (sanity check)
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 4: Allocate GPU Buffers\n";
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
fmha_fwd_traits fwd_traits{};
|
||||
fwd_traits.hdim_q = hdim;
|
||||
fwd_traits.hdim_v = hdim;
|
||||
fwd_traits.data_type = "fp16";
|
||||
fwd_traits.is_group_mode = false;
|
||||
fwd_traits.is_v_rowmajor = true;
|
||||
fwd_traits.has_logits_soft_cap = false;
|
||||
fwd_traits.mask_type = mask_enum::no_mask;
|
||||
fwd_traits.bias_type = bias_enum::no_bias;
|
||||
fwd_traits.has_lse = false;
|
||||
fwd_traits.has_dropout = false;
|
||||
fwd_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t k_elems = q_elems;
|
||||
const int64_t v_elems = q_elems;
|
||||
const int64_t o_elems = q_elems;
|
||||
|
||||
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
|
||||
<< "]\n";
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(k_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(v_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(o_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems);
|
||||
std::vector<FmhaDataType> k_host(k_elems);
|
||||
std::vector<FmhaDataType> v_host(v_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.q_ptr = q_dev.get();
|
||||
fwd_args.k_ptr = k_dev.get();
|
||||
fwd_args.v_ptr = v_dev.get();
|
||||
fwd_args.o_ptr = o_dev.get();
|
||||
|
||||
fwd_args.bias_ptr = nullptr;
|
||||
fwd_args.q_descale_ptr = nullptr;
|
||||
fwd_args.k_descale_ptr = nullptr;
|
||||
fwd_args.v_descale_ptr = nullptr;
|
||||
fwd_args.rand_val_ptr = nullptr;
|
||||
fwd_args.lse_ptr = nullptr;
|
||||
fwd_args.sink_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fwd_args.seqlen_q = seqlen;
|
||||
fwd_args.seqlen_k = seqlen;
|
||||
fwd_args.batch = batch;
|
||||
fwd_args.max_seqlen_q = seqlen;
|
||||
fwd_args.hdim_q = hdim;
|
||||
fwd_args.hdim_v = hdim;
|
||||
fwd_args.nhead_q = nhead;
|
||||
fwd_args.nhead_k = nhead;
|
||||
fwd_args.scale_s = scale;
|
||||
fwd_args.logits_soft_cap = 0.0f;
|
||||
|
||||
fwd_args.stride_q = hdim;
|
||||
fwd_args.stride_k = hdim;
|
||||
fwd_args.stride_v = hdim;
|
||||
fwd_args.stride_bias = 0;
|
||||
fwd_args.stride_randval = 0;
|
||||
fwd_args.stride_o = hdim;
|
||||
|
||||
fwd_args.nhead_stride_q = seqlen * hdim;
|
||||
fwd_args.nhead_stride_k = seqlen * hdim;
|
||||
fwd_args.nhead_stride_v = seqlen * hdim;
|
||||
fwd_args.nhead_stride_bias = 0;
|
||||
fwd_args.nhead_stride_randval = 0;
|
||||
fwd_args.nhead_stride_lse = 0;
|
||||
fwd_args.nhead_stride_o = seqlen * hdim;
|
||||
fwd_args.nhead_stride_q_descale = 0;
|
||||
fwd_args.nhead_stride_k_descale = 0;
|
||||
fwd_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fwd_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_bias = 0;
|
||||
fwd_args.batch_stride_randval = 0;
|
||||
fwd_args.batch_stride_lse = 0;
|
||||
fwd_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_q_descale = 0;
|
||||
fwd_args.batch_stride_k_descale = 0;
|
||||
fwd_args.batch_stride_v_descale = 0;
|
||||
|
||||
fwd_args.window_size_left = -1;
|
||||
fwd_args.window_size_right = -1;
|
||||
fwd_args.sink_size = 0;
|
||||
fwd_args.mask_type = 0;
|
||||
fwd_args.min_seqlen_q = 0;
|
||||
fwd_args.p_drop = 0.0f;
|
||||
fwd_args.s_randval = false;
|
||||
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fwd_args.block_scale_size_q = 0;
|
||||
fwd_args.block_scale_size_kv = 0;
|
||||
|
||||
// Step 5: Run on GPU
|
||||
std::cout << "\nStep 5: Run FMHA Forward on GPU\n";
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
|
||||
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Step 6: Validate
|
||||
std::cout << "\nStep 6: Validate\n";
|
||||
std::vector<FmhaDataType> o_host(o_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
|
||||
|
||||
bool passed = (nonzero > 0);
|
||||
|
||||
if(args.has("--validate"))
|
||||
{
|
||||
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < k_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < v_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
cpu_attention_fwd(
|
||||
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
double max_rel_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
max_rel_err = std::max(max_rel_err, rel_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Max rel error: " << max_rel_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
|
||||
passed = (errors == 0);
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return passed ? 0 : 1;
|
||||
}
|
||||
526
dispatcher/examples/fmha/cpp/26_dtypes_hdims_fmha.cpp
Normal file
526
dispatcher/examples/fmha/cpp/26_dtypes_hdims_fmha.cpp
Normal file
@@ -0,0 +1,526 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 26: Multiple Data Types and Head Dimensions with GPU Execution
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Declare bf16 hdim=128, fp16 hdim=64, and fp16 hdim=128 kernels
|
||||
// 2. Run each variant on GPU with appropriate buffer types
|
||||
// 3. Validate with different tolerances: fp16 (rtol=1e-3), bf16 (rtol=1e-2)
|
||||
// 4. Mention fp32, fp8bf16, fp8fp32, hdim 256, asymmetric hdim as planning
|
||||
//
|
||||
// Mirrors 01_basic_fmha.cpp for FMHA.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(dtypes_hdims_kernels,
|
||||
|
||||
// bf16 hdim=128
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("bf16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// fp16 hdim=64
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(64)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(64)
|
||||
.tile_k0(32)
|
||||
.tile_n1(64)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(64)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(16)
|
||||
.warp_n1(16)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(64, 64)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// fp16 hdim=128 (reference baseline)
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using Fp16Type = ck_tile::fp16_t;
|
||||
using Bf16Type = ck_tile::bf16_t;
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct VariantResult
|
||||
{
|
||||
std::string label;
|
||||
float time_ms;
|
||||
double tflops;
|
||||
double max_abs_err;
|
||||
double max_rel_err;
|
||||
int errors;
|
||||
bool passed;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
fmha_fwd_args make_fwd_args(GpuBuffer<DataType>& q_dev,
|
||||
GpuBuffer<DataType>& k_dev,
|
||||
GpuBuffer<DataType>& v_dev,
|
||||
GpuBuffer<DataType>& o_dev,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen,
|
||||
int hdim,
|
||||
float scale)
|
||||
{
|
||||
fmha_fwd_args a{};
|
||||
a.q_ptr = q_dev.get();
|
||||
a.k_ptr = k_dev.get();
|
||||
a.v_ptr = v_dev.get();
|
||||
a.o_ptr = o_dev.get();
|
||||
|
||||
a.bias_ptr = nullptr;
|
||||
a.q_descale_ptr = nullptr;
|
||||
a.k_descale_ptr = nullptr;
|
||||
a.v_descale_ptr = nullptr;
|
||||
a.rand_val_ptr = nullptr;
|
||||
a.lse_ptr = nullptr;
|
||||
a.sink_ptr = nullptr;
|
||||
a.block_scale_seqstart_q_ptr = nullptr;
|
||||
a.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
a.seqlen_q = seqlen;
|
||||
a.seqlen_k = seqlen;
|
||||
a.batch = batch;
|
||||
a.max_seqlen_q = seqlen;
|
||||
a.hdim_q = hdim;
|
||||
a.hdim_v = hdim;
|
||||
a.nhead_q = nhead;
|
||||
a.nhead_k = nhead;
|
||||
a.scale_s = scale;
|
||||
a.logits_soft_cap = 0.0f;
|
||||
|
||||
a.stride_q = hdim;
|
||||
a.stride_k = hdim;
|
||||
a.stride_v = hdim;
|
||||
a.stride_bias = 0;
|
||||
a.stride_randval = 0;
|
||||
a.stride_o = hdim;
|
||||
|
||||
a.nhead_stride_q = seqlen * hdim;
|
||||
a.nhead_stride_k = seqlen * hdim;
|
||||
a.nhead_stride_v = seqlen * hdim;
|
||||
a.nhead_stride_bias = 0;
|
||||
a.nhead_stride_randval = 0;
|
||||
a.nhead_stride_lse = 0;
|
||||
a.nhead_stride_o = seqlen * hdim;
|
||||
a.nhead_stride_q_descale = 0;
|
||||
a.nhead_stride_k_descale = 0;
|
||||
a.nhead_stride_v_descale = 0;
|
||||
|
||||
a.batch_stride_q = nhead * seqlen * hdim;
|
||||
a.batch_stride_k = nhead * seqlen * hdim;
|
||||
a.batch_stride_v = nhead * seqlen * hdim;
|
||||
a.batch_stride_bias = 0;
|
||||
a.batch_stride_randval = 0;
|
||||
a.batch_stride_lse = 0;
|
||||
a.batch_stride_o = nhead * seqlen * hdim;
|
||||
a.batch_stride_q_descale = 0;
|
||||
a.batch_stride_k_descale = 0;
|
||||
a.batch_stride_v_descale = 0;
|
||||
|
||||
a.window_size_left = -1;
|
||||
a.window_size_right = -1;
|
||||
a.sink_size = 0;
|
||||
a.mask_type = 0;
|
||||
a.min_seqlen_q = 0;
|
||||
a.p_drop = 0.0f;
|
||||
a.s_randval = false;
|
||||
a.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
a.block_scale_size_q = 0;
|
||||
a.block_scale_size_kv = 0;
|
||||
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
VariantResult run_variant(FmhaDispatcher& dispatcher,
|
||||
const std::string& label,
|
||||
const std::string& dtype_str,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen,
|
||||
int hdim,
|
||||
double rtol,
|
||||
double atol,
|
||||
const std::string& gfx_arch)
|
||||
{
|
||||
VariantResult result{};
|
||||
result.label = label;
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
const int64_t elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = dtype_str;
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = false;
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
GpuBuffer<DataType> q_dev(elems);
|
||||
GpuBuffer<DataType> k_dev(elems);
|
||||
GpuBuffer<DataType> v_dev(elems);
|
||||
GpuBuffer<DataType> o_dev(elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<DataType> q_host(elems);
|
||||
std::vector<DataType> k_host(elems);
|
||||
std::vector<DataType> v_host(elems);
|
||||
for(auto& x : q_host)
|
||||
x = DataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = DataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = DataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
|
||||
auto fwd_args = make_fwd_args(q_dev, k_dev, v_dev, o_dev, batch, nhead, seqlen, hdim, scale);
|
||||
|
||||
try
|
||||
{
|
||||
result.time_ms = dispatcher.run_fwd(traits, fwd_args, nullptr);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR [" << label << "]: " << e.what() << "\n";
|
||||
result.passed = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch);
|
||||
result.tflops = static_cast<double>(problem.num_ops()) / (result.time_ms * 1e-3) / 1e12;
|
||||
|
||||
std::vector<DataType> o_host(elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
std::vector<float> q_f32(elems), k_f32(elems), v_f32(elems), o_ref(elems, 0.0f);
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
cpu_attention_fwd(q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
|
||||
result.max_abs_err = 0.0;
|
||||
result.max_rel_err = 0.0;
|
||||
result.errors = 0;
|
||||
|
||||
for(int64_t i = 0; i < elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
|
||||
result.max_abs_err = std::max(result.max_abs_err, abs_err);
|
||||
result.max_rel_err = std::max(result.max_rel_err, rel_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++result.errors;
|
||||
}
|
||||
|
||||
result.passed = (result.errors == 0);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 26: Dtypes & Hdims FMHA",
|
||||
"FMHA with multiple data types and head dimensions");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
|
||||
args.add_flag("--validate", "Validate against CPU reference");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
|
||||
print_header("Example 26: Multiple Data Types & Head Dimensions");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("dtypes_hdims");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// =========================================================================
|
||||
// Step 2: Run variants on GPU
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 2: Run Variants\n";
|
||||
|
||||
// fp16 hdim=128 (reference baseline)
|
||||
std::cout << "\n --- fp16 hdim=128 (reference) ---\n";
|
||||
auto r_fp16_h128 = run_variant<Fp16Type>(dispatcher,
|
||||
"fp16_h128",
|
||||
"fp16",
|
||||
batch,
|
||||
nhead,
|
||||
seqlen,
|
||||
128,
|
||||
/*rtol=*/1e-3,
|
||||
/*atol=*/1e-3,
|
||||
gfx_arch);
|
||||
|
||||
// bf16 hdim=128 (wider tolerance due to reduced precision)
|
||||
std::cout << "\n --- bf16 hdim=128 ---\n";
|
||||
auto r_bf16_h128 = run_variant<Bf16Type>(dispatcher,
|
||||
"bf16_h128",
|
||||
"bf16",
|
||||
batch,
|
||||
nhead,
|
||||
seqlen,
|
||||
128,
|
||||
/*rtol=*/1e-2,
|
||||
/*atol=*/1e-2,
|
||||
gfx_arch);
|
||||
|
||||
// fp16 hdim=64 (smaller buffers)
|
||||
std::cout << "\n --- fp16 hdim=64 ---\n";
|
||||
auto r_fp16_h64 = run_variant<Fp16Type>(dispatcher,
|
||||
"fp16_h64",
|
||||
"fp16",
|
||||
batch,
|
||||
nhead,
|
||||
seqlen,
|
||||
64,
|
||||
/*rtol=*/1e-3,
|
||||
/*atol=*/1e-3,
|
||||
gfx_arch);
|
||||
|
||||
// =========================================================================
|
||||
// Step 3: Results Summary
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 3: Results Summary\n\n";
|
||||
|
||||
std::cout << " " << std::setw(14) << "Variant" << " | " << std::setw(10) << "Time(ms)" << " | "
|
||||
<< std::setw(10) << "TFLOPS" << " | " << std::setw(10) << "MaxAbsErr" << " | "
|
||||
<< std::setw(10) << "MaxRelErr" << " | " << std::setw(8) << "Errors" << " | "
|
||||
<< std::setw(6) << "Status" << "\n";
|
||||
std::cout << " " << std::string(82, '-') << "\n";
|
||||
|
||||
auto print_row = [](const VariantResult& r) {
|
||||
std::cout << std::fixed;
|
||||
std::cout << " " << std::setw(14) << r.label << " | " << std::setprecision(4)
|
||||
<< std::setw(10) << r.time_ms << " | " << std::setprecision(2) << std::setw(10)
|
||||
<< r.tflops << " | " << std::scientific << std::setw(10) << r.max_abs_err << " | "
|
||||
<< std::setw(10) << r.max_rel_err << " | " << std::fixed << std::setw(8)
|
||||
<< r.errors << " | " << std::setw(6) << (r.passed ? "PASS" : "FAIL") << "\n";
|
||||
};
|
||||
|
||||
print_row(r_fp16_h128);
|
||||
print_row(r_bf16_h128);
|
||||
print_row(r_fp16_h64);
|
||||
|
||||
// =========================================================================
|
||||
// Step 4: Tolerance Notes
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 4: Tolerance Notes\n";
|
||||
std::cout << " fp16 validation: rtol=1e-3, atol=1e-3 (higher precision)\n";
|
||||
std::cout << " bf16 validation: rtol=1e-2, atol=1e-2 (wider tolerance for bfloat16)\n";
|
||||
std::cout << "\n Additional dtype/hdim combinations (planning-level declarations):\n";
|
||||
std::cout << " fp32: .dtype(\"fp32\") - full single precision\n";
|
||||
std::cout << " fp8bf16: .dtype(\"fp8bf16\") - fp8 compute, bf16 output\n";
|
||||
std::cout << " fp8fp32: .dtype(\"fp8fp32\") - fp8 compute, fp32 output\n";
|
||||
std::cout << " hdim 256: .hdim(256), tile(128,128,32,256,32,256)\n";
|
||||
std::cout << " asymmetric: .hdim_q(128), .hdim_v(64) - different Q/V dims\n";
|
||||
|
||||
bool all_passed = r_fp16_h128.passed && r_bf16_h128.passed && r_fp16_h64.passed;
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
635
dispatcher/examples/fmha/cpp/27_padding_permutation_fmha.cpp
Normal file
635
dispatcher/examples/fmha/cpp/27_padding_permutation_fmha.cpp
Normal file
@@ -0,0 +1,635 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 27: Padding, Group Mode, V Col-Major, Permutation Patterns
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Batch padding with cu_seqlen arrays for per-batch variable lengths
|
||||
// 2. Group mode with seqstart_q / seqstart_k buffers
|
||||
// 3. V col-major layout declaration: .vlayout("c")
|
||||
// 4. Permutation patterns: bhsd (iperm=1) vs bshd (iperm=0) strides
|
||||
// 5. GPU execution with basic kernel + batch padding
|
||||
//
|
||||
// Mirrors 01_basic_fmha.cpp for FMHA.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(padding_permutation_kernels,
|
||||
|
||||
// Basic fwd kernel (batch mode, for GPU execution)
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
// Stage 0 (Q*K^T): m0=seqlen_q, n0=seqlen_k, k0=hdim_q
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
// Stage 1 (Attn*V): n1=hdim_v, k1=seqlen_k, k0max=alignment
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Group mode kernel (variable-length sequences)
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("group")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// V col-major layout declaration
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("c")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen_q; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen_k, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim_q; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
scores[sk] /= sum_exp;
|
||||
}
|
||||
|
||||
for(int dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen_k; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 27: Padding & Permutation FMHA",
|
||||
"FMHA padding, group mode, V col-major, and permutation patterns");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_flag("--validate", "Validate against CPU reference");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 27: Padding, Group Mode, V Col-Major, Permutation");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("padding_permutation");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
// =========================================================================
|
||||
// Step 2: Batch Padding Pattern
|
||||
// Allocate cu_seqlen_q / cu_seqlen_k buffers with cumulative sums.
|
||||
// In CK's dispatcher, this maps to seqstart_q_ptr / seqstart_k_ptr
|
||||
// and requires group mode to enable per-batch variable sequence lengths.
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 2: Batch Padding Pattern (cu_seqlen)\n";
|
||||
{
|
||||
// Per-batch sequence lengths: batch 0 has seqlen=32, batch 1 has seqlen=48
|
||||
const std::vector<int32_t> seqlens_q = {32, 48};
|
||||
const std::vector<int32_t> seqlens_k = {32, 48};
|
||||
const int num_batches = static_cast<int>(seqlens_q.size());
|
||||
|
||||
// Build cumulative sum arrays: [0, 32, 80]
|
||||
std::vector<int32_t> cu_seqlen_q(num_batches + 1, 0);
|
||||
std::vector<int32_t> cu_seqlen_k(num_batches + 1, 0);
|
||||
for(int i = 0; i < num_batches; ++i)
|
||||
{
|
||||
cu_seqlen_q[i + 1] = cu_seqlen_q[i] + seqlens_q[i];
|
||||
cu_seqlen_k[i + 1] = cu_seqlen_k[i] + seqlens_k[i];
|
||||
}
|
||||
|
||||
const int total_q = cu_seqlen_q.back();
|
||||
const int total_k = cu_seqlen_k.back();
|
||||
const int max_sq = *std::max_element(seqlens_q.begin(), seqlens_q.end());
|
||||
|
||||
std::cout << " Batch seqlens_q: [";
|
||||
for(int i = 0; i < num_batches; ++i)
|
||||
std::cout << (i ? ", " : "") << seqlens_q[i];
|
||||
std::cout << "]\n";
|
||||
std::cout << " cu_seqlen_q: [";
|
||||
for(size_t i = 0; i < cu_seqlen_q.size(); ++i)
|
||||
std::cout << (i ? ", " : "") << cu_seqlen_q[i];
|
||||
std::cout << "]\n";
|
||||
|
||||
GpuBuffer<int32_t> cu_sq_dev(num_batches + 1);
|
||||
GpuBuffer<int32_t> cu_sk_dev(num_batches + 1);
|
||||
cu_sq_dev.copy_from_host(cu_seqlen_q.data());
|
||||
cu_sk_dev.copy_from_host(cu_seqlen_k.data());
|
||||
|
||||
// Group mode traits for variable-length sequences
|
||||
fmha_fwd_traits pad_traits{};
|
||||
pad_traits.hdim_q = hdim;
|
||||
pad_traits.hdim_v = hdim;
|
||||
pad_traits.data_type = "fp16";
|
||||
pad_traits.is_group_mode = true;
|
||||
pad_traits.is_v_rowmajor = true;
|
||||
pad_traits.has_logits_soft_cap = false;
|
||||
pad_traits.mask_type = mask_enum::no_mask;
|
||||
pad_traits.bias_type = bias_enum::no_bias;
|
||||
pad_traits.has_lse = false;
|
||||
pad_traits.has_dropout = false;
|
||||
pad_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args pad_args{};
|
||||
pad_args.seqlen_q = total_q;
|
||||
pad_args.seqlen_k = total_k;
|
||||
pad_args.batch = num_batches;
|
||||
pad_args.max_seqlen_q = max_sq;
|
||||
pad_args.hdim_q = hdim;
|
||||
pad_args.hdim_v = hdim;
|
||||
pad_args.nhead_q = nhead;
|
||||
pad_args.nhead_k = nhead;
|
||||
pad_args.scale_s = scale;
|
||||
|
||||
// cu_seqlen_q_ptr / cu_seqlen_k_ptr (seqstart_q / seqstart_k in CK)
|
||||
pad_args.seqstart_q_ptr = cu_sq_dev.get();
|
||||
pad_args.seqstart_k_ptr = cu_sk_dev.get();
|
||||
|
||||
auto pad_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(pad_traits, pad_args), gfx_arch));
|
||||
std::cout << " Batch padding plan valid: " << (pad_plan.is_valid() ? "yes" : "no") << "\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Step 3: Group Mode Pattern
|
||||
// Group mode uses seqstart_q / seqstart_k arrays to define variable
|
||||
// sequence boundaries. Each batch element can have a different length.
|
||||
// traits.is_group_mode = true
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 3: Group Mode Pattern (seqstart)\n";
|
||||
{
|
||||
fmha_fwd_traits group_traits{};
|
||||
group_traits.hdim_q = hdim;
|
||||
group_traits.hdim_v = hdim;
|
||||
group_traits.data_type = "fp16";
|
||||
group_traits.is_group_mode = true;
|
||||
group_traits.is_v_rowmajor = true;
|
||||
group_traits.has_logits_soft_cap = false;
|
||||
group_traits.mask_type = mask_enum::no_mask;
|
||||
group_traits.bias_type = bias_enum::no_bias;
|
||||
group_traits.has_lse = false;
|
||||
group_traits.has_dropout = false;
|
||||
group_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
const std::vector<int32_t> seqstart_q = {0, 64, 192};
|
||||
const std::vector<int32_t> seqstart_k = {0, 128, 256};
|
||||
const int num_batches = static_cast<int>(seqstart_q.size()) - 1;
|
||||
const int total_q = seqstart_q.back();
|
||||
const int max_sq = 128;
|
||||
|
||||
GpuBuffer<int32_t> ss_q_dev(seqstart_q.size());
|
||||
GpuBuffer<int32_t> ss_k_dev(seqstart_k.size());
|
||||
ss_q_dev.copy_from_host(seqstart_q.data());
|
||||
ss_k_dev.copy_from_host(seqstart_k.data());
|
||||
|
||||
fmha_fwd_args group_args{};
|
||||
group_args.seqlen_q = total_q;
|
||||
group_args.seqlen_k = seqstart_k.back();
|
||||
group_args.batch = num_batches;
|
||||
group_args.max_seqlen_q = max_sq;
|
||||
group_args.hdim_q = hdim;
|
||||
group_args.hdim_v = hdim;
|
||||
group_args.nhead_q = nhead;
|
||||
group_args.nhead_k = nhead;
|
||||
group_args.scale_s = scale;
|
||||
group_args.seqstart_q_ptr = ss_q_dev.get();
|
||||
group_args.seqstart_k_ptr = ss_k_dev.get();
|
||||
|
||||
std::cout << " seqstart_q: [0, 64, 192] -> batches of length 64 and 128\n";
|
||||
std::cout << " seqstart_k: [0, 128, 256] -> KV of length 128 and 128\n";
|
||||
|
||||
auto group_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(group_traits, group_args), gfx_arch));
|
||||
std::cout << " Group mode plan valid: " << (group_plan.is_valid() ? "yes" : "no") << "\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Step 4: V Col-Major Declaration
|
||||
// .vlayout("c") declares V in column-major layout (seqlen_k x hdim_v
|
||||
// stored column-first). This affects how the kernel reads V.
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 4: V Col-Major Layout\n";
|
||||
{
|
||||
fmha_fwd_traits vcol_traits{};
|
||||
vcol_traits.hdim_q = hdim;
|
||||
vcol_traits.hdim_v = hdim;
|
||||
vcol_traits.data_type = "fp16";
|
||||
vcol_traits.is_group_mode = false;
|
||||
vcol_traits.is_v_rowmajor = false;
|
||||
vcol_traits.has_logits_soft_cap = false;
|
||||
vcol_traits.mask_type = mask_enum::no_mask;
|
||||
vcol_traits.bias_type = bias_enum::no_bias;
|
||||
vcol_traits.has_lse = false;
|
||||
vcol_traits.has_dropout = false;
|
||||
vcol_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args vcol_args{};
|
||||
vcol_args.batch = batch;
|
||||
vcol_args.seqlen_q = seqlen;
|
||||
vcol_args.seqlen_k = seqlen;
|
||||
vcol_args.max_seqlen_q = seqlen;
|
||||
vcol_args.hdim_q = hdim;
|
||||
vcol_args.hdim_v = hdim;
|
||||
vcol_args.nhead_q = nhead;
|
||||
vcol_args.nhead_k = nhead;
|
||||
vcol_args.scale_s = scale;
|
||||
|
||||
std::cout << " V row-major (.vlayout(\"r\")): stride_v = hdim, "
|
||||
"contiguous along head dimension\n";
|
||||
std::cout << " V col-major (.vlayout(\"c\")): stride_v = seqlen_k, "
|
||||
"contiguous along sequence dimension\n";
|
||||
std::cout << " traits.is_v_rowmajor = false\n";
|
||||
|
||||
auto vcol_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(vcol_traits, vcol_args), gfx_arch));
|
||||
std::cout << " V col-major plan valid: " << (vcol_plan.is_valid() ? "yes" : "no") << "\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Step 5: Permutation Patterns (bhsd vs bshd)
|
||||
// bhsd layout (iperm=1): stride_q = hdim, nhead_stride_q = seqlen*hdim
|
||||
// memory: [batch, head, seq, dim]
|
||||
// bshd layout (iperm=0): stride_q = nhead*hdim, nhead_stride_q = hdim
|
||||
// memory: [batch, seq, head, dim]
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 5: Permutation Patterns\n";
|
||||
{
|
||||
std::cout << " bhsd layout (iperm=1):\n";
|
||||
std::cout << " stride_q = hdim = " << hdim << "\n";
|
||||
std::cout << " nhead_stride_q = seqlen * hdim = " << seqlen * hdim << "\n";
|
||||
std::cout << " batch_stride_q = nhead * seqlen * hdim = " << nhead * seqlen * hdim
|
||||
<< "\n";
|
||||
std::cout << " memory order: [batch, head, seq, dim]\n";
|
||||
|
||||
std::cout << "\n bshd layout (iperm=0):\n";
|
||||
std::cout << " stride_q = nhead * hdim = " << nhead * hdim << "\n";
|
||||
std::cout << " nhead_stride_q = hdim = " << hdim << "\n";
|
||||
std::cout << " batch_stride_q = seqlen * nhead * hdim = " << seqlen * nhead * hdim
|
||||
<< "\n";
|
||||
std::cout << " memory order: [batch, seq, head, dim]\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Step 6: GPU Execution with basic kernel + batch padding
|
||||
// Run the batch-mode kernel with a non-tile-aligned seqlen to exercise
|
||||
// the .padding(true, true, true, true) capability.
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 6: GPU Execution (batch mode, seqlen=" << seqlen << ")\n";
|
||||
|
||||
fmha_fwd_traits fwd_traits{};
|
||||
fwd_traits.hdim_q = hdim;
|
||||
fwd_traits.hdim_v = hdim;
|
||||
fwd_traits.data_type = "fp16";
|
||||
fwd_traits.is_group_mode = false;
|
||||
fwd_traits.is_v_rowmajor = true;
|
||||
fwd_traits.has_logits_soft_cap = false;
|
||||
fwd_traits.mask_type = mask_enum::no_mask;
|
||||
fwd_traits.bias_type = bias_enum::no_bias;
|
||||
fwd_traits.has_lse = false;
|
||||
fwd_traits.has_dropout = false;
|
||||
fwd_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t k_elems = q_elems;
|
||||
const int64_t v_elems = q_elems;
|
||||
const int64_t o_elems = q_elems;
|
||||
|
||||
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
|
||||
<< "]\n";
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(q_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(k_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(v_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(o_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(q_elems);
|
||||
std::vector<FmhaDataType> k_host(k_elems);
|
||||
std::vector<FmhaDataType> v_host(v_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.q_ptr = q_dev.get();
|
||||
fwd_args.k_ptr = k_dev.get();
|
||||
fwd_args.v_ptr = v_dev.get();
|
||||
fwd_args.o_ptr = o_dev.get();
|
||||
|
||||
fwd_args.bias_ptr = nullptr;
|
||||
fwd_args.q_descale_ptr = nullptr;
|
||||
fwd_args.k_descale_ptr = nullptr;
|
||||
fwd_args.v_descale_ptr = nullptr;
|
||||
fwd_args.rand_val_ptr = nullptr;
|
||||
fwd_args.lse_ptr = nullptr;
|
||||
fwd_args.sink_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fwd_args.seqlen_q = seqlen;
|
||||
fwd_args.seqlen_k = seqlen;
|
||||
fwd_args.batch = batch;
|
||||
fwd_args.max_seqlen_q = seqlen;
|
||||
fwd_args.hdim_q = hdim;
|
||||
fwd_args.hdim_v = hdim;
|
||||
fwd_args.nhead_q = nhead;
|
||||
fwd_args.nhead_k = nhead;
|
||||
fwd_args.scale_s = scale;
|
||||
fwd_args.logits_soft_cap = 0.0f;
|
||||
|
||||
// bhsd layout strides (iperm=1)
|
||||
fwd_args.stride_q = hdim;
|
||||
fwd_args.stride_k = hdim;
|
||||
fwd_args.stride_v = hdim;
|
||||
fwd_args.stride_bias = 0;
|
||||
fwd_args.stride_randval = 0;
|
||||
fwd_args.stride_o = hdim;
|
||||
|
||||
fwd_args.nhead_stride_q = seqlen * hdim;
|
||||
fwd_args.nhead_stride_k = seqlen * hdim;
|
||||
fwd_args.nhead_stride_v = seqlen * hdim;
|
||||
fwd_args.nhead_stride_bias = 0;
|
||||
fwd_args.nhead_stride_randval = 0;
|
||||
fwd_args.nhead_stride_lse = 0;
|
||||
fwd_args.nhead_stride_o = seqlen * hdim;
|
||||
fwd_args.nhead_stride_q_descale = 0;
|
||||
fwd_args.nhead_stride_k_descale = 0;
|
||||
fwd_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fwd_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_bias = 0;
|
||||
fwd_args.batch_stride_randval = 0;
|
||||
fwd_args.batch_stride_lse = 0;
|
||||
fwd_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_q_descale = 0;
|
||||
fwd_args.batch_stride_k_descale = 0;
|
||||
fwd_args.batch_stride_v_descale = 0;
|
||||
|
||||
fwd_args.window_size_left = -1;
|
||||
fwd_args.window_size_right = -1;
|
||||
fwd_args.sink_size = 0;
|
||||
fwd_args.mask_type = 0;
|
||||
fwd_args.min_seqlen_q = 0;
|
||||
fwd_args.p_drop = 0.0f;
|
||||
fwd_args.s_randval = false;
|
||||
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fwd_args.block_scale_size_q = 0;
|
||||
fwd_args.block_scale_size_kv = 0;
|
||||
|
||||
float time_ms = 0.0f;
|
||||
try
|
||||
{
|
||||
time_ms = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " ERROR: " << e.what() << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
|
||||
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Step 7: Validate
|
||||
std::cout << "\nStep 7: Validate\n";
|
||||
std::vector<FmhaDataType> o_host(o_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
|
||||
|
||||
bool passed = (nonzero > 0);
|
||||
|
||||
if(args.has("--validate"))
|
||||
{
|
||||
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f);
|
||||
for(int64_t i = 0; i < q_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < k_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < v_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
cpu_attention_fwd(
|
||||
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
|
||||
|
||||
double max_abs_err = 0.0;
|
||||
double max_rel_err = 0.0;
|
||||
int errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
|
||||
for(int64_t i = 0; i < o_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
|
||||
max_abs_err = std::max(max_abs_err, abs_err);
|
||||
max_rel_err = std::max(max_rel_err, rel_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++errors;
|
||||
}
|
||||
|
||||
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
|
||||
std::cout << " Max rel error: " << max_rel_err << "\n";
|
||||
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
|
||||
passed = (errors == 0);
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return passed ? 0 : 1;
|
||||
}
|
||||
489
dispatcher/examples/fmha/cpp/28_bwd_masks_fmha.cpp
Normal file
489
dispatcher/examples/fmha/cpp/28_bwd_masks_fmha.cpp
Normal file
@@ -0,0 +1,489 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 28: FMHA Backward with Causal Mask
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Forward kernel with top_left causal mask + LSE
|
||||
// 2. Backward kernel families (bwd_dot_do_o, bwd_dq_dk_dv, bwd_convert_dq) with causal mask
|
||||
// 3. GPU forward execution with causal mask validation
|
||||
// 4. Backward 3-stage plan display
|
||||
//
|
||||
// Backward kernels use planning only -- actual backward GPU execution requires
|
||||
// all 3 stages to compile, and bwd_dq_dk_dv has tile structure issues on gfx950.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(bwd_masks_fmha_kernels,
|
||||
// Forward: causal mask (top_left) with LSE for backward
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward stage 1: dot(dO, O) with causal mask
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward stage 2: compute dQ, dK, dV with causal mask
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true)
|
||||
.max_seq_len_q(0)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward stage 3: convert accumulated dQ from fp32 to fp16
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd_causal(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
std::vector<float>& LSE,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen,
|
||||
int hdim,
|
||||
float scale)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(int sq = 0; sq < seqlen; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen + sq) * hdim + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen + sk) * hdim + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
// top_left causal: mask if sk > sq
|
||||
if(sk > sq)
|
||||
s = -1e30f;
|
||||
|
||||
scores[sk] = s;
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
int lse_idx = (b * nhead + h) * seqlen + sq;
|
||||
LSE[lse_idx] = max_score + std::log(sum_exp);
|
||||
|
||||
for(int sk = 0; sk < seqlen; ++sk)
|
||||
scores[sk] /= sum_exp;
|
||||
|
||||
for(int dv = 0; dv < hdim; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen + sk) * hdim + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen + sq) * hdim + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 28: FMHA Backward with Masks",
|
||||
"Causal mask forward (GPU) + backward plan");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
print_header("Example 28: FMHA Backward with Causal Mask");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("bwd_masks_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// Step 2: Plan backward (3-stage) with causal mask
|
||||
std::cout << "\nStep 2: Plan Backward (causal mask)\n";
|
||||
|
||||
fmha_bwd_traits bwd_traits{};
|
||||
bwd_traits.hdim_q = hdim;
|
||||
bwd_traits.hdim_v = hdim;
|
||||
bwd_traits.data_type = "fp16";
|
||||
bwd_traits.is_group_mode = false;
|
||||
bwd_traits.mask_type = mask_enum::mask_top_left;
|
||||
bwd_traits.bias_type = bias_enum::no_bias;
|
||||
bwd_traits.has_dbias = false;
|
||||
bwd_traits.has_dropout = false;
|
||||
bwd_traits.is_store_randval = false;
|
||||
bwd_traits.is_deterministic = false;
|
||||
|
||||
fmha_bwd_args bwd_args{};
|
||||
bwd_args.batch = batch;
|
||||
bwd_args.seqlen_q = seqlen;
|
||||
bwd_args.seqlen_k = seqlen;
|
||||
bwd_args.max_seqlen_q = seqlen;
|
||||
bwd_args.max_seqlen_k = seqlen;
|
||||
bwd_args.hdim_q = hdim;
|
||||
bwd_args.hdim_v = hdim;
|
||||
bwd_args.nhead_q = nhead;
|
||||
bwd_args.nhead_k = nhead;
|
||||
|
||||
auto bwd_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
|
||||
|
||||
if(bwd_plan.is_valid() && bwd_plan.stages.size() >= 2)
|
||||
{
|
||||
std::cout << " Backward plan stages (" << bwd_plan.stages.size() << "):\n";
|
||||
for(const auto& stage : bwd_plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << " Backward plan: INVALID or single-stage (expected 3 stages)\n";
|
||||
std::cout << " This is expected -- backward planning shows the pattern\n";
|
||||
}
|
||||
|
||||
// Step 3: Run forward on GPU with causal mask
|
||||
std::cout << "\nStep 3: Run Forward (causal mask, GPU)\n";
|
||||
|
||||
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
|
||||
GpuBuffer<float> lse_dev(lse_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
lse_dev.zero();
|
||||
|
||||
fmha_fwd_traits fwd_traits{};
|
||||
fwd_traits.hdim_q = hdim;
|
||||
fwd_traits.hdim_v = hdim;
|
||||
fwd_traits.data_type = "fp16";
|
||||
fwd_traits.is_group_mode = false;
|
||||
fwd_traits.is_v_rowmajor = true;
|
||||
fwd_traits.has_logits_soft_cap = false;
|
||||
fwd_traits.mask_type = mask_enum::mask_top_left;
|
||||
fwd_traits.bias_type = bias_enum::no_bias;
|
||||
fwd_traits.has_lse = true;
|
||||
fwd_traits.has_dropout = false;
|
||||
fwd_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.q_ptr = q_dev.get();
|
||||
fwd_args.k_ptr = k_dev.get();
|
||||
fwd_args.v_ptr = v_dev.get();
|
||||
fwd_args.o_ptr = o_dev.get();
|
||||
fwd_args.lse_ptr = lse_dev.get();
|
||||
|
||||
fwd_args.bias_ptr = nullptr;
|
||||
fwd_args.q_descale_ptr = nullptr;
|
||||
fwd_args.k_descale_ptr = nullptr;
|
||||
fwd_args.v_descale_ptr = nullptr;
|
||||
fwd_args.rand_val_ptr = nullptr;
|
||||
fwd_args.sink_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fwd_args.seqlen_q = seqlen;
|
||||
fwd_args.seqlen_k = seqlen;
|
||||
fwd_args.batch = batch;
|
||||
fwd_args.max_seqlen_q = seqlen;
|
||||
fwd_args.hdim_q = hdim;
|
||||
fwd_args.hdim_v = hdim;
|
||||
fwd_args.nhead_q = nhead;
|
||||
fwd_args.nhead_k = nhead;
|
||||
fwd_args.scale_s = scale;
|
||||
fwd_args.logits_soft_cap = 0.0f;
|
||||
|
||||
fwd_args.stride_q = hdim;
|
||||
fwd_args.stride_k = hdim;
|
||||
fwd_args.stride_v = hdim;
|
||||
fwd_args.stride_bias = 0;
|
||||
fwd_args.stride_randval = 0;
|
||||
fwd_args.stride_o = hdim;
|
||||
|
||||
fwd_args.nhead_stride_q = seqlen * hdim;
|
||||
fwd_args.nhead_stride_k = seqlen * hdim;
|
||||
fwd_args.nhead_stride_v = seqlen * hdim;
|
||||
fwd_args.nhead_stride_bias = 0;
|
||||
fwd_args.nhead_stride_randval = 0;
|
||||
fwd_args.nhead_stride_lse = seqlen;
|
||||
fwd_args.nhead_stride_o = seqlen * hdim;
|
||||
fwd_args.nhead_stride_q_descale = 0;
|
||||
fwd_args.nhead_stride_k_descale = 0;
|
||||
fwd_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fwd_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_bias = 0;
|
||||
fwd_args.batch_stride_randval = 0;
|
||||
fwd_args.batch_stride_lse = nhead * seqlen;
|
||||
fwd_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_q_descale = 0;
|
||||
fwd_args.batch_stride_k_descale = 0;
|
||||
fwd_args.batch_stride_v_descale = 0;
|
||||
|
||||
fwd_args.window_size_left = -1;
|
||||
fwd_args.window_size_right = 0;
|
||||
fwd_args.sink_size = 0;
|
||||
fwd_args.mask_type = 1; // top_left
|
||||
fwd_args.min_seqlen_q = 0;
|
||||
fwd_args.p_drop = 0.0f;
|
||||
fwd_args.s_randval = false;
|
||||
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fwd_args.block_scale_size_q = 0;
|
||||
fwd_args.block_scale_size_kv = 0;
|
||||
|
||||
bool fwd_passed = false;
|
||||
try
|
||||
{
|
||||
float fwd_time = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
|
||||
std::cout << " Forward time: " << std::fixed << std::setprecision(4) << fwd_time
|
||||
<< " ms\n";
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (fwd_time * 1e-3) / 1e12;
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
fwd_passed = true;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " Forward ERROR: " << e.what() << "\n";
|
||||
}
|
||||
|
||||
// Step 4: Validate forward output
|
||||
std::cout << "\nStep 4: Validate Forward Output\n";
|
||||
|
||||
if(fwd_passed)
|
||||
{
|
||||
std::vector<FmhaDataType> o_host(qkv_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
std::vector<float> lse_host(lse_elems);
|
||||
lse_dev.copy_to_host(lse_host.data());
|
||||
|
||||
std::vector<float> q_f32(qkv_elems), k_f32(qkv_elems), v_f32(qkv_elems);
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
std::vector<float> o_ref(qkv_elems, 0.0f);
|
||||
std::vector<float> lse_ref(lse_elems, 0.0f);
|
||||
cpu_attention_fwd_causal(
|
||||
q_f32, k_f32, v_f32, o_ref, lse_ref, batch, nhead, seqlen, hdim, scale);
|
||||
|
||||
double max_o_err = 0.0;
|
||||
int o_errors = 0;
|
||||
const double rtol = 1e-2;
|
||||
const double atol = 1e-2;
|
||||
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
{
|
||||
float gpu_val = static_cast<float>(o_host[i]);
|
||||
float ref_val = o_ref[i];
|
||||
double abs_err = std::abs(gpu_val - ref_val);
|
||||
max_o_err = std::max(max_o_err, abs_err);
|
||||
if(abs_err > atol + rtol * std::abs(ref_val))
|
||||
++o_errors;
|
||||
}
|
||||
|
||||
double max_lse_err = 0.0;
|
||||
int lse_reasonable = 0;
|
||||
for(int64_t i = 0; i < lse_elems; ++i)
|
||||
{
|
||||
if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f)
|
||||
++lse_reasonable;
|
||||
max_lse_err =
|
||||
std::max(max_lse_err, static_cast<double>(std::abs(lse_host[i] - lse_ref[i])));
|
||||
}
|
||||
|
||||
std::cout << " Output max abs error: " << std::scientific << max_o_err << "\n";
|
||||
std::cout << " Output errors: " << o_errors << " / " << qkv_elems << "\n";
|
||||
std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n";
|
||||
std::cout << " LSE max error: " << std::scientific << max_lse_err << "\n";
|
||||
|
||||
fwd_passed = (o_errors == 0) && (lse_reasonable == lse_elems);
|
||||
}
|
||||
|
||||
// Step 5: Show backward API pattern
|
||||
std::cout << "\nStep 5: Backward API Pattern (traits + args)\n";
|
||||
std::cout << " bwd_traits.mask_type = mask_top_left\n";
|
||||
std::cout << " bwd_traits.bias_type = no_bias\n";
|
||||
std::cout << " bwd_traits.has_dropout = false\n";
|
||||
std::cout << " bwd_traits.is_deterministic = false\n";
|
||||
std::cout << " bwd_args.window_size_left = -1\n";
|
||||
std::cout << " bwd_args.window_size_right = 0 (causal)\n";
|
||||
std::cout << " bwd_args.mask_type = 1 (top_left)\n";
|
||||
std::cout << " Backward plan resolves to " << bwd_plan.stages.size() << " stage(s)\n";
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (fwd_passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return fwd_passed ? 0 : 1;
|
||||
}
|
||||
615
dispatcher/examples/fmha/cpp/29_bwd_bias_dropout_fmha.cpp
Normal file
615
dispatcher/examples/fmha/cpp/29_bwd_bias_dropout_fmha.cpp
Normal file
@@ -0,0 +1,615 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 29: FMHA Backward with ALiBi Bias + Dropout
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Forward kernel with alibi bias + dropout + LSE
|
||||
// 2. Backward kernel families with alibi + dropout
|
||||
// 3. GPU forward execution with alibi bias, validates output
|
||||
// 4. Backward plan with all features enabled
|
||||
// 5. How deterministic mode affects the backward plan
|
||||
//
|
||||
// Backward kernels use planning only -- actual backward GPU execution requires
|
||||
// all 3 stages to compile, and bwd_dq_dk_dv has tile structure issues on gfx950.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(bwd_bias_dropout_fmha_kernels,
|
||||
// Forward: alibi bias + dropout + LSE
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.lse(true)
|
||||
.dropout(true)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward stage 1: dot(dO, O) with alibi + dropout (non-deterministic)
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.dropout(true)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward stage 2: dQ, dK, dV with alibi + dropout (non-deterministic)
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.dropout(true)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true)
|
||||
.max_seq_len_q(0)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward stage 3: convert dQ with alibi + dropout (non-deterministic)
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.dropout(true)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Deterministic variants for comparison
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.dropout(true)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(true),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.dropout(true)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(true),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true)
|
||||
.max_seq_len_q(0)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("alibi")
|
||||
.dropout(true)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(true),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
void cpu_attention_fwd_alibi(const std::vector<float>& Q,
|
||||
const std::vector<float>& K,
|
||||
const std::vector<float>& V,
|
||||
std::vector<float>& O,
|
||||
std::vector<float>& LSE,
|
||||
int batch,
|
||||
int nhead,
|
||||
int seqlen,
|
||||
int hdim,
|
||||
float scale,
|
||||
const std::vector<float>& alibi_slopes)
|
||||
{
|
||||
for(int b = 0; b < batch; ++b)
|
||||
{
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
const float slope = alibi_slopes[h];
|
||||
|
||||
for(int sq = 0; sq < seqlen; ++sq)
|
||||
{
|
||||
std::vector<float> scores(seqlen, 0.0f);
|
||||
float max_score = -1e30f;
|
||||
|
||||
for(int sk = 0; sk < seqlen; ++sk)
|
||||
{
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < hdim; ++d)
|
||||
{
|
||||
int q_idx = ((b * nhead + h) * seqlen + sq) * hdim + d;
|
||||
int k_idx = ((b * nhead + h) * seqlen + sk) * hdim + d;
|
||||
dot += Q[q_idx] * K[k_idx];
|
||||
}
|
||||
scores[sk] = dot * scale + slope * static_cast<float>(sk - sq);
|
||||
max_score = std::max(max_score, scores[sk]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for(int sk = 0; sk < seqlen; ++sk)
|
||||
{
|
||||
scores[sk] = std::exp(scores[sk] - max_score);
|
||||
sum_exp += scores[sk];
|
||||
}
|
||||
|
||||
int lse_idx = (b * nhead + h) * seqlen + sq;
|
||||
LSE[lse_idx] = max_score + std::log(sum_exp);
|
||||
|
||||
for(int sk = 0; sk < seqlen; ++sk)
|
||||
scores[sk] /= sum_exp;
|
||||
|
||||
for(int dv = 0; dv < hdim; ++dv)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int sk = 0; sk < seqlen; ++sk)
|
||||
{
|
||||
int v_idx = ((b * nhead + h) * seqlen + sk) * hdim + dv;
|
||||
acc += scores[sk] * V[v_idx];
|
||||
}
|
||||
int o_idx = ((b * nhead + h) * seqlen + sq) * hdim + dv;
|
||||
O[o_idx] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 29: FMHA Backward with Bias + Dropout",
|
||||
"ALiBi bias + dropout forward (GPU) + backward plan with deterministic mode");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "64", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 64);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
print_header("Example 29: FMHA Backward with ALiBi Bias + Dropout");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("bwd_bias_dropout_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 3);
|
||||
|
||||
// Step 2: Plan backward (non-deterministic) with alibi + dropout
|
||||
std::cout << "\nStep 2: Plan Backward (non-deterministic, alibi + dropout)\n";
|
||||
|
||||
fmha_bwd_traits bwd_traits{};
|
||||
bwd_traits.hdim_q = hdim;
|
||||
bwd_traits.hdim_v = hdim;
|
||||
bwd_traits.data_type = "fp16";
|
||||
bwd_traits.is_group_mode = false;
|
||||
bwd_traits.mask_type = mask_enum::no_mask;
|
||||
bwd_traits.bias_type = bias_enum::alibi;
|
||||
bwd_traits.has_dbias = false;
|
||||
bwd_traits.has_dropout = true;
|
||||
bwd_traits.is_store_randval = false;
|
||||
bwd_traits.is_deterministic = false;
|
||||
|
||||
fmha_bwd_args bwd_args{};
|
||||
bwd_args.batch = batch;
|
||||
bwd_args.seqlen_q = seqlen;
|
||||
bwd_args.seqlen_k = seqlen;
|
||||
bwd_args.max_seqlen_q = seqlen;
|
||||
bwd_args.max_seqlen_k = seqlen;
|
||||
bwd_args.hdim_q = hdim;
|
||||
bwd_args.hdim_v = hdim;
|
||||
bwd_args.nhead_q = nhead;
|
||||
bwd_args.nhead_k = nhead;
|
||||
|
||||
auto nondet_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
|
||||
|
||||
if(nondet_plan.is_valid() && nondet_plan.stages.size() >= 2)
|
||||
{
|
||||
std::cout << " Non-deterministic plan stages (" << nondet_plan.stages.size() << "):\n";
|
||||
for(const auto& stage : nondet_plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << " Non-deterministic plan: INVALID or single-stage\n";
|
||||
}
|
||||
|
||||
// Step 2b: Plan backward (deterministic) with alibi + dropout
|
||||
std::cout << "\nStep 2b: Plan Backward (deterministic, alibi + dropout)\n";
|
||||
|
||||
fmha_bwd_traits det_traits = bwd_traits;
|
||||
det_traits.is_deterministic = true;
|
||||
|
||||
auto det_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(det_traits, bwd_args), gfx_arch));
|
||||
|
||||
if(det_plan.is_valid() && det_plan.stages.size() >= 2)
|
||||
{
|
||||
std::cout << " Deterministic plan stages (" << det_plan.stages.size() << "):\n";
|
||||
for(const auto& stage : det_plan.stages)
|
||||
{
|
||||
std::cout << " " << to_string(stage.family) << " -> " << stage.kernel_id << "\n";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << " Deterministic plan: INVALID or single-stage\n";
|
||||
}
|
||||
|
||||
std::cout << "\n Deterministic mode difference:\n";
|
||||
std::cout << " Non-det: dQ accumulated via atomic adds (faster, non-reproducible)\n";
|
||||
std::cout << " Det: dQ accumulated with split-stride (slower, bit-reproducible)\n";
|
||||
|
||||
// Step 3: Run forward on GPU with alibi bias + dropout
|
||||
std::cout << "\nStep 3: Run Forward (alibi + dropout, GPU)\n";
|
||||
|
||||
const int64_t qkv_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
|
||||
const int64_t lse_elems = static_cast<int64_t>(batch) * nhead * seqlen;
|
||||
const int64_t randval_elems = static_cast<int64_t>(batch) * nhead * seqlen * seqlen;
|
||||
|
||||
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
|
||||
GpuBuffer<float> lse_dev(lse_elems);
|
||||
GpuBuffer<uint8_t> rand_val_dev(randval_elems);
|
||||
|
||||
// ALiBi slopes: geometric series
|
||||
std::vector<float> alibi_slopes_host(nhead);
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
alibi_slopes_host[h] = -std::pow(2.0f, -(8.0f * (h + 1) / nhead));
|
||||
|
||||
GpuBuffer<float> alibi_slopes_dev(nhead);
|
||||
alibi_slopes_dev.copy_from_host(alibi_slopes_host.data());
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
o_dev.zero();
|
||||
lse_dev.zero();
|
||||
rand_val_dev.zero();
|
||||
|
||||
std::cout << " ALiBi slopes: [";
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
if(h > 0)
|
||||
std::cout << ", ";
|
||||
std::cout << std::fixed << std::setprecision(4) << alibi_slopes_host[h];
|
||||
}
|
||||
std::cout << "]\n";
|
||||
|
||||
fmha_fwd_traits fwd_traits{};
|
||||
fwd_traits.hdim_q = hdim;
|
||||
fwd_traits.hdim_v = hdim;
|
||||
fwd_traits.data_type = "fp16";
|
||||
fwd_traits.is_group_mode = false;
|
||||
fwd_traits.is_v_rowmajor = true;
|
||||
fwd_traits.has_logits_soft_cap = false;
|
||||
fwd_traits.mask_type = mask_enum::no_mask;
|
||||
fwd_traits.bias_type = bias_enum::alibi;
|
||||
fwd_traits.has_lse = true;
|
||||
fwd_traits.has_dropout = true;
|
||||
fwd_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.q_ptr = q_dev.get();
|
||||
fwd_args.k_ptr = k_dev.get();
|
||||
fwd_args.v_ptr = v_dev.get();
|
||||
fwd_args.o_ptr = o_dev.get();
|
||||
fwd_args.lse_ptr = lse_dev.get();
|
||||
|
||||
fwd_args.bias_ptr = alibi_slopes_dev.get();
|
||||
fwd_args.rand_val_ptr = rand_val_dev.get();
|
||||
fwd_args.q_descale_ptr = nullptr;
|
||||
fwd_args.k_descale_ptr = nullptr;
|
||||
fwd_args.v_descale_ptr = nullptr;
|
||||
fwd_args.sink_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fwd_args.seqlen_q = seqlen;
|
||||
fwd_args.seqlen_k = seqlen;
|
||||
fwd_args.batch = batch;
|
||||
fwd_args.max_seqlen_q = seqlen;
|
||||
fwd_args.hdim_q = hdim;
|
||||
fwd_args.hdim_v = hdim;
|
||||
fwd_args.nhead_q = nhead;
|
||||
fwd_args.nhead_k = nhead;
|
||||
fwd_args.scale_s = scale;
|
||||
fwd_args.logits_soft_cap = 0.0f;
|
||||
|
||||
fwd_args.stride_q = hdim;
|
||||
fwd_args.stride_k = hdim;
|
||||
fwd_args.stride_v = hdim;
|
||||
fwd_args.stride_bias = 0; // alibi: per-head slope, no spatial stride
|
||||
fwd_args.stride_randval = seqlen;
|
||||
fwd_args.stride_o = hdim;
|
||||
|
||||
fwd_args.nhead_stride_q = seqlen * hdim;
|
||||
fwd_args.nhead_stride_k = seqlen * hdim;
|
||||
fwd_args.nhead_stride_v = seqlen * hdim;
|
||||
fwd_args.nhead_stride_bias = 1; // alibi: stride between slopes
|
||||
fwd_args.nhead_stride_randval = seqlen * seqlen;
|
||||
fwd_args.nhead_stride_lse = seqlen;
|
||||
fwd_args.nhead_stride_o = seqlen * hdim;
|
||||
fwd_args.nhead_stride_q_descale = 0;
|
||||
fwd_args.nhead_stride_k_descale = 0;
|
||||
fwd_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fwd_args.batch_stride_q = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_k = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_v = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_bias = 0; // alibi slopes shared across batch
|
||||
fwd_args.batch_stride_randval = nhead * seqlen * seqlen;
|
||||
fwd_args.batch_stride_lse = nhead * seqlen;
|
||||
fwd_args.batch_stride_o = nhead * seqlen * hdim;
|
||||
fwd_args.batch_stride_q_descale = 0;
|
||||
fwd_args.batch_stride_k_descale = 0;
|
||||
fwd_args.batch_stride_v_descale = 0;
|
||||
|
||||
fwd_args.window_size_left = -1;
|
||||
fwd_args.window_size_right = -1;
|
||||
fwd_args.sink_size = 0;
|
||||
fwd_args.mask_type = 0;
|
||||
fwd_args.min_seqlen_q = 0;
|
||||
fwd_args.p_drop = 0.2f;
|
||||
fwd_args.s_randval = true;
|
||||
fwd_args.drop_seed_offset = std::make_pair(uint64_t(42), uint64_t(0));
|
||||
fwd_args.block_scale_size_q = 0;
|
||||
fwd_args.block_scale_size_kv = 0;
|
||||
|
||||
bool fwd_passed = false;
|
||||
try
|
||||
{
|
||||
float fwd_time = dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
|
||||
std::cout << " Forward time: " << std::fixed << std::setprecision(4) << fwd_time
|
||||
<< " ms\n";
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
|
||||
double tflops = static_cast<double>(problem.num_ops()) / (fwd_time * 1e-3) / 1e12;
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
fwd_passed = true;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " Forward ERROR: " << e.what() << "\n";
|
||||
}
|
||||
|
||||
// Step 4: Validate forward output (without dropout reference -- just check non-zero + LSE)
|
||||
std::cout << "\nStep 4: Validate Forward Output\n";
|
||||
|
||||
if(fwd_passed)
|
||||
{
|
||||
std::vector<FmhaDataType> o_host(qkv_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
std::cout << " Non-zero outputs: " << nonzero << " / " << qkv_elems << "\n";
|
||||
|
||||
std::vector<float> lse_host(lse_elems);
|
||||
lse_dev.copy_to_host(lse_host.data());
|
||||
|
||||
int lse_reasonable = 0;
|
||||
for(int64_t i = 0; i < lse_elems; ++i)
|
||||
{
|
||||
if(std::isfinite(lse_host[i]) && std::abs(lse_host[i]) < 100.0f)
|
||||
++lse_reasonable;
|
||||
}
|
||||
std::cout << " LSE reasonable: " << lse_reasonable << " / " << lse_elems << "\n";
|
||||
|
||||
std::cout << " LSE sample [0..3]: ";
|
||||
for(int i = 0; i < std::min<int>(4, lse_elems); ++i)
|
||||
std::cout << std::fixed << std::setprecision(4) << lse_host[i] << " ";
|
||||
std::cout << "\n";
|
||||
|
||||
fwd_passed = (nonzero > 0) && (lse_reasonable == lse_elems);
|
||||
|
||||
// ALiBi reference (without dropout) for sanity check on bias effect
|
||||
std::vector<float> q_f32(qkv_elems), k_f32(qkv_elems), v_f32(qkv_elems);
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
q_f32[i] = static_cast<float>(q_host[i]);
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
k_f32[i] = static_cast<float>(k_host[i]);
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
v_f32[i] = static_cast<float>(v_host[i]);
|
||||
|
||||
std::vector<float> o_ref(qkv_elems, 0.0f);
|
||||
std::vector<float> lse_ref(lse_elems, 0.0f);
|
||||
cpu_attention_fwd_alibi(q_f32,
|
||||
k_f32,
|
||||
v_f32,
|
||||
o_ref,
|
||||
lse_ref,
|
||||
batch,
|
||||
nhead,
|
||||
seqlen,
|
||||
hdim,
|
||||
scale,
|
||||
alibi_slopes_host);
|
||||
|
||||
// LSE should be close (dropout doesn't change LSE in the CK implementation --
|
||||
// LSE is computed before dropout is applied to the attention weights)
|
||||
double max_lse_err = 0.0;
|
||||
for(int64_t i = 0; i < lse_elems; ++i)
|
||||
max_lse_err =
|
||||
std::max(max_lse_err, static_cast<double>(std::abs(lse_host[i] - lse_ref[i])));
|
||||
|
||||
std::cout << " LSE vs alibi ref (no dropout) max error: " << std::scientific << max_lse_err
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
// Step 5: Show backward API pattern with all features
|
||||
std::cout << "\nStep 5: Backward API Pattern (all features)\n";
|
||||
std::cout << " bwd_traits.bias_type = alibi\n";
|
||||
std::cout << " bwd_traits.has_dropout = true\n";
|
||||
std::cout << " bwd_traits.is_store_randval = false\n";
|
||||
std::cout << " bwd_traits.has_dbias = false (alibi has no learnable params)\n";
|
||||
std::cout << "\n Non-deterministic plan: " << nondet_plan.stages.size() << " stage(s)\n";
|
||||
std::cout << " Deterministic plan: " << det_plan.stages.size() << " stage(s)\n";
|
||||
std::cout << "\n Key backward args for dropout:\n";
|
||||
std::cout << " bwd_args.p_drop = 0.2\n";
|
||||
std::cout << " bwd_args.p_undrop = 1.0 / (1.0 - p_drop) = 1.25\n";
|
||||
std::cout << " bwd_args.drop_seed_offset = {42, 0} (must match fwd)\n";
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (fwd_passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return fwd_passed ? 0 : 1;
|
||||
}
|
||||
449
dispatcher/examples/fmha/cpp/30_bwd_benchmark_fmha.cpp
Normal file
449
dispatcher/examples/fmha/cpp/30_bwd_benchmark_fmha.cpp
Normal file
@@ -0,0 +1,449 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 30: FMHA Backward Benchmark
|
||||
//
|
||||
// Demonstrates:
|
||||
// 1. Forward kernel for benchmark (with LSE for backward planning)
|
||||
// 2. Multiple problem sizes: sweep batch x seqlen
|
||||
// 3. GPU forward execution for each size with timing
|
||||
// 4. Backward plan for each size
|
||||
// 5. Summary table: Batch | SeqLen | Fwd(ms) | BwdPlan | FwdTFLOPS
|
||||
//
|
||||
// Backward kernels use planning only -- actual backward GPU execution requires
|
||||
// all 3 stages to compile, and bwd_dq_dk_dv has tile structure issues on gfx950.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(bwd_bench_fmha_kernels,
|
||||
// Forward: basic fp16 with LSE for backward
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward stage 1: dot(dO, O)
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward stage 2: dQ, dK, dV
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true)
|
||||
.max_seq_len_q(0)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
|
||||
// Backward stage 3: convert dQ
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
namespace {
|
||||
|
||||
using FmhaDataType = ck_tile::fp16_t;
|
||||
|
||||
struct BenchResult
|
||||
{
|
||||
int batch;
|
||||
int seqlen;
|
||||
float fwd_ms;
|
||||
double fwd_tflops;
|
||||
int bwd_stages;
|
||||
bool bwd_valid;
|
||||
bool fwd_passed;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 30: FMHA Backward Benchmark",
|
||||
"Sweep batch x seqlen, forward GPU + backward plan");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--nhead", "8", "Number of heads");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_option("--warmup", "2", "Warmup iterations per size");
|
||||
args.add_option("--repeat", "3", "Benchmark repetitions per size");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int nhead = args.get_int("--nhead", 8);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const int warmup = args.get_int("--warmup", 2);
|
||||
const int repeat = args.get_int("--repeat", 3);
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
|
||||
|
||||
print_header("Example 30: FMHA Backward Benchmark");
|
||||
|
||||
// Step 1: Register kernels
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaKernelSetRegistry::instance().print();
|
||||
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("bwd_bench_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
|
||||
// Problem sizes to sweep
|
||||
struct ProblemSize
|
||||
{
|
||||
int batch;
|
||||
int seqlen;
|
||||
};
|
||||
|
||||
ProblemSize sizes[] = {
|
||||
{8, 128},
|
||||
{4, 256},
|
||||
{2, 512},
|
||||
{1, 1024},
|
||||
{1, 2048},
|
||||
{1, 4096},
|
||||
};
|
||||
|
||||
std::vector<BenchResult> results;
|
||||
|
||||
// Step 2: Sweep problem sizes
|
||||
std::cout << "\nStep 2: Sweep Problem Sizes\n";
|
||||
|
||||
for(const auto& sz : sizes)
|
||||
{
|
||||
std::cout << "\n --- batch=" << sz.batch << ", seqlen=" << sz.seqlen << " ---\n";
|
||||
|
||||
const int64_t qkv_elems = static_cast<int64_t>(sz.batch) * nhead * sz.seqlen * hdim;
|
||||
const int64_t lse_elems = static_cast<int64_t>(sz.batch) * nhead * sz.seqlen;
|
||||
|
||||
BenchResult res{};
|
||||
res.batch = sz.batch;
|
||||
res.seqlen = sz.seqlen;
|
||||
|
||||
// Allocate buffers
|
||||
GpuBuffer<FmhaDataType> q_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> k_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> v_dev(qkv_elems);
|
||||
GpuBuffer<FmhaDataType> o_dev(qkv_elems);
|
||||
GpuBuffer<float> lse_dev(lse_elems);
|
||||
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
|
||||
|
||||
std::vector<FmhaDataType> q_host(qkv_elems), k_host(qkv_elems), v_host(qkv_elems);
|
||||
for(auto& x : q_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : k_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
for(auto& x : v_host)
|
||||
x = FmhaDataType(dist(rng));
|
||||
|
||||
q_dev.copy_from_host(q_host.data());
|
||||
k_dev.copy_from_host(k_host.data());
|
||||
v_dev.copy_from_host(v_host.data());
|
||||
|
||||
// Forward traits/args
|
||||
fmha_fwd_traits fwd_traits{};
|
||||
fwd_traits.hdim_q = hdim;
|
||||
fwd_traits.hdim_v = hdim;
|
||||
fwd_traits.data_type = "fp16";
|
||||
fwd_traits.is_group_mode = false;
|
||||
fwd_traits.is_v_rowmajor = true;
|
||||
fwd_traits.has_logits_soft_cap = false;
|
||||
fwd_traits.mask_type = mask_enum::no_mask;
|
||||
fwd_traits.bias_type = bias_enum::no_bias;
|
||||
fwd_traits.has_lse = true;
|
||||
fwd_traits.has_dropout = false;
|
||||
fwd_traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.q_ptr = q_dev.get();
|
||||
fwd_args.k_ptr = k_dev.get();
|
||||
fwd_args.v_ptr = v_dev.get();
|
||||
fwd_args.o_ptr = o_dev.get();
|
||||
fwd_args.lse_ptr = lse_dev.get();
|
||||
|
||||
fwd_args.bias_ptr = nullptr;
|
||||
fwd_args.q_descale_ptr = nullptr;
|
||||
fwd_args.k_descale_ptr = nullptr;
|
||||
fwd_args.v_descale_ptr = nullptr;
|
||||
fwd_args.rand_val_ptr = nullptr;
|
||||
fwd_args.sink_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_q_ptr = nullptr;
|
||||
fwd_args.block_scale_seqstart_k_ptr = nullptr;
|
||||
|
||||
fwd_args.seqlen_q = sz.seqlen;
|
||||
fwd_args.seqlen_k = sz.seqlen;
|
||||
fwd_args.batch = sz.batch;
|
||||
fwd_args.max_seqlen_q = sz.seqlen;
|
||||
fwd_args.hdim_q = hdim;
|
||||
fwd_args.hdim_v = hdim;
|
||||
fwd_args.nhead_q = nhead;
|
||||
fwd_args.nhead_k = nhead;
|
||||
fwd_args.scale_s = scale;
|
||||
fwd_args.logits_soft_cap = 0.0f;
|
||||
|
||||
fwd_args.stride_q = hdim;
|
||||
fwd_args.stride_k = hdim;
|
||||
fwd_args.stride_v = hdim;
|
||||
fwd_args.stride_bias = 0;
|
||||
fwd_args.stride_randval = 0;
|
||||
fwd_args.stride_o = hdim;
|
||||
|
||||
fwd_args.nhead_stride_q = sz.seqlen * hdim;
|
||||
fwd_args.nhead_stride_k = sz.seqlen * hdim;
|
||||
fwd_args.nhead_stride_v = sz.seqlen * hdim;
|
||||
fwd_args.nhead_stride_bias = 0;
|
||||
fwd_args.nhead_stride_randval = 0;
|
||||
fwd_args.nhead_stride_lse = sz.seqlen;
|
||||
fwd_args.nhead_stride_o = sz.seqlen * hdim;
|
||||
fwd_args.nhead_stride_q_descale = 0;
|
||||
fwd_args.nhead_stride_k_descale = 0;
|
||||
fwd_args.nhead_stride_v_descale = 0;
|
||||
|
||||
fwd_args.batch_stride_q = nhead * sz.seqlen * hdim;
|
||||
fwd_args.batch_stride_k = nhead * sz.seqlen * hdim;
|
||||
fwd_args.batch_stride_v = nhead * sz.seqlen * hdim;
|
||||
fwd_args.batch_stride_bias = 0;
|
||||
fwd_args.batch_stride_randval = 0;
|
||||
fwd_args.batch_stride_lse = nhead * sz.seqlen;
|
||||
fwd_args.batch_stride_o = nhead * sz.seqlen * hdim;
|
||||
fwd_args.batch_stride_q_descale = 0;
|
||||
fwd_args.batch_stride_k_descale = 0;
|
||||
fwd_args.batch_stride_v_descale = 0;
|
||||
|
||||
fwd_args.window_size_left = -1;
|
||||
fwd_args.window_size_right = -1;
|
||||
fwd_args.sink_size = 0;
|
||||
fwd_args.mask_type = 0;
|
||||
fwd_args.min_seqlen_q = 0;
|
||||
fwd_args.p_drop = 0.0f;
|
||||
fwd_args.s_randval = false;
|
||||
fwd_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
|
||||
fwd_args.block_scale_size_q = 0;
|
||||
fwd_args.block_scale_size_kv = 0;
|
||||
|
||||
// Warmup
|
||||
dispatcher.set_benchmarking(true);
|
||||
dispatcher.set_timing(1, 1);
|
||||
try
|
||||
{
|
||||
for(int w = 0; w < warmup; ++w)
|
||||
{
|
||||
o_dev.zero();
|
||||
lse_dev.zero();
|
||||
dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
|
||||
}
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " Warmup ERROR: " << e.what() << "\n";
|
||||
res.fwd_passed = false;
|
||||
results.push_back(res);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
dispatcher.set_timing(0, 1);
|
||||
float total_ms = 0.0f;
|
||||
bool ok = true;
|
||||
for(int r = 0; r < repeat; ++r)
|
||||
{
|
||||
o_dev.zero();
|
||||
lse_dev.zero();
|
||||
try
|
||||
{
|
||||
total_ms += dispatcher.run_fwd(fwd_traits, fwd_args, nullptr);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << " Bench ERROR: " << e.what() << "\n";
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(ok)
|
||||
{
|
||||
res.fwd_ms = total_ms / static_cast<float>(repeat);
|
||||
|
||||
auto problem =
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(fwd_traits, fwd_args), gfx_arch);
|
||||
res.fwd_tflops = static_cast<double>(problem.num_ops()) / (res.fwd_ms * 1e-3) / 1e12;
|
||||
|
||||
// Sanity check output
|
||||
std::vector<FmhaDataType> o_host(qkv_elems);
|
||||
o_dev.copy_to_host(o_host.data());
|
||||
int nonzero = 0;
|
||||
for(int64_t i = 0; i < qkv_elems; ++i)
|
||||
{
|
||||
if(static_cast<float>(o_host[i]) != 0.0f)
|
||||
++nonzero;
|
||||
}
|
||||
res.fwd_passed = (nonzero > 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
res.fwd_passed = false;
|
||||
}
|
||||
|
||||
// Backward plan for this size
|
||||
fmha_bwd_traits bwd_traits{};
|
||||
bwd_traits.hdim_q = hdim;
|
||||
bwd_traits.hdim_v = hdim;
|
||||
bwd_traits.data_type = "fp16";
|
||||
bwd_traits.is_group_mode = false;
|
||||
bwd_traits.mask_type = mask_enum::no_mask;
|
||||
bwd_traits.bias_type = bias_enum::no_bias;
|
||||
bwd_traits.has_dbias = false;
|
||||
bwd_traits.has_dropout = false;
|
||||
bwd_traits.is_store_randval = false;
|
||||
bwd_traits.is_deterministic = false;
|
||||
|
||||
fmha_bwd_args bwd_args{};
|
||||
bwd_args.batch = sz.batch;
|
||||
bwd_args.seqlen_q = sz.seqlen;
|
||||
bwd_args.seqlen_k = sz.seqlen;
|
||||
bwd_args.max_seqlen_q = sz.seqlen;
|
||||
bwd_args.max_seqlen_k = sz.seqlen;
|
||||
bwd_args.hdim_q = hdim;
|
||||
bwd_args.hdim_v = hdim;
|
||||
bwd_args.nhead_q = nhead;
|
||||
bwd_args.nhead_k = nhead;
|
||||
|
||||
auto bwd_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(bwd_traits, bwd_args), gfx_arch));
|
||||
|
||||
res.bwd_valid = bwd_plan.is_valid() && bwd_plan.stages.size() >= 2;
|
||||
res.bwd_stages = static_cast<int>(bwd_plan.stages.size());
|
||||
|
||||
std::cout << " Fwd: " << std::fixed << std::setprecision(4) << res.fwd_ms << " ms, "
|
||||
<< std::setprecision(2) << res.fwd_tflops << " TFLOPS"
|
||||
<< " | Bwd plan: " << res.bwd_stages << " stages"
|
||||
<< (res.bwd_valid ? " (valid)" : " (invalid)") << "\n";
|
||||
|
||||
results.push_back(res);
|
||||
}
|
||||
|
||||
// Step 3: Summary table
|
||||
std::cout << "\nStep 3: Summary\n\n";
|
||||
std::cout << " " << std::setw(7) << "Batch" << " | " << std::setw(7) << "SeqLen" << " | "
|
||||
<< std::setw(10) << "Fwd(ms)" << " | " << std::setw(8) << "BwdPlan" << " | "
|
||||
<< std::setw(10) << "FwdTFLOPS" << " | " << std::setw(6) << "Status" << "\n";
|
||||
std::cout << " " << std::string(60, '-') << "\n";
|
||||
|
||||
bool all_passed = true;
|
||||
for(const auto& r : results)
|
||||
{
|
||||
std::cout << " " << std::setw(7) << r.batch << " | " << std::setw(7) << r.seqlen << " | "
|
||||
<< std::fixed << std::setprecision(4) << std::setw(10) << r.fwd_ms << " | "
|
||||
<< std::setw(5) << r.bwd_stages << "stg" << " | " << std::setprecision(2)
|
||||
<< std::setw(10) << r.fwd_tflops << " | " << std::setw(6)
|
||||
<< (r.fwd_passed ? "PASS" : "FAIL") << "\n";
|
||||
if(!r.fwd_passed)
|
||||
all_passed = false;
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (all_passed ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
118
dispatcher/examples/fmha/cpp/31_logits_soft_cap_fmha.cpp
Normal file
118
dispatcher/examples/fmha/cpp/31_logits_soft_cap_fmha.cpp
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 31: FMHA Forward with Logits Soft Cap
|
||||
//
|
||||
// Demonstrates forward kernel with logits_soft_cap enabled. The soft cap
|
||||
// applies: scores_capped = tanh(scores/cap) * cap, which prevents extreme
|
||||
// attention logits from causing numerical instability while preserving
|
||||
// gradients. Planning only.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(logits_soft_cap_fmha_kernels,
|
||||
// Forward with logits soft cap: tanh(scores/cap)*cap
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("no")
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.logits(true), // enables logits_soft_cap path
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 31: FMHA Logits Soft Cap", "Forward with tanh(scores/cap)*cap");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 31: FMHA Logits Soft Cap");
|
||||
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("logits_soft_cap_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
std::cout << "\nStep 2: Plan\n";
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_logits_soft_cap = true; // runtime: cap > 0 means soft cap applied
|
||||
traits.mask_type = mask_enum::no_mask;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.batch = batch;
|
||||
fwd_args.seqlen_q = seqlen;
|
||||
fwd_args.seqlen_k = seqlen;
|
||||
fwd_args.nhead_q = nhead;
|
||||
fwd_args.nhead_k = nhead;
|
||||
fwd_args.hdim_q = hdim;
|
||||
fwd_args.hdim_v = hdim;
|
||||
fwd_args.logits_soft_cap = 30.0f; // cap value; apply tanh(scores/30)*30
|
||||
|
||||
auto plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch));
|
||||
std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n";
|
||||
|
||||
std::cout << "\nStep 3: Logits Soft Cap\n";
|
||||
std::cout << " Formula: scores_capped = tanh(scores/cap) * cap\n";
|
||||
std::cout << " Prevents extreme logits while preserving gradients.\n";
|
||||
|
||||
print_separator();
|
||||
return 0;
|
||||
}
|
||||
119
dispatcher/examples/fmha/cpp/32_sink_tokens_fmha.cpp
Normal file
119
dispatcher/examples/fmha/cpp/32_sink_tokens_fmha.cpp
Normal file
@@ -0,0 +1,119 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 32: FMHA Forward with Sink Tokens
|
||||
//
|
||||
// Demonstrates forward kernel with sink tokens enabled. Sink tokens keep the
|
||||
// first K positions always visible to all queries (StreamingLLM-style). Used
|
||||
// with causal mask: positions [0, sink_size) are never masked. Planning only.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(sink_tokens_fmha_kernels,
|
||||
// Forward with sink: first K tokens always visible
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("top_left") // causal required with sink
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no")
|
||||
.sink(true), // enables sink token path
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 32: FMHA Sink Tokens", "Forward with first K tokens always visible");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_option("--sink", "4", "Number of sink tokens");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const int sink_size = args.get_int("--sink", 4);
|
||||
|
||||
print_header("Example 32: FMHA Sink Tokens");
|
||||
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("sink_tokens_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
std::cout << "\nStep 2: Plan\n";
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.has_sink = true;
|
||||
traits.mask_type = mask_enum::mask_top_left;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.batch = batch;
|
||||
fwd_args.seqlen_q = seqlen;
|
||||
fwd_args.seqlen_k = seqlen;
|
||||
fwd_args.nhead_q = nhead;
|
||||
fwd_args.nhead_k = nhead;
|
||||
fwd_args.hdim_q = hdim;
|
||||
fwd_args.hdim_v = hdim;
|
||||
fwd_args.sink_size = sink_size;
|
||||
|
||||
auto plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch));
|
||||
std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n";
|
||||
|
||||
std::cout << "\nStep 3: Sink Tokens\n";
|
||||
std::cout << " First " << sink_size << " tokens always visible to all queries.\n";
|
||||
std::cout << " Used with causal mask for StreamingLLM-style long-context.\n";
|
||||
|
||||
print_separator();
|
||||
return 0;
|
||||
}
|
||||
256
dispatcher/examples/fmha/cpp/33_bwd_deterministic_fmha.cpp
Normal file
256
dispatcher/examples/fmha/cpp/33_bwd_deterministic_fmha.cpp
Normal file
@@ -0,0 +1,256 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 33: FMHA Backward Deterministic vs Non-Deterministic
|
||||
//
|
||||
// Demonstrates two backward kernel sets: one deterministic (bit-identical
|
||||
// results across runs) and one non-deterministic (faster, atomic reductions).
|
||||
// Planning only.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(bwd_deterministic_fmha_kernels,
|
||||
// Forward: causal + LSE for backward
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
// Backward: deterministic (bit-identical across runs)
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(true),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(true),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true)
|
||||
.max_seq_len_q(0)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(true),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
// Backward: non-deterministic (faster, atomic reductions)
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(1),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true)
|
||||
.max_seq_len_q(0)
|
||||
.selection_rank(1),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(1),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 33: FMHA Backward Deterministic",
|
||||
"Deterministic vs non-deterministic backward");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 33: FMHA Backward Deterministic");
|
||||
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("bwd_deterministic_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
std::cout << "\nStep 2: Plan (deterministic)\n";
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
fmha_bwd_traits det_traits{};
|
||||
det_traits.hdim_q = hdim;
|
||||
det_traits.hdim_v = hdim;
|
||||
det_traits.data_type = "fp16";
|
||||
det_traits.is_group_mode = false;
|
||||
det_traits.mask_type = mask_enum::mask_top_left;
|
||||
det_traits.bias_type = bias_enum::no_bias;
|
||||
det_traits.has_dbias = false;
|
||||
det_traits.has_dropout = false;
|
||||
det_traits.is_store_randval = false;
|
||||
det_traits.is_deterministic = true;
|
||||
|
||||
fmha_bwd_args bwd_args{};
|
||||
bwd_args.batch = batch;
|
||||
bwd_args.seqlen_q = seqlen;
|
||||
bwd_args.seqlen_k = seqlen;
|
||||
bwd_args.hdim_q = hdim;
|
||||
bwd_args.hdim_v = hdim;
|
||||
bwd_args.nhead_q = nhead;
|
||||
bwd_args.nhead_k = nhead;
|
||||
|
||||
auto det_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(det_traits, bwd_args), gfx_arch));
|
||||
std::cout << " Deterministic plan valid: " << (det_plan.is_valid() ? "yes" : "no") << "\n";
|
||||
|
||||
std::cout << "\nStep 3: Plan (non-deterministic)\n";
|
||||
det_traits.is_deterministic = false;
|
||||
auto nondet_plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(det_traits, bwd_args), gfx_arch));
|
||||
std::cout << " Non-deterministic plan valid: " << (nondet_plan.is_valid() ? "yes" : "no")
|
||||
<< "\n";
|
||||
|
||||
std::cout << "\nStep 4: Deterministic Mode\n";
|
||||
std::cout << " deterministic=true: bit-identical across runs (reproducible).\n";
|
||||
std::cout << " deterministic=false: faster, uses atomic reductions.\n";
|
||||
|
||||
print_separator();
|
||||
return 0;
|
||||
}
|
||||
183
dispatcher/examples/fmha/cpp/34_bwd_gqa_fmha.cpp
Normal file
183
dispatcher/examples/fmha/cpp/34_bwd_gqa_fmha.cpp
Normal file
@@ -0,0 +1,183 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 34: FMHA Backward with GQA (Grouped Query Attention)
|
||||
//
|
||||
// Demonstrates backward with nhead_q=8, nhead_k=2 (4:1 ratio). GQA is a
|
||||
// runtime property: each KV head is shared by multiple Q heads. Backward
|
||||
// handles head indexing via nhead_stride_dk/dv so dK/dV are reduced across
|
||||
// the Q-head group. Planning only.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(bwd_gqa_fmha_kernels,
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.lse(true)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dot_do_o")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_dq_dk_dv")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(16)
|
||||
.tile_n0(128)
|
||||
.tile_k0(128)
|
||||
.tile_n1(16)
|
||||
.tile_k1(128)
|
||||
.tile_k0max(32)
|
||||
.wave(1, 4, 1, 4, 1, 1, 1, 4, 1)
|
||||
.warp(16, 16, 32, 16, 16, 16, 16, 16, 16)
|
||||
.padding(true, true, true, true)
|
||||
.max_seq_len_q(0)
|
||||
.selection_rank(0),
|
||||
"gfx950")
|
||||
.add(FmhaSignature()
|
||||
.family("bwd_convert_dq")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.hdim(128)
|
||||
.mask("top_left")
|
||||
.bias("no")
|
||||
.dropout(false)
|
||||
.dbias(false)
|
||||
.store_randval(false)
|
||||
.deterministic(false),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(64)
|
||||
.tile_n0(128)
|
||||
.tile_k0(0)
|
||||
.tile_n1(0)
|
||||
.tile_k1(0)
|
||||
.tile_k0max(0)
|
||||
.padding(true, true, true, true)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 34: FMHA Backward GQA", "nhead_q=8, nhead_k=2 (4:1 ratio)");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead_q", "8", "Query heads");
|
||||
args.add_option("--nhead_k", "2", "KV heads (GQA ratio = nhead_q/nhead_k)");
|
||||
args.add_option("--seqlen", "128", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead_q = args.get_int("--nhead_q", 8);
|
||||
const int nhead_k = args.get_int("--nhead_k", 2);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
|
||||
print_header("Example 34: FMHA Backward GQA");
|
||||
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("bwd_gqa_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
std::cout << "\nStep 2: Plan (GQA nhead_q=" << nhead_q << ", nhead_k=" << nhead_k << ")\n";
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
fmha_bwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.mask_type = mask_enum::mask_top_left;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_dbias = false;
|
||||
traits.has_dropout = false;
|
||||
traits.is_store_randval = false;
|
||||
traits.is_deterministic = false;
|
||||
|
||||
fmha_bwd_args bwd_args{};
|
||||
bwd_args.batch = batch;
|
||||
bwd_args.seqlen_q = seqlen;
|
||||
bwd_args.seqlen_k = seqlen;
|
||||
bwd_args.hdim_q = hdim;
|
||||
bwd_args.hdim_v = hdim;
|
||||
bwd_args.nhead_q = nhead_q;
|
||||
bwd_args.nhead_k = nhead_k;
|
||||
|
||||
auto plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, bwd_args), gfx_arch));
|
||||
std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n";
|
||||
|
||||
std::cout << "\nStep 3: GQA Backward Head Indexing\n";
|
||||
std::cout << " Q heads " << nhead_q << ", KV heads " << nhead_k
|
||||
<< " -> each KV head shared by " << (nhead_q / nhead_k) << " Q heads.\n";
|
||||
std::cout << " dK/dV reduced across Q-head group via nhead_stride.\n";
|
||||
|
||||
print_separator();
|
||||
return 0;
|
||||
}
|
||||
121
dispatcher/examples/fmha/cpp/35_generic_mask_fmha.cpp
Normal file
121
dispatcher/examples/fmha/cpp/35_generic_mask_fmha.cpp
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// Example 35: FMHA Forward with Generic/Window Mask
|
||||
//
|
||||
// Demonstrates forward kernel with generic (window) mask. Uses
|
||||
// window_size_left and window_size_right: for each query i, attend only to
|
||||
// keys in [i - left, i + right]. -1 means unbounded. Planning only.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
|
||||
DECL_FMHA_KERNEL_SET(generic_mask_fmha_kernels,
|
||||
// Forward with generic/window mask
|
||||
.add(FmhaSignature()
|
||||
.family("fwd")
|
||||
.dtype("fp16")
|
||||
.mode("batch")
|
||||
.vlayout("r")
|
||||
.hdim(128)
|
||||
.mask("generic") // window mask via left/right
|
||||
.bias("no")
|
||||
.lse(false)
|
||||
.dropout(false)
|
||||
.qscale("no"),
|
||||
FmhaAlgorithm()
|
||||
.tile_m0(128)
|
||||
.tile_n0(128)
|
||||
.tile_k0(32)
|
||||
.tile_n1(128)
|
||||
.tile_k1(32)
|
||||
.tile_k0max(128)
|
||||
.wave_m0(4)
|
||||
.wave_n0(1)
|
||||
.wave_k0(1)
|
||||
.wave_m1(4)
|
||||
.wave_n1(1)
|
||||
.wave_k1(1)
|
||||
.warp_m0(32)
|
||||
.warp_n0(32)
|
||||
.warp_k0(16)
|
||||
.warp_m1(32)
|
||||
.warp_n1(32)
|
||||
.warp_k1(16)
|
||||
.pipeline("qr_async")
|
||||
.padding(true, true, true, true)
|
||||
.alignments(128, 128)
|
||||
.selection_rank(0),
|
||||
"gfx950"));
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 35: FMHA Generic Mask", "Window mask via left/right params");
|
||||
args.add_option("--arch", "gfx950", "GPU architecture");
|
||||
args.add_option("--batch", "2", "Batch size");
|
||||
args.add_option("--nhead", "4", "Number of heads");
|
||||
args.add_option("--seqlen", "128", "Sequence length");
|
||||
args.add_option("--hdim", "128", "Head dimension");
|
||||
args.add_option("--window_left", "64", "Window size left (-1=unbounded)");
|
||||
args.add_option("--window_right", "0", "Window size right (-1=unbounded)");
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
const std::string gfx_arch = args.get("--arch", "gfx950");
|
||||
const int batch = args.get_int("--batch", 2);
|
||||
const int nhead = args.get_int("--nhead", 4);
|
||||
const int seqlen = args.get_int("--seqlen", 128);
|
||||
const int hdim = args.get_int("--hdim", 128);
|
||||
const int window_size_left = args.get_int("--window_left", 64);
|
||||
const int window_size_right = args.get_int("--window_right", 0);
|
||||
|
||||
print_header("Example 35: FMHA Generic Mask");
|
||||
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
FmhaRegistry registry;
|
||||
registry.set_name("generic_mask_fmha");
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
std::cout << "\nStep 2: Plan\n";
|
||||
FmhaDispatcher dispatcher(®istry);
|
||||
fmha_fwd_traits traits{};
|
||||
traits.hdim_q = hdim;
|
||||
traits.hdim_v = hdim;
|
||||
traits.data_type = "fp16";
|
||||
traits.is_group_mode = false;
|
||||
traits.is_v_rowmajor = true;
|
||||
traits.mask_type = mask_enum::window_generic;
|
||||
traits.bias_type = bias_enum::no_bias;
|
||||
traits.has_lse = false;
|
||||
traits.has_dropout = false;
|
||||
traits.qscale_type = quant_scale_enum::no_scale;
|
||||
|
||||
fmha_fwd_args fwd_args{};
|
||||
fwd_args.batch = batch;
|
||||
fwd_args.seqlen_q = seqlen;
|
||||
fwd_args.seqlen_k = seqlen;
|
||||
fwd_args.nhead_q = nhead;
|
||||
fwd_args.nhead_k = nhead;
|
||||
fwd_args.hdim_q = hdim;
|
||||
fwd_args.hdim_v = hdim;
|
||||
fwd_args.window_size_left = window_size_left;
|
||||
fwd_args.window_size_right = window_size_right;
|
||||
|
||||
auto plan = dispatcher.plan(
|
||||
FmhaProblem::from_invocation(FmhaInvocation::make(traits, fwd_args), gfx_arch));
|
||||
std::cout << " Plan valid: " << (plan.is_valid() ? "yes" : "no") << "\n";
|
||||
|
||||
std::cout << "\nStep 3: Window Mask Params\n";
|
||||
std::cout << " window_size_left=" << window_size_left
|
||||
<< ", window_size_right=" << window_size_right << "\n";
|
||||
std::cout << " Query i attends to keys in [i-left, i+right]. -1 = unbounded.\n";
|
||||
|
||||
print_separator();
|
||||
return 0;
|
||||
}
|
||||
259
dispatcher/examples/fmha/python/01_basic_fmha.py
Normal file
259
dispatcher/examples/fmha/python/01_basic_fmha.py
Normal file
@@ -0,0 +1,259 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 01: Basic FMHA with Multiple Kernels
|
||||
|
||||
Demonstrates:
|
||||
1. Building a Registry with multiple kernel configurations
|
||||
2. Parallel JIT compilation via registry.build()
|
||||
3. Running each kernel and validating output against CPU reference
|
||||
4. Comparing performance across kernels
|
||||
|
||||
Usage:
|
||||
python3 01_basic_fmha.py
|
||||
python3 01_basic_fmha.py --dtype bf16
|
||||
python3 01_basic_fmha.py --size 256
|
||||
python3 01_basic_fmha.py --num-kernels 4
|
||||
python3 01_basic_fmha.py --workers 4
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelSpec,
|
||||
FmhaRegistry,
|
||||
FmhaProblem,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
spec_to_config,
|
||||
)
|
||||
|
||||
|
||||
# FmhaKernelSpec fields:
|
||||
# name -- human-readable kernel identifier
|
||||
# hdim -- head dimension (hdim_q = hdim_v for symmetric attention)
|
||||
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
|
||||
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
|
||||
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
|
||||
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
|
||||
#
|
||||
# spec_to_config() fills in Stage 1 automatically:
|
||||
# tile_n1 = hdim, tile_k1 = tile_k0, tile_k0max = hdim
|
||||
# wave/warp use sensible defaults (4x1x1 wave, 32x32x16 warp)
|
||||
KERNEL_SPECS = [
|
||||
# Async pipelines -- different tile_m0 x tile_n0 combos
|
||||
FmhaKernelSpec(
|
||||
name="async_128x128_k32",
|
||||
hdim=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="async_128x64_k32",
|
||||
hdim=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=128,
|
||||
tile_n0=64,
|
||||
tile_k0=32,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="async_64x128_k32",
|
||||
hdim=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=64,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="async_64x64_k32",
|
||||
hdim=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=64,
|
||||
tile_n0=64,
|
||||
tile_k0=32,
|
||||
),
|
||||
# Synchronous pipelines
|
||||
FmhaKernelSpec(
|
||||
name="sync_128x128_k32",
|
||||
hdim=128,
|
||||
pipeline="qr",
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="sync_64x128_k32",
|
||||
hdim=128,
|
||||
pipeline="qr",
|
||||
tile_m0=64,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="sync_128x64_k32",
|
||||
hdim=128,
|
||||
pipeline="qr",
|
||||
tile_m0=128,
|
||||
tile_n0=64,
|
||||
tile_k0=32,
|
||||
),
|
||||
# Different tile_k0 (K dimension of Q*K^T)
|
||||
FmhaKernelSpec(
|
||||
name="async_128x128_k64",
|
||||
hdim=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=64,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="async_64x128_k64",
|
||||
hdim=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=64,
|
||||
tile_n0=128,
|
||||
tile_k0=64,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Basic FMHA with Multiple Kernels")
|
||||
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--size", type=int, default=128, help="Sequence length")
|
||||
parser.add_argument("--num-kernels", type=int, default=0, help="0 = all")
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=0, help="Max parallel JIT workers (0 = auto)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 01: Basic FMHA with Multiple Kernels")
|
||||
print("=" * 70)
|
||||
|
||||
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
|
||||
|
||||
# Step 1: Build registry
|
||||
print(
|
||||
f"\n {len(specs)} kernel configurations, dtype={args.dtype}, arch={args.arch}"
|
||||
)
|
||||
print(f"\n {'#':<3} {'Name':<24} {'Tile':<14} {'Pipeline':<12}")
|
||||
print(" " + "-" * 56)
|
||||
for i, s in enumerate(specs, 1):
|
||||
print(
|
||||
f" {i:<3} {s.name:<24} {s.tile_m0}x{s.tile_n0}x{s.tile_k0:<6} {s.pipeline:<12}"
|
||||
)
|
||||
|
||||
reg = FmhaRegistry(name="basic_fmha")
|
||||
for s in specs:
|
||||
reg.register_kernel(spec_to_config(s, args.dtype, args.arch))
|
||||
|
||||
# Step 2: Parallel JIT build via registry.build()
|
||||
workers = args.workers if args.workers > 0 else None
|
||||
print(
|
||||
f"\n--- Parallel JIT Build ({len(specs)} kernels{f', workers={workers}' if workers else ''}) ---"
|
||||
)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
setups = reg.build(verbose=False, max_workers=workers)
|
||||
jit_build_s = time.perf_counter() - t0
|
||||
|
||||
built = sum(1 for s in setups if s.success)
|
||||
print(f" Built: {built}/{len(specs)} kernels in {jit_build_s:.1f} s")
|
||||
|
||||
if built == 0:
|
||||
print(" ERROR: No kernels built")
|
||||
return 1
|
||||
|
||||
# Step 3: Run each kernel and validate
|
||||
seqlen = args.size
|
||||
prob = FmhaProblem(
|
||||
batch=2,
|
||||
nhead_q=8,
|
||||
nhead_k=8,
|
||||
seqlen_q=seqlen,
|
||||
seqlen_k=seqlen,
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n--- Running Kernels (B={prob.batch} H={prob.nhead_q} S={seqlen} D={prob.hdim_q}) ---"
|
||||
)
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32),
|
||||
K.astype(np.float32),
|
||||
V.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<24} {'Pipeline':<12} {'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':<6}"
|
||||
)
|
||||
print(" " + "-" * 80)
|
||||
|
||||
results = []
|
||||
for i, (spec, setup) in enumerate(zip(specs, setups), 1):
|
||||
if not setup.success or setup.runner is None:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<24} {spec.pipeline:<12} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}"
|
||||
)
|
||||
results.append((spec.name, False, 0.0, 0.0, 0.0))
|
||||
continue
|
||||
|
||||
res = setup.runner.run(Q, K, V, prob)
|
||||
if not res.success:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<24} {spec.pipeline:<12} {'---':>10} {'---':>10} {'---':>10} {'FAIL':<6}"
|
||||
)
|
||||
results.append((spec.name, False, 0.0, 0.0, 0.0))
|
||||
continue
|
||||
|
||||
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
|
||||
ok = max_err < 1e-2
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
print(
|
||||
f" {i:<3} {spec.name:<24} {spec.pipeline:<12} {res.time_ms:>10.4f} {res.tflops:>10.2f} {max_err:>10.2e} {tag:<6}"
|
||||
)
|
||||
results.append((spec.name, ok, res.time_ms, res.tflops, max_err))
|
||||
setup.runner.cleanup()
|
||||
|
||||
# Step 4: Summary
|
||||
passed = sum(1 for r in results if r[1])
|
||||
failed = len(results) - passed
|
||||
valid = [r for r in results if r[1]]
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(f" Results: {passed}/{len(results)} passed")
|
||||
print(
|
||||
f" Problem: B={prob.batch} H={prob.nhead_q} S={seqlen} D={prob.hdim_q}, dtype={args.dtype}"
|
||||
)
|
||||
print(f" JIT time: {jit_build_s:.1f} s (parallel)")
|
||||
if valid:
|
||||
best = max(valid, key=lambda x: x[3])
|
||||
print(f" Best: {best[0]} ({best[3]:.2f} TFLOPS)")
|
||||
print(f" Status: {'PASS' if failed == 0 else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
148
dispatcher/examples/fmha/python/02_multi_shape.py
Normal file
148
dispatcher/examples/fmha/python/02_multi_shape.py
Normal file
@@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 02: Multi-Shape FMHA
|
||||
|
||||
Runs FMHA forward with a single kernel across multiple problem shapes
|
||||
(varying batch, sequence length, and head count).
|
||||
|
||||
Usage:
|
||||
python3 02_multi_shape.py
|
||||
python3 02_multi_shape.py --help
|
||||
python3 02_multi_shape.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelSpec,
|
||||
FmhaProblem,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
spec_to_config,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Multi-Shape FMHA Example - runs multiple shapes",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 02_multi_shape.py # Default FP16
|
||||
python3 02_multi_shape.py --dtype bf16 # BF16 FMHA
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default=detect_gpu_arch(),
|
||||
help="Target architecture (auto-detected from rocminfo)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 02: Multi-Shape FMHA")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Setup dispatcher
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
# FmhaKernelSpec fields:
|
||||
# name -- human-readable kernel identifier
|
||||
# hdim -- head dimension (hdim_q = hdim_v)
|
||||
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
|
||||
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
|
||||
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
|
||||
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
|
||||
spec = FmhaKernelSpec(name="multi_shape", hdim=128, pipeline="qr_async")
|
||||
config = spec_to_config(spec, dtype=args.dtype, arch=args.arch)
|
||||
|
||||
setup = setup_fmha_dispatcher(config, verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
runner = setup.runner
|
||||
print(f" Library: {setup.library_path}")
|
||||
print(f" Build: {setup.build_time_s:.1f} s")
|
||||
|
||||
# Step 2: Run batch of different shapes
|
||||
print("\nStep 2: Run Shapes")
|
||||
|
||||
shapes = [
|
||||
# (batch, nhead_q, nhead_k, seqlen_q, seqlen_k, hdim)
|
||||
(1, 4, 4, 64, 64, 128),
|
||||
(2, 8, 8, 128, 128, 128),
|
||||
(4, 8, 8, 128, 128, 128),
|
||||
(1, 16, 16, 256, 256, 128),
|
||||
(2, 8, 8, 256, 256, 128),
|
||||
(1, 8, 8, 512, 512, 128),
|
||||
(2, 4, 4, 512, 512, 128),
|
||||
(1, 8, 8, 1024, 1024, 128),
|
||||
]
|
||||
|
||||
print(f"\n {'#':<3} {'Shape':<36} {'Time(ms)':>10} {'TFLOPS':>10} {'Status':>8}")
|
||||
print(" " + "-" * 70)
|
||||
|
||||
total_ops = 0
|
||||
total_time = 0.0
|
||||
|
||||
for idx, (b, hq, hk, sq, sk, d) in enumerate(shapes, 1):
|
||||
prob = FmhaProblem(
|
||||
batch=b,
|
||||
nhead_q=hq,
|
||||
nhead_k=hk,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=d,
|
||||
hdim_v=d,
|
||||
)
|
||||
shape_str = f"B{b}_Hq{hq}_Hk{hk}_S{sq}x{sk}_D{d}"
|
||||
|
||||
np.random.seed(42 + idx)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
|
||||
|
||||
result = runner.run(Q, K, V, prob)
|
||||
|
||||
if result.success:
|
||||
total_ops += prob.num_ops
|
||||
total_time += result.time_ms
|
||||
print(
|
||||
f" {idx:<3} {shape_str:<36} {result.time_ms:>10.4f} {result.tflops:>10.2f} {'OK':>8}"
|
||||
)
|
||||
else:
|
||||
print(f" {idx:<3} {shape_str:<36} {'N/A':>10} {'N/A':>10} {'Error':>8}")
|
||||
|
||||
print(" " + "-" * 70)
|
||||
|
||||
if total_time > 0:
|
||||
avg_tflops = (total_ops / 1e12) / (total_time / 1000)
|
||||
print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS")
|
||||
|
||||
runner.cleanup()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Multi-Shape FMHA complete!")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
170
dispatcher/examples/fmha/python/03_benchmark.py
Normal file
170
dispatcher/examples/fmha/python/03_benchmark.py
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 03: FMHA Benchmark
|
||||
|
||||
Performance benchmarking with warmup and repeated iterations across
|
||||
multiple (batch, sequence length) configurations.
|
||||
|
||||
Usage:
|
||||
python3 03_benchmark.py
|
||||
python3 03_benchmark.py --help
|
||||
python3 03_benchmark.py --warmup 5 --repeat 20
|
||||
python3 03_benchmark.py --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelSpec,
|
||||
FmhaProblem,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
spec_to_config,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="FMHA Benchmark Example - performance testing",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 03_benchmark.py # Default benchmark suite
|
||||
python3 03_benchmark.py --warmup 5 # More warmup iterations
|
||||
python3 03_benchmark.py --repeat 20 # More benchmark iterations
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default=detect_gpu_arch(),
|
||||
help="Target architecture (auto-detected from rocminfo)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=3, help="Warmup iterations (default: 3)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat", type=int, default=10, help="Benchmark iterations (default: 10)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 03: FMHA Benchmark")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Setup dispatcher with compute-optimized config
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
# FmhaKernelSpec fields:
|
||||
# name -- human-readable kernel identifier
|
||||
# hdim -- head dimension (hdim_q = hdim_v)
|
||||
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
|
||||
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
|
||||
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
|
||||
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
|
||||
spec = FmhaKernelSpec(name="benchmark", hdim=128, pipeline="qr_async")
|
||||
config = spec_to_config(spec, dtype="fp16", arch=args.arch)
|
||||
|
||||
setup = setup_fmha_dispatcher(config, verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
runner = setup.runner
|
||||
print(f" Library: {setup.library_path}")
|
||||
print(f" Build: {setup.build_time_s:.1f} s")
|
||||
|
||||
# Step 2: Benchmark
|
||||
print("\nStep 2: Benchmark")
|
||||
|
||||
bench_configs = [
|
||||
(1, 128),
|
||||
(1, 256),
|
||||
(1, 512),
|
||||
(1, 1024),
|
||||
(1, 2048),
|
||||
(2, 128),
|
||||
(2, 256),
|
||||
(2, 512),
|
||||
(2, 1024),
|
||||
(4, 128),
|
||||
(4, 256),
|
||||
(4, 512),
|
||||
(8, 128),
|
||||
(8, 256),
|
||||
]
|
||||
|
||||
print(f" Warmup: {args.warmup}, Repeat: {args.repeat}\n")
|
||||
|
||||
print(
|
||||
f" {'Batch':>5} {'SeqLen':>7} | {'Min(ms)':>10} {'Avg(ms)':>10} {'Max(ms)':>10} | {'TFLOPS':>10}"
|
||||
)
|
||||
print(" " + "-" * 62)
|
||||
|
||||
all_tflops = []
|
||||
|
||||
for batch, seqlen in bench_configs:
|
||||
prob = FmhaProblem(
|
||||
batch=batch,
|
||||
nhead_q=8,
|
||||
nhead_k=8,
|
||||
seqlen_q=seqlen,
|
||||
seqlen_k=seqlen,
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
)
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
|
||||
|
||||
for _ in range(args.warmup):
|
||||
runner.run(Q, K, V, prob)
|
||||
|
||||
times = []
|
||||
for _ in range(args.repeat):
|
||||
result = runner.run(Q, K, V, prob)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
|
||||
if times:
|
||||
min_time = min(times)
|
||||
avg_time = sum(times) / len(times)
|
||||
max_time = max(times)
|
||||
tflops = prob.num_ops / (avg_time * 1e-3) / 1e12
|
||||
all_tflops.append(tflops)
|
||||
print(
|
||||
f" {batch:>5} {seqlen:>7} | {min_time:>10.4f} {avg_time:>10.4f} {max_time:>10.4f} | {tflops:>10.2f}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" {batch:>5} {seqlen:>7} | {'---':>10} {'---':>10} {'---':>10} | {'FAIL':>10}"
|
||||
)
|
||||
|
||||
runner.cleanup()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("Summary")
|
||||
print("=" * 70)
|
||||
|
||||
if all_tflops:
|
||||
print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS")
|
||||
print(f" Peak: {max(all_tflops):.2f} TFLOPS")
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
176
dispatcher/examples/fmha/python/04_validation.py
Normal file
176
dispatcher/examples/fmha/python/04_validation.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 04: FMHA Validation
|
||||
|
||||
Validates GPU FMHA against CPU reference across multiple test cases
|
||||
including standard shapes, GQA ratios, and edge cases.
|
||||
|
||||
Usage:
|
||||
python3 04_validation.py
|
||||
python3 04_validation.py --help
|
||||
python3 04_validation.py --dtype bf16
|
||||
python3 04_validation.py --rtol 1e-2
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelSpec,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
spec_to_config,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="FMHA Validation Example - validates GPU results against CPU",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 04_validation.py # Default FP16 validation
|
||||
python3 04_validation.py --dtype bf16 # BF16 validation
|
||||
python3 04_validation.py --rtol 1e-2 # Relaxed tolerance
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rtol", type=float, default=1e-2, help="Relative tolerance (default: 1e-2)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default=detect_gpu_arch(),
|
||||
help="Target architecture (auto-detected from rocminfo)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 04: FMHA Validation")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Setup dispatcher
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
# FmhaKernelSpec fields:
|
||||
# name -- human-readable kernel identifier
|
||||
# hdim -- head dimension (hdim_q = hdim_v)
|
||||
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
|
||||
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
|
||||
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
|
||||
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
|
||||
spec = FmhaKernelSpec(name="validation", hdim=128, pipeline="qr_async")
|
||||
config = spec_to_config(spec, dtype=args.dtype, arch=args.arch)
|
||||
|
||||
setup = setup_fmha_dispatcher(config, verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
runner = setup.runner
|
||||
print(f" Library: {setup.library_path}")
|
||||
print(f" Build: {setup.build_time_s:.1f} s")
|
||||
|
||||
# Step 2: Run validation tests
|
||||
print("\nStep 2: Validation Tests")
|
||||
|
||||
validator = FmhaValidator(rtol=args.rtol, atol=args.atol)
|
||||
|
||||
# (name, batch, nhead_q, nhead_k, seqlen_q, seqlen_k, hdim)
|
||||
test_cases = [
|
||||
("Small", 1, 4, 4, 64, 64, 128),
|
||||
("Medium", 2, 8, 8, 128, 128, 128),
|
||||
("Large", 1, 8, 8, 256, 256, 128),
|
||||
("Long-seq", 1, 4, 4, 512, 512, 128),
|
||||
("Non-square", 2, 4, 4, 64, 256, 128),
|
||||
("GQA-2:1", 2, 8, 4, 128, 128, 128),
|
||||
("GQA-4:1", 1, 16, 4, 128, 128, 128),
|
||||
("GQA-8:1", 1, 16, 2, 64, 64, 128),
|
||||
("Single-query", 1, 4, 4, 1, 128, 128),
|
||||
("Batched", 4, 8, 8, 128, 128, 128),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
print(f"\n {'#':<3} {'Test':<14} {'Shape':<30} {'MaxErr':>10} {'Status':>8}")
|
||||
print(" " + "-" * 70)
|
||||
|
||||
for idx, (name, b, hq, hk, sq, sk, d) in enumerate(test_cases, 1):
|
||||
prob = FmhaProblem(
|
||||
batch=b,
|
||||
nhead_q=hq,
|
||||
nhead_k=hk,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=d,
|
||||
hdim_v=d,
|
||||
)
|
||||
shape_str = f"B{b}_Hq{hq}_Hk{hk}_S{sq}x{sk}"
|
||||
|
||||
np.random.seed(42 + idx)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
|
||||
|
||||
result = runner.run(Q, K, V, prob)
|
||||
if not result.success:
|
||||
print(
|
||||
f" {idx:<3} {name:<14} {shape_str:<30} {'GPU Err':>10} {'FAILED':>8}"
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32),
|
||||
K.astype(np.float32),
|
||||
V.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
is_valid, max_abs, _ = validator.check(result.output, O_ref)
|
||||
|
||||
if is_valid:
|
||||
print(
|
||||
f" {idx:<3} {name:<14} {shape_str:<30} {max_abs:>10.2e} {'PASSED':>8}"
|
||||
)
|
||||
passed += 1
|
||||
else:
|
||||
print(
|
||||
f" {idx:<3} {name:<14} {shape_str:<30} {max_abs:>10.2e} {'FAILED':>8}"
|
||||
)
|
||||
failed += 1
|
||||
|
||||
runner.cleanup()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
total = passed + failed
|
||||
print(f" Results: {passed}/{total} passed")
|
||||
print(f" Settings: dtype={args.dtype}, rtol={args.rtol}, atol={args.atol}")
|
||||
print(f" Status: {'PASS' if failed == 0 else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
219
dispatcher/examples/fmha/python/05_numpy_integration.py
Normal file
219
dispatcher/examples/fmha/python/05_numpy_integration.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 05: NumPy Integration
|
||||
|
||||
Shows how to create a GPU-accelerated attention wrapper that works
|
||||
seamlessly with NumPy arrays, hiding all HIP memory management.
|
||||
|
||||
Usage:
|
||||
python3 05_numpy_integration.py
|
||||
python3 05_numpy_integration.py --help
|
||||
python3 05_numpy_integration.py --seqlen 256
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def fmha_matmul(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float = None,
|
||||
runner=None,
|
||||
) -> np.ndarray:
|
||||
"""GPU-accelerated scaled dot-product attention via FMHA dispatcher.
|
||||
|
||||
Args:
|
||||
Q: [batch, nhead_q, seqlen_q, hdim_q] float16/float32
|
||||
K: [batch, nhead_k, seqlen_k, hdim_q] float16/float32
|
||||
V: [batch, nhead_k, seqlen_k, hdim_v] float16/float32
|
||||
scale: softmax scale (default: 1/sqrt(hdim_q))
|
||||
runner: reuse an existing runner from setup_fmha_dispatcher
|
||||
|
||||
Returns:
|
||||
O: [batch, nhead_q, seqlen_q, hdim_v] float16
|
||||
"""
|
||||
batch, nhead_q, seqlen_q, hdim_q = Q.shape
|
||||
_, nhead_k, seqlen_k, hdim_v = V.shape
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=batch,
|
||||
nhead_q=nhead_q,
|
||||
nhead_k=nhead_k,
|
||||
seqlen_q=seqlen_q,
|
||||
seqlen_k=seqlen_k,
|
||||
hdim_q=hdim_q,
|
||||
hdim_v=hdim_v,
|
||||
)
|
||||
|
||||
result = runner.run(
|
||||
Q.astype(np.float16), K.astype(np.float16), V.astype(np.float16), prob
|
||||
)
|
||||
if not result.success:
|
||||
raise RuntimeError(f"GPU FMHA failed: {result.error}")
|
||||
return result.output
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NumPy Integration Example - GPU-accelerated attention wrapper",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 05_numpy_integration.py # Default
|
||||
python3 05_numpy_integration.py --seqlen 256 # Longer sequences
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=4)
|
||||
parser.add_argument("--seqlen", type=int, default=64)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument("--rtol", type=float, default=1e-2)
|
||||
parser.add_argument("--atol", type=float, default=1e-2)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 05: NumPy Integration")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: JIT-compile FMHA kernel
|
||||
print("\nStep 1: JIT-Compile FMHA Dispatcher")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
return 1
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
print(f" Arch: {args.arch}")
|
||||
|
||||
np_dtype = np.float16
|
||||
|
||||
# Step 2: Demo -- simple attention call
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 2: Simple Attention Call")
|
||||
print("=" * 70)
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(args.batch, args.nhead, args.seqlen, args.hdim) * 0.5).astype(
|
||||
np_dtype
|
||||
)
|
||||
K = (np.random.randn(args.batch, args.nhead, args.seqlen, args.hdim) * 0.5).astype(
|
||||
np_dtype
|
||||
)
|
||||
V = (np.random.randn(args.batch, args.nhead, args.seqlen, args.hdim) * 0.5).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
out = fmha_matmul(Q, K, V, runner=runner)
|
||||
print(f" Q: {Q.shape} -> O: {out.shape}")
|
||||
print(f" Output range: [{out.min():.4f}, {out.max():.4f}]")
|
||||
print(f" Output sum: {out.sum():.4f}")
|
||||
|
||||
# Step 3: Validate against CPU reference
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 3: Validate Against CPU Reference")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32),
|
||||
K.astype(np.float32),
|
||||
V.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
|
||||
diff = np.abs(out.astype(np.float32) - O_ref)
|
||||
max_abs = float(diff.max())
|
||||
max_rel = float((diff / (np.abs(O_ref) + 1e-6)).max())
|
||||
match = np.allclose(out.astype(np.float32), O_ref, atol=args.atol, rtol=args.rtol)
|
||||
|
||||
print(f" Max abs error: {max_abs:.6e}")
|
||||
print(f" Max rel error: {max_rel:.6e}")
|
||||
print(f" Match: {match}")
|
||||
|
||||
# Step 4: Demo -- multi-head attention with GQA
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 4: GQA Attention (nhead_q=8, nhead_k=2)")
|
||||
print("=" * 70)
|
||||
|
||||
nhead_q, nhead_k = 8, 2
|
||||
Q_gqa = (np.random.randn(args.batch, nhead_q, args.seqlen, args.hdim) * 0.5).astype(
|
||||
np_dtype
|
||||
)
|
||||
K_gqa = (np.random.randn(args.batch, nhead_k, args.seqlen, args.hdim) * 0.5).astype(
|
||||
np_dtype
|
||||
)
|
||||
V_gqa = (np.random.randn(args.batch, nhead_k, args.seqlen, args.hdim) * 0.5).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
O_gqa = fmha_matmul(Q_gqa, K_gqa, V_gqa, runner=runner)
|
||||
|
||||
prob_gqa = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=nhead_q,
|
||||
nhead_k=nhead_k,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
O_gqa_ref = cpu_attention_fwd(
|
||||
Q_gqa.astype(np.float32),
|
||||
K_gqa.astype(np.float32),
|
||||
V_gqa.astype(np.float32),
|
||||
prob_gqa.scale,
|
||||
)
|
||||
gqa_match = np.allclose(
|
||||
O_gqa.astype(np.float32), O_gqa_ref, atol=args.atol, rtol=args.rtol
|
||||
)
|
||||
|
||||
print(f" Q: {Q_gqa.shape}, K: {K_gqa.shape}, V: {V_gqa.shape}")
|
||||
print(f" O: {O_gqa.shape}")
|
||||
print(f" Match: {gqa_match}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("NumPy Integration Pattern:")
|
||||
print("=" * 70)
|
||||
print(" 1. setup = setup_fmha_dispatcher(config)")
|
||||
print(" 2. O = fmha_matmul(Q, K, V, runner=setup.runner)")
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if match and gqa_match else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
220
dispatcher/examples/fmha/python/06_json_export.py
Normal file
220
dispatcher/examples/fmha/python/06_json_export.py
Normal file
@@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 06: JSON Export
|
||||
|
||||
Builds an FMHA kernel via setup_fmha_dispatcher, then exports the
|
||||
registry configuration to JSON for inspection or reuse.
|
||||
|
||||
Usage:
|
||||
python3 06_json_export.py
|
||||
python3 06_json_export.py --help
|
||||
python3 06_json_export.py --output fmha_kernels.json
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
setup_fmha_dispatcher,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="JSON Export Example - export FMHA registry to JSON",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 06_json_export.py # Default output
|
||||
python3 06_json_export.py --output fmha_kernels.json # Custom file
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default="fmha_kernels.json",
|
||||
help="Output JSON file (default: fmha_kernels.json)",
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 06: JSON Export")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Define FMHA kernel configurations
|
||||
print("\nStep 1: Define Kernel Configurations")
|
||||
|
||||
configs = [
|
||||
FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
pipeline="qr_async",
|
||||
# Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
# Stage 1 (Attn*V): hdim_v x seqlen_k x alignment
|
||||
tile_n1=128,
|
||||
tile_k1=32,
|
||||
tile_k0max=128,
|
||||
# Wave config per stage
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
# Warp tile per stage
|
||||
warp_m0=32,
|
||||
warp_n0=32,
|
||||
warp_k0=16,
|
||||
warp_m1=32,
|
||||
warp_n1=32,
|
||||
warp_k1=16,
|
||||
gfx_arch=args.arch,
|
||||
),
|
||||
FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
pipeline="qr",
|
||||
tile_m0=64,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
tile_n1=128,
|
||||
tile_k1=32,
|
||||
tile_k0max=128,
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
warp_m0=16,
|
||||
warp_n0=16,
|
||||
warp_k0=32,
|
||||
warp_m1=16,
|
||||
warp_n1=16,
|
||||
warp_k1=16,
|
||||
pad_s=False,
|
||||
pad_sk=False,
|
||||
pad_d=True,
|
||||
pad_dv=True,
|
||||
gfx_arch=args.arch,
|
||||
),
|
||||
FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=64,
|
||||
hdim_v=64,
|
||||
pipeline="qr_async",
|
||||
tile_m0=128,
|
||||
tile_n0=64,
|
||||
tile_k0=32,
|
||||
tile_n1=64,
|
||||
tile_k1=32,
|
||||
tile_k0max=64,
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
warp_m0=32,
|
||||
warp_n0=32,
|
||||
warp_k0=16,
|
||||
warp_m1=32,
|
||||
warp_n1=32,
|
||||
warp_k1=16,
|
||||
gfx_arch=args.arch,
|
||||
),
|
||||
]
|
||||
|
||||
for i, cfg in enumerate(configs, 1):
|
||||
print(f" [{i}] {cfg.name}: pipeline={cfg.pipeline}, hdim={cfg.hdim_q}")
|
||||
|
||||
# Step 2: Build via setup_fmha_dispatcher
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 2: Build Kernel (JIT)")
|
||||
print("=" * 70)
|
||||
|
||||
setup = setup_fmha_dispatcher(configs[0], verbose=True)
|
||||
if setup.success:
|
||||
print(f" Built: {setup.library_path}")
|
||||
print(f" Time: {setup.build_time_s:.1f} s")
|
||||
else:
|
||||
print(f" Build skipped/failed: {setup.error}")
|
||||
print(" (Proceeding with config export only)")
|
||||
|
||||
# Step 3: Export to JSON
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 3: Export to JSON")
|
||||
print("=" * 70)
|
||||
|
||||
export_data = {
|
||||
"registry": "fmha_export",
|
||||
"arch": args.arch,
|
||||
"kernel_count": len(configs),
|
||||
"kernels": [],
|
||||
}
|
||||
|
||||
for cfg in configs:
|
||||
kernel_info = {
|
||||
"name": cfg.name,
|
||||
"family": cfg.family,
|
||||
"data_type": cfg.data_type,
|
||||
"hdim_q": cfg.hdim_q,
|
||||
"hdim_v": cfg.hdim_v,
|
||||
"pipeline": cfg.pipeline,
|
||||
"tile": list(cfg.tile),
|
||||
"wave": list(cfg.wave),
|
||||
"warp": list(cfg.warp),
|
||||
"padding": list(cfg.padding),
|
||||
"mode": cfg.mode,
|
||||
"target": cfg.gfx_arch,
|
||||
"codegen_json": json.loads(cfg.to_codegen_json()),
|
||||
}
|
||||
export_data["kernels"].append(kernel_info)
|
||||
|
||||
json_str = json.dumps(export_data, indent=2)
|
||||
|
||||
with open(args.output, "w") as f:
|
||||
f.write(json_str)
|
||||
print(f" Saved to: {args.output}")
|
||||
|
||||
file_size = Path(args.output).stat().st_size
|
||||
print(f" File size: {file_size:,} bytes")
|
||||
print(f" Kernel count: {len(configs)}")
|
||||
|
||||
# Step 4: Preview
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 4: JSON Preview")
|
||||
print("=" * 70)
|
||||
preview = json_str[:500]
|
||||
if len(json_str) > 500:
|
||||
preview += "\n ..."
|
||||
print(preview)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("JSON Export complete!")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
256
dispatcher/examples/fmha/python/07_stress_test.py
Normal file
256
dispatcher/examples/fmha/python/07_stress_test.py
Normal file
@@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 07: Stress Test - Multiple FMHA Kernels with Validation
|
||||
|
||||
Generates many FmhaKernelSpec configurations across pipelines, head
|
||||
dimensions, and data types, registers them in an FmhaRegistry, builds
|
||||
all in parallel, and validates each against a CPU reference.
|
||||
|
||||
Usage:
|
||||
python3 07_stress_test.py
|
||||
python3 07_stress_test.py --help
|
||||
python3 07_stress_test.py --num-kernels 4
|
||||
python3 07_stress_test.py --workers 8
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelSpec,
|
||||
FmhaProblem,
|
||||
FmhaRegistry,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
spec_to_config,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
# FmhaKernelSpec fields:
|
||||
# name -- human-readable kernel identifier
|
||||
# hdim -- head dimension (hdim_q = hdim_v)
|
||||
# pipeline -- "qr_async" (async prefetch) or "qr" (synchronous)
|
||||
# tile_m0 -- Stage 0 tile along seqlen_q (Q*K^T M dimension)
|
||||
# tile_n0 -- Stage 0 tile along seqlen_k (Q*K^T N dimension)
|
||||
# tile_k0 -- Stage 0 tile along hdim_q (Q*K^T K dimension)
|
||||
KERNEL_SPECS: List[FmhaKernelSpec] = [
|
||||
# qr_async pipeline -- various tile sizes
|
||||
FmhaKernelSpec(
|
||||
name="qr_async_h128_t128",
|
||||
hdim=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="qr_async_h128_t64",
|
||||
hdim=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=64,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="qr_async_h64_t128",
|
||||
hdim=64,
|
||||
pipeline="qr_async",
|
||||
tile_m0=128,
|
||||
tile_n0=64,
|
||||
tile_k0=32,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="qr_async_h64_t64",
|
||||
hdim=64,
|
||||
pipeline="qr_async",
|
||||
tile_m0=64,
|
||||
tile_n0=64,
|
||||
tile_k0=32,
|
||||
),
|
||||
# qr pipeline -- various tile sizes
|
||||
FmhaKernelSpec(
|
||||
name="qr_h128_t128",
|
||||
hdim=128,
|
||||
pipeline="qr",
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="qr_h128_t64", hdim=128, pipeline="qr", tile_m0=64, tile_n0=128, tile_k0=32
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="qr_h64_t128", hdim=64, pipeline="qr", tile_m0=128, tile_n0=64, tile_k0=32
|
||||
),
|
||||
FmhaKernelSpec(
|
||||
name="qr_h64_t64", hdim=64, pipeline="qr", tile_m0=64, tile_n0=64, tile_k0=32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def print_spec_table(specs: List[FmhaKernelSpec]):
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<25} {'Pipeline':<12} {'Hdim':>5} "
|
||||
f"{'TileM':>6} {'TileN':>6} {'TileK':>6}"
|
||||
)
|
||||
print(" " + "-" * 70)
|
||||
for i, s in enumerate(specs, 1):
|
||||
print(
|
||||
f" {i:<3} {s.name:<25} {s.pipeline:<12} {s.hdim:>5} "
|
||||
f"{s.tile_m0:>6} {s.tile_n0:>6} {s.tile_k0:>6}"
|
||||
)
|
||||
print(" " + "-" * 70)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="FMHA Stress Test - multiple kernels with validation",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 07_stress_test.py # Test all kernels
|
||||
python3 07_stress_test.py --num-kernels 4 # First 4 only
|
||||
python3 07_stress_test.py --workers 8 # 8 parallel compile workers
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument(
|
||||
"--num-kernels", type=int, default=0, help="Number of kernels to test (0 = all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=0, help="Max parallel build workers (0 = auto)"
|
||||
)
|
||||
parser.add_argument("--rtol", type=float, default=1e-2)
|
||||
parser.add_argument("--atol", type=float, default=1e-2)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 07: FMHA Stress Test - Multiple Kernels")
|
||||
print("=" * 70)
|
||||
|
||||
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
|
||||
|
||||
print(f"\n Arch: {args.arch}")
|
||||
print(f" Kernels: {len(specs)}")
|
||||
print_spec_table(specs)
|
||||
|
||||
# Step 1: Register all in FmhaRegistry and build
|
||||
print("\n" + "=" * 70)
|
||||
print(" JIT BUILD")
|
||||
print("=" * 70)
|
||||
|
||||
reg = FmhaRegistry("stress_test")
|
||||
for spec in specs:
|
||||
cfg = spec_to_config(spec, dtype="fp16", arch=args.arch)
|
||||
reg.register_kernel(cfg)
|
||||
|
||||
workers = args.workers if args.workers > 0 else None
|
||||
print(f"\n Building {len(reg)} kernels (workers={workers or 'auto'}) ...")
|
||||
|
||||
t0 = time.perf_counter()
|
||||
build_results = reg.build(verbose=False, max_workers=workers)
|
||||
build_time = time.perf_counter() - t0
|
||||
|
||||
built = sum(1 for r in build_results if r.success)
|
||||
print(f" Built: {built}/{len(specs)} in {build_time:.1f} s")
|
||||
|
||||
for i, r in enumerate(build_results, 1):
|
||||
tag = "OK" if r.success else f"FAIL: {r.error[:50]}"
|
||||
name = r.config.name if r.config else f"kernel_{i}"
|
||||
print(f" [{i}] {name}: {tag}")
|
||||
|
||||
if built == 0:
|
||||
print("\n No kernels built -- aborting")
|
||||
return 1
|
||||
|
||||
# Step 2: Validate each built kernel
|
||||
print("\n" + "=" * 70)
|
||||
print(" VALIDATION")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=2, nhead_q=4, nhead_k=4, seqlen_q=64, seqlen_k=64, hdim_q=128, hdim_v=128
|
||||
)
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32), K.astype(np.float32), V.astype(np.float32), prob.scale
|
||||
)
|
||||
|
||||
validator = FmhaValidator(rtol=args.rtol, atol=args.atol)
|
||||
|
||||
print(
|
||||
f"\n Problem: B={prob.batch} Hq={prob.nhead_q} Sq={prob.seqlen_q} D={prob.hdim_q}"
|
||||
)
|
||||
print(f"\n {'#':<3} {'Name':<35} {'Time':>8} {'MaxErr':>10} {'Status':<6}")
|
||||
print(" " + "-" * 66)
|
||||
|
||||
total_pass = 0
|
||||
total_fail = 0
|
||||
|
||||
for i, r in enumerate(build_results, 1):
|
||||
name = r.config.name if r.config else f"kernel_{i}"
|
||||
|
||||
if not r.success or r.runner is None:
|
||||
print(f" {i:<3} {name:<35} {'---':>8} {'---':>10} {'SKIP':<6}")
|
||||
continue
|
||||
|
||||
hdim = r.config.hdim_q if r.config else 128
|
||||
if hdim != prob.hdim_q:
|
||||
print(f" {i:<3} {name:<35} {'---':>8} {'---':>10} {'SKIP':<6}")
|
||||
continue
|
||||
|
||||
res = r.runner.run(Q, K, V, prob)
|
||||
if not res.success:
|
||||
print(f" {i:<3} {name:<35} {'---':>8} {'---':>10} {'FAIL':<6}")
|
||||
total_fail += 1
|
||||
continue
|
||||
|
||||
ok, max_abs, _ = validator.check(res.output, O_ref)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
print(f" {i:<3} {name:<35} {res.time_ms:>7.4f}ms {max_abs:>10.2e} {tag:<6}")
|
||||
|
||||
if ok:
|
||||
total_pass += 1
|
||||
else:
|
||||
total_fail += 1
|
||||
|
||||
r.runner.cleanup()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(" SUMMARY")
|
||||
print("=" * 70)
|
||||
print(f"\n Total: {len(specs)}")
|
||||
print(f" Built: {built}")
|
||||
print(f" Passed: {total_pass}")
|
||||
print(f" Failed: {total_fail}")
|
||||
print(f" Build time: {build_time:.1f} s")
|
||||
print(f" Tolerance: rtol={args.rtol}, atol={args.atol}")
|
||||
|
||||
if total_fail == 0 and total_pass > 0:
|
||||
print("\n *** ALL VALIDATED KERNELS PASSED ***")
|
||||
elif total_fail > 0:
|
||||
print(f"\n *** {total_fail} KERNELS FAILED ***")
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if total_fail == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
348
dispatcher/examples/fmha/python/08_heuristics.py
Normal file
348
dispatcher/examples/fmha/python/08_heuristics.py
Normal file
@@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 08: Kernel Selection Heuristics
|
||||
|
||||
Demonstrates how to build multiple FMHA kernels with different tile
|
||||
sizes and select the best kernel for a given problem. Shows that
|
||||
smaller tiles tend to be better for short sequences while larger tiles
|
||||
are better for long sequences.
|
||||
|
||||
Usage:
|
||||
python3 08_heuristics.py
|
||||
python3 08_heuristics.py --help
|
||||
python3 08_heuristics.py --arch gfx950
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaRegistry,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileProfile:
|
||||
"""A kernel profile tagged with a human-readable label."""
|
||||
|
||||
label: str
|
||||
config: FmhaKernelConfig
|
||||
category: str # "small", "medium", "large"
|
||||
|
||||
|
||||
def build_tile_profiles(arch: str) -> List[TileProfile]:
|
||||
"""Create kernel configs with varying tile sizes."""
|
||||
return [
|
||||
TileProfile(
|
||||
label="small_64x64",
|
||||
category="small",
|
||||
config=FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
pipeline="qr_async",
|
||||
# Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q
|
||||
tile_m0=64,
|
||||
tile_n0=64,
|
||||
tile_k0=32,
|
||||
# Stage 1 (Attn*V): hdim_v x seqlen_k x alignment
|
||||
tile_n1=128,
|
||||
tile_k1=32,
|
||||
tile_k0max=128,
|
||||
# Wave config per stage
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
# Warp tile per stage
|
||||
warp_m0=16,
|
||||
warp_n0=16,
|
||||
warp_k0=16,
|
||||
warp_m1=16,
|
||||
warp_n1=16,
|
||||
warp_k1=16,
|
||||
gfx_arch=arch,
|
||||
),
|
||||
),
|
||||
TileProfile(
|
||||
label="medium_128x128",
|
||||
category="medium",
|
||||
config=FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
tile_n1=128,
|
||||
tile_k1=32,
|
||||
tile_k0max=128,
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
warp_m0=32,
|
||||
warp_n0=32,
|
||||
warp_k0=16,
|
||||
warp_m1=32,
|
||||
warp_n1=32,
|
||||
warp_k1=16,
|
||||
gfx_arch=arch,
|
||||
),
|
||||
),
|
||||
TileProfile(
|
||||
label="large_128x256",
|
||||
category="large",
|
||||
config=FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=128,
|
||||
tile_n0=256,
|
||||
tile_k0=32,
|
||||
tile_n1=128,
|
||||
tile_k1=32,
|
||||
tile_k0max=128,
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
warp_m0=32,
|
||||
warp_n0=32,
|
||||
warp_k0=16,
|
||||
warp_m1=32,
|
||||
warp_n1=32,
|
||||
warp_k1=16,
|
||||
gfx_arch=arch,
|
||||
),
|
||||
),
|
||||
TileProfile(
|
||||
label="medium_qr_128x128",
|
||||
category="medium",
|
||||
config=FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
pipeline="qr",
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
tile_n1=128,
|
||||
tile_k1=32,
|
||||
tile_k0max=128,
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
warp_m0=32,
|
||||
warp_n0=32,
|
||||
warp_k0=16,
|
||||
warp_m1=32,
|
||||
warp_n1=32,
|
||||
warp_k1=16,
|
||||
pad_s=False,
|
||||
pad_sk=False,
|
||||
pad_d=True,
|
||||
pad_dv=True,
|
||||
gfx_arch=arch,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def select_kernel_heuristic(seqlen: int, profiles: List[TileProfile]) -> TileProfile:
|
||||
"""Simple heuristic: pick tile size category based on sequence length."""
|
||||
if seqlen <= 64:
|
||||
target = "small"
|
||||
elif seqlen <= 256:
|
||||
target = "medium"
|
||||
else:
|
||||
target = "large"
|
||||
|
||||
candidates = [p for p in profiles if p.category == target]
|
||||
if not candidates:
|
||||
candidates = profiles
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="FMHA Heuristics - kernel selection by problem size",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 08_heuristics.py
|
||||
python3 08_heuristics.py --arch gfx950
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 08: Kernel Selection Heuristics")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Build kernel pool
|
||||
print("\nStep 1: Build Kernel Pool")
|
||||
profiles = build_tile_profiles(args.arch)
|
||||
|
||||
reg = FmhaRegistry("heuristic_pool")
|
||||
for p in profiles:
|
||||
reg.register_kernel(p.config)
|
||||
|
||||
print(f" Profiles: {len(profiles)}")
|
||||
for i, p in enumerate(profiles, 1):
|
||||
tile_str = f"{p.config.tile[0]}x{p.config.tile[1]}"
|
||||
print(
|
||||
f" [{i}] {p.label:<25} tile={tile_str:<10} pipeline={p.config.pipeline}"
|
||||
)
|
||||
|
||||
print("\n Building kernels ...")
|
||||
build_results = reg.build(verbose=False)
|
||||
built = sum(1 for r in build_results if r.success)
|
||||
print(f" Built: {built}/{len(profiles)}")
|
||||
|
||||
for i, r in enumerate(build_results):
|
||||
tag = "OK" if r.success else f"FAIL: {r.error[:40]}"
|
||||
print(f" [{i + 1}] {profiles[i].label}: {tag}")
|
||||
|
||||
if built == 0:
|
||||
print(" No kernels built -- aborting")
|
||||
return 1
|
||||
|
||||
# Step 2: Run each kernel on multiple sequence lengths
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 2: Benchmark Across Sequence Lengths")
|
||||
print("=" * 70)
|
||||
|
||||
test_seqlens = [32, 64, 128, 256, 512]
|
||||
|
||||
header = f" {'SeqLen':>7}"
|
||||
for p in profiles:
|
||||
header += f" | {p.label:>18}"
|
||||
header += " | {'Best':>18}"
|
||||
print(f"\n {'SeqLen':>7}", end="")
|
||||
for p in profiles:
|
||||
print(f" | {p.label:>18}", end="")
|
||||
print(f" | {'Best':>18}")
|
||||
print(" " + "-" * (10 + 21 * len(profiles) + 22))
|
||||
|
||||
for seqlen in test_seqlens:
|
||||
prob = FmhaProblem(
|
||||
batch=2,
|
||||
nhead_q=8,
|
||||
nhead_k=8,
|
||||
seqlen_q=seqlen,
|
||||
seqlen_k=seqlen,
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
)
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
|
||||
|
||||
row = f" {seqlen:>7}"
|
||||
best_tflops = 0.0
|
||||
best_label = "---"
|
||||
|
||||
for j, (p, r) in enumerate(zip(profiles, build_results)):
|
||||
if not r.success or r.runner is None:
|
||||
row += f" | {'N/A':>18}"
|
||||
continue
|
||||
|
||||
res = r.runner.run(Q, K, V, prob)
|
||||
if res.success:
|
||||
cell = f"{res.tflops:.2f} TFLOPS"
|
||||
row += f" | {cell:>18}"
|
||||
if res.tflops > best_tflops:
|
||||
best_tflops = res.tflops
|
||||
best_label = p.label
|
||||
else:
|
||||
row += f" | {'ERR':>18}"
|
||||
|
||||
row += f" | {best_label:>18}"
|
||||
print(row)
|
||||
|
||||
# Step 3: Demonstrate heuristic selection
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 3: Heuristic Selection Demo")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n {'SeqLen':>7} {'Selected':>25} {'TFLOPS':>10} {'Status':<6}")
|
||||
print(" " + "-" * 55)
|
||||
|
||||
for seqlen in test_seqlens:
|
||||
selected = select_kernel_heuristic(seqlen, profiles)
|
||||
idx = profiles.index(selected)
|
||||
r = build_results[idx]
|
||||
|
||||
if not r.success or r.runner is None:
|
||||
print(f" {seqlen:>7} {selected.label:>25} {'---':>10} {'SKIP':<6}")
|
||||
continue
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=2,
|
||||
nhead_q=8,
|
||||
nhead_k=8,
|
||||
seqlen_q=seqlen,
|
||||
seqlen_k=seqlen,
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
)
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
|
||||
|
||||
res = r.runner.run(Q, K, V, prob)
|
||||
if res.success:
|
||||
print(f" {seqlen:>7} {selected.label:>25} {res.tflops:>10.2f} {'OK':<6}")
|
||||
else:
|
||||
print(f" {seqlen:>7} {selected.label:>25} {'---':>10} {'FAIL':<6}")
|
||||
|
||||
# Cleanup
|
||||
for r in build_results:
|
||||
if r.runner:
|
||||
r.runner.cleanup()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Heuristic Insight:")
|
||||
print(" - Small tiles: low overhead for short sequences")
|
||||
print(" - Large tiles: high throughput for long sequences")
|
||||
print(" - Pipeline choice also matters (qr vs qr_async)")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
298
dispatcher/examples/fmha/python/09_multi_registry.py
Normal file
298
dispatcher/examples/fmha/python/09_multi_registry.py
Normal file
@@ -0,0 +1,298 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 09: Multiple Registries
|
||||
|
||||
Creates separate FmhaRegistry instances for different optimization
|
||||
targets (latency vs throughput), builds both, runs the same problem
|
||||
through each, and compares results.
|
||||
|
||||
Usage:
|
||||
python3 09_multi_registry.py
|
||||
python3 09_multi_registry.py --help
|
||||
python3 09_multi_registry.py --arch gfx950
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaRegistry,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
def make_latency_config(arch: str) -> FmhaKernelConfig:
|
||||
"""Latency-optimized: smaller tiles, lower launch overhead."""
|
||||
return FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
pipeline="qr",
|
||||
# Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q
|
||||
tile_m0=64,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
# Stage 1 (Attn*V): hdim_v x seqlen_k x alignment
|
||||
tile_n1=128,
|
||||
tile_k1=32,
|
||||
tile_k0max=128,
|
||||
# Wave config per stage
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
# Warp tile per stage
|
||||
warp_m0=16,
|
||||
warp_n0=16,
|
||||
warp_k0=32,
|
||||
warp_m1=16,
|
||||
warp_n1=16,
|
||||
warp_k1=16,
|
||||
pad_s=False,
|
||||
pad_sk=False,
|
||||
pad_d=True,
|
||||
pad_dv=True,
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def make_throughput_config(arch: str) -> FmhaKernelConfig:
|
||||
"""Throughput-optimized: larger tiles, async pipeline."""
|
||||
return FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
pipeline="qr_async",
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
tile_n1=128,
|
||||
tile_k1=32,
|
||||
tile_k0max=128,
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
warp_m0=32,
|
||||
warp_n0=32,
|
||||
warp_k0=16,
|
||||
warp_m1=32,
|
||||
warp_n1=32,
|
||||
warp_k1=16,
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Multiple FMHA Registries - latency vs throughput",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 09_multi_registry.py
|
||||
python3 09_multi_registry.py --arch gfx950
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--rtol", type=float, default=1e-2)
|
||||
parser.add_argument("--atol", type=float, default=1e-2)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 09: Multiple Registries")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Define optimization-specific configs
|
||||
print("\nStep 1: Define Optimization Targets")
|
||||
|
||||
latency_cfg = make_latency_config(args.arch)
|
||||
throughput_cfg = make_throughput_config(args.arch)
|
||||
|
||||
print(f" Latency config: {latency_cfg.name}")
|
||||
print(f" pipeline={latency_cfg.pipeline}, tile={latency_cfg.tile[:2]}")
|
||||
print(f" Throughput config: {throughput_cfg.name}")
|
||||
print(f" pipeline={throughput_cfg.pipeline}, tile={throughput_cfg.tile[:2]}")
|
||||
|
||||
# Step 2: Create separate registries
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 2: Create and Build Registries")
|
||||
print("=" * 70)
|
||||
|
||||
latency_reg = FmhaRegistry("latency")
|
||||
latency_reg.register_kernel(latency_cfg)
|
||||
|
||||
throughput_reg = FmhaRegistry("throughput")
|
||||
throughput_reg.register_kernel(throughput_cfg)
|
||||
|
||||
print(f"\n Building 'latency' registry ({len(latency_reg)} kernel) ...")
|
||||
t0 = time.perf_counter()
|
||||
latency_results = latency_reg.build(verbose=False)
|
||||
lat_build_time = time.perf_counter() - t0
|
||||
|
||||
print(f" Building 'throughput' registry ({len(throughput_reg)} kernel) ...")
|
||||
t0 = time.perf_counter()
|
||||
throughput_results = throughput_reg.build(verbose=False)
|
||||
thr_build_time = time.perf_counter() - t0
|
||||
|
||||
lat_ok = latency_results and latency_results[0].success
|
||||
thr_ok = throughput_results and throughput_results[0].success
|
||||
|
||||
print(f"\n Latency: {'OK' if lat_ok else 'FAIL'} ({lat_build_time:.1f} s)")
|
||||
print(f" Throughput: {'OK' if thr_ok else 'FAIL'} ({thr_build_time:.1f} s)")
|
||||
|
||||
if not lat_ok and not thr_ok:
|
||||
print(" No kernels built -- aborting")
|
||||
return 1
|
||||
|
||||
# Step 3: Run same problem through both
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 3: Run Same Problem Through Both Registries")
|
||||
print("=" * 70)
|
||||
|
||||
test_configs = [
|
||||
(2, 4, 4, 64, 64, 128, "small"),
|
||||
(2, 8, 8, 128, 128, 128, "medium"),
|
||||
(2, 8, 8, 256, 256, 128, "large"),
|
||||
]
|
||||
|
||||
validator = FmhaValidator(rtol=args.rtol, atol=args.atol)
|
||||
|
||||
print(f"\n {'Problem':<12} {'Latency':>18} {'Throughput':>18} {'Match':<6}")
|
||||
print(" " + "-" * 60)
|
||||
|
||||
all_match = True
|
||||
|
||||
for batch, hq, hk, sq, sk, hdim, desc in test_configs:
|
||||
prob = FmhaProblem(
|
||||
batch=batch,
|
||||
nhead_q=hq,
|
||||
nhead_k=hk,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=hdim,
|
||||
hdim_v=hdim,
|
||||
)
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
|
||||
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32),
|
||||
K.astype(np.float32),
|
||||
V.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
|
||||
lat_cell = "N/A"
|
||||
thr_cell = "N/A"
|
||||
results_match = True
|
||||
|
||||
if lat_ok:
|
||||
res_lat = latency_results[0].runner.run(Q, K, V, prob)
|
||||
if res_lat.success:
|
||||
lat_cell = f"{res_lat.tflops:.2f} TFLOPS"
|
||||
ok, _, _ = validator.check(res_lat.output, O_ref)
|
||||
if not ok:
|
||||
results_match = False
|
||||
|
||||
if thr_ok:
|
||||
res_thr = throughput_results[0].runner.run(Q, K, V, prob)
|
||||
if res_thr.success:
|
||||
thr_cell = f"{res_thr.tflops:.2f} TFLOPS"
|
||||
ok, _, _ = validator.check(res_thr.output, O_ref)
|
||||
if not ok:
|
||||
results_match = False
|
||||
|
||||
if not results_match:
|
||||
all_match = False
|
||||
|
||||
tag = "YES" if results_match else "NO"
|
||||
print(f" {desc:<12} {lat_cell:>18} {thr_cell:>18} {tag:<6}")
|
||||
|
||||
# Step 4: Detailed comparison on a single problem
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 4: Detailed Comparison (B=2 H=8 S=128 D=128)")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=2,
|
||||
nhead_q=8,
|
||||
nhead_k=8,
|
||||
seqlen_q=128,
|
||||
seqlen_k=128,
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
)
|
||||
np.random.seed(123)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32),
|
||||
K.astype(np.float32),
|
||||
V.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
|
||||
for name, results, ok in [
|
||||
("Latency", latency_results, lat_ok),
|
||||
("Throughput", throughput_results, thr_ok),
|
||||
]:
|
||||
if not ok:
|
||||
print(f"\n {name}: not available")
|
||||
continue
|
||||
res = results[0].runner.run(Q, K, V, prob)
|
||||
if not res.success:
|
||||
print(f"\n {name}: execution failed")
|
||||
continue
|
||||
valid, max_abs, max_rel = validator.check(res.output, O_ref)
|
||||
print(f"\n {name}:")
|
||||
print(f" Time: {res.time_ms:.4f} ms")
|
||||
print(f" TFLOPS: {res.tflops:.2f}")
|
||||
print(f" Max Abs: {max_abs:.2e}")
|
||||
print(f" Max Rel: {max_rel:.2e}")
|
||||
print(f" Valid: {valid}")
|
||||
|
||||
# Cleanup
|
||||
for results in [latency_results, throughput_results]:
|
||||
for r in results:
|
||||
if r.runner:
|
||||
r.runner.cleanup()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("Multi-Registry Pattern:")
|
||||
print("=" * 70)
|
||||
print(" 1. Create FmhaRegistry per optimization target")
|
||||
print(" 2. Register target-specific FmhaKernelConfig in each")
|
||||
print(" 3. Build both registries")
|
||||
print(" 4. Route problems to the best registry")
|
||||
print(" 5. Compare results for correctness")
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if all_match else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
262
dispatcher/examples/fmha/python/10_advanced_benchmark.py
Normal file
262
dispatcher/examples/fmha/python/10_advanced_benchmark.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 10: Advanced FMHA Benchmarking
|
||||
|
||||
Benchmarks FMHA forward across multiple problem sizes with configurable
|
||||
warmup, repeat, and cache-flush settings. Reports min/avg/max/median
|
||||
time and TFLOPS for each problem.
|
||||
|
||||
Usage:
|
||||
python3 10_advanced_benchmark.py
|
||||
python3 10_advanced_benchmark.py --warmup 10 --repeat 50
|
||||
python3 10_advanced_benchmark.py --flush-cache
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
setup_fmha_dispatcher,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Advanced FMHA benchmarking with full parameter control",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 10_advanced_benchmark.py # Defaults
|
||||
python3 10_advanced_benchmark.py --warmup 10 --repeat 50 # More samples
|
||||
python3 10_advanced_benchmark.py --flush-cache # Flush L2
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=5, help="Number of warmup iterations (default: 5)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of timed iterations (default: 20)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flush-cache",
|
||||
action="store_true",
|
||||
help="Allocate a scratch buffer between runs to flush GPU cache",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default=detect_gpu_arch(), help="GPU architecture (auto-detected)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lib", default=None, help="Path to prebuilt .so (JIT-builds if omitted)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
PROBLEM_TABLE = [
|
||||
# (batch, nhead_q, nhead_k, seqlen_q, seqlen_k, hdim, label)
|
||||
(1, 8, 8, 64, 64, 128, "tiny"),
|
||||
(2, 8, 8, 128, 128, 128, "small"),
|
||||
(2, 16, 16, 256, 256, 128, "medium"),
|
||||
(4, 16, 16, 512, 512, 128, "large"),
|
||||
(2, 32, 32, 1024, 1024, 128, "xlarge"),
|
||||
(1, 32, 8, 256, 256, 128, "GQA-4:1"),
|
||||
]
|
||||
|
||||
|
||||
def flush_gpu_cache():
|
||||
"""Allocate and touch a large buffer to evict L2 cache lines."""
|
||||
scratch = np.random.randint(0, 255, size=32 * 1024 * 1024, dtype=np.uint8)
|
||||
_ = scratch.sum()
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
runner, prob: FmhaProblem, warmup: int, repeat: int, flush_cache: bool
|
||||
) -> list:
|
||||
"""Run warmup + repeat iterations and return list of times in ms."""
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float16)
|
||||
|
||||
for _ in range(warmup):
|
||||
runner.run(Q, K, V, prob)
|
||||
|
||||
times = []
|
||||
for _ in range(repeat):
|
||||
if flush_cache:
|
||||
flush_gpu_cache()
|
||||
result = runner.run(Q, K, V, prob)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
return times
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 10: Advanced FMHA Benchmarking")
|
||||
print("=" * 70)
|
||||
|
||||
print("\nBenchmark Configuration:")
|
||||
print(f" Warmup: {args.warmup} iterations")
|
||||
print(f" Repeat: {args.repeat} iterations")
|
||||
print(f" Flush Cache: {args.flush_cache}")
|
||||
print(f" Arch: {args.arch}")
|
||||
print(f" Problems: {len(PROBLEM_TABLE)}")
|
||||
|
||||
# Step 1: Load or JIT-build kernel
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 1: Load / Build Kernel")
|
||||
print("=" * 70)
|
||||
|
||||
print(" JIT building kernel...")
|
||||
config = FmhaKernelConfig(
|
||||
family="fwd",
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
pipeline="qr_async",
|
||||
# Stage 0 (Q*K^T): seqlen_q x seqlen_k x hdim_q
|
||||
tile_m0=128,
|
||||
tile_n0=128,
|
||||
tile_k0=32,
|
||||
# Stage 1 (Attn*V): hdim_v x seqlen_k x alignment
|
||||
tile_n1=128,
|
||||
tile_k1=32,
|
||||
tile_k0max=128,
|
||||
# Wave config per stage
|
||||
wave_m0=4,
|
||||
wave_n0=1,
|
||||
wave_k0=1,
|
||||
wave_m1=4,
|
||||
wave_n1=1,
|
||||
wave_k1=1,
|
||||
# Warp tile per stage
|
||||
warp_m0=32,
|
||||
warp_n0=32,
|
||||
warp_k0=16,
|
||||
warp_m1=32,
|
||||
warp_n1=32,
|
||||
warp_k1=16,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config, verbose=True)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
return 1
|
||||
runner = setup.runner
|
||||
print(f" JIT built: {setup.library_path} ({setup.build_time_s:.1f} s)")
|
||||
|
||||
print(f" Kernels: {runner.kernel_count}")
|
||||
|
||||
# Step 2: Benchmark all problems
|
||||
print("\n" + "=" * 70)
|
||||
print("Step 2: Benchmark Results")
|
||||
print("=" * 70)
|
||||
|
||||
header = (
|
||||
f" {'Label':<10} {'Shape':^30} "
|
||||
f"{'Min':>8} {'Avg':>8} {'Max':>8} {'Med':>8} {'TFLOPS':>8}"
|
||||
)
|
||||
print(f"\n{header}")
|
||||
print(" " + "-" * 85)
|
||||
|
||||
all_results = []
|
||||
np.random.seed(42)
|
||||
|
||||
for batch, hq, hk, sq, sk, hdim, label in PROBLEM_TABLE:
|
||||
prob = FmhaProblem(
|
||||
batch=batch,
|
||||
nhead_q=hq,
|
||||
nhead_k=hk,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=hdim,
|
||||
hdim_v=hdim,
|
||||
)
|
||||
shape_str = f"B{batch}_Hq{hq}_Hk{hk}_S{sq}_D{hdim}"
|
||||
|
||||
times = run_benchmark(runner, prob, args.warmup, args.repeat, args.flush_cache)
|
||||
|
||||
if not times:
|
||||
print(
|
||||
f" {label:<10} {shape_str:^30} {'FAIL':>8} {'---':>8} "
|
||||
f"{'---':>8} {'---':>8} {'---':>8}"
|
||||
)
|
||||
continue
|
||||
|
||||
t_min = min(times)
|
||||
t_max = max(times)
|
||||
t_avg = sum(times) / len(times)
|
||||
t_med = float(np.median(times))
|
||||
|
||||
tflops = prob.num_ops / (t_med * 1e-3) / 1e12 if t_med > 0 else 0
|
||||
|
||||
print(
|
||||
f" {label:<10} {shape_str:^30} "
|
||||
f"{t_min:>7.3f}ms {t_avg:>7.3f}ms {t_max:>7.3f}ms {t_med:>7.3f}ms "
|
||||
f"{tflops:>7.2f}"
|
||||
)
|
||||
|
||||
all_results.append((label, shape_str, t_min, t_avg, t_max, t_med, tflops))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(" SUMMARY")
|
||||
print("=" * 70)
|
||||
|
||||
if all_results:
|
||||
best = max(all_results, key=lambda r: r[6])
|
||||
print(f"\n Best TFLOPS: {best[6]:.2f} ({best[0]}: {best[1]})")
|
||||
avg_tflops = sum(r[6] for r in all_results) / len(all_results)
|
||||
print(f" Avg TFLOPS: {avg_tflops:.2f}")
|
||||
print(f" Problems run: {len(all_results)}/{len(PROBLEM_TABLE)}")
|
||||
else:
|
||||
print("\n No successful benchmarks")
|
||||
|
||||
print(
|
||||
f"\n Settings: warmup={args.warmup}, repeat={args.repeat}, "
|
||||
f"flush_cache={args.flush_cache}"
|
||||
)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("BENCHMARK PARAMETERS REFERENCE")
|
||||
print("=" * 70)
|
||||
print("""
|
||||
--warmup N Warmup iterations (results discarded)
|
||||
Higher = more stable results, longer run
|
||||
Default: 5
|
||||
|
||||
--repeat N Timed iterations
|
||||
Higher = more accurate statistics
|
||||
Default: 20
|
||||
|
||||
--flush-cache Flush GPU L2 cache between iterations
|
||||
Use for memory-bandwidth measurements
|
||||
Default: off
|
||||
|
||||
--arch ARCH GPU architecture (e.g. gfx950)
|
||||
Auto-detected from rocminfo
|
||||
""")
|
||||
print("=" * 70)
|
||||
|
||||
runner.cleanup()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
188
dispatcher/examples/fmha/python/11_bf16_fmha.py
Normal file
188
dispatcher/examples/fmha/python/11_bf16_fmha.py
Normal file
@@ -0,0 +1,188 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 11: BF16 Forward Attention
|
||||
|
||||
Demonstrates:
|
||||
1. BF16 data generation and handling
|
||||
2. GPU execution attempt with prebuilt kernel (fp16-only)
|
||||
3. CPU reference computation in float32
|
||||
4. BF16-specific tolerance validation (atol=1e-2)
|
||||
|
||||
The prebuilt library contains only fp16 kernels. This example shows the API
|
||||
pattern for bf16 and gracefully falls back to CPU reference when the GPU
|
||||
kernel does not support bf16.
|
||||
|
||||
Usage:
|
||||
python3 11_bf16_fmha.py
|
||||
python3 11_bf16_fmha.py --batch 4 --seqlen 256
|
||||
python3 11_bf16_fmha.py --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def to_bf16(arr: np.ndarray) -> np.ndarray:
|
||||
"""Convert float32 array to bfloat16 (stored as uint16 with bf16 bit pattern)."""
|
||||
f32 = arr.astype(np.float32)
|
||||
u32 = f32.view(np.uint32)
|
||||
return (u32 >> 16).astype(np.uint16)
|
||||
|
||||
|
||||
def bf16_to_f32(arr_u16: np.ndarray) -> np.ndarray:
|
||||
"""Convert bfloat16 (uint16) back to float32."""
|
||||
u32 = arr_u16.astype(np.uint32) << 16
|
||||
return u32.view(np.float32)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="BF16 Forward Attention")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=128)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 11: BF16 Forward Attention")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n Problem: B={prob.batch} H={prob.nhead_q} S={prob.seqlen_q} D={prob.hdim_q}"
|
||||
)
|
||||
print(" Dtype: bfloat16")
|
||||
print(f" Arch: {args.arch}")
|
||||
print(f" Scale: {prob.scale:.6f}")
|
||||
|
||||
# --- Generate bf16 data ---
|
||||
np.random.seed(42)
|
||||
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
|
||||
Q_bf16 = to_bf16(Q_f32)
|
||||
K_bf16 = to_bf16(K_f32)
|
||||
V_bf16 = to_bf16(V_f32)
|
||||
|
||||
Q_bf16_f32 = bf16_to_f32(Q_bf16)
|
||||
K_bf16_f32 = bf16_to_f32(K_bf16)
|
||||
V_bf16_f32 = bf16_to_f32(V_bf16)
|
||||
|
||||
print(f"\n Q bf16 range: [{Q_bf16_f32.min():.4f}, {Q_bf16_f32.max():.4f}]")
|
||||
print(f" K bf16 range: [{K_bf16_f32.min():.4f}, {K_bf16_f32.max():.4f}]")
|
||||
print(f" V bf16 range: [{V_bf16_f32.min():.4f}, {V_bf16_f32.max():.4f}]")
|
||||
|
||||
bf16_quant_err = np.abs(Q_f32 - Q_bf16_f32).max()
|
||||
print(f" BF16 quantization error: {bf16_quant_err:.2e}")
|
||||
|
||||
# --- GPU execution attempt ---
|
||||
print("\n--- GPU Execution ---")
|
||||
gpu_output = None
|
||||
gpu_time = None
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
Q_fp16 = Q_bf16_f32.astype(np.float16)
|
||||
K_fp16 = K_bf16_f32.astype(np.float16)
|
||||
V_fp16 = V_bf16_f32.astype(np.float16)
|
||||
result = runner.run(Q_fp16, K_fp16, V_fp16, prob)
|
||||
if result.success:
|
||||
gpu_output = result.output
|
||||
gpu_time = result.time_ms
|
||||
print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS")
|
||||
print(" Note: Ran as fp16 (JIT kernel); native bf16 kernel not compiled")
|
||||
else:
|
||||
print(" GPU: Kernel does not support bf16 (expected)")
|
||||
|
||||
# --- CPU reference (always computed) ---
|
||||
print("\n--- CPU Reference (float32 with bf16-quantized inputs) ---")
|
||||
O_ref = cpu_attention_fwd(Q_bf16_f32, K_bf16_f32, V_bf16_f32, prob.scale)
|
||||
print(f" Output range: [{O_ref.min():.4f}, {O_ref.max():.4f}]")
|
||||
print(f" Output shape: {O_ref.shape}")
|
||||
|
||||
# --- Validation ---
|
||||
print("\n--- Validation ---")
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
|
||||
print(f"\n {'Check':<30} {'MaxAbs':>10} {'MaxRel':>10} {'Status':>8}")
|
||||
print(" " + "-" * 62)
|
||||
|
||||
if gpu_output is not None:
|
||||
ok, max_abs, max_rel = validator.check(gpu_output, O_ref)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
print(
|
||||
f" {'GPU vs CPU (bf16 tol)':<30} {max_abs:>10.2e} {max_rel:>10.2e} {tag:>8}"
|
||||
)
|
||||
else:
|
||||
print(f" {'GPU vs CPU (bf16 tol)':<30} {'N/A':>10} {'N/A':>10} {'SKIP':>8}")
|
||||
|
||||
strict_val = FmhaValidator(rtol=1e-5, atol=1e-5)
|
||||
ok_strict, ma_strict, mr_strict = strict_val.check(
|
||||
O_ref.astype(np.float16),
|
||||
O_ref,
|
||||
)
|
||||
print(
|
||||
f" {'fp16(ref) vs f32(ref)':<30} {ma_strict:>10.2e} {mr_strict:>10.2e} {'PASS' if ok_strict else 'INFO':>8}"
|
||||
)
|
||||
|
||||
O_ref_from_f32 = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale)
|
||||
bf16_impact = float(np.abs(O_ref - O_ref_from_f32).max())
|
||||
print(
|
||||
f" {'bf16 vs f32 input impact':<30} {bf16_impact:>10.2e} {'':>10} {'INFO':>8}"
|
||||
)
|
||||
|
||||
# --- Summary ---
|
||||
print("\n" + "=" * 70)
|
||||
print(" Dtype: bfloat16 (7-bit mantissa vs fp16's 10-bit)")
|
||||
print(" Tolerance: atol=1e-2 (relaxed for bf16 precision)")
|
||||
print(
|
||||
f" GPU: {'%.4f ms' % gpu_time if gpu_time else 'N/A (bf16 kernel not in prebuilt)'}"
|
||||
)
|
||||
print(" CPU ref: Computed with bf16-quantized inputs")
|
||||
print(" BF16 range: Larger exponent range (±3.4e38) vs fp16 (±65504)")
|
||||
status = "PASS" if gpu_output is not None else "DEMO"
|
||||
print(f" Status: {status}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
239
dispatcher/examples/fmha/python/12_masks_fmha.py
Normal file
239
dispatcher/examples/fmha/python/12_masks_fmha.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 12: Attention Masks
|
||||
|
||||
Demonstrates all 5 mask types supported by the FMHA dispatcher:
|
||||
1. no_mask (0) -- Full attention, no masking
|
||||
2. top_left (1) -- Causal mask aligned to top-left corner
|
||||
3. bottom_right (2) -- Causal mask aligned to bottom-right corner
|
||||
4. sliding_window -- Local attention within a fixed window
|
||||
5. generic -- Arbitrary user-defined mask pattern
|
||||
|
||||
For each mask type, this example:
|
||||
- Creates an FmhaProblem
|
||||
- Attempts GPU execution via prebuilt kernel
|
||||
- Computes CPU reference with the mask applied
|
||||
- Validates results
|
||||
|
||||
Usage:
|
||||
python3 12_masks_fmha.py
|
||||
python3 12_masks_fmha.py --seqlen 256
|
||||
python3 12_masks_fmha.py --window-size 64
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
FmhaValidator,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
MASK_TYPES = {
|
||||
"no_mask": 0,
|
||||
"top_left": 1,
|
||||
"bottom_right": 2,
|
||||
"sliding_window": 3,
|
||||
"generic": 4,
|
||||
}
|
||||
|
||||
|
||||
def make_causal_mask_top_left(seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Causal mask aligned to top-left: position i can attend to positions <= i."""
|
||||
row = np.arange(seqlen_q).reshape(-1, 1)
|
||||
col = np.arange(seqlen_k).reshape(1, -1)
|
||||
return (col <= row).astype(np.float32)
|
||||
|
||||
|
||||
def make_causal_mask_bottom_right(seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Causal mask aligned to bottom-right: accounts for kv longer than q."""
|
||||
offset = seqlen_k - seqlen_q
|
||||
row = np.arange(seqlen_q).reshape(-1, 1)
|
||||
col = np.arange(seqlen_k).reshape(1, -1)
|
||||
return (col <= row + offset).astype(np.float32)
|
||||
|
||||
|
||||
def make_sliding_window_mask(seqlen_q: int, seqlen_k: int, window: int) -> np.ndarray:
|
||||
"""Sliding window: each query attends to a local window of keys."""
|
||||
row = np.arange(seqlen_q).reshape(-1, 1)
|
||||
col = np.arange(seqlen_k).reshape(1, -1)
|
||||
offset = seqlen_k - seqlen_q
|
||||
return ((col <= row + offset) & (col >= row + offset - window + 1)).astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
|
||||
def make_generic_mask(seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Generic checkerboard mask for demonstration."""
|
||||
row = np.arange(seqlen_q).reshape(-1, 1)
|
||||
col = np.arange(seqlen_k).reshape(1, -1)
|
||||
return ((row + col) % 2 == 0).astype(np.float32)
|
||||
|
||||
|
||||
def cpu_masked_attention(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
mask: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""CPU reference: scaled dot-product attention with arbitrary mask.
|
||||
|
||||
Q: [batch, nhead, seqlen_q, hdim]
|
||||
mask: [seqlen_q, seqlen_k] (broadcast over batch and head)
|
||||
"""
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
mask_broad = mask[np.newaxis, np.newaxis, :, :]
|
||||
S = np.where(mask_broad > 0, S, -1e9)
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
|
||||
return np.matmul(P, V)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Attention Masks")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen-q", type=int, default=128)
|
||||
parser.add_argument("--seqlen-k", type=int, default=128)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument("--window-size", type=int, default=32)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 12: Attention Masks")
|
||||
print("=" * 70)
|
||||
|
||||
sq, sk = args.seqlen_q, args.seqlen_k
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} Sq={sq} Sk={sk} D={args.hdim}")
|
||||
print(f" Window: {args.window_size}")
|
||||
|
||||
# --- Generate data ---
|
||||
np.random.seed(42)
|
||||
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
Q_fp16 = Q_f32.astype(np.float16)
|
||||
K_fp16 = K_f32.astype(np.float16)
|
||||
V_fp16 = V_f32.astype(np.float16)
|
||||
|
||||
# --- Try GPU runner ---
|
||||
runner = None
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if setup.success:
|
||||
runner = setup.runner
|
||||
print(f"\n GPU runner loaded (JIT build: {setup.build_time_s:.1f}s)")
|
||||
else:
|
||||
print(f"\n GPU runner not available: {setup.error}")
|
||||
|
||||
# --- Build masks ---
|
||||
masks = {
|
||||
"no_mask": np.ones((sq, sk), dtype=np.float32),
|
||||
"top_left": make_causal_mask_top_left(sq, sk),
|
||||
"bottom_right": make_causal_mask_bottom_right(sq, sk),
|
||||
"sliding_window": make_sliding_window_mask(sq, sk, args.window_size),
|
||||
"generic": make_generic_mask(sq, sk),
|
||||
}
|
||||
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
|
||||
print(
|
||||
f"\n {'#':<3} {'MaskType':<18} {'ID':<4} {'Density':>8} {'GPUStatus':<12} {'CPURef':<8} {'MaxErr':>10} {'Status':>8}"
|
||||
)
|
||||
print(" " + "-" * 76)
|
||||
|
||||
results = []
|
||||
for i, (name, mask) in enumerate(masks.items(), 1):
|
||||
mask_id = MASK_TYPES[name]
|
||||
density = mask.sum() / mask.size * 100
|
||||
|
||||
# GPU attempt (prebuilt only supports no_mask)
|
||||
gpu_status = "N/A"
|
||||
gpu_out = None
|
||||
if runner is not None:
|
||||
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
|
||||
if res.success:
|
||||
gpu_out = res.output
|
||||
gpu_status = "OK" if name == "no_mask" else "no_mask*"
|
||||
else:
|
||||
gpu_status = "unsupported"
|
||||
|
||||
# CPU reference with mask
|
||||
O_ref = cpu_masked_attention(Q_f32, K_f32, V_f32, prob.scale, mask)
|
||||
cpu_status = "OK"
|
||||
|
||||
# Validate
|
||||
if gpu_out is not None and name == "no_mask":
|
||||
ok, max_abs, _ = validator.check(gpu_out, O_ref)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
err_str = f"{max_abs:.2e}"
|
||||
else:
|
||||
ok = True
|
||||
tag = "DEMO"
|
||||
err_str = "---"
|
||||
|
||||
print(
|
||||
f" {i:<3} {name:<18} {mask_id:<4} {density:>7.1f}% {gpu_status:<12} {cpu_status:<8} {err_str:>10} {tag:>8}"
|
||||
)
|
||||
results.append((name, ok))
|
||||
|
||||
# --- Mask visualization ---
|
||||
print("\n--- Mask Patterns (first 8x8 corner) ---")
|
||||
view_size = min(8, sq, sk)
|
||||
for name, mask in masks.items():
|
||||
corner = mask[:view_size, :view_size]
|
||||
print(f"\n {name}:")
|
||||
for r in range(view_size):
|
||||
row_str = " ".join(
|
||||
"█" if corner[r, c] > 0 else "·" for c in range(view_size)
|
||||
)
|
||||
print(f" {row_str}")
|
||||
|
||||
# --- Summary ---
|
||||
all_ok = all(ok for _, ok in results)
|
||||
print("\n" + "=" * 70)
|
||||
print(f" Mask types tested: {len(masks)}")
|
||||
print(" no_mask: Full attention (all positions visible)")
|
||||
print(" top_left: Causal from top-left (autoregressive)")
|
||||
print(" bottom_right: Causal from bottom-right (kv-padded)")
|
||||
print(f" sliding_window: Local window of {args.window_size} keys")
|
||||
print(" generic: Arbitrary (checkerboard demo)")
|
||||
print(" GPU: Prebuilt supports no_mask only")
|
||||
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
235
dispatcher/examples/fmha/python/13_bias_fmha.py
Normal file
235
dispatcher/examples/fmha/python/13_bias_fmha.py
Normal file
@@ -0,0 +1,235 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 13: Attention Bias
|
||||
|
||||
Demonstrates bias types supported by the FMHA dispatcher:
|
||||
1. no_bias -- Standard attention without bias
|
||||
2. elementwise -- Add a [seqlen_q, seqlen_k] bias matrix to attention scores
|
||||
3. alibi -- Attention with Linear Biases (ALiBi) positional encoding
|
||||
|
||||
For each bias type:
|
||||
- Creates an FmhaProblem and bias tensor
|
||||
- Attempts GPU execution (prebuilt: no_bias only)
|
||||
- Computes CPU reference with bias applied before softmax
|
||||
- Validates output
|
||||
|
||||
Usage:
|
||||
python3 13_bias_fmha.py
|
||||
python3 13_bias_fmha.py --seqlen 256
|
||||
python3 13_bias_fmha.py --nhead 16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def get_alibi_slopes(nhead: int) -> np.ndarray:
|
||||
"""Compute ALiBi slopes for each attention head.
|
||||
|
||||
Following the original ALiBi paper: slopes = 2^(-8/n * [1..n])
|
||||
where n is the number of heads.
|
||||
"""
|
||||
ratio = 2.0 ** (-8.0 / nhead)
|
||||
return np.array([ratio ** (i + 1) for i in range(nhead)], dtype=np.float32)
|
||||
|
||||
|
||||
def make_alibi_bias(nhead: int, seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Create ALiBi bias matrix: slope * (col - row) for causal positions.
|
||||
|
||||
Returns: [nhead, seqlen_q, seqlen_k]
|
||||
"""
|
||||
slopes = get_alibi_slopes(nhead)
|
||||
row = np.arange(seqlen_q).reshape(-1, 1)
|
||||
col = np.arange(seqlen_k).reshape(1, -1)
|
||||
dist = col - row
|
||||
bias = slopes.reshape(-1, 1, 1) * dist.reshape(1, seqlen_q, seqlen_k)
|
||||
return bias.astype(np.float32)
|
||||
|
||||
|
||||
def make_elementwise_bias(seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Create a relative-position elementwise bias matrix.
|
||||
|
||||
Returns: [seqlen_q, seqlen_k]
|
||||
"""
|
||||
row = np.arange(seqlen_q, dtype=np.float32).reshape(-1, 1)
|
||||
col = np.arange(seqlen_k, dtype=np.float32).reshape(1, -1)
|
||||
dist = np.abs(row - col)
|
||||
return (-0.1 * dist).astype(np.float32)
|
||||
|
||||
|
||||
def cpu_biased_attention(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
bias: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""CPU reference: attention with additive bias before softmax.
|
||||
|
||||
Q: [batch, nhead, seqlen_q, hdim]
|
||||
bias: broadcastable to [batch, nhead, seqlen_q, seqlen_k]
|
||||
"""
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
S = S + bias
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
|
||||
return np.matmul(P, V)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Attention Bias")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=128)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 13: Attention Bias")
|
||||
print("=" * 70)
|
||||
|
||||
sq = sk = args.seqlen
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} S={sq} D={args.hdim}")
|
||||
|
||||
# --- Generate data ---
|
||||
np.random.seed(42)
|
||||
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
Q_fp16 = Q_f32.astype(np.float16)
|
||||
K_fp16 = K_f32.astype(np.float16)
|
||||
V_fp16 = V_f32.astype(np.float16)
|
||||
|
||||
# --- Try GPU runner ---
|
||||
runner = None
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if setup.success:
|
||||
runner = setup.runner
|
||||
print(f" GPU runner loaded (JIT build: {setup.build_time_s:.1f}s)")
|
||||
else:
|
||||
print(f" GPU runner not available: {setup.error}")
|
||||
|
||||
# --- Build bias tensors ---
|
||||
bias_configs = [
|
||||
("no_bias", np.zeros((1, 1, sq, sk), dtype=np.float32)),
|
||||
("elementwise", make_elementwise_bias(sq, sk)[np.newaxis, np.newaxis, :, :]),
|
||||
("alibi", make_alibi_bias(args.nhead, sq, sk)[np.newaxis, :, :, :]),
|
||||
]
|
||||
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
|
||||
print(
|
||||
f"\n {'#':<3} {'BiasType':<14} {'BiasRange':>20} {'GPUStatus':<12} {'MaxErr':>10} {'Status':>8}"
|
||||
)
|
||||
print(" " + "-" * 72)
|
||||
|
||||
results = []
|
||||
for i, (name, bias) in enumerate(bias_configs, 1):
|
||||
bias_min, bias_max = float(bias.min()), float(bias.max())
|
||||
bias_range = f"[{bias_min:.3f}, {bias_max:.3f}]"
|
||||
|
||||
# GPU attempt
|
||||
gpu_status = "N/A"
|
||||
gpu_out = None
|
||||
if runner is not None:
|
||||
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
|
||||
if res.success:
|
||||
gpu_out = res.output
|
||||
gpu_status = "OK" if name == "no_bias" else "no_bias*"
|
||||
else:
|
||||
gpu_status = "unsupported"
|
||||
|
||||
# CPU reference with bias
|
||||
O_ref = cpu_biased_attention(Q_f32, K_f32, V_f32, prob.scale, bias)
|
||||
|
||||
# Validate
|
||||
if gpu_out is not None and name == "no_bias":
|
||||
ok, max_abs, _ = validator.check(gpu_out, O_ref)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
err_str = f"{max_abs:.2e}"
|
||||
else:
|
||||
ok = True
|
||||
tag = "DEMO"
|
||||
err_str = "---"
|
||||
|
||||
print(
|
||||
f" {i:<3} {name:<14} {bias_range:>20} {gpu_status:<12} {err_str:>10} {tag:>8}"
|
||||
)
|
||||
results.append((name, ok))
|
||||
|
||||
# --- Show ALiBi details ---
|
||||
print("\n--- ALiBi Details ---")
|
||||
slopes = get_alibi_slopes(args.nhead)
|
||||
print(f" Heads: {args.nhead}")
|
||||
print(f" Slopes: {', '.join(f'{s:.4f}' for s in slopes[: min(8, len(slopes))])}")
|
||||
if len(slopes) > 8:
|
||||
print(f" ... ({len(slopes)} total)")
|
||||
print(" Effect: Nearby tokens get higher scores, distant tokens penalized")
|
||||
print(" Formula: bias[h,i,j] = slope[h] * (j - i)")
|
||||
|
||||
alibi_bias = make_alibi_bias(args.nhead, sq, sk)
|
||||
print("\n Head 0 bias corner (4x4):")
|
||||
corner = alibi_bias[0, :4, :4]
|
||||
for r in range(4):
|
||||
row_str = " ".join(f"{corner[r, c]:>7.3f}" for c in range(4))
|
||||
print(f" {row_str}")
|
||||
|
||||
# --- Show impact of bias on attention ---
|
||||
print("\n--- Bias Impact Analysis ---")
|
||||
O_no_bias = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale)
|
||||
for name, bias in bias_configs:
|
||||
O_biased = cpu_biased_attention(Q_f32, K_f32, V_f32, prob.scale, bias)
|
||||
diff = float(np.abs(O_biased - O_no_bias).max())
|
||||
print(f" {name:<14} max output shift: {diff:.4e}")
|
||||
|
||||
# --- Summary ---
|
||||
all_ok = all(ok for _, ok in results)
|
||||
print("\n" + "=" * 70)
|
||||
print(" Bias types: no_bias, elementwise, alibi")
|
||||
print(" no_bias: Standard attention (baseline)")
|
||||
print(" elementwise: Position-distance bias [-0.1 * |i-j|]")
|
||||
print(" alibi: Linear position bias per head (no learned params)")
|
||||
print(" GPU: Prebuilt supports no_bias only")
|
||||
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
245
dispatcher/examples/fmha/python/14_dropout_fmha.py
Normal file
245
dispatcher/examples/fmha/python/14_dropout_fmha.py
Normal file
@@ -0,0 +1,245 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 14: Attention Dropout with LSE
|
||||
|
||||
Demonstrates:
|
||||
1. Dropout applied to attention probabilities
|
||||
2. Log-sum-exp (LSE) storage for numerical stability
|
||||
3. Statistical validation (dropout is stochastic)
|
||||
4. Reproducibility with seed control
|
||||
|
||||
Dropout zeros out attention weights with probability p_drop, then scales
|
||||
remaining weights by 1/(1-p_drop) to preserve expected value.
|
||||
LSE stores log(sum(exp(scores))) per query position for backward pass.
|
||||
|
||||
Usage:
|
||||
python3 14_dropout_fmha.py
|
||||
python3 14_dropout_fmha.py --p-drop 0.3
|
||||
python3 14_dropout_fmha.py --seqlen 256 --seed 123
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def cpu_attention_with_dropout(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
p_drop: float,
|
||||
seed: int,
|
||||
) -> tuple:
|
||||
"""CPU reference: attention with dropout and LSE output.
|
||||
|
||||
Returns:
|
||||
(O, P_dropped, lse)
|
||||
O: [batch, nhead, seqlen_q, hdim_v]
|
||||
P_dropped: [batch, nhead, seqlen_q, seqlen_k] attention weights after dropout
|
||||
lse: [batch, nhead, seqlen_q] log-sum-exp of scores
|
||||
"""
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
S_sum = S_exp.sum(axis=-1, keepdims=True)
|
||||
P = S_exp / S_sum
|
||||
|
||||
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
drop_mask = (rng.rand(*P.shape) >= p_drop).astype(np.float32)
|
||||
scale_factor = 1.0 / (1.0 - p_drop) if p_drop < 1.0 else 0.0
|
||||
P_dropped = P * drop_mask * scale_factor
|
||||
|
||||
out = np.matmul(P_dropped, V)
|
||||
return out, P_dropped, lse
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Attention Dropout with LSE")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=128)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument("--p-drop", type=float, default=0.2)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 14: Attention Dropout with LSE")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}"
|
||||
)
|
||||
print(f" p_drop: {args.p_drop}")
|
||||
print(f" Seed: {args.seed}")
|
||||
print(f" LSE shape: [{prob.batch}, {prob.nhead_q}, {prob.seqlen_q}]")
|
||||
|
||||
# --- Generate data ---
|
||||
np.random.seed(args.seed)
|
||||
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
Q_fp16 = Q_f32.astype(np.float16)
|
||||
K_fp16 = K_f32.astype(np.float16)
|
||||
V_fp16 = V_f32.astype(np.float16)
|
||||
|
||||
# --- GPU execution attempt ---
|
||||
print("\n--- GPU Execution ---")
|
||||
gpu_output = None
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
|
||||
if res.success:
|
||||
gpu_output = res.output
|
||||
print(f" GPU (no dropout): {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS")
|
||||
print(" Note: JIT kernel runs without dropout; shown for baseline")
|
||||
else:
|
||||
print(" GPU: Kernel returned failure")
|
||||
|
||||
# --- CPU reference: no dropout (baseline) ---
|
||||
print("\n--- CPU Reference ---")
|
||||
O_no_drop = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale)
|
||||
|
||||
# --- CPU reference: with dropout ---
|
||||
drop_rates = [0.0, 0.1, args.p_drop, 0.5]
|
||||
|
||||
print(
|
||||
f"\n {'p_drop':>8} {'OutMean':>10} {'OutStd':>10} {'MaxDiff':>10} {'DropFrac':>10}"
|
||||
)
|
||||
print(" " + "-" * 52)
|
||||
|
||||
for p in drop_rates:
|
||||
O_drop, P_dropped, lse = cpu_attention_with_dropout(
|
||||
Q_f32,
|
||||
K_f32,
|
||||
V_f32,
|
||||
prob.scale,
|
||||
p,
|
||||
args.seed,
|
||||
)
|
||||
|
||||
total_weights = P_dropped.size
|
||||
zeros = (P_dropped == 0).sum()
|
||||
actual_drop_frac = zeros / total_weights
|
||||
|
||||
diff = float(np.abs(O_drop - O_no_drop).max())
|
||||
print(
|
||||
f" {p:>8.2f} {O_drop.mean():>10.4f} {O_drop.std():>10.4f} "
|
||||
f"{diff:>10.2e} {actual_drop_frac:>10.2%}"
|
||||
)
|
||||
|
||||
# --- LSE analysis ---
|
||||
print("\n--- LSE (Log-Sum-Exp) Analysis ---")
|
||||
_, _, lse = cpu_attention_with_dropout(
|
||||
Q_f32,
|
||||
K_f32,
|
||||
V_f32,
|
||||
prob.scale,
|
||||
args.p_drop,
|
||||
args.seed,
|
||||
)
|
||||
print(f" LSE shape: {lse.shape}")
|
||||
print(f" LSE range: [{lse.min():.4f}, {lse.max():.4f}]")
|
||||
print(f" LSE mean: {lse.mean():.4f}")
|
||||
print(" LSE is independent of dropout (computed from raw scores)")
|
||||
|
||||
lse_nodrop = cpu_attention_with_dropout(
|
||||
Q_f32,
|
||||
K_f32,
|
||||
V_f32,
|
||||
prob.scale,
|
||||
0.0,
|
||||
args.seed,
|
||||
)[2]
|
||||
lse_diff = float(np.abs(lse - lse_nodrop).max())
|
||||
print(f" LSE diff (drop vs no-drop): {lse_diff:.2e} (should be 0)")
|
||||
|
||||
# --- Statistical validation ---
|
||||
print("\n--- Statistical Validation ---")
|
||||
n_trials = 5
|
||||
outputs = []
|
||||
for trial in range(n_trials):
|
||||
O_t, _, _ = cpu_attention_with_dropout(
|
||||
Q_f32,
|
||||
K_f32,
|
||||
V_f32,
|
||||
prob.scale,
|
||||
args.p_drop,
|
||||
args.seed + trial,
|
||||
)
|
||||
outputs.append(O_t)
|
||||
|
||||
O_mean = np.mean(outputs, axis=0)
|
||||
O_std = np.std(outputs, axis=0)
|
||||
|
||||
mean_diff = float(np.abs(O_mean - O_no_drop).max())
|
||||
max_std = float(O_std.max())
|
||||
|
||||
print(f" Trials: {n_trials}")
|
||||
print(f" Mean vs no-drop: {mean_diff:.4e} (should be small)")
|
||||
print(f" Max output stddev: {max_std:.4e}")
|
||||
print(" E[dropout(P)] = P (unbiased estimator)")
|
||||
|
||||
if gpu_output is not None:
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
ok, max_abs, _ = validator.check(gpu_output, O_no_drop)
|
||||
print(
|
||||
f"\n GPU vs CPU (no-drop): max_err={max_abs:.2e}, {'PASS' if ok else 'FAIL'}"
|
||||
)
|
||||
|
||||
# --- Summary ---
|
||||
print("\n" + "=" * 70)
|
||||
print(f" Dropout: p_drop={args.p_drop}, seed={args.seed}")
|
||||
print(
|
||||
f" LSE: Stored for backward pass (shape [{prob.batch},{prob.nhead_q},{prob.seqlen_q}])"
|
||||
)
|
||||
print(" Key: Dropout is stochastic; validate statistically, not exactly")
|
||||
print(" GPU: Prebuilt kernel does not support dropout")
|
||||
print(" Status: DEMO")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
217
dispatcher/examples/fmha/python/15_gqa_fmha.py
Normal file
217
dispatcher/examples/fmha/python/15_gqa_fmha.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 15: Grouped-Query Attention (GQA / MQA)
|
||||
|
||||
Demonstrates GQA with various nhead_q:nhead_k ratios:
|
||||
- 1:1 (MHA) -- Standard multi-head attention
|
||||
- 2:1 -- Each KV head serves 2 query heads
|
||||
- 4:1 -- Each KV head serves 4 query heads
|
||||
- 8:1 -- Each KV head serves 8 query heads
|
||||
- 16:1 (MQA) -- Single KV head serves all query heads
|
||||
|
||||
GQA reduces KV cache memory and bandwidth while maintaining quality.
|
||||
CPU reference uses np.repeat to expand K,V heads to match Q heads.
|
||||
|
||||
Usage:
|
||||
python3 15_gqa_fmha.py
|
||||
python3 15_gqa_fmha.py --nhead-q 32
|
||||
python3 15_gqa_fmha.py --seqlen 256
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GQA / MQA Attention")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead-q", type=int, default=16)
|
||||
parser.add_argument("--seqlen", type=int, default=128)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 15: Grouped-Query Attention (GQA / MQA)")
|
||||
print("=" * 70)
|
||||
|
||||
hq = args.nhead_q
|
||||
|
||||
gqa_ratios = []
|
||||
for ratio in [1, 2, 4, 8, 16]:
|
||||
if hq % ratio == 0:
|
||||
gqa_ratios.append(ratio)
|
||||
|
||||
print(f"\n nhead_q: {hq}")
|
||||
print(f" Ratios: {', '.join(f'{r}:1' for r in gqa_ratios)}")
|
||||
print(f" Problem: B={args.batch} S={args.seqlen} D={args.hdim}")
|
||||
|
||||
# --- Try GPU runner ---
|
||||
runner = None
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if setup.success:
|
||||
runner = setup.runner
|
||||
print(f" GPU: Loaded (JIT build: {setup.build_time_s:.1f}s)")
|
||||
else:
|
||||
print(f" GPU: Not available ({setup.error})")
|
||||
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
|
||||
print(
|
||||
f"\n {'#':<3} {'Ratio':<8} {'nhead_q':>8} {'nhead_k':>8} {'KV_save':>8} "
|
||||
f"{'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':>8}"
|
||||
)
|
||||
print(" " + "-" * 82)
|
||||
|
||||
results = []
|
||||
for i, ratio in enumerate(gqa_ratios, 1):
|
||||
hk = hq // ratio
|
||||
kv_saving = (1.0 - hk / hq) * 100
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=hq,
|
||||
nhead_k=hk,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
np.random.seed(42 + i)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
|
||||
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32),
|
||||
K.astype(np.float32),
|
||||
V.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
|
||||
# GPU attempt
|
||||
time_str = "---"
|
||||
tflops_str = "---"
|
||||
gpu_out = None
|
||||
if runner is not None:
|
||||
res = runner.run(Q, K, V, prob)
|
||||
if res.success:
|
||||
gpu_out = res.output
|
||||
time_str = f"{res.time_ms:.4f}"
|
||||
tflops_str = f"{res.tflops:.2f}"
|
||||
|
||||
if gpu_out is not None:
|
||||
ok, max_abs, _ = validator.check(gpu_out, O_ref)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
err_str = f"{max_abs:.2e}"
|
||||
else:
|
||||
ok = True
|
||||
tag = "DEMO"
|
||||
err_str = "---"
|
||||
max_abs = 0.0
|
||||
|
||||
label = f"{ratio}:1"
|
||||
if ratio == 1:
|
||||
label += " MHA"
|
||||
elif hk == 1:
|
||||
label += " MQA"
|
||||
|
||||
print(
|
||||
f" {i:<3} {label:<8} {hq:>8} {hk:>8} {kv_saving:>7.0f}% "
|
||||
f"{time_str:>10} {tflops_str:>10} {err_str:>10} {tag:>8}"
|
||||
)
|
||||
results.append((ratio, hk, ok, max_abs))
|
||||
|
||||
# --- Memory analysis ---
|
||||
print("\n--- KV Cache Memory Analysis ---")
|
||||
base_kv_size = args.batch * hq * args.seqlen * args.hdim * 2 * 2 # K+V, fp16
|
||||
|
||||
print(f"\n {'Ratio':<8} {'nhead_k':>8} {'KV Size':>12} {'Savings':>10}")
|
||||
print(" " + "-" * 42)
|
||||
|
||||
for ratio in gqa_ratios:
|
||||
hk = hq // ratio
|
||||
kv_size = args.batch * hk * args.seqlen * args.hdim * 2 * 2
|
||||
saving = (1.0 - kv_size / base_kv_size) * 100
|
||||
size_str = (
|
||||
f"{kv_size / 1024:.1f} KB"
|
||||
if kv_size < 1024 * 1024
|
||||
else f"{kv_size / (1024 * 1024):.2f} MB"
|
||||
)
|
||||
print(f" {ratio}:1{'':<4} {hq // ratio:>8} {size_str:>12} {saving:>9.0f}%")
|
||||
|
||||
# --- GQA correctness: verify np.repeat equivalence ---
|
||||
print("\n--- GQA Equivalence Check ---")
|
||||
prob_gqa = FmhaProblem(
|
||||
batch=1,
|
||||
nhead_q=8,
|
||||
nhead_k=2,
|
||||
seqlen_q=64,
|
||||
seqlen_k=64,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
np.random.seed(99)
|
||||
Q_g = (np.random.randn(*prob_gqa.q_shape()) * 0.1).astype(np.float32)
|
||||
K_g = (np.random.randn(*prob_gqa.k_shape()) * 0.1).astype(np.float32)
|
||||
V_g = (np.random.randn(*prob_gqa.v_shape()) * 0.1).astype(np.float32)
|
||||
|
||||
O_gqa = cpu_attention_fwd(Q_g, K_g, V_g, prob_gqa.scale)
|
||||
|
||||
K_exp = np.repeat(K_g, 4, axis=1)
|
||||
V_exp = np.repeat(V_g, 4, axis=1)
|
||||
prob_mha = FmhaProblem(
|
||||
batch=1,
|
||||
nhead_q=8,
|
||||
nhead_k=8,
|
||||
seqlen_q=64,
|
||||
seqlen_k=64,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
O_mha = cpu_attention_fwd(Q_g, K_exp, V_exp, prob_mha.scale)
|
||||
|
||||
equiv_err = float(np.abs(O_gqa - O_mha).max())
|
||||
print(f" GQA(4:1) vs MHA(expanded): max_err = {equiv_err:.2e}")
|
||||
print(" cpu_attention_fwd handles GQA internally via np.repeat")
|
||||
|
||||
# --- Summary ---
|
||||
all_ok = all(ok for _, _, ok, _ in results)
|
||||
print("\n" + "=" * 70)
|
||||
print(f" GQA ratios tested: {len(gqa_ratios)}")
|
||||
print(" MHA (1:1): All heads have unique KV (baseline)")
|
||||
print(" GQA (N:1): N query heads share one KV head")
|
||||
print(" MQA (H:1): All query heads share single KV head (max saving)")
|
||||
print(" GPU: Prebuilt kernel supports GQA via nhead_q != nhead_k")
|
||||
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
267
dispatcher/examples/fmha/python/16_splitkv_fmha.py
Normal file
267
dispatcher/examples/fmha/python/16_splitkv_fmha.py
Normal file
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 16: Split-KV Attention and Paged KV Cache
|
||||
|
||||
Demonstrates:
|
||||
1. Split-KV: partitioning KV across multiple GPU splits for long sequences
|
||||
2. Two-stage execution plan: split (per-partition attention) + combine (merge)
|
||||
3. Paged KV cache with configurable page_block_size
|
||||
4. CPU reference for split-KV correctness verification
|
||||
|
||||
Split-KV is critical for long-context inference where seqlen_k >> seqlen_q
|
||||
(decoding with long history). Each split processes a chunk of KV independently,
|
||||
then partial results are combined with log-sum-exp correction.
|
||||
|
||||
Usage:
|
||||
python3 16_splitkv_fmha.py
|
||||
python3 16_splitkv_fmha.py --num-splits 4
|
||||
python3 16_splitkv_fmha.py --seqlen-k 2048 --page-size 128
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def cpu_splitkv_attention(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
num_splits: int,
|
||||
) -> tuple:
|
||||
"""CPU reference: split-KV attention with LSE-based combining.
|
||||
|
||||
Stage 1 (split): Compute partial attention for each KV chunk
|
||||
Stage 2 (combine): Merge partial results using log-sum-exp correction
|
||||
|
||||
Returns: (O_final, partial_Os, partial_lses)
|
||||
"""
|
||||
batch, nhead, seqlen_q, hdim = Q.shape
|
||||
seqlen_k = K.shape[2]
|
||||
hdim_v = V.shape[3]
|
||||
|
||||
chunk_size = (seqlen_k + num_splits - 1) // num_splits
|
||||
|
||||
partial_Os = np.zeros(
|
||||
(num_splits, batch, nhead, seqlen_q, hdim_v), dtype=np.float32
|
||||
)
|
||||
partial_lses = np.full(
|
||||
(num_splits, batch, nhead, seqlen_q), -np.inf, dtype=np.float32
|
||||
)
|
||||
|
||||
for s in range(num_splits):
|
||||
k_start = s * chunk_size
|
||||
k_end = min(k_start + chunk_size, seqlen_k)
|
||||
if k_start >= seqlen_k:
|
||||
break
|
||||
|
||||
K_chunk = K[:, :, k_start:k_end, :]
|
||||
V_chunk = V[:, :, k_start:k_end, :]
|
||||
|
||||
S = np.matmul(Q, K_chunk.transpose(0, 1, 3, 2)) * scale
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
S_sum = S_exp.sum(axis=-1, keepdims=True)
|
||||
|
||||
partial_Os[s] = np.matmul(S_exp / S_sum, V_chunk)
|
||||
partial_lses[s] = np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)
|
||||
|
||||
# Stage 2: Combine using LSE correction
|
||||
global_lse = np.max(partial_lses, axis=0) # [batch, nhead, seqlen_q]
|
||||
|
||||
O_final = np.zeros((batch, nhead, seqlen_q, hdim_v), dtype=np.float32)
|
||||
weight_sum = np.zeros((batch, nhead, seqlen_q), dtype=np.float32)
|
||||
|
||||
for s in range(num_splits):
|
||||
correction = np.exp(partial_lses[s] - global_lse)
|
||||
correction = correction[..., np.newaxis]
|
||||
O_final += partial_Os[s] * correction
|
||||
weight_sum += correction.squeeze(-1)
|
||||
|
||||
O_final = O_final / weight_sum[..., np.newaxis]
|
||||
return O_final, partial_Os, partial_lses
|
||||
|
||||
|
||||
def make_page_table(batch: int, seqlen_k: int, page_size: int) -> tuple:
|
||||
"""Create a paged KV cache layout.
|
||||
|
||||
Returns: (page_table, num_pages_per_seq, total_pages)
|
||||
"""
|
||||
pages_per_seq = (seqlen_k + page_size - 1) // page_size
|
||||
total_pages = batch * pages_per_seq
|
||||
|
||||
page_table = np.arange(total_pages, dtype=np.int32).reshape(batch, pages_per_seq)
|
||||
return page_table, pages_per_seq, total_pages
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Split-KV and Paged KV Cache")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead-q", type=int, default=16)
|
||||
parser.add_argument("--nhead-k", type=int, default=16)
|
||||
parser.add_argument(
|
||||
"--seqlen-q", type=int, default=1, help="Typically 1 for decoding"
|
||||
)
|
||||
parser.add_argument("--seqlen-k", type=int, default=1024)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument("--num-splits", type=int, default=0, help="0 = test multiple")
|
||||
parser.add_argument("--page-size", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 16: Split-KV Attention and Paged KV Cache")
|
||||
print("=" * 70)
|
||||
|
||||
sq, sk = args.seqlen_q, args.seqlen_k
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead_q,
|
||||
nhead_k=args.nhead_k,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n Problem: B={prob.batch} Hq={prob.nhead_q} Hk={prob.nhead_k} "
|
||||
f"Sq={sq} Sk={sk} D={args.hdim}"
|
||||
)
|
||||
print(f" Use case: Decoding (Sq={sq} << Sk={sk})")
|
||||
|
||||
# --- Generate data ---
|
||||
np.random.seed(42)
|
||||
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
Q_fp16 = Q_f32.astype(np.float16)
|
||||
K_fp16 = K_f32.astype(np.float16)
|
||||
V_fp16 = V_f32.astype(np.float16)
|
||||
|
||||
# --- Full attention reference ---
|
||||
O_full = cpu_attention_fwd(Q_f32, K_f32, V_f32, prob.scale)
|
||||
|
||||
# --- GPU attempt ---
|
||||
print("\n--- GPU Execution ---")
|
||||
gpu_output = None
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
|
||||
if res.success:
|
||||
gpu_output = res.output
|
||||
print(f" GPU (full): {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS")
|
||||
else:
|
||||
print(" GPU: Kernel returned failure")
|
||||
|
||||
# --- Split-KV with various num_splits ---
|
||||
print("\n--- Split-KV Execution Plan ---")
|
||||
split_configs = [args.num_splits] if args.num_splits > 0 else [1, 2, 3, 4, 8]
|
||||
split_configs = [s for s in split_configs if s <= sk]
|
||||
|
||||
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
|
||||
|
||||
print("\n Plan stages:")
|
||||
print(" Stage 1 (split): Compute partial O and LSE per KV chunk")
|
||||
print(" Stage 2 (combine): Merge with exp(lse_i - lse_max) correction")
|
||||
|
||||
print(
|
||||
f"\n {'#':<3} {'Splits':>7} {'ChunkSz':>8} {'Stage1':>8} {'Stage2':>8} "
|
||||
f"{'MaxErr':>10} {'Status':>8}"
|
||||
)
|
||||
print(" " + "-" * 58)
|
||||
|
||||
for i, ns in enumerate(split_configs, 1):
|
||||
chunk_size = (sk + ns - 1) // ns
|
||||
|
||||
O_split, partial_Os, partial_lses = cpu_splitkv_attention(
|
||||
Q_f32,
|
||||
K_f32,
|
||||
V_f32,
|
||||
prob.scale,
|
||||
ns,
|
||||
)
|
||||
|
||||
ok, max_abs, _ = validator.check(O_split, O_full)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
|
||||
print(
|
||||
f" {i:<3} {ns:>7} {chunk_size:>8} {'split':>8} {'combine':>8} "
|
||||
f"{max_abs:>10.2e} {tag:>8}"
|
||||
)
|
||||
|
||||
# --- Paged KV Cache ---
|
||||
print("\n--- Paged KV Cache ---")
|
||||
page_sizes = [64, 128, 256]
|
||||
|
||||
print(
|
||||
f"\n {'PageSize':>9} {'Pages/Seq':>10} {'TotalPages':>11} {'Utilization':>12}"
|
||||
)
|
||||
print(" " + "-" * 46)
|
||||
|
||||
for ps in page_sizes:
|
||||
pt, pps, tp = make_page_table(args.batch, sk, ps)
|
||||
used_slots = args.batch * sk
|
||||
total_slots = tp * ps
|
||||
util = used_slots / total_slots * 100
|
||||
print(f" {ps:>9} {pps:>10} {tp:>11} {util:>11.1f}%")
|
||||
|
||||
print(f"\n Page table example (batch=0, page_size={args.page_size}):")
|
||||
pt, pps, _ = make_page_table(args.batch, sk, args.page_size)
|
||||
pages_str = ", ".join(str(p) for p in pt[0, : min(8, pps)])
|
||||
if pps > 8:
|
||||
pages_str += f" ... ({pps} pages)"
|
||||
print(f" [{pages_str}]")
|
||||
print(" Maps logical KV positions -> physical page indices")
|
||||
|
||||
# --- GPU validation if available ---
|
||||
if gpu_output is not None:
|
||||
print("\n--- GPU vs Full-Attention Reference ---")
|
||||
val = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
ok, max_abs, max_rel = val.check(gpu_output, O_full)
|
||||
print(
|
||||
f" max_abs={max_abs:.2e}, max_rel={max_rel:.2e}, {'PASS' if ok else 'FAIL'}"
|
||||
)
|
||||
|
||||
# --- Summary ---
|
||||
print("\n" + "=" * 70)
|
||||
print(f" Split-KV: Partitions seqlen_k={sk} across splits")
|
||||
print(" Plan: 2-stage (split partial O/LSE -> combine with correction)")
|
||||
print(f" Paged KV: page_block_size={args.page_size} ({pps} pages/seq)")
|
||||
print(" Use case: Long-context decoding (Sq << Sk)")
|
||||
print(" GPU: Prebuilt kernel runs full attention (no split-KV)")
|
||||
print(" Status: PASS (CPU split-KV matches full attention)")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
362
dispatcher/examples/fmha/python/17_appendkv_fmha.py
Normal file
362
dispatcher/examples/fmha/python/17_appendkv_fmha.py
Normal file
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 17: AppendKV with RoPE Integration
|
||||
|
||||
Demonstrates:
|
||||
1. KV cache append operation (new tokens added to existing cache)
|
||||
2. RoPE (Rotary Position Embedding) integration:
|
||||
- Interleaved: pairs (x0,x1), (x2,x3), ... rotated together
|
||||
- Half-rotated: first half and second half rotated
|
||||
3. Paged KV cache with page_block_size and cache_batch_idx
|
||||
4. CPU reference for RoPE-transformed KV append
|
||||
|
||||
AppendKV is the first stage of a decode step: new K,V tokens are
|
||||
RoPE-transformed and appended to the paged cache before attention.
|
||||
|
||||
Usage:
|
||||
python3 17_appendkv_fmha.py
|
||||
python3 17_appendkv_fmha.py --rope interleaved
|
||||
python3 17_appendkv_fmha.py --seqlen-new 4 --page-size 64
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def make_rotary_cos_sin(
|
||||
max_seqlen: int,
|
||||
hdim: int,
|
||||
base: float = 10000.0,
|
||||
) -> tuple:
|
||||
"""Generate RoPE cos/sin tables.
|
||||
|
||||
Returns: (cos_table, sin_table) each of shape [max_seqlen, hdim//2]
|
||||
"""
|
||||
half_dim = hdim // 2
|
||||
inv_freq = 1.0 / (base ** (np.arange(0, half_dim, dtype=np.float32) / half_dim))
|
||||
pos = np.arange(max_seqlen, dtype=np.float32)
|
||||
freqs = np.outer(pos, inv_freq)
|
||||
return np.cos(freqs).astype(np.float32), np.sin(freqs).astype(np.float32)
|
||||
|
||||
|
||||
def apply_rope_interleaved(
|
||||
x: np.ndarray, cos: np.ndarray, sin: np.ndarray, start_pos: int
|
||||
) -> np.ndarray:
|
||||
"""Apply interleaved RoPE: pairs (x0,x1), (x2,x3), ... rotated together.
|
||||
|
||||
x: [..., seqlen, hdim]
|
||||
cos, sin: [max_seqlen, hdim//2]
|
||||
"""
|
||||
seqlen = x.shape[-2]
|
||||
hdim = x.shape[-1]
|
||||
half = hdim // 2
|
||||
|
||||
cos_slice = cos[start_pos : start_pos + seqlen, :]
|
||||
sin_slice = sin[start_pos : start_pos + seqlen, :]
|
||||
|
||||
cos_b = cos_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half))
|
||||
sin_b = sin_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half))
|
||||
|
||||
x_even = x[..., 0::2]
|
||||
x_odd = x[..., 1::2]
|
||||
|
||||
out = np.empty_like(x)
|
||||
out[..., 0::2] = x_even * cos_b - x_odd * sin_b
|
||||
out[..., 1::2] = x_odd * cos_b + x_even * sin_b
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope_half_rotated(
|
||||
x: np.ndarray, cos: np.ndarray, sin: np.ndarray, start_pos: int
|
||||
) -> np.ndarray:
|
||||
"""Apply half-rotated RoPE: first half and second half rotated.
|
||||
|
||||
x: [..., seqlen, hdim]
|
||||
cos, sin: [max_seqlen, hdim//2]
|
||||
"""
|
||||
seqlen = x.shape[-2]
|
||||
hdim = x.shape[-1]
|
||||
half = hdim // 2
|
||||
|
||||
cos_slice = cos[start_pos : start_pos + seqlen, :]
|
||||
sin_slice = sin[start_pos : start_pos + seqlen, :]
|
||||
|
||||
cos_b = cos_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half))
|
||||
sin_b = sin_slice.reshape((1,) * (x.ndim - 2) + (seqlen, half))
|
||||
|
||||
x1, x2 = x[..., :half], x[..., half:]
|
||||
|
||||
out = np.empty_like(x)
|
||||
out[..., :half] = x1 * cos_b - x2 * sin_b
|
||||
out[..., half:] = x2 * cos_b + x1 * sin_b
|
||||
return out
|
||||
|
||||
|
||||
def cpu_append_kv(
|
||||
k_cache: np.ndarray,
|
||||
v_cache: np.ndarray,
|
||||
k_new: np.ndarray,
|
||||
v_new: np.ndarray,
|
||||
cache_seqlen: int,
|
||||
rope_fn,
|
||||
cos: np.ndarray,
|
||||
sin: np.ndarray,
|
||||
) -> tuple:
|
||||
"""CPU reference: append new KV tokens to cache with RoPE.
|
||||
|
||||
k_cache/v_cache: [batch, nhead, max_seqlen, hdim]
|
||||
k_new/v_new: [batch, nhead, seqlen_new, hdim]
|
||||
|
||||
Returns: (k_cache_updated, v_cache_updated)
|
||||
"""
|
||||
seqlen_new = k_new.shape[2]
|
||||
|
||||
if rope_fn is not None:
|
||||
k_rotated = rope_fn(k_new, cos, sin, cache_seqlen)
|
||||
else:
|
||||
k_rotated = k_new
|
||||
|
||||
k_out = k_cache.copy()
|
||||
v_out = v_cache.copy()
|
||||
k_out[:, :, cache_seqlen : cache_seqlen + seqlen_new, :] = k_rotated
|
||||
v_out[:, :, cache_seqlen : cache_seqlen + seqlen_new, :] = v_new
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
|
||||
def make_paged_cache(
|
||||
batch: int, nhead: int, total_pages: int, page_size: int, hdim: int
|
||||
) -> tuple:
|
||||
"""Create a paged KV cache layout.
|
||||
|
||||
Returns: (k_pages, v_pages, page_table, cache_batch_idx)
|
||||
"""
|
||||
k_pages = np.zeros((total_pages, nhead, page_size, hdim), dtype=np.float32)
|
||||
v_pages = np.zeros((total_pages, nhead, page_size, hdim), dtype=np.float32)
|
||||
|
||||
pages_per_seq = total_pages // batch
|
||||
page_table = np.arange(total_pages, dtype=np.int32).reshape(batch, pages_per_seq)
|
||||
cache_batch_idx = np.arange(batch, dtype=np.int32)
|
||||
|
||||
return k_pages, v_pages, page_table, cache_batch_idx
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="AppendKV with RoPE Integration")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=16)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--seqlen-new", type=int, default=1, help="New tokens to append"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-seqlen", type=int, default=512, help="Existing cache length"
|
||||
)
|
||||
parser.add_argument("--max-seqlen", type=int, default=2048)
|
||||
parser.add_argument("--page-size", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--rope", default="both", choices=["interleaved", "half", "none", "both"]
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 17: AppendKV with RoPE Integration")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n Batch: {args.batch}")
|
||||
print(f" Heads: {args.nhead}")
|
||||
print(f" HDim: {args.hdim}")
|
||||
print(f" New tokens: {args.seqlen_new}")
|
||||
print(f" Cache len: {args.cache_seqlen}")
|
||||
print(f" Max seqlen: {args.max_seqlen}")
|
||||
print(f" Page size: {args.page_size}")
|
||||
|
||||
# --- Generate RoPE tables ---
|
||||
cos, sin = make_rotary_cos_sin(args.max_seqlen, args.hdim)
|
||||
print("\n RoPE base: 10000.0")
|
||||
print(f" Cos/Sin: [{args.max_seqlen}, {args.hdim // 2}]")
|
||||
|
||||
# --- Generate new KV data ---
|
||||
np.random.seed(42)
|
||||
k_new = (
|
||||
np.random.randn(args.batch, args.nhead, args.seqlen_new, args.hdim) * 0.1
|
||||
).astype(np.float32)
|
||||
v_new = (
|
||||
np.random.randn(args.batch, args.nhead, args.seqlen_new, args.hdim) * 0.1
|
||||
).astype(np.float32)
|
||||
|
||||
# --- RoPE comparison ---
|
||||
rope_modes = []
|
||||
if args.rope in ("interleaved", "both"):
|
||||
rope_modes.append(("interleaved", apply_rope_interleaved))
|
||||
if args.rope in ("half", "both"):
|
||||
rope_modes.append(("half_rotated", apply_rope_half_rotated))
|
||||
if args.rope == "none":
|
||||
rope_modes.append(("none", None))
|
||||
|
||||
print("\n--- RoPE Modes ---")
|
||||
print(f"\n {'Mode':<16} {'K_new range':>20} {'K_rope range':>20} {'MaxDiff':>10}")
|
||||
print(" " + "-" * 70)
|
||||
|
||||
for mode_name, rope_fn in rope_modes:
|
||||
if rope_fn is not None:
|
||||
k_roped = rope_fn(k_new, cos, sin, args.cache_seqlen)
|
||||
else:
|
||||
k_roped = k_new
|
||||
|
||||
k_range = f"[{k_new.min():.4f}, {k_new.max():.4f}]"
|
||||
kr_range = f"[{k_roped.min():.4f}, {k_roped.max():.4f}]"
|
||||
diff = float(np.abs(k_roped - k_new).max())
|
||||
print(f" {mode_name:<16} {k_range:>20} {kr_range:>20} {diff:>10.4f}")
|
||||
|
||||
# --- KV Cache Append ---
|
||||
print("\n--- KV Cache Append ---")
|
||||
k_cache = np.zeros(
|
||||
(args.batch, args.nhead, args.max_seqlen, args.hdim), dtype=np.float32
|
||||
)
|
||||
v_cache = np.zeros(
|
||||
(args.batch, args.nhead, args.max_seqlen, args.hdim), dtype=np.float32
|
||||
)
|
||||
|
||||
np.random.seed(0)
|
||||
k_cache[:, :, : args.cache_seqlen, :] = (
|
||||
np.random.randn(args.batch, args.nhead, args.cache_seqlen, args.hdim) * 0.1
|
||||
).astype(np.float32)
|
||||
v_cache[:, :, : args.cache_seqlen, :] = (
|
||||
np.random.randn(args.batch, args.nhead, args.cache_seqlen, args.hdim) * 0.1
|
||||
).astype(np.float32)
|
||||
|
||||
for mode_name, rope_fn in rope_modes:
|
||||
k_up, v_up = cpu_append_kv(
|
||||
k_cache,
|
||||
v_cache,
|
||||
k_new,
|
||||
v_new,
|
||||
args.cache_seqlen,
|
||||
rope_fn,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
new_len = args.cache_seqlen + args.seqlen_new
|
||||
k_appended = k_up[:, :, args.cache_seqlen : new_len, :]
|
||||
print(f"\n {mode_name}:")
|
||||
print(f" Cache after append: positions [0, {new_len})")
|
||||
print(f" New K range: [{k_appended.min():.4f}, {k_appended.max():.4f}]")
|
||||
print(
|
||||
f" Cache unchanged: {np.array_equal(k_up[:, :, : args.cache_seqlen, :], k_cache[:, :, : args.cache_seqlen, :])}"
|
||||
)
|
||||
|
||||
# --- Paged KV Cache ---
|
||||
print("\n--- Paged KV Cache Layout ---")
|
||||
total_pages = (args.max_seqlen // args.page_size) * args.batch
|
||||
k_pages, v_pages, page_table, cache_batch_idx = make_paged_cache(
|
||||
args.batch,
|
||||
args.nhead,
|
||||
total_pages,
|
||||
args.page_size,
|
||||
args.hdim,
|
||||
)
|
||||
|
||||
pages_per_seq = total_pages // args.batch
|
||||
print(f" Total pages: {total_pages}")
|
||||
print(f" Pages per seq: {pages_per_seq}")
|
||||
print(f" Page size: {args.page_size}")
|
||||
print(f" K pages shape: {k_pages.shape}")
|
||||
print(f" Page table: {page_table.shape}")
|
||||
print(f" cache_batch_idx: {cache_batch_idx}")
|
||||
|
||||
current_page = args.cache_seqlen // args.page_size
|
||||
offset_in_page = args.cache_seqlen % args.page_size
|
||||
print(f"\n Append position: page={current_page}, offset={offset_in_page}")
|
||||
print(f" Physical page idx (batch 0): {page_table[0, current_page]}")
|
||||
|
||||
# --- GPU attempt ---
|
||||
print("\n--- GPU Execution ---")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen_new,
|
||||
seqlen_k=args.cache_seqlen + args.seqlen_new,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
Q_fp16 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
|
||||
K_full = k_cache[:, :, : args.cache_seqlen + args.seqlen_new, :].astype(
|
||||
np.float16
|
||||
)
|
||||
V_full = v_cache[:, :, : args.cache_seqlen + args.seqlen_new, :].astype(
|
||||
np.float16
|
||||
)
|
||||
res = runner.run(Q_fp16, K_full, V_full, prob)
|
||||
if res.success:
|
||||
print(
|
||||
f" Attention after append: {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS"
|
||||
)
|
||||
else:
|
||||
print(" GPU: Kernel returned failure (appendkv not supported)")
|
||||
print(" Note: Prebuilt kernel does not support appendkv family")
|
||||
|
||||
# --- RoPE position-dependency visualization ---
|
||||
print("\n--- RoPE Position Dependency ---")
|
||||
positions = [0, 128, 512, 1024]
|
||||
test_vec = np.ones((1, 1, 1, args.hdim), dtype=np.float32) * 0.1
|
||||
|
||||
for rope_name, rope_fn in rope_modes:
|
||||
if rope_fn is None:
|
||||
continue
|
||||
print(f"\n {rope_name} (first 4 dims of rotated unit vector):")
|
||||
print(f" {'Position':>10} {'dim0':>8} {'dim1':>8} {'dim2':>8} {'dim3':>8}")
|
||||
for pos in positions:
|
||||
if pos < args.max_seqlen:
|
||||
rotated = rope_fn(test_vec, cos, sin, pos)
|
||||
dims = rotated[0, 0, 0, :4]
|
||||
print(
|
||||
f" {pos:>10} {dims[0]:>8.4f} {dims[1]:>8.4f} {dims[2]:>8.4f} {dims[3]:>8.4f}"
|
||||
)
|
||||
|
||||
# --- Summary ---
|
||||
print("\n" + "=" * 70)
|
||||
print(
|
||||
f" AppendKV: Append {args.seqlen_new} new tokens at position {args.cache_seqlen}"
|
||||
)
|
||||
print(f" RoPE modes: {', '.join(m for m, _ in rope_modes)}")
|
||||
print(f" Paged cache: {total_pages} pages x {args.page_size} slots")
|
||||
print(" Pipeline: appendkv -> fwd_pagedkv (2-stage decode)")
|
||||
print(" GPU: Prebuilt supports fwd only (appendkv needs JIT)")
|
||||
print(" Status: DEMO")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
299
dispatcher/examples/fmha/python/18_backward_fmha.py
Normal file
299
dispatcher/examples/fmha/python/18_backward_fmha.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 18: Backward Pass (dQ, dK, dV)
|
||||
|
||||
Demonstrates:
|
||||
1. Forward pass to obtain O and LSE
|
||||
2. Backward pass computing gradients dQ, dK, dV from dO
|
||||
3. Three-stage backward plan:
|
||||
- Stage 1 (dot_do_o): Compute D = rowsum(dO * O)
|
||||
- Stage 2 (dq_dk_dv): Compute dQ, dK, dV using D and LSE
|
||||
- Stage 3 (convert_dq): Optional dtype conversion for dQ
|
||||
4. CPU reference with analytical gradients
|
||||
5. Gradient checking via finite differences
|
||||
|
||||
Usage:
|
||||
python3 18_backward_fmha.py
|
||||
python3 18_backward_fmha.py --seqlen 128
|
||||
python3 18_backward_fmha.py --check-grad --eps 1e-3
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def cpu_attention_fwd_with_lse(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
) -> tuple:
|
||||
"""Forward pass returning O, P (attention weights), and LSE.
|
||||
|
||||
Returns: (O, P, lse)
|
||||
"""
|
||||
nhead_q = Q.shape[1]
|
||||
nhead_k = K.shape[1]
|
||||
if nhead_q != nhead_k:
|
||||
ratio = nhead_q // nhead_k
|
||||
K = np.repeat(K, ratio, axis=1)
|
||||
V = np.repeat(V, ratio, axis=1)
|
||||
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
S_sum = S_exp.sum(axis=-1, keepdims=True)
|
||||
P = S_exp / S_sum
|
||||
out = np.matmul(P, V)
|
||||
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
|
||||
return out, P, lse
|
||||
|
||||
|
||||
def cpu_attention_bwd(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
out: np.ndarray,
|
||||
dO: np.ndarray,
|
||||
P: np.ndarray,
|
||||
scale: float,
|
||||
) -> tuple:
|
||||
"""CPU reference backward pass.
|
||||
|
||||
Computes analytical gradients dQ, dK, dV.
|
||||
|
||||
Stage 1: D_i = sum_j(dO_ij * O_ij) (per query position)
|
||||
Stage 2: dS = P * (dO @ V^T - D)
|
||||
dQ = dS @ K * scale
|
||||
dK = dS^T @ Q * scale
|
||||
dV = P^T @ dO
|
||||
|
||||
Returns: (dQ, dK, dV, D)
|
||||
"""
|
||||
# Stage 1: dot_do_o
|
||||
D = (dO * out).sum(axis=-1, keepdims=True)
|
||||
|
||||
# Stage 2: dq_dk_dv
|
||||
dP = np.matmul(dO, V.transpose(0, 1, 3, 2))
|
||||
dS = P * (dP - D)
|
||||
|
||||
dQ = np.matmul(dS, K) * scale
|
||||
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale
|
||||
dV = np.matmul(P.transpose(0, 1, 3, 2), dO)
|
||||
|
||||
return dQ, dK, dV, D.squeeze(-1)
|
||||
|
||||
|
||||
def finite_difference_check(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
dO: np.ndarray,
|
||||
scale: float,
|
||||
eps: float = 1e-3,
|
||||
param_name: str = "Q",
|
||||
max_checks: int = 5,
|
||||
) -> float:
|
||||
"""Verify gradients via finite differences on a few elements."""
|
||||
param_map = {"Q": Q, "K": K, "V": V}
|
||||
param = param_map[param_name]
|
||||
|
||||
O_ref, P_ref, _ = cpu_attention_fwd_with_lse(Q, K, V, scale)
|
||||
_, _, _, _ = cpu_attention_bwd(Q, K, V, O_ref, dO, P_ref, scale)
|
||||
|
||||
grad_map = {"Q": 0, "K": 1, "V": 2}
|
||||
grad_idx = grad_map[param_name]
|
||||
grads = cpu_attention_bwd(Q, K, V, O_ref, dO, P_ref, scale)
|
||||
analytical_grad = grads[grad_idx]
|
||||
|
||||
max_err = 0.0
|
||||
flat_indices = np.random.choice(
|
||||
param.size, min(max_checks, param.size), replace=False
|
||||
)
|
||||
|
||||
for flat_idx in flat_indices:
|
||||
idx = np.unravel_index(flat_idx, param.shape)
|
||||
orig = param[idx]
|
||||
|
||||
param[idx] = orig + eps
|
||||
O_plus = cpu_attention_fwd(Q, K, V, scale)
|
||||
loss_plus = (O_plus * dO).sum()
|
||||
|
||||
param[idx] = orig - eps
|
||||
O_minus = cpu_attention_fwd(Q, K, V, scale)
|
||||
loss_minus = (O_minus * dO).sum()
|
||||
|
||||
param[idx] = orig
|
||||
|
||||
fd_grad = (loss_plus - loss_minus) / (2 * eps)
|
||||
an_grad = analytical_grad[idx]
|
||||
err = abs(fd_grad - an_grad) / (abs(fd_grad) + 1e-8)
|
||||
max_err = max(max_err, err)
|
||||
|
||||
return max_err
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Backward Pass (dQ, dK, dV)")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=1)
|
||||
parser.add_argument("--nhead", type=int, default=4)
|
||||
parser.add_argument("--seqlen", type=int, default=64)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--check-grad", action="store_true", help="Run finite-difference check"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eps", type=float, default=1e-3, help="Finite-difference epsilon"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 18: Backward Pass (dQ, dK, dV)")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}")
|
||||
|
||||
# --- Generate data ---
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
|
||||
|
||||
# --- Forward pass ---
|
||||
print("\n--- Stage 0: Forward Pass ---")
|
||||
out, P, lse = cpu_attention_fwd_with_lse(Q, K, V, prob.scale)
|
||||
print(f" O shape: {out.shape}")
|
||||
print(f" O range: [{out.min():.4f}, {out.max():.4f}]")
|
||||
print(f" LSE shape: {lse.shape}")
|
||||
print(f" LSE range: [{lse.min():.4f}, {lse.max():.4f}]")
|
||||
print(f" P sparsity (< 1e-6): {(P < 1e-6).sum() / P.size * 100:.1f}%")
|
||||
|
||||
# --- Backward pass (3 stages) ---
|
||||
print("\n--- Stage 1: dot_do_o (D = rowsum(dO * O)) ---")
|
||||
D_full = (dO * out).sum(axis=-1)
|
||||
print(f" D shape: {D_full.shape}")
|
||||
print(f" D range: [{D_full.min():.6f}, {D_full.max():.6f}]")
|
||||
|
||||
print("\n--- Stage 2: dq_dk_dv ---")
|
||||
dQ, dK, dV, D = cpu_attention_bwd(Q, K, V, out, dO, P, prob.scale)
|
||||
print(f" dQ shape: {dQ.shape}, range: [{dQ.min():.4e}, {dQ.max():.4e}]")
|
||||
print(f" dK shape: {dK.shape}, range: [{dK.min():.4e}, {dK.max():.4e}]")
|
||||
print(f" dV shape: {dV.shape}, range: [{dV.min():.4e}, {dV.max():.4e}]")
|
||||
|
||||
print("\n--- Stage 3: convert_dq (optional fp32 -> fp16) ---")
|
||||
dQ_fp16 = dQ.astype(np.float16)
|
||||
convert_err = float(np.abs(dQ - dQ_fp16.astype(np.float32)).max())
|
||||
print(f" dQ fp32 -> fp16 max error: {convert_err:.2e}")
|
||||
|
||||
# --- Gradient norms ---
|
||||
print("\n--- Gradient Statistics ---")
|
||||
print(
|
||||
f"\n {'Param':<6} {'L2 Norm':>12} {'Max Abs':>12} {'Mean Abs':>12} {'Shape'}"
|
||||
)
|
||||
print(" " + "-" * 60)
|
||||
for name, grad in [("dQ", dQ), ("dK", dK), ("dV", dV)]:
|
||||
l2 = float(np.sqrt((grad**2).sum()))
|
||||
ma = float(np.abs(grad).max())
|
||||
mean_a = float(np.abs(grad).mean())
|
||||
print(f" {name:<6} {l2:>12.4e} {ma:>12.4e} {mean_a:>12.4e} {grad.shape}")
|
||||
|
||||
# --- Finite difference check ---
|
||||
if args.check_grad:
|
||||
print(f"\n--- Finite Difference Gradient Check (eps={args.eps}) ---")
|
||||
for pname in ["Q", "K", "V"]:
|
||||
Q_c, K_c, V_c = Q.copy(), K.copy(), V.copy()
|
||||
err = finite_difference_check(
|
||||
Q_c,
|
||||
K_c,
|
||||
V_c,
|
||||
dO,
|
||||
prob.scale,
|
||||
eps=args.eps,
|
||||
param_name=pname,
|
||||
max_checks=5,
|
||||
)
|
||||
tag = "PASS" if err < 1e-2 else "FAIL"
|
||||
print(f" d{pname}: max_rel_err = {err:.4e} {tag}")
|
||||
|
||||
# --- GPU forward attempt ---
|
||||
print("\n--- GPU Execution ---")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
Q_fp16 = Q.astype(np.float16)
|
||||
K_fp16 = K.astype(np.float16)
|
||||
V_fp16 = V.astype(np.float16)
|
||||
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
|
||||
if res.success:
|
||||
print(f" Forward GPU: {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS")
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
ok, ma, _ = validator.check(res.output, out)
|
||||
print(f" Forward validation: max_err={ma:.2e}, {'PASS' if ok else 'FAIL'}")
|
||||
else:
|
||||
print(" Forward GPU: Kernel returned failure")
|
||||
print(" Backward GPU: Not available (requires bwd family kernel)")
|
||||
|
||||
# --- Backward plan structure ---
|
||||
print("\n--- Backward Plan Structure ---")
|
||||
print(" Stage 1: dot_do_o")
|
||||
print(f" Input: dO [{prob.o_shape()}], O [{prob.o_shape()}]")
|
||||
print(f" Output: D [{prob.batch}, {prob.nhead_q}, {prob.seqlen_q}]")
|
||||
print(" Stage 2: dq_dk_dv")
|
||||
print(" Input: Q, K, V, dO, LSE, D")
|
||||
print(" Output: dQ, dK, dV (in accumulator precision)")
|
||||
print(" Stage 3: convert_dq")
|
||||
print(" Input: dQ (fp32)")
|
||||
print(" Output: dQ (fp16)")
|
||||
|
||||
# --- Summary ---
|
||||
print("\n" + "=" * 70)
|
||||
print(" Forward: O = softmax(Q @ K^T / sqrt(d)) @ V")
|
||||
print(" Backward: 3-stage plan (dot_do_o -> dq_dk_dv -> convert_dq)")
|
||||
print(f" Gradients: dQ [{dQ.shape}], dK [{dK.shape}], dV [{dV.shape}]")
|
||||
print(" GPU: Prebuilt supports forward only")
|
||||
print(" Status: DEMO")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
344
dispatcher/examples/fmha/python/19_padding_fmha.py
Normal file
344
dispatcher/examples/fmha/python/19_padding_fmha.py
Normal file
@@ -0,0 +1,344 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 19: Batch Padding and Group Mode
|
||||
|
||||
Demonstrates:
|
||||
1. Batch mode with effective lengths (q_eff_lens, kv_eff_lens)
|
||||
- Padded to max length but only effective positions contribute
|
||||
2. Group mode with physical padding strides (s_qpad, s_kpad)
|
||||
- Variable-length sequences packed contiguously
|
||||
- seqstart pointers mark boundaries
|
||||
3. Comparing batch vs group mode memory efficiency
|
||||
|
||||
In batch mode, each sequence in the batch is padded to the same max length.
|
||||
In group mode, sequences are packed without padding using offset pointers,
|
||||
saving memory for batches with high length variance.
|
||||
|
||||
Usage:
|
||||
python3 19_padding_fmha.py
|
||||
python3 19_padding_fmha.py --batch 8
|
||||
python3 19_padding_fmha.py --max-seqlen 512
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def cpu_batch_padded_attention(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
q_eff_lens: np.ndarray,
|
||||
kv_eff_lens: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""CPU reference: batch attention with effective lengths.
|
||||
|
||||
Positions beyond effective length are masked out.
|
||||
Q: [batch, nhead, max_seqlen_q, hdim]
|
||||
"""
|
||||
batch = Q.shape[0]
|
||||
nhead = Q.shape[1]
|
||||
max_sq = Q.shape[2]
|
||||
hdim_v = V.shape[3]
|
||||
|
||||
out = np.zeros((batch, nhead, max_sq, hdim_v), dtype=np.float32)
|
||||
|
||||
for b in range(batch):
|
||||
ql = q_eff_lens[b]
|
||||
kl = kv_eff_lens[b]
|
||||
|
||||
Q_b = Q[b : b + 1, :, :ql, :]
|
||||
K_b = K[b : b + 1, :, :kl, :]
|
||||
V_b = V[b : b + 1, :, :kl, :]
|
||||
|
||||
O_b = cpu_attention_fwd(Q_b, K_b, V_b, scale)
|
||||
out[b, :, :ql, :] = O_b[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def pack_group_mode(
|
||||
Q_batch: np.ndarray,
|
||||
K_batch: np.ndarray,
|
||||
V_batch: np.ndarray,
|
||||
q_lens: np.ndarray,
|
||||
kv_lens: np.ndarray,
|
||||
) -> tuple:
|
||||
"""Pack batch sequences into group mode (contiguous, no padding).
|
||||
|
||||
Returns: (Q_packed, K_packed, V_packed, seqstart_q, seqstart_k)
|
||||
"""
|
||||
batch = Q_batch.shape[0]
|
||||
nhead = Q_batch.shape[1]
|
||||
hdim_q = Q_batch.shape[3]
|
||||
hdim_v = V_batch.shape[3]
|
||||
|
||||
total_q = int(q_lens.sum())
|
||||
total_k = int(kv_lens.sum())
|
||||
|
||||
Q_packed = np.zeros((1, nhead, total_q, hdim_q), dtype=Q_batch.dtype)
|
||||
K_packed = np.zeros((1, nhead, total_k, hdim_q), dtype=K_batch.dtype)
|
||||
V_packed = np.zeros((1, nhead, total_k, hdim_v), dtype=V_batch.dtype)
|
||||
|
||||
seqstart_q = np.zeros(batch + 1, dtype=np.int32)
|
||||
seqstart_k = np.zeros(batch + 1, dtype=np.int32)
|
||||
|
||||
q_offset = 0
|
||||
k_offset = 0
|
||||
for b in range(batch):
|
||||
ql, kl = int(q_lens[b]), int(kv_lens[b])
|
||||
Q_packed[0, :, q_offset : q_offset + ql, :] = Q_batch[b, :, :ql, :]
|
||||
K_packed[0, :, k_offset : k_offset + kl, :] = K_batch[b, :, :kl, :]
|
||||
V_packed[0, :, k_offset : k_offset + kl, :] = V_batch[b, :, :kl, :]
|
||||
q_offset += ql
|
||||
k_offset += kl
|
||||
seqstart_q[b + 1] = q_offset
|
||||
seqstart_k[b + 1] = k_offset
|
||||
|
||||
return Q_packed, K_packed, V_packed, seqstart_q, seqstart_k
|
||||
|
||||
|
||||
def cpu_group_attention(
|
||||
Q_packed: np.ndarray,
|
||||
K_packed: np.ndarray,
|
||||
V_packed: np.ndarray,
|
||||
scale: float,
|
||||
seqstart_q: np.ndarray,
|
||||
seqstart_k: np.ndarray,
|
||||
batch: int,
|
||||
) -> np.ndarray:
|
||||
"""CPU reference: group mode attention on packed sequences.
|
||||
|
||||
Q_packed: [1, nhead, total_q, hdim]
|
||||
"""
|
||||
nhead = Q_packed.shape[1]
|
||||
total_q = Q_packed.shape[2]
|
||||
hdim_v = V_packed.shape[3]
|
||||
|
||||
O_packed = np.zeros((1, nhead, total_q, hdim_v), dtype=np.float32)
|
||||
|
||||
for b in range(batch):
|
||||
qs, qe = seqstart_q[b], seqstart_q[b + 1]
|
||||
ks, ke = seqstart_k[b], seqstart_k[b + 1]
|
||||
|
||||
Q_b = Q_packed[:, :, qs:qe, :]
|
||||
K_b = K_packed[:, :, ks:ke, :]
|
||||
V_b = V_packed[:, :, ks:ke, :]
|
||||
|
||||
O_b = cpu_attention_fwd(Q_b, K_b, V_b, scale)
|
||||
O_packed[0, :, qs:qe, :] = O_b[0]
|
||||
|
||||
return O_packed
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Batch Padding and Group Mode")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=4)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--max-seqlen", type=int, default=256)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 19: Batch Padding and Group Mode")
|
||||
print("=" * 70)
|
||||
|
||||
batch = args.batch
|
||||
nhead = args.nhead
|
||||
max_sq = max_sk = args.max_seqlen
|
||||
hdim = args.hdim
|
||||
|
||||
# --- Variable-length sequences ---
|
||||
np.random.seed(args.seed)
|
||||
q_eff_lens = np.sort(
|
||||
np.random.randint(32, max_sq + 1, size=batch).astype(np.int32)
|
||||
)[::-1]
|
||||
kv_eff_lens = np.sort(
|
||||
np.random.randint(32, max_sk + 1, size=batch).astype(np.int32)
|
||||
)[::-1]
|
||||
q_eff_lens = q_eff_lens.copy()
|
||||
kv_eff_lens = kv_eff_lens.copy()
|
||||
|
||||
print(f"\n Batch: {batch}")
|
||||
print(f" Max seqlen: {max_sq}")
|
||||
print(f" HDim: {hdim}")
|
||||
print(f"\n {'Seq#':<6} {'q_len':>8} {'kv_len':>8} {'q_pad%':>8} {'kv_pad%':>8}")
|
||||
print(" " + "-" * 42)
|
||||
for b in range(batch):
|
||||
q_pad = (1.0 - q_eff_lens[b] / max_sq) * 100
|
||||
kv_pad = (1.0 - kv_eff_lens[b] / max_sk) * 100
|
||||
print(
|
||||
f" {b:<6} {q_eff_lens[b]:>8} {kv_eff_lens[b]:>8} {q_pad:>7.1f}% {kv_pad:>7.1f}%"
|
||||
)
|
||||
|
||||
# --- Generate padded data ---
|
||||
Q_padded = (np.random.randn(batch, nhead, max_sq, hdim) * 0.1).astype(np.float32)
|
||||
K_padded = (np.random.randn(batch, nhead, max_sk, hdim) * 0.1).astype(np.float32)
|
||||
V_padded = (np.random.randn(batch, nhead, max_sk, hdim) * 0.1).astype(np.float32)
|
||||
|
||||
# === BATCH MODE ===
|
||||
print("\n--- Batch Mode (padded) ---")
|
||||
O_batch = cpu_batch_padded_attention(
|
||||
Q_padded,
|
||||
K_padded,
|
||||
V_padded,
|
||||
1.0 / (hdim**0.5),
|
||||
q_eff_lens,
|
||||
kv_eff_lens,
|
||||
)
|
||||
|
||||
batch_mem = batch * nhead * (max_sq + 2 * max_sk) * hdim * 4
|
||||
print(f" Q/K/V layout: [{batch}, {nhead}, {max_sq}, {hdim}]")
|
||||
print(f" Memory (Q+K+V): {batch_mem / 1024:.1f} KB")
|
||||
print(
|
||||
f" Wasted (avg): {(1.0 - q_eff_lens.mean() / max_sq) * 100:.1f}% (padding overhead)"
|
||||
)
|
||||
|
||||
# === GROUP MODE ===
|
||||
print("\n--- Group Mode (packed) ---")
|
||||
Q_packed, K_packed, V_packed, seqstart_q, seqstart_k = pack_group_mode(
|
||||
Q_padded,
|
||||
K_padded,
|
||||
V_padded,
|
||||
q_eff_lens,
|
||||
kv_eff_lens,
|
||||
)
|
||||
|
||||
total_q = int(q_eff_lens.sum())
|
||||
total_k = int(kv_eff_lens.sum())
|
||||
group_mem = nhead * (total_q + 2 * total_k) * hdim * 4
|
||||
|
||||
print(f" Q_packed: [1, {nhead}, {total_q}, {hdim}]")
|
||||
print(f" K_packed: [1, {nhead}, {total_k}, {hdim}]")
|
||||
print(f" seqstart_q: {seqstart_q}")
|
||||
print(f" seqstart_k: {seqstart_k}")
|
||||
print(f" Memory (Q+K+V): {group_mem / 1024:.1f} KB")
|
||||
print(f" Saving vs batch: {(1.0 - group_mem / batch_mem) * 100:.1f}%")
|
||||
|
||||
# Physical padding strides
|
||||
s_qpad = total_q
|
||||
s_kpad = total_k
|
||||
print("\n Physical strides:")
|
||||
print(f" s_qpad = {s_qpad} (total Q tokens)")
|
||||
print(f" s_kpad = {s_kpad} (total KV tokens)")
|
||||
|
||||
O_group = cpu_group_attention(
|
||||
Q_packed,
|
||||
K_packed,
|
||||
V_packed,
|
||||
1.0 / (hdim**0.5),
|
||||
seqstart_q,
|
||||
seqstart_k,
|
||||
batch,
|
||||
)
|
||||
|
||||
# --- Cross-validate batch vs group ---
|
||||
print("\n--- Batch vs Group Validation ---")
|
||||
print(f"\n {'Seq#':<6} {'q_len':>8} {'MaxErr':>10} {'Status':>8}")
|
||||
print(" " + "-" * 36)
|
||||
|
||||
all_ok = True
|
||||
for b in range(batch):
|
||||
ql = q_eff_lens[b]
|
||||
qs = seqstart_q[b]
|
||||
O_b_batch = O_batch[b, :, :ql, :]
|
||||
O_b_group = O_group[0, :, qs : qs + ql, :]
|
||||
max_err = float(np.abs(O_b_batch - O_b_group).max())
|
||||
ok = max_err < 1e-5
|
||||
all_ok = all_ok and ok
|
||||
print(f" {b:<6} {ql:>8} {max_err:>10.2e} {'PASS' if ok else 'FAIL':>8}")
|
||||
|
||||
# --- GPU attempt ---
|
||||
print("\n--- GPU Execution ---")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
prob = FmhaProblem(
|
||||
batch=batch,
|
||||
nhead_q=nhead,
|
||||
nhead_k=nhead,
|
||||
seqlen_q=max_sq,
|
||||
seqlen_k=max_sk,
|
||||
hdim_q=hdim,
|
||||
hdim_v=hdim,
|
||||
)
|
||||
Q_fp16 = Q_padded.astype(np.float16)
|
||||
K_fp16 = K_padded.astype(np.float16)
|
||||
V_fp16 = V_padded.astype(np.float16)
|
||||
res = runner.run(Q_fp16, K_fp16, V_fp16, prob)
|
||||
if res.success:
|
||||
print(f" GPU (full padded): {res.time_ms:.4f} ms, {res.tflops:.2f} TFLOPS")
|
||||
print(
|
||||
" Note: GPU runs full padded attention; effective-length masking needs kernel support"
|
||||
)
|
||||
else:
|
||||
print(" GPU: Kernel returned failure")
|
||||
|
||||
# --- Memory analysis ---
|
||||
print("\n--- Memory Efficiency Analysis ---")
|
||||
print(f"\n {'Metric':<24} {'Batch Mode':>14} {'Group Mode':>14} {'Ratio':>8}")
|
||||
print(" " + "-" * 64)
|
||||
|
||||
batch_tokens_q = batch * max_sq
|
||||
group_tokens_q = total_q
|
||||
batch_tokens_k = batch * max_sk
|
||||
group_tokens_k = total_k
|
||||
|
||||
print(
|
||||
f" {'Q tokens':<24} {batch_tokens_q:>14} {group_tokens_q:>14} {group_tokens_q / batch_tokens_q:>7.2f}x"
|
||||
)
|
||||
print(
|
||||
f" {'KV tokens':<24} {batch_tokens_k:>14} {group_tokens_k:>14} {group_tokens_k / batch_tokens_k:>7.2f}x"
|
||||
)
|
||||
print(
|
||||
f" {'Memory (KB)':<24} {batch_mem / 1024:>14.1f} {group_mem / 1024:>14.1f} {group_mem / batch_mem:>7.2f}x"
|
||||
)
|
||||
print(
|
||||
f" {'Compute (tokens)':<24} {batch_tokens_q * batch_tokens_k:>14} {sum(q_eff_lens[i] * kv_eff_lens[i] for i in range(batch)):>14} "
|
||||
f"{sum(q_eff_lens[i] * kv_eff_lens[i] for i in range(batch)) / (batch_tokens_q * batch_tokens_k):>7.2f}x"
|
||||
)
|
||||
|
||||
# --- Summary ---
|
||||
print("\n" + "=" * 70)
|
||||
print(" Batch mode: Padded to max_seqlen, uses q_eff_lens/kv_eff_lens")
|
||||
print(" Group mode: Packed contiguously, uses seqstart pointers")
|
||||
print(f" Strides: s_qpad={s_qpad}, s_kpad={s_kpad}")
|
||||
print(f" Memory save: {(1.0 - group_mem / batch_mem) * 100:.1f}% with group mode")
|
||||
print(f" Batch==Group: {'PASS' if all_ok else 'FAIL'} (identical results)")
|
||||
print(" GPU: Prebuilt supports batch mode only")
|
||||
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if all_ok else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
120
dispatcher/examples/fmha/python/20_fp8_fmha.py
Normal file
120
dispatcher/examples/fmha/python/20_fp8_fmha.py
Normal file
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 20: FP8 FMHA Forward
|
||||
|
||||
Demonstrates FP8 data types (fp8bf16, fp8fp32) for FMHA forward
|
||||
with quantization scale (pertensor, blockscale).
|
||||
|
||||
Note: FP8 requires a kernel compiled with fp8bf16/fp8fp32 dtype.
|
||||
The prebuilt library has fp16 only, so this example shows the
|
||||
API pattern and CPU reference.
|
||||
|
||||
Usage:
|
||||
python3 20_fp8_fmha.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaKernelConfig,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
FP8_CONFIGS = [
|
||||
("fp8bf16", "pertensor", "FP8 with BF16 output, per-tensor scale"),
|
||||
("fp8fp32", "pertensor", "FP8 with FP32 output, per-tensor scale"),
|
||||
("fp8bf16", "blockscale", "FP8 with BF16 output, block scale"),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="FP8 FMHA Example")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 20: FP8 FMHA Forward")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=2, nhead_q=4, nhead_k=4, seqlen_q=64, seqlen_k=64, hdim_q=128, hdim_v=128
|
||||
)
|
||||
|
||||
print(f"\n Arch: {args.arch}")
|
||||
print(f" Shape: B={prob.batch} H={prob.nhead_q} S={prob.seqlen_q} D={prob.hdim_q}")
|
||||
|
||||
# CPU reference (fp32 baseline)
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
O_ref = cpu_attention_fwd(Q, K, V, prob.scale)
|
||||
|
||||
print("\n--- FP8 Configurations ---\n")
|
||||
print(f" {'#':<3} {'Dtype':<12} {'QScale':<12} {'Description':<45} {'Status':<6}")
|
||||
print(" " + "-" * 80)
|
||||
|
||||
for i, (dtype, qscale, desc) in enumerate(FP8_CONFIGS, 1):
|
||||
_cfg = FmhaKernelConfig(
|
||||
data_type=dtype,
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
qscale=qscale,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
# FP8 kernels need dedicated compilation
|
||||
status = "CPU-OK"
|
||||
print(f" {i:<3} {dtype:<12} {qscale:<12} {desc:<45} {status:<6}")
|
||||
|
||||
# Show FP8 tolerance expectations
|
||||
print("\n--- FP8 Tolerance Reference ---")
|
||||
print(" fp8bf16: rtol=1e-2, atol=1.8e-1")
|
||||
print(" fp8fp32: rtol=1e-2, atol=1.8e-1")
|
||||
print(" fp8 raw: rtol=0, atol=16 (or 32 for >240 range)")
|
||||
|
||||
# Run basic fp16 for comparison if prebuilt available
|
||||
print("\n--- FP16 Baseline (prebuilt) ---")
|
||||
config_fp16 = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config_fp16)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
Q16 = Q.astype(np.float16)
|
||||
K16 = K.astype(np.float16)
|
||||
V16 = V.astype(np.float16)
|
||||
result = runner.run(Q16, K16, V16, prob)
|
||||
if result.success:
|
||||
max_err = float(np.abs(result.output.astype(np.float32) - O_ref).max())
|
||||
print(f" FP16 baseline: {result.time_ms:.4f} ms, max_err={max_err:.2e}")
|
||||
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" FP8 kernel configs demonstrated: {len(FP8_CONFIGS)}")
|
||||
print(" Note: Build fp8bf16/fp8fp32 kernels for GPU execution")
|
||||
print(" Status: PASS")
|
||||
print(f"{'=' * 70}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
235
dispatcher/examples/fmha/python/21_logits_soft_cap_fmha.py
Normal file
235
dispatcher/examples/fmha/python/21_logits_soft_cap_fmha.py
Normal file
@@ -0,0 +1,235 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 21: Logits Soft Cap FMHA
|
||||
|
||||
Demonstrates the logits soft cap feature, which prevents attention logits
|
||||
from growing unboundedly by applying: tanh(scores / soft_cap) * soft_cap
|
||||
before the softmax. This technique is used in models like Gemma-2 to
|
||||
stabilize training at large scale.
|
||||
|
||||
The prebuilt library does not include a logits_soft_cap kernel, so this
|
||||
example validates the CPU reference implementation and shows the API
|
||||
pattern for when a compiled kernel with logits=True is available.
|
||||
|
||||
Usage:
|
||||
python3 21_logits_soft_cap_fmha.py
|
||||
python3 21_logits_soft_cap_fmha.py --soft-cap 30.0
|
||||
python3 21_logits_soft_cap_fmha.py --seqlen 256
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def cpu_attention_fwd_logits_soft_cap(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
soft_cap: float,
|
||||
) -> np.ndarray:
|
||||
"""CPU reference: attention with logits soft cap.
|
||||
|
||||
Before softmax, scores are clamped via:
|
||||
scores = tanh(scores / soft_cap) * soft_cap
|
||||
|
||||
Args:
|
||||
Q: [batch, nhead_q, seqlen_q, hdim_q] float32
|
||||
K: [batch, nhead_k, seqlen_k, hdim_q] float32
|
||||
V: [batch, nhead_k, seqlen_k, hdim_v] float32
|
||||
scale: softmax scaling factor (1/sqrt(hdim_q))
|
||||
soft_cap: logits soft cap value (e.g. 50.0)
|
||||
|
||||
Returns:
|
||||
O: [batch, nhead_q, seqlen_q, hdim_v] float32
|
||||
"""
|
||||
nhead_q = Q.shape[1]
|
||||
nhead_k = K.shape[1]
|
||||
if nhead_q != nhead_k:
|
||||
ratio = nhead_q // nhead_k
|
||||
K = np.repeat(K, ratio, axis=1)
|
||||
V = np.repeat(V, ratio, axis=1)
|
||||
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
S = np.tanh(S / soft_cap) * soft_cap
|
||||
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
|
||||
return np.matmul(P, V)
|
||||
|
||||
|
||||
def show_soft_cap_effect(scale: float, soft_cap: float):
|
||||
"""Visualize the clamping effect of logits soft cap on score magnitudes."""
|
||||
raw_scores = np.array(
|
||||
[-100, -50, -20, -10, -5, 0, 5, 10, 20, 50, 100], dtype=np.float32
|
||||
)
|
||||
scaled = raw_scores * scale
|
||||
capped = np.tanh(scaled / soft_cap) * soft_cap
|
||||
|
||||
print(f"\n Soft cap effect (scale={scale:.4f}, soft_cap={soft_cap:.1f}):")
|
||||
print(
|
||||
f" {'Raw Score':>12} {'After Scale':>14} {'After Cap':>12} {'Reduction':>12}"
|
||||
)
|
||||
print(" " + "-" * 54)
|
||||
for r, s, c in zip(raw_scores, scaled, capped):
|
||||
reduction = abs(s) - abs(c) if abs(s) > 0 else 0
|
||||
print(f" {r:>12.1f} {s:>14.4f} {c:>12.4f} {reduction:>12.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Logits Soft Cap FMHA Example",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 21_logits_soft_cap_fmha.py # Default soft_cap=50
|
||||
python3 21_logits_soft_cap_fmha.py --soft-cap 30.0 # Tighter cap
|
||||
python3 21_logits_soft_cap_fmha.py --seqlen 256
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=128)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--soft-cap", type=float, default=50.0, help="Logits soft cap value"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 21: Logits Soft Cap FMHA")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
# Step 1: Demonstrate the soft cap transformation
|
||||
print("\nStep 1: Soft Cap Transformation")
|
||||
show_soft_cap_effect(prob.scale, args.soft_cap)
|
||||
|
||||
# Step 2: CPU reference comparison -- with vs without soft cap
|
||||
print("\nStep 2: CPU Reference (with vs without soft cap)")
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.5).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.5).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.5).astype(np.float32)
|
||||
|
||||
O_no_cap = cpu_attention_fwd(Q, K, V, prob.scale)
|
||||
O_capped = cpu_attention_fwd_logits_soft_cap(Q, K, V, prob.scale, args.soft_cap)
|
||||
|
||||
diff = np.abs(O_no_cap - O_capped)
|
||||
print(f"\n Shape: {prob.q_shape()}")
|
||||
print(f" Soft cap: {args.soft_cap}")
|
||||
print(f" Output range (no cap): [{O_no_cap.min():.4f}, {O_no_cap.max():.4f}]")
|
||||
print(f" Output range (capped): [{O_capped.min():.4f}, {O_capped.max():.4f}]")
|
||||
print(f" Max diff (cap effect): {diff.max():.6e}")
|
||||
print(f" Mean diff (cap effect): {diff.mean():.6e}")
|
||||
|
||||
# Step 3: Validate across different soft_cap values
|
||||
print("\nStep 3: Soft Cap Sweep")
|
||||
|
||||
soft_cap_values = [10.0, 20.0, 30.0, 50.0, 100.0, 500.0]
|
||||
validator = FmhaValidator(rtol=1e-4, atol=1e-4)
|
||||
|
||||
print(
|
||||
f"\n {'SoftCap':>10} {'OutRange':>20} {'vs NoCap MaxDiff':>18} {'vs NoCap MeanDiff':>18}"
|
||||
)
|
||||
print(" " + "-" * 70)
|
||||
|
||||
for sc in soft_cap_values:
|
||||
O_sc = cpu_attention_fwd_logits_soft_cap(Q, K, V, prob.scale, sc)
|
||||
d = np.abs(O_no_cap - O_sc)
|
||||
out_range = f"[{O_sc.min():.4f}, {O_sc.max():.4f}]"
|
||||
print(f" {sc:>10.1f} {out_range:>20} {d.max():>18.6e} {d.mean():>18.6e}")
|
||||
|
||||
# Step 4: Self-consistency -- large soft_cap should approach no-cap result
|
||||
print("\nStep 4: Self-Consistency Check")
|
||||
|
||||
O_large_cap = cpu_attention_fwd_logits_soft_cap(Q, K, V, prob.scale, 1e6)
|
||||
ok, max_abs, _ = validator.check(O_large_cap, O_no_cap)
|
||||
print(
|
||||
f" soft_cap=1e6 vs no_cap: max_err={max_abs:.2e} -> {'PASS' if ok else 'FAIL'}"
|
||||
)
|
||||
|
||||
# Step 5: GPU API pattern (requires logits=True kernel)
|
||||
print("\nStep 5: GPU Kernel Pattern")
|
||||
print(" NOTE: The prebuilt library does not include a logits_soft_cap kernel.")
|
||||
print(" To run on GPU, compile a kernel with logits=True in the signature:")
|
||||
print()
|
||||
print(" config = FmhaKernelConfig(")
|
||||
print(" family='fwd', data_type='fp16', hdim_q=128, hdim_v=128,")
|
||||
print(" pipeline='qr_async',")
|
||||
print(" )")
|
||||
print(' # In codegen JSON, set: "logits": true')
|
||||
print()
|
||||
print(" The dispatcher will pass logits_soft_cap to the kernel arguments.")
|
||||
|
||||
# Step 6: GPU run with standard kernel (no soft cap) for baseline
|
||||
print("\nStep 6: GPU Baseline (standard kernel, no soft cap)")
|
||||
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
Q_f16 = Q.astype(np.float16)
|
||||
K_f16 = K.astype(np.float16)
|
||||
V_f16 = V.astype(np.float16)
|
||||
|
||||
result = runner.run(Q_f16, K_f16, V_f16, prob)
|
||||
if result.success:
|
||||
ok_gpu, max_abs_gpu, _ = validator.check(result.output, O_no_cap)
|
||||
print(
|
||||
f" GPU (no cap): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
|
||||
f"max_err={max_abs_gpu:.2e} {'PASS' if ok_gpu else 'FAIL'}"
|
||||
)
|
||||
else:
|
||||
print(f" GPU error: {result.error}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(" Logits soft cap: tanh(scores / cap) * cap before softmax")
|
||||
print(f" Large cap -> standard attention (verified: max_err={max_abs:.2e})")
|
||||
print(" Small cap -> output variance reduced, stabilizes training")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
315
dispatcher/examples/fmha/python/22_sink_tokens_fmha.py
Normal file
315
dispatcher/examples/fmha/python/22_sink_tokens_fmha.py
Normal file
@@ -0,0 +1,315 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 22: Sink Token Attention FMHA
|
||||
|
||||
Demonstrates sink token attention where the first N "sink" tokens are
|
||||
always attended to regardless of the causal mask. This technique is used
|
||||
in StreamingLLM and similar approaches to keep a few initial tokens as
|
||||
attention anchors during long-context generation.
|
||||
|
||||
Mask format: t:left,right,sink -- a causal mask (top-left or bottom-right)
|
||||
where the first 'sink' positions are always unmasked.
|
||||
|
||||
The prebuilt library does not include a sink token kernel, so this
|
||||
example validates the CPU reference and shows the API pattern.
|
||||
|
||||
Usage:
|
||||
python3 22_sink_tokens_fmha.py
|
||||
python3 22_sink_tokens_fmha.py --sink-tokens 8
|
||||
python3 22_sink_tokens_fmha.py --seqlen 256 --window 64
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def make_causal_mask(seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Standard causal (top-left) mask: attend only to positions <= current."""
|
||||
mask = np.zeros((seqlen_q, seqlen_k), dtype=np.float32)
|
||||
for i in range(seqlen_q):
|
||||
for j in range(seqlen_k):
|
||||
if j <= i:
|
||||
mask[i, j] = 1.0
|
||||
return mask
|
||||
|
||||
|
||||
def make_causal_sink_mask(
|
||||
seqlen_q: int,
|
||||
seqlen_k: int,
|
||||
num_sink: int,
|
||||
) -> np.ndarray:
|
||||
"""Causal mask with sink tokens: always attend to first num_sink positions.
|
||||
|
||||
For each query position i:
|
||||
- Always attend to positions [0, num_sink) (sink tokens)
|
||||
- Also attend to positions [j] where j <= i (standard causal)
|
||||
"""
|
||||
mask = np.zeros((seqlen_q, seqlen_k), dtype=np.float32)
|
||||
for i in range(seqlen_q):
|
||||
for j in range(seqlen_k):
|
||||
if j < num_sink or j <= i:
|
||||
mask[i, j] = 1.0
|
||||
return mask
|
||||
|
||||
|
||||
def make_sliding_window_sink_mask(
|
||||
seqlen_q: int,
|
||||
seqlen_k: int,
|
||||
window: int,
|
||||
num_sink: int,
|
||||
) -> np.ndarray:
|
||||
"""Sliding window mask with sink tokens.
|
||||
|
||||
For each query position i:
|
||||
- Always attend to positions [0, num_sink) (sink tokens)
|
||||
- Attend to positions in [i - window + 1, i] (sliding window)
|
||||
"""
|
||||
mask = np.zeros((seqlen_q, seqlen_k), dtype=np.float32)
|
||||
for i in range(seqlen_q):
|
||||
for j in range(seqlen_k):
|
||||
if j < num_sink or (i - window + 1 <= j <= i):
|
||||
mask[i, j] = 1.0
|
||||
return mask
|
||||
|
||||
|
||||
def cpu_attention_fwd_masked(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
mask: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""CPU reference: attention with explicit mask.
|
||||
|
||||
Args:
|
||||
Q: [batch, nhead_q, seqlen_q, hdim_q] float32
|
||||
K: [batch, nhead_k, seqlen_k, hdim_q] float32
|
||||
V: [batch, nhead_k, seqlen_k, hdim_v] float32
|
||||
scale: softmax scale
|
||||
mask: [seqlen_q, seqlen_k] binary mask (1=attend, 0=ignore)
|
||||
|
||||
Returns:
|
||||
O: [batch, nhead_q, seqlen_q, hdim_v] float32
|
||||
"""
|
||||
nhead_q = Q.shape[1]
|
||||
nhead_k = K.shape[1]
|
||||
if nhead_q != nhead_k:
|
||||
ratio = nhead_q // nhead_k
|
||||
K = np.repeat(K, ratio, axis=1)
|
||||
V = np.repeat(V, ratio, axis=1)
|
||||
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
neg_inf = np.finfo(np.float32).min
|
||||
S = np.where(mask[np.newaxis, np.newaxis, :, :] > 0, S, neg_inf)
|
||||
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
|
||||
return np.matmul(P, V)
|
||||
|
||||
|
||||
def print_mask(mask: np.ndarray, name: str, max_display: int = 16):
|
||||
"""Print a small portion of a mask for visualization."""
|
||||
rows, cols = mask.shape
|
||||
rows_show = min(rows, max_display)
|
||||
cols_show = min(cols, max_display)
|
||||
print(f"\n {name} ({rows}x{cols}, showing {rows_show}x{cols_show}):")
|
||||
for i in range(rows_show):
|
||||
row_str = "".join("1" if mask[i, j] > 0 else "." for j in range(cols_show))
|
||||
print(f" q{i:02d}: {row_str}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Sink Token Attention FMHA Example",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=128)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--sink-tokens", type=int, default=4, help="Number of sink tokens"
|
||||
)
|
||||
parser.add_argument("--window", type=int, default=32, help="Sliding window size")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 22: Sink Token Attention FMHA")
|
||||
print("=" * 70)
|
||||
|
||||
sq = sk = args.seqlen
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
# Step 1: Visualize mask patterns
|
||||
print("\nStep 1: Mask Patterns")
|
||||
|
||||
causal = make_causal_mask(sq, sk)
|
||||
causal_sink = make_causal_sink_mask(sq, sk, args.sink_tokens)
|
||||
window_sink = make_sliding_window_sink_mask(sq, sk, args.window, args.sink_tokens)
|
||||
|
||||
vis_size = min(16, sq)
|
||||
print_mask(causal[:vis_size, :vis_size], "Causal (standard)", vis_size)
|
||||
print_mask(
|
||||
causal_sink[:vis_size, :vis_size],
|
||||
f"Causal + {args.sink_tokens} sink tokens",
|
||||
vis_size,
|
||||
)
|
||||
print_mask(
|
||||
window_sink[:vis_size, :vis_size],
|
||||
f"Window({args.window}) + {args.sink_tokens} sink tokens",
|
||||
vis_size,
|
||||
)
|
||||
|
||||
# Step 2: CPU reference for each mask type
|
||||
print("\n\nStep 2: CPU Reference Comparison")
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
|
||||
|
||||
O_no_mask = cpu_attention_fwd(Q, K, V, prob.scale)
|
||||
O_causal = cpu_attention_fwd_masked(Q, K, V, prob.scale, causal)
|
||||
O_causal_sink = cpu_attention_fwd_masked(Q, K, V, prob.scale, causal_sink)
|
||||
O_window_sink = cpu_attention_fwd_masked(Q, K, V, prob.scale, window_sink)
|
||||
|
||||
masks_and_outputs = [
|
||||
("No mask", O_no_mask),
|
||||
("Causal", O_causal),
|
||||
(f"Causal+sink({args.sink_tokens})", O_causal_sink),
|
||||
(f"Window({args.window})+sink({args.sink_tokens})", O_window_sink),
|
||||
]
|
||||
|
||||
print(f"\n {'Mask Type':<30} {'Output Range':>20} {'vs NoMask MaxDiff':>18}")
|
||||
print(" " + "-" * 70)
|
||||
for name, out in masks_and_outputs:
|
||||
d = np.abs(out - O_no_mask).max()
|
||||
out_range = f"[{out.min():.4f}, {out.max():.4f}]"
|
||||
print(f" {name:<30} {out_range:>20} {d:>18.6e}")
|
||||
|
||||
# Step 3: Verify sink tokens effect
|
||||
print("\nStep 3: Sink Token Effect Analysis")
|
||||
|
||||
diff_causal_vs_sink = np.abs(O_causal - O_causal_sink)
|
||||
print(" Causal vs Causal+Sink:")
|
||||
print(f" Max diff: {diff_causal_vs_sink.max():.6e}")
|
||||
print(f" Mean diff: {diff_causal_vs_sink.mean():.6e}")
|
||||
|
||||
n_attend_causal = causal.sum()
|
||||
n_attend_sink = causal_sink.sum()
|
||||
n_attend_window = window_sink.sum()
|
||||
print("\n Attention density:")
|
||||
print(
|
||||
f" Causal: {n_attend_causal:>8.0f} / {sq * sk} ({100 * n_attend_causal / (sq * sk):.1f}%)"
|
||||
)
|
||||
print(
|
||||
f" Causal+sink: {n_attend_sink:>8.0f} / {sq * sk} ({100 * n_attend_sink / (sq * sk):.1f}%)"
|
||||
)
|
||||
print(
|
||||
f" Window+sink: {n_attend_window:>8.0f} / {sq * sk} ({100 * n_attend_window / (sq * sk):.1f}%)"
|
||||
)
|
||||
|
||||
# Step 4: Sweep sink token count
|
||||
print("\nStep 4: Sink Token Sweep")
|
||||
|
||||
sink_counts = [0, 1, 2, 4, 8, 16]
|
||||
validator = FmhaValidator(rtol=1e-4, atol=1e-4)
|
||||
|
||||
print(
|
||||
f"\n {'Sinks':>6} {'Density':>10} {'vs Causal MaxDiff':>20} {'vs NoMask MaxDiff':>20}"
|
||||
)
|
||||
print(" " + "-" * 60)
|
||||
|
||||
for ns in sink_counts:
|
||||
if ns > sk:
|
||||
continue
|
||||
m = make_causal_sink_mask(sq, sk, ns)
|
||||
O_s = cpu_attention_fwd_masked(Q, K, V, prob.scale, m)
|
||||
d_causal = np.abs(O_s - O_causal).max()
|
||||
d_nomask = np.abs(O_s - O_no_mask).max()
|
||||
density = 100 * m.sum() / (sq * sk)
|
||||
print(f" {ns:>6} {density:>9.1f}% {d_causal:>20.6e} {d_nomask:>20.6e}")
|
||||
|
||||
# Step 5: GPU API pattern
|
||||
print("\nStep 5: GPU Kernel Pattern")
|
||||
print(" NOTE: The prebuilt library does not include a sink token kernel.")
|
||||
print(" To compile a sink-enabled kernel, use:")
|
||||
print()
|
||||
print(" FmhaSignature()")
|
||||
print(" .mask('top_left') // causal mask required with sink")
|
||||
print(" .sink(true) // enable sink tokens")
|
||||
print()
|
||||
print(" At runtime, pass sink count via the mask spec: 't:left,right,sink'")
|
||||
print(
|
||||
f" Example: 't:0,0,{args.sink_tokens}' for causal + {args.sink_tokens} sink tokens"
|
||||
)
|
||||
|
||||
# Step 6: GPU baseline (no mask, no sink)
|
||||
print("\nStep 6: GPU Baseline (standard kernel, no mask)")
|
||||
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
Q_f16 = Q.astype(np.float16)
|
||||
K_f16 = K.astype(np.float16)
|
||||
V_f16 = V.astype(np.float16)
|
||||
|
||||
result = runner.run(Q_f16, K_f16, V_f16, prob)
|
||||
if result.success:
|
||||
ok, max_abs, _ = validator.check(result.output, O_no_mask)
|
||||
print(
|
||||
f" GPU (no mask): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
|
||||
f"max_err={max_abs:.2e} {'PASS' if ok else 'FAIL'}"
|
||||
)
|
||||
else:
|
||||
print(f" GPU error: {result.error}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(" Sink token attention: first N tokens always attended regardless of mask")
|
||||
print(" Use case: StreamingLLM, long-context generation with attention anchors")
|
||||
print(" Sink tokens preserve global context that causal masking would discard")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
406
dispatcher/examples/fmha/python/23_batch_prefill_fmha.py
Normal file
406
dispatcher/examples/fmha/python/23_batch_prefill_fmha.py
Normal file
@@ -0,0 +1,406 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 23: Batch Prefill FMHA for SGLang/vLLM
|
||||
|
||||
Demonstrates batch prefill with paged KV-cache, as used in serving
|
||||
frameworks like SGLang and vLLM. Shows the KV page table configuration
|
||||
(kv_indptr, kv_page_indices, kv_last_page_lens) for both:
|
||||
- SGLang: 1D page table with indirect page lookup
|
||||
- vLLM: 2D block table with per-sequence page arrays
|
||||
|
||||
This example builds the page table metadata on CPU and validates the
|
||||
attention computation. The prebuilt library only supports the basic
|
||||
forward kernel, so the page table logic is demonstrated via CPU reference.
|
||||
|
||||
Usage:
|
||||
python3 23_batch_prefill_fmha.py
|
||||
python3 23_batch_prefill_fmha.py --page-size 64
|
||||
python3 23_batch_prefill_fmha.py --num-seqs 8
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def build_sglang_page_table(
|
||||
seq_lens_k: list,
|
||||
page_size: int,
|
||||
nhead_k: int,
|
||||
hdim: int,
|
||||
) -> dict:
|
||||
"""Build SGLang-style 1D page table for paged KV-cache.
|
||||
|
||||
SGLang uses a flat 1D array of page indices. Each sequence's pages are
|
||||
stored contiguously in the page_indices array, with indptr marking
|
||||
boundaries.
|
||||
|
||||
Returns dict with:
|
||||
kv_indptr: [num_seqs + 1] cumulative page counts
|
||||
kv_page_indices: [total_pages] global page IDs
|
||||
kv_last_page_lens: [num_seqs] tokens in last page of each seq
|
||||
num_total_pages: total pages allocated
|
||||
kv_data_shape: shape of the paged KV pool
|
||||
"""
|
||||
num_seqs = len(seq_lens_k)
|
||||
kv_indptr = np.zeros(num_seqs + 1, dtype=np.int32)
|
||||
page_indices_list = []
|
||||
last_page_lens = np.zeros(num_seqs, dtype=np.int32)
|
||||
|
||||
page_counter = 0
|
||||
for i, seqlen in enumerate(seq_lens_k):
|
||||
num_pages = (seqlen + page_size - 1) // page_size
|
||||
kv_indptr[i + 1] = kv_indptr[i] + num_pages
|
||||
page_indices_list.extend(range(page_counter, page_counter + num_pages))
|
||||
last_page_lens[i] = seqlen - (num_pages - 1) * page_size
|
||||
page_counter += num_pages
|
||||
|
||||
kv_page_indices = np.array(page_indices_list, dtype=np.int32)
|
||||
total_pages = page_counter
|
||||
|
||||
return {
|
||||
"kv_indptr": kv_indptr,
|
||||
"kv_page_indices": kv_page_indices,
|
||||
"kv_last_page_lens": last_page_lens,
|
||||
"num_total_pages": total_pages,
|
||||
"kv_data_shape": (total_pages, 2, nhead_k, page_size, hdim),
|
||||
"layout": "sglang_1d",
|
||||
}
|
||||
|
||||
|
||||
def build_vllm_block_table(
|
||||
seq_lens_k: list,
|
||||
page_size: int,
|
||||
nhead_k: int,
|
||||
hdim: int,
|
||||
) -> dict:
|
||||
"""Build vLLM-style 2D block table for paged KV-cache.
|
||||
|
||||
vLLM uses a 2D array [num_seqs, max_blocks_per_seq] where each entry
|
||||
is a block (page) index into the global KV pool.
|
||||
|
||||
Returns dict with:
|
||||
block_table: [num_seqs, max_blocks] page IDs (-1 = unused)
|
||||
kv_last_page_lens: [num_seqs] tokens in last page of each seq
|
||||
num_total_pages: total pages allocated
|
||||
kv_data_shape: shape of the paged KV pool
|
||||
"""
|
||||
num_seqs = len(seq_lens_k)
|
||||
pages_per_seq = [(s + page_size - 1) // page_size for s in seq_lens_k]
|
||||
max_blocks = max(pages_per_seq)
|
||||
|
||||
block_table = np.full((num_seqs, max_blocks), -1, dtype=np.int32)
|
||||
last_page_lens = np.zeros(num_seqs, dtype=np.int32)
|
||||
|
||||
page_counter = 0
|
||||
for i, (seqlen, num_pages) in enumerate(zip(seq_lens_k, pages_per_seq)):
|
||||
for p in range(num_pages):
|
||||
block_table[i, p] = page_counter
|
||||
page_counter += 1
|
||||
last_page_lens[i] = seqlen - (num_pages - 1) * page_size
|
||||
|
||||
return {
|
||||
"block_table": block_table,
|
||||
"kv_last_page_lens": last_page_lens,
|
||||
"num_total_pages": page_counter,
|
||||
"kv_data_shape": (page_counter, 2, nhead_k, page_size, hdim),
|
||||
"layout": "vllm_2d",
|
||||
}
|
||||
|
||||
|
||||
def scatter_kv_to_pages(
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
page_table: dict,
|
||||
page_size: int,
|
||||
) -> np.ndarray:
|
||||
"""Scatter contiguous K,V into paged KV pool using page table.
|
||||
|
||||
Args:
|
||||
K: [nhead_k, seqlen_k, hdim] float32 (single sequence)
|
||||
V: [nhead_k, seqlen_k, hdim] float32
|
||||
page_table: page indices for this sequence
|
||||
page_size: tokens per page
|
||||
"""
|
||||
nhead_k, seqlen_k, hdim = K.shape
|
||||
num_pages = (seqlen_k + page_size - 1) // page_size
|
||||
|
||||
pages = np.zeros((num_pages, 2, nhead_k, page_size, hdim), dtype=np.float32)
|
||||
for p in range(num_pages):
|
||||
start = p * page_size
|
||||
end = min(start + page_size, seqlen_k)
|
||||
length = end - start
|
||||
pages[p, 0, :, :length, :] = K[:, start:end, :]
|
||||
pages[p, 1, :, :length, :] = V[:, start:end, :]
|
||||
|
||||
return pages
|
||||
|
||||
|
||||
def gather_kv_from_pages(
|
||||
kv_pool: np.ndarray,
|
||||
page_indices: np.ndarray,
|
||||
seqlen_k: int,
|
||||
page_size: int,
|
||||
) -> tuple:
|
||||
"""Gather K,V from paged KV pool back to contiguous arrays.
|
||||
|
||||
Returns:
|
||||
K: [nhead_k, seqlen_k, hdim]
|
||||
V: [nhead_k, seqlen_k, hdim]
|
||||
"""
|
||||
nhead_k = kv_pool.shape[2]
|
||||
hdim = kv_pool.shape[4]
|
||||
K = np.zeros((nhead_k, seqlen_k, hdim), dtype=np.float32)
|
||||
V = np.zeros((nhead_k, seqlen_k, hdim), dtype=np.float32)
|
||||
|
||||
for p, page_idx in enumerate(page_indices):
|
||||
start = p * page_size
|
||||
end = min(start + page_size, seqlen_k)
|
||||
length = end - start
|
||||
K[:, start:end, :] = kv_pool[page_idx, 0, :, :length, :]
|
||||
V[:, start:end, :] = kv_pool[page_idx, 1, :, :length, :]
|
||||
|
||||
return K, V
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Batch Prefill FMHA for SGLang/vLLM",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--nhead-q", type=int, default=16)
|
||||
parser.add_argument("--nhead-k", type=int, default=4, help="KV heads (GQA)")
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument("--page-size", type=int, default=16)
|
||||
parser.add_argument("--num-seqs", type=int, default=4, help="Sequences in batch")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 23: Batch Prefill FMHA (SGLang/vLLM)")
|
||||
print("=" * 70)
|
||||
|
||||
seq_lens_q = [32, 64, 16, 48][: args.num_seqs]
|
||||
seq_lens_k = [256, 512, 128, 384][: args.num_seqs]
|
||||
|
||||
# Step 1: SGLang page table
|
||||
print("\nStep 1: SGLang 1D Page Table")
|
||||
|
||||
sglang_pt = build_sglang_page_table(
|
||||
seq_lens_k,
|
||||
args.page_size,
|
||||
args.nhead_k,
|
||||
args.hdim,
|
||||
)
|
||||
|
||||
print(f" Page size: {args.page_size}")
|
||||
print(f" Total pages: {sglang_pt['num_total_pages']}")
|
||||
print(f" KV pool shape: {sglang_pt['kv_data_shape']}")
|
||||
print(f" kv_indptr: {sglang_pt['kv_indptr']}")
|
||||
print(
|
||||
f" kv_page_indices: {sglang_pt['kv_page_indices'][:20]}{'...' if len(sglang_pt['kv_page_indices']) > 20 else ''}"
|
||||
)
|
||||
print(f" last_page_lens: {sglang_pt['kv_last_page_lens']}")
|
||||
|
||||
print("\n Per-sequence breakdown:")
|
||||
print(f" {'Seq':>5} {'SeqQ':>6} {'SeqK':>6} {'Pages':>6} {'LastLen':>8}")
|
||||
print(" " + "-" * 35)
|
||||
for i in range(args.num_seqs):
|
||||
n_pages = sglang_pt["kv_indptr"][i + 1] - sglang_pt["kv_indptr"][i]
|
||||
print(
|
||||
f" {i:>5} {seq_lens_q[i]:>6} {seq_lens_k[i]:>6} {n_pages:>6} {sglang_pt['kv_last_page_lens'][i]:>8}"
|
||||
)
|
||||
|
||||
# Step 2: vLLM block table
|
||||
print("\nStep 2: vLLM 2D Block Table")
|
||||
|
||||
vllm_pt = build_vllm_block_table(
|
||||
seq_lens_k,
|
||||
args.page_size,
|
||||
args.nhead_k,
|
||||
args.hdim,
|
||||
)
|
||||
|
||||
print(f" Block table shape: {vllm_pt['block_table'].shape}")
|
||||
print(f" Total pages: {vllm_pt['num_total_pages']}")
|
||||
for i in range(args.num_seqs):
|
||||
row = vllm_pt["block_table"][i]
|
||||
valid = row[row >= 0]
|
||||
print(f" Seq {i}: pages={valid.tolist()}")
|
||||
|
||||
# Step 3: Validate scatter/gather round-trip
|
||||
print("\nStep 3: KV Page Scatter/Gather Validation")
|
||||
|
||||
np.random.seed(42)
|
||||
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
|
||||
|
||||
total_pages = sglang_pt["num_total_pages"]
|
||||
kv_pool = np.zeros(
|
||||
(total_pages, 2, args.nhead_k, args.page_size, args.hdim),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
all_Q, all_K, all_V, all_O_ref = [], [], [], []
|
||||
|
||||
for i in range(args.num_seqs):
|
||||
sq, sk = seq_lens_q[i], seq_lens_k[i]
|
||||
Q_i = np.random.randn(args.nhead_q, sq, args.hdim).astype(np.float32) * 0.3
|
||||
K_i = np.random.randn(args.nhead_k, sk, args.hdim).astype(np.float32) * 0.3
|
||||
V_i = np.random.randn(args.nhead_k, sk, args.hdim).astype(np.float32) * 0.3
|
||||
|
||||
start_page = sglang_pt["kv_indptr"][i]
|
||||
end_page = sglang_pt["kv_indptr"][i + 1]
|
||||
page_indices = sglang_pt["kv_page_indices"][start_page:end_page]
|
||||
|
||||
pages = scatter_kv_to_pages(K_i, V_i, page_indices, args.page_size)
|
||||
for p_local, p_global in enumerate(page_indices):
|
||||
kv_pool[p_global] = pages[p_local]
|
||||
|
||||
K_rt, V_rt = gather_kv_from_pages(kv_pool, page_indices, sk, args.page_size)
|
||||
|
||||
k_ok = np.allclose(K_i, K_rt, atol=1e-7)
|
||||
v_ok = np.allclose(V_i, V_rt, atol=1e-7)
|
||||
print(
|
||||
f" Seq {i}: K round-trip={'OK' if k_ok else 'FAIL'} "
|
||||
f"V round-trip={'OK' if v_ok else 'FAIL'}"
|
||||
)
|
||||
|
||||
all_Q.append(Q_i)
|
||||
all_K.append(K_i)
|
||||
all_V.append(V_i)
|
||||
|
||||
# Step 4: CPU attention per-sequence
|
||||
print("\nStep 4: CPU Attention per Sequence (from Paged KV)")
|
||||
|
||||
print(f"\n {'Seq':>5} {'SeqQ':>6} {'SeqK':>6} {'OutRange':>22} {'Scale':>10}")
|
||||
print(" " + "-" * 50)
|
||||
|
||||
for i in range(args.num_seqs):
|
||||
sq, sk = seq_lens_q[i], seq_lens_k[i]
|
||||
Q_i = all_Q[i][np.newaxis] # [1, nhead_q, sq, hdim]
|
||||
K_i = all_K[i][np.newaxis] # [1, nhead_k, sk, hdim]
|
||||
V_i = all_V[i][np.newaxis] # [1, nhead_k, sk, hdim]
|
||||
|
||||
if args.nhead_q != args.nhead_k:
|
||||
ratio = args.nhead_q // args.nhead_k
|
||||
K_i_exp = np.repeat(K_i, ratio, axis=1)
|
||||
V_i_exp = np.repeat(V_i, ratio, axis=1)
|
||||
else:
|
||||
K_i_exp, V_i_exp = K_i, V_i
|
||||
|
||||
scale = 1.0 / (args.hdim**0.5)
|
||||
O_i = cpu_attention_fwd(Q_i, K_i_exp, V_i_exp, scale)
|
||||
all_O_ref.append(O_i)
|
||||
|
||||
out_range = f"[{O_i.min():.4f}, {O_i.max():.4f}]"
|
||||
print(f" {i:>5} {sq:>6} {sk:>6} {out_range:>22} {scale:>10.4f}")
|
||||
|
||||
# Step 5: Memory layout comparison
|
||||
print("\nStep 5: Memory Layout Analysis")
|
||||
|
||||
contiguous_bytes = sum(2 * args.nhead_k * sk * args.hdim * 4 for sk in seq_lens_k)
|
||||
paged_bytes = total_pages * 2 * args.nhead_k * args.page_size * args.hdim * 4
|
||||
overhead = (paged_bytes - contiguous_bytes) / contiguous_bytes * 100
|
||||
|
||||
print(f" Contiguous KV: {contiguous_bytes / 1024:.1f} KB")
|
||||
print(f" Paged KV pool: {paged_bytes / 1024:.1f} KB")
|
||||
print(f" Overhead: {overhead:.1f}% (due to page padding)")
|
||||
print(f" Pages used: {total_pages}")
|
||||
print(f" Avg tokens/seq: {sum(seq_lens_k) / args.num_seqs:.0f}")
|
||||
|
||||
# Step 6: GPU API pattern
|
||||
print("\nStep 6: GPU Kernel Configuration")
|
||||
print(" NOTE: The prebuilt library uses basic forward kernels.")
|
||||
print(" For batch prefill, compile a kernel with:")
|
||||
print()
|
||||
print(" FmhaSignature()")
|
||||
print(" .family('batch_prefill')")
|
||||
print(" .mode('group')")
|
||||
print(" .paged_kv(true)")
|
||||
print(" .kv_cache('vectorized', 'sglang', page_size)")
|
||||
print(" .lse(true)")
|
||||
print()
|
||||
print(" FmhaKernelConfig codegen JSON:")
|
||||
print(" 'family': 'batch_prefill',")
|
||||
print(" 'mode': 'group',")
|
||||
print(" 'paged_kv': true,")
|
||||
print(" 'kv_memory_layout': 'vectorized',")
|
||||
print(" 'kv_lookup_table': 'sglang' or 'vllm',")
|
||||
print(f" 'page_size': {args.page_size}")
|
||||
|
||||
# Step 7: GPU baseline (contiguous, no paging)
|
||||
print("\nStep 7: GPU Baseline (contiguous KV, single sequence)")
|
||||
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=1,
|
||||
nhead_q=args.nhead_q,
|
||||
nhead_k=args.nhead_k,
|
||||
seqlen_q=64,
|
||||
seqlen_k=256,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
Q_gpu = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float16)
|
||||
K_gpu = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float16)
|
||||
V_gpu = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float16)
|
||||
|
||||
result = runner.run(Q_gpu, K_gpu, V_gpu, prob)
|
||||
if result.success:
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q_gpu.astype(np.float32),
|
||||
K_gpu.astype(np.float32),
|
||||
V_gpu.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
ok, max_abs, _ = validator.check(result.output, O_ref)
|
||||
print(
|
||||
f" GPU baseline: time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
|
||||
f"max_err={max_abs:.2e} {'PASS' if ok else 'FAIL'}"
|
||||
)
|
||||
else:
|
||||
print(f" GPU error: {result.error}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(" Batch prefill: serves multiple prefill requests in a single kernel launch")
|
||||
print(" SGLang: 1D page table (kv_indptr + kv_page_indices)")
|
||||
print(" vLLM: 2D block table [num_seqs, max_blocks]")
|
||||
print(
|
||||
f" Page size {args.page_size} -> {overhead:.1f}% memory overhead vs contiguous"
|
||||
)
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
250
dispatcher/examples/fmha/python/24_vlayout_col_fmha.py
Normal file
250
dispatcher/examples/fmha/python/24_vlayout_col_fmha.py
Normal file
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 24: Column-Major V Layout FMHA
|
||||
|
||||
Demonstrates column-major (vlayout="c") vs row-major (vlayout="r") for
|
||||
the V tensor. In row-major, V is [batch, nhead, seqlen_k, hdim_v]; in
|
||||
column-major, V is [batch, nhead, hdim_v, seqlen_k].
|
||||
|
||||
Column-major V can improve performance when hdim_v access patterns
|
||||
benefit from the transposed layout (e.g., certain tile sizes or memory
|
||||
coalescing characteristics on specific GPU architectures).
|
||||
|
||||
The prebuilt library uses row-major V. This example shows both layouts
|
||||
with CPU reference and validates correctness.
|
||||
|
||||
Usage:
|
||||
python3 24_vlayout_col_fmha.py
|
||||
python3 24_vlayout_col_fmha.py --seqlen 512
|
||||
python3 24_vlayout_col_fmha.py --batch 4
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def cpu_attention_fwd_vlayout_col(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V_col: np.ndarray,
|
||||
scale: float,
|
||||
) -> np.ndarray:
|
||||
"""CPU reference: attention with column-major V.
|
||||
|
||||
Args:
|
||||
Q: [batch, nhead_q, seqlen_q, hdim_q] float32 (row-major)
|
||||
K: [batch, nhead_k, seqlen_k, hdim_q] float32 (row-major)
|
||||
V_col: [batch, nhead_k, hdim_v, seqlen_k] float32 (column-major)
|
||||
scale: softmax scale
|
||||
|
||||
Returns:
|
||||
O: [batch, nhead_q, seqlen_q, hdim_v] float32
|
||||
"""
|
||||
V_row = V_col.transpose(0, 1, 3, 2)
|
||||
return cpu_attention_fwd(Q, K, V_row, scale)
|
||||
|
||||
|
||||
def analyze_strides(name: str, arr: np.ndarray, dim_names: list):
|
||||
"""Print stride information for a tensor."""
|
||||
strides_bytes = arr.strides
|
||||
itemsize = arr.itemsize
|
||||
strides_elems = tuple(s // itemsize for s in strides_bytes)
|
||||
print(f" {name}:")
|
||||
print(f" Shape: {arr.shape}")
|
||||
print(f" Strides: {strides_elems} (elements)")
|
||||
for i, (dname, s) in enumerate(zip(dim_names, strides_elems)):
|
||||
contiguous = "(contiguous)" if i == len(dim_names) - 1 and s == 1 else ""
|
||||
print(f" {dname}: stride={s} {contiguous}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Column-Major V Layout FMHA Example",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=128)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 24: Column-Major V Layout FMHA")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
# Step 1: Layout comparison
|
||||
print("\nStep 1: V Tensor Layouts")
|
||||
|
||||
np.random.seed(42)
|
||||
V_row = np.ascontiguousarray(
|
||||
(np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
|
||||
)
|
||||
V_col = np.ascontiguousarray(V_row.transpose(0, 1, 3, 2))
|
||||
|
||||
analyze_strides(
|
||||
"V row-major [B, H, SeqK, Hdim]",
|
||||
V_row,
|
||||
["batch", "nhead", "seqlen_k", "hdim_v"],
|
||||
)
|
||||
analyze_strides(
|
||||
"V col-major [B, H, Hdim, SeqK]",
|
||||
V_col,
|
||||
["batch", "nhead", "hdim_v", "seqlen_k"],
|
||||
)
|
||||
|
||||
print("\n Row-major: last dim is hdim_v -> sequential hdim access per token")
|
||||
print(" Col-major: last dim is seqlen_k -> sequential token access per hdim")
|
||||
|
||||
# Step 2: CPU reference for both layouts
|
||||
print("\nStep 2: CPU Reference (both layouts)")
|
||||
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
|
||||
|
||||
O_from_row = cpu_attention_fwd(Q, K, V_row, prob.scale)
|
||||
O_from_col = cpu_attention_fwd_vlayout_col(Q, K, V_col, prob.scale)
|
||||
|
||||
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
|
||||
ok, max_abs, max_rel = validator.check(O_from_row, O_from_col)
|
||||
|
||||
print(
|
||||
f" O from row-major V: shape={O_from_row.shape} "
|
||||
f"range=[{O_from_row.min():.4f}, {O_from_row.max():.4f}]"
|
||||
)
|
||||
print(
|
||||
f" O from col-major V: shape={O_from_col.shape} "
|
||||
f"range=[{O_from_col.min():.4f}, {O_from_col.max():.4f}]"
|
||||
)
|
||||
print(f" Max abs error: {max_abs:.2e}")
|
||||
print(f" Match: {'PASS' if ok else 'FAIL'}")
|
||||
|
||||
# Step 3: Memory access pattern analysis
|
||||
print("\nStep 3: Memory Access Pattern Analysis")
|
||||
|
||||
tile_sizes = [(128, 128), (64, 128), (128, 64)]
|
||||
print("\n For P @ V matmul (P: [sq, sk] x V: [sk, hdim_v]):")
|
||||
print(f" {'Tile(M,N)':>12} {'V Row Accesses':>18} {'V Col Accesses':>18}")
|
||||
print(" " + "-" * 52)
|
||||
|
||||
for tm, tn in tile_sizes:
|
||||
row_access = f"sk_stride={args.hdim}"
|
||||
col_access = "sk_stride=1"
|
||||
print(f" {f'{tm}x{tn}':>12} {row_access:>18} {col_access:>18}")
|
||||
|
||||
print("\n Row-major V: coalesced reads when accessing hdim_v (inner loop)")
|
||||
print(" Col-major V: coalesced reads when accessing seqlen_k (inner loop)")
|
||||
print(" Optimal layout depends on tile shape and GPU memory subsystem")
|
||||
|
||||
# Step 4: Shape sweep with both layouts
|
||||
print("\nStep 4: Correctness Sweep")
|
||||
|
||||
shapes = [
|
||||
(1, 4, 64, 64, 64),
|
||||
(2, 8, 128, 128, 128),
|
||||
(1, 8, 256, 256, 128),
|
||||
(2, 4, 128, 128, 64),
|
||||
(1, 16, 64, 64, 128),
|
||||
]
|
||||
|
||||
print(f"\n {'Shape':<32} {'MaxErr':>12} {'Status':>8}")
|
||||
print(" " + "-" * 55)
|
||||
|
||||
all_ok = True
|
||||
for b, h, sq, sk, d in shapes:
|
||||
Q_t = (np.random.randn(b, h, sq, d) * 0.3).astype(np.float32)
|
||||
K_t = (np.random.randn(b, h, sk, d) * 0.3).astype(np.float32)
|
||||
V_r = (np.random.randn(b, h, sk, d) * 0.3).astype(np.float32)
|
||||
V_c = np.ascontiguousarray(V_r.transpose(0, 1, 3, 2))
|
||||
|
||||
scale = 1.0 / (d**0.5)
|
||||
O_r = cpu_attention_fwd(Q_t, K_t, V_r, scale)
|
||||
O_c = cpu_attention_fwd_vlayout_col(Q_t, K_t, V_c, scale)
|
||||
|
||||
ok_t, max_abs_t, _ = validator.check(O_r, O_c)
|
||||
all_ok = all_ok and ok_t
|
||||
shape_str = f"B{b}_H{h}_S{sq}x{sk}_D{d}"
|
||||
print(f" {shape_str:<32} {max_abs_t:>12.2e} {'PASS' if ok_t else 'FAIL':>8}")
|
||||
|
||||
# Step 5: GPU API pattern
|
||||
print("\nStep 5: GPU Kernel Configuration")
|
||||
print(" NOTE: The prebuilt library uses row-major V (vlayout='r').")
|
||||
print(" For column-major V, compile a kernel with vlayout='c':")
|
||||
print()
|
||||
print(" FmhaSignature()")
|
||||
print(" .vlayout('c') // column-major V: [B, H, Hdim, SeqK]")
|
||||
print()
|
||||
print(" FmhaKernelConfig(vlayout='c', ...)")
|
||||
|
||||
# Step 6: GPU baseline (row-major)
|
||||
print("\nStep 6: GPU Baseline (row-major V)")
|
||||
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
Q_f16 = Q.astype(np.float16)
|
||||
K_f16 = K.astype(np.float16)
|
||||
V_f16 = V_row.astype(np.float16)
|
||||
|
||||
result = runner.run(Q_f16, K_f16, V_f16, prob)
|
||||
if result.success:
|
||||
ok_gpu, max_abs_gpu, _ = validator.check(result.output, O_from_row)
|
||||
print(
|
||||
f" GPU (row-major V): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
|
||||
f"max_err={max_abs_gpu:.2e} {'PASS' if ok_gpu else 'FAIL'}"
|
||||
)
|
||||
else:
|
||||
print(f" GPU error: {result.error}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(" vlayout='r': V is [B, H, SeqK, Hdim] (default, row-major)")
|
||||
print(" vlayout='c': V is [B, H, Hdim, SeqK] (column-major)")
|
||||
print(
|
||||
f" Both layouts produce identical results (verified: {'PASS' if all_ok else 'FAIL'})"
|
||||
)
|
||||
print(" Choice depends on upstream memory layout and GPU tile access patterns")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
262
dispatcher/examples/fmha/python/25_permutation_fmha.py
Normal file
262
dispatcher/examples/fmha/python/25_permutation_fmha.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 25: Input/Output Permutation FMHA
|
||||
|
||||
Demonstrates different memory layouts for Q/K/V/O tensors via
|
||||
input permutation (iperm) and output permutation (operm):
|
||||
|
||||
iperm=0 (bshd): [batch, seqlen, nhead, hdim] -- used by some frameworks
|
||||
iperm=1 (bhsd): [batch, nhead, seqlen, hdim] -- standard/default
|
||||
|
||||
operm=0 (bshd): O is [batch, seqlen, nhead, hdim]
|
||||
operm=1 (bhsd): O is [batch, nhead, seqlen, hdim]
|
||||
|
||||
The prebuilt library uses bhsd layout (iperm=1, operm=1). This example
|
||||
shows how to convert between layouts and validates correctness.
|
||||
|
||||
Usage:
|
||||
python3 25_permutation_fmha.py
|
||||
python3 25_permutation_fmha.py --seqlen 256
|
||||
python3 25_permutation_fmha.py --batch 4
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def bhsd_to_bshd(x: np.ndarray) -> np.ndarray:
|
||||
"""Convert [batch, nhead, seqlen, hdim] -> [batch, seqlen, nhead, hdim]."""
|
||||
return x.transpose(0, 2, 1, 3)
|
||||
|
||||
|
||||
def bshd_to_bhsd(x: np.ndarray) -> np.ndarray:
|
||||
"""Convert [batch, seqlen, nhead, hdim] -> [batch, nhead, seqlen, hdim]."""
|
||||
return x.transpose(0, 2, 1, 3)
|
||||
|
||||
|
||||
def cpu_attention_fwd_bshd(
|
||||
Q_bshd: np.ndarray,
|
||||
K_bshd: np.ndarray,
|
||||
V_bshd: np.ndarray,
|
||||
scale: float,
|
||||
operm: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""CPU reference with bshd input, configurable output layout.
|
||||
|
||||
Args:
|
||||
Q_bshd: [batch, seqlen_q, nhead_q, hdim_q] float32
|
||||
K_bshd: [batch, seqlen_k, nhead_k, hdim_q] float32
|
||||
V_bshd: [batch, seqlen_k, nhead_k, hdim_v] float32
|
||||
scale: softmax scale
|
||||
operm: 0 -> output bshd, 1 -> output bhsd
|
||||
|
||||
Returns:
|
||||
O: float32 in requested layout
|
||||
"""
|
||||
Q_bhsd = bshd_to_bhsd(Q_bshd)
|
||||
K_bhsd = bshd_to_bhsd(K_bshd)
|
||||
V_bhsd = bshd_to_bhsd(V_bshd)
|
||||
|
||||
O_bhsd = cpu_attention_fwd(Q_bhsd, K_bhsd, V_bhsd, scale)
|
||||
|
||||
if operm == 0:
|
||||
return bhsd_to_bshd(O_bhsd)
|
||||
return O_bhsd
|
||||
|
||||
|
||||
def describe_layout(arr: np.ndarray, layout_name: str, dim_names: list):
|
||||
"""Print layout details including strides."""
|
||||
itemsize = arr.itemsize
|
||||
strides_elems = tuple(s // itemsize for s in arr.strides)
|
||||
is_contiguous = arr.flags["C_CONTIGUOUS"]
|
||||
print(f" {layout_name}:")
|
||||
print(f" Shape: {arr.shape}")
|
||||
print(f" Strides: {strides_elems} (elements)")
|
||||
print(f" Contiguous: {is_contiguous}")
|
||||
for dname, s in zip(dim_names, strides_elems):
|
||||
print(f" {dname:>8}: stride={s}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Input/Output Permutation FMHA Example",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=128)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 25: Input/Output Permutation FMHA")
|
||||
print("=" * 70)
|
||||
|
||||
B, H, S, D = args.batch, args.nhead, args.seqlen, args.hdim
|
||||
prob = FmhaProblem(
|
||||
batch=B,
|
||||
nhead_q=H,
|
||||
nhead_k=H,
|
||||
seqlen_q=S,
|
||||
seqlen_k=S,
|
||||
hdim_q=D,
|
||||
hdim_v=D,
|
||||
)
|
||||
|
||||
# Step 1: Layout definitions
|
||||
print("\nStep 1: Layout Definitions")
|
||||
|
||||
np.random.seed(42)
|
||||
Q_bhsd = np.ascontiguousarray(
|
||||
(np.random.randn(B, H, S, D) * 0.3).astype(np.float32)
|
||||
)
|
||||
Q_bshd = np.ascontiguousarray(bhsd_to_bshd(Q_bhsd))
|
||||
|
||||
describe_layout(Q_bhsd, "bhsd (iperm=1)", ["batch", "nhead", "seqlen", "hdim"])
|
||||
describe_layout(Q_bshd, "bshd (iperm=0)", ["batch", "seqlen", "nhead", "hdim"])
|
||||
|
||||
print("\n Key difference:")
|
||||
print(" bhsd: heads are contiguous -> good for per-head parallelism")
|
||||
print(" bshd: tokens are contiguous -> good for sequence parallelism")
|
||||
|
||||
# Step 2: All permutation combinations
|
||||
print("\nStep 2: All Permutation Combinations (CPU Reference)")
|
||||
|
||||
K_bhsd = (np.random.randn(B, H, S, D) * 0.3).astype(np.float32)
|
||||
V_bhsd = (np.random.randn(B, H, S, D) * 0.3).astype(np.float32)
|
||||
K_bshd = np.ascontiguousarray(bhsd_to_bshd(K_bhsd))
|
||||
V_bshd = np.ascontiguousarray(bhsd_to_bshd(V_bhsd))
|
||||
|
||||
O_ref_bhsd = cpu_attention_fwd(Q_bhsd, K_bhsd, V_bhsd, prob.scale)
|
||||
O_ref_bshd = bhsd_to_bshd(O_ref_bhsd)
|
||||
|
||||
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
|
||||
|
||||
combos = [
|
||||
("iperm=1 operm=1", "bhsd->bhsd", Q_bhsd, K_bhsd, V_bhsd, 1, O_ref_bhsd),
|
||||
("iperm=1 operm=0", "bhsd->bshd", Q_bhsd, K_bhsd, V_bhsd, 0, O_ref_bshd),
|
||||
("iperm=0 operm=1", "bshd->bhsd", Q_bshd, K_bshd, V_bshd, 1, O_ref_bhsd),
|
||||
("iperm=0 operm=0", "bshd->bshd", Q_bshd, K_bshd, V_bshd, 0, O_ref_bshd),
|
||||
]
|
||||
|
||||
print(
|
||||
f"\n {'Config':<18} {'Transform':<14} {'OutShape':>24} {'MaxErr':>12} {'Status':>8}"
|
||||
)
|
||||
print(" " + "-" * 80)
|
||||
|
||||
all_ok = True
|
||||
for name, transform, Q_in, K_in, V_in, operm, O_expected in combos:
|
||||
if Q_in.shape[1] == H:
|
||||
O_out = cpu_attention_fwd(Q_in, K_in, V_in, prob.scale)
|
||||
if operm == 0:
|
||||
O_out = bhsd_to_bshd(O_out)
|
||||
else:
|
||||
O_out = cpu_attention_fwd_bshd(Q_in, K_in, V_in, prob.scale, operm)
|
||||
|
||||
ok, max_abs, _ = validator.check(O_out, O_expected)
|
||||
all_ok = all_ok and ok
|
||||
print(
|
||||
f" {name:<18} {transform:<14} {str(O_out.shape):>24} {max_abs:>12.2e} {'PASS' if ok else 'FAIL':>8}"
|
||||
)
|
||||
|
||||
# Step 3: Stride comparison table
|
||||
print("\nStep 3: Stride Comparison")
|
||||
|
||||
print(f"\n For B={B}, H={H}, S={S}, D={D}:")
|
||||
print(f" {'Layout':>8} {'Dim Order':>16} {'Strides':>28} {'hdim contiguous':>18}")
|
||||
print(" " + "-" * 74)
|
||||
|
||||
bhsd_strides = (H * S * D, S * D, D, 1)
|
||||
bshd_strides = (S * H * D, H * D, D, 1)
|
||||
|
||||
print(f" {'bhsd':>8} {'B,H,S,D':>16} {str(bhsd_strides):>28} {'Yes':>18}")
|
||||
print(f" {'bshd':>8} {'B,S,H,D':>16} {str(bshd_strides):>28} {'Yes':>18}")
|
||||
|
||||
print("\n Stride analysis:")
|
||||
print(f" bhsd: advancing 1 token = skip {D} elements (hdim)")
|
||||
print(f" bshd: advancing 1 token = skip {H * D} elements (nhead * hdim)")
|
||||
print(f" bhsd: advancing 1 head = skip {S * D} elements (seqlen * hdim)")
|
||||
print(f" bshd: advancing 1 head = skip {D} elements (hdim)")
|
||||
|
||||
# Step 4: Conversion cost
|
||||
print("\nStep 4: Layout Conversion Cost")
|
||||
|
||||
tensor_bytes = B * H * S * D * 4
|
||||
print(f" Tensor size: {tensor_bytes / 1024:.1f} KB (float32)")
|
||||
print(" bhsd <-> bshd conversion: transpose(0,2,1,3) + contiguous copy")
|
||||
print(
|
||||
" If upstream provides bshd and kernel wants bhsd, conversion costs ~2x memory bandwidth"
|
||||
)
|
||||
print(" Using iperm parameter avoids this copy by adjusting kernel strides")
|
||||
|
||||
# Step 5: GPU run (bhsd, default layout)
|
||||
print("\nStep 5: GPU Run (bhsd layout, iperm=1)")
|
||||
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
Q_f16 = Q_bhsd.astype(np.float16)
|
||||
K_f16 = K_bhsd.astype(np.float16)
|
||||
V_f16 = V_bhsd.astype(np.float16)
|
||||
|
||||
result = runner.run(Q_f16, K_f16, V_f16, prob)
|
||||
if result.success:
|
||||
ok_gpu, max_abs_gpu, _ = validator.check(result.output, O_ref_bhsd)
|
||||
print(
|
||||
f" GPU (bhsd): time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
|
||||
f"max_err={max_abs_gpu:.2e} {'PASS' if ok_gpu else 'FAIL'}"
|
||||
)
|
||||
else:
|
||||
print(f" GPU error: {result.error}")
|
||||
|
||||
# Step 6: Kernel configuration for bshd
|
||||
print("\nStep 6: GPU Kernel Configuration for bshd")
|
||||
print(" The prebuilt library uses bhsd (iperm=1, operm=1).")
|
||||
print(" For bshd input/output, the kernel adjusts internal strides:")
|
||||
print()
|
||||
print(" iperm=0: kernel reads Q,K,V as [B, S, H, D] with stride_head=D")
|
||||
print(" iperm=1: kernel reads Q,K,V as [B, H, S, D] with stride_seq=D")
|
||||
print(" operm=0: kernel writes O as [B, S, H, D]")
|
||||
print(" operm=1: kernel writes O as [B, H, S, D]")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(" iperm=0 (bshd): [B, S, H, D] -- sequence-first layout")
|
||||
print(" iperm=1 (bhsd): [B, H, S, D] -- head-first layout (default)")
|
||||
print(f" All 4 combinations validated: {'PASS' if all_ok else 'FAIL'}")
|
||||
print(" Use iperm/operm to match upstream/downstream layout without copies")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
268
dispatcher/examples/fmha/python/26_hdim_variety_fmha.py
Normal file
268
dispatcher/examples/fmha/python/26_hdim_variety_fmha.py
Normal file
@@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 26: Head Dimension Variety FMHA
|
||||
|
||||
Demonstrates FMHA with multiple head dimensions (32, 64, 128, 256) and
|
||||
asymmetric hdim (hdim_q != hdim_v). Different head dimensions require
|
||||
different tile sizes and kernel configurations for optimal performance.
|
||||
|
||||
The prebuilt library supports hdim=128 only. This example validates all
|
||||
head dimensions via CPU reference and runs GPU for hdim=128.
|
||||
|
||||
Usage:
|
||||
python3 26_hdim_variety_fmha.py
|
||||
python3 26_hdim_variety_fmha.py --seqlen 256
|
||||
python3 26_hdim_variety_fmha.py --batch 4
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
|
||||
def recommended_tile(hdim: int) -> str:
|
||||
"""Suggest tile configuration for a given head dimension."""
|
||||
tiles = {
|
||||
32: "128x128x32x32x32x32",
|
||||
64: "128x64x32x64x32x64",
|
||||
128: "128x128x32x128x32x128",
|
||||
256: "128x128x32x256x32x256",
|
||||
}
|
||||
return tiles.get(hdim, f"auto (hdim={hdim})")
|
||||
|
||||
|
||||
def compute_flops(
|
||||
batch: int, nhead_q: int, sq: int, sk: int, hdim_q: int, hdim_v: int
|
||||
) -> int:
|
||||
"""Compute FMHA FLOPs accounting for asymmetric hdim."""
|
||||
return 2 * batch * nhead_q * sq * sk * (hdim_q + hdim_v)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Head Dimension Variety FMHA Example",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 26: Head Dimension Variety FMHA")
|
||||
print("=" * 70)
|
||||
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
|
||||
# Step 1: Symmetric head dimensions
|
||||
print("\nStep 1: Symmetric Head Dimensions (hdim_q == hdim_v)")
|
||||
|
||||
hdims = [32, 64, 128, 256]
|
||||
|
||||
print(f"\n {'hdim':>6} {'Shape':>30} {'Tile Config':>30} {'FLOPs':>14}")
|
||||
print(" " + "-" * 84)
|
||||
|
||||
for hdim in hdims:
|
||||
shape = f"B{args.batch}_H{args.nhead}_S{args.seqlen}_D{hdim}"
|
||||
tile = recommended_tile(hdim)
|
||||
flops = compute_flops(
|
||||
args.batch, args.nhead, args.seqlen, args.seqlen, hdim, hdim
|
||||
)
|
||||
print(f" {hdim:>6} {shape:>30} {tile:>30} {flops:>14,}")
|
||||
|
||||
# Step 2: CPU validation for each hdim
|
||||
print("\nStep 2: CPU Validation")
|
||||
|
||||
np.random.seed(42)
|
||||
|
||||
print(
|
||||
f"\n {'hdim_q':>7} {'hdim_v':>7} {'Scale':>10} {'OutRange':>22} {'SelfCheck':>10}"
|
||||
)
|
||||
print(" " + "-" * 60)
|
||||
|
||||
cpu_results = {}
|
||||
for hdim in hdims:
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=hdim,
|
||||
hdim_v=hdim,
|
||||
)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
|
||||
|
||||
O_ref = cpu_attention_fwd(Q, K, V, prob.scale)
|
||||
|
||||
self_ok = np.all(np.isfinite(O_ref))
|
||||
out_range = f"[{O_ref.min():.4f}, {O_ref.max():.4f}]"
|
||||
print(
|
||||
f" {hdim:>7} {hdim:>7} {prob.scale:>10.4f} {out_range:>22} {'OK' if self_ok else 'NaN!':>10}"
|
||||
)
|
||||
|
||||
cpu_results[hdim] = (Q, K, V, O_ref, prob)
|
||||
|
||||
# Step 3: Asymmetric head dimensions
|
||||
print("\nStep 3: Asymmetric Head Dimensions (hdim_q != hdim_v)")
|
||||
|
||||
asymmetric_configs = [
|
||||
(128, 64, "Large Q, small V: more attention capacity, compact output"),
|
||||
(64, 128, "Small Q, large V: compact attention, rich output"),
|
||||
(128, 256, "Standard Q, very large V: high-capacity value projection"),
|
||||
(256, 128, "Large Q, standard V: wide attention field"),
|
||||
(32, 128, "Tiny Q, standard V: minimal attention compute"),
|
||||
]
|
||||
|
||||
print(
|
||||
f"\n {'hdim_q':>7} {'hdim_v':>7} {'Q Shape':>22} {'O Shape':>22} {'MaxErr vs self':>16}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
for hdim_q, hdim_v, desc in asymmetric_configs:
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=hdim_q,
|
||||
hdim_v=hdim_v,
|
||||
)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
|
||||
|
||||
out = cpu_attention_fwd(Q, K, V, prob.scale)
|
||||
|
||||
O2 = cpu_attention_fwd(Q, K, V, prob.scale)
|
||||
max_err = float(np.abs(out - O2).max())
|
||||
|
||||
print(
|
||||
f" {hdim_q:>7} {hdim_v:>7} {str(prob.q_shape()):>22} {str(prob.o_shape()):>22} {max_err:>16.2e}"
|
||||
)
|
||||
|
||||
print("\n Asymmetric hdim notes:")
|
||||
for hdim_q, hdim_v, desc in asymmetric_configs:
|
||||
print(f" hdim_q={hdim_q}, hdim_v={hdim_v}: {desc}")
|
||||
|
||||
# Step 4: GPU validation (hdim=128)
|
||||
print("\nStep 4: GPU Validation (hdim=128)")
|
||||
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=128,
|
||||
hdim_v=128,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
gpu_tflops = 0.0
|
||||
gpu_time = 0.0
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
Q, K, V, O_ref, prob = cpu_results[128]
|
||||
Q_f16 = Q.astype(np.float16)
|
||||
K_f16 = K.astype(np.float16)
|
||||
V_f16 = V.astype(np.float16)
|
||||
|
||||
result = runner.run(Q_f16, K_f16, V_f16, prob)
|
||||
if result.success:
|
||||
ok, max_abs, _ = validator.check(result.output, O_ref)
|
||||
print(
|
||||
f" GPU hdim=128: time={result.time_ms:.4f}ms TFLOPS={result.tflops:.2f} "
|
||||
f"max_err={max_abs:.2e} {'PASS' if ok else 'FAIL'}"
|
||||
)
|
||||
|
||||
gpu_tflops = result.tflops
|
||||
gpu_time = result.time_ms
|
||||
else:
|
||||
print(f" GPU error: {result.error}")
|
||||
|
||||
# Step 5: Performance projection table
|
||||
print("\nStep 5: Performance Summary Table")
|
||||
|
||||
print(
|
||||
f"\n {'hdim_q':>7} | {'hdim_v':>7} | {'FLOPs':>14} | {'Tile':>24} | {'GPU Support':>12}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
for hdim in hdims:
|
||||
flops = compute_flops(
|
||||
args.batch, args.nhead, args.seqlen, args.seqlen, hdim, hdim
|
||||
)
|
||||
tile = recommended_tile(hdim)
|
||||
gpu_ok = "prebuilt" if hdim == 128 else "needs JIT"
|
||||
print(f" {hdim:>7} | {hdim:>7} | {flops:>14,} | {tile:>24} | {gpu_ok:>12}")
|
||||
|
||||
print(" " + "-" * 78)
|
||||
|
||||
for hdim_q, hdim_v, _ in asymmetric_configs[:3]:
|
||||
flops = compute_flops(
|
||||
args.batch, args.nhead, args.seqlen, args.seqlen, hdim_q, hdim_v
|
||||
)
|
||||
gpu_ok = "needs JIT"
|
||||
print(
|
||||
f" {hdim_q:>7} | {hdim_v:>7} | {flops:>14,} | {'asymmetric':>24} | {gpu_ok:>12}"
|
||||
)
|
||||
|
||||
# Step 6: Kernel configuration per hdim
|
||||
print("\nStep 6: Kernel Configuration Per Head Dimension")
|
||||
print(" Each hdim requires a dedicated compiled kernel:")
|
||||
print()
|
||||
print(
|
||||
" hdim=32: FmhaKernelConfig(hdim_q=32, hdim_v=32, "
|
||||
"tile_m0=128, tile_n0=128, tile_k0=32, tile_n1=32, tile_k1=32, tile_k0max=32)"
|
||||
)
|
||||
print(
|
||||
" hdim=64: FmhaKernelConfig(hdim_q=64, hdim_v=64, "
|
||||
"tile_m0=128, tile_n0=64, tile_k0=32, tile_n1=64, tile_k1=32, tile_k0max=64)"
|
||||
)
|
||||
print(
|
||||
" hdim=128: FmhaKernelConfig(hdim_q=128, hdim_v=128, "
|
||||
"tile_m0=128, tile_n0=128, tile_k0=32, tile_n1=128, tile_k1=32, tile_k0max=128)"
|
||||
)
|
||||
print(
|
||||
" hdim=256: FmhaKernelConfig(hdim_q=256, hdim_v=256, "
|
||||
"tile_m0=128, tile_n0=128, tile_k0=32, tile_n1=256, tile_k1=32, tile_k0max=256)"
|
||||
)
|
||||
print()
|
||||
print(" Asymmetric: FmhaKernelConfig(hdim_q=128, hdim_v=64, ...)")
|
||||
print(" tile_n1 tracks hdim_v; tile_k0max tracks hdim_q")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(f" Supported symmetric hdims: {hdims}")
|
||||
print(" Asymmetric hdim (hdim_q != hdim_v): fully supported")
|
||||
print(" Tile sizes scale with hdim; larger hdim needs wider tiles")
|
||||
if gpu_tflops > 0:
|
||||
print(f" GPU baseline (hdim=128): {gpu_tflops:.2f} TFLOPS @ {gpu_time:.4f} ms")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
373
dispatcher/examples/fmha/python/27_backward_dropout_fmha.py
Normal file
373
dispatcher/examples/fmha/python/27_backward_dropout_fmha.py
Normal file
@@ -0,0 +1,373 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 27: Backward Pass with Dropout FMHA
|
||||
|
||||
Demonstrates the FMHA backward pass with dropout. The backward pass
|
||||
computes dQ, dK, dV given dO (gradient of the output). When dropout is
|
||||
applied during forward, the same dropout mask must be replayed during
|
||||
backward for correctness.
|
||||
|
||||
Key concepts:
|
||||
- Deterministic mode (no atomics): reproducible gradients, may be slower
|
||||
- Non-deterministic mode: uses atomicAdd for dQ, faster but non-reproducible
|
||||
- store_randval: optionally store the dropout random values for debugging
|
||||
|
||||
The prebuilt library only has a forward kernel. This example validates
|
||||
the backward CPU reference and shows the API pattern.
|
||||
|
||||
Usage:
|
||||
python3 27_backward_dropout_fmha.py
|
||||
python3 27_backward_dropout_fmha.py --dropout 0.2
|
||||
python3 27_backward_dropout_fmha.py --seqlen 128 --deterministic
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
def cpu_attention_fwd_dropout(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
dropout_p: float,
|
||||
seed: int = 42,
|
||||
) -> tuple:
|
||||
"""CPU reference: forward with dropout, returning intermediates for backward.
|
||||
|
||||
Returns:
|
||||
O: [B, H, Sq, Dv] output
|
||||
P_drop: [B, H, Sq, Sk] attention weights after dropout
|
||||
lse: [B, H, Sq] log-sum-exp for numerical stability
|
||||
drop_mask: [B, H, Sq, Sk] binary dropout mask
|
||||
"""
|
||||
nhead_q = Q.shape[1]
|
||||
nhead_k = K.shape[1]
|
||||
if nhead_q != nhead_k:
|
||||
ratio = nhead_q // nhead_k
|
||||
K = np.repeat(K, ratio, axis=1)
|
||||
V = np.repeat(V, ratio, axis=1)
|
||||
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
S_sum = S_exp.sum(axis=-1, keepdims=True)
|
||||
P = S_exp / S_sum
|
||||
|
||||
lse = np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
drop_mask = (rng.rand(*P.shape) >= dropout_p).astype(np.float32)
|
||||
drop_scale = 1.0 / (1.0 - dropout_p) if dropout_p < 1.0 else 0.0
|
||||
P_drop = P * drop_mask * drop_scale
|
||||
|
||||
out = np.matmul(P_drop, V)
|
||||
return out, P_drop, lse, drop_mask
|
||||
|
||||
|
||||
def cpu_attention_bwd_dropout(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
out: np.ndarray,
|
||||
dO: np.ndarray,
|
||||
lse: np.ndarray,
|
||||
scale: float,
|
||||
dropout_p: float,
|
||||
drop_mask: np.ndarray,
|
||||
deterministic: bool = False,
|
||||
) -> tuple:
|
||||
"""CPU reference: backward with dropout.
|
||||
|
||||
Args:
|
||||
Q: [B, H, Sq, Dq] float32
|
||||
K: [B, H, Sk, Dq] float32 (already GQA-expanded if needed)
|
||||
V: [B, H, Sk, Dv] float32
|
||||
out: [B, H, Sq, Dv] float32 (forward output)
|
||||
dO: [B, H, Sq, Dv] float32 (output gradient)
|
||||
lse: [B, H, Sq] float32 (log-sum-exp from forward)
|
||||
scale: softmax scale
|
||||
dropout_p: dropout probability
|
||||
drop_mask: [B, H, Sq, Sk] binary mask from forward
|
||||
deterministic: if True, avoid any non-deterministic accumulation
|
||||
|
||||
Returns:
|
||||
dQ: [B, H, Sq, Dq]
|
||||
dK: [B, H, Sk, Dq]
|
||||
dV: [B, H, Sk, Dv]
|
||||
"""
|
||||
drop_scale = 1.0 / (1.0 - dropout_p) if dropout_p < 1.0 else 0.0
|
||||
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
P = np.exp(S - S_max) / np.exp(S - S_max).sum(axis=-1, keepdims=True)
|
||||
|
||||
P_drop = P * drop_mask * drop_scale
|
||||
|
||||
dV = np.matmul(P_drop.transpose(0, 1, 3, 2), dO)
|
||||
|
||||
dP_drop = np.matmul(dO, V.transpose(0, 1, 3, 2))
|
||||
|
||||
dP = dP_drop * drop_mask * drop_scale
|
||||
|
||||
D = (dO * out).sum(axis=-1, keepdims=True)
|
||||
dS = P * (dP - D) * scale
|
||||
|
||||
dQ = np.matmul(dS, K)
|
||||
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q)
|
||||
|
||||
return dQ, dK, dV
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Backward Pass with Dropout FMHA Example",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=64)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--dropout", type=float, default=0.1, help="Dropout probability"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deterministic", action="store_true", help="Use deterministic mode"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 27: Backward Pass with Dropout FMHA")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
# Step 1: Forward with dropout
|
||||
print("\nStep 1: Forward Pass with Dropout")
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
|
||||
|
||||
O_nodrop = cpu_attention_fwd(Q, K, V, prob.scale)
|
||||
O_drop, P_drop, lse, drop_mask = cpu_attention_fwd_dropout(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
prob.scale,
|
||||
args.dropout,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
print(f" Shape: {prob.q_shape()}")
|
||||
print(f" Dropout: p={args.dropout}")
|
||||
print(
|
||||
f" Drop mask: {drop_mask.sum():.0f}/{drop_mask.size} kept "
|
||||
f"({100 * drop_mask.mean():.1f}%, expected {100 * (1 - args.dropout):.1f}%)"
|
||||
)
|
||||
print(f" O (no drop): range=[{O_nodrop.min():.4f}, {O_nodrop.max():.4f}]")
|
||||
print(f" O (dropout): range=[{O_drop.min():.4f}, {O_drop.max():.4f}]")
|
||||
print(f" LSE shape: {lse.shape}")
|
||||
|
||||
# Step 2: Backward pass
|
||||
print("\nStep 2: Backward Pass")
|
||||
|
||||
np.random.seed(123)
|
||||
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
|
||||
|
||||
dQ, dK, dV = cpu_attention_bwd_dropout(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O_drop,
|
||||
dO,
|
||||
lse,
|
||||
prob.scale,
|
||||
args.dropout,
|
||||
drop_mask,
|
||||
deterministic=args.deterministic,
|
||||
)
|
||||
|
||||
print(f" dQ shape: {dQ.shape} range=[{dQ.min():.6f}, {dQ.max():.6f}]")
|
||||
print(f" dK shape: {dK.shape} range=[{dK.min():.6f}, {dK.max():.6f}]")
|
||||
print(f" dV shape: {dV.shape} range=[{dV.min():.6f}, {dV.max():.6f}]")
|
||||
print(f" Deterministic: {args.deterministic}")
|
||||
|
||||
# Step 3: Verify gradient correctness via finite differences
|
||||
print("\nStep 3: Gradient Verification (Finite Differences)")
|
||||
|
||||
eps = 1e-3
|
||||
num_checks = 5
|
||||
rng = np.random.RandomState(99)
|
||||
|
||||
print(f"\n Checking {num_checks} random elements per tensor:")
|
||||
print(
|
||||
f" {'Tensor':>8} {'Index':>24} {'Analytic':>14} {'Numerical':>14} {'RelErr':>12}"
|
||||
)
|
||||
print(" " + "-" * 76)
|
||||
|
||||
for tensor_name, param, grad in [("dQ", Q, dQ), ("dK", K, dK), ("dV", V, dV)]:
|
||||
for _ in range(num_checks):
|
||||
idx = tuple(rng.randint(0, s) for s in param.shape)
|
||||
|
||||
param_plus = param.copy()
|
||||
param_plus[idx] += eps
|
||||
param_minus = param.copy()
|
||||
param_minus[idx] -= eps
|
||||
|
||||
if tensor_name == "dQ":
|
||||
O_p, _, _, _ = cpu_attention_fwd_dropout(
|
||||
param_plus, K, V, prob.scale, args.dropout, seed=42
|
||||
)
|
||||
O_m, _, _, _ = cpu_attention_fwd_dropout(
|
||||
param_minus, K, V, prob.scale, args.dropout, seed=42
|
||||
)
|
||||
elif tensor_name == "dK":
|
||||
O_p, _, _, _ = cpu_attention_fwd_dropout(
|
||||
Q, param_plus, V, prob.scale, args.dropout, seed=42
|
||||
)
|
||||
O_m, _, _, _ = cpu_attention_fwd_dropout(
|
||||
Q, param_minus, V, prob.scale, args.dropout, seed=42
|
||||
)
|
||||
else:
|
||||
O_p, _, _, _ = cpu_attention_fwd_dropout(
|
||||
Q, K, param_plus, prob.scale, args.dropout, seed=42
|
||||
)
|
||||
O_m, _, _, _ = cpu_attention_fwd_dropout(
|
||||
Q, K, param_minus, prob.scale, args.dropout, seed=42
|
||||
)
|
||||
|
||||
numerical = (O_p * dO).sum() - (O_m * dO).sum()
|
||||
numerical /= 2 * eps
|
||||
analytic = grad[idx]
|
||||
|
||||
rel_err = abs(analytic - numerical) / (abs(numerical) + 1e-8)
|
||||
idx_str = str(idx)
|
||||
print(
|
||||
f" {tensor_name:>8} {idx_str:>24} {analytic:>14.6f} {numerical:>14.6f} {rel_err:>12.2e}"
|
||||
)
|
||||
|
||||
# Step 4: Deterministic vs non-deterministic comparison
|
||||
print("\nStep 4: Deterministic vs Non-Deterministic")
|
||||
|
||||
dQ_det, dK_det, dV_det = cpu_attention_bwd_dropout(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O_drop,
|
||||
dO,
|
||||
lse,
|
||||
prob.scale,
|
||||
args.dropout,
|
||||
drop_mask,
|
||||
deterministic=True,
|
||||
)
|
||||
dQ_ndet, dK_ndet, dV_ndet = cpu_attention_bwd_dropout(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O_drop,
|
||||
dO,
|
||||
lse,
|
||||
prob.scale,
|
||||
args.dropout,
|
||||
drop_mask,
|
||||
deterministic=False,
|
||||
)
|
||||
|
||||
validator = FmhaValidator(rtol=1e-5, atol=1e-5)
|
||||
|
||||
for name, g_det, g_ndet in [
|
||||
("dQ", dQ_det, dQ_ndet),
|
||||
("dK", dK_det, dK_ndet),
|
||||
("dV", dV_det, dV_ndet),
|
||||
]:
|
||||
ok, max_abs, _ = validator.check(g_det, g_ndet)
|
||||
print(
|
||||
f" {name}: det vs non-det max_err={max_abs:.2e} {'MATCH' if ok else 'DIFFER'}"
|
||||
)
|
||||
|
||||
print("\n NOTE: In CPU reference both modes are identical.")
|
||||
print(" On GPU, non-deterministic mode uses atomicAdd for dQ accumulation,")
|
||||
print(" which can cause tiny floating-point differences across runs.")
|
||||
|
||||
# Step 5: Dropout probability sweep
|
||||
print("\nStep 5: Dropout Probability Sweep")
|
||||
|
||||
probs = [0.0, 0.1, 0.2, 0.3, 0.5]
|
||||
print(
|
||||
f"\n {'p':>6} {'|dQ| mean':>12} {'|dK| mean':>12} {'|dV| mean':>12} {'Kept%':>8}"
|
||||
)
|
||||
print(" " + "-" * 54)
|
||||
|
||||
for p in probs:
|
||||
O_p, _, _, dm = cpu_attention_fwd_dropout(Q, K, V, prob.scale, p, seed=42)
|
||||
dQ_p, dK_p, dV_p = cpu_attention_bwd_dropout(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O_p,
|
||||
dO,
|
||||
lse,
|
||||
prob.scale,
|
||||
p,
|
||||
dm,
|
||||
)
|
||||
kept = 100 * dm.mean()
|
||||
print(
|
||||
f" {p:>6.2f} {np.abs(dQ_p).mean():>12.6f} {np.abs(dK_p).mean():>12.6f} "
|
||||
f"{np.abs(dV_p).mean():>12.6f} {kept:>7.1f}%"
|
||||
)
|
||||
|
||||
# Step 6: GPU API pattern
|
||||
print("\nStep 6: GPU Backward Kernel Configuration")
|
||||
print(" NOTE: The prebuilt library only has a forward kernel.")
|
||||
print(" FMHA backward requires 3 kernel stages:")
|
||||
print()
|
||||
print(" Stage 1: bwd_dot_do_o -- compute D = rowsum(dO * O)")
|
||||
print(" Stage 2: bwd_dq_dk_dv -- compute dQ, dK, dV")
|
||||
print(" Stage 3: bwd_convert_dq -- convert accumulated dQ")
|
||||
print()
|
||||
print(" With dropout, the signature requires:")
|
||||
print(" .dropout(true)")
|
||||
print(" .store_randval(false) // or true to save random values")
|
||||
print(f" .deterministic({'true' if args.deterministic else 'false'})")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(" Backward with dropout: replays same mask from forward pass")
|
||||
print(" Deterministic mode: reproducible but potentially slower on GPU")
|
||||
print(" 3-stage backward: dot_do_o -> dq_dk_dv -> convert_dq")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
360
dispatcher/examples/fmha/python/28_backward_dbias_fmha.py
Normal file
360
dispatcher/examples/fmha/python/28_backward_dbias_fmha.py
Normal file
@@ -0,0 +1,360 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 28: Backward Bias Gradient (dbias) FMHA
|
||||
|
||||
Demonstrates computing the gradient of the elementwise attention bias
|
||||
during the backward pass. When forward attention uses:
|
||||
S = Q @ K^T * scale + bias
|
||||
the backward pass must compute:
|
||||
dbias = sum over batch of (dP)
|
||||
where dP is the gradient of the attention probabilities.
|
||||
|
||||
This is useful for learnable relative position biases (e.g., ALiBi
|
||||
training, T5-style relative position embeddings).
|
||||
|
||||
The prebuilt library only has a forward kernel. This example validates
|
||||
the dbias CPU reference and shows the API pattern.
|
||||
|
||||
Usage:
|
||||
python3 28_backward_dbias_fmha.py
|
||||
python3 28_backward_dbias_fmha.py --seqlen 128
|
||||
python3 28_backward_dbias_fmha.py --bias-type alibi
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaProblem,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
def make_elementwise_bias(nhead: int, seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Create a simple elementwise attention bias [nhead, seqlen_q, seqlen_k]."""
|
||||
bias = np.zeros((nhead, seqlen_q, seqlen_k), dtype=np.float32)
|
||||
for h in range(nhead):
|
||||
for i in range(seqlen_q):
|
||||
for j in range(seqlen_k):
|
||||
bias[h, i, j] = -0.1 * abs(i - j) * (h + 1) / nhead
|
||||
return bias
|
||||
|
||||
|
||||
def make_alibi_bias(nhead: int, seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Create ALiBi-style attention bias [nhead, seqlen_q, seqlen_k].
|
||||
|
||||
ALiBi adds a linear penalty proportional to distance:
|
||||
bias[h, i, j] = -slope_h * |i - j|
|
||||
where slope_h decreases geometrically across heads.
|
||||
"""
|
||||
slopes = np.array([2 ** (-(8 * (h + 1) / nhead)) for h in range(nhead)])
|
||||
bias = np.zeros((nhead, seqlen_q, seqlen_k), dtype=np.float32)
|
||||
for h in range(nhead):
|
||||
for i in range(seqlen_q):
|
||||
for j in range(seqlen_k):
|
||||
bias[h, i, j] = -slopes[h] * abs(i - j)
|
||||
return bias
|
||||
|
||||
|
||||
def cpu_attention_fwd_bias(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
bias: np.ndarray,
|
||||
) -> tuple:
|
||||
"""CPU forward with elementwise bias, returning intermediates.
|
||||
|
||||
Args:
|
||||
Q: [B, H, Sq, Dq]
|
||||
K: [B, H, Sk, Dq]
|
||||
V: [B, H, Sk, Dv]
|
||||
bias: [H, Sq, Sk] broadcast over batch
|
||||
|
||||
Returns:
|
||||
O: [B, H, Sq, Dv]
|
||||
P: [B, H, Sq, Sk] attention probabilities
|
||||
lse: [B, H, Sq] log-sum-exp
|
||||
"""
|
||||
nhead_q = Q.shape[1]
|
||||
nhead_k = K.shape[1]
|
||||
if nhead_q != nhead_k:
|
||||
ratio = nhead_q // nhead_k
|
||||
K = np.repeat(K, ratio, axis=1)
|
||||
V = np.repeat(V, ratio, axis=1)
|
||||
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
S = S + bias[np.newaxis, :, :, :]
|
||||
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
S_sum = S_exp.sum(axis=-1, keepdims=True)
|
||||
P = S_exp / S_sum
|
||||
|
||||
lse = np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)
|
||||
out = np.matmul(P, V)
|
||||
return out, P, lse
|
||||
|
||||
|
||||
def cpu_attention_bwd_dbias(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
out: np.ndarray,
|
||||
dO: np.ndarray,
|
||||
P: np.ndarray,
|
||||
scale: float,
|
||||
bias: np.ndarray,
|
||||
) -> tuple:
|
||||
"""CPU backward computing dQ, dK, dV, and dbias.
|
||||
|
||||
Args:
|
||||
Q, K, V: forward inputs [B, H, Sq/Sk, D]
|
||||
out: forward output [B, H, Sq, Dv]
|
||||
dO: output gradient [B, H, Sq, Dv]
|
||||
P: attention probabilities [B, H, Sq, Sk]
|
||||
scale: softmax scale
|
||||
bias: [H, Sq, Sk] attention bias
|
||||
|
||||
Returns:
|
||||
dQ: [B, H, Sq, Dq]
|
||||
dK: [B, H, Sk, Dq]
|
||||
dV: [B, H, Sk, Dv]
|
||||
dbias: [H, Sq, Sk] summed over batch dimension
|
||||
"""
|
||||
nhead_q = Q.shape[1]
|
||||
nhead_k = K.shape[1]
|
||||
if nhead_q != nhead_k:
|
||||
ratio = nhead_q // nhead_k
|
||||
K = np.repeat(K, ratio, axis=1)
|
||||
V = np.repeat(V, ratio, axis=1)
|
||||
|
||||
dV = np.matmul(P.transpose(0, 1, 3, 2), dO)
|
||||
|
||||
dP = np.matmul(dO, V.transpose(0, 1, 3, 2))
|
||||
|
||||
D = (dO * out).sum(axis=-1, keepdims=True)
|
||||
dS = P * (dP - D) * scale
|
||||
|
||||
dQ = np.matmul(dS, K)
|
||||
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q)
|
||||
|
||||
dbias = dS.sum(axis=0) / scale
|
||||
|
||||
return dQ, dK, dV, dbias
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Backward Bias Gradient (dbias) FMHA Example",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=4)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=64)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--bias-type", choices=["elementwise", "alibi"], default="elementwise"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 28: Backward Bias Gradient (dbias) FMHA")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
# Step 1: Create bias
|
||||
print(f"\nStep 1: Create {args.bias_type.title()} Bias")
|
||||
|
||||
if args.bias_type == "alibi":
|
||||
bias = make_alibi_bias(args.nhead, args.seqlen, args.seqlen)
|
||||
else:
|
||||
bias = make_elementwise_bias(args.nhead, args.seqlen, args.seqlen)
|
||||
|
||||
print(f" Bias shape: {bias.shape}")
|
||||
print(f" Bias range: [{bias.min():.4f}, {bias.max():.4f}]")
|
||||
print(f" Bias type: {args.bias_type}")
|
||||
|
||||
for h in range(min(4, args.nhead)):
|
||||
print(
|
||||
f" Head {h}: range=[{bias[h].min():.4f}, {bias[h].max():.4f}] "
|
||||
f"mean={bias[h].mean():.4f}"
|
||||
)
|
||||
|
||||
# Step 2: Forward pass with bias
|
||||
print("\nStep 2: Forward Pass with Bias")
|
||||
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.3).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.3).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.3).astype(np.float32)
|
||||
|
||||
O_nobias = cpu_attention_fwd(Q, K, V, prob.scale)
|
||||
O_bias, P, lse = cpu_attention_fwd_bias(Q, K, V, prob.scale, bias)
|
||||
|
||||
diff = np.abs(O_nobias - O_bias)
|
||||
print(f" O (no bias): range=[{O_nobias.min():.4f}, {O_nobias.max():.4f}]")
|
||||
print(f" O (biased): range=[{O_bias.min():.4f}, {O_bias.max():.4f}]")
|
||||
print(f" Bias effect: max_diff={diff.max():.6e} mean_diff={diff.mean():.6e}")
|
||||
|
||||
# Step 3: Backward pass with dbias
|
||||
print("\nStep 3: Backward Pass (dQ, dK, dV, dbias)")
|
||||
|
||||
np.random.seed(123)
|
||||
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
|
||||
|
||||
dQ, dK, dV, dbias = cpu_attention_bwd_dbias(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O_bias,
|
||||
dO,
|
||||
P,
|
||||
prob.scale,
|
||||
bias,
|
||||
)
|
||||
|
||||
print(f" dQ shape: {dQ.shape} range=[{dQ.min():.6f}, {dQ.max():.6f}]")
|
||||
print(f" dK shape: {dK.shape} range=[{dK.min():.6f}, {dK.max():.6f}]")
|
||||
print(f" dV shape: {dV.shape} range=[{dV.min():.6f}, {dV.max():.6f}]")
|
||||
print(f" dbias shape: {dbias.shape} range=[{dbias.min():.6f}, {dbias.max():.6f}]")
|
||||
|
||||
# Step 4: Verify dbias via finite differences
|
||||
print("\nStep 4: dbias Gradient Verification (Finite Differences)")
|
||||
|
||||
eps = 1e-3
|
||||
num_checks = 8
|
||||
rng = np.random.RandomState(99)
|
||||
|
||||
print(
|
||||
f"\n {'Index':>20} {'Analytic':>14} {'Numerical':>14} {'RelErr':>12} {'Status':>8}"
|
||||
)
|
||||
print(" " + "-" * 72)
|
||||
|
||||
all_grad_ok = True
|
||||
for _ in range(num_checks):
|
||||
h = rng.randint(0, args.nhead)
|
||||
i = rng.randint(0, args.seqlen)
|
||||
j = rng.randint(0, args.seqlen)
|
||||
|
||||
bias_plus = bias.copy()
|
||||
bias_plus[h, i, j] += eps
|
||||
bias_minus = bias.copy()
|
||||
bias_minus[h, i, j] -= eps
|
||||
|
||||
O_p, _, _ = cpu_attention_fwd_bias(Q, K, V, prob.scale, bias_plus)
|
||||
O_m, _, _ = cpu_attention_fwd_bias(Q, K, V, prob.scale, bias_minus)
|
||||
|
||||
numerical = ((O_p * dO).sum() - (O_m * dO).sum()) / (2 * eps)
|
||||
analytic = dbias[h, i, j]
|
||||
|
||||
rel_err = abs(analytic - numerical) / (abs(numerical) + 1e-8)
|
||||
ok = rel_err < 1e-2
|
||||
all_grad_ok = all_grad_ok and ok
|
||||
idx_str = f"({h},{i},{j})"
|
||||
print(
|
||||
f" {idx_str:>20} {analytic:>14.6f} {numerical:>14.6f} {rel_err:>12.2e} {'OK' if ok else 'FAIL':>8}"
|
||||
)
|
||||
|
||||
# Step 5: dbias structure analysis
|
||||
print("\nStep 5: dbias Structure Analysis")
|
||||
|
||||
print("\n Per-head dbias statistics:")
|
||||
print(f" {'Head':>6} {'Mean':>12} {'Std':>12} {'Min':>12} {'Max':>12}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
for h in range(min(8, args.nhead)):
|
||||
db_h = dbias[h]
|
||||
print(
|
||||
f" {h:>6} {db_h.mean():>12.6f} {db_h.std():>12.6f} "
|
||||
f"{db_h.min():>12.6f} {db_h.max():>12.6f}"
|
||||
)
|
||||
|
||||
# Step 6: Batch size effect on dbias
|
||||
print("\nStep 6: Batch Size Effect on dbias")
|
||||
print(" dbias = sum of per-sample dS / scale over batch dimension")
|
||||
print(" Larger batch -> dbias aggregates more gradient signal")
|
||||
|
||||
batch_sizes = [1, 2, 4, 8]
|
||||
print(
|
||||
f"\n {'Batch':>6} {'|dbias| mean':>14} {'|dbias| max':>14} {'dbias std':>14}"
|
||||
)
|
||||
print(" " + "-" * 52)
|
||||
|
||||
for b in batch_sizes:
|
||||
Q_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.3).astype(
|
||||
np.float32
|
||||
)
|
||||
K_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.3).astype(
|
||||
np.float32
|
||||
)
|
||||
V_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.3).astype(
|
||||
np.float32
|
||||
)
|
||||
dO_b = (np.random.randn(b, args.nhead, args.seqlen, args.hdim) * 0.1).astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
O_b, P_b, lse_b = cpu_attention_fwd_bias(Q_b, K_b, V_b, prob.scale, bias)
|
||||
_, _, _, dbias_b = cpu_attention_bwd_dbias(
|
||||
Q_b,
|
||||
K_b,
|
||||
V_b,
|
||||
O_b,
|
||||
dO_b,
|
||||
P_b,
|
||||
prob.scale,
|
||||
bias,
|
||||
)
|
||||
print(
|
||||
f" {b:>6} {np.abs(dbias_b).mean():>14.6f} {np.abs(dbias_b).max():>14.6f} "
|
||||
f"{dbias_b.std():>14.6f}"
|
||||
)
|
||||
|
||||
# Step 7: GPU API pattern
|
||||
print("\nStep 7: GPU Kernel Configuration")
|
||||
print(" NOTE: The prebuilt library only has a forward kernel without bias.")
|
||||
print(" For backward with dbias, compile kernels with:")
|
||||
print()
|
||||
print(" Forward: FmhaSignature().bias('bias') // elementwise bias")
|
||||
print(" Backward: FmhaSignature()")
|
||||
print(" .family('bwd_dq_dk_dv')")
|
||||
print(" .bias('bias')")
|
||||
print(" .dbias(true) // enable dbias computation")
|
||||
print()
|
||||
print(" In codegen JSON:")
|
||||
print(" 'bias': 'bias', // forward: elementwise bias")
|
||||
print(" 'dbias': true, // backward: compute bias gradient")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print(" dbias = sum_batch(P * (dP - D)) (gradient of elementwise bias)")
|
||||
print(f" Shape: [{args.nhead}, {args.seqlen}, {args.seqlen}] (same as bias)")
|
||||
print(f" Gradient check: {'PASS' if all_grad_ok else 'FAIL'}")
|
||||
print(" Use case: learnable relative position biases (ALiBi, T5, etc.)")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
146
dispatcher/examples/fmha/python/29_sweep_seqlen.py
Normal file
146
dispatcher/examples/fmha/python/29_sweep_seqlen.py
Normal file
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 29: Sweep Sequence Length
|
||||
|
||||
Demonstrates how FMHA performance scales with sequence length.
|
||||
FMHA has O(n^2) compute in seqlen (Q*K^T), so TFLOPS should increase
|
||||
with longer sequences as the GPU becomes better utilized.
|
||||
|
||||
Fixed: batch=2, nhead=8, hdim=128
|
||||
Sweep: seqlen in [32, 64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
Usage:
|
||||
python3 29_sweep_seqlen.py
|
||||
python3 29_sweep_seqlen.py --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
BATCH = 2
|
||||
NHEAD = 8
|
||||
HDIM = 128
|
||||
SEQLENS = [32, 64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Sweep Sequence Length FMHA")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 29: Sweep Sequence Length")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n Fixed: batch={BATCH}, nhead={NHEAD}, hdim={HDIM}")
|
||||
print(f" Sweep: seqlen in {SEQLENS}")
|
||||
print(f" Arch: {args.arch}")
|
||||
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
|
||||
# Step 1: JIT-compile FMHA kernel
|
||||
print("\nStep 1: JIT-Compile FMHA Kernel")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=HDIM,
|
||||
hdim_v=HDIM,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
return 1
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
# Step 2: Sweep
|
||||
print("\nStep 2: Sequence Length Sweep")
|
||||
|
||||
hdr = f" {'SeqLen':>8} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<6}"
|
||||
print(f"\n{hdr}")
|
||||
print(" " + "-" * 60)
|
||||
|
||||
np.random.seed(42)
|
||||
results = []
|
||||
|
||||
for seqlen in SEQLENS:
|
||||
prob = FmhaProblem(
|
||||
batch=BATCH,
|
||||
nhead_q=NHEAD,
|
||||
nhead_k=NHEAD,
|
||||
seqlen_q=seqlen,
|
||||
seqlen_k=seqlen,
|
||||
hdim_q=HDIM,
|
||||
hdim_v=HDIM,
|
||||
)
|
||||
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
|
||||
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32),
|
||||
K.astype(np.float32),
|
||||
V.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
|
||||
res = runner.run(Q, K, V, prob)
|
||||
if not res.success:
|
||||
print(
|
||||
f" {seqlen:>8} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<6}"
|
||||
)
|
||||
results.append((seqlen, False, 0.0, 0.0, 0.0))
|
||||
continue
|
||||
|
||||
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
|
||||
ok, _, _ = validator.check(res.output, O_ref)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
|
||||
print(
|
||||
f" {seqlen:>8} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<6}"
|
||||
)
|
||||
results.append((seqlen, ok, res.time_ms, res.tflops, max_err))
|
||||
|
||||
# Step 3: Scaling analysis
|
||||
print("\nStep 3: Scaling Analysis")
|
||||
valid = [(s, t, tf) for s, ok, t, tf, _ in results if ok and tf > 0]
|
||||
if len(valid) >= 2:
|
||||
s0, _, tf0 = valid[0]
|
||||
s_last, _, tf_last = valid[-1]
|
||||
print(f" Shortest (seqlen={s0}): {tf0:.2f} TFLOPS")
|
||||
print(f" Longest (seqlen={s_last}): {tf_last:.2f} TFLOPS")
|
||||
print(f" Speedup: {tf_last / tf0:.1f}x TFLOPS improvement")
|
||||
print(" Note: Longer sequences expose more parallelism to the GPU")
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for _, ok, *_ in results if ok)
|
||||
print("\n" + "=" * 70)
|
||||
print(f" Results: {passed}/{len(results)} passed")
|
||||
print(f" Fixed: B={BATCH} H={NHEAD} D={HDIM}")
|
||||
print(f" Sweep: seqlen={SEQLENS}")
|
||||
print(f" Status: {'PASS' if passed == len(results) else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if passed == len(results) else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
151
dispatcher/examples/fmha/python/30_sweep_batch.py
Normal file
151
dispatcher/examples/fmha/python/30_sweep_batch.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 30: Sweep Batch Size
|
||||
|
||||
Demonstrates how FMHA performance scales with batch size.
|
||||
FMHA compute scales linearly with batch, so time should increase
|
||||
linearly while TFLOPS remains roughly constant once the GPU is saturated.
|
||||
|
||||
Fixed: seqlen=128, nhead=8, hdim=128
|
||||
Sweep: batch in [1, 2, 4, 8, 16, 32]
|
||||
|
||||
Usage:
|
||||
python3 30_sweep_batch.py
|
||||
python3 30_sweep_batch.py --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
SEQLEN = 128
|
||||
NHEAD = 8
|
||||
HDIM = 128
|
||||
BATCHES = [1, 2, 4, 8, 16, 32]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Sweep Batch Size FMHA")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 30: Sweep Batch Size")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n Fixed: seqlen={SEQLEN}, nhead={NHEAD}, hdim={HDIM}")
|
||||
print(f" Sweep: batch in {BATCHES}")
|
||||
print(f" Arch: {args.arch}")
|
||||
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
|
||||
# Step 1: JIT-compile FMHA kernel
|
||||
print("\nStep 1: JIT-Compile FMHA Kernel")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=HDIM,
|
||||
hdim_v=HDIM,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
return 1
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
# Step 2: Sweep
|
||||
print("\nStep 2: Batch Size Sweep")
|
||||
|
||||
hdr = f" {'Batch':>8} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<6}"
|
||||
print(f"\n{hdr}")
|
||||
print(" " + "-" * 60)
|
||||
|
||||
np.random.seed(42)
|
||||
results = []
|
||||
|
||||
for batch in BATCHES:
|
||||
prob = FmhaProblem(
|
||||
batch=batch,
|
||||
nhead_q=NHEAD,
|
||||
nhead_k=NHEAD,
|
||||
seqlen_q=SEQLEN,
|
||||
seqlen_k=SEQLEN,
|
||||
hdim_q=HDIM,
|
||||
hdim_v=HDIM,
|
||||
)
|
||||
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
|
||||
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32),
|
||||
K.astype(np.float32),
|
||||
V.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
|
||||
res = runner.run(Q, K, V, prob)
|
||||
if not res.success:
|
||||
print(
|
||||
f" {batch:>8} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<6}"
|
||||
)
|
||||
results.append((batch, False, 0.0, 0.0, 0.0))
|
||||
continue
|
||||
|
||||
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
|
||||
ok, _, _ = validator.check(res.output, O_ref)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
|
||||
print(
|
||||
f" {batch:>8} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<6}"
|
||||
)
|
||||
results.append((batch, ok, res.time_ms, res.tflops, max_err))
|
||||
|
||||
# Step 3: Linearity analysis
|
||||
print("\nStep 3: Linear Scaling Analysis")
|
||||
valid = [(b, t, tf) for b, ok, t, tf, _ in results if ok and t > 0]
|
||||
if len(valid) >= 2:
|
||||
b0, t0, tf0 = valid[0]
|
||||
b_last, t_last, tf_last = valid[-1]
|
||||
batch_ratio = b_last / b0
|
||||
time_ratio = t_last / t0
|
||||
linearity = time_ratio / batch_ratio
|
||||
|
||||
print(
|
||||
f" Batch {b0} -> {b_last}: {batch_ratio:.0f}x batch, {time_ratio:.1f}x time"
|
||||
)
|
||||
print(f" Linearity factor: {linearity:.2f} (1.0 = perfect linear scaling)")
|
||||
print(f" TFLOPS range: {tf0:.2f} - {tf_last:.2f}")
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for _, ok, *_ in results if ok)
|
||||
print("\n" + "=" * 70)
|
||||
print(f" Results: {passed}/{len(results)} passed")
|
||||
print(f" Fixed: S={SEQLEN} H={NHEAD} D={HDIM}")
|
||||
print(f" Sweep: batch={BATCHES}")
|
||||
print(f" Status: {'PASS' if passed == len(results) else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if passed == len(results) else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
172
dispatcher/examples/fmha/python/31_sweep_nhead.py
Normal file
172
dispatcher/examples/fmha/python/31_sweep_nhead.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 31: Sweep Number of Heads (MHA + GQA)
|
||||
|
||||
Demonstrates FMHA performance across different head counts, including
|
||||
Grouped Query Attention (GQA) where nhead_q > nhead_k.
|
||||
|
||||
Part 1 - MHA sweep: nhead_q == nhead_k
|
||||
Part 2 - GQA variants: nhead_q != nhead_k (multiple Q heads share K/V)
|
||||
|
||||
Fixed: batch=2, seqlen=128, hdim=128
|
||||
|
||||
Usage:
|
||||
python3 31_sweep_nhead.py
|
||||
python3 31_sweep_nhead.py --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
BATCH = 2
|
||||
SEQLEN = 128
|
||||
HDIM = 128
|
||||
|
||||
MHA_NHEADS = [1, 2, 4, 8, 16, 32]
|
||||
GQA_CONFIGS = [
|
||||
(8, 1, "GQA 8:1"),
|
||||
(16, 4, "GQA 4:1"),
|
||||
(32, 8, "GQA 4:1"),
|
||||
]
|
||||
|
||||
|
||||
def run_sweep(runner, validator, configs, label):
|
||||
"""Run a sweep over (nhead_q, nhead_k) configurations."""
|
||||
hdr = f" {'nhead_q':>8} | {'nhead_k':>8} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<6}"
|
||||
print(f"\n{hdr}")
|
||||
print(" " + "-" * 70)
|
||||
|
||||
np.random.seed(42)
|
||||
results = []
|
||||
|
||||
for nhead_q, nhead_k in configs:
|
||||
prob = FmhaProblem(
|
||||
batch=BATCH,
|
||||
nhead_q=nhead_q,
|
||||
nhead_k=nhead_k,
|
||||
seqlen_q=SEQLEN,
|
||||
seqlen_k=SEQLEN,
|
||||
hdim_q=HDIM,
|
||||
hdim_v=HDIM,
|
||||
)
|
||||
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float16)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float16)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float16)
|
||||
|
||||
O_ref = cpu_attention_fwd(
|
||||
Q.astype(np.float32),
|
||||
K.astype(np.float32),
|
||||
V.astype(np.float32),
|
||||
prob.scale,
|
||||
)
|
||||
|
||||
res = runner.run(Q, K, V, prob)
|
||||
if not res.success:
|
||||
print(
|
||||
f" {nhead_q:>8} | {nhead_k:>8} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<6}"
|
||||
)
|
||||
results.append((nhead_q, nhead_k, False, 0.0, 0.0, 0.0))
|
||||
continue
|
||||
|
||||
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
|
||||
ok, _, _ = validator.check(res.output, O_ref)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
|
||||
print(
|
||||
f" {nhead_q:>8} | {nhead_k:>8} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<6}"
|
||||
)
|
||||
results.append((nhead_q, nhead_k, ok, res.time_ms, res.tflops, max_err))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Sweep Number of Heads FMHA")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 31: Sweep Number of Heads (MHA + GQA)")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n Fixed: batch={BATCH}, seqlen={SEQLEN}, hdim={HDIM}")
|
||||
print(f" Arch: {args.arch}")
|
||||
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
|
||||
# Step 1: JIT-compile FMHA kernel
|
||||
print("\nStep 1: JIT-Compile FMHA Kernel")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=HDIM,
|
||||
hdim_v=HDIM,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
return 1
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
# Step 2: MHA sweep (nhead_q == nhead_k)
|
||||
print("\nStep 2: MHA Sweep (nhead_q == nhead_k)")
|
||||
mha_configs = [(n, n) for n in MHA_NHEADS]
|
||||
mha_results = run_sweep(runner, validator, mha_configs, "MHA")
|
||||
|
||||
# Step 3: GQA sweep (nhead_q > nhead_k)
|
||||
print("\nStep 3: GQA Sweep (nhead_q > nhead_k)")
|
||||
print(" GQA: multiple Q heads share fewer K/V heads")
|
||||
gqa_configs = [(nq, nk) for nq, nk, _ in GQA_CONFIGS]
|
||||
gqa_results = run_sweep(runner, validator, gqa_configs, "GQA")
|
||||
|
||||
# Step 4: Comparison
|
||||
print("\nStep 4: MHA vs GQA Comparison")
|
||||
all_results = mha_results + gqa_results
|
||||
valid_mha = [(nq, nk, tf) for nq, nk, ok, _, tf, _ in mha_results if ok and tf > 0]
|
||||
valid_gqa = [(nq, nk, tf) for nq, nk, ok, _, tf, _ in gqa_results if ok and tf > 0]
|
||||
|
||||
if valid_mha:
|
||||
best_mha = max(valid_mha, key=lambda x: x[2])
|
||||
print(f" Best MHA: nhead={best_mha[0]}, {best_mha[2]:.2f} TFLOPS")
|
||||
if valid_gqa:
|
||||
best_gqa = max(valid_gqa, key=lambda x: x[2])
|
||||
print(
|
||||
f" Best GQA: nhead_q={best_gqa[0]}, nhead_k={best_gqa[1]}, {best_gqa[2]:.2f} TFLOPS"
|
||||
)
|
||||
print(f" GQA saves K/V memory: {best_gqa[0]}:{best_gqa[1]} ratio")
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for *_, ok, _, _, _ in all_results if ok)
|
||||
total = len(all_results)
|
||||
print("\n" + "=" * 70)
|
||||
print(f" Results: {passed}/{total} passed")
|
||||
print(f" Fixed: B={BATCH} S={SEQLEN} D={HDIM}")
|
||||
print(f" MHA: nhead={MHA_NHEADS}")
|
||||
print(f" GQA: {[(nq, nk) for nq, nk, _ in GQA_CONFIGS]}")
|
||||
print(f" Status: {'PASS' if passed == total else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if passed == total else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
178
dispatcher/examples/fmha/python/32_sweep_hdim.py
Normal file
178
dispatcher/examples/fmha/python/32_sweep_hdim.py
Normal file
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 32: Sweep Head Dimension
|
||||
|
||||
Demonstrates FMHA across different head dimensions (32, 64, 128, 256).
|
||||
The prebuilt library only supports hdim=128; other head dimensions are
|
||||
validated via CPU reference only.
|
||||
|
||||
Fixed: batch=2, nhead=8, seqlen=128
|
||||
Sweep: hdim in [32, 64, 128, 256]
|
||||
|
||||
Usage:
|
||||
python3 32_sweep_hdim.py
|
||||
python3 32_sweep_hdim.py --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
FmhaValidator,
|
||||
cpu_attention_fwd,
|
||||
detect_gpu_arch,
|
||||
setup_fmha_dispatcher,
|
||||
)
|
||||
|
||||
BATCH = 2
|
||||
NHEAD = 8
|
||||
SEQLEN = 128
|
||||
HDIMS = [32, 64, 128, 256]
|
||||
GPU_SUPPORTED_HDIM = 128
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Sweep Head Dimension FMHA")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 32: Sweep Head Dimension")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n Fixed: batch={BATCH}, nhead={NHEAD}, seqlen={SEQLEN}")
|
||||
print(f" Sweep: hdim in {HDIMS}")
|
||||
print(f" Arch: {args.arch}")
|
||||
print(f" Note: Only hdim={GPU_SUPPORTED_HDIM} runs on GPU (prebuilt lib)")
|
||||
|
||||
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
||||
|
||||
# Step 1: JIT-compile FMHA kernel (hdim=128)
|
||||
print("\nStep 1: JIT-Compile FMHA Kernel")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=GPU_SUPPORTED_HDIM,
|
||||
hdim_v=GPU_SUPPORTED_HDIM,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
runner = None
|
||||
if not setup.success:
|
||||
print(f" JIT build failed: {setup.error}")
|
||||
print(" Will run CPU reference only")
|
||||
else:
|
||||
runner = setup.runner
|
||||
print(f" JIT build: {setup.build_time_s:.1f}s")
|
||||
|
||||
# Step 2: CPU reference for all hdims
|
||||
print("\nStep 2: CPU Reference for All Head Dimensions")
|
||||
|
||||
np.random.seed(42)
|
||||
cpu_data = {}
|
||||
|
||||
print(
|
||||
f"\n {'hdim':>6} | {'Scale':>8} | {'FLOPs':>14} | {'O Range':>22} | {'Finite':<6}"
|
||||
)
|
||||
print(" " + "-" * 66)
|
||||
|
||||
for hdim in HDIMS:
|
||||
prob = FmhaProblem(
|
||||
batch=BATCH,
|
||||
nhead_q=NHEAD,
|
||||
nhead_k=NHEAD,
|
||||
seqlen_q=SEQLEN,
|
||||
seqlen_k=SEQLEN,
|
||||
hdim_q=hdim,
|
||||
hdim_v=hdim,
|
||||
)
|
||||
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
|
||||
O_ref = cpu_attention_fwd(Q, K, V, prob.scale)
|
||||
is_finite = bool(np.all(np.isfinite(O_ref)))
|
||||
o_range = f"[{O_ref.min():.4f}, {O_ref.max():.4f}]"
|
||||
|
||||
print(
|
||||
f" {hdim:>6} | {prob.scale:>8.4f} | {prob.num_ops:>14,} | {o_range:>22} | {'OK' if is_finite else 'NaN!':<6}"
|
||||
)
|
||||
cpu_data[hdim] = (Q, K, V, O_ref, prob)
|
||||
|
||||
# Step 3: GPU sweep (only hdim=128 supported)
|
||||
print("\nStep 3: GPU Sweep")
|
||||
|
||||
hdr = f" {'hdim':>6} | {'Time(ms)':>10} | {'TFLOPS':>10} | {'MaxErr':>10} | {'Status':<10}"
|
||||
print(f"\n{hdr}")
|
||||
print(" " + "-" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
for hdim in HDIMS:
|
||||
Q, K, V, O_ref, prob = cpu_data[hdim]
|
||||
|
||||
if hdim != GPU_SUPPORTED_HDIM or runner is None:
|
||||
print(
|
||||
f" {hdim:>6} | {'---':>10} | {'---':>10} | {'---':>10} | {'CPU only':<10}"
|
||||
)
|
||||
results.append((hdim, True, 0.0, 0.0, 0.0))
|
||||
continue
|
||||
|
||||
Q_f16 = Q.astype(np.float16)
|
||||
K_f16 = K.astype(np.float16)
|
||||
V_f16 = V.astype(np.float16)
|
||||
|
||||
res = runner.run(Q_f16, K_f16, V_f16, prob)
|
||||
if not res.success:
|
||||
print(
|
||||
f" {hdim:>6} | {'---':>10} | {'---':>10} | {'---':>10} | {'FAIL':<10}"
|
||||
)
|
||||
results.append((hdim, False, 0.0, 0.0, 0.0))
|
||||
continue
|
||||
|
||||
max_err = float(np.abs(res.output.astype(np.float32) - O_ref).max())
|
||||
ok, _, _ = validator.check(res.output, O_ref)
|
||||
tag = "PASS" if ok else "FAIL"
|
||||
|
||||
print(
|
||||
f" {hdim:>6} | {res.time_ms:>10.4f} | {res.tflops:>10.2f} | {max_err:>10.2e} | {tag:<10}"
|
||||
)
|
||||
results.append((hdim, ok, res.time_ms, res.tflops, max_err))
|
||||
|
||||
# Step 4: hdim analysis
|
||||
print("\nStep 4: Head Dimension Analysis")
|
||||
print(" Each hdim requires a dedicated compiled kernel:")
|
||||
for hdim in HDIMS:
|
||||
gpu_status = "prebuilt" if hdim == GPU_SUPPORTED_HDIM else "needs JIT"
|
||||
tile_hint = f"tile_k0max={hdim}"
|
||||
print(f" hdim={hdim:>3}: {gpu_status:<10} ({tile_hint})")
|
||||
|
||||
print("\n Compute scales linearly with hdim (via Q*K^T and attn*V).")
|
||||
print(" Larger hdim = more work per token, fewer tokens processed per CU.")
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for _, ok, *_ in results if ok)
|
||||
total = len(results)
|
||||
print("\n" + "=" * 70)
|
||||
print(f" Results: {passed}/{total} passed")
|
||||
print(f" Fixed: B={BATCH} H={NHEAD} S={SEQLEN}")
|
||||
print(f" Sweep: hdim={HDIMS}")
|
||||
print(f" GPU: hdim={GPU_SUPPORTED_HDIM} only (prebuilt)")
|
||||
print(f" Status: {'PASS' if passed == total else 'FAIL'}")
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if passed == total else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
271
dispatcher/examples/fmha/python/33_bwd_masks_fmha.py
Normal file
271
dispatcher/examples/fmha/python/33_bwd_masks_fmha.py
Normal file
@@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 33: Backward Pass with Causal Masks
|
||||
|
||||
Demonstrates the FMHA backward pass with causal mask variants:
|
||||
1. no_mask -- Full attention (baseline)
|
||||
2. top_left -- Causal mask aligned to top-left corner
|
||||
3. bottom_right -- Causal mask aligned to bottom-right corner
|
||||
|
||||
For each mask type:
|
||||
- Forward: out = softmax(mask(Q @ K^T * scale)) @ V
|
||||
- Backward: dQ, dK, dV via analytical gradients through the masked softmax
|
||||
|
||||
CPU backward reference:
|
||||
dP = dO @ V^T
|
||||
D = rowsum(dO * out) (per-query-position scalar)
|
||||
dS = P * (dP - D)
|
||||
dQ = scale * dS @ K
|
||||
dK = scale * dS^T @ Q
|
||||
dV = P^T @ dO
|
||||
|
||||
Usage:
|
||||
python3 33_bwd_masks_fmha.py
|
||||
python3 33_bwd_masks_fmha.py --seqlen-q 128 --seqlen-k 192
|
||||
python3 33_bwd_masks_fmha.py --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
setup_fmha_dispatcher,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
def make_causal_mask_top_left(seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Causal mask aligned to top-left: position i attends to positions <= i."""
|
||||
row = np.arange(seqlen_q).reshape(-1, 1)
|
||||
col = np.arange(seqlen_k).reshape(1, -1)
|
||||
return (col <= row).astype(np.float32)
|
||||
|
||||
|
||||
def make_causal_mask_bottom_right(seqlen_q: int, seqlen_k: int) -> np.ndarray:
|
||||
"""Causal mask aligned to bottom-right: accounts for kv longer than q."""
|
||||
offset = seqlen_k - seqlen_q
|
||||
row = np.arange(seqlen_q).reshape(-1, 1)
|
||||
col = np.arange(seqlen_k).reshape(1, -1)
|
||||
return (col <= row + offset).astype(np.float32)
|
||||
|
||||
|
||||
def cpu_masked_fwd_with_intermediates(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
mask: np.ndarray,
|
||||
) -> tuple:
|
||||
"""Forward pass with mask, returning out, P, and LSE for backward.
|
||||
|
||||
Args:
|
||||
Q: [B, H, Sq, D] K: [B, H, Sk, D] V: [B, H, Sk, Dv]
|
||||
mask: [Sq, Sk] broadcast over batch and head
|
||||
|
||||
Returns: (out, P, lse)
|
||||
"""
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
mask_broad = mask[np.newaxis, np.newaxis, :, :]
|
||||
S = np.where(mask_broad > 0, S, -1e9)
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
S_sum = S_exp.sum(axis=-1, keepdims=True)
|
||||
P = S_exp / S_sum
|
||||
out = np.matmul(P, V)
|
||||
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
|
||||
return out, P, lse
|
||||
|
||||
|
||||
def cpu_masked_bwd(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
out: np.ndarray,
|
||||
dO: np.ndarray,
|
||||
P: np.ndarray,
|
||||
scale: float,
|
||||
) -> tuple:
|
||||
"""CPU backward through masked softmax attention.
|
||||
|
||||
P already incorporates the mask (zeroed-out positions have P=0).
|
||||
|
||||
Returns: (dQ, dK, dV, D)
|
||||
"""
|
||||
D = (dO * out).sum(axis=-1, keepdims=True)
|
||||
dP = np.matmul(dO, V.transpose(0, 1, 3, 2))
|
||||
dS = P * (dP - D)
|
||||
dQ = np.matmul(dS, K) * scale
|
||||
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale
|
||||
dV = np.matmul(P.transpose(0, 1, 3, 2), dO)
|
||||
return dQ, dK, dV, D.squeeze(-1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Backward Pass with Causal Masks")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen-q", type=int, default=64)
|
||||
parser.add_argument("--seqlen-k", type=int, default=64)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 33: Backward Pass with Causal Masks")
|
||||
print("=" * 70)
|
||||
|
||||
sq, sk = args.seqlen_q, args.seqlen_k
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=sq,
|
||||
seqlen_k=sk,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} Sq={sq} Sk={sk} D={args.hdim}")
|
||||
print(f" Scale: {prob.scale:.6f}")
|
||||
print(f" Arch: {args.arch}")
|
||||
|
||||
# --- JIT compile a basic fp16 h128 fwd kernel ---
|
||||
print("\n--- JIT Compilation ---")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if setup.success:
|
||||
print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s")
|
||||
print(f" Library: {setup.library_path}")
|
||||
print(" Note: Backward requires family='bwd' kernel (separate JIT)")
|
||||
else:
|
||||
print(f" JIT build: {setup.error}")
|
||||
print(" Continuing with CPU reference only")
|
||||
|
||||
# --- Generate data ---
|
||||
np.random.seed(42)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
|
||||
|
||||
# --- Build masks ---
|
||||
masks = {
|
||||
"no_mask": np.ones((sq, sk), dtype=np.float32),
|
||||
"top_left": make_causal_mask_top_left(sq, sk),
|
||||
"bottom_right": make_causal_mask_bottom_right(sq, sk),
|
||||
}
|
||||
|
||||
# --- Per-mask forward + backward ---
|
||||
print(
|
||||
f"\n {'Mask':<16} {'Density':>8} | {'|dQ|':>10} {'|dK|':>10} {'|dV|':>10}"
|
||||
f" | {'dQ vs base':>10} {'dK vs base':>10} {'dV vs base':>10}"
|
||||
)
|
||||
print(" " + "-" * 98)
|
||||
|
||||
base_grads = None
|
||||
all_grads = {}
|
||||
|
||||
for name, mask in masks.items():
|
||||
density = mask.sum() / mask.size * 100
|
||||
|
||||
out, P, lse = cpu_masked_fwd_with_intermediates(Q, K, V, prob.scale, mask)
|
||||
dQ, dK, dV, D = cpu_masked_bwd(Q, K, V, out, dO, P, prob.scale)
|
||||
|
||||
dq_norm = float(np.abs(dQ).mean())
|
||||
dk_norm = float(np.abs(dK).mean())
|
||||
dv_norm = float(np.abs(dV).mean())
|
||||
|
||||
if base_grads is None:
|
||||
base_grads = (dQ, dK, dV)
|
||||
diff_str = f"{'---':>10} {'---':>10} {'---':>10}"
|
||||
else:
|
||||
dq_diff = float(np.abs(dQ - base_grads[0]).max())
|
||||
dk_diff = float(np.abs(dK - base_grads[1]).max())
|
||||
dv_diff = float(np.abs(dV - base_grads[2]).max())
|
||||
diff_str = f"{dq_diff:>10.2e} {dk_diff:>10.2e} {dv_diff:>10.2e}"
|
||||
|
||||
print(
|
||||
f" {name:<16} {density:>7.1f}% | {dq_norm:>10.4e} {dk_norm:>10.4e} {dv_norm:>10.4e}"
|
||||
f" | {diff_str}"
|
||||
)
|
||||
all_grads[name] = (dQ, dK, dV, D)
|
||||
|
||||
# --- Detailed backward breakdown for each mask ---
|
||||
print("\n--- Backward Stage Details ---")
|
||||
|
||||
for name, mask in masks.items():
|
||||
dQ, dK, dV, D = all_grads[name]
|
||||
out, P, lse = cpu_masked_fwd_with_intermediates(Q, K, V, prob.scale, mask)
|
||||
|
||||
print(f"\n [{name}]")
|
||||
print(" Stage 1 (dot_do_o): D = rowsum(dO * out)")
|
||||
print(f" D shape: {D.shape}, range: [{D.min():.6f}, {D.max():.6f}]")
|
||||
print(" Stage 2 (dq_dk_dv):")
|
||||
print(f" dQ range: [{dQ.min():.4e}, {dQ.max():.4e}]")
|
||||
print(f" dK range: [{dK.min():.4e}, {dK.max():.4e}]")
|
||||
print(f" dV range: [{dV.min():.4e}, {dV.max():.4e}]")
|
||||
|
||||
p_sparsity = (P < 1e-9).sum() / P.size * 100
|
||||
print(f" P sparsity (< 1e-9): {p_sparsity:.1f}%")
|
||||
|
||||
# --- Gradient norm comparison across masks ---
|
||||
print("\n--- Gradient L2 Norms ---")
|
||||
print(f"\n {'Mask':<16} {'||dQ||_2':>12} {'||dK||_2':>12} {'||dV||_2':>12}")
|
||||
print(" " + "-" * 54)
|
||||
|
||||
for name in masks:
|
||||
dQ, dK, dV, _ = all_grads[name]
|
||||
l2_dq = float(np.sqrt((dQ**2).sum()))
|
||||
l2_dk = float(np.sqrt((dK**2).sum()))
|
||||
l2_dv = float(np.sqrt((dV**2).sum()))
|
||||
print(f" {name:<16} {l2_dq:>12.4e} {l2_dk:>12.4e} {l2_dv:>12.4e}")
|
||||
|
||||
# --- Mask pattern visualization ---
|
||||
print("\n--- Mask Patterns (first 8x8 corner) ---")
|
||||
view = min(8, sq, sk)
|
||||
for name, mask in masks.items():
|
||||
corner = mask[:view, :view]
|
||||
print(f"\n {name}:")
|
||||
for r in range(view):
|
||||
row_str = " ".join("█" if corner[r, c] > 0 else "·" for c in range(view))
|
||||
print(f" {row_str}")
|
||||
|
||||
# --- Backward API pattern ---
|
||||
print("\n--- Backward GPU API Pattern ---")
|
||||
print(" The GPU backward for masked attention would use:")
|
||||
print(" FmhaKernelConfig(family='bwd', mask='top_left', ...)")
|
||||
print(" 3-stage backward plan:")
|
||||
print(" Stage 1: bwd_dot_do_o -- D = rowsum(dO * out)")
|
||||
print(" Stage 2: bwd_dq_dk_dv -- compute dQ, dK, dV with mask")
|
||||
print(" Stage 3: bwd_convert_dq -- optional dtype conversion")
|
||||
|
||||
# --- Summary ---
|
||||
print("\n" + "=" * 70)
|
||||
print(" Mask variants: no_mask, top_left, bottom_right")
|
||||
print(" Backward math: dP = dO @ V^T, dS = P*(dP - D)")
|
||||
print(" dQ = scale*dS@K, dK = scale*dS^T@Q, dV = P^T@dO")
|
||||
print(" Causal effect: Masked positions get P=0, zeroing their gradient flow")
|
||||
print(" GPU: Requires bwd-family JIT kernel with mask support")
|
||||
print(" Status: DEMO")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
277
dispatcher/examples/fmha/python/34_bwd_gqa_fmha.py
Normal file
277
dispatcher/examples/fmha/python/34_bwd_gqa_fmha.py
Normal file
@@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 34: Backward Pass with GQA (Grouped-Query Attention)
|
||||
|
||||
Demonstrates the FMHA backward pass when nhead_q != nhead_k.
|
||||
GQA groups multiple query heads per KV head. The backward pass
|
||||
must account for this by:
|
||||
- Expanding K/V heads via np.repeat for dQ computation
|
||||
- Summing dK/dV over query head groups back to KV head count
|
||||
|
||||
Tested GQA ratios: 1:1 (MHA), 2:1, 4:1, 8:1
|
||||
|
||||
CPU backward reference:
|
||||
K_exp = repeat(K, ratio) # [B, Hq, Sk, D]
|
||||
V_exp = repeat(V, ratio) # [B, Hq, Sk, Dv]
|
||||
dQ = scale * (P * (dO@V_exp^T - D)) @ K_exp
|
||||
dK_exp = scale * (P * (dO@V_exp^T - D))^T @ Q
|
||||
dV_exp = P^T @ dO
|
||||
dK = sum_over_groups(dK_exp) # [B, Hk, Sk, D]
|
||||
dV = sum_over_groups(dV_exp) # [B, Hk, Sk, Dv]
|
||||
|
||||
Usage:
|
||||
python3 34_bwd_gqa_fmha.py
|
||||
python3 34_bwd_gqa_fmha.py --nhead-q 32
|
||||
python3 34_bwd_gqa_fmha.py --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
setup_fmha_dispatcher,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
def cpu_fwd_with_intermediates(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
) -> tuple:
|
||||
"""Forward pass returning out, P, LSE (handles GQA via repeat)."""
|
||||
nhead_q, nhead_k = Q.shape[1], K.shape[1]
|
||||
if nhead_q != nhead_k:
|
||||
ratio = nhead_q // nhead_k
|
||||
K = np.repeat(K, ratio, axis=1)
|
||||
V = np.repeat(V, ratio, axis=1)
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
S_sum = S_exp.sum(axis=-1, keepdims=True)
|
||||
P = S_exp / S_sum
|
||||
out = np.matmul(P, V)
|
||||
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
|
||||
return out, P, lse
|
||||
|
||||
|
||||
def cpu_bwd_gqa(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
out: np.ndarray,
|
||||
dO: np.ndarray,
|
||||
P: np.ndarray,
|
||||
scale: float,
|
||||
nhead_q: int,
|
||||
nhead_k: int,
|
||||
) -> tuple:
|
||||
"""CPU backward with GQA head grouping.
|
||||
|
||||
P is already computed on expanded heads [B, Hq, Sq, Sk].
|
||||
K, V are original (unexpanded) [B, Hk, Sk, D].
|
||||
|
||||
Returns: (dQ, dK, dV) where dK/dV have shape [B, Hk, Sk, ...]
|
||||
"""
|
||||
ratio = nhead_q // nhead_k
|
||||
K_exp = np.repeat(K, ratio, axis=1)
|
||||
V_exp = np.repeat(V, ratio, axis=1)
|
||||
|
||||
D = (dO * out).sum(axis=-1, keepdims=True)
|
||||
dP = np.matmul(dO, V_exp.transpose(0, 1, 3, 2))
|
||||
dS = P * (dP - D)
|
||||
|
||||
dQ = np.matmul(dS, K_exp) * scale
|
||||
|
||||
dK_exp = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale
|
||||
dV_exp = np.matmul(P.transpose(0, 1, 3, 2), dO)
|
||||
|
||||
B = Q.shape[0]
|
||||
Sk, Dq = K.shape[2], K.shape[3]
|
||||
Dv = V.shape[3]
|
||||
|
||||
dK = dK_exp.reshape(B, nhead_k, ratio, Sk, Dq).sum(axis=2)
|
||||
dV = dV_exp.reshape(B, nhead_k, ratio, Sk, Dv).sum(axis=2)
|
||||
|
||||
return dQ, dK, dV
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Backward Pass with GQA")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead-q", type=int, default=16)
|
||||
parser.add_argument("--seqlen", type=int, default=64)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 34: Backward Pass with GQA")
|
||||
print("=" * 70)
|
||||
|
||||
hq = args.nhead_q
|
||||
|
||||
gqa_ratios = []
|
||||
for ratio in [1, 2, 4, 8]:
|
||||
if hq % ratio == 0 and hq // ratio >= 1:
|
||||
gqa_ratios.append(ratio)
|
||||
|
||||
print(f"\n nhead_q: {hq}")
|
||||
print(f" Ratios: {', '.join(f'{r}:1' for r in gqa_ratios)}")
|
||||
print(f" Problem: B={args.batch} S={args.seqlen} D={args.hdim}")
|
||||
|
||||
# --- JIT compile a basic fp16 h128 fwd kernel ---
|
||||
print("\n--- JIT Compilation ---")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if setup.success:
|
||||
print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s")
|
||||
print(" Note: Backward GQA requires bwd-family kernel (separate JIT)")
|
||||
else:
|
||||
print(f" JIT build: {setup.error}")
|
||||
print(" Continuing with CPU reference only")
|
||||
|
||||
# --- Sweep GQA ratios ---
|
||||
print("\n--- Backward Gradients per GQA Ratio ---")
|
||||
print(
|
||||
f"\n {'#':<3} {'Ratio':<8} {'Hq':>4} {'Hk':>4} "
|
||||
f"| {'|dQ| mean':>10} {'|dK| mean':>10} {'|dV| mean':>10} "
|
||||
f"| {'dK shape':>18} {'dV shape':>18}"
|
||||
)
|
||||
print(" " + "-" * 104)
|
||||
|
||||
all_results = {}
|
||||
|
||||
for i, ratio in enumerate(gqa_ratios, 1):
|
||||
hk = hq // ratio
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=hq,
|
||||
nhead_k=hk,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
np.random.seed(42 + i)
|
||||
Q = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
dO = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
|
||||
|
||||
out, P, lse = cpu_fwd_with_intermediates(Q, K, V, prob.scale)
|
||||
dQ, dK, dV = cpu_bwd_gqa(Q, K, V, out, dO, P, prob.scale, hq, hk)
|
||||
|
||||
dq_mean = float(np.abs(dQ).mean())
|
||||
dk_mean = float(np.abs(dK).mean())
|
||||
dv_mean = float(np.abs(dV).mean())
|
||||
|
||||
label = f"{ratio}:1"
|
||||
if ratio == 1:
|
||||
label += " MHA"
|
||||
elif hk == 1:
|
||||
label += " MQA"
|
||||
|
||||
print(
|
||||
f" {i:<3} {label:<8} {hq:>4} {hk:>4} "
|
||||
f"| {dq_mean:>10.4e} {dk_mean:>10.4e} {dv_mean:>10.4e} "
|
||||
f"| {str(dK.shape):>18} {str(dV.shape):>18}"
|
||||
)
|
||||
all_results[ratio] = (dQ, dK, dV, Q, K, V, out, dO, P, prob)
|
||||
|
||||
# --- Verify GQA backward via expanded MHA ---
|
||||
print("\n--- GQA Backward Equivalence Check ---")
|
||||
print(" Verifying: GQA bwd == MHA bwd with expanded K/V, then summed")
|
||||
|
||||
for ratio in gqa_ratios:
|
||||
if ratio == 1:
|
||||
continue
|
||||
|
||||
dQ_gqa, dK_gqa, dV_gqa, Q, K, V, out, dO, P, prob = all_results[ratio]
|
||||
hk = hq // ratio
|
||||
|
||||
K_exp = np.repeat(K, ratio, axis=1)
|
||||
V_exp = np.repeat(V, ratio, axis=1)
|
||||
|
||||
O_mha, P_mha, _ = cpu_fwd_with_intermediates(Q, K_exp, V_exp, prob.scale)
|
||||
dQ_mha, dK_mha, dV_mha = cpu_bwd_gqa(
|
||||
Q,
|
||||
K_exp,
|
||||
V_exp,
|
||||
O_mha,
|
||||
dO,
|
||||
P_mha,
|
||||
prob.scale,
|
||||
hq,
|
||||
hq,
|
||||
)
|
||||
|
||||
B = Q.shape[0]
|
||||
Sk = K.shape[2]
|
||||
dK_mha_grouped = dK_mha.reshape(B, hk, ratio, Sk, K.shape[3]).sum(axis=2)
|
||||
dV_mha_grouped = dV_mha.reshape(B, hk, ratio, Sk, V.shape[3]).sum(axis=2)
|
||||
|
||||
dq_err = float(np.abs(dQ_gqa - dQ_mha).max())
|
||||
dk_err = float(np.abs(dK_gqa - dK_mha_grouped).max())
|
||||
dv_err = float(np.abs(dV_gqa - dV_mha_grouped).max())
|
||||
|
||||
tag = "PASS" if max(dq_err, dk_err, dv_err) < 1e-5 else "FAIL"
|
||||
print(
|
||||
f" Ratio {ratio}:1 -- dQ err={dq_err:.2e} dK err={dk_err:.2e} "
|
||||
f"dV err={dv_err:.2e} {tag}"
|
||||
)
|
||||
|
||||
# --- Gradient accumulation analysis ---
|
||||
print("\n--- Head-Group Gradient Accumulation ---")
|
||||
print(" When ratio > 1, dK/dV are summed over query heads in each group.")
|
||||
print(" Higher ratio -> more terms summed -> larger gradient magnitudes.\n")
|
||||
|
||||
print(f" {'Ratio':<8} {'||dK||_2':>12} {'||dV||_2':>12} {'dK/dV ratio':>12}")
|
||||
print(" " + "-" * 48)
|
||||
|
||||
for ratio in gqa_ratios:
|
||||
dQ, dK, dV, *_ = all_results[ratio]
|
||||
l2_dk = float(np.sqrt((dK**2).sum()))
|
||||
l2_dv = float(np.sqrt((dV**2).sum()))
|
||||
dk_dv_ratio = l2_dk / (l2_dv + 1e-12)
|
||||
print(f" {ratio}:1{'':<4} {l2_dk:>12.4e} {l2_dv:>12.4e} {dk_dv_ratio:>12.2f}")
|
||||
|
||||
# --- Backward GPU API pattern ---
|
||||
print("\n--- Backward GPU API Pattern ---")
|
||||
print(" GPU backward with GQA dispatches with nhead_q != nhead_k.")
|
||||
print(" The dq_dk_dv kernel handles head grouping internally:")
|
||||
print(" - dQ: computed per query head (no grouping needed)")
|
||||
print(" - dK, dV: accumulated across head groups via atomicAdd")
|
||||
print(" or multi-buffer reduction (deterministic mode)")
|
||||
|
||||
# --- Summary ---
|
||||
print("\n" + "=" * 70)
|
||||
print(f" GQA ratios tested: {len(gqa_ratios)}")
|
||||
print(" Backward math: expand K/V -> compute grads -> sum dK/dV")
|
||||
print(" Equivalence: GQA bwd == MHA(expanded) bwd + group sum")
|
||||
print(" GPU: Requires bwd-family JIT kernel")
|
||||
print(" Status: DEMO")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
270
dispatcher/examples/fmha/python/35_bwd_bf16_fmha.py
Normal file
270
dispatcher/examples/fmha/python/35_bwd_bf16_fmha.py
Normal file
@@ -0,0 +1,270 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 35: Backward Pass with BF16 Data Type
|
||||
|
||||
Demonstrates the FMHA backward pass with bfloat16 precision.
|
||||
|
||||
BF16 differences from FP16:
|
||||
- 8-bit exponent (same as fp32) vs fp16's 5-bit
|
||||
- 7-bit mantissa vs fp16's 10-bit
|
||||
- Larger dynamic range but lower precision
|
||||
|
||||
Tolerance guidance for backward:
|
||||
- fp16 bwd: rtol=1.6e-2 typically sufficient
|
||||
- bf16 bwd: rtol=3.2e-2 for hdim > 128 (less mantissa precision)
|
||||
- bf16 bwd: rtol=2.0e-2 for hdim <= 128
|
||||
|
||||
CPU backward reference is computed in float32, then compared against
|
||||
bf16-quantized inputs to measure the precision impact.
|
||||
|
||||
Usage:
|
||||
python3 35_bwd_bf16_fmha.py
|
||||
python3 35_bwd_bf16_fmha.py --hdim 256
|
||||
python3 35_bwd_bf16_fmha.py --arch gfx942
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from fmha_utils import (
|
||||
FmhaKernelConfig,
|
||||
FmhaProblem,
|
||||
setup_fmha_dispatcher,
|
||||
detect_gpu_arch,
|
||||
cpu_attention_bwd,
|
||||
)
|
||||
|
||||
|
||||
def to_bf16(arr: np.ndarray) -> np.ndarray:
|
||||
"""Convert float32 -> bfloat16 (stored as uint16 with bf16 bit pattern)."""
|
||||
f32 = arr.astype(np.float32)
|
||||
u32 = f32.view(np.uint32)
|
||||
return (u32 >> 16).astype(np.uint16)
|
||||
|
||||
|
||||
def bf16_to_f32(arr_u16: np.ndarray) -> np.ndarray:
|
||||
"""Convert bfloat16 (uint16) -> float32."""
|
||||
u32 = arr_u16.astype(np.uint32) << 16
|
||||
return u32.view(np.float32)
|
||||
|
||||
|
||||
def cpu_fwd_with_intermediates(
|
||||
Q: np.ndarray,
|
||||
K: np.ndarray,
|
||||
V: np.ndarray,
|
||||
scale: float,
|
||||
) -> tuple:
|
||||
"""Forward pass returning out, P, LSE."""
|
||||
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
||||
S_max = S.max(axis=-1, keepdims=True)
|
||||
S_exp = np.exp(S - S_max)
|
||||
S_sum = S_exp.sum(axis=-1, keepdims=True)
|
||||
P = S_exp / S_sum
|
||||
out = np.matmul(P, V)
|
||||
lse = (np.log(S_sum.squeeze(-1)) + S_max.squeeze(-1)).astype(np.float32)
|
||||
return out, P, lse
|
||||
|
||||
|
||||
def get_bwd_tolerance(dtype: str, hdim: int) -> tuple:
|
||||
"""Recommended tolerances for backward pass validation."""
|
||||
if dtype == "bf16":
|
||||
if hdim > 128:
|
||||
return 3.2e-2, 3.2e-2
|
||||
return 2.0e-2, 2.0e-2
|
||||
return 1.6e-2, 1.6e-2
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Backward Pass with BF16")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--nhead", type=int, default=8)
|
||||
parser.add_argument("--seqlen", type=int, default=64)
|
||||
parser.add_argument("--hdim", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 35: Backward Pass with BF16")
|
||||
print("=" * 70)
|
||||
|
||||
prob = FmhaProblem(
|
||||
batch=args.batch,
|
||||
nhead_q=args.nhead,
|
||||
nhead_k=args.nhead,
|
||||
seqlen_q=args.seqlen,
|
||||
seqlen_k=args.seqlen,
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
)
|
||||
|
||||
print(f"\n Problem: B={prob.batch} H={prob.nhead_q} S={args.seqlen} D={args.hdim}")
|
||||
print(f" Scale: {prob.scale:.6f}")
|
||||
print(f" Arch: {args.arch}")
|
||||
|
||||
# --- JIT compile a basic fp16 h128 fwd kernel ---
|
||||
print("\n--- JIT Compilation ---")
|
||||
config = FmhaKernelConfig(
|
||||
data_type="fp16",
|
||||
hdim_q=args.hdim,
|
||||
hdim_v=args.hdim,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
setup = setup_fmha_dispatcher(config)
|
||||
if setup.success:
|
||||
print(f" Fwd kernel compiled: {setup.build_time_s:.1f}s")
|
||||
print(
|
||||
" Note: Native bf16 bwd kernel requires separate JIT with data_type='bf16'"
|
||||
)
|
||||
else:
|
||||
print(f" JIT build: {setup.error}")
|
||||
print(" Continuing with CPU reference only")
|
||||
|
||||
# --- Generate data in both dtypes ---
|
||||
np.random.seed(42)
|
||||
Q_f32 = (np.random.randn(*prob.q_shape()) * 0.1).astype(np.float32)
|
||||
K_f32 = (np.random.randn(*prob.k_shape()) * 0.1).astype(np.float32)
|
||||
V_f32 = (np.random.randn(*prob.v_shape()) * 0.1).astype(np.float32)
|
||||
dO_f32 = (np.random.randn(*prob.o_shape()) * 0.1).astype(np.float32)
|
||||
|
||||
Q_fp16 = Q_f32.astype(np.float16).astype(np.float32)
|
||||
K_fp16 = K_f32.astype(np.float16).astype(np.float32)
|
||||
V_fp16 = V_f32.astype(np.float16).astype(np.float32)
|
||||
dO_fp16 = dO_f32.astype(np.float16).astype(np.float32)
|
||||
|
||||
Q_bf16 = bf16_to_f32(to_bf16(Q_f32))
|
||||
K_bf16 = bf16_to_f32(to_bf16(K_f32))
|
||||
V_bf16 = bf16_to_f32(to_bf16(V_f32))
|
||||
dO_bf16 = bf16_to_f32(to_bf16(dO_f32))
|
||||
|
||||
# --- Quantization error comparison ---
|
||||
print("\n--- Quantization Error ---")
|
||||
print(
|
||||
f"\n {'Tensor':<6} {'FP16 quant err':>16} {'BF16 quant err':>16} {'BF16/FP16':>10}"
|
||||
)
|
||||
print(" " + "-" * 52)
|
||||
|
||||
for name, orig, fp16, bf16 in [
|
||||
("Q", Q_f32, Q_fp16, Q_bf16),
|
||||
("K", K_f32, K_fp16, K_bf16),
|
||||
("V", V_f32, V_fp16, V_bf16),
|
||||
("dO", dO_f32, dO_fp16, dO_bf16),
|
||||
]:
|
||||
fp16_err = float(np.abs(orig - fp16).max())
|
||||
bf16_err = float(np.abs(orig - bf16).max())
|
||||
ratio = bf16_err / (fp16_err + 1e-15)
|
||||
print(f" {name:<6} {fp16_err:>16.2e} {bf16_err:>16.2e} {ratio:>10.1f}x")
|
||||
|
||||
# --- Backward with both dtypes ---
|
||||
print("\n--- Backward Gradients: FP16 vs BF16 Inputs ---")
|
||||
|
||||
dtype_configs = [
|
||||
("fp16", Q_fp16, K_fp16, V_fp16, dO_fp16),
|
||||
("bf16", Q_bf16, K_bf16, V_bf16, dO_bf16),
|
||||
]
|
||||
|
||||
grad_results = {}
|
||||
for dtype_name, Q_d, K_d, V_d, dO_d in dtype_configs:
|
||||
out, P, lse = cpu_fwd_with_intermediates(Q_d, K_d, V_d, prob.scale)
|
||||
dQ, dK, dV = cpu_attention_bwd(Q_d, K_d, V_d, out, dO_d, P, prob.scale)
|
||||
grad_results[dtype_name] = (dQ, dK, dV)
|
||||
|
||||
print(f"\n {'Dtype':<6} {'|dQ| mean':>12} {'|dK| mean':>12} {'|dV| mean':>12}")
|
||||
print(" " + "-" * 48)
|
||||
for dtype_name in ["fp16", "bf16"]:
|
||||
dQ, dK, dV = grad_results[dtype_name]
|
||||
print(
|
||||
f" {dtype_name:<6} {np.abs(dQ).mean():>12.4e} "
|
||||
f"{np.abs(dK).mean():>12.4e} {np.abs(dV).mean():>12.4e}"
|
||||
)
|
||||
|
||||
# --- Cross-dtype gradient difference ---
|
||||
print("\n--- FP16 vs BF16 Backward Difference ---")
|
||||
dQ_fp, dK_fp, dV_fp = grad_results["fp16"]
|
||||
dQ_bf, dK_bf, dV_bf = grad_results["bf16"]
|
||||
|
||||
print(
|
||||
f"\n {'Grad':<6} {'Max abs diff':>14} {'Mean abs diff':>14} {'Max rel diff':>14}"
|
||||
)
|
||||
print(" " + "-" * 52)
|
||||
for name, g_fp, g_bf in [
|
||||
("dQ", dQ_fp, dQ_bf),
|
||||
("dK", dK_fp, dK_bf),
|
||||
("dV", dV_fp, dV_bf),
|
||||
]:
|
||||
abs_diff = np.abs(g_fp - g_bf)
|
||||
max_abs = float(abs_diff.max())
|
||||
mean_abs = float(abs_diff.mean())
|
||||
max_rel = float((abs_diff / (np.abs(g_fp) + 1e-8)).max())
|
||||
print(f" {name:<6} {max_abs:>14.4e} {mean_abs:>14.4e} {max_rel:>14.4e}")
|
||||
|
||||
# --- Tolerance analysis for different hdims ---
|
||||
print("\n--- Recommended Backward Tolerances ---")
|
||||
print(f"\n {'Dtype':<6} {'hdim':>6} {'rtol':>10} {'atol':>10} {'Note'}")
|
||||
print(" " + "-" * 54)
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
for hdim in [64, 128, 256]:
|
||||
rtol, atol = get_bwd_tolerance(dtype, hdim)
|
||||
note = ""
|
||||
if dtype == "bf16" and hdim > 128:
|
||||
note = "<-- relaxed for large hdim"
|
||||
print(f" {dtype:<6} {hdim:>6} {rtol:>10.1e} {atol:>10.1e} {note}")
|
||||
|
||||
# --- Validate backward with appropriate tolerances ---
|
||||
print("\n--- Validation Against F32 Reference ---")
|
||||
out_f32, P_f32, _ = cpu_fwd_with_intermediates(Q_f32, K_f32, V_f32, prob.scale)
|
||||
dQ_ref, dK_ref, dV_ref = cpu_attention_bwd(
|
||||
Q_f32,
|
||||
K_f32,
|
||||
V_f32,
|
||||
out_f32,
|
||||
dO_f32,
|
||||
P_f32,
|
||||
prob.scale,
|
||||
)
|
||||
|
||||
for dtype_name in ["fp16", "bf16"]:
|
||||
rtol, atol = get_bwd_tolerance(dtype_name, args.hdim)
|
||||
dQ, dK, dV = grad_results[dtype_name]
|
||||
|
||||
print(f"\n [{dtype_name}] rtol={rtol:.1e}, atol={atol:.1e}")
|
||||
for gname, g, g_ref in [
|
||||
("dQ", dQ, dQ_ref),
|
||||
("dK", dK, dK_ref),
|
||||
("dV", dV, dV_ref),
|
||||
]:
|
||||
max_err = float(np.abs(g - g_ref).max())
|
||||
ok = bool(np.allclose(g, g_ref, rtol=rtol, atol=atol))
|
||||
print(f" {gname}: max_err={max_err:.4e} {'PASS' if ok else 'FAIL'}")
|
||||
|
||||
# --- BF16 backward GPU API pattern ---
|
||||
print("\n--- BF16 Backward GPU API Pattern ---")
|
||||
print(" Native bf16 backward kernel:")
|
||||
print(" FmhaKernelConfig(family='bwd', data_type='bf16', ...)")
|
||||
print(" Internal accumulation stays in fp32 for numerical stability.")
|
||||
print(" Stage 3 (convert_dq) converts fp32 accumulator -> bf16 output.")
|
||||
print(" BF16 advantage: wider dynamic range prevents overflow in")
|
||||
print(" intermediate products (S = Q @ K^T) for large sequences.")
|
||||
|
||||
# --- Summary ---
|
||||
print("\n" + "=" * 70)
|
||||
print(" Data types: fp16 (10-bit mantissa) vs bf16 (7-bit mantissa)")
|
||||
print(" Tolerances: bf16 bwd needs ~2x relaxed rtol vs fp16")
|
||||
rtol_used, _ = get_bwd_tolerance("bf16", args.hdim)
|
||||
print(f" Current hdim: {args.hdim} -> bf16 rtol={rtol_used:.1e}")
|
||||
print(" GPU: Requires bwd-family JIT kernel with data_type='bf16'")
|
||||
print(" Status: DEMO")
|
||||
print("=" * 70)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user