Merge branch 'develop' into tianxing/unified-attention

This commit is contained in:
Illia Silin
2026-01-19 15:33:37 -08:00
committed by GitHub
578 changed files with 38474 additions and 13390 deletions

12
.github/CODEOWNERS vendored
View File

@@ -1,8 +1,8 @@
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron
# Documentation files
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
# Header directory for Doxygen documentation
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron

View File

@@ -54,7 +54,7 @@ jobs:
with:
repository: "ROCm/TheRock"
path: "TheRock"
ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit
ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit
- name: Setup ccache
run: |
@@ -78,8 +78,9 @@ jobs:
run: |
git config --global --add safe.directory '*'
# Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo
rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch
git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch
rm ./TheRock/patches/amd-mainline/rocm-libraries/0003-Find-rocm_smi-via-config-files.patch
rm ./TheRock/patches/amd-mainline/rocm-libraries/0007-Remove-Windows-third_party_dlls-copying-code.patch
# git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch
- name: Install python deps
run: |

View File

@@ -51,7 +51,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: "ROCm/TheRock"
ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit
ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit
- name: Run setup test environment workflow
uses: './.github/actions/setup_test_environment'

View File

@@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit
ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit
- name: "Configuring CI options"
env:

View File

@@ -5,12 +5,16 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## (Unreleased) Composable Kernel 1.3.0
### Added
* Added preshuffleB support for abquant mode in blockscale GEMM.
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
* Added FP8 KV cache support for FMHA batch prefill.
* Added support for gfx1153 target.
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
### Changed
@@ -40,6 +44,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added top-k sigmoid kernel in CK_TILE
* Added the blockscale 2D support for CK_TILE GEMM.
* Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types
* Added reduce and multi reduction kernels
### Changed

View File

@@ -31,11 +31,12 @@ endif()
# Default installation path
if(NOT WIN32)
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
else()
set(CMAKE_INSTALL_PREFIX "C:/dist/TheRock" CACHE PATH "")
endif()
set(version 1.2.0)
# Check support for CUDA/HIP in Cmake
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
project(composable_kernel VERSION ${version} LANGUAGES CXX)
include(CTest)
option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON)
@@ -162,7 +163,13 @@ execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMI
configure_file(include/ck/version.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/version.h)
set(ROCM_SYMLINK_LIBS OFF)
find_package(ROCM REQUIRED PATHS /opt/rocm)
if (WIN32)
find_package(ROCmCMakeBuildTools REQUIRED PATHS C:/dist/TheRock)
set(HIP_PLATFORM "amd" CACHE STRING "HIP platform")
else()
find_package(ROCM REQUIRED PATHS /opt/rocm)
endif()
include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers)
@@ -189,7 +196,10 @@ if(GPU_TARGETS)
else()
set(USER_GPU_TARGETS 0)
endif()
find_package(hip REQUIRED)
enable_language(HIP)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")

View File

@@ -2,7 +2,7 @@ ARG BASE_DOCKER="rocm/pytorch:latest"
FROM $BASE_DOCKER
ARG AITER_BRANCH="main"
ARG CK_AITER_BRANCH="develop"
RUN pip install pandas zmq einops ninja && \
RUN pip install pandas zmq einops ninja tabulate && \
pip install numpy==1.26.2 && \
sudo mkdir /home/jenkins && \
sudo mkdir /home/jenkins/workspace && \

144
Jenkinsfile vendored
View File

