diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index ccdfb0f6fb..bd597344ea 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @shumway @vidyasagar-amd # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd diff --git a/.gitignore b/.gitignore index 2e7ed0c77c..058d180a3c 100644 --- a/.gitignore +++ b/.gitignore @@ -69,6 +69,9 @@ build*/ # Python cache __pycache__/ +.cache/ + + test_data/* script/*.csv script/*.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d6700ae05b..e4e85651f6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,3 +12,27 @@ repos: verbose: false language: script types: [c++] + - id: remove-exec-bit + name: Remove executable bit from non-executable files + entry: script/remove_exec_bit.sh + language: script + types_or: [c++, text] + verbose: true + - id: ruff-check + name: Ruff Linter + entry: ruff check --fix + language: python + types: [python] + additional_dependencies: [ruff] + - id: ruff-format + name: Ruff Formatter + entry: ruff format + language: python + types: [python] + additional_dependencies: [ruff] + - id: run-remod-if-ck-tile-changed + name: Run remod.py if ck_tile files changed + entry: script/remod_for_ck_tile.sh + language: script + always_run: true + pass_filenames: false diff --git a/CHANGELOG.md b/CHANGELOG.md index aecf16d83d..17f9455feb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,18 +13,23 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM -* Added GEMM pipeline for microscaling (MX) FP8/FP4 data types +* Added support for Multiple D GEMM +* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) * Added benchmarking support for tile engine GEMM. +* Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. +* Added int8 support for CK_TILE GEMM. ### Optimized * Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166) * Added Vectorize Transpose optimization for CK Tile (#2131) +* Added the asynchronous copy for gfx950 (#2425) ### Fixes diff --git a/CMakeLists.txt b/CMakeLists.txt index aab74f3069..6e032a30cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -308,6 +308,7 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) +option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -321,6 +322,12 @@ if(USE_OPT_GFX11) message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") endif() +if(ENABLE_ASM_DUMP) + add_compile_options(--save-temps) + add_compile_options(-Wno-gnu-line-marker) + message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) @@ -627,7 +634,7 @@ option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) add_subdirectory(library) -if(NOT GPU_ARCHS AND USER_GPU_TARGETS) +if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name diff --git a/Dockerfile b/Dockerfile index 4ce1b85fc0..4dd03c3b90 100644 --- a/Dockerfile +++ b/Dockerfile @@ -77,7 +77,6 @@ RUN git clone https://github.com/ccache/ccache.git && \ wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip && \ gunzip /usr/local/bin/ninja.gz && \ chmod a+x /usr/local/bin/ninja && \ - git clone https://github.com/nico/ninjatracing.git && \ #Install ClangBuildAnalyzer git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \ cd ClangBuildAnalyzer/ && \ diff --git a/Jenkinsfile b/Jenkinsfile index 1cb1a6ca6c..50c15701a7 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -225,7 +225,9 @@ def cmake_build(Map conf=[:]){ def build_envs = "CTEST_PARALLEL_LEVEL=4 " + conf.get("build_env","") def prefixpath = conf.get("prefixpath","/opt/rocm") def setup_args = conf.get("setup_args","") - + // make sure all unit tests always run on develop branch + def runAllUnitTests = (env.BRANCH_NAME == "develop") ? true : params.RUN_ALL_UNIT_TESTS + if (prefixpath != "/usr/local"){ setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} " } @@ -343,15 +345,19 @@ def cmake_build(Map conf=[:]){ def build_cmd def execute_cmd = conf.get("execute_cmd", "") if(!setup_args.contains("NO_CK_BUILD")){ - if (setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE){ + def cmake_flags = params.NINJA_FTIME_TRACE ? "-O3 -ftime-trace" : "-O3" + if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" - setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """) - build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") - } - else{ - setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") - build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}") } + setup_cmd = conf.get( + "setup_cmd", + """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" ${cmake_flags} " .. """ + ) + build_cmd = conf.get( + "build_cmd", + "${build_envs} ninja -j${nt} ${config_targets}" + ) + cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} @@ -372,14 +378,23 @@ def cmake_build(Map conf=[:]){ //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ if ((setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE) || params.BUILD_INSTANCES_ONLY){ - sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json" - sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" - sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log" + if (params.NINJA_FTIME_TRACE) { + echo "running ninja ftime trace" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log" + archiveArtifacts "clang_build_analysis.log" + } + sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace.json" archiveArtifacts "ck_build_trace.json" - archiveArtifacts "clang_build_analysis.log" + // do not run unit tests when building instances only if(!params.BUILD_INSTANCES_ONLY){ - sh "ninja check" + if (!runAllUnitTests){ + sh "../script/launch_tests.sh" + } + else{ + sh "ninja check" + } } if(params.BUILD_INSTANCES_ONLY){ // build deb packages @@ -393,7 +408,12 @@ def cmake_build(Map conf=[:]){ else{ // run unit tests unless building library for all targets if (!params.BUILD_INSTANCES_ONLY){ - sh "make check" + if (!runAllUnitTests){ + sh "../script/launch_tests.sh" + } + else{ + sh "ninja check" + } } } } @@ -793,10 +813,10 @@ def process_results(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true + 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : "" @@ -859,8 +879,8 @@ pipeline { description: "Run the cppcheck static analysis (default: OFF)") booleanParam( name: "RUN_PERFORMANCE_TESTS", - defaultValue: true, - description: "Run the performance tests (default: ON)") + defaultValue: false, + description: "Run the performance tests (default: OFF)") booleanParam( name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS", defaultValue: false, @@ -893,10 +913,26 @@ pipeline { name: "BUILD_GFX908", defaultValue: false, description: "Build CK and run tests on gfx908 (default: OFF)") + booleanParam( + name: "BUILD_GFX90A", + defaultValue: true, + description: "Build CK and run tests on gfx90a (default: ON)") + booleanParam( + name: "BUILD_GFX942", + defaultValue: true, + description: "Build CK and run tests on gfx942 (default: ON)") booleanParam( name: "BUILD_GFX950", defaultValue: false, description: "Build CK and run tests on gfx950 (default: OFF)") + booleanParam( + name: "BUILD_GFX10", + defaultValue: true, + description: "Build CK and run tests on gfx10 (default: ON)") + booleanParam( + name: "BUILD_GFX11", + defaultValue: true, + description: "Build CK and run tests on gfx11 (default: ON)") booleanParam( name: "BUILD_GFX12", defaultValue: true, @@ -905,6 +941,10 @@ pipeline { name: "NINJA_BUILD_TRACE", defaultValue: false, description: "Generate a ninja build trace (default: OFF)") + booleanParam( + name: "NINJA_FTIME_TRACE", + defaultValue: false, + description: "Generate a detailed time trace (default: OFF)") booleanParam( name: "BUILD_LEGACY_OS", defaultValue: false, @@ -913,6 +953,10 @@ pipeline { name: "RUN_INDUCTOR_TESTS", defaultValue: true, description: "Run inductor codegen tests (default: ON)") + booleanParam( + name: "RUN_ALL_UNIT_TESTS", + defaultValue: false, + description: "Run all unit tests (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -1025,7 +1069,7 @@ pipeline { { when { beforeAgent true - expression { params.RUN_CODEGEN_TESTS.toBoolean() } + expression { params.RUN_CODEGEN_TESTS.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } agent{ label rocmnode("gfx90a")} environment{ @@ -1185,9 +1229,16 @@ pipeline { agent{ label rocmnode("gfx90a") } environment{ setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make benchmark_gemm -j && \ - ./bin/benchmark_gemm """ + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx90a" \ + -D GEMM_DATATYPE="fp8;fp16" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j64 benchmark_gemm_fp8 && \ + ./bin/benchmark_gemm_fp8 && \ + ninja -j64 benchmark_gemm_fp16 && \ + ./bin/benchmark_gemm_fp16 """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1203,9 +1254,16 @@ pipeline { agent{ label rocmnode("gfx942") } environment{ setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ - make benchmark_gemm -j && \ - ./bin/benchmark_gemm """ + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx942" \ + -D GEMM_DATATYPE="fp8;fp16" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j128 benchmark_gemm_fp8 && \ + ./bin/benchmark_gemm_fp8 && \ + ninja -j128 benchmark_gemm_fp16 && \ + ./bin/benchmark_gemm_fp16 """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1261,7 +1319,7 @@ pipeline { { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { (params.BUILD_GFX942.toBoolean() || params.RUN_FULL_QA.toBoolean()) && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx942") } environment{ @@ -1299,7 +1357,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, docker_name: "rocm/composable_kernel-private:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } @@ -1328,7 +1386,7 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { params.BUILD_GFX90A.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx90a") } environment{ @@ -1352,14 +1410,20 @@ pipeline { expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx942") } - environment{ - execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER="${build_compiler()}" \ - -D CMAKE_BUILD_TYPE=Release \ - -D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """ - } steps{ - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + script { + def execute_args = params.NINJA_FTIME_TRACE ? + """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """ : + """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ + + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + } cleanWs() } } @@ -1367,7 +1431,7 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { params.BUILD_GFX10.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx1030") } environment{ @@ -1388,7 +1452,7 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx1101") } environment{ diff --git a/client_example/32_gemm_mx/CMakeLists.txt b/client_example/32_gemm_mx/CMakeLists.txt new file mode 100644 index 0000000000..558986bf5a --- /dev/null +++ b/client_example/32_gemm_mx/CMakeLists.txt @@ -0,0 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx950") + add_executable(client_gemm_mx_fp8 gemm_mx_fp8.cpp) + target_link_libraries(client_gemm_mx_fp8 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/32_gemm_mx/gemm_mx_fp8.cpp b/client_example/32_gemm_mx/gemm_mx_fp8.cpp new file mode 100644 index 0000000000..6e14bf2a5f --- /dev/null +++ b/client_example/32_gemm_mx/gemm_mx_fp8.cpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using CDataType = ck::half_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; +template +inline constexpr bool is_same_v = ck::is_same::value; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AScaleLayout = Row; +using BScaleLayout = Col; + +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + mem_size_ = mem_size; + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; + std::size_t mem_size_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t KBatch = 1; + + /* Require by mx type*/ + constexpr ck::index_t ScaleBlockSize = 32; // scaling block size + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideC = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + /* Scale stride Calculation */ + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + return static_cast(col); + else + return static_cast(row); + } + else + return static_cast(stride); + }; + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + auto Scale_Padded_M = (M + ScaleBlockSize - 1) / ScaleBlockSize * ScaleBlockSize; + auto Scale_Stride_AM = + f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + SimpleDeviceMem a_scale_device_buf( + sizeof(XDataType) * + f_matrix_space_size(Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + SimpleDeviceMem b_scale_device_buf( + sizeof(XDataType) * + f_matrix_space_size(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + using DeviceOp = + ck::tensor_operation::device::DeviceGemmMX; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(ADataType) * M * K / ck::packed_size_v + + sizeof(BDataType) * K * N / ck::packed_size_v + + sizeof(CDataType) * M * N + + sizeof(XDataType) * M * K / ScaleBlockSize + + sizeof(XDataType) * N * K / ScaleBlockSize; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/README.md b/client_example/README.md index d9f793434d..34c6733d05 100644 --- a/client_example/README.md +++ b/client_example/README.md @@ -14,8 +14,10 @@ cd client_example/build cmake \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \ +-D GPU_TARGETS="gfx908;gfx90a" \ .. ``` +You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s). ### Build client example ```bash diff --git a/cmake/ShardInstantiation.cmake b/cmake/ShardInstantiation.cmake new file mode 100644 index 0000000000..47a5d0c48c --- /dev/null +++ b/cmake/ShardInstantiation.cmake @@ -0,0 +1,116 @@ +# Function to generate templated instantiation functions and caller function. + +# In order to reduce build times, we split the instantiation of template functions into multiple files. +# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions, +# which can be placed the TEMPLATE_FILE (typically a .in file). + +# This CMake function generates the instantiation functions and a caller function that calls all the instantiation +# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of +# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard, +# and generates a caller function that calls all the instantiation functions. + +# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation +# of the template functions in the caller function, and that code is automatically generated by this function. + +# In addition to the user-supplied template, this CMake function uses two generic templates: +# +# 1. `instantiate_shard.in`: This is the template for the instantiation functions. +# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions. + +# This function takes the following arguments: +# +# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`). +# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions. +# - NUM_SHARDS: The number of shards to generate. +# - OUTPUT_DIR: The build directory where the generated source files will be placed. +# - SRC_LIST: The list of source files to which the generated source files will be added. + + +function(generate_sharded_instantiations) + cmake_parse_arguments( + GEN_SHARDED + # No boolean arguments + "" + # Single-value arguments + "INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST" + # No multi-value arguments. + "" + ${ARGN} + ) + if (NOT GEN_SHARDED_INSTANCES_NAME) + message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_TEMPLATE_FILE) + message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_NUM_SHARDS) + message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations") + endif() + if(NOT GEN_SHARDED_OUTPUT_DIR) + message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_SRC_LIST) + message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations") + endif() + + file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR}) + + + set(GENERATED_SOURCE_FILES "") + set(EXTERN_TEMPLATE_STATEMENTS "") + set(CALL_STATEMENTS "") + message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") + + set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}") + + # Generate the inc file with the template function defintions. + # This include file will hold the template function definitions and a using alias for all the shard + # instantiation functions. + configure_file( + "${GEN_SHARDED_TEMPLATE_FILE}" + "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc" + @ONLY + ) + + # Generate the sharded instantiation functions. + # This is where the build parallelization happens. + # Each of these source files will contain a single instantiation function for a shard, + # which will be called sequentially by the caller function. + set(INC_DIR "${GEN_SHARDED_INC_DIR}") + math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1") + foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID}) + set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}") + set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp") + set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in") + configure_file( + "${SHARD_FUNCTION_TEMPLATE}" + "${SHARD_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}") + set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>") + list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)") + list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)") + endforeach() + + # Join the include statements, the extern template declarations, and the call statements each + # into a single string for variable substitution in the caller function. + string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}") + string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}") + string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}") + + # Generate the caller function. + set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp") + set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in") + configure_file( + "${FUNCTION_TEMPLATE}" + "${CALLER_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}") + + # Add the generated source files to the list of source files. + # This allows the generated source files to be included in the build. + list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES}) + set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE) +endfunction() \ No newline at end of file diff --git a/cmake/call_shard.in b/cmake/call_shard.in new file mode 100644 index 0000000000..daba79b055 --- /dev/null +++ b/cmake/call_shard.in @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { + +@EXTERN_TEMPLATE_STATEMENTS@; + +void add_@INSTANCES@( + @INSTANCES@& instances) { +@CALL_STATEMENTS@; +} + +} // namespace ck::tensor_operation::device::instance diff --git a/cmake/instantiate_shard.in b/cmake/instantiate_shard.in new file mode 100644 index 0000000000..dbc0af17a9 --- /dev/null +++ b/cmake/instantiate_shard.in @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { +template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>( + @INSTANCES@& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index 4367aabc95..4c8019f8d3 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -945,11 +945,9 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../../include/ck/tensor_operation/gpu/grid \ - ../../include/ck/tensor_operation/gpu/block \ - ../../include/ck/tensor_operation/gpu/thread \ +INPUT = ../../include \ + ../../include/ck/ \ ../../library/include/ck/library/utility \ - ../../include/ck/wrapper \ ../../include/ck_tile # This tag can be used to specify the character encoding of the source files @@ -1849,7 +1847,7 @@ MATHJAX_CODEFILE = # The default value is: YES. # This tag requires that the tag GENERATE_HTML is set to YES. -SEARCHENGINE = YES +SEARCHENGINE = NO # When the SERVER_BASED_SEARCH tag is enabled the search engine will be # implemented using a web server instead of a web client using JavaScript. There @@ -2406,7 +2404,7 @@ TAGFILES = # tag file that is based on the input files it reads. See section "Linking to # external documentation" for more information about the usage of tag files. -GENERATE_TAGFILE = +GENERATE_TAGFILE = html/tagfile.xml # If the ALLEXTERNALS tag is set to YES, all external class will be listed in # the class index. If set to NO, only the inherited external classes will be @@ -2653,7 +2651,7 @@ DIR_GRAPH_MAX_DEPTH = 1 # The default value is: png. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_IMAGE_FORMAT = png +DOT_IMAGE_FORMAT = svg # If DOT_IMAGE_FORMAT is set to svg, then this option can be set to YES to # enable generation of interactive SVG images that allow zooming and panning. @@ -2665,7 +2663,7 @@ DOT_IMAGE_FORMAT = png # The default value is: NO. # This tag requires that the tag HAVE_DOT is set to YES. -INTERACTIVE_SVG = NO +INTERACTIVE_SVG = YES # The DOT_PATH tag can be used to specify the path where the dot tool can be # found. If left blank, it is assumed the dot tool can be found in the path. diff --git a/docs/index.rst b/docs/index.rst index 4cc26a1d3e..89a5e3e836 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -36,7 +36,9 @@ The Composable Kernel repository is located at `https://github.com/ROCm/composab * :doc:`Composable Kernel custom types <./reference/Composable_Kernel_custom_types>` * :doc:`Composable Kernel vector utilities <./reference/Composable_Kernel_vector_utilities>` * :ref:`wrapper` - * :doc:`Composable Kernel complete class list <./doxygen/html/annotated>` + * :doc:`Composable Kernel API reference <./doxygen/html/namespace_c_k>` + * :doc:`CK Tile API reference <./doxygen/html/namespaceck__tile>` + * :doc:`Composable Kernel complete API class list <./doxygen/html/annotated>` To contribute to the documentation refer to `Contributing to ROCm `_. diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 489a448860..beedb4e867 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ rocm-docs-core[api_reference]==1.20.1 -sphinxcontrib-bibtex==2.6.3 +sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 14e74b2a6f..e8aa02aa01 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -182,7 +182,7 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data -pybtex==0.24.0 +pybtex==0.25.1 # via # pybtex-docutils # sphinxcontrib-bibtex @@ -244,9 +244,7 @@ rpds-py==0.24.0 # jsonschema # referencing six==1.17.0 - # via - # pybtex - # python-dateutil + # via python-dateutil smmap==5.0.2 # via gitdb snowballstemmer==2.2.0 @@ -278,7 +276,7 @@ sphinx-notfound-page==1.1.0 # via rocm-docs-core sphinxcontrib-applehelp==2.0.0 # via sphinx -sphinxcontrib-bibtex==2.6.3 +sphinxcontrib-bibtex==2.6.5 # via -r requirements.in sphinxcontrib-devhelp==2.0.0 # via sphinx diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 36f1860e4f..b9748aabda 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -43,7 +43,13 @@ endforeach() set(GEMM_OPTIONS) list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") set(BLOCKSCALE_GEMM_OPTIONS) -list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1") +check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) +check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) +if(HAS_MISCHED_BOTTOMUP) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1") +elseif(HAS_MISCHED_PRERA_DIRECTION) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup") +endif() check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental " HAS_MAX_OCCUPANCY_EXPERIMENTAL) if(HAS_MAX_OCCUPANCY_EXPERIMENTAL) list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental) diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 3188ba142c..6a3986ea32 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -139,6 +139,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool PerTokenQuant = true; static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -169,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, PerTokenQuant, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -197,7 +198,7 @@ int main(int argc, char* argv[]) { // use default case } - else if(argc == 3) + else if(argc == 4) { // use default case do_verification = std::stoi(argv[1]); @@ -238,7 +239,8 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0, 0}; + constexpr auto StrideDs = PerTokenQuant ? std::array{1, 1, 0} + : std::array{0, 0, 0}; ck::index_t KBatch = 1; @@ -279,8 +281,10 @@ int main(int argc, char* argv[]) Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d0_t_n( + HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 34c54a7e12..35c5d18d50 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -10,6 +10,9 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) # add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) # add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) +add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp6) + add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_fp4) @@ -22,17 +25,40 @@ add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns) add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp) add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns) +add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4) + +add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4) + +add_example_executable(example_moe_gemm1_xdl_mx_fp4_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bpreshuffle) + +add_example_executable(example_moe_gemm2_xdl_mx_fp4_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bpreshuffle) + set(FP4_MXGEMM_OPTIONS) list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1") example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) -example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) -example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +# mx moe B no-shuffling + scale shuffling example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) +# mx moe B no-shuffling + scale shuffling (async loads) +example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) + +# mx moe B shuffling + scale shuffling (async loads) +example_compile_options(example_moe_gemm1_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) + set(FP8_MXGEMM_OPTIONS) list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS}) example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS}) + +set(FP6_MXGEMM_OPTIONS) +list(APPEND FP6_MXGEMM_OPTIONS -mavx512f) +example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS}) diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 1f01e1c7be..6ce10817ff 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -245,6 +245,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); }; + if(K % ck::packed_size_v != 0 || K % ck::packed_size_v != 0) + { + throw std::runtime_error("wrong! K must be multiple of packed size."); + }; + // Hardcode scale layouts as per pipeline assumptions // TODO: Allow user to specify scale layouts using AScaleLayout = Row; @@ -292,12 +297,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto a_data_element = [](float x) { if constexpr(ck::is_same_v) return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); else return ck::type_convert(x); }; auto b_data_element = [](float x) { if constexpr(ck::is_same_v) return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); else return ck::type_convert(x); }; @@ -307,30 +320,35 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c switch(config.init_method) { case 0: // Initializations for development and debugging - ck::utils::FillConstant{a_data_element(1.0f)}(a_m_k); - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); + + ck::utils::FillConstant{a_data_element(0.5f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); + ck::utils::FillConstant{b_data_element(2.0f)}(*b_k_n); ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n_scale); + if(config.verbosity > 0) { - std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A = {0.5}" << std::endl; std::cout << "Init A scale = {2.0}" << std::endl; - std::cout << "Init B = {0.5}" << std::endl; - std::cout << "Init B scale = {1.0}" << std::endl; + std::cout << "Init B = {2.0}" << std::endl; + std::cout << "Init B scale = {0.5}" << std::endl; std::cout << "Expect C = {K}" << std::endl; } break; case 1: - a_m_k.GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5] - b_k_n->GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5] + a_m_k.GenerateTensorDistr( + int_distr{-5, 5}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-5,5] + b_k_n->GenerateTensorDistr(int_distr{-5, 5}); // Z[-5,5] static_assert(ck::is_same_v); - a_m_k_scale.GenerateTensorDistr(int_distr{120, 129}); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2} + a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} break; case 2: - a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0}); + a_m_k.GenerateTensorDistr( + float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2] a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0}); diff --git a/example/67_gemm_microscaling/gemm_mx_fp6.cpp b/example/67_gemm_microscaling/gemm_mx_fp6.cpp new file mode 100644 index 0000000000..615980082d --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp6.cpp @@ -0,0 +1,99 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f6x16_pk_t; +using BDataType = ck::f6x16_pk_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / ck::packed_size_v; // K dimension size per block + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Number of threads per block + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 1, // AK1 number of elements to read at a time when transferring from global memory to LDS + 1, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp new file mode 100644 index 0000000000..aaf0cb3891 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t BlockSize = 256; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, BlockSize, + MPerBlock, NPerBlock, KPerBlock, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_k_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_k_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + e_t_k_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{0.5f}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.5f}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k(token_id, k); + } + } + } + + // A/B scale shuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + // FMA * tokens * N * (Gate+Up) * topk * K + + // FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale) + std::size_t(2) * tokens * N * 2 * topk * K + + std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm1; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + a1_t_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_k_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp new file mode 100644 index 0000000000..08ed8e11fb --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -0,0 +1,574 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// B preshuffle +void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + I64 tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K_pk; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 64, KPerBlock, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_k_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_k_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + e_t_k_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k(token_id, k); + } + } + } + + // A/B scale shuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + preShuffleBuffer(b0_e_n_k.mData.data(), + b0_preshuffled.mData.data(), + N * 2 * experts, + K, + device_op.GetPreShuffleParameters()); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + // FMA * tokens * N * (Gate+Up) * topk * K + + // FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale) + std::size_t(2) * tokens * N * 2 * topk * K + + std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm1; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + a1_t_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_k_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp new file mode 100644 index 0000000000..1b8a7a16e3 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -0,0 +1,542 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, + 16, 16, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + int eids[sorted_tile_num]{}; + for(int i = 0; i < sorted_tile_num; i++) + { + if(i < valid_tile_num) + { + eids[i] = (i * experts) / valid_tile_num; + } + else + { + eids[i] = 3; + } + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 8: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); + // d2_e_n.savetxt("weight.txt", "int"); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); + } + } + } + + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // FMA * tokens * N * topk * K + + // FMA * tokens * N * topk * (K/BlockScale) + std::size_t flop = std::size_t(2) * tokens * topk * N * K + + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + a1_t_k_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, // topk weights + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 6718581a50..829bf9af24 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -158,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp new file mode 100644 index 0000000000..efbd0f0c03 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -0,0 +1,584 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// B preshuffle +void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + I64 tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K_pk; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, + 16, 16, + 16, 16, + 8, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = 13; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + int eids[sorted_tile_num]{}; + for(int i = 0; i < sorted_tile_num; i++) + { + if(i < valid_tile_num) + { + eids[i] = (i * experts) / valid_tile_num; + } + else + { + eids[i] = 3; + } + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 8: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); + } + } + } + + // A, B Scale preshuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + preShuffleBuffer(b0_e_n_k.mData.data(), + b0_preshuffled.mData.data(), + N * experts, + K, + device_op.GetPreShuffleParameters()); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // FMA * tokens * N * topk * K + + // FMA * tokens * N * topk * (K/BlockScale) + std::size_t flop = std::size_t(2) * tokens * topk * N * K + + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + a1_t_k_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, // topk weights + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 4fc8b0b4c9..1b004ec100 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,7 +1,7 @@ # validate user-specified fmha_fwd API list -set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv") +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING - "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") + "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") if(FMHA_FWD_ENABLE_APIS STREQUAL "all") set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) endif() @@ -17,11 +17,30 @@ if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) list(APPEND FMHA_FWD_ENABLE_APIS "fwd") endif() +file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py +) +# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change +set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") + string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +set(FMHA_FWD_CODE_GEN_COMMON_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FMHA_FWD_APIS} + # --filter fmha_fwd... +) +set(FMHA_BWD_CODE_GEN_COMMON_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api bwd + --receipt 3 + # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... +) + # generate a list of kernels, but not actually emit files at config sta execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) @@ -29,8 +48,8 @@ if(ret AND NOT ret EQUAL 0) endif() execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3 + COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) @@ -44,14 +63,16 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} ) add_custom_command( OUTPUT ${FMHA_BWD_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3 + COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") @@ -73,7 +94,7 @@ target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) - set(FMHA_FWD_FAST_EXP2 true) + set(FMHA_FWD_FAST_EXP2 true) endif() set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) @@ -82,9 +103,9 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) @@ -102,6 +123,13 @@ else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) endif() +# conditionally enable call to the pagedkv_prefill API in fmha_fwd example +if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) +endif() + # conditionally specify the use of OCP_FP8 if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 12414a20ed..72109a660b 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -71,6 +71,7 @@ args: -drop_seed seed for random number generator (default:1) -drop_offset offset for random number generator (default:0) -drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0) + -num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1) -warmup number of iterations before benchmark the kernel (default:5) -repeat number of iterations to benchmark the kernel (default:20) ``` diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 5b9d5742b4..9e15a822ef 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -122,6 +122,7 @@ PIPELINE_ENUM_MAP = { "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", + "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 80b64f918a..c251460a9a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -169,7 +169,7 @@ template 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 7cbbdb9034..37a1b7329b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -282,18 +282,19 @@ class FmhaFwdApiPool: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + self.pool[trait.dtype][hdim].append(copy.copy(trait)) @property def api(self) -> str: per_dtypes=str() for i, dtype in enumerate(self.pool.keys()): per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][(hdim, hdim_v)] inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' @@ -306,7 +307,7 @@ class FmhaFwdApiPool: F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: @@ -435,18 +436,20 @@ class FmhaFwdKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + (32, 32) : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + (64, 64) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### (96, 128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + (128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### (160,160) : FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1), + (192,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### (192,192) : FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1), + (256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + (64,64 ) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + (128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + (256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -454,7 +457,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + def get_pipelines(dtype, hdim, hdim_v) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! @@ -463,7 +466,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl pipelines = [] if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - if hdim == 256: + if hdim == 256 and hdim_v == 256: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) @@ -507,15 +510,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if d == None: continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): + for ((hdim, hdim_v), tile), mode in itertools.product(d.items(), MODE_MAP.keys()): + for pipeline in get_pipelines(dtype, hdim, hdim_v): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if hdim == 192 and tile.F_bn1 == 128: + if (hdim, hdim_v) == (192, 128) or hdim == 160: # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 3ae0e28be3..2d2d71555d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -34,6 +34,7 @@ K0_MAX_SUBMAX_MAP = { 64 : 64, 96 : 128, 128: 128, + # 160: 160, 256: 256 } @@ -638,6 +639,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -656,6 +658,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d '64' : FmhaFwdSplitKVCombineTileSize(32, -1), ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + ### '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -683,7 +686,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if dtype in ['fp16', 'bf16']: for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): # TODO: use async pipeline when compiler is more stable - if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: + if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]: # if True: pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py new file mode 100644 index 0000000000..650ebaf80e --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -0,0 +1,585 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_FWD_PAGEDKV_PIPELINE_MAP = { + "qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, //lse + {F_pagedkv}, //pagedkv + {F_squant}, + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdPagedKVKernel; + +using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + +#include + +template<> +float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + return fmha_fwd_pagedkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + pagedkv : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_pagedkv : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + + if self.F_pagedkv == 't' : n += '_pagedkv' + else: n += '_npagedkv' + + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + pagedkv=self.F_pipeline.F_pagedkv, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + elif dtype in ['fp8', 'bf8']: + # TODO + None + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + # if pipeline.F_pagedkv == 'f': + # continue + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' : + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index eaf99529f3..b6de5ea621 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_bwd.hpp" #include "ck_tile/host.hpp" @@ -355,7 +355,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); - assert(slopes.size() == nhead); + assert(slopes.size() == static_cast(nhead)); if(bias.rank_info == 0) { // alibi in 1*h @@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser) if(p_drop > 0) { - p_hp_host_ref.ForEach( - [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + p_dropped_hp_host_ref = p_hp_host_ref; randval_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); }); ck_tile::reference_batched_dropout( p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); - p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType(); } else { - p_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_hp_host_ref.template CopyAsType(); } // O = P * V @@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser) } // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) - ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { - AccDataType do_dot_o = 0; - for(int o = 0; o < hdim_v; o++) - { - auto idx_gmo = idx_gmn; - idx_gmo[2] = o; - do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * - ck_tile::type_convert(o_host_refs[wb](idx_gmo)); - } - self(idx_gmn) = ck_tile::type_convert( - p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); - }); + ck_tile::make_ParallelTensorFunctor( + [&](auto i0, auto i1, auto i2) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + do_dot_o += ck_tile::type_convert(do_host_ref(i0, i1, o)) * + ck_tile::type_convert(o_host_refs[wb](i0, i1, o)); + } + ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert( + p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o)); + }, + ds_hp_host_ref.mDesc.get_lengths()[0], + ds_hp_host_ref.mDesc.get_lengths()[1], + ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency()); if(use_dbias) { - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + dbias_host_ref = ds_hp_host_ref.template CopyAsType(); } - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + ds_lp_host_ref = ds_hp_host_ref.template CopyAsType(); // dV = P_drop^T@dO^T // dV = P^T@dO^T w/o dropout diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp old mode 100644 new mode 100755 index bb1f495c4e..e9403f4698 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" @@ -178,50 +178,30 @@ auto get_elimit(std::string init_method) } } -int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split if(batch_nhead_mblocks >= 0.8f * num_SMs) { return 1; } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + max_splits = std::min({max_splits, num_SMs}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; - // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, - // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks - // (i.e. it's 11 splits anyway). - // So we check if the number of blocks per split is the same as the previous num_splits. - auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { - return num_splits == 1 || - ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); - }; for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) { - efficiency.push_back(0.f); - } - else - { - float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if(eff > max_efficiency) - { - max_efficiency = eff; - } - efficiency.push_back(eff); + max_efficiency = eff; } + efficiency.push_back(eff); } for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) - { - continue; - } if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); @@ -234,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int override_num_splits_if_necessary( int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) { + (void)hdim_v; int device; auto status = hipGetDevice(&device); if(status != hipSuccess) @@ -250,15 +231,13 @@ int override_num_splits_if_necessary( // tile size should match the generate.py const int kM0 = 64; - const int kN1 = hdim_v; const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); - const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); if(num_splits < 1 && p_drop == 0.0f) { return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); } return num_splits; @@ -344,7 +323,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \ + CK_TILE_FMHA_FWD_PAGEDKV_API)) if(0 < page_block_size) { std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" @@ -360,7 +340,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API) if(use_cache_batch_idx) { std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" @@ -542,8 +522,8 @@ bool run(const ck_tile::ArgParser& arg_parser) max_seqlen_k = real_seqlen_k; } - flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + - static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); + flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + + static_cast(2) * mask.get_unmaskarea() * hdim_v); num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + sizeof(KDataType) * real_seqlen_k * hdim_q + @@ -568,7 +548,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cerr << "num_splits greater than 128 is not supported" << std::endl; return false; } -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < p_drop && (1 < num_splits || use_kvcache)) { std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" @@ -823,7 +803,7 @@ bool run(const ck_tile::ArgParser& arg_parser) << (is_rotary_interleaved ? "inter" : "half") << ")"; } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(1 < num_splits) { std::cout << ", num_splits:" << num_splits; @@ -864,6 +844,11 @@ bool run(const ck_tile::ArgParser& arg_parser) { traits.has_dropout = (p_drop > 0.0f); } + else if constexpr(std::is_same_v>) + { + traits.use_pagedkv = use_kvcache; + } } }; @@ -1072,6 +1057,17 @@ bool run(const ck_tile::ArgParser& arg_parser) args.split_stride_lse_acc = split_stride_lse_acc; args.split_stride_o_acc = split_stride_o_acc; } + else if constexpr(std::is_same_v>) + { + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + } } }; @@ -1093,7 +1089,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const float fwd_ave_time = [&] { #if CK_TILE_FMHA_FWD_SPLITKV_API - if(1 < num_splits || use_kvcache) + if(1 < num_splits && use_kvcache) { fmha_fwd_splitkv_traits fmha_splitkv_traits; init_traits(fmha_splitkv_traits); @@ -1103,6 +1099,18 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); } +#endif +#if CK_TILE_FMHA_FWD_PAGEDKV_API + if(use_kvcache) + { + fmha_fwd_pagedkv_traits fmha_pagedkv_traits; + init_traits(fmha_pagedkv_traits); + + fmha_fwd_pagedkv_args fmha_pagedkv_args; + init_args(fmha_pagedkv_args); + + return fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); + } #endif fmha_fwd_traits fmha_traits; init_traits(fmha_traits); @@ -1258,7 +1266,7 @@ bool run(const ck_tile::ArgParser& arg_parser) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < page_block_size) { if(i_perm) { k_host_ref.ForEach([&](auto& self, auto i) { @@ -1309,7 +1317,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < page_block_size) { if(is_v_rowmajor) { if(i_perm) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 5ce56d48b5..81dda692ea 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -178,6 +178,86 @@ struct fmha_fwd_args drop_seed_offset; }; +struct fmha_fwd_pagedkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not + // nullptr. + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + // + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k_ptr[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // + // when is_gappy=true: + // seqlen_k = kargs.seqlen_k_ptr[b] + // seqstart_k_ptr[b] now store local offset of each batch + // + // when is_gappy=false: + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; +}; + struct fmha_fwd_splitkv_args { const void* q_ptr; @@ -501,6 +581,114 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } } +template +auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.is_gappy, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + // FmhaKernel::PrintParameters(kargs, args.batch); + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + template auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) { @@ -715,102 +903,102 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, #if 0 // we assume page_block_size=1 for now args.kv_last_page_lens, args.page_block_size, #endif - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_k, - args.batch_stride_v, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqlen_q, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, #if 0 // we assume page_block_size=1 for now args.kv_last_page_lens, args.page_block_size, #endif - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_lse, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } }(); @@ -870,6 +1058,57 @@ struct fmha_fwd_traits_ template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +template +struct fmha_fwd_pagedkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kIsPagedKV = kIsPagedKV_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; +}; + +template +float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args); + template (0)); + ck_tile::index_t x_end = std::min(i_y + x, seqlen_k); + if(x_end > x_start) + { + area += (x_end - x_start); + } + } + return area; + } friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) { mi.serialize(os); diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 0238a125dc..d77582630a 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 4c16f13cef..da37159aeb 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -30,7 +30,7 @@ args: -stride_c Tensor C stride (default:0) -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) -e Absolute error tolerance (default:1e-5) - -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) -warmup number of iterations before benchmark the kernel (default:10) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index de9608bcb4..80c18cdb87 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -12,15 +12,20 @@ #include "ck_tile/host.hpp" #include "gemm_utils.hpp" -template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + bool Persistent, + typename CDEElementWise> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { if constexpr(Persistent) std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; @@ -53,8 +58,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using CodegenGemmTraits = ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; const auto Run = [&](const auto memory_operation_) { @@ -63,9 +70,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -131,12 +141,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -147,24 +157,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } else { - if(a_layout == "R" && b_layout == "R") + if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } + else if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -202,15 +212,24 @@ int run_gemm_example(int argc, char* argv[]) return run_gemm_example_prec_type( a_layout, b_layout, argc, argv); } - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + else if(data_type == "i8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType - return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } } -#endif else { throw std::runtime_error("Unsupported data type for this operation !!!"); diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index aec5f6a116..2157397f1d 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -1,4 +1,3 @@ - // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. @@ -14,78 +13,28 @@ #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 -#endif - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#else -#error "unsupported CK_TILE_PIPELINE_DEFAULT value" -#endif - -struct GemmConfig +template +constexpr ck_tile::index_t get_k_warp_tile() { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 32; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 8; - - static constexpr bool DoubleSmemBuffer = false; -#endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - // Compute friendly for Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; - - static constexpr bool DoubleSmemBuffer = false; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 32; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool DoubleSmemBuffer = true; +#if defined(__gfx950__) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; #endif +} +struct GemmConfigBase +{ static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; @@ -99,6 +48,169 @@ struct GemmConfig static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; template @@ -150,6 +262,15 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + template struct DataTypeTraits; @@ -165,6 +286,12 @@ struct DataTypeTraits static constexpr const char* name = "fp64"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + template <> struct DataTypeTraits { @@ -195,6 +322,51 @@ struct DataTypeTraits static constexpr const char* name = "pk_int4_t"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -223,10 +395,13 @@ auto create_args(int argc, char* argv[]) // host API template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); + bool Persistent = false, + typename CDEElementWise> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index bf455a6415..d3ef974d91 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -30,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template ; - using GemmPipeline = GEMM_PIPELINE; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; const ck_tile::index_t K = tensor.get_length(0); const ck_tile::index_t N = tensor.get_length(1); @@ -144,13 +146,31 @@ void permute_vectors_i4x4_b(Tensor& tensor) } } -template + typename DsLayout, + typename CLayout, + bool Persistent, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s); + +template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -165,41 +185,50 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_repeat, bool persistent) { - ck_tile::GemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_C = stride_C; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + {}, + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C}; float ave_time; if(persistent) { - ave_time = gemm_calc( + ave_time = gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); } else { - ave_time = gemm_calc( + ave_time = gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); } @@ -222,7 +251,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, return ave_time; } -template {-1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); } else if(init_method == 1) { @@ -302,7 +332,8 @@ int run_gemm_example_with_layouts(int argc, ck_tile::HostTensor b_k_n_dev = b_k_n; if constexpr(GemmConfig::PermuteB) { - permute_tensor_b( - a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat, - persistent); + invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat, + persistent); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; @@ -376,29 +415,19 @@ int run_gemm_example_with_layouts(int argc, // Restore input for B for gpu reference b_k_n_dev_buf.ToDevice(b_k_n.data()); } + + // memory on host to store gpu reference result ck_tile::HostTensor c_m_n_gpu_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + // memory on device to store gpu reference result ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); - ADataType* d_A; - BDataType* d_B; - CDataType* d_C; - - ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes())); - ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes())); - ck_tile::hip_check_error( - hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes())); - - ck_tile::hip_check_error(hipMemcpy(d_A, - a_m_k_dev_buf.GetDeviceBuffer(), - a_m_k.get_element_space_size_in_bytes(), - hipMemcpyHostToDevice)); - ck_tile::hip_check_error(hipMemcpy(d_B, - b_k_n_dev_buf.GetDeviceBuffer(), - b_k_n.get_element_space_size_in_bytes(), - hipMemcpyHostToDevice)); + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), - d_C, - c_m_n_dev_result.get_element_space_size_in_bytes(), - hipMemcpyDeviceToHost)); - - ck_tile::hip_check_error(hipFree(d_A)); - ck_tile::hip_check_error(hipFree(d_B)); - ck_tile::hip_check_error(hipFree(d_C)); - c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + const float max_accumulated_value = *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 3a7cc93df8..c2c3fc1fa4 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -13,15 +13,20 @@ #include "gemm_utils.hpp" #include "run_gemm_example.inc" -template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + typename DsLayout, + typename ELayout, + bool Persistent, + typename CDEElementWise> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -30,31 +35,36 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& sequence, GemmConfig::PermuteA, GemmConfig::PermuteB>; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; + ELayout, + GemmConfig::NumWaveGroups>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + Persistent, + GemmConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; @@ -68,7 +78,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; + memory_operation, + GemmConfig::NumWaveGroups>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -154,7 +169,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& // clear c mem if(args.k_batch > 1) hipGetErrorString(hipMemsetAsync( - args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; ave_time = ck_tile::launch_kernel_preprocess( s, @@ -190,11 +205,13 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& }; BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -204,12 +221,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -222,22 +239,22 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "R") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Row{}, Row{}); } else if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -247,6 +264,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } } +template