mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into fa_decode_pipeline
This commit is contained in:
12
.github/CODEOWNERS
vendored
12
.github/CODEOWNERS
vendored
@@ -1,8 +1,8 @@
|
||||
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli
|
||||
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @shumway @vidyasagar-amd
|
||||
# Documentation files
|
||||
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD
|
||||
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD
|
||||
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD
|
||||
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD
|
||||
# Header directory for Doxygen documentation
|
||||
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -68,3 +68,6 @@ build*/
|
||||
|
||||
# Python cache
|
||||
__pycache__/
|
||||
|
||||
.cache/
|
||||
|
||||
|
||||
@@ -12,9 +12,27 @@ repos:
|
||||
verbose: false
|
||||
language: script
|
||||
types: [c++]
|
||||
- id: remove-exec-bit
|
||||
name: Remove executable bit from non-executable files
|
||||
entry: script/remove_exec_bit.sh
|
||||
language: script
|
||||
types_or: [c++, text]
|
||||
verbose: true
|
||||
- id: remove-exec-bit
|
||||
name: Remove executable bit from non-executable files
|
||||
entry: script/remove_exec_bit.sh
|
||||
language: script
|
||||
types_or: [c++, text]
|
||||
verbose: true
|
||||
- id: ruff-check
|
||||
name: Ruff Linter
|
||||
entry: ruff check --fix
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: [ruff]
|
||||
- id: ruff-format
|
||||
name: Ruff Formatter
|
||||
entry: ruff format
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: [ruff]
|
||||
- id: run-remod-if-ck-tile-changed
|
||||
name: Run remod.py if ck_tile files changed
|
||||
entry: script/remod_for_ck_tile.sh
|
||||
language: script
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
@@ -14,19 +14,22 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
|
||||
* Added support for Stream-K version of mixed fp8/bf16 GEMM
|
||||
* Added support for Multiple D GEMM
|
||||
* Added GEMM pipeline for microscaling (MX) FP8/FP4 data types
|
||||
* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types
|
||||
* Added support for FP16 2:4 structured sparsity to universal GEMM.
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
* Added logit soft-capping support for fMHA forward kernels.
|
||||
* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv)
|
||||
* Added benchmarking support for tile engine GEMM.
|
||||
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
|
||||
* Added rotating buffer feature for CK_Tile GEMM.
|
||||
* Added int8 support for CK_TILE GEMM.
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166)
|
||||
* Added Vectorize Transpose optimization for CK Tile (#2131)
|
||||
* Added the asynchronous copy for gfx950 (#2425)
|
||||
|
||||
|
||||
### Fixes
|
||||
|
||||
@@ -634,7 +634,7 @@ option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
|
||||
add_subdirectory(library)
|
||||
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
|
||||
@@ -77,7 +77,6 @@ RUN git clone https://github.com/ccache/ccache.git && \
|
||||
wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip && \
|
||||
gunzip /usr/local/bin/ninja.gz && \
|
||||
chmod a+x /usr/local/bin/ninja && \
|
||||
git clone https://github.com/nico/ninjatracing.git && \
|
||||
#Install ClangBuildAnalyzer
|
||||
git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \
|
||||
cd ClangBuildAnalyzer/ && \
|
||||
|
||||
137
Jenkinsfile
vendored
137
Jenkinsfile
vendored
@@ -225,18 +225,15 @@ def cmake_build(Map conf=[:]){
|
||||
def build_envs = "CTEST_PARALLEL_LEVEL=4 " + conf.get("build_env","")
|
||||
def prefixpath = conf.get("prefixpath","/opt/rocm")
|
||||
def setup_args = conf.get("setup_args","")
|
||||
|
||||
// make sure all unit tests always run on develop branch
|
||||
def runAllUnitTests = (env.BRANCH_NAME == "develop") ? true : params.RUN_ALL_UNIT_TESTS
|
||||
|
||||
if (prefixpath != "/usr/local"){
|
||||
setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} "
|
||||
}
|
||||
|
||||
def build_type_debug = (conf.get("build_type",'release') == 'debug')
|
||||
|
||||
// use special compiler for gfx950
|
||||
if ( check_arch() == 7){
|
||||
compiler = "/llvm-project/build/bin/clang++"
|
||||
}
|
||||
|
||||
//cmake_env can overwrite default CXX variables.
|
||||
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
|
||||
|
||||
@@ -343,15 +340,19 @@ def cmake_build(Map conf=[:]){
|
||||
def build_cmd
|
||||
def execute_cmd = conf.get("execute_cmd", "")
|
||||
if(!setup_args.contains("NO_CK_BUILD")){
|
||||
if (setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE){
|
||||
def cmake_flags = params.NINJA_FTIME_TRACE ? "-O3 -ftime-trace" : "-O3"
|
||||
if (params.NINJA_BUILD_TRACE) {
|
||||
echo "running ninja build trace"
|
||||
setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """)
|
||||
build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}")
|
||||
}
|
||||
else{
|
||||
setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
|
||||
build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}")
|
||||
}
|
||||
setup_cmd = conf.get(
|
||||
"setup_cmd",
|
||||
"""${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" ${cmake_flags} " .. """
|
||||
)
|
||||
build_cmd = conf.get(
|
||||
"build_cmd",
|
||||
"${build_envs} ninja -j${nt} ${config_targets}"
|
||||
)
|
||||
|
||||
cmd = conf.get("cmd", """
|
||||
${setup_cmd}
|
||||
${build_cmd}
|
||||
@@ -372,14 +373,23 @@ def cmake_build(Map conf=[:]){
|
||||
//run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
|
||||
if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
|
||||
if ((setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE) || params.BUILD_INSTANCES_ONLY){
|
||||
sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log"
|
||||
if (params.NINJA_FTIME_TRACE) {
|
||||
echo "running ninja ftime trace"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log"
|
||||
archiveArtifacts "clang_build_analysis.log"
|
||||
}
|
||||
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace.json"
|
||||
archiveArtifacts "ck_build_trace.json"
|
||||
archiveArtifacts "clang_build_analysis.log"
|
||||
|
||||
// do not run unit tests when building instances only
|
||||
if(!params.BUILD_INSTANCES_ONLY){
|
||||
sh "ninja check"
|
||||
if (!runAllUnitTests){
|
||||
sh "../script/launch_tests.sh"
|
||||
}
|
||||
else{
|
||||
sh "ninja check"
|
||||
}
|
||||
}
|
||||
if(params.BUILD_INSTANCES_ONLY){
|
||||
// build deb packages
|
||||
@@ -393,7 +403,12 @@ def cmake_build(Map conf=[:]){
|
||||
else{
|
||||
// run unit tests unless building library for all targets
|
||||
if (!params.BUILD_INSTANCES_ONLY){
|
||||
sh "make check"
|
||||
if (!runAllUnitTests){
|
||||
sh "../script/launch_tests.sh"
|
||||
}
|
||||
else{
|
||||
sh "ninja check"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -793,10 +808,10 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=false
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : ""
|
||||
|
||||
@@ -859,8 +874,8 @@ pipeline {
|
||||
description: "Run the cppcheck static analysis (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_PERFORMANCE_TESTS",
|
||||
defaultValue: true,
|
||||
description: "Run the performance tests (default: ON)")
|
||||
defaultValue: false,
|
||||
description: "Run the performance tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS",
|
||||
defaultValue: false,
|
||||
@@ -893,10 +908,26 @@ pipeline {
|
||||
name: "BUILD_GFX908",
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx908 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX90A",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx90a (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX942",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx942 (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX950",
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx950 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX10",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx10 (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX11",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx11 (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX12",
|
||||
defaultValue: true,
|
||||
@@ -905,6 +936,10 @@ pipeline {
|
||||
name: "NINJA_BUILD_TRACE",
|
||||
defaultValue: false,
|
||||
description: "Generate a ninja build trace (default: OFF)")
|
||||
booleanParam(
|
||||
name: "NINJA_FTIME_TRACE",
|
||||
defaultValue: false,
|
||||
description: "Generate a detailed time trace (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_LEGACY_OS",
|
||||
defaultValue: false,
|
||||
@@ -913,6 +948,10 @@ pipeline {
|
||||
name: "RUN_INDUCTOR_TESTS",
|
||||
defaultValue: true,
|
||||
description: "Run inductor codegen tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "RUN_ALL_UNIT_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run all unit tests (default: OFF)")
|
||||
}
|
||||
environment{
|
||||
dbuser = "${dbuser}"
|
||||
@@ -1025,7 +1064,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CODEGEN_TESTS.toBoolean() }
|
||||
expression { params.RUN_CODEGEN_TESTS.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a")}
|
||||
environment{
|
||||
@@ -1189,9 +1228,12 @@ pipeline {
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx90a" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j64 benchmark_gemm && \
|
||||
./bin/benchmark_gemm """
|
||||
ninja -j64 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j64 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1211,9 +1253,12 @@ pipeline {
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j128 benchmark_gemm && \
|
||||
./bin/benchmark_gemm """
|
||||
ninja -j128 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j128 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1269,7 +1314,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { (params.BUILD_GFX942.toBoolean() || params.RUN_FULL_QA.toBoolean()) && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
@@ -1302,12 +1347,12 @@ pipeline {
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "rocm/composable_kernel-private:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
@@ -1336,7 +1381,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { params.BUILD_GFX90A.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
@@ -1360,14 +1405,20 @@ pipeline {
|
||||
expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
script {
|
||||
def execute_args = params.NINJA_FTIME_TRACE ?
|
||||
""" cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """ :
|
||||
""" cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """
|
||||
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
}
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
@@ -1375,7 +1426,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { params.BUILD_GFX10.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1030") }
|
||||
environment{
|
||||
@@ -1396,7 +1447,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1101") }
|
||||
environment{
|
||||
|
||||
4
client_example/32_gemm_mx/CMakeLists.txt
Normal file
4
client_example/32_gemm_mx/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx950")
|
||||
add_executable(client_gemm_mx_fp8 gemm_mx_fp8.cpp)
|
||||
target_link_libraries(client_gemm_mx_fp8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
330
client_example/32_gemm_mx/gemm_mx_fp8.cpp
Normal file
330
client_example/32_gemm_mx/gemm_mx_fp8.cpp
Normal file
@@ -0,0 +1,330 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t;
|
||||
template <typename X, typename Y>
|
||||
inline constexpr bool is_same_v = ck::is_same<X, Y>::value;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AScaleLayout = Row;
|
||||
using BScaleLayout = Col;
|
||||
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
mem_size_ = mem_size;
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
std::size_t mem_size_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
/* Require by mx type*/
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
M = std::stoi(argv[1]);
|
||||
N = std::stoi(argv[2]);
|
||||
K = std::stoi(argv[3]);
|
||||
|
||||
StrideA = std::stoi(argv[4]);
|
||||
StrideB = std::stoi(argv[5]);
|
||||
StrideC = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_matrix_space_size =
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if constexpr(std::is_same<Layout, Row>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (nCol - 1) * stride + nRow;
|
||||
}
|
||||
};
|
||||
|
||||
/* Scale stride Calculation */
|
||||
auto f_get_default_stride =
|
||||
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
return static_cast<ck::index_t>(col);
|
||||
else
|
||||
return static_cast<ck::index_t>(row);
|
||||
}
|
||||
else
|
||||
return static_cast<ck::index_t>(stride);
|
||||
};
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
auto Scale_Padded_M = (M + ScaleBlockSize - 1) / ScaleBlockSize * ScaleBlockSize;
|
||||
auto Scale_Stride_AM =
|
||||
f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{});
|
||||
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
|
||||
|
||||
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
|
||||
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
|
||||
SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
|
||||
SimpleDeviceMem a_scale_device_buf(
|
||||
sizeof(XDataType) *
|
||||
f_matrix_space_size(Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
|
||||
SimpleDeviceMem b_scale_device_buf(
|
||||
sizeof(XDataType) *
|
||||
f_matrix_space_size(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
|
||||
|
||||
using DeviceOp =
|
||||
ck::tensor_operation::device::DeviceGemmMX<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
XPackedDataType,
|
||||
BDataType,
|
||||
XPackedDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop =
|
||||
std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K / ck::packed_size_v<ADataType> +
|
||||
sizeof(BDataType) * K * N / ck::packed_size_v<BDataType> +
|
||||
sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * M * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * N * K / ScaleBlockSize;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
if(found)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -14,8 +14,10 @@ cd client_example/build
|
||||
cmake \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \
|
||||
-D GPU_TARGETS="gfx908;gfx90a" \
|
||||
..
|
||||
```
|
||||
You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s).
|
||||
|
||||
### Build client example
|
||||
```bash
|
||||
|
||||
116
cmake/ShardInstantiation.cmake
Normal file
116
cmake/ShardInstantiation.cmake
Normal file
@@ -0,0 +1,116 @@
|
||||
# Function to generate templated instantiation functions and caller function.
|
||||
|
||||
# In order to reduce build times, we split the instantiation of template functions into multiple files.
|
||||
# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions,
|
||||
# which can be placed the TEMPLATE_FILE (typically a .in file).
|
||||
|
||||
# This CMake function generates the instantiation functions and a caller function that calls all the instantiation
|
||||
# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of
|
||||
# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard,
|
||||
# and generates a caller function that calls all the instantiation functions.
|
||||
|
||||
# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation
|
||||
# of the template functions in the caller function, and that code is automatically generated by this function.
|
||||
|
||||
# In addition to the user-supplied template, this CMake function uses two generic templates:
|
||||
#
|
||||
# 1. `instantiate_shard.in`: This is the template for the instantiation functions.
|
||||
# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions.
|
||||
|
||||
# This function takes the following arguments:
|
||||
#
|
||||
# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`).
|
||||
# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions.
|
||||
# - NUM_SHARDS: The number of shards to generate.
|
||||
# - OUTPUT_DIR: The build directory where the generated source files will be placed.
|
||||
# - SRC_LIST: The list of source files to which the generated source files will be added.
|
||||
|
||||
|
||||
function(generate_sharded_instantiations)
|
||||
cmake_parse_arguments(
|
||||
GEN_SHARDED
|
||||
# No boolean arguments
|
||||
""
|
||||
# Single-value arguments
|
||||
"INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST"
|
||||
# No multi-value arguments.
|
||||
""
|
||||
${ARGN}
|
||||
)
|
||||
if (NOT GEN_SHARDED_INSTANCES_NAME)
|
||||
message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
if (NOT GEN_SHARDED_TEMPLATE_FILE)
|
||||
message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
if (NOT GEN_SHARDED_NUM_SHARDS)
|
||||
message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
if(NOT GEN_SHARDED_OUTPUT_DIR)
|
||||
message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
if (NOT GEN_SHARDED_SRC_LIST)
|
||||
message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
|
||||
file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR})
|
||||
|
||||
|
||||
set(GENERATED_SOURCE_FILES "")
|
||||
set(EXTERN_TEMPLATE_STATEMENTS "")
|
||||
set(CALL_STATEMENTS "")
|
||||
message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}")
|
||||
|
||||
set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}")
|
||||
|
||||
# Generate the inc file with the template function defintions.
|
||||
# This include file will hold the template function definitions and a using alias for all the shard
|
||||
# instantiation functions.
|
||||
configure_file(
|
||||
"${GEN_SHARDED_TEMPLATE_FILE}"
|
||||
"${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc"
|
||||
@ONLY
|
||||
)
|
||||
|
||||
# Generate the sharded instantiation functions.
|
||||
# This is where the build parallelization happens.
|
||||
# Each of these source files will contain a single instantiation function for a shard,
|
||||
# which will be called sequentially by the caller function.
|
||||
set(INC_DIR "${GEN_SHARDED_INC_DIR}")
|
||||
math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1")
|
||||
foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID})
|
||||
set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}")
|
||||
set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp")
|
||||
set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in")
|
||||
configure_file(
|
||||
"${SHARD_FUNCTION_TEMPLATE}"
|
||||
"${SHARD_FUNCTION_PATH}"
|
||||
@ONLY
|
||||
)
|
||||
list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}")
|
||||
set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>")
|
||||
list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)")
|
||||
list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)")
|
||||
endforeach()
|
||||
|
||||
# Join the include statements, the extern template declarations, and the call statements each
|
||||
# into a single string for variable substitution in the caller function.
|
||||
string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}")
|
||||
string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}")
|
||||
string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}")
|
||||
|
||||
# Generate the caller function.
|
||||
set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp")
|
||||
set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in")
|
||||
configure_file(
|
||||
"${FUNCTION_TEMPLATE}"
|
||||
"${CALLER_FUNCTION_PATH}"
|
||||
@ONLY
|
||||
)
|
||||
list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}")
|
||||
|
||||
# Add the generated source files to the list of source files.
|
||||
# This allows the generated source files to be included in the build.
|
||||
list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES})
|
||||
set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
15
cmake/call_shard.in
Normal file
15
cmake/call_shard.in
Normal file
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "@INSTANCES@.inc"
|
||||
|
||||
namespace ck::tensor_operation::device::instance {
|
||||
|
||||
@EXTERN_TEMPLATE_STATEMENTS@;
|
||||
|
||||
void add_@INSTANCES@(
|
||||
@INSTANCES@& instances) {
|
||||
@CALL_STATEMENTS@;
|
||||
}
|
||||
|
||||
} // namespace ck::tensor_operation::device::instance
|
||||
@@ -68,3 +68,6 @@ endif()
|
||||
|
||||
target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS})
|
||||
target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS})
|
||||
target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0)
|
||||
target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0)
|
||||
|
||||
|
||||
9
cmake/instantiate_shard.in
Normal file
9
cmake/instantiate_shard.in
Normal file
@@ -0,0 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "@INSTANCES@.inc"
|
||||
|
||||
namespace ck::tensor_operation::device::instance {
|
||||
template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>(
|
||||
@INSTANCES@& instances);
|
||||
} // namespace ck::tensor_operation::device::instance
|
||||
@@ -945,11 +945,9 @@ WARN_LOGFILE =
|
||||
# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
|
||||
# Note: If this tag is empty the current directory is searched.
|
||||
|
||||
INPUT = ../../include/ck/tensor_operation/gpu/grid \
|
||||
../../include/ck/tensor_operation/gpu/block \
|
||||
../../include/ck/tensor_operation/gpu/thread \
|
||||
INPUT = ../../include \
|
||||
../../include/ck/ \
|
||||
../../library/include/ck/library/utility \
|
||||
../../include/ck/wrapper \
|
||||
../../include/ck_tile
|
||||
|
||||
# This tag can be used to specify the character encoding of the source files
|
||||
@@ -1849,7 +1847,7 @@ MATHJAX_CODEFILE =
|
||||
# The default value is: YES.
|
||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||
|
||||
SEARCHENGINE = YES
|
||||
SEARCHENGINE = NO
|
||||
|
||||
# When the SERVER_BASED_SEARCH tag is enabled the search engine will be
|
||||
# implemented using a web server instead of a web client using JavaScript. There
|
||||
@@ -2406,7 +2404,7 @@ TAGFILES =
|
||||
# tag file that is based on the input files it reads. See section "Linking to
|
||||
# external documentation" for more information about the usage of tag files.
|
||||
|
||||
GENERATE_TAGFILE =
|
||||
GENERATE_TAGFILE = html/tagfile.xml
|
||||
|
||||
# If the ALLEXTERNALS tag is set to YES, all external class will be listed in
|
||||
# the class index. If set to NO, only the inherited external classes will be
|
||||
@@ -2653,7 +2651,7 @@ DIR_GRAPH_MAX_DEPTH = 1
|
||||
# The default value is: png.
|
||||
# This tag requires that the tag HAVE_DOT is set to YES.
|
||||
|
||||
DOT_IMAGE_FORMAT = png
|
||||
DOT_IMAGE_FORMAT = svg
|
||||
|
||||
# If DOT_IMAGE_FORMAT is set to svg, then this option can be set to YES to
|
||||
# enable generation of interactive SVG images that allow zooming and panning.
|
||||
@@ -2665,7 +2663,7 @@ DOT_IMAGE_FORMAT = png
|
||||
# The default value is: NO.
|
||||
# This tag requires that the tag HAVE_DOT is set to YES.
|
||||
|
||||
INTERACTIVE_SVG = NO
|
||||
INTERACTIVE_SVG = YES
|
||||
|
||||
# The DOT_PATH tag can be used to specify the path where the dot tool can be
|
||||
# found. If left blank, it is assumed the dot tool can be found in the path.
|
||||
|
||||
@@ -36,7 +36,9 @@ The Composable Kernel repository is located at `https://github.com/ROCm/composab
|
||||
* :doc:`Composable Kernel custom types <./reference/Composable_Kernel_custom_types>`
|
||||
* :doc:`Composable Kernel vector utilities <./reference/Composable_Kernel_vector_utilities>`
|
||||
* :ref:`wrapper`
|
||||
* :doc:`Composable Kernel complete class list <./doxygen/html/annotated>`
|
||||
* :doc:`Composable Kernel API reference <./doxygen/html/namespace_c_k>`
|
||||
* :doc:`CK Tile API reference <./doxygen/html/namespaceck__tile>`
|
||||
* :doc:`Composable Kernel complete API class list <./doxygen/html/annotated>`
|
||||
|
||||
To contribute to the documentation refer to `Contributing to ROCm <https://rocm.docs.amd.com/en/latest/contribute/contributing.html>`_.
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core[api_reference]==1.20.1
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
sphinxcontrib-bibtex==2.6.5
|
||||
|
||||
@@ -182,7 +182,7 @@ ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pybtex==0.24.0
|
||||
pybtex==0.25.1
|
||||
# via
|
||||
# pybtex-docutils
|
||||
# sphinxcontrib-bibtex
|
||||
@@ -244,9 +244,7 @@ rpds-py==0.24.0
|
||||
# jsonschema
|
||||
# referencing
|
||||
six==1.17.0
|
||||
# via
|
||||
# pybtex
|
||||
# python-dateutil
|
||||
# via python-dateutil
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
snowballstemmer==2.2.0
|
||||
@@ -278,7 +276,7 @@ sphinx-notfound-page==1.1.0
|
||||
# via rocm-docs-core
|
||||
sphinxcontrib-applehelp==2.0.0
|
||||
# via sphinx
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
sphinxcontrib-bibtex==2.6.5
|
||||
# via -r requirements.in
|
||||
sphinxcontrib-devhelp==2.0.0
|
||||
# via sphinx
|
||||
|
||||
@@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
|
||||
@@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
|
||||
if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132)
|
||||
if(hip_VERSION_FLAT LESS_EQUAL 600342132)
|
||||
set(EXAMPLE_COMPILE_OPTIONS)
|
||||
check_cxx_compiler_flag("-mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1" HAS_MAX_ILP_SCHEDULING_STRATEGY)
|
||||
if(HAS_MAX_ILP_SCHEDULING_STRATEGY)
|
||||
@@ -31,8 +31,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
example_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
endif()
|
||||
set(GEMM_OPTIONS)
|
||||
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
@@ -40,16 +39,25 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(GEMM_OPTIONS)
|
||||
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
set(BLOCKSCALE_GEMM_OPTIONS)
|
||||
set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
set(BLOCKSCALE_GEMM_OPTIONS )
|
||||
check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP)
|
||||
check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION)
|
||||
if(HAS_MISCHED_BOTTOMUP)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1")
|
||||
elseif(HAS_MISCHED_PRERA_DIRECTION)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup")
|
||||
|
||||
if(hip_VERSION_FLAT LESS 600443483 OR hip_VERSION_FLAT GREATER_EQUAL 700000000)
|
||||
if(HAS_MISCHED_BOTTOMUP)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1")
|
||||
elseif(HAS_MISCHED_PRERA_DIRECTION)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup")
|
||||
endif()
|
||||
else()
|
||||
if(HAS_MISCHED_BOTTOMUP)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-bottomup=1")
|
||||
elseif(HAS_MISCHED_PRERA_DIRECTION)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-prera-direction=bottomup")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental " HAS_MAX_OCCUPANCY_EXPERIMENTAL)
|
||||
if(HAS_MAX_OCCUPANCY_EXPERIMENTAL)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental)
|
||||
|
||||
@@ -139,6 +139,7 @@ static constexpr ck::index_t EVec = 2;
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
static constexpr bool PerTokenQuant = true;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
@@ -169,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, PerTokenQuant, int32_t, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
|
||||
@@ -197,7 +198,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 3)
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
@@ -238,7 +239,8 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
constexpr auto StrideDs = PerTokenQuant ? std::array<ck::index_t, NumDTensor>{1, 1, 0}
|
||||
: std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
@@ -279,8 +281,10 @@ int main(int argc, char* argv[])
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
Tensor<D0DataType> d0_t_n(
|
||||
HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(
|
||||
HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
|
||||
@@ -10,6 +10,12 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
|
||||
# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
|
||||
# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp6)
|
||||
|
||||
add_example_executable(example_gemm_mx_bf6 gemm_mx_bf6.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_bf6)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4)
|
||||
|
||||
@@ -22,17 +28,41 @@ add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns)
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns)
|
||||
|
||||
add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4)
|
||||
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4)
|
||||
|
||||
add_example_executable(example_moe_gemm1_xdl_mx_fp4_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bpreshuffle)
|
||||
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bpreshuffle)
|
||||
|
||||
set(FP4_MXGEMM_OPTIONS)
|
||||
list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1")
|
||||
example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
# mx moe B no-shuffling + scale shuffling
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
# mx moe B no-shuffling + scale shuffling (async loads)
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
# mx moe B shuffling + scale shuffling (async loads)
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
set(FP8_MXGEMM_OPTIONS)
|
||||
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
|
||||
set(FP6_MXGEMM_OPTIONS)
|
||||
list(APPEND FP6_MXGEMM_OPTIONS -mavx512f)
|
||||
example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_bf6 PRIVATE ${FP6_MXGEMM_OPTIONS})
|
||||
|
||||
@@ -8,14 +8,16 @@ Custom verification parameters:
|
||||
# arg2: initialization (0=constant values, 1=integer values, 2=decimal values)
|
||||
# arg3: time kernel (0=no, 1=yes)
|
||||
# arg4: verbosity (0=no info, 1=verbose info)
|
||||
# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC
|
||||
# arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC
|
||||
# arg11: KBatch
|
||||
# arg12: warmup runs pre-timing
|
||||
# arg13: repeat run count for timing
|
||||
./bin/example_gemm_mx_fp8 1 1 0 1
|
||||
```
|
||||
|
||||
Custom tensor shapes:
|
||||
```bash
|
||||
./bin/example_gemm_mx_fp8 1 2 1 0 128 128 256 -1 -1 -1 1
|
||||
./bin/example_gemm_mx_fp8 1 2 1 0 256 256 512 -1 -1 -1 1 10 10
|
||||
```
|
||||
|
||||
Default invocation:
|
||||
|
||||
101
example/67_gemm_microscaling/gemm_mx_bf6.cpp
Normal file
101
example/67_gemm_microscaling/gemm_mx_bf6.cpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::bf6x16_pk_t;
|
||||
using BDataType = ck::bf6x16_pk_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 16; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 bf6 = 16 bf6x16_pk_t
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XPackedDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XPackedDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
1, // AK1
|
||||
1, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
1, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
1, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XPackedDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ScaleBlockSize>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -100,8 +100,11 @@ bool parse_cmd_args(int argc,
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
|
||||
<< "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg11: KBatch" << std::endl;
|
||||
<< "arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg11: KBatch" << std::endl
|
||||
<< "arg12: warmup runs pre-timing" << std::endl
|
||||
<< "arg13: repeat run count for timing" << std::endl;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -245,6 +248,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
if(K % ck::packed_size_v<ADataType> != 0 || K % ck::packed_size_v<BDataType> != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of packed size.");
|
||||
};
|
||||
|
||||
// Hardcode scale layouts as per pipeline assumptions
|
||||
// TODO: Allow user to specify scale layouts
|
||||
using AScaleLayout = Row;
|
||||
@@ -292,12 +300,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
auto a_data_element = [](float x) {
|
||||
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
|
||||
return ck::type_convert<ADataType>(ck::float2_t(x));
|
||||
else if constexpr(ck::packed_size_v<ADataType> == 32)
|
||||
return ck::type_convert<ADataType>(ck::float32_t(x));
|
||||
else if constexpr(ck::packed_size_v<ADataType> == 16)
|
||||
return ck::type_convert<ADataType>(ck::float16_t(x));
|
||||
else
|
||||
return ck::type_convert<ADataType>(x);
|
||||
};
|
||||
auto b_data_element = [](float x) {
|
||||
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
|
||||
return ck::type_convert<BDataType>(ck::float2_t(x));
|
||||
else if constexpr(ck::packed_size_v<BDataType> == 32)
|
||||
return ck::type_convert<BDataType>(ck::float32_t(x));
|
||||
else if constexpr(ck::packed_size_v<BDataType> == 16)
|
||||
return ck::type_convert<BDataType>(ck::float16_t(x));
|
||||
else
|
||||
return ck::type_convert<BDataType>(x);
|
||||
};
|
||||
@@ -307,30 +323,35 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: // Initializations for development and debugging
|
||||
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
|
||||
|
||||
ck::utils::FillConstant<ADataType>{a_data_element(0.5f)}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
|
||||
|
||||
ck::utils::FillConstant<BDataType>{b_data_element(2.0f)}(*b_k_n);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
|
||||
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Init A = {1}" << std::endl;
|
||||
std::cout << "Init A = {0.5}" << std::endl;
|
||||
std::cout << "Init A scale = {2.0}" << std::endl;
|
||||
std::cout << "Init B = {0.5}" << std::endl;
|
||||
std::cout << "Init B scale = {1.0}" << std::endl;
|
||||
std::cout << "Init B = {2.0}" << std::endl;
|
||||
std::cout << "Init B scale = {0.5}" << std::endl;
|
||||
std::cout << "Expect C = {K}" << std::endl;
|
||||
}
|
||||
break;
|
||||
|
||||
case 1:
|
||||
a_m_k.GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
|
||||
b_k_n->GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
|
||||
a_m_k.GenerateTensorDistr(
|
||||
int_distr{-5, 5}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-5,5]
|
||||
b_k_n->GenerateTensorDistr(int_distr{-5, 5}); // Z[-5,5]
|
||||
static_assert(ck::is_same_v<XDataType, ck::e8m0_bexp_t>);
|
||||
a_m_k_scale.GenerateTensorDistr(int_distr{120, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
|
||||
break;
|
||||
|
||||
case 2:
|
||||
a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0});
|
||||
a_m_k.GenerateTensorDistr(
|
||||
float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2]
|
||||
a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f});
|
||||
|
||||
b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0});
|
||||
|
||||
99
example/67_gemm_microscaling/gemm_mx_fp6.cpp
Normal file
99
example/67_gemm_microscaling/gemm_mx_fp6.cpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f6x16_pk_t;
|
||||
using BDataType = ck::f6x16_pk_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / ck::packed_size_v<ADataType>; // K dimension size per block
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Number of threads per block
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
1, // AK1 number of elements to read at a time when transferring from global memory to LDS
|
||||
1, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ScaleBlockSize>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
548
example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp
Normal file
548
example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp
Normal file
@@ -0,0 +1,548 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
|
||||
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t NPerBlock = 64;
|
||||
static constexpr ck::index_t BlockSize = 256;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, BlockSize,
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
|
||||
ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
}
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
|
||||
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_k_n_host_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_k_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
|
||||
e_t_k_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{0.5f});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.5f});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k(token_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A/B scale shuffle
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
|
||||
b_scale_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop =
|
||||
// FMA * tokens * N * (Gate+Up) * topk * K +
|
||||
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
|
||||
std::size_t(2) * tokens * N * 2 * topk * K +
|
||||
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
|
||||
sizeof(B0DataType) / 2 * K * N * 2 * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
|
||||
sizeof(EDataType) * tokens * topk * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// gemm2 use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
float, // CShuffleDataType,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ActOP,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k,
|
||||
a1_t_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int m = 0; m < valid_size; ++m)
|
||||
{
|
||||
const int fuse_t = sorted_token_ids.mData[m];
|
||||
const int t = fuse_t & 0xffffff;
|
||||
const int topk_id = (fuse_t & 0xff000000) >> 24;
|
||||
|
||||
if(t >= tokens)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_k_n_host_result(t, topk_id, n) =
|
||||
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto status =
|
||||
ck::utils::check_err(
|
||||
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
if(status == 0)
|
||||
{
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,574 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
using I64 = int64_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
|
||||
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// B preshuffle
|
||||
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
I64 tempk;
|
||||
for(I64 n = 0; n < N; ++n)
|
||||
{
|
||||
for(I64 k = 0; k < K_pk; ++k)
|
||||
{
|
||||
I64 n0 = n / NLane;
|
||||
I64 n1 = n % NLane;
|
||||
|
||||
I64 k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
I64 k1 = tempk / KPack;
|
||||
I64 k2 = tempk % KPack;
|
||||
|
||||
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 64, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
}
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
|
||||
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
// B preshuffle
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_k_n_host_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_k_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
|
||||
e_t_k_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k(token_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A/B scale shuffle
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
|
||||
b_scale_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
preShuffleBuffer(b0_e_n_k.mData.data(),
|
||||
b0_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K,
|
||||
device_op.GetPreShuffleParameters());
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop =
|
||||
// FMA * tokens * N * (Gate+Up) * topk * K +
|
||||
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
|
||||
std::size_t(2) * tokens * N * 2 * topk * K +
|
||||
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
|
||||
sizeof(B0DataType) / 2 * K * N * 2 * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
|
||||
sizeof(EDataType) * tokens * topk * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
float, // CShuffleDataType,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ActOP,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k,
|
||||
a1_t_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int m = 0; m < valid_size; ++m)
|
||||
{
|
||||
const int fuse_t = sorted_token_ids.mData[m];
|
||||
const int t = fuse_t & 0xffffff;
|
||||
const int topk_id = (fuse_t & 0xff000000) >> 24;
|
||||
|
||||
if(t >= tokens)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_k_n_host_result(t, topk_id, n) =
|
||||
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto status =
|
||||
ck::utils::check_err(
|
||||
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
if(status == 0)
|
||||
{
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
542
example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp
Normal file
542
example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp
Normal file
@@ -0,0 +1,542 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
|
||||
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
|
||||
int eids[sorted_tile_num]{};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
if(i < valid_tile_num)
|
||||
{
|
||||
eids[i] = (i * experts) / valid_tile_num;
|
||||
}
|
||||
else
|
||||
{
|
||||
eids[i] = 3;
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<XDataType> a1_t_k_k(
|
||||
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
|
||||
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 8:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
|
||||
// d2_e_n.savetxt("weight.txt", "int");
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
|
||||
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
// FMA * tokens * N * topk * K +
|
||||
// FMA * tokens * N * topk * (K/BlockScale)
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
|
||||
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// gemm2 use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
D2DataType,
|
||||
float, // using float for Cshuffle type
|
||||
// in reference
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp,
|
||||
MulRoutedWeight,
|
||||
float,
|
||||
float>;
|
||||
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k_k,
|
||||
a1_t_k_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n, // topk weights
|
||||
c_t_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -158,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,584 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
using I64 = int64_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
|
||||
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// B preshuffle
|
||||
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
I64 tempk;
|
||||
for(I64 n = 0; n < N; ++n)
|
||||
{
|
||||
for(I64 k = 0; k < K_pk; ++k)
|
||||
{
|
||||
I64 n0 = n / NLane;
|
||||
I64 n1 = n % NLane;
|
||||
|
||||
I64 k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
I64 k1 = tempk / KPack;
|
||||
I64 k2 = tempk % KPack;
|
||||
|
||||
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
8, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
|
||||
int eids[sorted_tile_num]{};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
if(i < valid_tile_num)
|
||||
{
|
||||
eids[i] = (i * experts) / valid_tile_num;
|
||||
}
|
||||
else
|
||||
{
|
||||
eids[i] = 3;
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<XDataType> a1_t_k_k(
|
||||
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
|
||||
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
// B preshuffle
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 8:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A, B Scale preshuffle
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
|
||||
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
preShuffleBuffer(b0_e_n_k.mData.data(),
|
||||
b0_preshuffled.mData.data(),
|
||||
N * experts,
|
||||
K,
|
||||
device_op.GetPreShuffleParameters());
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
// FMA * tokens * N * topk * K +
|
||||
// FMA * tokens * N * topk * (K/BlockScale)
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
|
||||
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// gemm2 use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
D2DataType,
|
||||
float, // using float for Cshuffle type
|
||||
// in reference
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp,
|
||||
MulRoutedWeight,
|
||||
float,
|
||||
float>;
|
||||
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k_k,
|
||||
a1_t_k_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n, // topk weights
|
||||
c_t_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -128,6 +128,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
|
||||
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# validate user-specified fmha_fwd API list
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;fwd_decode")
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill;fwd_decode")
|
||||
set(FMHA_FWD_ENABLE_APIS "fwd_decode" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
# set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
# "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
@@ -17,14 +19,30 @@ endforeach()
|
||||
# list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
|
||||
# endif()
|
||||
|
||||
# Filtering kernel
|
||||
# set(KERNEL fmha_fwd_decode_d64_bf16_batch_b16x64x32x64x32x64_r1x4x1_r1x4x1_w16x16x32_w16x16x32_decode_qr_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
|
||||
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
|
||||
)
|
||||
# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change
|
||||
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
|
||||
|
||||
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS}
|
||||
# --filter fmha_fwd...
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd
|
||||
--receipt 3
|
||||
# --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd...
|
||||
)
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt #--filter ${KERNEL}
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
@@ -32,8 +50,8 @@ if(ret AND NOT ret EQUAL 0)
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
@@ -47,14 +65,16 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} #--filter ${KERNEL}
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_BWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
@@ -76,7 +96,7 @@ target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
# NOTE: this is dangerous since will change the whole kernel to flush denormals
|
||||
# WIP with compiler team for an exp2 intrinsic..., then remove this
|
||||
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
|
||||
set(FMHA_FWD_FAST_EXP2 true)
|
||||
set(FMHA_FWD_FAST_EXP2 true)
|
||||
endif()
|
||||
|
||||
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
|
||||
@@ -85,9 +105,9 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
# ... because they are auto-generated
|
||||
if(FMHA_FWD_FAST_EXP2)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
# list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
|
||||
@@ -113,6 +133,13 @@ else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_DECODE_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the pagedkv_prefill API in fmha_fwd example
|
||||
if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally specify the use of OCP_FP8
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
|
||||
@@ -71,6 +71,7 @@ args:
|
||||
-drop_seed seed for random number generator (default:1)
|
||||
-drop_offset offset for random number generator (default:0)
|
||||
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
|
||||
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
```
|
||||
|
||||
@@ -123,6 +123,7 @@ PIPELINE_ENUM_MAP = {
|
||||
# "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
# "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
# "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
# "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"decode_qr" : "ck_tile::BlockFmhaPipelineEnum::DECODE_QRKSVS",
|
||||
}
|
||||
|
||||
|
||||
@@ -150,14 +150,14 @@ unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seq
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
auto get_num_blocks = [&](unsigned kM0) {{
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
@@ -490,7 +490,7 @@ class KernelComponentFactory:
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'128' : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -516,13 +516,11 @@ class KernelComponentFactory:
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'128' : [FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
if 128 in result.keys():
|
||||
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
|
||||
return result
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
@@ -536,9 +534,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tiles = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
|
||||
@@ -169,7 +169,7 @@ template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_d
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
@@ -527,6 +527,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= bias in ['no', 'bias']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= mode == 'batch'
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -114,8 +115,52 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
@@ -131,37 +176,51 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return 'true'
|
||||
else:
|
||||
return f'{self.bool_expr}'
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f'({str(self)}) && ({str(other)})')
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
constraint : CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -218,18 +277,19 @@ class FmhaFwdApiTrait:
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -303,6 +363,7 @@ class FmhaFwdApiPool:
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
@@ -317,25 +378,27 @@ class FmhaFwdApiPool:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
@@ -429,35 +492,38 @@ class FmhaFwdKernel:
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip)
|
||||
skip=self.F_pipeline.F_skip,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
(32, 32) : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(64, 64) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (96, 128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (160,160) : FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(192,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (192,192) : FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
(64,64 ) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
(32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
### (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
### (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
### (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim, hdim_v) -> List[FmhaFwdPipeline]:
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
@@ -502,16 +568,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
|
||||
return result
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for ((hdim, hdim_v), tile), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for pipeline in get_pipelines(dtype, hdim, hdim_v):
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
@@ -551,7 +629,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
cond &= pipeline.F_logits == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
|
||||
@@ -332,6 +332,12 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16, bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ K0_MAX_SUBMAX_MAP = {
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
# 160: 160,
|
||||
256: 256
|
||||
}
|
||||
|
||||
@@ -638,6 +639,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -656,6 +658,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -683,7 +686,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
@@ -751,6 +754,15 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16, bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= mode == 'batch'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
|
||||
585
example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
Normal file
585
example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
Normal file
@@ -0,0 +1,585 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8" : 8,
|
||||
"bf8" : 8
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
256: 256
|
||||
}
|
||||
|
||||
FMHA_FWD_PAGEDKV_PIPELINE_MAP = {
|
||||
"qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse}, //lse
|
||||
{F_pagedkv}, //pagedkv
|
||||
{F_squant},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_pagedkv_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
pagedkv : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_pagedkv : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_skip == 't' : n += '_skip'
|
||||
else: n += '_nskip'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
if self.F_pagedkv == 't' : n += '_pagedkv'
|
||||
else: n += '_npagedkv'
|
||||
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0max = self.F_tile.F_bk0max,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# TODO
|
||||
None
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
# if pipeline.F_pagedkv == 'f':
|
||||
# continue
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' :
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -355,7 +355,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
auto slopes = ck_tile::get_alibi_slopes<AccDataType>(nhead);
|
||||
assert(slopes.size() == nhead);
|
||||
assert(slopes.size() == static_cast<decltype(slopes.size())>(nhead));
|
||||
if(bias.rank_info == 0)
|
||||
{
|
||||
// alibi in 1*h
|
||||
@@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_hp_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); });
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
@@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
auto idx_gmo = idx_gmn;
|
||||
idx_gmo[2] = o;
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
|
||||
}
|
||||
self(idx_gmn) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o));
|
||||
});
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
[&](auto i0, auto i1, auto i2) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
|
||||
}
|
||||
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
|
||||
},
|
||||
ds_hp_host_ref.mDesc.get_lengths()[0],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
|
||||
});
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
}
|
||||
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
ds_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
|
||||
// dV = P_drop^T@dO^T
|
||||
// dV = P^T@dO^T w/o dropout
|
||||
|
||||
89
example/ck_tile/01_fmha/fmha_fwd.cpp
Normal file → Executable file
89
example/ck_tile/01_fmha/fmha_fwd.cpp
Normal file → Executable file
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -178,50 +178,30 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
|
||||
}
|
||||
}
|
||||
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
|
||||
{
|
||||
// If we have enough to almost fill the SMs, then just use 1 split
|
||||
if(batch_nhead_mblocks >= 0.8f * num_SMs)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
||||
max_splits = std::min({max_splits, num_SMs});
|
||||
float max_efficiency = 0.f;
|
||||
std::vector<float> efficiency;
|
||||
efficiency.reserve(max_splits);
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
||||
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
||||
// (i.e. it's 11 splits anyway).
|
||||
// So we check if the number of blocks per split is the same as the previous num_splits.
|
||||
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
||||
return num_splits == 1 ||
|
||||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
||||
};
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
efficiency.push_back(0.f);
|
||||
}
|
||||
else
|
||||
{
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
}
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
|
||||
{
|
||||
// printf("num_splits chosen = %d\n", num_splits);
|
||||
@@ -234,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks,
|
||||
int override_num_splits_if_necessary(
|
||||
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
||||
{
|
||||
(void)hdim_v;
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
@@ -250,15 +231,13 @@ int override_num_splits_if_necessary(
|
||||
|
||||
// tile size should match the generate.py
|
||||
const int kM0 = 64;
|
||||
const int kN1 = hdim_v;
|
||||
|
||||
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
|
||||
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
if(num_splits < 1 && p_drop == 0.0f)
|
||||
{
|
||||
return num_splits_heuristic(
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
|
||||
}
|
||||
|
||||
return num_splits;
|
||||
@@ -344,7 +323,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
|
||||
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \
|
||||
CK_TILE_FMHA_FWD_PAGEDKV_API))
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
|
||||
@@ -360,7 +340,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
|
||||
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API)
|
||||
if(use_cache_batch_idx)
|
||||
{
|
||||
std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option"
|
||||
@@ -542,8 +522,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * real_seqlen_q * real_seqlen_k * hdim_q +
|
||||
static_cast<std::size_t>(2) * real_seqlen_q * hdim_v * real_seqlen_k);
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
|
||||
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
|
||||
sizeof(KDataType) * real_seqlen_k * hdim_q +
|
||||
@@ -568,7 +548,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
|
||||
return false;
|
||||
}
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
if(0 < p_drop && (1 < num_splits || use_kvcache))
|
||||
{
|
||||
std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
|
||||
@@ -832,7 +812,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< (is_rotary_interleaved ? "inter" : "half") << ")";
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_DECODE_API
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || CK_TILE_FMHA_FWD_DECODE_API
|
||||
if(1 < num_splits)
|
||||
{
|
||||
std::cout << ", num_splits:" << num_splits;
|
||||
@@ -873,6 +853,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
traits.has_dropout = (p_drop > 0.0f);
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_traits,
|
||||
std::decay_t<decltype(traits)>>)
|
||||
{
|
||||
traits.use_pagedkv = use_kvcache;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1081,6 +1066,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.split_stride_lse_acc = split_stride_lse_acc;
|
||||
args.split_stride_o_acc = split_stride_o_acc;
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.block_table_ptr =
|
||||
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
|
||||
args.batch_stride_block_table = batch_stride_block_table;
|
||||
args.page_block_size = page_block_size;
|
||||
args.is_gappy = false; // use 'false' for flash-attention integration
|
||||
|
||||
args.cache_batch_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1102,7 +1098,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
const float fwd_ave_time = [&] {
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(1 < num_splits || use_kvcache)
|
||||
if(1 < num_splits && use_kvcache)
|
||||
{
|
||||
fmha_fwd_splitkv_traits fmha_splitkv_traits;
|
||||
init_traits(fmha_splitkv_traits);
|
||||
@@ -1120,6 +1116,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
init_args(fmha_decode_args);
|
||||
|
||||
return fmha_fwd_decode(fmha_decode_traits, fmha_decode_args, stream_config);
|
||||
#elif CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
if(use_kvcache)
|
||||
{
|
||||
fmha_fwd_pagedkv_traits fmha_pagedkv_traits;
|
||||
init_traits(fmha_pagedkv_traits);
|
||||
|
||||
fmha_fwd_pagedkv_args fmha_pagedkv_args;
|
||||
init_args(fmha_pagedkv_args);
|
||||
|
||||
return fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config);
|
||||
}
|
||||
#else
|
||||
fmha_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
@@ -1276,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
if(0 < page_block_size) {
|
||||
if(i_perm) {
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
@@ -1327,7 +1334,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
});
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
if(0 < page_block_size) {
|
||||
if(is_v_rowmajor) {
|
||||
if(i_perm) {
|
||||
|
||||
@@ -178,6 +178,86 @@ struct fmha_fwd_args
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
struct fmha_fwd_pagedkv_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
|
||||
// nullptr.
|
||||
|
||||
const void* cache_batch_idx;
|
||||
|
||||
// the real seqlen_q & seqlen_k are decided by following:
|
||||
// batch mode: seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k
|
||||
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
//
|
||||
// batch mode (kvcache):
|
||||
// seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// group mode (kvcache):
|
||||
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
//
|
||||
// when is_gappy=true:
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// seqstart_k_ptr[b] now store local offset of each batch
|
||||
//
|
||||
// when is_gappy=false:
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
float logits_soft_cap;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
};
|
||||
|
||||
struct fmha_fwd_splitkv_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
@@ -590,6 +670,114 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.is_gappy,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.cache_batch_idx,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
}
|
||||
}();
|
||||
|
||||
// FmhaKernel::PrintParameters(kargs, args.batch);
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
dim3 grids = FmhaKernel::GridSize(
|
||||
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 grids =
|
||||
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
{
|
||||
@@ -966,102 +1154,102 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1121,6 +1309,57 @@ struct fmha_fwd_traits_
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kIsPagedKV_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
struct fmha_fwd_pagedkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args);
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
@@ -1333,6 +1572,27 @@ struct fmha_fwd_traits
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_pagedkv_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
bool has_logits_soft_cap;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse = false;
|
||||
bool use_pagedkv = true;
|
||||
bool do_fp8_static_quant = false;
|
||||
bool skip_min_seqlen_q = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits&,
|
||||
fmha_fwd_pagedkv_args&,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_splitkv_traits
|
||||
{
|
||||
int hdim_q;
|
||||
|
||||
21
example/ck_tile/01_fmha/mask.hpp
Normal file → Executable file
21
example/ck_tile/01_fmha/mask.hpp
Normal file → Executable file
@@ -21,6 +21,8 @@ enum class mask_enum
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
@@ -42,6 +44,8 @@ struct mask_info
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
tmp.seqlen_q = seqlen_q;
|
||||
tmp.seqlen_k = seqlen_k;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
@@ -148,7 +152,22 @@ struct mask_info
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
ck_tile::index_t get_unmaskarea() const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
return seqlen_q * seqlen_k;
|
||||
ck_tile::index_t area = 0;
|
||||
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
|
||||
{
|
||||
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
|
||||
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
|
||||
if(x_end > x_start)
|
||||
{
|
||||
area += (x_end - x_start);
|
||||
}
|
||||
}
|
||||
return area;
|
||||
}
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
|
||||
@@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (WarpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / WarpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
|
||||
return ThreadPerBlock_N_ / WarpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
|
||||
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
|
||||
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
|
||||
@@ -30,7 +30,7 @@ args:
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
|
||||
@@ -212,6 +212,11 @@ int run_gemm_example(int argc, char* argv[])
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::int8_t, ck_tile::int8_t, int32_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
@@ -15,6 +14,40 @@
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE 5
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
@@ -34,6 +67,7 @@ struct GemmConfigBase
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -90,7 +124,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
@@ -109,7 +143,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
@@ -128,7 +162,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 32 : 128;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
@@ -151,7 +185,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
@@ -170,7 +204,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
@@ -189,13 +223,57 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -245,6 +323,15 @@ struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using BDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CDataType = int32_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
@@ -260,6 +347,12 @@ struct DataTypeTraits<double>
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
@@ -290,6 +383,12 @@ struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
@@ -329,6 +428,16 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline =
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
|
||||
294
example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp
Normal file
294
example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp
Normal file
@@ -0,0 +1,294 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_gemm_example<GemmConfigPreshufle_1>(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
// Return a non-zero code to indicate failure
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
@@ -251,6 +251,22 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
@@ -284,6 +300,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool persistent = arg_parser.get_int("persistent");
|
||||
|
||||
const bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
@@ -297,8 +315,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
@@ -316,7 +334,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(GemmConfig::UseStructuredSparsity)
|
||||
if(!preshuffle && GemmConfig::UseStructuredSparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
@@ -326,33 +344,43 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
static_assert(!GemmConfig::PermuteA, "Not implemented");
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
|
||||
if constexpr(preshuffle)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
permute_tensor_b<GemmConfig,
|
||||
decltype(b_k_n_dev),
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(b_k_n_dev);
|
||||
}
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
|
||||
// shuffled buffer B for device implementation
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
std::cout << "Permute for this DataType is not implemented." << std::endl;
|
||||
return false;
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
permute_tensor_b<GemmConfig,
|
||||
decltype(b_k_n_dev),
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(b_k_n_dev);
|
||||
}
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
std::cout << "Permute for this DataType is not implemented." << std::endl;
|
||||
return false;
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
@@ -415,29 +443,23 @@ int run_gemm_example_with_layouts(int argc,
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
// memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
// memory on device to store gpu reference result
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(
|
||||
hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_A,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
a_m_k.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
@@ -447,16 +469,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
|
||||
d_C,
|
||||
c_m_n_dev_result.get_element_space_size_in_bytes(),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_A));
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
|
||||
@@ -59,7 +59,8 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
@@ -71,7 +72,6 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
@@ -92,6 +92,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -101,7 +102,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
@@ -112,6 +113,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -135,7 +137,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
@@ -214,8 +216,21 @@ template <typename GemmConfig,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
|
||||
}
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
@@ -299,6 +314,13 @@ int run_gemm_example(int argc, char* argv[])
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "int8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
|
||||
@@ -35,7 +35,7 @@ struct Reduce2dShape
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
WarpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
|
||||
@@ -15,13 +15,14 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "0", "cold iter")
|
||||
.insert("repeat", "1", "hot iter");
|
||||
.insert("repeat", "1", "hot iter")
|
||||
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataType, int USEModelSensitive>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
@@ -81,8 +82,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
false, // kSaveInvRms
|
||||
false, // kSaveUnquant
|
||||
kTwoPass,
|
||||
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
|
||||
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant
|
||||
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
|
||||
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP, // fuse quant
|
||||
static_cast<ck_tile::Rmsnorm2dSensitiveEnum>(
|
||||
USEModelSensitive)>;
|
||||
|
||||
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
|
||||
GammaDataType,
|
||||
@@ -97,7 +100,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>;
|
||||
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass<Problem>;
|
||||
|
||||
using Pipeline =
|
||||
std::conditional_t<(PipelineTraits::kUseModelSensitiveRMSNorm ==
|
||||
ck_tile::Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL ||
|
||||
PipelineTraits::kTwoPass), // TODO: consider TwoPass for T5PassPipeline
|
||||
std::conditional_t<PipelineTraits::kTwoPass,
|
||||
TwoPassPipeline,
|
||||
OnePassPipeline>, // kUseModelSensitiveRMSNorm
|
||||
// == 0
|
||||
T5PassPipeline>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::
|
||||
Default2DEpilogueProblem<ComputeDataType, YDataType, false, PipelineTraits::kPadN, false>;
|
||||
@@ -172,7 +185,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
<< ", s:" << USEModelSensitive << ", valid:" << (pass ? "y" : "n") << std::flush
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
@@ -184,10 +198,19 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
if(use_model_sensitive_rmsnorm == 0) // 0: for no specific RMSNorm
|
||||
{
|
||||
return run<ck_tile::half_t, 0>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(use_model_sensitive_rmsnorm == 1) // 1: for T5-like RMSNorm
|
||||
{
|
||||
return run<ck_tile::half_t, 1>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
@@ -65,7 +65,8 @@ template <typename XDataType_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
ck_tile::index_t kFusedQuant_ = 0,
|
||||
ck_tile::index_t kUseModelSensitiveRMSNorm_ = 0>
|
||||
struct rmsnorm2d_fwd_traits_
|
||||
{
|
||||
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
||||
@@ -74,22 +75,22 @@ struct rmsnorm2d_fwd_traits_
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (WarpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / WarpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -97,13 +98,13 @@ struct rmsnorm2d_fwd_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
|
||||
return ThreadPerBlock_N_ / WarpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -127,8 +128,9 @@ struct rmsnorm2d_fwd_traits_
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
static constexpr ck_tile::index_t kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
@@ -146,7 +148,8 @@ template <typename XDataType_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
int kFusedQuant_,
|
||||
int kUseModelSensitiveRMSNorm_>
|
||||
using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
YDataType_,
|
||||
SmoothScaleDataType_,
|
||||
@@ -162,7 +165,8 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
kSaveUnquant_,
|
||||
kTwoPass_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
kFusedQuant_,
|
||||
kUseModelSensitiveRMSNorm_>;
|
||||
"""
|
||||
|
||||
API_COMMON_HEADER = """
|
||||
@@ -197,7 +201,8 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
Traits_::kSaveUnquant,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant),
|
||||
static_cast<ck_tile::Rmsnorm2dSensitiveEnum>(Traits_::kUseModelSensitiveRMSNorm)>;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
|
||||
@@ -213,7 +218,13 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass<PipelineProblem>;
|
||||
|
||||
using Pipeline = std::conditional_t<
|
||||
(Traits_::kUseModelSensitiveRMSNorm == 0 || Traits_::kTwoPass), // TODO: consider TwoPass for T5PassPipeline
|
||||
std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>, // kUseModelSensitiveRMSNorm == 0
|
||||
T5PassPipeline
|
||||
>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
|
||||
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
||||
@@ -387,12 +398,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
F_kTwoPass : bool
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
F_use_model_sensitive_rmsnorm : int
|
||||
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}, {self.F_use_model_sensitive_rmsnorm:4}'
|
||||
return t_
|
||||
|
||||
# string when calling this kernel
|
||||
@@ -413,6 +425,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
F_add : int
|
||||
F_sweep : int
|
||||
F_saveunquant : bool
|
||||
F_use_model_sensitive_rmsnorm : int
|
||||
instance_list : List[Any] # List[h_traits]
|
||||
|
||||
@property
|
||||
@@ -426,6 +439,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
||||
if self.F_saveunquant:
|
||||
nnn = nnn + '_saveunquant'
|
||||
if self.F_use_model_sensitive_rmsnorm == 0:
|
||||
nnn = nnn + '_nsm'
|
||||
elif self.F_use_model_sensitive_rmsnorm == 1:
|
||||
nnn = nnn + '_t5ml'
|
||||
return nnn
|
||||
|
||||
@property
|
||||
@@ -481,9 +498,9 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant))
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}) && (t.use_model_sensitive_rmsnorm == {f_use_model_sensitive_rmsnorm}) )'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
|
||||
f_sweep_cond = _sweep_cond)
|
||||
f_sweep_cond = _sweep_cond, f_use_model_sensitive_rmsnorm = ins.F_use_model_sensitive_rmsnorm)
|
||||
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
|
||||
F_VEC_COND = _cond, F_instance_func=ins.call_name)
|
||||
#inner_str = inner_str + vec_str
|
||||
@@ -516,85 +533,149 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
bool_list = [False, True]
|
||||
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]}
|
||||
h_trait_dicts = {
|
||||
0: {
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep srm
|
||||
'64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 0)],
|
||||
'768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 0)]
|
||||
},
|
||||
1: {
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep srm
|
||||
'64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 32, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 1)],
|
||||
'768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 1)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 1)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 1)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
|
||||
}
|
||||
}
|
||||
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
|
||||
continue # skip non dynamic quant case
|
||||
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
|
||||
continue
|
||||
if (fused_quant == 0 and save_unquant == True):
|
||||
continue # save_unquant should always be false when there is no quant enabled
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_UnquantYDataType = prec_i
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
h_.F_kSaveUnquant = save_unquant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs))
|
||||
|
||||
for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
|
||||
current_trait_dict = h_trait_dicts[model_sensitive_flag]
|
||||
for hs_key in current_trait_dict:
|
||||
hs = current_trait_dict[hs_key]
|
||||
current_n = hs_key
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
|
||||
continue # skip non dynamic quant case
|
||||
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
|
||||
continue
|
||||
if (fused_quant == 0 and save_unquant == True):
|
||||
continue # save_unquant should always be false when there is no quant enabled
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_UnquantYDataType = prec_i
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
h_.F_kSaveUnquant = save_unquant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, h_.F_use_model_sensitive_rmsnorm, current_hs))
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self) -> None:
|
||||
@@ -712,4 +793,4 @@ if __name__ == "__main__":
|
||||
if args.list_blobs:
|
||||
list_blobs(args)
|
||||
else:
|
||||
gen_blobs(args)
|
||||
gen_blobs(args)
|
||||
|
||||
@@ -52,7 +52,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -66,15 +67,16 @@ template <typename InDataType,
|
||||
bool SaveUnquant>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
|
||||
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
@@ -194,10 +196,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
|
||||
|
||||
rmsnorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant};
|
||||
rmsnorm2d_fwd_traits traits{prec_i,
|
||||
prec_o,
|
||||
prec_sm,
|
||||
prec_sy,
|
||||
SaveRms,
|
||||
SaveUnquant,
|
||||
fused_add,
|
||||
fused_quant,
|
||||
use_model_sensitive_rmsnorm};
|
||||
|
||||
rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
|
||||
@@ -64,6 +64,8 @@ struct rmsnorm2d_fwd_traits
|
||||
bool save_unquant;
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
|
||||
int use_model_sensitive_rmsnorm = 0; // 0: Use default RMSNorm; 1: Use T5-like implementation
|
||||
};
|
||||
|
||||
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -1,37 +1,74 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
|
||||
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
# 0: for no specific RMSNorm
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
|
||||
# 1: for T5-like RMSNorm
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
@@ -5,29 +5,32 @@ for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -p
|
||||
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -36,8 +39,11 @@ done
|
||||
for fquant in ""
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
@@ -49,22 +49,22 @@ struct smoothquant_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (WarpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / WarpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -72,13 +72,13 @@ struct smoothquant_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
|
||||
return ThreadPerBlock_N_ / WarpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -35,7 +35,20 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
.insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf")
|
||||
.insert(
|
||||
"moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...")
|
||||
#else
|
||||
.insert("moe_buf_size", "0", "moe_buf_size")
|
||||
#endif
|
||||
.insert("ci",
|
||||
"1",
|
||||
"clear workspace inside API or not(if \"0\", require manually clear outside)")
|
||||
.insert(
|
||||
"dispatch",
|
||||
"0",
|
||||
"dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel")
|
||||
.insert("local_eid",
|
||||
"-1",
|
||||
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
|
||||
@@ -88,10 +101,17 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
int unit_size = args.get_int("unit");
|
||||
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim");
|
||||
int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes");
|
||||
#else
|
||||
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
|
||||
#endif
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
bool clear_inside = args.get_int("ci") != 0;
|
||||
int dispatch_policy = args.get_int("dispatch");
|
||||
|
||||
int max_output_ids =
|
||||
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
|
||||
@@ -149,11 +169,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
|
||||
// for simplicity, below buffer allocate 2 dword
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::HostTensor<int8_t> moe_buf_host(
|
||||
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
|
||||
moe_buf_elem_bytes});
|
||||
auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast<std::size_t>(0)
|
||||
: moe_buf_host.get_element_space_size_in_bytes();
|
||||
#else
|
||||
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
|
||||
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
|
||||
: moe_buf_host.get_element_space_size_in_bytes();
|
||||
#endif
|
||||
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
|
||||
#else
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
|
||||
#endif
|
||||
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
|
||||
|
||||
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
|
||||
@@ -176,7 +211,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
moe_buf_dev.ToDevice(moe_buf_host.data());
|
||||
}
|
||||
@@ -184,29 +219,31 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
ck_tile::index_t workspace_size =
|
||||
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
if(workspace_size != 0 && clear_inside == false)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
|
||||
moe_sorting_trait trait{
|
||||
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
topk,
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))};
|
||||
moe_sorting_args karg
|
||||
{
|
||||
topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size,
|
||||
num_experts, topk,
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
moe_buf_interm_dim, moe_buf_elem_bytes
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
|
||||
#endif
|
||||
};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
@@ -219,7 +256,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
#if 0
|
||||
{
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
moe_sorting_ws.FromDevice(ws_host.data());
|
||||
|
||||
int * p_mesh = reinterpret_cast<int*>(ws_host.data());
|
||||
@@ -268,7 +305,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens);
|
||||
printf("[%s|%s|%s|%d]tokens:%d",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
|
||||
dispatch_policy,
|
||||
tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
printf("(%d)", local_tokens);
|
||||
@@ -280,6 +322,19 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
|
||||
}
|
||||
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
printf("moe_buf:%lu(%d,%d), ",
|
||||
static_cast<uint64_t>(moe_buf_bytes),
|
||||
moe_buf_interm_dim,
|
||||
moe_buf_elem_bytes);
|
||||
#else
|
||||
|
||||
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
|
||||
#endif
|
||||
}
|
||||
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
else
|
||||
@@ -294,7 +349,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
sorted_weights_dev.FromDevice(sorted_weights_host.data());
|
||||
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
|
||||
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
moe_buf_dev.FromDevice(moe_buf_host.data());
|
||||
}
|
||||
@@ -340,6 +395,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
// if(is_local_token)
|
||||
{
|
||||
auto t_ = is_local_token ? local_tokens : tokens;
|
||||
bool _f = t_ == sorted_id_cnt_host.mData[1];
|
||||
rtn &= _f;
|
||||
if(!_f)
|
||||
{
|
||||
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -347,9 +412,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
rtn = false;
|
||||
}
|
||||
|
||||
if(moe_buf_size)
|
||||
if(moe_buf_bytes)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
|
||||
#else
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
#endif
|
||||
rtn &= ck_tile::check_err(
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}
|
||||
}
|
||||
#else
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0)
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk, t.dispatch_policy) != 0)
|
||||
{
|
||||
return moe_sorting_mp(t, a, s);
|
||||
}
|
||||
@@ -293,6 +293,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
@@ -302,6 +303,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
@@ -314,6 +316,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
@@ -323,6 +326,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
@@ -330,6 +334,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
} \
|
||||
}
|
||||
|
||||
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
|
||||
[&]() { \
|
||||
using problem_ = \
|
||||
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
|
||||
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
@@ -338,6 +353,22 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
|
||||
if(t.clear_workspace_inside_api)
|
||||
{
|
||||
if(is_local_token)
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
@@ -345,6 +376,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
@@ -354,6 +386,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
@@ -405,7 +438,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
return -1;
|
||||
}
|
||||
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
|
||||
}
|
||||
|
||||
@@ -10,8 +10,14 @@
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
bool clear_workspace_inside_api; // if true, no need clear workspace outsize (will take care of
|
||||
// it inside API)
|
||||
int dispatch_policy; // 0 - let the API choose kernel for you. 1 - always use single kerenl. 2 -
|
||||
// always use mp kernel NOTE: moe_sorting_get_workspace_size() need use
|
||||
// same dispatch_policy value. it will be undefined behavior if ppl using
|
||||
// different value when get ws and call the kernel
|
||||
};
|
||||
|
||||
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
@@ -22,6 +28,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
// if return non zero, means need workspace, you need to allocate a GPU buffer
|
||||
// and set to moe_sorting_args.p_ws
|
||||
// NOTE: workspace size are required to clear zero before use the API
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk);
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy);
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# #!/bin/sh
|
||||
|
||||
EXE=./build/bin/tile_example_moe_sorting
|
||||
MOE_BUF="12"
|
||||
|
||||
if [ "x$MOE_BUF" = "x1" ] ; then
|
||||
$EXE -t=80 -e=17 -moe_buf_size=16
|
||||
$EXE -t=111 -e=117 -moe_buf_size=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_size=1024
|
||||
@@ -42,3 +44,46 @@ $EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940
|
||||
else
|
||||
$EXE -t=80 -e=17 -moe_buf_interm_dim=16 -moe_buf_elem_bytes=4
|
||||
$EXE -t=111 -e=117 -moe_buf_interm_dim=4 -moe_buf_elem_bytes=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_interm_dim=1024 -moe_buf_elem_bytes=1
|
||||
$EXE -t=99 -e=120 -moe_buf_interm_dim=10244 -moe_buf_elem_bytes=2
|
||||
$EXE -t=175 -e=64 -k=8
|
||||
$EXE -t=65 -e=8 -k=2
|
||||
$EXE -t=1 -e=25
|
||||
$EXE -t=31 -e=19 -k=15
|
||||
$EXE -t=81 -e=37 -k=7
|
||||
$EXE -t=23 -e=1 -k=1
|
||||
$EXE -t=127 -e=99 -k=19
|
||||
$EXE -t=71 -e=11 -k=11
|
||||
$EXE -t=1 -e=1 -k=1
|
||||
$EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
$EXE -t=11 -e=256 -k=5
|
||||
$EXE -t=64 -e=455 -k=8
|
||||
$EXE -t=777 -e=802 -k=99
|
||||
$EXE -t=4097 -e=906 -k=51
|
||||
$EXE -t=128 -e=32 -k=5 -local_t=6 -moe_buf_interm_dim=262144
|
||||
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
|
||||
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
|
||||
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
|
||||
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
|
||||
$EXE -t=128 -e=128 -k=6 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
|
||||
$EXE -t=8192 -e=32 -k=5 -local_t=11 -moe_buf_interm_dim=163840
|
||||
$EXE -t=8192 -e=32 -k=8 -local_t=12 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
|
||||
$EXE -t=8192 -e=256 -k=5 -local_t=13 -moe_buf_interm_dim=163840
|
||||
$EXE -t=8192 -e=256 -k=8 -local_t=8 -moe_buf_interm_dim=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -local_t=4 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=4
|
||||
$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145
|
||||
$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99
|
||||
$EXE -t=99 -local_t=93 -e=121 -local_t=4 -moe_buf_interm_dim=10244
|
||||
$EXE -t=536 -local_t=345 -e=802 -k=99
|
||||
$EXE -t=331 -local_t=39 -e=83 -k=33
|
||||
$EXE -t=765 -local_t=654 -e=783 -k=8
|
||||
$EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -local_t=2 -moe_buf_interm_dim=133940 -moe_buf_elem_bytes=1
|
||||
|
||||
fi
|
||||
|
||||
@@ -38,22 +38,22 @@ struct moe_smoothquant_traits_
|
||||
using InputType = ck_tile::remove_cvref_t<InputType_>;
|
||||
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (WarpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / WarpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -61,13 +61,13 @@ struct moe_smoothquant_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
|
||||
return ThreadPerBlock_N_ / WarpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
|
||||
@@ -24,23 +25,28 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
}();
|
||||
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
|
||||
auto a0 = fused_moesorting_args{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
|
||||
o_data_bytes // index_t moe_buf_bytes;
|
||||
auto a0 = fused_moesorting_args
|
||||
{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
a.stride_token, o_data_bytes,
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) *
|
||||
a.stride_token* o_data_bytes // index_t moe_buf_bytes;
|
||||
#endif
|
||||
};
|
||||
|
||||
auto t1 = fused_moegemm_traits{t.prec_i,
|
||||
|
||||
@@ -413,5 +413,6 @@ float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
|
||||
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
@@ -399,7 +399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size =
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
@@ -1,8 +1,149 @@
|
||||
# Grouped CShuffle GEMM
|
||||
# Grouped Gemm
|
||||
|
||||
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation.
|
||||
Grouped General Matrix Multiplication (Grouped GEMM) is a technique used in GPU computing and high-performance computing to batch together multiple independent GEMM operations (matrix multiplications) into a single kernel launch in order to improve performance and efficiency. This folder contains Grouped GEMM examples that use the ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
## Quick Tour for New Users
|
||||
|
||||
The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads.
|
||||
|
||||
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
|
||||
|
||||
### Parsing Arguments
|
||||
The example takes three arguments: `group_count`, `repeat`, and `warmup`:
|
||||
- `group_count`: the number of GEMM operations in the group,
|
||||
- `repeat`: the number of times to repeat the kernel for benchmarking
|
||||
- `warmup`: the number of iterations before the actual kernel run time measure.
|
||||
|
||||
```cpp
|
||||
// Example
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
```
|
||||
In the next step, the input parameters `Ms`, `Ns`, `Ks`, as well as the corresponding `stride_As`, `stride_Bs`, and `stride_Cs` are either provided from the comand line or generated by default. Since one or more input data sets are expected for `A` and `B`, each parameter is stored in a `std::vector`. The size of the `vector` is defined by `group_count`.
|
||||
|
||||
```cpp
|
||||
// Example
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
```
|
||||
Where:
|
||||
- `Ms` is the M dimension of each GEMM.
|
||||
- `Ns` is the N dimension of each GEMM.
|
||||
- `Ks` is the K dimension of each GEMM.
|
||||
- `stride_As` is the stride values for matrix A.
|
||||
- `stride_Bs` is the stride values for matrix B.
|
||||
- `stride_Cs` is the stride values for matrix C.
|
||||
|
||||
### HostTensor and Device Memory Buffers (for CPU and GPU)
|
||||
Each parameter `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs` and `stride_Cs` contains values for more than one matrix, meaning different matrix sizes and strides can be used for different grouped GEMM computations.
|
||||
The next step is to properly load the input values. For each input matrix, `A` and `B`, and for each output matrix, `C`, you need to create both `HostTensor` and `DeviceMemory`, where:
|
||||
- `HostTensor` represents the matrix data on the host (CPU). It stores the data before they are transferred to the device for computation.
|
||||
- `DeviceMemory` represents the matrix data on the device (GPU). This will store the data on the GPU for computation during the Grouped GEMM operation.
|
||||
|
||||
#### HostTensor Buffers (for CPU)
|
||||
In the first step, create `HostTensor` for `A`, `B`, `C`. `HostTensor` allocates memory on the host (CPU) to store the matrices, initializing the memory with the appropriate dimensions and values to store the data. Below is an example code showing how to create HostTensors for those tensors:
|
||||
```cpp
|
||||
// Example
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
```
|
||||
Where:
|
||||
- `a_m_k_tensors` is the vector of `HostTensor` objects for matrix `A` (with dimensions `M × K`). Each tensor stores the data for single GEMM operation.
|
||||
- `b_k_n_tensors` is the vector of `HostTensor` objects for matrix `B` (with dimensions `K × N`).
|
||||
- `c_m_n_tensors` is the vector of `HostTensor` objects for matrix `C` (the output matrix with dimensions `M × N`).
|
||||
|
||||
The `std::vector` container is used for this purpose throughout. As mentioned above, the number of HostTensors is equal to `group_count`.
|
||||
|
||||
#### Device Memory Buffers (for GPU)
|
||||
Now it's time to allocate memory on the device (GPU) and transfer the data from `HostTensor` to `DeviceMemory` for actual computation..
|
||||
```cpp
|
||||
// Example
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
```
|
||||
Where:
|
||||
- `a_m_k_dev_buf` is the buffer used for storing matrix A on the GPU.
|
||||
- `b_k_n_dev_buf` is the buffer used for storing matrix B on the GPU.
|
||||
- `c_m_n_dev_buf` is the buffer used for storing the result matrix C on the GPU.
|
||||
|
||||
## Prepare data
|
||||
In the next step, the input tensors are populated. A pseudorandom number generator, an existing distribution (e.g., `FillUniformDistribution`), or user data can be used to populate the tensors. Descriptors also need to be create for each input tensor.
|
||||
|
||||
Use `get_default_stride` to get the strides for A, B, and C. `get_default_stride` is a template function that calculates the default stride for a 2D array based on whether it is row-major or column-major. Template parameter determines whether the storage order is row-major (true) or column-major (false). The function takes four params `row`, `col`, `stride` and `bool_constant<is_row_major>`. If the stride is explicitly provided (`stride != 0`), the stride is returned as-is. If the stride is not provided (`stride == 0`), the function computes the default stride. For the Row-major order (`is_row_major == true`), the stride is set to the number of columns (col). For the column-major order (`is_row_major == false`), the stride is set to the number of rows (row). This function is useful when working with dynamically allocated 2D arrays, where the user may not specify the stride explicitly. It ensures a natural default stride based on the chosen storage order.
|
||||
|
||||
```cpp
|
||||
// Example, API
|
||||
template <bool is_row_major>
|
||||
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant<is_row_major>) {
|
||||
// code
|
||||
}
|
||||
```
|
||||
|
||||
Where:
|
||||
- `is_row_major` is a bool template parameter that determines whether the storage order is row-major (true) or column-major (false).
|
||||
- `row` is the number of rows in the matrix.
|
||||
- `col` is the number of columns in the matrix.
|
||||
- `stride` is the current stride (the distance between consecutive elements in memory).
|
||||
- `bool_constant<is_row_major>` is a tag type that helps in differentiating behavior at compile-time.
|
||||
|
||||
Next host descriptors for each of the input tensors, A, B, and C are created. Use the `f_host_tensor_descriptor` function defined below. This function takes four parameters, row, col, stride, and layout, and returns a HostTensorDescriptor based on the specified layout.
|
||||
|
||||
```cpp
|
||||
// Example for tensor A
|
||||
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)))
|
||||
```
|
||||
|
||||
After creating the host_tensors, create `deviceMem` for each tensor `A`, `B`, and `C`, and then transfer the data to the device. The `get_element_space_size_in_bytes()` function is used to get the buffer size in bytes. Use `ToDevice()` to transfer data from the host to the device. The data that was previously generated (`a_m_k_tensors[i].data()`) is passed as a parameter to `ToDevice()`.
|
||||
|
||||
The final step before running the GEMM operation is to retrieve the pointers to the buffers of `A`, `B`, and `C` stored on the device using `->GetDeviceBuffer()` and pack them into a shared container. For example: `gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]})`, where `gemm_descs` is `std::vector<grouped_gemm_kargs> gemm_descs` ([Code](https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc#L221)). The container should include values such as:
|
||||
```cpp
|
||||
struct GroupedGemmHostArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
};
|
||||
```
|
||||
The data prepared in this way can be passed to the `invoke_gemm` function. This is a templated function that also takes three template parameters: `ALayout`, `BLayout`, and `CLayout`:
|
||||
```cpp
|
||||
// Example, API
|
||||
template <typename ALayout, typename BLayout, typename CLayout, bool Persistent>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_kargs>& args)
|
||||
```
|
||||
`invoke_gemm` returns the run time in milliseconds. The workspace memory required for computation is allocated. Workspace memory on the GPU refers to temporary memory buffers allocated when some operations are run. This extra space is needed to hold GEMM descriptions. The following structure can be used to allocate workspace:
|
||||
|
||||
```cpp
|
||||
// Example
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(GetWorkspaceSize(args));
|
||||
```
|
||||
Finally the arguments are passed to group_gemm and the kernel is launched.
|
||||
```cpp
|
||||
// API
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
```
|
||||
All the necessary parameters are set, the tiling is computed, the GEMM pipeline and epilogue are prepared, and the GroupedGemmKernel is launched.
|
||||
|
||||
## Build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
@@ -16,10 +157,17 @@ This will result in an executable `build/bin/tile_example_grouped_gemm`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-a_layout Tensor A layout (default:R)
|
||||
-b_layout Tensor B layout (default:R)
|
||||
-c_layout Tensor C layout (default:R)
|
||||
-v 0. No validation, 1. Validation on CPU
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-Ms M dimensions - (Default: empty).
|
||||
-Ns N dimensions - (Default: empty).
|
||||
-Ks K dimensions - (Default: empty).
|
||||
-stride_As Tensor A strides - (Default: empty).
|
||||
-stride_Bs Tensor B strides - (Default: empty).
|
||||
-stride_Cs Tensor C strides - (Default: empty).
|
||||
-a_layout A tensor data layout - (Default: Row).
|
||||
-b_layout B tensor data layout - (Default: Col).
|
||||
-c_layout C tensor data layout - (Default: Row).
|
||||
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10).
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100).
|
||||
-group_count Group count. (Default: 16).
|
||||
```
|
||||
|
||||
@@ -116,9 +116,12 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
|
||||
@@ -3,6 +3,4 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
|
||||
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -Wno-unused-local-typedef)
|
||||
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -Wno-unused-local-typedef)
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -13,48 +13,94 @@
|
||||
#include "flatmm_basic.hpp"
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using FlatmmConfig = FlatmmConfig<ADataType>;
|
||||
using CodegenFlatmmShape = ck_tile::TileFlatmmShape<
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenFlatmmShape>;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -64,11 +110,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy;
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenFlatmmPolicy>;
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
@@ -86,14 +129,15 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName()
|
||||
<< CodegenPipelineProblem::GetName() << " grid: {" << grids.x << ", "
|
||||
<< grids.y << ", " << grids.z << "}"
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time{0};
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
@@ -111,7 +155,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
@@ -122,7 +166,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
@@ -139,18 +183,28 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -163,24 +217,27 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -196,9 +253,29 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
return !run_flatmm_example(argc, argv);
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 2)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -12,26 +12,81 @@
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
// GEMM config with 32x132 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
|
||||
#endif
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#else
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
@@ -109,43 +164,6 @@ struct is_8bit_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct FlatmmConfig
|
||||
{
|
||||
#if defined(USING_MFMA_16x16x32)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 64 : 16;
|
||||
|
||||
#elif defined(USING_MFMA_32x32x16)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 8;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
|
||||
#endif
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -163,8 +181,11 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value");
|
||||
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
@@ -174,7 +195,8 @@ template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s);
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -18,6 +18,10 @@ constexpr const char* DataTypeToString()
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
@@ -69,13 +73,31 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
@@ -89,22 +111,32 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::FlatmmHostArgs args;
|
||||
args.a_ptr = a_dev_buf.GetDeviceBuffer();
|
||||
args.b_shuffle_ptr = b_shuffle_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_dev_buf.GetDeviceBuffer();
|
||||
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C};
|
||||
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time =
|
||||
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
float ave_time = flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -120,7 +152,11 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_flatmm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
@@ -131,11 +167,10 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
using FlatmmConfig = FlatmmConfig<ADataType>;
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
@@ -145,10 +180,11 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
// persistent not added
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -162,8 +198,26 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
// TODO: add different init types
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_host.SetZero();
|
||||
b_origin_host.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
@@ -173,23 +227,30 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
|
||||
// do pre-shuffle
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
invoke_flatmm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
invoke_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
bool pass = true;
|
||||
|
||||
@@ -1 +1,6 @@
|
||||
add_executable(tile_example_gemm_multi_d_fp16 EXCLUDE_FROM_ALL gemm_multi_d_fp16.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_gemm_multi_d_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
4
example/ck_tile/20_grouped_convolution/CMakeLists.txt
Normal file
4
example/ck_tile/20_grouped_convolution/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
|
||||
set(EXAMPLE_CONV_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
@@ -0,0 +1,207 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
|
||||
#include "run_grouped_convolution_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_fwd_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout!");
|
||||
}
|
||||
}
|
||||
|
||||
int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("weight_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); }
|
||||
@@ -0,0 +1,108 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
ck_tile::index_t fill_spatial_dimensions(std::vector<ck_tile::index_t>& filter_spatial_lengths,
|
||||
std::vector<ck_tile::index_t>& image_spatial_lengths,
|
||||
std::vector<ck_tile::index_t>& strides,
|
||||
std::vector<ck_tile::index_t>& dilations,
|
||||
std::vector<ck_tile::index_t>& lpads,
|
||||
std::vector<ck_tile::index_t>& rpads,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
|
||||
constexpr ck_tile::index_t non_sp_dims = 3;
|
||||
const ck_tile::index_t n_dim_sp = arg_parser.get_str("in_layout").size() - non_sp_dims;
|
||||
|
||||
if(!(n_dim_sp >= 1 && n_dim_sp <= 3))
|
||||
{
|
||||
throw std::runtime_error("Wrong layout!\n");
|
||||
}
|
||||
|
||||
if(n_dim_sp == 3)
|
||||
{
|
||||
filter_spatial_lengths.push_back(arg_parser.get_int("z"));
|
||||
image_spatial_lengths.push_back(arg_parser.get_int("d"));
|
||||
strides.push_back(arg_parser.get_int("stride_d"));
|
||||
dilations.push_back(arg_parser.get_int("dilation_d"));
|
||||
lpads.push_back(arg_parser.get_int("lpad_d"));
|
||||
rpads.push_back(arg_parser.get_int("rpad_d"));
|
||||
}
|
||||
if(n_dim_sp >= 2)
|
||||
{
|
||||
filter_spatial_lengths.push_back(arg_parser.get_int("y"));
|
||||
image_spatial_lengths.push_back(arg_parser.get_int("h"));
|
||||
strides.push_back(arg_parser.get_int("stride_h"));
|
||||
dilations.push_back(arg_parser.get_int("dilation_h"));
|
||||
lpads.push_back(arg_parser.get_int("lpad_h"));
|
||||
rpads.push_back(arg_parser.get_int("rpad_h"));
|
||||
}
|
||||
filter_spatial_lengths.push_back(arg_parser.get_int("x"));
|
||||
image_spatial_lengths.push_back(arg_parser.get_int("w"));
|
||||
strides.push_back(arg_parser.get_int("stride_w"));
|
||||
dilations.push_back(arg_parser.get_int("dilation_w"));
|
||||
lpads.push_back(arg_parser.get_int("lpad_w"));
|
||||
rpads.push_back(arg_parser.get_int("rpad_w"));
|
||||
|
||||
return n_dim_sp;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("g", "2", "group dimension")
|
||||
.insert("n", "32", "n dimension")
|
||||
.insert("k", "32", "k dimension")
|
||||
.insert("c", "32", "c dimension")
|
||||
|
||||
.insert("d", "64", "d dimension")
|
||||
.insert("h", "64", "h dimension")
|
||||
.insert("w", "64", "w dimension")
|
||||
|
||||
.insert("z", "4", "z dimension")
|
||||
.insert("y", "4", "y dimension")
|
||||
.insert("x", "4", "x dimension")
|
||||
|
||||
.insert("stride_d", "1", "d stride")
|
||||
.insert("stride_h", "1", "h stride")
|
||||
.insert("stride_w", "1", "w stride")
|
||||
|
||||
.insert("dilation_d", "1", "d dilation")
|
||||
.insert("dilation_h", "1", "h dilation")
|
||||
.insert("dilation_w", "1", "w dilation")
|
||||
|
||||
.insert("lpad_d", "0", "left pad for d dimension")
|
||||
.insert("lpad_h", "0", "left pad for h dimension")
|
||||
.insert("lpad_w", "0", "left pad for w dimension")
|
||||
|
||||
.insert("rpad_d", "0", "right pad for d dimension")
|
||||
.insert("rpad_h", "0", "right pad for h dimension")
|
||||
.insert("rpad_w", "0", "right pad for w dimension")
|
||||
|
||||
.insert("in_layout", "NHWGC", "Input image layout - NHWGC by default")
|
||||
.insert("weight_layout", "GKYXC", "Weight layout - GKYXC by default")
|
||||
.insert("out_layout", "NHWGK", "Output image layout - NHWGK by default")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s);
|
||||
@@ -0,0 +1,206 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(InDataType) < sizeof(WeiDataType), InDataType, WeiDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<OutDataType, OutDataType, OutDataType>(kbatch);
|
||||
const auto atol_split_k =
|
||||
ck_tile::get_absolute_threshold<OutDataType, OutDataType, OutDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_fwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
int run_grouped_conv_fwd_example_with_layouts(
|
||||
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
std::vector<ck_tile::index_t> filter_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> image_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> strides;
|
||||
std::vector<ck_tile::index_t> dilations;
|
||||
std::vector<ck_tile::index_t> lpads;
|
||||
std::vector<ck_tile::index_t> rpads;
|
||||
|
||||
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads,
|
||||
arg_parser);
|
||||
|
||||
ck_tile::conv::ConvParam conv_param{num_dim_sp,
|
||||
arg_parser.get_int("g"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("c"),
|
||||
filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(input);
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{-5.f, 5.f}(weight);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<InDataType>{}(input);
|
||||
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
|
||||
}
|
||||
else
|
||||
{
|
||||
input.SetZero();
|
||||
weight.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
|
||||
|
||||
input_dev_buf.ToDevice(input.data());
|
||||
weight_dev_buf.ToDevice(weight.data());
|
||||
output_dev_buf.SetZero();
|
||||
|
||||
ck_tile::GroupedConvHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
|
||||
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_fwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
|
||||
output_dev_buf.FromDevice(output.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
|
||||
output_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_grouped_conv_fwd<NDimSpatial, InDataType, WeiDataType, OutDataType>(
|
||||
input,
|
||||
weight,
|
||||
output_host_ref,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_);
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
|
||||
GemmK, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(output,
|
||||
output_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
throw std::runtime_error("Unsupported gpu verification !!!");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
add_executable(test_copy_kernel EXCLUDE_FROM_ALL test_copy.cpp)
|
||||
target_compile_options(test_copy_kernel PRIVATE
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
)
|
||||
@@ -1,117 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <cstring>
|
||||
#include "test_copy.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "64", "m dimension")
|
||||
.insert("n", "8", "n dimension")
|
||||
.insert("id", "0", "warp to use")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "50", "cold iter")
|
||||
.insert("repeat", "100", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using YDataType = DataType;
|
||||
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t warp_id = arg_parser.get_int("id");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n});
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n});
|
||||
|
||||
// ck_tile::FillConstant<XDataType>{1.f}(x_host);
|
||||
ck_tile::half_t value = 1;
|
||||
for(int i = 0; i < m; i++)
|
||||
{
|
||||
value = 1;
|
||||
for(int j = 0; j < n; j++)
|
||||
{
|
||||
x_host(i, j) = value++;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWaves = ck_tile::sequence<2, 1>;
|
||||
using BlockTile = ck_tile::sequence<64, 8>;
|
||||
using WaveTile = ck_tile::sequence<64, 8>;
|
||||
using Vector = ck_tile::sequence<1, 4>;
|
||||
|
||||
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
|
||||
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
|
||||
using Kernel = ck_tile::TileCopy<Problem>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 128;
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
std::cout << "block size " << kBlockSize << std::endl;
|
||||
std::cout << "warp SIze " << ck_tile::get_warp_size() << std::endl;
|
||||
std::cout << "warps per block _M " << Shape::WarpPerBlock_M << " " << Shape::WarpPerBlock_N
|
||||
<< std::endl;
|
||||
std::cout << "Block waves: " << BlockWaves::at(ck_tile::number<0>{}) << " "
|
||||
<< BlockWaves::at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << " Wave Groups: " << Shape::WaveGroups << std::endl;
|
||||
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
m,
|
||||
n,
|
||||
warp_id));
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
y_buf.FromDevice(y_host_dev.mData.data());
|
||||
pass = ck_tile::check_err(y_host_dev, x_host);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
9
example/ck_tile/37_transpose/CMakeLists.txt
Normal file
9
example/ck_tile/37_transpose/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
set(TARGET_NAME tile_example_transpose)
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL transpose_example.cpp transpose_api.cpp)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(tile_example_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS})
|
||||
|
||||
27
example/ck_tile/37_transpose/README.md
Normal file
27
example/ck_tile/37_transpose/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# Batched Transpose
|
||||
This folder contains example for transpose load for architecture gfx950. This transpose load has some constraints in input tile distribution.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Make the transpose executable
|
||||
make tile_example_transpose -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_transpose`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-N input batch size (default:2)
|
||||
-C input channel size. (default:64)
|
||||
-H input height size. (default:1)
|
||||
-W input width size. (default:64)
|
||||
-v whether do CPU validation or not (default: 1)
|
||||
-layout_in input tensor data layout - NCHW by default
|
||||
-layout_out output tensor data layout - NHWC by default
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-k_name t to 1 will print kernel name (default:0)
|
||||
```
|
||||
120
example/ck_tile/37_transpose/batched_transpose_kernel.hpp
Normal file
120
example/ck_tile/37_transpose/batched_transpose_kernel.hpp
Normal file
@@ -0,0 +1,120 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedTransposeHostArgs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
index_t batch;
|
||||
index_t height;
|
||||
index_t width;
|
||||
// index_t dim_blocks;
|
||||
index_t dim_stride;
|
||||
index_t dim_block_h;
|
||||
index_t dim_block_w;
|
||||
};
|
||||
|
||||
template <typename Pipeline_>
|
||||
struct BatchedTransposeKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using Type = typename Problem::DataType;
|
||||
|
||||
struct BatchedTransposeKargs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
index_t batch;
|
||||
index_t height;
|
||||
index_t width;
|
||||
index_t dim_stride;
|
||||
};
|
||||
|
||||
using Kargs = BatchedTransposeKargs;
|
||||
using Hargs = BatchedTransposeHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
size_t grid_size_x = h.dim_block_w;
|
||||
size_t grid_size_y = h.dim_block_h;
|
||||
size_t grid_size_z = h.batch;
|
||||
return dim3(grid_size_x, grid_size_y, grid_size_z);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_input = h.p_input;
|
||||
k.p_output = h.p_output;
|
||||
k.batch = h.batch;
|
||||
k.height = h.height;
|
||||
k.width = h.width;
|
||||
k.dim_stride = h.dim_stride;
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock;
|
||||
|
||||
const auto iDim = blockIdx.z;
|
||||
const auto x_m_n = [&]() {
|
||||
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
|
||||
make_tuple(kargs.height, kargs.width),
|
||||
make_tuple(kargs.width, 1),
|
||||
number<Pipeline::GetVectorSize()>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(x_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock);
|
||||
|
||||
const auto y_n_m = [&]() {
|
||||
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
|
||||
make_tuple(kargs.width, kargs.height),
|
||||
make_tuple(kargs.height, 1),
|
||||
number<Pipeline::GetVectorSize()>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(y_dram_naive,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iM), static_cast<ck_tile::index_t>(iN)});
|
||||
|
||||
auto y_block_window = make_tile_window(
|
||||
y_n_m,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iN), static_cast<ck_tile::index_t>(iM)});
|
||||
|
||||
Pipeline{}(x_block_window, y_block_window, smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
149
example/ck_tile/37_transpose/block_transpose.hpp
Normal file
149
example/ck_tile/37_transpose/block_transpose.hpp
Normal file
@@ -0,0 +1,149 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "transpose_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Layout_, index_t kRow, index_t kCol>
|
||||
struct TransposeTraits
|
||||
{
|
||||
static constexpr index_t kLeadDim = kCol;
|
||||
static constexpr index_t kSecondDim = kRow;
|
||||
};
|
||||
|
||||
template <index_t kRow, index_t kCol>
|
||||
struct TransposeTraits<tensor_layout::gemm::ColumnMajor, kRow, kCol>
|
||||
{
|
||||
static constexpr index_t kLeadDim = kRow;
|
||||
static constexpr index_t kSecondDim = kCol;
|
||||
};
|
||||
|
||||
// supports 2D transpose which will store to lds, then use ds_read_b*_tr_b* instruction to get the
|
||||
// transposed data; Layout in TransposePipelineProblem is the original layout of the data in the
|
||||
// global memory
|
||||
template <typename DataType_,
|
||||
typename Layout_,
|
||||
index_t kBlockSize_,
|
||||
index_t kRowWarps_, // how many warps in row direction
|
||||
index_t kColWarps_, // how many warps in col direction
|
||||
index_t kRowPerBlock_, // row number per block
|
||||
index_t kColPerBlock_, // col number per block
|
||||
index_t kRowPerXdl_, // row number per xdl ops
|
||||
index_t kColPerXdl_> // col number per xdl ops
|
||||
struct TransposePipelineProblem
|
||||
{
|
||||
static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_,
|
||||
"the block size is not correct!");
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using Layout = remove_cvref_t<Layout_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kLeadNumWarps =
|
||||
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kLeadDim;
|
||||
static constexpr index_t kSecondNumWarps =
|
||||
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kSecondDim;
|
||||
static constexpr index_t kLeadSizePerBlock =
|
||||
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kLeadDim;
|
||||
static constexpr index_t kSecondSizePerBlock =
|
||||
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kSecondDim;
|
||||
static constexpr index_t kLeadSizePerXdl =
|
||||
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kLeadDim;
|
||||
static constexpr index_t kSecondSizePerXdl =
|
||||
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kSecondDim;
|
||||
|
||||
static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
|
||||
static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
|
||||
|
||||
static_assert(kLeadSizePerBlock % kLeadNumWarps == 0,
|
||||
"block dim should be divided by warp dim!");
|
||||
static_assert(kSecondSizePerBlock % kSecondNumWarps == 0,
|
||||
"block dim should be divided by warp dim!");
|
||||
// how many rows/cols implemented in one warp
|
||||
static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps;
|
||||
static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps;
|
||||
|
||||
static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0,
|
||||
"warp dim should be divided by xdl dim!");
|
||||
static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0,
|
||||
"warp dim should be divided by xdl dim!");
|
||||
|
||||
// warp rows/cols is divided into xdl.
|
||||
static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl;
|
||||
static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl;
|
||||
|
||||
static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0,
|
||||
"xdl dim should be divided by quad dim!");
|
||||
static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0,
|
||||
"xdl dim should be divided by quad dim!");
|
||||
// xdl rows/cols is divided into quadrants.
|
||||
static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim;
|
||||
static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim;
|
||||
|
||||
static constexpr index_t kIterationsInSecondDim =
|
||||
kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size();
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = TransposePolicy>
|
||||
struct BlockTranspose
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using DataType = remove_cvref_t<typename Problem::DataType>;
|
||||
using Layout = remove_cvref_t<typename Problem::Layout>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock;
|
||||
static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock;
|
||||
|
||||
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename InputTileWindow, typename OutputTileWindow>
|
||||
CK_TILE_DEVICE void operator()(const InputTileWindow& input_window,
|
||||
OutputTileWindow& output_window,
|
||||
void* __restrict__ p_smem)
|
||||
{
|
||||
auto input_tile_window =
|
||||
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
|
||||
auto output_tile_window =
|
||||
make_tile_window(output_window, Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
DataType* p_lds_ptr = static_cast<DataType*>(p_smem);
|
||||
constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor<Problem>();
|
||||
auto input_lds_block =
|
||||
make_tensor_view<address_space_enum::lds>(p_lds_ptr, in_lds_block_desc);
|
||||
|
||||
constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor<Problem>();
|
||||
auto output_lds_block =
|
||||
make_tensor_view<address_space_enum::lds>(p_lds_ptr, out_lds_block_desc);
|
||||
|
||||
auto copy_to_lds_window =
|
||||
make_tile_window(input_lds_block,
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0});
|
||||
auto load_from_lds_window =
|
||||
make_tile_window(output_lds_block,
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeLdsLoadTileDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(input_tile_window);
|
||||
|
||||
store_tile(copy_to_lds_window, x);
|
||||
block_sync_lds();
|
||||
|
||||
auto y = load_tile_transpose(load_from_lds_window);
|
||||
|
||||
store_tile(output_tile_window, y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
59
example/ck_tile/37_transpose/transpose_api.cpp
Normal file
59
example/ck_tile/37_transpose/transpose_api.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "transpose_example.hpp"
|
||||
#include <iostream>
|
||||
|
||||
template <typename ts_type,
|
||||
ck_tile::index_t block_x,
|
||||
ck_tile::index_t block_y,
|
||||
ck_tile::index_t warp_x,
|
||||
ck_tile::index_t warp_y>
|
||||
float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s)
|
||||
{
|
||||
uint32_t dim_block_h = (a.height + block_y - 1) / block_y;
|
||||
uint32_t dim_block_w = (a.width + block_x - 1) / block_x;
|
||||
uint32_t dim_stride = a.height * a.width;
|
||||
|
||||
a.dim_stride = dim_stride;
|
||||
a.dim_block_h = dim_block_h;
|
||||
a.dim_block_w = dim_block_w;
|
||||
|
||||
using ts_problem = ck_tile::TransposePipelineProblem<ts_type,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
64,
|
||||
1,
|
||||
1,
|
||||
block_y,
|
||||
block_x,
|
||||
warp_y,
|
||||
warp_x>;
|
||||
using ts_pipeline = ck_tile::BlockTranspose<ts_problem>;
|
||||
|
||||
using kernel = ck_tile::BatchedTransposeKernel<ts_pipeline>;
|
||||
|
||||
auto kargs = kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = kernel::GridSize(a);
|
||||
constexpr dim3 blocks = kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float batched_transpose(batched_transpose_trait t,
|
||||
batched_transpose_kargs a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
if(t.type == "fp16")
|
||||
{
|
||||
return batched_transpose_dispatch<ck_tile::fp16_t, 16, 32, 16, 32>(a, s);
|
||||
}
|
||||
else if(t.type == "fp8")
|
||||
{
|
||||
return batched_transpose_dispatch<ck_tile::fp8_t, 16, 64, 16, 64>(a, s);
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
257
example/ck_tile/37_transpose/transpose_example.cpp
Normal file
257
example/ck_tile/37_transpose/transpose_example.cpp
Normal file
@@ -0,0 +1,257 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "transpose_example.hpp"
|
||||
|
||||
#if 0
|
||||
template <typename T>
|
||||
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
|
||||
{
|
||||
auto len = x.get_lengths();
|
||||
assert(len.size() == 4);
|
||||
std::cout << "[";
|
||||
for(size_t i = 0; i < len[0]; i++)
|
||||
{
|
||||
std::cout << i << ": [";
|
||||
for(size_t j = 0; j < len[1]; j++)
|
||||
{
|
||||
std::cout << j << ": [";
|
||||
for(size_t k = 0; k < len[2]; k++)
|
||||
{
|
||||
std::cout << k << ": [";
|
||||
for(size_t v = 0; v < len[3]; v++)
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
|
||||
{
|
||||
auto m =
|
||||
ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j, k, v}));
|
||||
|
||||
std::cout << m;
|
||||
if(v != len[3] - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
|
||||
}
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "--------------------" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "whether do CPU validation or not")
|
||||
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
|
||||
.insert("N", "2", "input batch size. ")
|
||||
.insert("C", "64", "input channel size.")
|
||||
.insert("H", "1", "input height size.")
|
||||
.insert("W", "64", "input width size. ")
|
||||
.insert("layout_in", "NCHW", "input tensor data layout - NCHW by default")
|
||||
.insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "t to 1 will print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
bool run_batched_transpose(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string prec = args.get_str("pr");
|
||||
int N = args.get_int("N");
|
||||
int C = args.get_int("C");
|
||||
int H = args.get_int("H");
|
||||
int W = args.get_int("W");
|
||||
std::string layout_in = args.get_str("layout_in");
|
||||
std::string layout_out = args.get_str("layout_out");
|
||||
int seed = args.get_int("seed");
|
||||
|
||||
int dim_in[4], dim_out[4];
|
||||
int stride_dim_in[4], stride_dim_out[4];
|
||||
bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC";
|
||||
bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW";
|
||||
assert(nchw2nhwc != nhwc2nchw);
|
||||
(void)nhwc2nchw;
|
||||
|
||||
dim_in[0] = N;
|
||||
dim_in[1] = nchw2nhwc ? C : H;
|
||||
dim_in[2] = nchw2nhwc ? H : W;
|
||||
dim_in[3] = nchw2nhwc ? W : C;
|
||||
dim_out[0] = N;
|
||||
dim_out[1] = nchw2nhwc ? H : C;
|
||||
dim_out[2] = nchw2nhwc ? W : H;
|
||||
dim_out[3] = nchw2nhwc ? C : W;
|
||||
stride_dim_in[0] = C * H * W;
|
||||
stride_dim_in[1] = nchw2nhwc ? H * W : C * W;
|
||||
stride_dim_in[2] = nchw2nhwc ? W : C;
|
||||
stride_dim_in[3] = 1;
|
||||
stride_dim_out[0] = C * H * W;
|
||||
stride_dim_out[1] = nchw2nhwc ? C * W : H * W;
|
||||
stride_dim_out[2] = nchw2nhwc ? C : W;
|
||||
stride_dim_out[3] = 1;
|
||||
|
||||
if(seed < 0)
|
||||
{
|
||||
seed = std::time(nullptr);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<Type> x_host(
|
||||
{dim_in[0], dim_in[1], dim_in[2], dim_in[3]},
|
||||
{stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]});
|
||||
ck_tile::HostTensor<Type> y_host(
|
||||
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
|
||||
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
|
||||
|
||||
ck_tile::FillUniformDistribution<Type>{-.5f, .5f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
auto trait = batched_transpose_trait{prec, layout_in};
|
||||
|
||||
uint32_t height = nchw2nhwc ? C : H * W;
|
||||
uint32_t width = nchw2nhwc ? H * W : C;
|
||||
|
||||
batched_transpose_kargs karg = [&]() {
|
||||
batched_transpose_kargs a_;
|
||||
a_.p_input = x_dev.GetDeviceBuffer();
|
||||
a_.p_output = y_dev.GetDeviceBuffer();
|
||||
a_.batch = N;
|
||||
a_.height = height;
|
||||
a_.width = width;
|
||||
return a_;
|
||||
}();
|
||||
|
||||
ck_tile::stream_config sc{nullptr, true};
|
||||
|
||||
auto ms = batched_transpose(trait, karg, sc);
|
||||
|
||||
std::size_t num_operations = N * C * H * (W - 1);
|
||||
std::size_t num_bytes = N * C * H * W * sizeof(Type);
|
||||
|
||||
float ave_time = ms * 1E-3;
|
||||
float gb_per_sec = num_bytes / ms * 1.E-6;
|
||||
float tflops = static_cast<float>(num_operations) / ms * 1.E-6;
|
||||
|
||||
std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H
|
||||
<< ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out
|
||||
<< " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops"
|
||||
<< gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n",
|
||||
prec.c_str(),
|
||||
N,
|
||||
C,
|
||||
H,
|
||||
W,
|
||||
layout_in.c_str(),
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
fflush(stdout);
|
||||
|
||||
if(ms < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
y_dev.FromDevice(y_host.data());
|
||||
|
||||
bool rtn = true;
|
||||
if(validate)
|
||||
{
|
||||
// this host buffer will not copy to GPU, so no need use stride
|
||||
ck_tile::HostTensor<Type> y_ref(
|
||||
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
|
||||
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
|
||||
|
||||
ck_tile::reference_batched_transpose<Type>(x_host, y_ref, layout_in, layout_out);
|
||||
|
||||
auto [rtol, atol] = get_elimit<Type>("");
|
||||
|
||||
rtn &= ck_tile::check_err(
|
||||
y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol);
|
||||
}
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string prec = args.get_str("pr");
|
||||
|
||||
bool r = true;
|
||||
if(prec.compare("fp16") == 0)
|
||||
{
|
||||
r &= run_batched_transpose<ck_tile::fp16_t>(args);
|
||||
}
|
||||
else if(prec.compare("fp8") == 0)
|
||||
{
|
||||
r &= run_batched_transpose<ck_tile::fp8_t>(args);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type: " << prec << std::endl;
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
27
example/ck_tile/37_transpose/transpose_example.hpp
Normal file
27
example/ck_tile/37_transpose/transpose_example.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "batched_transpose_kernel.hpp"
|
||||
#include "block_transpose.hpp"
|
||||
#include "transpose_policy.hpp"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#pragma once
|
||||
|
||||
struct batched_transpose_trait
|
||||
{
|
||||
std::string type;
|
||||
std::string layout;
|
||||
};
|
||||
|
||||
struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float batched_transpose(batched_transpose_trait t,
|
||||
batched_transpose_kargs a,
|
||||
ck_tile::stream_config s);
|
||||
151
example/ck_tile/37_transpose/transpose_policy.hpp
Normal file
151
example/ck_tile/37_transpose/transpose_policy.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct TransposePolicy
|
||||
{
|
||||
static constexpr auto TileAccessPattern = tile_distribution_pattern::thread_raked;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSize()
|
||||
{
|
||||
return 16 / sizeof(typename Problem::DataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return integer_least_multiple(
|
||||
sizeof(typename Problem::DataType) *
|
||||
MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType);
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
SecondDimPerBlock,
|
||||
LeadDimPerBlock,
|
||||
VecLoadSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
constexpr auto input_dstr = MakeLdsLoadTileDistribution<Problem>();
|
||||
|
||||
using OutTileDstrEncode =
|
||||
typename OutputTileDistributionTraits<remove_cvref_t<decltype(input_dstr)>,
|
||||
typename Problem::DataType>::OutDstrEncode;
|
||||
constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
|
||||
return block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kSecondDimPerBlock>{},
|
||||
number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}),
|
||||
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
|
||||
number<kVectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
|
||||
constexpr index_t kVectorSize = 8 / sizeof(typename Problem::DataType);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kSecondDimPerBlock>{},
|
||||
number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}),
|
||||
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
|
||||
number<kVectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution()
|
||||
{
|
||||
using DataType = typename Problem::DataType;
|
||||
|
||||
// Extract base dimensions from the traits
|
||||
constexpr index_t kBaseLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
|
||||
constexpr index_t kBaseSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
|
||||
|
||||
// Calculate block-level dimensions
|
||||
constexpr index_t kLead = Problem::kLeadSizePerXdl;
|
||||
constexpr index_t kSecond = Problem::kSecondSizePerXdl;
|
||||
constexpr index_t kLeadIterPerWarp = Problem::kLeadXdlNumPerWarp;
|
||||
constexpr index_t kSecondIterPerWarp = Problem::kSecondXdlNumPerWarp;
|
||||
constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
|
||||
constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
|
||||
|
||||
// Calculate repetitions of base pattern
|
||||
constexpr index_t kLeadRepetitions = kLead / kBaseLeadDim;
|
||||
constexpr index_t kSecondRepetitions = kSecond / kBaseSecondDim;
|
||||
constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim;
|
||||
constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations;
|
||||
|
||||
constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode<DataType,
|
||||
kSecondDimStrSub,
|
||||
kSecondDimIterations,
|
||||
kLeadRepetitions,
|
||||
1>();
|
||||
|
||||
constexpr auto input_tile_encode =
|
||||
InputTileDistributionEncoding<decltype(xdllevel_dstr_encoding),
|
||||
kLeadIterPerWarp,
|
||||
kSecondIterPerWarp,
|
||||
kLeadNumWarps,
|
||||
kSecondNumWarps>();
|
||||
constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode);
|
||||
return block_dstr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -19,5 +19,6 @@ add_subdirectory(16_batched_gemm)
|
||||
add_subdirectory(17_grouped_gemm)
|
||||
add_subdirectory(18_flatmm)
|
||||
add_subdirectory(19_gemm_multi_d)
|
||||
add_subdirectory(20_grouped_convolution)
|
||||
add_subdirectory(35_batched_transpose)
|
||||
add_subdirectory(36_copy)
|
||||
add_subdirectory(37_transpose)
|
||||
|
||||
@@ -244,12 +244,6 @@
|
||||
// workaround: compiler issue on gfx908
|
||||
#define CK_WORKAROUND_SWDEV_388832 1
|
||||
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1
|
||||
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_TEMP_DISABLE_FP4_TESTS 1
|
||||
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
|
||||
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
|
||||
{
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
@@ -550,12 +550,77 @@ struct Tensor
|
||||
auto dis_ = dis; // copy
|
||||
g_.discard(ib_begin * BLOCK_SIZE * ck::packed_size_v<T>);
|
||||
auto t_fn = [&]() {
|
||||
if constexpr(ck::packed_size_v<T> == 1)
|
||||
// As user can pass integer distribution in dis, we must ensure that the correct
|
||||
// constructor/converter is called at all times. For f4/f6/f8 types, to ensure
|
||||
// correct results, we convert from float to the target type. In these cases
|
||||
// integer constructors are interpreted as direct initialization of the internal
|
||||
// storage with binary values instead of treating integers as subset of floats.
|
||||
if constexpr(ck::is_same_v<T, ck::f8_t> || ck::is_same_v<T, ck::bf8_t>)
|
||||
return ck::type_convert<T>(static_cast<float>(fn(dis_(g_))));
|
||||
else if constexpr(ck::packed_size_v<T> == 1)
|
||||
return ck::type_convert<T>(fn(dis_(g_)));
|
||||
else if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(
|
||||
ck::float2_t{ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_)))})};
|
||||
else if constexpr(ck::is_same_v<T, ck::f6x32_pk_t> ||
|
||||
ck::is_same_v<T, ck::bf6x32_pk_t>)
|
||||
{
|
||||
return ck::type_convert<T>(
|
||||
ck::float32_t{ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_)))});
|
||||
}
|
||||
else if constexpr(ck::is_same_v<T, ck::f6x16_pk_t> ||
|
||||
ck::is_same_v<T, ck::bf6x16_pk_t>)
|
||||
{
|
||||
return ck::type_convert<T>(
|
||||
ck::float16_t{ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_)))});
|
||||
}
|
||||
else
|
||||
static_assert(false, "Unsupported packed size for T");
|
||||
};
|
||||
@@ -600,12 +665,12 @@ struct Tensor
|
||||
ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
T& operator()(const std::vector<std::size_t>& idx)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
const T& operator()(const std::vector<std::size_t>& idx) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
@@ -66,9 +66,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
static constexpr index_t AMmaKStride = KPack;
|
||||
static constexpr index_t BMmaKStride = KPack;
|
||||
|
||||
//> store rows/cols into thread registers in chunks of 16
|
||||
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
|
||||
static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA);
|
||||
// store rows/cols into thread registers in chunks of 16 for FP8
|
||||
// e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
|
||||
// or in chunks of 32 / APackedSize for FP6/FP4
|
||||
static constexpr index_t KThreadChunk = (APackedSize == 1) ? 16 : 32 / APackedSize;
|
||||
|
||||
static_assert(APackedSize == BPackedSize, "APackedSize must be equal to BPackedSize for now");
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
|
||||
@@ -122,7 +122,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
@@ -154,9 +153,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
@@ -291,14 +290,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -391,15 +388,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -483,14 +477,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -596,7 +588,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
|
||||
@@ -1,919 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack> : BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::APackedSize;
|
||||
using Base::BPackedSize;
|
||||
using Base::ComputePackedSize;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
BScaleThreadTransfer& b_scale_thread_copy_up,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleGridBuffer& b_scale_grid_buf_up,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
ignore = a_scale_grid_buf;
|
||||
ignore = b_scale_grid_buf;
|
||||
ignore = b_scale_grid_buf_up;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs_up;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 0
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 0
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 1
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I1));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 1
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(
|
||||
0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk *
|
||||
xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(mfma_reg_buf));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I1][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from a_scale_grid to a_scale_thread
|
||||
static constexpr auto a_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from b_scale_grid to b_scale_thread_buf
|
||||
static constexpr auto b_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
protected:
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
// using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user