@@ -574,6 +574,8 @@ def cmake_build(Map conf=[:]){
def setup_cmd
def build_cmd
def execute_cmd = conf.get("execute_cmd", "")
//check the node gpu architecture
def arch_name = check_arch_name()
if(!setup_args.contains("NO_CK_BUILD")){
if (params.NINJA_BUILD_TRACE) {
echo "running ninja build trace"
@@ -646,15 +648,15 @@ def cmake_build(Map conf=[:]){
//run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${check_arch_name()}.json"
archiveArtifacts "ck_build_trace_${check_arch_name()}.json"
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${check_arch_name()}.json"
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json"
archiveArtifacts "ck_build_trace_${arch_name}.json"
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json"
if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){
if (params.NINJA_FTIME_TRACE) {
echo "running ClangBuildAnalyzer"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${check_arch_name()}.log"
archiveArtifacts "clang_build_analysis_${check_arch_name()}.log"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log"
archiveArtifacts "clang_build_analysis_${arch_name}.log"
}
@@ -672,8 +674,8 @@ def cmake_build(Map conf=[:]){
if(params.BUILD_PACKAGES){
echo "Build ckProfiler packages"
sh 'ninja -j64 package'
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}"
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
}
}
if(params.BUILD_INSTANCES_ONLY){
@@ -699,16 +701,14 @@ def cmake_build(Map conf=[:]){
if(params.BUILD_PACKAGES){
echo "Build ckProfiler packages"
sh 'ninja -j64 package'
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}"
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
}
}
}
}
}
//check the node gpu architecture
def arch_name = check_arch_name()
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
archiveArtifacts "perf_fmha_*.log"
@@ -811,41 +811,12 @@ def Build_CK(Map conf=[:]){
archiveArtifacts "perf_*.log"
stash includes: "perf_**.log", name: "perf_log_${arch}"
}
// disable performance tests on gfx1030 for now.
//else if ( arch == "gfx10"){
// run basic tests on gfx1030
// echo "Run gemm performance tests"
// sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10"
// archiveArtifacts "perf_onnx_gemm_gfx10.log"
// stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10"
//}
else if ( arch == "gfx11"){
// run basic tests on gfx11
else if ( arch != "gfx10"){
// run basic tests on gfx11/gfx12/gfx908/gfx950, but not on gfx10, it takes too long
echo "Run gemm performance tests"
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11"
archiveArtifacts "perf_onnx_gemm_gfx11.log"
stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11"
}
else if ( arch == "gfx120" ){
// run basic tests on gfx12
echo "Run gemm performance tests"
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12"
archiveArtifacts "perf_onnx_gemm_gfx12.log"
stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12"
}
else if ( arch == "gfx908" ){
// run basic tests on gfx908
echo "Run performance tests"
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908"
archiveArtifacts "perf_onnx_gemm_gfx908.log"
stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908"
}
else if ( arch == "gfx950" ){
// run basic tests on gfx950
echo "Run performance tests"
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950"
archiveArtifacts "perf_onnx_gemm_gfx950.log"
stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950"
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} ${arch}"
archiveArtifacts "perf_onnx_gemm_*.log"
stash includes: "perf_onnx_gemm_**.log", name: "perf_log_${arch}"
}
}
}
@@ -1046,9 +1017,10 @@ def run_aiter_tests(Map conf=[:]){
sh "rocminfo"
sh "python3 --version"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py"
//sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" //temporarily disable
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_batch_prefill.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py"
@@ -1201,8 +1173,8 @@ pipeline {
description: "Run the ck_tile FMHA tests (default: OFF)")
booleanParam(
name: "RUN_TILE_ENGINE_BASIC_TESTS",
defaultValue: false,
description: "Run the tile_engine_basic tests (default: OFF)")
defaultValue: true,
description: "Run the tile_engine_basic tests (default: ON)")
booleanParam(
name: "RUN_TILE_ENGINE_GEMM_TESTS",
defaultValue: false,
@@ -1346,21 +1318,15 @@ pipeline {
agent{ label rocmnode("nogpu") }
environment{
setup_args = "NO_CK_BUILD"
execute_cmd = "(cd .. && git ls-files \'*.h\' \
\'*.hpp\' \
\'*.cpp\' \
\'*.h.in\' \
\'*.hpp.in\' \
\'*.cpp.in\' \
\'*.cl\' \
| grep -v 'build/' \
| grep -v 'include/rapidjson' \
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\') && \
execute_cmd = """cd .. && \
find . -type f \\( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.cl' \\) \
-not -path '*/build/*' -not -path '*/include/rapidjson/*' | \
xargs -P 8 -I{} sh -c 'clang-format-18 -style=file {} | diff -u - {} || (echo "ERROR: {} needs formatting" && exit 1)' && \
/cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \
-D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \
-D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \
-U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"""
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd)
@@ -1376,17 +1342,10 @@ pipeline {
agent{ label rocmnode("nogpu") }
environment{
setup_args = "NO_CK_BUILD"
execute_cmd = "(cd .. && git ls-files \
\'*.h\' \
\'*.hpp\' \
\'*.cpp\' \
\'*.h.in\' \
\'*.hpp.in\' \
\'*.cpp.in\' \
\'*.cl\' \
| grep -v 'build/' \
| grep -v 'include/rapidjson' \
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\')"
execute_cmd = """cd .. && \
find . -type f \\( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.cl' \\) \
-not -path '*/build/*' -not -path '*/include/rapidjson/*' | \
xargs -P 8 -I{} sh -c 'clang-format-18 -style=file {} | diff -u - {} || (echo "ERROR: {} needs formatting" && exit 1)'"""
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd)
@@ -1469,8 +1428,8 @@ pipeline {
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \
./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases"""
make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \
./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases"""
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
@@ -1650,7 +1609,10 @@ pipeline {
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
-D GEMM_PRESHUFFLE_CONFIG_FILE="default_ci_config.json" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all """
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
@@ -1667,37 +1629,6 @@ pipeline {
}
parallel
{
stage("Run TILE_ENGINE_GEMM Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx90a" \
-D GEMM_UNIVERSAL_DATATYPE="fp8;fp16" \
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
-D GEMM_STREAMK_DATATYPE="fp8;fp16" \
-D GEMM_STREAMK_LAYOUT="rcr" \
-D GEMM_MULTI_D_DATATYPE="fp16" \
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Run TILE_ENGINE_GEMM Tests on gfx942")
{
when {
@@ -1787,7 +1718,10 @@ pipeline {
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = """ -DGPU_TARGETS="gfx942" -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
// SLES15 is a legacy platform with limited C++20 ecosystem support (older system libraries,
// standard library implementation). While the ROCm compiler supports C++20, the experimental
// CK Builder requires full C++20 feature support that does not be reliably available on SLES15.
setup_args = """ -DGPU_TARGETS="gfx942" -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 -DCK_EXPERIMENTAL_BUILDER=OFF """
execute_args = " "
}
steps{

View File

@@ -137,6 +137,22 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
```
**[See Note on -j](#notes)**
### Building for Windows
Install TheRock and run CMake configure as
```bash
cmake \
-D CMAKE_PREFIX_PATH="C:/dist/TheRock" \
-D CMAKE_CXX_COMPILER="C:/dist/TheRock/bin/hipcc.exe" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx1151" \
-G Ninja \
..
```
Use Ninja to build either the whole library or individual targets.
## Optional post-install steps
* Build examples and tests:

View File

@@ -1,35 +1,13 @@
.. meta::
:description: Composable Kernel CK Tile buffer views
:keywords: composable kernel, CK, CK Tile, ROCm, API, buffer view, raw memory
.. _ck_tile_buffer_views:
CK Tile buffer view
=======================
Buffer view is an abstraction that provides structured access to memory. The ``buffer_view`` class is exposed in ``include/ck_tile/core/tensor/buffer_view.hpp``.
Buffer view serves as the foundation for :ref:`ck_tile_tensor_views`. BufferView handles memory addressing and type safety, while TensorView builds upon this to add multi-dimensional coordinates (shape and strides).
Buffer view provides the following advantages:
* A unified interface across global, shared, and register memory
* Address spaces encoded in types, taking advantage of compile-time type checking
* Configurable handling of invalid values, out-of-bounds operations, and conditional access patterns
* Atomic operations for parallel algorithms
* AMD GPU-specific optimizations
* Automatic application of appropriate memory ordering constraints and cache control directives based on the target address space and operation type
[TO DO: do we want to say more about these items? There wasn't a lot of detail in the original text, so I put them in a list for now]
Buffer Views - Raw Memory Access
Address Space Usage Patterns
----------------------------
[TO DO: explain in words what the diagram shows]
..
Original mermaid diagram (edit here, then run update_diagrams.py)
..
Original mermaid diagram (edit here, then run update_diagrams.py)
@@ -66,18 +44,26 @@ Address Space Usage Patterns
style Compute fill:#e0e7ff,stroke:#4338ca,stroke-width:2px
.. image:: diagrams/buffer_views_1.svg
:alt: Diagram
:align: center
C++ Implementation
------------------
**File**: ``include/ck_tile/core/tensor/buffer_view.hpp``
Basic Creation
~~~~~~~~~~~~~~
[TO DO: remove "modern C++ template metaprogramming" and "zero-overhead abstraction"]
By encoding critical properties such as buffer size and address space as template parameters, BufferView transforms what would traditionally be runtime decisions into compile-time constants. This design philosophy enables the compiler to perform aggressive optimizations, including constant propagation, loop unrolling, and instruction selection, that would be impossible with runtime parameters.
[TO DO: might want to move the implementation details to a separate section under "reference"]
The use of compile-time constants extends beyond mere optimization. When the buffer size is encoded in the type system using constructs like ``number<8>{}``, the compiler can statically verify that array accesses are within bounds, eliminate unnecessary bounds checks, and even restructure algorithms to better match the known data dimensions. This compile-time knowledge propagates through the entire computation, enabling optimizations at every level of the abstraction hierarchy.
The address space template parameter represents another crucial design decision. By making the memory space part of the type system, BufferView ensures that operations appropriate for one memory space cannot be accidentally applied to another. This type safety prevents common errors such as attempting atomic operations on register memory or using global memory synchronization primitives on local memory. The compiler enforces these constraints at compile time, transforming potential runtime errors into compile-time diagnostics.
.. code-block:: cpp
@@ -98,7 +84,6 @@ Basic Creation
buffer_size // number of elements
);
// Implementation detail: The actual C++ template is:
// template <address_space_enum BufferAddressSpace,
// typename T,
@@ -123,17 +108,14 @@ Basic Creation
static_assert(space == address_space_enum::global, "Should be global memory");
}
[TO DO: add details and remove unnecessary comments; the "implementation detail" comment can be moved out and either placed outside and explained further, or just removed, depending on what we want to do]
Out-of-Bounds Handling
~~~~~~~~~~~~~~~~~~~~~~
[TO DO: might want to put this implementation detail in the reference section]
Traditional approaches to bounds checking often involve conditional branches that can severely impact performance on GPU architectures, where divergent execution paths within a warp lead to serialization. BufferView's approach sidesteps this problem through two carefully designed modes that maintain performance while providing predictable behavior.
Buffer view uses two modes, zero value mode and custom value mode, that can prevent serialization during bounds checking.
The Zero Value Mode leverages the mathematical property that zero often serves as a neutral element in computations. When an access falls outside the valid buffer range, this mode returns numerical zero without branching. This approach proves particularly effective for algorithms like convolution, where out-of-bounds accesses naturally correspond to zero-padding. The branchless implementation ensures that all threads in a warp follow the same execution path, maintaining the SIMD efficiency that is crucial for GPU performance.
Zero value mode returns zero without branching when an access falls outside the valid buffer range. This is useful in convolutions where out-of-bounds accesses correspond to zero-padding.
Custom value mode returns a custom value without branching when an access falls outside the valid buffer range. Custom value mode accommodates algorithms that require specific values for boundary conditions.
[TO DO: there were two examples of custom value mode that I removed. I removed them because unlike for zero value mode where the example was convolution, the example was vague in custom value. Is there a more specific example of where custom value would be used?]
The Custom Value Mode extends this concept by letting developers specify arbitrary sentinel values for invalid accesses. This flexibility accommodates algorithms that require specific values for boundary conditions, such as using negative infinity for maximum operations or special markers for missing data. The implementation maintains the same branchless characteristics, using conditional move instructions or predicated execution to avoid divergent control flow.
.. code-block:: cpp
@@ -158,92 +140,39 @@ Custom value mode returns a custom value without branching when an access falls
data, buffer_size, custom_invalid);
}
When ``InvalidElementUseNumericalZeroValue`` is set to true, the system uses zero value mode for out of bounds checking. When ``InvalidElementUseNumericalZeroValue`` is set to false, custom value mode is used. Zero value mode is used by default.
.. note::
Zero or custom invalid value is only returned for complete invalid values or out of bound access, for example when the first address of the vector is invalid. Partial out of bounds access during vector reads will not return useful results.
.. code-block:: cpp
// Create data array
constexpr size_t buffer_size = 8;
float data[buffer_size] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
float custom_invalid = 13.0f;
// Create global memory buffer view with zero invalid value mode (default)
auto buffer_view = make_buffer_view<address_space_enum::global>(data, buffer_size, custom_invalid);
// Invalid element access with is_valid_element=false
// Returns custom_invalid due to custom invalid value mode
auto invalid_value = buffer_view.template get<float>(0, 0, false);
printf("Invalid element: %.1f\n", invalid_value.get(0));
// Out of bounds access - AMD buffer addressing handles bounds checking
// Will return custom_invalid when accessing beyond buffer_size
auto oob_value = buffer_view.template get<float>(0, 100, true);
printf("Out of bounds: %.1f\n", oob_value.get(0));
Get Operations
--------------
[TO DO: might want to put this implementation detail in the reference section]
Scalar Access
~~~~~~~~~~~~~
The signature for the ``buffer_view`` ``get()`` takes four parameters:
The get operations in BufferView form the cornerstone of memory access patterns in CK Tile. These operations embody a advanced understanding of GPU memory systems and the patterns that lead to optimal performance. The scalar access interface incorporates multiple layers of optimization and safety mechanisms that work together to provide both performance and correctness.
``i``: the primary offset into the buffer expressed in terms of elements of type T rather than raw bytes.
The parameter structure of scalar access operations reflects careful design choices aimed at maximizing flexibility while maintaining efficiency. The base index parameter ``i`` represents the primary offset into the buffer, expressed in terms of elements of type T rather than raw bytes. This type-aware indexing prevents common errors related to pointer arithmetic and ensures that vector types are handled correctly. The additional ``linear_offset`` parameter provides fine-grained control over the final access location, enabling complex access patterns without requiring expensive index calculations in the kernel code.
``linear_offset``: [TO DO: what is this?]
The ``is_valid_element`` parameter provides a solution to conditional memory access. Rather than using traditional if-statements that would cause warp divergence, this boolean parameter enables predicated execution where the memory access occurs unconditionally but the result is conditionally used. This approach maintains uniform control flow across all threads in a warp, preserving the SIMD execution model that is fundamental to GPU performance.
``is_valid_element``: [TO DO: what is this?]
The invalid value modes provide a mechanism for handling the boundary conditions that arise in parallel algorithms. When ``InvalidElementUseNumericalZeroValue`` is set to true, the system returns zero for any invalid access, whether due to the ``is_valid_element`` flag or out-of-bounds indexing. This mode is important for algorithms where zero serves as a natural extension value, such as in image processing with zero-padding or sparse matrix operations where missing elements are implicitly zero.
[TO DO: the last param, that's the out of bounds handling, yes?
.. code:: cpp
The custom invalid value mode, activated when ``InvalidElementUseNumericalZeroValue`` is false, offers additional flexibility for algorithms with specific boundary requirements. This mode returns a user-specified value for invalid accesses, accommodating use cases such as sentinel values in sorting algorithms, infinity values in optimization problems, or special markers in data processing pipelines. The implementation ensures that this flexibility comes without performance penalty, using the same branchless execution strategies as the zero mode.
get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
Out-of-bounds handling leverages AMD GPU hardware capabilities to provide safety with minimal impact to performance. When AMD buffer addressing is enabled, the hardware automatically clamps memory accesses to valid ranges, preventing the segmentation faults that would occur on CPU systems. This hardware-assisted bounds checking operates at wire speed, adding no overhead to the memory access path while ensuring that kernels cannot corrupt memory outside their allocated regions.
Vector Access
~~~~~~~~~~~~~
[TO DO: need some context around the code]
Vector memory operations represent one of the most critical optimizations available in modern GPU programming, and BufferView's vector access interface exposes this capability. By using template parameters to specify vector types through constructs like ``ext_vector_t<float, N>``, the interface enables compile-time selection of optimal load and store instructions that can transfer multiple data elements in a single memory transaction. This vectorization is crucial for :ref:`ck_tile_load_store_traits`, which automatically selects optimal access patterns.
[TO DO: code chunks need to have detail and explanation so that the reader can see what they're trying to demonstrate.]
The significance of vector operations extends beyond bandwidth improvements. GPUs are designed with wide memory buses that can transfer 128, 256, or even 512 bits per transaction. When scalar operations access only 32 bits at a time, they utilize only a fraction of this available bandwidth. Vector operations align with these wide buses, enabling full bandwidth utilization and reducing the total number of memory transactions required.
The implementation of vector access maintains the same parameter structure as scalar operations, providing consistency across the API while automatically handling the complexities of multi-element transfers. The system manages alignment requirements, ensures that vector loads and stores use the optimal hardware instructions, and handles cases where vector operations extend beyond buffer boundaries. This transparent handling of edge cases allows developers to use vector operations confidently without manual boundary checks or special-case code for partial vectors.
.. code-block:: cpp
// Create buffer view
float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
auto buffer_view = make_buffer_view<address_space_enum::global>(data, 8);
// Simple get - compile-time bounds checking when possible
auto value_buf = buffer_view.template get<float>(0,1,true); //get the buffer from the buffer view
float value = value_buf.get(0); //get the value from the buffer
// Get with valid flag - branchless conditional access
bool valid_flag = false;
value_buf = buffer_view.template get<float>(0,1,valid_flag);
value = value_buf.get(0);
// Returns 0 valid_flag is false
// vectorized get
using float2 = ext_vector_t<float, 2>;
auto vector_buf = buffer_view.template get<float2>(0, 0, true);
// Loads 2 floats in a single instruction
float val1 = vector_buf.get(0);
float val2 = vector_buf.get(1);
}
``ext_vector_t<float, N>`` enables compile-time selection of optimal load and store instructions that can transfer multiple data elements in a single memory transaction.
[TO DO: what is it actually doing? When does one use scalars vs vectors? Is it application specific or are there ]
Scalar vs Vectorized Memory Access
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
..
Original mermaid diagram (edit here, then run update_diagrams.py)
..
Original mermaid diagram (edit here, then run update_diagrams.py)
@@ -287,8 +216,9 @@ The signature for the ``buffer_view`` ``get()`` takes four parameters:
Understanding BufferView Indexing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[TO DO: an explanation of the diagram is needed]
..
Original mermaid diagram (edit here, then run update_diagrams.py)
..
Original mermaid diagram (edit here, then run update_diagrams.py)
@@ -335,14 +265,69 @@ Understanding BufferView Indexing
.. image:: diagrams/buffer_views_3.svg
:alt: Diagram
:align: center
C++ Get Operations
~~~~~~~~~~~~~~~~~~
.. code-block:: cpp
__device__ void example_get_operations()
{
// Create buffer view
float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
auto buffer_view = make_buffer_view<address_space_enum::global>(data, 8);
// Simple get - compile-time bounds checking when possible
auto value_buf = buffer_view.template get<float>(0,1,true); //get the buffer from the buffer view
float value = value_buf.get(0); //get the value from the buffer
// Get with valid flag - branchless conditional access
bool valid_flag = false;
value_buf = buffer_view.template get<float>(0,1,valid_flag);
value = value_buf.get(0);
// Returns 0 valid_flag is false
// vectorized get
using float2 = ext_vector_t<float, 2>;
auto vector_buf = buffer_view.template get<float2>(0, 0, true);
// Loads 2 floats in a single instruction
float val1 = vector_buf.get(0);
float val2 = vector_buf.get(1);
}
Custom Value Return Mode for OOB & Invalid Access
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: cpp
void scalar_get_operations_example() {
// Create data array
constexpr size_t buffer_size = 8;
float data[buffer_size] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
float custom_invalid = 13.0f;
// Create global memory buffer view with zero invalid value mode (default)
auto buffer_view = make_buffer_view<address_space_enum::global>(data, buffer_size, custom_invalid);
// Invalid element access with is_valid_element=false
// Returns custom_invalid due to custom invalid value mode
auto invalid_value = buffer_view.template get<float>(0, 0, false);
printf("Invalid element: %.1f\n", invalid_value.get(0));
// Out of bounds access - AMD buffer addressing handles bounds checking
// Will return custom_invalid when accessing beyond buffer_size
auto oob_value = buffer_view.template get<float>(0, 100, true);
printf("Out of bounds: %.1f\n", oob_value.get(0));
}
.. note::
Partial Out Of Bound (OOB) access during vector reads will return 'junk' values for the OOB access. Zero or custom invalid value is only returned for complete invalid/OOB access, in other words, it is only returned when the first address of the vector is invalid.
Update Operations
-----------------
Update operations modify the buffer content. The ``set()`` method writes a value to a specific location.
.. code-block:: cpp
void scalar_set_operations_example() {
@@ -373,8 +358,6 @@ Update operations modify the buffer content. The ``set()`` method writes a value
Atomic Operations
-----------------
[TO DO: this needs information]
Atomic vs Non-Atomic Operations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -441,3 +424,21 @@ C++ Atomic Operations
__syncthreads();
}
Summary
-------
BufferView abstracts GPU memory hierarchies behind a concise interface. The approach is intended to keep overhead small while enabling optimizations that are otherwise awkward in low-level code.
BufferView offers a unified interface across global, shared, and register memory. Using the same API for each space can lower cognitive overhead, reduce certain classes of mistakes, and support code reuse via template parameters.
Address spaces are encoded in types so that common errors are reported at compile time. Consistent with CK Tiles zero-overhead design aim, compile-time checks are favored over runtime guards. The C++ type system enforces memory-space constraints and can make valid cases more amenable to compiler optimization.
BufferView supports configurable handling of invalid values, optional runtime bounds checks, and conditional access patterns. It also provides atomic operations for thread-safe updates. These features are intended to cover common edge cases without adding unnecessary overhead.
By hiding the complexity of different memory spaces while exposing the operations needed for high-performance GPU computing, BufferView establishes a pattern that the rest of CK Tile follows: compile-time abstractions that enhance rather than compromise performance. The :ref:`ck_tile_tensor_views` and :ref:`ck_tile_distribution` add capability while maintaining the efficiency established at the base. For hardware-specific details about memory hierarchies, see :ref:`ck_tile_gpu_basics`.
Next Steps
----------
Continue to :ref:`ck_tile_tensor_views` to learn how to build structured tensor views on top of buffer views.

View File

@@ -1,2 +1,2 @@
rocm-docs-core[api_reference]==1.31.1
rocm-docs-core[api_reference]==1.31.3
sphinxcontrib-bibtex==2.6.5

View File

@@ -237,7 +237,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core[api-reference]==1.31.1
rocm-docs-core[api-reference]==1.31.3
# via -r requirements.in
rpds-py==0.24.0
# via

View File

@@ -149,3 +149,7 @@ add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3)
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale)
add_example_executable(example_gemm_wmma_fp8_bpreshuffle gemm_wmma_fp8_bpreshuffle.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_bpreshuffle)
add_example_executable(example_gemm_wmma_fp16_bpreshuffle gemm_wmma_fp16_bpreshuffle.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_bpreshuffle)

View File

@@ -0,0 +1,70 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include <cstddef>
#include <iostream>
#include <type_traits>
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ComputeTypeA = F16;
using ComputeTypeB = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr int KPack = 8; // int4 -> 32, fp8 -> 16, fp16 -> 8
// clang-format off
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
128,
32, 128, 128,
8, 8,
16, 16,
2, 2,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 1, S<1, 16, 1, 8>, S<4, 4, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>;
// clang-format on
#include "run_gemm_wmma_bpreshuffle_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,72 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include <cstddef>
#include <iostream>
#include <type_traits>
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using ADataType = F8;
using BDataType = F8;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ComputeTypeA = F8;
using ComputeTypeB = F8;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr int KPack = 16; // int4 -> 32, fp8 -> 16, fp16 -> 8
// clang-format off
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
256,
32, 128, 256,
16, 16,
16, 16,
2, 1,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>;
// clang-format on
#include "run_gemm_wmma_bpreshuffle_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,206 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{0, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n_preshuffled: " << b_k_n_preshuffled.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// do GEMM
auto device_op = DeviceOpInstance{};
// weight pre-shuffle
int NPerWmma = device_op.GetPreShuffleParameters();
int KLane = ck::get_warp_size() / NPerWmma;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NPerWmma
// N, K -> N0 K0 KLane NPerWmma KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / NPerWmma;
int n1 = n % NPerWmma;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NPerWmma * KLane * K0 + k0 * KPack * NPerWmma * KLane +
k1 * KPack * NPerWmma + n1 * KPack + k2;
b_k_n_preshuffled(outputIndex) = b_k_n(n * K + k);
}
}
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_preshuffled.mData.data());
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!device_op.IsSupportedArgument(argument))
{
std::cerr << device_op.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 50, false, 1});
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
invoker.Run(argument, StreamConfig{nullptr, false, 0});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
}
if(config.time_kernel)
{
ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
return pass;
}
bool run_gemm_splitk_example(int argc, char* argv[])
{
ProblemSizeSplitK problem_size{3840, 4096, 4096, 4096, 4096, 4096, 1};
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
}

View File

@@ -119,7 +119,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;

View File

@@ -119,7 +119,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;

View File

@@ -31,7 +31,7 @@ class SimpleAppArgs
bool do_verification = true;
int data_type = 1;
int init_method = 2;
bool time_kernel = true;
bool time_kernel = false;
public:
void show_usage(const char* cmd)

View File

@@ -31,7 +31,7 @@ class SimpleAppArgs
bool do_verification = true;
int data_type = 1;
int init_method = 2;
bool time_kernel = true;
bool time_kernel = false;
public:
void show_usage(const char* cmd)

View File

@@ -31,7 +31,7 @@ class SimpleAppArgs
bool do_verification = true;
int data_type = 1;
int init_method = 2;
bool time_kernel = true;
bool time_kernel = false;
public:
void show_usage(const char* cmd)

View File

@@ -53,7 +53,7 @@ int main(int argc, char* argv[])
{
do_verification = true;
init_method = 1;
time_kernel = true;
time_kernel = false;
}
else if(argc == 4)
{

View File

@@ -27,10 +27,11 @@ using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ActivationOp = PassThrough;
@@ -125,11 +126,11 @@ int main(int /* argc */, char* /* argv */[])
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
}
};

View File

@@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)

View File

@@ -90,7 +90,7 @@ struct ExecutionConfig final
bool do_verification = true;
int init_method = 1;
int k_batch = 128;
bool time_kernel = true;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)

View File

@@ -0,0 +1,76 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <ck/utility/data_type.hpp>
#include <ck/utility/tuple.hpp>
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
using ::ck::DeviceMem;
using ::ck::hip_check_error;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType, DDataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using DsLayout = ck::Tuple<DLayout, DLayout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr int NumDs = 2;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>;
// clang-format on
#include "run_grouped_gemm_multiple_d_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

View File

@@ -71,339 +71,6 @@ using DeviceGemmInstance =
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
#include "run_grouped_gemm_multiple_d_example.inc"
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<std::vector<ck::index_t>> stride_Ds;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
using GemmDesc = ck::tensor_operation::device::GemmDesc;
// GEMM shape
std::vector<GemmDesc> gemm_descs;
std::vector<KernelArguments> ggemm_kargs;
std::vector<void*> p_Cs;
std::vector<const void*> p_As;
std::vector<const void*> p_Bs;
std::vector<std::array<const void*, NumDs>> p_Ds = {};
gemm_descs.reserve(group_count);
ggemm_kargs.reserve(group_count);
p_As.reserve(group_count);
p_Bs.reserve(group_count);
p_Ds.reserve(group_count);
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_result_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_result_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
d_tensors_device.resize(group_count); // reserve and update vector size
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
d_tensors.push_back(d_tens);
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
sizeof(BDataType) * b_tensors[i].GetElementSize() +
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
}
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
}
}
}
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
for(int j = 0; j < NumDs; ++j)
{
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
}
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
for(int j = 0; j < NumDs; ++j)
{
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
}
c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_Ds.push_back(
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
// The device op does not have to know M problem size at lunch time.
gemm_descs.push_back({0,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
ggemm_kargs.push_back(
{a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
ggemm_kargs.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1});
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
auto karg = ggemm_kargs[i];
auto dev_res_tensor =
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
d_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc < 10)
{
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(252);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
problem_size.stride_Ds.push_back({});
for(int j = 0; j < NumDs; ++j)
{
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
}
}
std::cout
<< "Usage:\n"
<< "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "... setting default values." << std::endl;
}
else
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.Ms = argToIntArray(argv[4]);
problem_size.Ns = argToIntArray(argv[5]);
problem_size.Ks = argToIntArray(argv[6]);
problem_size.stride_As = argToIntArray(argv[7]);
problem_size.stride_Bs = argToIntArray(argv[8]);
problem_size.stride_Cs = argToIntArray(argv[9]);
for(int j = 0; j < NumDs; ++j)
{
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
}
problem_size.group_count = problem_size.Ms.size();
}
return !run_grouped_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

View File

@@ -58,11 +58,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// clang-format on

View File

@@ -57,11 +57,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// clang-format on

View File

@@ -323,8 +323,8 @@ bool run_grouped_gemm_example(int argc, char* argv[])
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: async hargs (0=n0, 1=yes)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: async hargs (0=no, 1=yes)\n");
printf("arg5: group count (default=16)\n");
#if defined(EXAMPLE_USE_SPLITK)
printf("arg6: k-batch count (default=1)\n");

View File

@@ -0,0 +1,341 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<std::vector<ck::index_t>> stride_Ds;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
using GemmDesc = ck::tensor_operation::device::GemmDesc;
// GEMM shape
std::vector<GemmDesc> gemm_descs;
std::vector<KernelArguments> ggemm_kargs;
std::vector<void*> p_Cs;
std::vector<const void*> p_As;
std::vector<const void*> p_Bs;
std::vector<std::array<const void*, NumDs>> p_Ds = {};
gemm_descs.reserve(group_count);
ggemm_kargs.reserve(group_count);
p_As.reserve(group_count);
p_Bs.reserve(group_count);
p_Ds.reserve(group_count);
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_result_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_result_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
d_tensors_device.resize(group_count); // reserve and update vector size
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
d_tensors.push_back(d_tens);
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
sizeof(BDataType) * b_tensors[i].GetElementSize() +
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
}
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
}
}
}
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
for(int j = 0; j < NumDs; ++j)
{
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
}
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
for(int j = 0; j < NumDs; ++j)
{
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
}
c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_Ds.push_back(
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
// The device op does not have to know M problem size at lunch time.
gemm_descs.push_back({0,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
ggemm_kargs.push_back(
{a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
ggemm_kargs.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1});
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
auto karg = ggemm_kargs[i];
auto dev_res_tensor =
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
d_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
bool run_grouped_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc < 10)
{
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(252);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
problem_size.stride_Ds.push_back({});
for(int j = 0; j < NumDs; ++j)
{
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
}
}
std::cout
<< "Usage:\n"
<< "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "... setting default values." << std::endl;
}
else
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.Ms = argToIntArray(argv[4]);
problem_size.Ns = argToIntArray(argv[5]);
problem_size.Ks = argToIntArray(argv[6]);
problem_size.stride_As = argToIntArray(argv[7]);
problem_size.stride_Bs = argToIntArray(argv[8]);
problem_size.stride_Cs = argToIntArray(argv[9]);
for(int j = 0; j < NumDs; ++j)
{
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
}
problem_size.group_count = problem_size.Ms.size();
}
return run_grouped_gemm(problem_size, config);
}

View File

@@ -268,7 +268,7 @@ int main()
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
}
bool time_kernel = true;
bool time_kernel = false;
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});

View File

@@ -302,7 +302,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -108,7 +108,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -105,7 +105,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -81,7 +81,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// CGEMM shape
ck::index_t M = 1024;

View File

@@ -65,7 +65,7 @@ class SimpleAppArgs
bool do_verification = true;
int init_method = 2;
bool time_kernel = true;
bool time_kernel = false;
public:
void show_usage(const char* cmd)

View File

@@ -27,7 +27,7 @@ struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
};
template <typename DataType>

View File

@@ -17,7 +17,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
@@ -69,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G1_M2_N3_K1::Argument;
float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c;
};
make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M3_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
@@ -353,16 +217,18 @@ int main(int argc, char* argv[])
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceBatchedContraction_G1_M2_N3_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
@@ -399,7 +265,13 @@ int main(int argc, char* argv[])
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
if(!pass)
{
return 1;
}
}
return 0;

View File

@@ -17,6 +17,8 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
using ::ck::make_ParallelTensorFunctor;
@@ -67,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKNN;
template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 1 && NumDimM == 3 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G1_M3_N2_K1::Argument;
float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_G1_M3_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
@@ -353,17 +219,18 @@ int main(int argc, char* argv[])
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceBatchedContraction_G1_M3_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
@@ -400,7 +267,13 @@ int main(int argc, char* argv[])
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
if(!pass)
{
return 1;
}
}
return 0;

View File

@@ -3,3 +3,4 @@
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_bias_e_permute_wmma_v3_fp16 batched_gemm_bias_e_permute_wmma_v3_fp16.cpp)

View File

@@ -106,352 +106,5 @@ using DeviceOpInstanceKKNN =
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_G2_M2_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
ck::index_t G0 = 1;
ck::index_t G1 = 2;
ck::index_t M0 = 4;
ck::index_t M1 = 128;
ck::index_t N0 = 16;
ck::index_t N1 = 256;
ck::index_t K0 = 2048;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
G0 = std::stoi(argv[4]);
G1 = std::stoi(argv[5]);
M0 = std::stoi(argv[6]);
M1 = std::stoi(argv[7]);
N0 = std::stoi(argv[8]);
N1 = std::stoi(argv[9]);
K0 = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n");
exit(0);
}
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{});
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{});
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{});
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t G =
ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{});
ck::index_t M = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl;
std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
{
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n1)
{
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
d_gs_ms_ns(g0, g1, m0, m1, n0, n1));
}
}
}
}
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
}
return 0;
}
#include "run_batched_gemm_bias_e_permute_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); }

View File

@@ -0,0 +1,111 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
using ::ck::make_ParallelTensorFunctor;
using ::ck::Tensor;
using Row = ck::tensor_layout::gemm::RowMajor;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 1;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceOpInstanceKKNN =
ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
ASpec,
BSpec,
DESpec,
128,
64,
64,
64,
4,
4,
16,
16,
1,
4,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
false,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
false,
1,
1,
S<1, 64, 1, 2>,
S<8, 8>>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
#include "run_batched_gemm_bias_e_permute_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); }

View File

@@ -67,340 +67,5 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_G2_M2_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t G0 = 1;
ck::index_t G1 = 2;
ck::index_t M0 = 4;
ck::index_t M1 = 256;
ck::index_t N0 = 16;
ck::index_t N1 = 128;
ck::index_t K0 = 64;
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{});
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{});
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{});
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t G =
ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{});
ck::index_t M = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
{
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n1)
{
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
d_gs_ms_ns(g0, g1, m0, m1, n0, n1));
}
}
}
}
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
}
return 0;
}
#include "run_batched_gemm_bias_e_permute_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); }

View File

@@ -0,0 +1,350 @@
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_G2_M2_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int run_batched_gemm_bias_e_permute_example(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t G0 = 1;
ck::index_t G1 = 2;
ck::index_t M0 = 4;
ck::index_t M1 = 128;
ck::index_t N0 = 16;
ck::index_t N1 = 256;
ck::index_t K0 = 2048;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
G0 = std::stoi(argv[4]);
G1 = std::stoi(argv[5]);
M0 = std::stoi(argv[6]);
M1 = std::stoi(argv[7]);
N0 = std::stoi(argv[8]);
N1 = std::stoi(argv[9]);
K0 = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n");
exit(0);
}
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{});
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{});
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{});
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t G =
ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{});
ck::index_t M = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl;
std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{});
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
{
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n1)
{
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
d_gs_ms_ns(g0, g1, m0, m1, n0, n1));
}
}
}
}
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
}
return 1;
}

View File

@@ -92,7 +92,7 @@ struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
};
#define DefaultConvParam \

View File

@@ -92,7 +92,7 @@ struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
};
#define DefaultConvParam \

View File

@@ -40,7 +40,7 @@ class SimpleAppArgs
bool do_verification = true;
int init_method = 2;
bool time_kernel = true;
bool time_kernel = false;
public:
SimpleAppArgs()

View File

@@ -44,7 +44,7 @@ struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 2;
bool time_kernel = true;
bool time_kernel = false;
};
template <ck::index_t... Is>

View File

@@ -56,7 +56,7 @@ template<> struct emb_kernel<ck::half_t, 8192> { using kernel_type = DeviceInsta
int main(int argc, char* argv[])
{
bool time_kernel = true;
bool time_kernel = false;
ck::index_t num_rows = 65536;
constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{};

View File

@@ -195,7 +195,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;

View File

@@ -86,7 +86,7 @@ using DeviceGroupedConvNDFwdInstance =
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance =
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -87,7 +87,7 @@ using DeviceGroupedConvNDFwdInstance =
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance =
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance =
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -90,7 +90,7 @@ using DeviceGroupedConvNDFwdInstance =
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance =
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance =
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -12,7 +12,7 @@ int run_groupnorm_fwd_example(int argc, char* argv[])
ck::index_t C = 128;
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
bool log_kernel = true;
if(argc == 1)

View File

@@ -53,7 +53,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
std::vector<std::size_t> nchw = {16, 128, 32, 64};

View File

@@ -46,7 +46,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -49,7 +49,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -121,7 +121,7 @@ void reference_scale_permute_amax(Tensor<InputDataType>& input,
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
const float scale = 2.f;

View File

@@ -58,7 +58,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
if(argc == 1)
{

View File

@@ -84,7 +84,7 @@ void host_elementwise2D(HostTensorC& C,
int main(int argc, char* argv[])
{
bool do_verification = true;
bool time_kernel = true;
bool time_kernel = false;
ck::index_t M = 48 * 256;
ck::index_t N = 1024;

View File

@@ -31,8 +31,9 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
using A0DataType = F16;
using B0DataType = F16;
@@ -139,11 +140,11 @@ int main(int argc, char* argv[])
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
}
};

View File

@@ -205,7 +205,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// GEMM shape
ck::index_t N = 4096;

View File

@@ -193,7 +193,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
#if 1
// GEMM shape
ck::index_t N = 4096;

View File

@@ -119,7 +119,7 @@ static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_an
static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight.
#if 1
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MPerBlock = 64;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
@@ -156,7 +156,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, true>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
@@ -171,7 +172,8 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, false>;
#endif
// clang-format on
@@ -182,12 +184,14 @@ int main(int argc, char* argv[])
bool time_kernel = true;
#if 1
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 6144;
ck::index_t N = 1536;
ck::index_t K = 4096;
// ck::index_t N = 4096;
// ck::index_t K = 6144;
// ck::index_t N = 128;
// ck::index_t K = 512;
ck::index_t experts = 8;
ck::index_t topk = 2;
ck::index_t experts = 16;
ck::index_t topk = 8;
// ck::index_t sorted_tile_num = 515;
// ck::index_t valid_tile_num = 512;
// ck::index_t tokens = 208;
@@ -196,9 +200,9 @@ int main(int argc, char* argv[])
// ck::index_t sorted_tile_num = 259;
// ck::index_t valid_tile_num = 256;
// ck::index_t tokens = 4096;
ck::index_t sorted_tile_num = 2;
ck::index_t valid_tile_num = 2;
ck::index_t tokens = 32;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 16;
ck::index_t tokens = 4;
#else
// deepseek
ck::index_t N = 2048;
@@ -209,7 +213,7 @@ int main(int argc, char* argv[])
ck::index_t sorted_tile_num = 261;
ck::index_t valid_tile_num = 256;
#endif
ck::index_t KBatch = 6;
ck::index_t KBatch = 1;
if(argc == 1)
{
// use default case

View File

@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// per expert:
// GEMM shape

View File

@@ -185,7 +185,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// per expert:
// GEMM shape

View File

@@ -188,7 +188,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// tokens = 1
// topk = 1

View File

@@ -164,7 +164,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// per expert:
// GEMM shape

View File

@@ -5,6 +5,8 @@
int run_gemm_example(int argc, char* argv[])
{
using Bypass = ck::tensor_layout::BypassLayoutVerification;
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
@@ -64,11 +66,11 @@ int run_gemm_example(int argc, char* argv[])
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return ck::HostTensorDescriptor({row, col}, {stride, 1_uz});
return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
}
else
{
return ck::HostTensorDescriptor({row, col}, {1_uz, stride});
return ck::HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
}
};

View File

@@ -178,7 +178,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// per expert:
// GEMM shape

View File

@@ -178,7 +178,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// per expert:
// GEMM shape

View File

@@ -208,7 +208,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// per expert:
// GEMM shape

View File

@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// per expert:
// GEMM shape

View File

@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// per expert:
// GEMM shape

View File

@@ -204,7 +204,7 @@ int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
// per expert:
// GEMM shape

View File

@@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);

View File

@@ -6,6 +6,7 @@
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
@@ -13,11 +14,11 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
}
};

View File

@@ -6,6 +6,7 @@
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
@@ -13,11 +14,11 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
}
};

View File

@@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);

View File

@@ -6,6 +6,7 @@
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
@@ -13,11 +14,11 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& c
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
}
};

View File

@@ -6,6 +6,7 @@
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
@@ -13,11 +14,11 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& c
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
}
};

View File

@@ -6,6 +6,35 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include
)
if(WIN32)
# On Windows, HIP uses -nostdlib which prevents C runtime linking
# We need legacy_stdio_definitions.lib to provide vfprintf and other legacy C functions
# This is mainly needed for the getopt library.
set(LEGACY_STDIO_SEARCH_PATHS)
# Try to use Visual C++ Tools environment variable (if build executes from Visual Studio Developer Command Prompt)
if(DEFINED ENV{VCToolsInstallDir})
list(APPEND LEGACY_STDIO_SEARCH_PATHS "$ENV{VCToolsInstallDir}/lib/x64")
endif()
# Fallback: Search common Visual Studio installation locations
file(GLOB MSVC_LIB_PATHS "C:/Program Files/Microsoft Visual Studio/*/*/VC/Tools/MSVC/*/lib/x64")
list(APPEND LEGACY_STDIO_SEARCH_PATHS ${MSVC_LIB_PATHS})
# Use find_library to locate the library
find_library(LEGACY_STDIO_LIB legacy_stdio_definitions
PATHS ${LEGACY_STDIO_SEARCH_PATHS}
NO_DEFAULT_PATH
)
if(LEGACY_STDIO_LIB)
message(STATUS "Found legacy_stdio_definitions.lib: ${LEGACY_STDIO_LIB}")
add_link_options("SHELL:-Xlinker \"${LEGACY_STDIO_LIB}\"")
else()
message(WARNING "Could not find legacy_stdio_definitions.lib - examples may fail to link.")
endif()
endif()
add_custom_target(examples)
@@ -216,6 +245,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
add_dependencies(examples ${EXAMPLE_NAME})
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)

View File

@@ -36,6 +36,19 @@ DTYPE_BITS = {
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
}
KV_LOOKUP_TABLE_ENUM_MAP = {
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}
@@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
@@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse},
{F_dropout},
{F_qscale},
{F_occupancy}>;
{F_occupancy},
false,
{F_page_size},
{F_kv_memory_layout},
{F_kv_lookup_table}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
@@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
fmha_variant_{F_idx},
fmha_mask_{F_idx},
false,
{F_page_size},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
@@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} =
using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
#include <iostream>
@@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@@ -230,12 +248,15 @@ class FmhaFwdApiTrait:
dpad: str
dvpad: str
constraint: CppConstraint
kv_memory_layout: str
kv_lookup_table: str
page_size: int = 1 # page block size
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
)
@property
@@ -322,6 +343,8 @@ class FmhaFwdPipeline:
F_dropout: str #
F_qscale: str # no/pertensor
F_mask: str # value from MASK_MAP
F_kv_memory_layout: str #
F_kv_lookup_table: str #
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
@@ -382,6 +405,8 @@ class FmhaFwdPipeline:
n += f"_{self.F_qscale}"
else:
n += "_nqscale"
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
return n
@@ -440,6 +465,13 @@ class FmhaFwdApiPool:
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
trait.kv_memory_layout
],
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
trait.kv_lookup_table
],
F_page_size=trait.page_size,
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
@@ -497,6 +529,7 @@ class FmhaFwdKernel:
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
F_page_size: int = 1 # page block size
@property
def template(self) -> str:
@@ -534,17 +567,24 @@ class FmhaFwdKernel:
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
F_occupancy=self.F_tile.F_occupancy,
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
self.F_pipeline.F_kv_memory_layout
],
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
self.F_pipeline.F_kv_lookup_table
],
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
F_page_size=self.F_page_size,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
@@ -578,6 +618,9 @@ class FmhaFwdKernel:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
page_size=self.F_page_size,
)
@@ -604,23 +647,42 @@ class KernelComponentFactory:
pipelines = []
if dtype in ["fp16", "bf16"]:
qscale = "no"
for logits, mask, bias, lse, dropout in itertools.product(
for (
logits,
mask,
bias,
lse,
dropout,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
elif dtype in ["fp8bf16"]:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
for (
logits,
qscale,
mask,
bias,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor"],
get_mask_map(mask_impl).keys(),
["no"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
else:
assert False
return pipelines
@@ -672,69 +734,75 @@ def get_fwd_blobs(
or pipeline.F_logits == "f"
):
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
# Generate kernels for both page_size=16 and page_size=1024
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
F_page_size=page_size,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)

View File

@@ -315,7 +315,7 @@ class FmhaFwdApiTrait:
assert False
def seqtune(self, max_bm0: int) -> str:
if self.bm0 == max_bm0:
if self.bm0 == max_bm0 or self.bm0 == 64:
return "true/*fall back to largest tile*/"
else:
return f"a.seqlen_q <= {self.bm0}"
@@ -847,6 +847,11 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
(problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128)
and kernel_ctx.tile.F_bm0 != 128
)
or (
(problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128)
and kernel_ctx.pipeline.tag != "qr_async"
and kernel_ctx.tile.F_bk0 == 64
)
):
# non qr_async_trload only support km0=128 tile size when hdim is not 128
# non qr_async only support kn0=128 tile size when hdim is 128
@@ -942,6 +947,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],

View File

@@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[])
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
"Comma-separated list of length 'b'. If empty, no override.")
.insert("init_sink", "0", "value to init the output tensor sink value for validation");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
int init_sink_value = arg_parser.get_int("init_sink");
ck_tile::stream_config stream_config{nullptr,
true,
@@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
init_method,
seed,
do_validation,
init_sink_value,
stream_config,
json);
}

Some files were not shown because too many files have changed in this diff Show More